diff --git a/Tests/LibSQL/TestSqlExpressionParser.cpp b/Tests/LibSQL/TestSqlExpressionParser.cpp index ada10d993f..f308e3fc58 100644 --- a/Tests/LibSQL/TestSqlExpressionParser.cpp +++ b/Tests/LibSQL/TestSqlExpressionParser.cpp @@ -131,6 +131,19 @@ TEST_CASE(null_literal) validate("NULL"sv); } +TEST_CASE(bind_parameter) +{ + auto validate = [](StringView sql) { + auto result = parse(sql); + EXPECT(!result.is_error()); + + auto expression = result.release_value(); + EXPECT(is(*expression)); + }; + + validate("?"sv); +} + TEST_CASE(column_name) { EXPECT(parse(".column_name"sv).is_error()); diff --git a/Tests/LibSQL/TestSqlStatementExecution.cpp b/Tests/LibSQL/TestSqlStatementExecution.cpp index 87228dad59..cc2fd9e425 100644 --- a/Tests/LibSQL/TestSqlStatementExecution.cpp +++ b/Tests/LibSQL/TestSqlStatementExecution.cpp @@ -21,19 +21,19 @@ namespace { constexpr char const* db_name = "/tmp/test.db"; -SQL::ResultOr try_execute(NonnullRefPtr database, DeprecatedString const& sql) +SQL::ResultOr try_execute(NonnullRefPtr database, DeprecatedString const& sql, Vector placeholder_values = {}) { auto parser = SQL::AST::Parser(SQL::AST::Lexer(sql)); auto statement = parser.next_statement(); EXPECT(!parser.has_errors()); if (parser.has_errors()) outln("{}", parser.errors()[0].to_deprecated_string()); - return statement->execute(move(database)); + return statement->execute(move(database), placeholder_values); } -SQL::ResultSet execute(NonnullRefPtr database, DeprecatedString const& sql) +SQL::ResultSet execute(NonnullRefPtr database, DeprecatedString const& sql, Vector placeholder_values = {}) { - auto result = try_execute(move(database), sql); + auto result = try_execute(move(database), sql, move(placeholder_values)); if (result.is_error()) { outln("{}", result.release_error().error_string()); VERIFY_NOT_REACHED(); @@ -41,6 +41,12 @@ SQL::ResultSet execute(NonnullRefPtr database, DeprecatedString c return result.release_value(); } +template +Vector placeholders(Args&&... args) +{ + return { SQL::Value(forward(args))... }; +} + void create_schema(NonnullRefPtr database) { auto result = execute(database, "CREATE SCHEMA TestSchema;"); @@ -175,6 +181,59 @@ TEST_CASE(insert_without_column_names) EXPECT_EQ(rows_or_error.value().size(), 2u); } +TEST_CASE(insert_with_placeholders) +{ + ScopeGuard guard([]() { unlink(db_name); }); + + auto database = SQL::Database::construct(db_name); + EXPECT(!database->open().is_error()); + create_table(database); + + { + auto result = try_execute(database, "INSERT INTO TestSchema.TestTable VALUES (?, ?);"); + EXPECT(result.is_error()); + EXPECT_EQ(result.error().error(), SQL::SQLErrorCode::InvalidNumberOfPlaceholderValues); + + result = try_execute(database, "INSERT INTO TestSchema.TestTable VALUES (?, ?);", placeholders("Test_1"sv)); + EXPECT(result.is_error()); + EXPECT_EQ(result.error().error(), SQL::SQLErrorCode::InvalidNumberOfPlaceholderValues); + + result = try_execute(database, "INSERT INTO TestSchema.TestTable VALUES (?, ?);", placeholders(42, 42)); + EXPECT(result.is_error()); + EXPECT_EQ(result.error().error(), SQL::SQLErrorCode::InvalidValueType); + + result = try_execute(database, "INSERT INTO TestSchema.TestTable VALUES (?, ?);", placeholders("Test_1"sv, "Test_2"sv)); + EXPECT(result.is_error()); + EXPECT_EQ(result.error().error(), SQL::SQLErrorCode::InvalidValueType); + } + { + auto result = execute(database, "INSERT INTO TestSchema.TestTable VALUES (?, ?);", placeholders("Test_1"sv, 42)); + EXPECT_EQ(result.size(), 1u); + + result = execute(database, "SELECT TextColumn, IntColumn FROM TestSchema.TestTable ORDER BY TextColumn;"); + EXPECT_EQ(result.size(), 1u); + + EXPECT_EQ(result[0].row[0], "Test_1"sv); + EXPECT_EQ(result[0].row[1], 42); + } + { + auto result = execute(database, "INSERT INTO TestSchema.TestTable VALUES (?, ?), (?, ?);", placeholders("Test_2"sv, 43, "Test_3"sv, 44)); + EXPECT_EQ(result.size(), 2u); + + result = execute(database, "SELECT TextColumn, IntColumn FROM TestSchema.TestTable ORDER BY TextColumn;"); + EXPECT_EQ(result.size(), 3u); + + EXPECT_EQ(result[0].row[0], "Test_1"sv); + EXPECT_EQ(result[0].row[1], 42); + + EXPECT_EQ(result[1].row[0], "Test_2"sv); + EXPECT_EQ(result[1].row[1], 43); + + EXPECT_EQ(result[2].row[0], "Test_3"sv); + EXPECT_EQ(result[2].row[1], 44); + } +} + TEST_CASE(select_from_empty_table) { ScopeGuard guard([]() { unlink(db_name); }); diff --git a/Tests/LibSQL/TestSqlStatementParser.cpp b/Tests/LibSQL/TestSqlStatementParser.cpp index 11ff1242cf..e42819a295 100644 --- a/Tests/LibSQL/TestSqlStatementParser.cpp +++ b/Tests/LibSQL/TestSqlStatementParser.cpp @@ -752,6 +752,13 @@ TEST_CASE(nested_subquery_limit) EXPECT(parse(DeprecatedString::formatted("SELECT * FROM ({});"sv, subquery)).is_error()); } +TEST_CASE(bound_parameter_limit) +{ + auto subquery = DeprecatedString::repeated("?, "sv, SQL::AST::Limits::maximum_bound_parameters); + EXPECT(!parse(DeprecatedString::formatted("INSERT INTO table_name VALUES ({}42);"sv, subquery)).is_error()); + EXPECT(parse(DeprecatedString::formatted("INSERT INTO table_name VALUES ({}?);"sv, subquery)).is_error()); +} + TEST_CASE(describe_table) { EXPECT(parse("DESCRIBE"sv).is_error()); diff --git a/Userland/Libraries/LibSQL/AST/AST.h b/Userland/Libraries/LibSQL/AST/AST.h index 4d64c9d48a..fd319289de 100644 --- a/Userland/Libraries/LibSQL/AST/AST.h +++ b/Userland/Libraries/LibSQL/AST/AST.h @@ -300,7 +300,8 @@ private: struct ExecutionContext { NonnullRefPtr database; - class Statement const* statement; + Statement const* statement { nullptr }; + Span placeholder_values {}; Tuple* current_row { nullptr }; }; @@ -361,6 +362,21 @@ public: virtual ResultOr evaluate(ExecutionContext&) const override; }; +class Placeholder : public Expression { +public: + explicit Placeholder(size_t parameter_index) + : m_parameter_index(parameter_index) + { + } + + size_t parameter_index() const { return m_parameter_index; } + + virtual ResultOr evaluate(ExecutionContext&) const override; + +private: + size_t m_parameter_index { 0 }; +}; + class NestedExpression : public Expression { public: NonnullRefPtr const& expression() const { return m_expression; } @@ -729,7 +745,7 @@ private: class Statement : public ASTNode { public: - ResultOr execute(AK::NonnullRefPtr database) const; + ResultOr execute(AK::NonnullRefPtr database, Span placeholder_values = {}) const; virtual ResultOr execute(ExecutionContext&) const { diff --git a/Userland/Libraries/LibSQL/AST/Expression.cpp b/Userland/Libraries/LibSQL/AST/Expression.cpp index 043670b5b4..0a39d4a30b 100644 --- a/Userland/Libraries/LibSQL/AST/Expression.cpp +++ b/Userland/Libraries/LibSQL/AST/Expression.cpp @@ -29,6 +29,13 @@ ResultOr NullLiteral::evaluate(ExecutionContext&) const return Value {}; } +ResultOr Placeholder::evaluate(ExecutionContext& context) const +{ + if (parameter_index() >= context.placeholder_values.size()) + return Result { SQLCommand::Unknown, SQLErrorCode::InvalidNumberOfPlaceholderValues }; + return context.placeholder_values[parameter_index()]; +} + ResultOr NestedExpression::evaluate(ExecutionContext& context) const { return expression()->evaluate(context); diff --git a/Userland/Libraries/LibSQL/AST/Parser.cpp b/Userland/Libraries/LibSQL/AST/Parser.cpp index bccea887aa..81ef007519 100644 --- a/Userland/Libraries/LibSQL/AST/Parser.cpp +++ b/Userland/Libraries/LibSQL/AST/Parser.cpp @@ -401,7 +401,6 @@ NonnullRefPtr Parser::parse_expression() if (match_secondary_expression()) expression = parse_secondary_expression(move(expression)); - // FIXME: Parse 'bind-parameter'. // FIXME: Parse 'function-name'. // FIXME: Parse 'raise-function'. @@ -414,6 +413,9 @@ NonnullRefPtr Parser::parse_primary_expression() if (auto expression = parse_literal_value_expression()) return expression.release_nonnull(); + if (auto expression = parse_bind_parameter_expression()) + return expression.release_nonnull(); + if (auto expression = parse_column_name_expression()) return expression.release_nonnull(); @@ -528,6 +530,21 @@ RefPtr Parser::parse_literal_value_expression() return {}; } +// https://sqlite.org/lang_expr.html#varparam +RefPtr Parser::parse_bind_parameter_expression() +{ + // FIXME: Support ?NNN, :AAAA, @AAAA, and $AAAA forms. + if (consume_if(TokenType::Placeholder)) { + auto parameter = m_parser_state.m_bound_parameters; + if (++m_parser_state.m_bound_parameters > Limits::maximum_bound_parameters) + syntax_error(DeprecatedString::formatted("Exceeded maximum number of bound parameters {}", Limits::maximum_bound_parameters)); + + return create_ast_node(parameter); + } + + return {}; +} + RefPtr Parser::parse_column_name_expression(DeprecatedString with_parsed_identifier, bool with_parsed_period) { if (with_parsed_identifier.is_null() && !match(TokenType::Identifier)) diff --git a/Userland/Libraries/LibSQL/AST/Parser.h b/Userland/Libraries/LibSQL/AST/Parser.h index 64e0d9f6ad..430280f9f2 100644 --- a/Userland/Libraries/LibSQL/AST/Parser.h +++ b/Userland/Libraries/LibSQL/AST/Parser.h @@ -19,6 +19,7 @@ namespace Limits { // https://www.sqlite.org/limits.html constexpr size_t maximum_expression_tree_depth = 1000; constexpr size_t maximum_subquery_depth = 100; +constexpr size_t maximum_bound_parameters = 1000; } class Parser { @@ -52,6 +53,7 @@ private: Vector m_errors; size_t m_current_expression_depth { 0 }; size_t m_current_subquery_depth { 0 }; + size_t m_bound_parameters { 0 }; }; NonnullRefPtr parse_statement(); @@ -71,6 +73,7 @@ private: NonnullRefPtr parse_secondary_expression(NonnullRefPtr primary); bool match_secondary_expression() const; RefPtr parse_literal_value_expression(); + RefPtr parse_bind_parameter_expression(); RefPtr parse_column_name_expression(DeprecatedString with_parsed_identifier = {}, bool with_parsed_period = false); RefPtr parse_unary_operator_expression(); RefPtr parse_binary_operator_expression(NonnullRefPtr lhs); diff --git a/Userland/Libraries/LibSQL/AST/Statement.cpp b/Userland/Libraries/LibSQL/AST/Statement.cpp index 7bc718d3ac..97a1fdb622 100644 --- a/Userland/Libraries/LibSQL/AST/Statement.cpp +++ b/Userland/Libraries/LibSQL/AST/Statement.cpp @@ -11,9 +11,9 @@ namespace SQL::AST { -ResultOr Statement::execute(AK::NonnullRefPtr database) const +ResultOr Statement::execute(AK::NonnullRefPtr database, Span placeholder_values) const { - ExecutionContext context { move(database), this, nullptr }; + ExecutionContext context { move(database), this, placeholder_values, nullptr }; auto result = TRY(execute(context)); // FIXME: When transactional sessions are supported, don't auto-commit modifications. diff --git a/Userland/Libraries/LibSQL/AST/Token.h b/Userland/Libraries/LibSQL/AST/Token.h index 86a55d722f..354f1bdcec 100644 --- a/Userland/Libraries/LibSQL/AST/Token.h +++ b/Userland/Libraries/LibSQL/AST/Token.h @@ -171,6 +171,7 @@ namespace SQL::AST { __ENUMERATE_SQL_TOKEN("_blob_", BlobLiteral, Blob) \ __ENUMERATE_SQL_TOKEN("_eof_", Eof, Invalid) \ __ENUMERATE_SQL_TOKEN("_invalid_", Invalid, Invalid) \ + __ENUMERATE_SQL_TOKEN("?", Placeholder, Operator) \ __ENUMERATE_SQL_TOKEN("&", Ampersand, Operator) \ __ENUMERATE_SQL_TOKEN("*", Asterisk, Operator) \ __ENUMERATE_SQL_TOKEN(",", Comma, Punctuation) \ diff --git a/Userland/Libraries/LibSQL/Result.h b/Userland/Libraries/LibSQL/Result.h index 5f65653f7c..594c2200e3 100644 --- a/Userland/Libraries/LibSQL/Result.h +++ b/Userland/Libraries/LibSQL/Result.h @@ -41,27 +41,28 @@ constexpr char const* command_tag(SQLCommand command) } } -#define ENUMERATE_SQL_ERRORS(S) \ - S(NoError, "No error") \ - S(InternalError, "{}") \ - S(NotYetImplemented, "{}") \ - S(DatabaseUnavailable, "Database Unavailable") \ - S(StatementUnavailable, "Statement with id '{}' Unavailable") \ - S(SyntaxError, "Syntax Error") \ - S(DatabaseDoesNotExist, "Database '{}' does not exist") \ - S(SchemaDoesNotExist, "Schema '{}' does not exist") \ - S(SchemaExists, "Schema '{}' already exist") \ - S(TableDoesNotExist, "Table '{}' does not exist") \ - S(ColumnDoesNotExist, "Column '{}' does not exist") \ - S(AmbiguousColumnName, "Column name '{}' is ambiguous") \ - S(TableExists, "Table '{}' already exist") \ - S(InvalidType, "Invalid type '{}'") \ - S(InvalidDatabaseName, "Invalid database name '{}'") \ - S(InvalidValueType, "Invalid type for attribute '{}'") \ - S(InvalidNumberOfValues, "Number of values does not match number of columns") \ - S(BooleanOperatorTypeMismatch, "Cannot apply '{}' operator to non-boolean operands") \ - S(NumericOperatorTypeMismatch, "Cannot apply '{}' operator to non-numeric operands") \ - S(IntegerOperatorTypeMismatch, "Cannot apply '{}' operator to non-numeric operands") \ +#define ENUMERATE_SQL_ERRORS(S) \ + S(NoError, "No error") \ + S(InternalError, "{}") \ + S(NotYetImplemented, "{}") \ + S(DatabaseUnavailable, "Database Unavailable") \ + S(StatementUnavailable, "Statement with id '{}' Unavailable") \ + S(SyntaxError, "Syntax Error") \ + S(DatabaseDoesNotExist, "Database '{}' does not exist") \ + S(SchemaDoesNotExist, "Schema '{}' does not exist") \ + S(SchemaExists, "Schema '{}' already exist") \ + S(TableDoesNotExist, "Table '{}' does not exist") \ + S(ColumnDoesNotExist, "Column '{}' does not exist") \ + S(AmbiguousColumnName, "Column name '{}' is ambiguous") \ + S(TableExists, "Table '{}' already exist") \ + S(InvalidType, "Invalid type '{}'") \ + S(InvalidDatabaseName, "Invalid database name '{}'") \ + S(InvalidValueType, "Invalid type for attribute '{}'") \ + S(InvalidNumberOfPlaceholderValues, "Number of values does not match number of placeholders") \ + S(InvalidNumberOfValues, "Number of values does not match number of columns") \ + S(BooleanOperatorTypeMismatch, "Cannot apply '{}' operator to non-boolean operands") \ + S(NumericOperatorTypeMismatch, "Cannot apply '{}' operator to non-numeric operands") \ + S(IntegerOperatorTypeMismatch, "Cannot apply '{}' operator to non-numeric operands") \ S(InvalidOperator, "Invalid operator '{}'") enum class SQLErrorCode {