From 1205d39fc35fdf9cd79d434f70ce14ae6335e0f8 Mon Sep 17 00:00:00 2001 From: Timothy Flynn Date: Tue, 9 Jan 2024 19:32:44 -0500 Subject: [PATCH] LibSQL+SQLServer: Inform SQLServer when the client has processed results The architecture of SQLServer is currently such that it sends results over IPC one row at a time. After the rows are exhausted, it sends a completion IPC. However, it does not wait for the client to finish processing a row before sending another row or the completion signal. This can result in clients hanging if the completion comes in while a row is being processed. At least in the case of WebView::Database, the result is that the completion signal is dropped, and the browser then hangs forever waiting for that signal (after it finishes processing the row). This patch makes SQLServer asynchronously wait for the client to tell it that the row has been processed and the next row (or completion) may be sent. We repurpose the `m_ongoing_executions` in SQLStatement for this purpose (this member was oddly being written to, but otherwise unused). --- Userland/Libraries/LibSQL/SQLClient.cpp | 3 + .../SQLServer/ConnectionFromClient.cpp | 14 +++++ .../Services/SQLServer/ConnectionFromClient.h | 1 + Userland/Services/SQLServer/SQLServer.ipc | 1 + Userland/Services/SQLServer/SQLStatement.cpp | 56 ++++++++++--------- Userland/Services/SQLServer/SQLStatement.h | 8 ++- 6 files changed, 54 insertions(+), 29 deletions(-) diff --git a/Userland/Libraries/LibSQL/SQLClient.cpp b/Userland/Libraries/LibSQL/SQLClient.cpp index b9d7045823..e37605ab91 100644 --- a/Userland/Libraries/LibSQL/SQLClient.cpp +++ b/Userland/Libraries/LibSQL/SQLClient.cpp @@ -6,6 +6,7 @@ */ #include +#include #include #include @@ -202,6 +203,8 @@ void SQLClient::execution_error(u64 statement_id, u64 execution_id, SQLErrorCode void SQLClient::next_result(u64 statement_id, u64 execution_id, Vector const& row) { + ScopeGuard guard { [&]() { async_ready_for_next_result(statement_id, execution_id); } }; + if (!on_next_result) { StringBuilder builder; builder.join(", "sv, row, "\"{}\""sv); diff --git a/Userland/Services/SQLServer/ConnectionFromClient.cpp b/Userland/Services/SQLServer/ConnectionFromClient.cpp index 426755d8fc..67c4c16856 100644 --- a/Userland/Services/SQLServer/ConnectionFromClient.cpp +++ b/Userland/Services/SQLServer/ConnectionFromClient.cpp @@ -97,4 +97,18 @@ Messages::SQLServer::ExecuteStatementResponse ConnectionFromClient::execute_stat return Optional {}; } +void ConnectionFromClient::ready_for_next_result(SQL::StatementID statement_id, SQL::ExecutionID execution_id) +{ + dbgln_if(SQLSERVER_DEBUG, "ConnectionFromClient::ready_for_next_result(statement_id: {}, execution_id: {})", statement_id, execution_id); + auto statement = SQLStatement::statement_for(statement_id); + + if (statement && statement->connection().client_id() == client_id()) { + statement->ready_for_next_result(execution_id); + return; + } + + dbgln_if(SQLSERVER_DEBUG, "Statement has disappeared"); + async_execution_error(statement_id, execution_id, SQL::SQLErrorCode::StatementUnavailable, ByteString::formatted("{}", statement_id)); +} + } diff --git a/Userland/Services/SQLServer/ConnectionFromClient.h b/Userland/Services/SQLServer/ConnectionFromClient.h index c12c5b5b62..60549a4492 100644 --- a/Userland/Services/SQLServer/ConnectionFromClient.h +++ b/Userland/Services/SQLServer/ConnectionFromClient.h @@ -37,6 +37,7 @@ private: virtual Messages::SQLServer::ConnectResponse connect(ByteString const&) override; virtual Messages::SQLServer::PrepareStatementResponse prepare_statement(SQL::ConnectionID, ByteString const&) override; virtual Messages::SQLServer::ExecuteStatementResponse execute_statement(SQL::StatementID, Vector const& placeholder_values) override; + virtual void ready_for_next_result(SQL::StatementID, SQL::ExecutionID) override; virtual void disconnect(SQL::ConnectionID) override; ByteString m_database_path; diff --git a/Userland/Services/SQLServer/SQLServer.ipc b/Userland/Services/SQLServer/SQLServer.ipc index 1ffa538023..e25fcf64a5 100644 --- a/Userland/Services/SQLServer/SQLServer.ipc +++ b/Userland/Services/SQLServer/SQLServer.ipc @@ -5,5 +5,6 @@ endpoint SQLServer connect(ByteString name) => (Optional connection_id) prepare_statement(u64 connection_id, ByteString statement) => (Optional statement_id) execute_statement(u64 statement_id, Vector placeholder_values) => (Optional execution_id) + ready_for_next_result(u64 statement_id, u64 execution_id) =| disconnect(u64 connection_id) => () } diff --git a/Userland/Services/SQLServer/SQLStatement.cpp b/Userland/Services/SQLServer/SQLStatement.cpp index 287fab05e3..2d56f5641c 100644 --- a/Userland/Services/SQLServer/SQLStatement.cpp +++ b/Userland/Services/SQLServer/SQLStatement.cpp @@ -68,11 +68,9 @@ Optional SQLStatement::execute(Vector placeholder_ } auto execution_id = m_next_execution_id++; - m_ongoing_executions.set(execution_id); Core::deferred_invoke([this, strong_this = NonnullRefPtr(*this), placeholder_values = move(placeholder_values), execution_id] { auto execution_result = m_statement->execute(connection().database(), placeholder_values); - m_ongoing_executions.remove(execution_id); if (execution_result.is_error()) { report_error(execution_result.release_error(), execution_id); @@ -86,19 +84,20 @@ Optional SQLStatement::execute(Vector placeholder_ } auto result = execution_result.release_value(); + auto result_size = result.size(); if (should_send_result_rows(result)) { client_connection->async_execution_success(statement_id(), execution_id, result.column_names(), true, 0, 0, 0); - auto result_size = result.size(); - next(execution_id, move(result), result_size); + m_ongoing_executions.set(execution_id, { move(result), result_size }); + ready_for_next_result(execution_id); } else { if (result.command() == SQL::SQLCommand::Insert) - client_connection->async_execution_success(statement_id(), execution_id, result.column_names(), false, result.size(), 0, 0); + client_connection->async_execution_success(statement_id(), execution_id, result.column_names(), false, result_size, 0, 0); else if (result.command() == SQL::SQLCommand::Update) - client_connection->async_execution_success(statement_id(), execution_id, result.column_names(), false, 0, result.size(), 0); + client_connection->async_execution_success(statement_id(), execution_id, result.column_names(), false, 0, result_size, 0); else if (result.command() == SQL::SQLCommand::Delete) - client_connection->async_execution_success(statement_id(), execution_id, result.column_names(), false, 0, 0, result.size()); + client_connection->async_execution_success(statement_id(), execution_id, result.column_names(), false, 0, 0, result_size); else client_connection->async_execution_success(statement_id(), execution_id, result.column_names(), false, 0, 0, 0); } @@ -107,6 +106,29 @@ Optional SQLStatement::execute(Vector placeholder_ return execution_id; } +void SQLStatement::ready_for_next_result(SQL::ExecutionID execution_id) +{ + auto client_connection = ConnectionFromClient::client_connection_for(connection().client_id()); + if (!client_connection) { + warnln("Cannot yield next result. Client disconnected"); + return; + } + + auto execution = m_ongoing_executions.get(execution_id); + if (!execution.has_value()) { + return; + } + + if (execution->result.is_empty()) { + client_connection->async_results_exhausted(statement_id(), execution_id, execution->result_size); + m_ongoing_executions.remove(execution_id); + return; + } + + auto result_row = execution->result.take_first(); + client_connection->async_next_result(statement_id(), execution_id, result_row.row.take_data()); +} + bool SQLStatement::should_send_result_rows(SQL::ResultSet const& result) const { if (result.is_empty()) @@ -121,24 +143,4 @@ bool SQLStatement::should_send_result_rows(SQL::ResultSet const& result) const } } -void SQLStatement::next(SQL::ExecutionID execution_id, SQL::ResultSet result, size_t result_size) -{ - auto client_connection = ConnectionFromClient::client_connection_for(connection().client_id()); - if (!client_connection) { - warnln("Cannot yield next result. Client disconnected"); - return; - } - - if (!result.is_empty()) { - auto result_row = result.take_first(); - client_connection->async_next_result(statement_id(), execution_id, result_row.row.take_data()); - - Core::deferred_invoke([this, strong_this = NonnullRefPtr(*this), execution_id, result = move(result), result_size]() mutable { - next(execution_id, move(result), result_size); - }); - } else { - client_connection->async_results_exhausted(statement_id(), execution_id, result_size); - } -} - } diff --git a/Userland/Services/SQLServer/SQLStatement.h b/Userland/Services/SQLServer/SQLStatement.h index 10155211fd..866e24efb6 100644 --- a/Userland/Services/SQLServer/SQLStatement.h +++ b/Userland/Services/SQLServer/SQLStatement.h @@ -26,18 +26,22 @@ public: SQL::StatementID statement_id() const { return m_statement_id; } DatabaseConnection& connection() { return m_connection; } Optional execute(Vector placeholder_values); + void ready_for_next_result(SQL::ExecutionID); private: SQLStatement(DatabaseConnection&, NonnullRefPtr statement); bool should_send_result_rows(SQL::ResultSet const& result) const; - void next(SQL::ExecutionID execution_id, SQL::ResultSet result, size_t result_size); void report_error(SQL::Result, SQL::ExecutionID execution_id); DatabaseConnection& m_connection; SQL::StatementID m_statement_id { 0 }; - HashTable m_ongoing_executions; + struct Execution { + SQL::ResultSet result; + size_t result_size { 0 }; + }; + HashMap m_ongoing_executions; SQL::ExecutionID m_next_execution_id { 0 }; NonnullRefPtr m_statement;