diff --git a/Userland/Applications/Browser/Database.cpp b/Userland/Applications/Browser/Database.cpp index e74ab95da3..6931be0b28 100644 --- a/Userland/Applications/Browser/Database.cpp +++ b/Userland/Applications/Browser/Database.cpp @@ -30,11 +30,11 @@ Database::Database(NonnullRefPtr sql_client, SQL::ConnectionID c : m_sql_client(move(sql_client)) , m_connection_id(connection_id) { - m_sql_client->on_execution_success = [this](auto statement_id, auto execution_id, auto has_results, auto, auto, auto) { - if (has_results) + m_sql_client->on_execution_success = [this](auto result) { + if (result.has_results) return; - if (auto it = m_pending_executions.find({ statement_id, execution_id }); it != m_pending_executions.end()) { + if (auto it = find_pending_execution(result); it != m_pending_executions.end()) { auto in_progress_statement = move(it->value); m_pending_executions.remove(it); @@ -43,15 +43,15 @@ Database::Database(NonnullRefPtr sql_client, SQL::ConnectionID c } }; - m_sql_client->on_next_result = [this](auto statement_id, auto execution_id, auto row) { - if (auto it = m_pending_executions.find({ statement_id, execution_id }); it != m_pending_executions.end()) { + m_sql_client->on_next_result = [this](auto result) { + if (auto it = find_pending_execution(result); it != m_pending_executions.end()) { if (it->value.on_result) - it->value.on_result(row); + it->value.on_result(result.values); } }; - m_sql_client->on_results_exhausted = [this](auto statement_id, auto execution_id, auto) { - if (auto it = m_pending_executions.find({ statement_id, execution_id }); it != m_pending_executions.end()) { + m_sql_client->on_results_exhausted = [this](auto result) { + if (auto it = find_pending_execution(result); it != m_pending_executions.end()) { auto in_progress_statement = move(it->value); m_pending_executions.remove(it); @@ -60,13 +60,13 @@ Database::Database(NonnullRefPtr sql_client, SQL::ConnectionID c } }; - m_sql_client->on_execution_error = [this](auto statement_id, auto execution_id, auto, auto const& message) { - if (auto it = m_pending_executions.find({ statement_id, execution_id }); it != m_pending_executions.end()) { + m_sql_client->on_execution_error = [this](auto result) { + if (auto it = find_pending_execution(result); it != m_pending_executions.end()) { auto in_progress_statement = move(it->value); m_pending_executions.remove(it); if (in_progress_statement.on_error) - in_progress_statement.on_error(message); + in_progress_statement.on_error(result.error_message); } }; } diff --git a/Userland/Applications/Browser/Database.h b/Userland/Applications/Browser/Database.h index 74eb58bac3..bbe05e198c 100644 --- a/Userland/Applications/Browser/Database.h +++ b/Userland/Applications/Browser/Database.h @@ -68,6 +68,12 @@ private: Database(NonnullRefPtr sql_client, SQL::ConnectionID connection_id); void execute_statement(SQL::StatementID statement_id, Vector placeholder_values, PendingExecution pending_execution); + template + auto find_pending_execution(ResultData const& result_data) + { + return m_pending_executions.find({ result_data.statement_id, result_data.execution_id }); + } + NonnullRefPtr m_sql_client; SQL::ConnectionID m_connection_id { 0 }; diff --git a/Userland/DevTools/SQLStudio/MainWidget.cpp b/Userland/DevTools/SQLStudio/MainWidget.cpp index 4de878b9cf..08cfa90d4c 100644 --- a/Userland/DevTools/SQLStudio/MainWidget.cpp +++ b/Userland/DevTools/SQLStudio/MainWidget.cpp @@ -253,23 +253,23 @@ MainWidget::MainWidget() }; m_sql_client = SQL::SQLClient::try_create().release_value_but_fixme_should_propagate_errors(); - m_sql_client->on_execution_success = [this](auto, auto, auto, auto, auto, auto) { + m_sql_client->on_execution_success = [this](auto) { read_next_sql_statement_of_editor(); }; - m_sql_client->on_execution_error = [this](auto, auto, auto, auto message) { + m_sql_client->on_execution_error = [this](auto result) { auto* editor = active_editor(); VERIFY(editor); - GUI::MessageBox::show_error(window(), DeprecatedString::formatted("Error executing {}\n{}", editor->path(), message)); + GUI::MessageBox::show_error(window(), DeprecatedString::formatted("Error executing {}\n{}", editor->path(), result.error_message)); }; - m_sql_client->on_next_result = [this](auto, auto, auto row) { + m_sql_client->on_next_result = [this](auto result) { m_results.append({}); - m_results.last().ensure_capacity(row.size()); + m_results.last().ensure_capacity(result.values.size()); - for (auto const& value : row) + for (auto const& value : result.values) m_results.last().unchecked_append(value.to_deprecated_string()); }; - m_sql_client->on_results_exhausted = [this](auto, auto, auto) { + m_sql_client->on_results_exhausted = [this](auto) { if (m_results.size() == 0) return; if (m_results[0].size() == 0) diff --git a/Userland/Libraries/LibSQL/SQLClient.cpp b/Userland/Libraries/LibSQL/SQLClient.cpp index bab488ef0f..1b0705ed16 100644 --- a/Userland/Libraries/LibSQL/SQLClient.cpp +++ b/Userland/Libraries/LibSQL/SQLClient.cpp @@ -154,45 +154,74 @@ ErrorOr> SQLClient::launch_server_and_create_client(Vec #endif -void SQLClient::execution_error(u64 statement_id, u64 execution_id, SQLErrorCode const& code, DeprecatedString const& message) -{ - if (on_execution_error) - on_execution_error(statement_id, execution_id, code, message); - else - warnln("Execution error for statement_id {}: {} ({})", statement_id, message, to_underlying(code)); -} - void SQLClient::execution_success(u64 statement_id, u64 execution_id, bool has_results, size_t created, size_t updated, size_t deleted) { - if (on_execution_success) - on_execution_success(statement_id, execution_id, has_results, created, updated, deleted); - else + if (!on_execution_success) { outln("{} row(s) created, {} updated, {} deleted", created, updated, deleted); -} - -void SQLClient::next_result(u64 statement_id, u64 execution_id, Vector const& row) -{ - if (on_next_result) { - on_next_result(statement_id, execution_id, row); return; } - bool first = true; - for (auto& column : row) { - if (!first) - out(", "); - out("\"{}\"", column); - first = false; + ExecutionSuccess success { + .statement_id = statement_id, + .execution_id = execution_id, + .has_results = has_results, + .rows_created = created, + .rows_updated = updated, + .rows_deleted = deleted, + }; + + on_execution_success(move(success)); +} + +void SQLClient::execution_error(u64 statement_id, u64 execution_id, SQLErrorCode const& code, DeprecatedString const& message) +{ + if (!on_execution_error) { + warnln("Execution error for statement_id {}: {} ({})", statement_id, message, to_underlying(code)); + return; } - outln(); + + ExecutionError error { + .statement_id = statement_id, + .execution_id = execution_id, + .error_code = code, + .error_message = move(const_cast(message)), + }; + + on_execution_error(move(error)); +} + +void SQLClient::next_result(u64 statement_id, u64 execution_id, Vector const& row) +{ + if (!on_next_result) { + StringBuilder builder; + builder.join(", "sv, row, "\"{}\""sv); + outln("{}", builder.string_view()); + return; + } + + ExecutionResult result { + .statement_id = statement_id, + .execution_id = execution_id, + .values = move(const_cast&>(row)), + }; + + on_next_result(move(result)); } void SQLClient::results_exhausted(u64 statement_id, u64 execution_id, size_t total_rows) { - if (on_results_exhausted) - on_results_exhausted(statement_id, execution_id, total_rows); - else + if (!on_results_exhausted) { outln("{} total row(s)", total_rows); + return; + } + + ExecutionComplete success { + .statement_id = statement_id, + .execution_id = execution_id, + .total_rows = total_rows, + }; + + on_results_exhausted(move(success)); } } diff --git a/Userland/Libraries/LibSQL/SQLClient.h b/Userland/Libraries/LibSQL/SQLClient.h index 76c4a94f7e..97c37a7ea8 100644 --- a/Userland/Libraries/LibSQL/SQLClient.h +++ b/Userland/Libraries/LibSQL/SQLClient.h @@ -15,6 +15,38 @@ namespace SQL { +struct ExecutionSuccess { + u64 statement_id { 0 }; + u64 execution_id { 0 }; + + bool has_results { false }; + size_t rows_created { 0 }; + size_t rows_updated { 0 }; + size_t rows_deleted { 0 }; +}; + +struct ExecutionError { + u64 statement_id { 0 }; + u64 execution_id { 0 }; + + SQLErrorCode error_code; + DeprecatedString error_message; +}; + +struct ExecutionResult { + u64 statement_id { 0 }; + u64 execution_id { 0 }; + + Vector values; +}; + +struct ExecutionComplete { + u64 statement_id { 0 }; + u64 execution_id { 0 }; + + size_t total_rows { 0 }; +}; + class SQLClient : public IPC::ConnectionToServer , public SQLClientEndpoint { @@ -27,10 +59,10 @@ public: virtual ~SQLClient() = default; - Function on_execution_error; - Function on_execution_success; - Function)> on_next_result; - Function on_results_exhausted; + Function on_execution_success; + Function on_execution_error; + Function on_next_result; + Function on_results_exhausted; private: explicit SQLClient(NonnullOwnPtr socket) @@ -39,9 +71,9 @@ private: } virtual void execution_success(u64 statement_id, u64 execution_id, bool has_results, size_t created, size_t updated, size_t deleted) override; + virtual void execution_error(u64 statement_id, u64 execution_id, SQLErrorCode const& code, DeprecatedString const& message) override; virtual void next_result(u64 statement_id, u64 execution_id, Vector const&) override; virtual void results_exhausted(u64 statement_id, u64 execution_id, size_t total_rows) override; - virtual void execution_error(u64 statement_id, u64 execution_id, SQLErrorCode const& code, DeprecatedString const& message) override; }; } diff --git a/Userland/Utilities/sql.cpp b/Userland/Utilities/sql.cpp index b263973b01..d4d630cbf2 100644 --- a/Userland/Utilities/sql.cpp +++ b/Userland/Utilities/sql.cpp @@ -76,28 +76,26 @@ public: m_editor->set_prompt(prompt_for_level(open_indents)); }; - m_sql_client->on_execution_success = [this](auto, auto, auto has_results, auto created, auto updated, auto deleted) { - if (updated != 0 || created != 0 || deleted != 0) { - outln("{} row(s) created, {} updated, {} deleted", created, updated, deleted); - } - if (!has_results) { + m_sql_client->on_execution_success = [this](auto result) { + if (result.rows_updated != 0 || result.rows_created != 0 || result.rows_deleted != 0) + outln("{} row(s) created, {} updated, {} deleted", result.rows_created, result.rows_updated, result.rows_deleted); + if (!result.has_results) read_sql(); - } }; - m_sql_client->on_next_result = [](auto, auto, auto row) { + m_sql_client->on_next_result = [](auto result) { StringBuilder builder; - builder.join(", "sv, row); + builder.join(", "sv, result.values); outln("{}", builder.to_deprecated_string()); }; - m_sql_client->on_results_exhausted = [this](auto, auto, auto total_rows) { - outln("{} row(s)", total_rows); + m_sql_client->on_results_exhausted = [this](auto result) { + outln("{} row(s)", result.total_rows); read_sql(); }; - m_sql_client->on_execution_error = [this](auto, auto, auto, auto const& message) { - outln("\033[33;1mExecution error:\033[0m {}", message); + m_sql_client->on_execution_error = [this](auto result) { + outln("\033[33;1mExecution error:\033[0m {}", result.error_message); read_sql(); };