diff --git a/Userland/Libraries/LibSQL/SQLClient.cpp b/Userland/Libraries/LibSQL/SQLClient.cpp index b9d7045823..e37605ab91 100644 --- a/Userland/Libraries/LibSQL/SQLClient.cpp +++ b/Userland/Libraries/LibSQL/SQLClient.cpp @@ -6,6 +6,7 @@ */ #include +#include #include #include @@ -202,6 +203,8 @@ void SQLClient::execution_error(u64 statement_id, u64 execution_id, SQLErrorCode void SQLClient::next_result(u64 statement_id, u64 execution_id, Vector const& row) { + ScopeGuard guard { [&]() { async_ready_for_next_result(statement_id, execution_id); } }; + if (!on_next_result) { StringBuilder builder; builder.join(", "sv, row, "\"{}\""sv); diff --git a/Userland/Services/SQLServer/ConnectionFromClient.cpp b/Userland/Services/SQLServer/ConnectionFromClient.cpp index 426755d8fc..67c4c16856 100644 --- a/Userland/Services/SQLServer/ConnectionFromClient.cpp +++ b/Userland/Services/SQLServer/ConnectionFromClient.cpp @@ -97,4 +97,18 @@ Messages::SQLServer::ExecuteStatementResponse ConnectionFromClient::execute_stat return Optional {}; } +void ConnectionFromClient::ready_for_next_result(SQL::StatementID statement_id, SQL::ExecutionID execution_id) +{ + dbgln_if(SQLSERVER_DEBUG, "ConnectionFromClient::ready_for_next_result(statement_id: {}, execution_id: {})", statement_id, execution_id); + auto statement = SQLStatement::statement_for(statement_id); + + if (statement && statement->connection().client_id() == client_id()) { + statement->ready_for_next_result(execution_id); + return; + } + + dbgln_if(SQLSERVER_DEBUG, "Statement has disappeared"); + async_execution_error(statement_id, execution_id, SQL::SQLErrorCode::StatementUnavailable, ByteString::formatted("{}", statement_id)); +} + } diff --git a/Userland/Services/SQLServer/ConnectionFromClient.h b/Userland/Services/SQLServer/ConnectionFromClient.h index c12c5b5b62..60549a4492 100644 --- a/Userland/Services/SQLServer/ConnectionFromClient.h +++ b/Userland/Services/SQLServer/ConnectionFromClient.h @@ -37,6 +37,7 @@ private: virtual Messages::SQLServer::ConnectResponse connect(ByteString const&) override; virtual Messages::SQLServer::PrepareStatementResponse prepare_statement(SQL::ConnectionID, ByteString const&) override; virtual Messages::SQLServer::ExecuteStatementResponse execute_statement(SQL::StatementID, Vector const& placeholder_values) override; + virtual void ready_for_next_result(SQL::StatementID, SQL::ExecutionID) override; virtual void disconnect(SQL::ConnectionID) override; ByteString m_database_path; diff --git a/Userland/Services/SQLServer/SQLServer.ipc b/Userland/Services/SQLServer/SQLServer.ipc index 1ffa538023..e25fcf64a5 100644 --- a/Userland/Services/SQLServer/SQLServer.ipc +++ b/Userland/Services/SQLServer/SQLServer.ipc @@ -5,5 +5,6 @@ endpoint SQLServer connect(ByteString name) => (Optional connection_id) prepare_statement(u64 connection_id, ByteString statement) => (Optional statement_id) execute_statement(u64 statement_id, Vector placeholder_values) => (Optional execution_id) + ready_for_next_result(u64 statement_id, u64 execution_id) =| disconnect(u64 connection_id) => () } diff --git a/Userland/Services/SQLServer/SQLStatement.cpp b/Userland/Services/SQLServer/SQLStatement.cpp index 287fab05e3..2d56f5641c 100644 --- a/Userland/Services/SQLServer/SQLStatement.cpp +++ b/Userland/Services/SQLServer/SQLStatement.cpp @@ -68,11 +68,9 @@ Optional SQLStatement::execute(Vector placeholder_ } auto execution_id = m_next_execution_id++; - m_ongoing_executions.set(execution_id); Core::deferred_invoke([this, strong_this = NonnullRefPtr(*this), placeholder_values = move(placeholder_values), execution_id] { auto execution_result = m_statement->execute(connection().database(), placeholder_values); - m_ongoing_executions.remove(execution_id); if (execution_result.is_error()) { report_error(execution_result.release_error(), execution_id); @@ -86,19 +84,20 @@ Optional SQLStatement::execute(Vector placeholder_ } auto result = execution_result.release_value(); + auto result_size = result.size(); if (should_send_result_rows(result)) { client_connection->async_execution_success(statement_id(), execution_id, result.column_names(), true, 0, 0, 0); - auto result_size = result.size(); - next(execution_id, move(result), result_size); + m_ongoing_executions.set(execution_id, { move(result), result_size }); + ready_for_next_result(execution_id); } else { if (result.command() == SQL::SQLCommand::Insert) - client_connection->async_execution_success(statement_id(), execution_id, result.column_names(), false, result.size(), 0, 0); + client_connection->async_execution_success(statement_id(), execution_id, result.column_names(), false, result_size, 0, 0); else if (result.command() == SQL::SQLCommand::Update) - client_connection->async_execution_success(statement_id(), execution_id, result.column_names(), false, 0, result.size(), 0); + client_connection->async_execution_success(statement_id(), execution_id, result.column_names(), false, 0, result_size, 0); else if (result.command() == SQL::SQLCommand::Delete) - client_connection->async_execution_success(statement_id(), execution_id, result.column_names(), false, 0, 0, result.size()); + client_connection->async_execution_success(statement_id(), execution_id, result.column_names(), false, 0, 0, result_size); else client_connection->async_execution_success(statement_id(), execution_id, result.column_names(), false, 0, 0, 0); } @@ -107,6 +106,29 @@ Optional SQLStatement::execute(Vector placeholder_ return execution_id; } +void SQLStatement::ready_for_next_result(SQL::ExecutionID execution_id) +{ + auto client_connection = ConnectionFromClient::client_connection_for(connection().client_id()); + if (!client_connection) { + warnln("Cannot yield next result. Client disconnected"); + return; + } + + auto execution = m_ongoing_executions.get(execution_id); + if (!execution.has_value()) { + return; + } + + if (execution->result.is_empty()) { + client_connection->async_results_exhausted(statement_id(), execution_id, execution->result_size); + m_ongoing_executions.remove(execution_id); + return; + } + + auto result_row = execution->result.take_first(); + client_connection->async_next_result(statement_id(), execution_id, result_row.row.take_data()); +} + bool SQLStatement::should_send_result_rows(SQL::ResultSet const& result) const { if (result.is_empty()) @@ -121,24 +143,4 @@ bool SQLStatement::should_send_result_rows(SQL::ResultSet const& result) const } } -void SQLStatement::next(SQL::ExecutionID execution_id, SQL::ResultSet result, size_t result_size) -{ - auto client_connection = ConnectionFromClient::client_connection_for(connection().client_id()); - if (!client_connection) { - warnln("Cannot yield next result. Client disconnected"); - return; - } - - if (!result.is_empty()) { - auto result_row = result.take_first(); - client_connection->async_next_result(statement_id(), execution_id, result_row.row.take_data()); - - Core::deferred_invoke([this, strong_this = NonnullRefPtr(*this), execution_id, result = move(result), result_size]() mutable { - next(execution_id, move(result), result_size); - }); - } else { - client_connection->async_results_exhausted(statement_id(), execution_id, result_size); - } -} - } diff --git a/Userland/Services/SQLServer/SQLStatement.h b/Userland/Services/SQLServer/SQLStatement.h index 10155211fd..866e24efb6 100644 --- a/Userland/Services/SQLServer/SQLStatement.h +++ b/Userland/Services/SQLServer/SQLStatement.h @@ -26,18 +26,22 @@ public: SQL::StatementID statement_id() const { return m_statement_id; } DatabaseConnection& connection() { return m_connection; } Optional execute(Vector placeholder_values); + void ready_for_next_result(SQL::ExecutionID); private: SQLStatement(DatabaseConnection&, NonnullRefPtr statement); bool should_send_result_rows(SQL::ResultSet const& result) const; - void next(SQL::ExecutionID execution_id, SQL::ResultSet result, size_t result_size); void report_error(SQL::Result, SQL::ExecutionID execution_id); DatabaseConnection& m_connection; SQL::StatementID m_statement_id { 0 }; - HashTable m_ongoing_executions; + struct Execution { + SQL::ResultSet result; + size_t result_size { 0 }; + }; + HashMap m_ongoing_executions; SQL::ExecutionID m_next_execution_id { 0 }; NonnullRefPtr m_statement;