diff --git a/Userland/DevTools/SQLStudio/MainWidget.cpp b/Userland/DevTools/SQLStudio/MainWidget.cpp index 6974fad0b1..25f1f2f5b0 100644 --- a/Userland/DevTools/SQLStudio/MainWidget.cpp +++ b/Userland/DevTools/SQLStudio/MainWidget.cpp @@ -214,13 +214,13 @@ MainWidget::MainWidget() m_statusbar->segment(2).set_fixed_width(font().width("Ln 0000, Col 000"sv) + font().max_glyph_width()); m_sql_client = SQL::SQLClient::try_create().release_value_but_fixme_should_propagate_errors(); - m_sql_client->on_execution_success = [this](auto, auto, auto, auto, auto) { + m_sql_client->on_execution_success = [this](auto, auto, auto, auto, auto, auto) { read_next_sql_statement_of_editor(); }; - m_sql_client->on_next_result = [this](auto, auto const& row) { + m_sql_client->on_next_result = [this](auto, auto, auto const& row) { m_results.append(row); }; - m_sql_client->on_results_exhausted = [this](auto, auto) { + m_sql_client->on_results_exhausted = [this](auto, auto, auto) { if (m_results.size() == 0) return; if (m_results[0].size() == 0) diff --git a/Userland/Libraries/LibSQL/SQLClient.cpp b/Userland/Libraries/LibSQL/SQLClient.cpp index b0a0b3c410..3b1bd35f44 100644 --- a/Userland/Libraries/LibSQL/SQLClient.cpp +++ b/Userland/Libraries/LibSQL/SQLClient.cpp @@ -29,26 +29,26 @@ void SQLClient::connection_error(u64 connection_id, SQLErrorCode const& code, De warnln("Connection error for connection_id {}: {} ({})", connection_id, message, to_underlying(code)); } -void SQLClient::execution_error(u64 statement_id, SQLErrorCode const& code, DeprecatedString const& message) +void SQLClient::execution_error(u64 statement_id, u64 execution_id, SQLErrorCode const& code, DeprecatedString const& message) { if (on_execution_error) - on_execution_error(statement_id, code, message); + on_execution_error(statement_id, execution_id, code, message); else warnln("Execution error for statement_id {}: {} ({})", statement_id, message, to_underlying(code)); } -void SQLClient::execution_success(u64 statement_id, bool has_results, size_t created, size_t updated, size_t deleted) +void SQLClient::execution_success(u64 statement_id, u64 execution_id, bool has_results, size_t created, size_t updated, size_t deleted) { if (on_execution_success) - on_execution_success(statement_id, has_results, created, updated, deleted); + on_execution_success(statement_id, execution_id, has_results, created, updated, deleted); else outln("{} row(s) created, {} updated, {} deleted", created, updated, deleted); } -void SQLClient::next_result(u64 statement_id, Vector const& row) +void SQLClient::next_result(u64 statement_id, u64 execution_id, Vector const& row) { if (on_next_result) { - on_next_result(statement_id, row); + on_next_result(statement_id, execution_id, row); return; } bool first = true; @@ -61,10 +61,10 @@ void SQLClient::next_result(u64 statement_id, Vector const& ro outln(); } -void SQLClient::results_exhausted(u64 statement_id, size_t total_rows) +void SQLClient::results_exhausted(u64 statement_id, u64 execution_id, size_t total_rows) { if (on_results_exhausted) - on_results_exhausted(statement_id, total_rows); + on_results_exhausted(statement_id, execution_id, total_rows); else outln("{} total row(s)", total_rows); } diff --git a/Userland/Libraries/LibSQL/SQLClient.h b/Userland/Libraries/LibSQL/SQLClient.h index ea4340141c..52ecb33375 100644 --- a/Userland/Libraries/LibSQL/SQLClient.h +++ b/Userland/Libraries/LibSQL/SQLClient.h @@ -23,10 +23,10 @@ class SQLClient Function on_connected; Function on_disconnected; Function on_connection_error; - Function on_execution_error; - Function on_execution_success; - Function const&)> on_next_result; - Function on_results_exhausted; + Function on_execution_error; + Function on_execution_success; + Function const&)> on_next_result; + Function on_results_exhausted; private: SQLClient(NonnullOwnPtr socket) @@ -36,10 +36,10 @@ private: virtual void connected(u64 connection_id, DeprecatedString const& connected_to_database) override; virtual void connection_error(u64 connection_id, SQLErrorCode const& code, DeprecatedString const& message) override; - virtual void execution_success(u64 statement_id, bool has_results, size_t created, size_t updated, size_t deleted) override; - virtual void next_result(u64 statement_id, Vector const&) override; - virtual void results_exhausted(u64 statement_id, size_t total_rows) override; - virtual void execution_error(u64 statement_id, SQLErrorCode const& code, DeprecatedString const& message) override; + virtual void execution_success(u64 statement_id, u64 execution_id, bool has_results, size_t created, size_t updated, size_t deleted) override; + virtual void next_result(u64 statement_id, u64 execution_id, Vector const&) override; + virtual void results_exhausted(u64 statement_id, u64 execution_id, size_t total_rows) override; + virtual void execution_error(u64 statement_id, u64 execution_id, SQLErrorCode const& code, DeprecatedString const& message) override; virtual void disconnected(u64 connection_id) override; }; diff --git a/Userland/Services/SQLServer/ConnectionFromClient.cpp b/Userland/Services/SQLServer/ConnectionFromClient.cpp index cb4fd185c4..95d54622da 100644 --- a/Userland/Services/SQLServer/ConnectionFromClient.cpp +++ b/Userland/Services/SQLServer/ConnectionFromClient.cpp @@ -71,17 +71,19 @@ Messages::SQLServer::PrepareStatementResponse ConnectionFromClient::prepare_stat return { result.value() }; } -void ConnectionFromClient::execute_statement(u64 statement_id, Vector const& placeholder_values) +Messages::SQLServer::ExecuteStatementResponse ConnectionFromClient::execute_statement(u64 statement_id, Vector const& placeholder_values) { dbgln_if(SQLSERVER_DEBUG, "ConnectionFromClient::execute_query_statement(statement_id: {})", statement_id); + auto statement = SQLStatement::statement_for(statement_id); if (statement && statement->connection()->client_id() == client_id()) { // FIXME: Support taking parameters from IPC requests. - statement->execute(move(const_cast&>(placeholder_values))); - } else { - dbgln_if(SQLSERVER_DEBUG, "Statement has disappeared"); - async_execution_error(statement_id, SQL::SQLErrorCode::StatementUnavailable, DeprecatedString::formatted("{}", statement_id)); + return statement->execute(move(const_cast&>(placeholder_values))); } + + dbgln_if(SQLSERVER_DEBUG, "Statement has disappeared"); + async_execution_error(statement_id, -1, SQL::SQLErrorCode::StatementUnavailable, DeprecatedString::formatted("{}", statement_id)); + return { {} }; } } diff --git a/Userland/Services/SQLServer/ConnectionFromClient.h b/Userland/Services/SQLServer/ConnectionFromClient.h index 05e2186e03..c880d649de 100644 --- a/Userland/Services/SQLServer/ConnectionFromClient.h +++ b/Userland/Services/SQLServer/ConnectionFromClient.h @@ -30,7 +30,7 @@ private: virtual Messages::SQLServer::ConnectResponse connect(DeprecatedString const&) override; virtual Messages::SQLServer::PrepareStatementResponse prepare_statement(u64, DeprecatedString const&) override; - virtual void execute_statement(u64, Vector const& placeholder_values) override; + virtual Messages::SQLServer::ExecuteStatementResponse execute_statement(u64, Vector const& placeholder_values) override; virtual void disconnect(u64) override; }; diff --git a/Userland/Services/SQLServer/SQLClient.ipc b/Userland/Services/SQLServer/SQLClient.ipc index 7228851b6e..e5b87110f4 100644 --- a/Userland/Services/SQLServer/SQLClient.ipc +++ b/Userland/Services/SQLServer/SQLClient.ipc @@ -4,9 +4,9 @@ endpoint SQLClient { connected(u64 connection_id, DeprecatedString connected_to_database) =| connection_error(u64 connection_id, SQL::SQLErrorCode code, DeprecatedString message) =| - execution_success(u64 statement_id, bool has_results, size_t created, size_t updated, size_t deleted) =| - next_result(u64 statement_id, Vector row) =| - results_exhausted(u64 statement_id, size_t total_rows) =| - execution_error(u64 statement_id, SQL::SQLErrorCode code, DeprecatedString message) =| + execution_success(u64 statement_id, u64 execution_id, bool has_results, size_t created, size_t updated, size_t deleted) =| + next_result(u64 statement_id, u64 execution_id, Vector row) =| + results_exhausted(u64 statement_id, u64 execution_id, size_t total_rows) =| + execution_error(u64 statement_id, u64 execution_id, SQL::SQLErrorCode code, DeprecatedString message) =| disconnected(u64 connection_id) =| } diff --git a/Userland/Services/SQLServer/SQLServer.ipc b/Userland/Services/SQLServer/SQLServer.ipc index 89fb7b8392..ba06d5eb39 100644 --- a/Userland/Services/SQLServer/SQLServer.ipc +++ b/Userland/Services/SQLServer/SQLServer.ipc @@ -4,6 +4,6 @@ endpoint SQLServer { connect(DeprecatedString name) => (u64 connection_id) prepare_statement(u64 connection_id, DeprecatedString statement) => (Optional statement_id) - execute_statement(u64 statement_id, Vector placeholder_values) =| + execute_statement(u64 statement_id, Vector placeholder_values) => (Optional execution_id) disconnect(u64 connection_id) =| } diff --git a/Userland/Services/SQLServer/SQLStatement.cpp b/Userland/Services/SQLServer/SQLStatement.cpp index d4b9f583fc..cce1528f05 100644 --- a/Userland/Services/SQLServer/SQLStatement.cpp +++ b/Userland/Services/SQLServer/SQLStatement.cpp @@ -43,7 +43,7 @@ SQLStatement::SQLStatement(DatabaseConnection& connection, NonnullRefPtrasync_execution_error(statement_id(), result.error(), result.error_string()); + client_connection->async_execution_error(statement_id(), execution_id, result.error(), result.error_string()); else warnln("Cannot return execution error. Client disconnected"); m_result = {}; } -void SQLStatement::execute(Vector placeholder_values) +Optional SQLStatement::execute(Vector placeholder_values) { dbgln_if(SQLSERVER_DEBUG, "SQLStatement::execute(statement_id {}", statement_id()); auto client_connection = ConnectionFromClient::client_connection_for(connection()->client_id()); if (!client_connection) { warnln("Cannot yield next result. Client disconnected"); - return; + return {}; } - deferred_invoke([this, placeholder_values = move(placeholder_values)] { + auto execution_id = m_next_execution_id++; + m_ongoing_executions.set(execution_id); + + deferred_invoke([this, placeholder_values = move(placeholder_values), execution_id] { VERIFY(!connection()->database().is_null()); auto execution_result = m_statement->execute(connection()->database().release_nonnull(), placeholder_values); + m_ongoing_executions.remove(execution_id); + if (execution_result.is_error()) { - report_error(execution_result.release_error()); + report_error(execution_result.release_error(), execution_id); return; } @@ -88,13 +93,15 @@ void SQLStatement::execute(Vector placeholder_values) m_result = execution_result.release_value(); if (should_send_result_rows()) { - client_connection->async_execution_success(statement_id(), true, 0, 0, 0); + client_connection->async_execution_success(statement_id(), execution_id, true, 0, 0, 0); m_index = 0; - next(); + next(execution_id); } else { - client_connection->async_execution_success(statement_id(), false, 0, m_result->size(), 0); + client_connection->async_execution_success(statement_id(), execution_id, false, 0, m_result->size(), 0); } }); + + return execution_id; } bool SQLStatement::should_send_result_rows() const @@ -113,22 +120,24 @@ bool SQLStatement::should_send_result_rows() const } } -void SQLStatement::next() +void SQLStatement::next(u64 execution_id) { VERIFY(!m_result->is_empty()); + auto client_connection = ConnectionFromClient::client_connection_for(connection()->client_id()); if (!client_connection) { warnln("Cannot yield next result. Client disconnected"); return; } + if (m_index < m_result->size()) { auto& tuple = m_result->at(m_index++).row; - client_connection->async_next_result(statement_id(), tuple.to_deprecated_string_vector()); - deferred_invoke([this]() { - next(); + client_connection->async_next_result(statement_id(), execution_id, tuple.to_deprecated_string_vector()); + deferred_invoke([this, execution_id]() { + next(execution_id); }); } else { - client_connection->async_results_exhausted(statement_id(), m_index); + client_connection->async_results_exhausted(statement_id(), execution_id, m_index); } } diff --git a/Userland/Services/SQLServer/SQLStatement.h b/Userland/Services/SQLServer/SQLStatement.h index 19e97d0718..7bdf95c910 100644 --- a/Userland/Services/SQLServer/SQLStatement.h +++ b/Userland/Services/SQLServer/SQLStatement.h @@ -28,17 +28,21 @@ public: static RefPtr statement_for(u64 statement_id); u64 statement_id() const { return m_statement_id; } DatabaseConnection* connection() { return dynamic_cast(parent()); } - void execute(Vector placeholder_values); + Optional execute(Vector placeholder_values); private: SQLStatement(DatabaseConnection&, NonnullRefPtr statement); bool should_send_result_rows() const; - void next(); - void report_error(SQL::Result); + void next(u64 execution_id); + void report_error(SQL::Result, u64 execution_id); u64 m_statement_id { 0 }; size_t m_index { 0 }; + + HashTable m_ongoing_executions; + u64 m_next_execution_id { 0 }; + NonnullRefPtr m_statement; Optional m_result {}; }; diff --git a/Userland/Utilities/sql.cpp b/Userland/Utilities/sql.cpp index 9d15132a09..a0c046c9ac 100644 --- a/Userland/Utilities/sql.cpp +++ b/Userland/Utilities/sql.cpp @@ -84,7 +84,7 @@ public: read_sql(); }; - m_sql_client->on_execution_success = [this](auto, auto has_results, auto updated, auto created, auto deleted) { + m_sql_client->on_execution_success = [this](auto, auto, auto has_results, auto updated, auto created, auto deleted) { if (updated != 0 || created != 0 || deleted != 0) { outln("{} row(s) updated, {} created, {} deleted", updated, created, deleted); } @@ -93,13 +93,13 @@ public: } }; - m_sql_client->on_next_result = [](auto, auto const& row) { + m_sql_client->on_next_result = [](auto, auto, auto const& row) { StringBuilder builder; builder.join(", "sv, row); outln("{}", builder.build()); }; - m_sql_client->on_results_exhausted = [this](auto, auto total_rows) { + m_sql_client->on_results_exhausted = [this](auto, auto, auto total_rows) { outln("{} row(s)", total_rows); read_sql(); }; @@ -109,7 +109,7 @@ public: m_loop.quit(to_underlying(code)); }; - m_sql_client->on_execution_error = [this](auto, auto, auto const& message) { + m_sql_client->on_execution_error = [this](auto, auto, auto, auto const& message) { outln("\033[33;1mExecution error:\033[0m {}", message); read_sql(); };