diff --git a/Userland/Services/SQLServer/ConnectionFromClient.cpp b/Userland/Services/SQLServer/ConnectionFromClient.cpp index 69b0517cf7..80d5a85f47 100644 --- a/Userland/Services/SQLServer/ConnectionFromClient.cpp +++ b/Userland/Services/SQLServer/ConnectionFromClient.cpp @@ -87,7 +87,7 @@ Messages::SQLServer::ExecuteStatementResponse ConnectionFromClient::execute_stat dbgln_if(SQLSERVER_DEBUG, "ConnectionFromClient::execute_query_statement(statement_id: {})", statement_id); auto statement = SQLStatement::statement_for(statement_id); - if (statement && statement->connection()->client_id() == client_id()) { + if (statement && statement->connection().client_id() == client_id()) { // FIXME: Support taking parameters from IPC requests. return statement->execute(move(const_cast&>(placeholder_values))); } diff --git a/Userland/Services/SQLServer/SQLStatement.cpp b/Userland/Services/SQLServer/SQLStatement.cpp index 6e61888536..5c49efc62d 100644 --- a/Userland/Services/SQLServer/SQLStatement.cpp +++ b/Userland/Services/SQLServer/SQLStatement.cpp @@ -35,7 +35,7 @@ SQL::ResultOr> SQLStatement::create(DatabaseConnecti } SQLStatement::SQLStatement(DatabaseConnection& connection, NonnullRefPtr statement) - : Core::EventReceiver(&connection) + : m_connection(connection) , m_statement_id(s_next_statement_id++) , m_statement(move(statement)) { @@ -47,10 +47,9 @@ void SQLStatement::report_error(SQL::Result result, SQL::ExecutionID execution_i { dbgln_if(SQLSERVER_DEBUG, "SQLStatement::report_error(statement_id {}, error {}", statement_id(), result.error_string()); - auto client_connection = ConnectionFromClient::client_connection_for(connection()->client_id()); + auto client_connection = ConnectionFromClient::client_connection_for(connection().client_id()); s_statements.remove(statement_id()); - remove_from_parent(); if (client_connection) client_connection->async_execution_error(statement_id(), execution_id, result.error(), result.error_string()); @@ -62,7 +61,7 @@ Optional SQLStatement::execute(Vector placeholder_ { dbgln_if(SQLSERVER_DEBUG, "SQLStatement::execute(statement_id {}", statement_id()); - auto client_connection = ConnectionFromClient::client_connection_for(connection()->client_id()); + auto client_connection = ConnectionFromClient::client_connection_for(connection().client_id()); if (!client_connection) { warnln("Cannot yield next result. Client disconnected"); return {}; @@ -71,8 +70,8 @@ Optional SQLStatement::execute(Vector placeholder_ auto execution_id = m_next_execution_id++; m_ongoing_executions.set(execution_id); - deferred_invoke([this, placeholder_values = move(placeholder_values), execution_id] { - auto execution_result = m_statement->execute(connection()->database(), placeholder_values); + Core::deferred_invoke([this, strong_this = NonnullRefPtr(*this), placeholder_values = move(placeholder_values), execution_id] { + auto execution_result = m_statement->execute(connection().database(), placeholder_values); m_ongoing_executions.remove(execution_id); if (execution_result.is_error()) { @@ -80,7 +79,7 @@ Optional SQLStatement::execute(Vector placeholder_ return; } - auto client_connection = ConnectionFromClient::client_connection_for(connection()->client_id()); + auto client_connection = ConnectionFromClient::client_connection_for(connection().client_id()); if (!client_connection) { warnln("Cannot return statement execution results. Client disconnected"); return; @@ -124,7 +123,7 @@ bool SQLStatement::should_send_result_rows(SQL::ResultSet const& result) const void SQLStatement::next(SQL::ExecutionID execution_id, SQL::ResultSet result, size_t result_size) { - auto client_connection = ConnectionFromClient::client_connection_for(connection()->client_id()); + auto client_connection = ConnectionFromClient::client_connection_for(connection().client_id()); if (!client_connection) { warnln("Cannot yield next result. Client disconnected"); return; @@ -134,7 +133,7 @@ void SQLStatement::next(SQL::ExecutionID execution_id, SQL::ResultSet result, si auto result_row = result.take_first(); client_connection->async_next_result(statement_id(), execution_id, result_row.row.take_data()); - deferred_invoke([this, execution_id, result = move(result), result_size]() mutable { + Core::deferred_invoke([this, strong_this = NonnullRefPtr(*this), execution_id, result = move(result), result_size]() mutable { next(execution_id, move(result), result_size); }); } else { diff --git a/Userland/Services/SQLServer/SQLStatement.h b/Userland/Services/SQLServer/SQLStatement.h index 233a309c97..10155211fd 100644 --- a/Userland/Services/SQLServer/SQLStatement.h +++ b/Userland/Services/SQLServer/SQLStatement.h @@ -7,8 +7,8 @@ #pragma once #include +#include #include -#include #include #include #include @@ -18,16 +18,13 @@ namespace SQLServer { -class SQLStatement final : public Core::EventReceiver { - C_OBJECT_ABSTRACT(SQLStatement) - +class SQLStatement final : public RefCounted { public: static SQL::ResultOr> create(DatabaseConnection&, StringView sql); - ~SQLStatement() override = default; static RefPtr statement_for(SQL::StatementID statement_id); SQL::StatementID statement_id() const { return m_statement_id; } - DatabaseConnection* connection() { return dynamic_cast(parent()); } + DatabaseConnection& connection() { return m_connection; } Optional execute(Vector placeholder_values); private: @@ -37,6 +34,7 @@ private: void next(SQL::ExecutionID execution_id, SQL::ResultSet result, size_t result_size); void report_error(SQL::Result, SQL::ExecutionID execution_id); + DatabaseConnection& m_connection; SQL::StatementID m_statement_id { 0 }; HashTable m_ongoing_executions;