From 896d1e4f42d30f90ce1c983c0b1c5e7bf9578fd1 Mon Sep 17 00:00:00 2001 From: Tim Ledbetter Date: Wed, 12 Apr 2023 23:47:00 +0100 Subject: [PATCH] LibSQL: Handle statements with malformed exists expressions correctly Previously, statements containing malformed exists expressions such as: `INSERT INTO t(a) VALUES (SELECT 1)`; could cause the parser to crash. The parser will now return an error message instead. --- Tests/LibSQL/TestSqlStatementParser.cpp | 40 +++++++++++++++++++ Userland/Libraries/LibSQL/AST/Parser.cpp | 51 +++++++++++++++--------- Userland/Libraries/LibSQL/AST/Parser.h | 4 +- 3 files changed, 75 insertions(+), 20 deletions(-) diff --git a/Tests/LibSQL/TestSqlStatementParser.cpp b/Tests/LibSQL/TestSqlStatementParser.cpp index 335df91a01..0793276fe3 100644 --- a/Tests/LibSQL/TestSqlStatementParser.cpp +++ b/Tests/LibSQL/TestSqlStatementParser.cpp @@ -307,6 +307,14 @@ TEST_CASE(insert) EXPECT(parse("INSERT INTO table_name VALUES"sv).is_error()); EXPECT(parse("INSERT INTO table_name VALUES ();"sv).is_error()); EXPECT(parse("INSERT INTO table_name VALUES (1)"sv).is_error()); + EXPECT(parse("INSERT INTO table_name VALUES SELECT"sv).is_error()); + EXPECT(parse("INSERT INTO table_name VALUES EXISTS"sv).is_error()); + EXPECT(parse("INSERT INTO table_name VALUES NOT"sv).is_error()); + EXPECT(parse("INSERT INTO table_name VALUES EXISTS (SELECT 1)"sv).is_error()); + EXPECT(parse("INSERT INTO table_name VALUES (SELECT)"sv).is_error()); + EXPECT(parse("INSERT INTO table_name VALUES (EXISTS SELECT)"sv).is_error()); + EXPECT(parse("INSERT INTO table_name VALUES ((SELECT))"sv).is_error()); + EXPECT(parse("INSERT INTO table_name VALUES (EXISTS (SELECT))"sv).is_error()); EXPECT(parse("INSERT INTO table_name SELECT"sv).is_error()); EXPECT(parse("INSERT INTO table_name SELECT * from table_name"sv).is_error()); EXPECT(parse("INSERT OR INTO table_name DEFAULT VALUES;"sv).is_error()); @@ -367,6 +375,12 @@ TEST_CASE(insert) validate("INSERT INTO table_name VALUES (1, 2);"sv, resolution, {}, "TABLE_NAME"sv, {}, {}, { 2 }, false); validate("INSERT INTO table_name VALUES (1, 2), (3, 4, 5);"sv, resolution, {}, "TABLE_NAME"sv, {}, {}, { 2, 3 }, false); + validate("INSERT INTO table_name VALUES ((SELECT 1));"sv, resolution, {}, "TABLE_NAME"sv, {}, {}, { 1 }, false); + validate("INSERT INTO table_name VALUES (EXISTS (SELECT 1));"sv, resolution, {}, "TABLE_NAME"sv, {}, {}, { 1 }, false); + validate("INSERT INTO table_name VALUES (NOT EXISTS (SELECT 1));"sv, resolution, {}, "TABLE_NAME"sv, {}, {}, { 1 }, false); + validate("INSERT INTO table_name VALUES ((SELECT 1), (SELECT 1));"sv, resolution, {}, "TABLE_NAME"sv, {}, {}, { 2 }, false); + validate("INSERT INTO table_name VALUES ((SELECT 1), (SELECT 1)), ((SELECT 1), (SELECT 1), (SELECT 1));"sv, resolution, {}, "TABLE_NAME"sv, {}, {}, { 2, 3 }, false); + validate("INSERT INTO table_name SELECT * FROM table_name;"sv, resolution, {}, "TABLE_NAME"sv, {}, {}, {}, true); } @@ -379,11 +393,21 @@ TEST_CASE(update) EXPECT(parse("UPDATE table_name SET column_name=4"sv).is_error()); EXPECT(parse("UPDATE table_name SET column_name=4, ;"sv).is_error()); EXPECT(parse("UPDATE table_name SET (column_name)=4"sv).is_error()); + EXPECT(parse("UPDATE table_name SET (column_name)=EXISTS"sv).is_error()); + EXPECT(parse("UPDATE table_name SET (column_name)=SELECT"sv).is_error()); + EXPECT(parse("UPDATE table_name SET (column_name)=(SELECT)"sv).is_error()); + EXPECT(parse("UPDATE table_name SET (column_name)=NOT (SELECT 1)"sv).is_error()); EXPECT(parse("UPDATE table_name SET (column_name)=4, ;"sv).is_error()); EXPECT(parse("UPDATE table_name SET (column_name, )=4;"sv).is_error()); EXPECT(parse("UPDATE table_name SET column_name=4 FROM"sv).is_error()); EXPECT(parse("UPDATE table_name SET column_name=4 FROM table_name"sv).is_error()); EXPECT(parse("UPDATE table_name SET column_name=4 WHERE"sv).is_error()); + EXPECT(parse("UPDATE table_name SET column_name=4 WHERE EXISTS"sv).is_error()); + EXPECT(parse("UPDATE table_name SET column_name=4 WHERE NOT"sv).is_error()); + EXPECT(parse("UPDATE table_name SET column_name=4 WHERE NOT EXISTS"sv).is_error()); + EXPECT(parse("UPDATE table_name SET column_name=4 WHERE SELECT"sv).is_error()); + EXPECT(parse("UPDATE table_name SET column_name=4 WHERE (SELECT)"sv).is_error()); + EXPECT(parse("UPDATE table_name SET column_name=4 WHERE NOT (SELECT)"sv).is_error()); EXPECT(parse("UPDATE table_name SET column_name=4 WHERE 1==1"sv).is_error()); EXPECT(parse("UPDATE table_name SET column_name=4 RETURNING"sv).is_error()); EXPECT(parse("UPDATE table_name SET column_name=4 RETURNING *"sv).is_error()); @@ -452,11 +476,18 @@ TEST_CASE(update) validate("UPDATE table_name AS foo SET column_name=1;"sv, resolution, {}, "TABLE_NAME"sv, "FOO"sv, update_columns, false, false, {}); validate("UPDATE table_name SET column_name=1;"sv, resolution, {}, "TABLE_NAME"sv, {}, { { "COLUMN_NAME"sv } }, false, false, {}); + validate("UPDATE table_name SET column_name=(SELECT 1);"sv, resolution, {}, "TABLE_NAME"sv, {}, { { "COLUMN_NAME"sv } }, false, false, {}); + validate("UPDATE table_name SET column_name=EXISTS (SELECT 1);"sv, resolution, {}, "TABLE_NAME"sv, {}, { { "COLUMN_NAME"sv } }, false, false, {}); + validate("UPDATE table_name SET column_name=NOT EXISTS (SELECT 1);"sv, resolution, {}, "TABLE_NAME"sv, {}, { { "COLUMN_NAME"sv } }, false, false, {}); validate("UPDATE table_name SET column1=1, column2=2;"sv, resolution, {}, "TABLE_NAME"sv, {}, { { "COLUMN1"sv }, { "COLUMN2"sv } }, false, false, {}); validate("UPDATE table_name SET (column1, column2)=1, column3=2;"sv, resolution, {}, "TABLE_NAME"sv, {}, { { "COLUMN1"sv, "COLUMN2"sv }, { "COLUMN3"sv } }, false, false, {}); validate("UPDATE table_name SET column_name=1 WHERE 1==1;"sv, resolution, {}, "TABLE_NAME"sv, {}, update_columns, true, false, {}); + validate("UPDATE table_name SET column_name=1 WHERE (SELECT 1);"sv, resolution, {}, "TABLE_NAME"sv, {}, { { "COLUMN_NAME"sv } }, true, false, {}); + validate("UPDATE table_name SET column_name=1 WHERE EXISTS (SELECT 1);"sv, resolution, {}, "TABLE_NAME"sv, {}, { { "COLUMN_NAME"sv } }, true, false, {}); + validate("UPDATE table_name SET column_name=1 WHERE NOT EXISTS (SELECT 1);"sv, resolution, {}, "TABLE_NAME"sv, {}, { { "COLUMN_NAME"sv } }, true, false, {}); + validate("UPDATE table_name SET column_name=1 RETURNING *;"sv, resolution, {}, "TABLE_NAME"sv, {}, update_columns, false, true, {}); validate("UPDATE table_name SET column_name=1 RETURNING column_name;"sv, resolution, {}, "TABLE_NAME"sv, {}, update_columns, false, true, { {} }); validate("UPDATE table_name SET column_name=1 RETURNING column_name AS alias;"sv, resolution, {}, "TABLE_NAME"sv, {}, update_columns, false, true, { "ALIAS"sv }); @@ -469,6 +500,12 @@ TEST_CASE(delete_) EXPECT(parse("DELETE FROM"sv).is_error()); EXPECT(parse("DELETE FROM table_name"sv).is_error()); EXPECT(parse("DELETE FROM table_name WHERE"sv).is_error()); + EXPECT(parse("DELETE FROM table_name WHERE EXISTS"sv).is_error()); + EXPECT(parse("DELETE FROM table_name WHERE NOT"sv).is_error()); + EXPECT(parse("DELETE FROM table_name WHERE NOT (SELECT 1)"sv).is_error()); + EXPECT(parse("DELETE FROM table_name WHERE NOT EXISTS"sv).is_error()); + EXPECT(parse("DELETE FROM table_name WHERE SELECT"sv).is_error()); + EXPECT(parse("DELETE FROM table_name WHERE (SELECT)"sv).is_error()); EXPECT(parse("DELETE FROM table_name WHERE 15"sv).is_error()); EXPECT(parse("DELETE FROM table_name WHERE 15 RETURNING"sv).is_error()); EXPECT(parse("DELETE FROM table_name WHERE 15 RETURNING *"sv).is_error()); @@ -514,6 +551,9 @@ TEST_CASE(delete_) validate("DELETE FROM schema_name.table_name;"sv, "SCHEMA_NAME"sv, "TABLE_NAME"sv, {}, false, false, {}); validate("DELETE FROM schema_name.table_name AS alias;"sv, "SCHEMA_NAME"sv, "TABLE_NAME"sv, "ALIAS"sv, false, false, {}); validate("DELETE FROM table_name WHERE (1 == 1);"sv, {}, "TABLE_NAME"sv, {}, true, false, {}); + validate("DELETE FROM table_name WHERE EXISTS (SELECT 1);"sv, {}, "TABLE_NAME"sv, {}, true, false, {}); + validate("DELETE FROM table_name WHERE NOT EXISTS (SELECT 1);"sv, {}, "TABLE_NAME"sv, {}, true, false, {}); + validate("DELETE FROM table_name WHERE (SELECT 1);"sv, {}, "TABLE_NAME"sv, {}, true, false, {}); validate("DELETE FROM table_name RETURNING *;"sv, {}, "TABLE_NAME"sv, {}, false, true, {}); validate("DELETE FROM table_name RETURNING column_name;"sv, {}, "TABLE_NAME"sv, {}, false, true, { {} }); validate("DELETE FROM table_name RETURNING column_name AS alias;"sv, {}, "TABLE_NAME"sv, {}, false, true, { "ALIAS"sv }); diff --git a/Userland/Libraries/LibSQL/AST/Parser.cpp b/Userland/Libraries/LibSQL/AST/Parser.cpp index 26d88ec205..33f944a403 100644 --- a/Userland/Libraries/LibSQL/AST/Parser.cpp +++ b/Userland/Libraries/LibSQL/AST/Parser.cpp @@ -219,7 +219,7 @@ NonnullRefPtr Parser::parse_insert_statement(RefPtr(chained_expression.ptr()); + auto* chained_expr = verify_cast(chained_expression.ptr()); if ((column_names.size() > 0) && (chained_expr->expressions().size() != column_names.size())) { syntax_error("Number of expressions does not match number of columns"); } else { @@ -422,17 +422,34 @@ NonnullRefPtr Parser::parse_primary_expression() if (auto expression = parse_unary_operator_expression()) return expression.release_nonnull(); - if (auto expression = parse_chained_expression()) - return expression.release_nonnull(); - if (auto expression = parse_cast_expression()) return expression.release_nonnull(); if (auto expression = parse_case_expression()) return expression.release_nonnull(); - if (auto expression = parse_exists_expression(false)) - return expression.release_nonnull(); + if (auto invert_expression = consume_if(TokenType::Not); invert_expression || consume_if(TokenType::Exists)) { + if (auto expression = parse_exists_expression(invert_expression)) + return expression.release_nonnull(); + + expected("Exists expression"sv); + } + + if (consume_if(TokenType::ParenOpen)) { + // Encountering a Select token at this point means this must be an ExistsExpression with no EXISTS keyword. + if (match(TokenType::Select)) { + auto select_statement = parse_select_statement({}); + consume(TokenType::ParenClose); + return create_ast_node(move(select_statement), false); + } + + if (auto expression = parse_chained_expression(false)) { + consume(TokenType::ParenClose); + return expression.release_nonnull(); + } + + expected("Chained expression"sv); + } expected("Primary Expression"sv); consume(); @@ -662,17 +679,16 @@ RefPtr Parser::parse_binary_operator_expression(NonnullRefPtr Parser::parse_chained_expression() +RefPtr Parser::parse_chained_expression(bool surrounded_by_parentheses) { - if (!consume_if(TokenType::ParenOpen)) + if (surrounded_by_parentheses && !consume_if(TokenType::ParenOpen)) return {}; - if (match(TokenType::Select)) - return parse_exists_expression(false, TokenType::Select); - Vector> expressions; parse_comma_separated_list(false, [&]() { expressions.append(parse_expression()); }); - consume(TokenType::ParenClose); + + if (surrounded_by_parentheses) + consume(TokenType::ParenClose); return create_ast_node(move(expressions)); } @@ -726,15 +742,14 @@ RefPtr Parser::parse_case_expression() return create_ast_node(move(case_expression), move(when_then_clauses), move(else_expression)); } -RefPtr Parser::parse_exists_expression(bool invert_expression, TokenType opening_token) +RefPtr Parser::parse_exists_expression(bool invert_expression) { - VERIFY((opening_token == TokenType::Exists) || (opening_token == TokenType::Select)); - - if ((opening_token == TokenType::Exists) && !consume_if(TokenType::Exists)) + if (!(match(TokenType::Exists) || match(TokenType::ParenOpen))) return {}; - if (opening_token == TokenType::Exists) - consume(TokenType::ParenOpen); + consume_if(TokenType::Exists); + consume(TokenType::ParenOpen); + auto select_statement = parse_select_statement({}); consume(TokenType::ParenClose); diff --git a/Userland/Libraries/LibSQL/AST/Parser.h b/Userland/Libraries/LibSQL/AST/Parser.h index 430280f9f2..2721e45043 100644 --- a/Userland/Libraries/LibSQL/AST/Parser.h +++ b/Userland/Libraries/LibSQL/AST/Parser.h @@ -77,10 +77,10 @@ private: RefPtr parse_column_name_expression(DeprecatedString with_parsed_identifier = {}, bool with_parsed_period = false); RefPtr parse_unary_operator_expression(); RefPtr parse_binary_operator_expression(NonnullRefPtr lhs); - RefPtr parse_chained_expression(); + RefPtr parse_chained_expression(bool surrounded_by_parentheses = true); RefPtr parse_cast_expression(); RefPtr parse_case_expression(); - RefPtr parse_exists_expression(bool invert_expression, TokenType opening_token = TokenType::Exists); + RefPtr parse_exists_expression(bool invert_expression); RefPtr parse_collate_expression(NonnullRefPtr expression); RefPtr parse_is_expression(NonnullRefPtr expression); RefPtr parse_match_expression(NonnullRefPtr lhs, bool invert_expression);