From b13527b8b29adb45625bef529e43f6ed7ac1c3a7 Mon Sep 17 00:00:00 2001 From: Timothy Flynn Date: Fri, 2 Dec 2022 08:04:05 -0500 Subject: [PATCH] SQLServer: Parse SQL a single time to actually "prepare" the statement One of the benefits of prepared statements is that the SQL string is parsed just once and re-used. This updates SQLStatement to do just that and store the parsed result. --- .../SQLServer/ConnectionFromClient.cpp | 16 ++++++--- .../Services/SQLServer/DatabaseConnection.cpp | 15 ++++---- .../Services/SQLServer/DatabaseConnection.h | 3 +- Userland/Services/SQLServer/SQLStatement.cpp | 34 ++++++++----------- Userland/Services/SQLServer/SQLStatement.h | 11 +++--- 5 files changed, 40 insertions(+), 39 deletions(-) diff --git a/Userland/Services/SQLServer/ConnectionFromClient.cpp b/Userland/Services/SQLServer/ConnectionFromClient.cpp index 6589152b73..fddcb1576e 100644 --- a/Userland/Services/SQLServer/ConnectionFromClient.cpp +++ b/Userland/Services/SQLServer/ConnectionFromClient.cpp @@ -54,15 +54,21 @@ void ConnectionFromClient::disconnect(int connection_id) Messages::SQLServer::PrepareStatementResponse ConnectionFromClient::prepare_statement(int connection_id, DeprecatedString const& sql) { dbgln_if(SQLSERVER_DEBUG, "ConnectionFromClient::prepare_statement(connection_id: {}, sql: '{}')", connection_id, sql); + auto database_connection = DatabaseConnection::connection_for(connection_id); - if (database_connection) { - auto statement_id = database_connection->prepare_statement(sql); - dbgln_if(SQLSERVER_DEBUG, "ConnectionFromClient::prepare_statement -> statement_id = {}", statement_id); - return { statement_id }; - } else { + if (!database_connection) { dbgln("Database connection has disappeared"); return { -1 }; } + + auto result = database_connection->prepare_statement(sql); + if (result.is_error()) { + dbgln_if(SQLSERVER_DEBUG, "Could not parse SQL statement: {}", result.error().error_string()); + return { -1 }; + } + + dbgln_if(SQLSERVER_DEBUG, "ConnectionFromClient::prepare_statement -> statement_id = {}", result.value()); + return { result.value() }; } void ConnectionFromClient::execute_statement(int statement_id) diff --git a/Userland/Services/SQLServer/DatabaseConnection.cpp b/Userland/Services/SQLServer/DatabaseConnection.cpp index f018d60e40..0f3bebf224 100644 --- a/Userland/Services/SQLServer/DatabaseConnection.cpp +++ b/Userland/Services/SQLServer/DatabaseConnection.cpp @@ -67,19 +67,20 @@ void DatabaseConnection::disconnect() }); } -int DatabaseConnection::prepare_statement(DeprecatedString const& sql) +SQL::ResultOr DatabaseConnection::prepare_statement(StringView sql) { dbgln_if(SQLSERVER_DEBUG, "DatabaseConnection::prepare_statement(connection_id {}, database '{}', sql '{}'", connection_id(), m_database_name, sql); + + if (!m_accept_statements) + return SQL::Result { SQL::SQLCommand::Unknown, SQL::SQLErrorCode::DatabaseUnavailable }; + auto client_connection = ConnectionFromClient::client_connection_for(client_id()); if (!client_connection) { warnln("Cannot notify client of database disconnection. Client disconnected"); - return -1; + return SQL::Result { SQL::SQLCommand::Unknown, SQL::SQLErrorCode::InternalError, "Client disconnected"sv }; } - if (!m_accept_statements) { - client_connection->async_execution_error(-1, (int)SQL::SQLErrorCode::DatabaseUnavailable, m_database_name); - return -1; - } - auto statement = SQLStatement::construct(*this, sql); + + auto statement = TRY(SQLStatement::create(*this, sql)); return statement->statement_id(); } diff --git a/Userland/Services/SQLServer/DatabaseConnection.h b/Userland/Services/SQLServer/DatabaseConnection.h index 9244d7d398..77632f0938 100644 --- a/Userland/Services/SQLServer/DatabaseConnection.h +++ b/Userland/Services/SQLServer/DatabaseConnection.h @@ -8,6 +8,7 @@ #include #include +#include #include namespace SQLServer { @@ -23,7 +24,7 @@ public: int client_id() const { return m_client_id; } RefPtr database() { return m_database; } void disconnect(); - int prepare_statement(DeprecatedString const& sql); + SQL::ResultOr prepare_statement(StringView sql); private: DatabaseConnection(DeprecatedString database_name, int client_id); diff --git a/Userland/Services/SQLServer/SQLStatement.cpp b/Userland/Services/SQLServer/SQLStatement.cpp index c9d61fcd06..6e33444b48 100644 --- a/Userland/Services/SQLServer/SQLStatement.cpp +++ b/Userland/Services/SQLServer/SQLStatement.cpp @@ -24,12 +24,23 @@ RefPtr SQLStatement::statement_for(int statement_id) static int s_next_statement_id = 0; -SQLStatement::SQLStatement(DatabaseConnection& connection, DeprecatedString sql) +SQL::ResultOr> SQLStatement::create(DatabaseConnection& connection, StringView sql) +{ + auto parser = SQL::AST::Parser(SQL::AST::Lexer(sql)); + auto statement = parser.next_statement(); + + if (parser.has_errors()) + return SQL::Result { SQL::SQLCommand::Unknown, SQL::SQLErrorCode::SyntaxError, parser.errors()[0].to_deprecated_string() }; + + return TRY(adopt_nonnull_ref_or_enomem(new (nothrow) SQLStatement(connection, move(statement)))); +} + +SQLStatement::SQLStatement(DatabaseConnection& connection, NonnullRefPtr statement) : Core::Object(&connection) , m_statement_id(s_next_statement_id++) - , m_sql(move(sql)) + , m_statement(move(statement)) { - dbgln_if(SQLSERVER_DEBUG, "SQLStatement({}, {})", connection.connection_id(), sql); + dbgln_if(SQLSERVER_DEBUG, "SQLStatement({})", connection.connection_id()); s_statements.set(m_statement_id, *this); } @@ -47,7 +58,6 @@ void SQLStatement::report_error(SQL::Result result) else warnln("Cannot return execution error. Client disconnected"); - m_statement = nullptr; m_result = {}; } @@ -61,12 +71,6 @@ void SQLStatement::execute() } deferred_invoke([this] { - auto parse_result = parse(); - if (parse_result.is_error()) { - report_error(parse_result.release_error()); - return; - } - VERIFY(!connection()->database().is_null()); auto execution_result = m_statement->execute(connection()->database().release_nonnull()); @@ -93,16 +97,6 @@ void SQLStatement::execute() }); } -SQL::ResultOr SQLStatement::parse() -{ - auto parser = SQL::AST::Parser(SQL::AST::Lexer(m_sql)); - m_statement = parser.next_statement(); - - if (parser.has_errors()) - return SQL::Result { SQL::SQLCommand::Unknown, SQL::SQLErrorCode::SyntaxError, parser.errors()[0].to_deprecated_string() }; - return {}; -} - bool SQLStatement::should_send_result_rows() const { VERIFY(m_result.has_value()); diff --git a/Userland/Services/SQLServer/SQLStatement.h b/Userland/Services/SQLServer/SQLStatement.h index c4bae21119..3088e23766 100644 --- a/Userland/Services/SQLServer/SQLStatement.h +++ b/Userland/Services/SQLServer/SQLStatement.h @@ -18,28 +18,27 @@ namespace SQLServer { class SQLStatement final : public Core::Object { - C_OBJECT(SQLStatement) + C_OBJECT_ABSTRACT(SQLStatement) public: + static SQL::ResultOr> create(DatabaseConnection&, StringView sql); ~SQLStatement() override = default; static RefPtr statement_for(int statement_id); int statement_id() const { return m_statement_id; } - DeprecatedString const& sql() const { return m_sql; } DatabaseConnection* connection() { return dynamic_cast(parent()); } void execute(); private: - SQLStatement(DatabaseConnection&, DeprecatedString sql); - SQL::ResultOr parse(); + SQLStatement(DatabaseConnection&, NonnullRefPtr statement); + bool should_send_result_rows() const; void next(); void report_error(SQL::Result); int m_statement_id; - DeprecatedString m_sql; size_t m_index { 0 }; - RefPtr m_statement { nullptr }; + NonnullRefPtr m_statement; Optional m_result {}; };