diff --git a/Tests/LibSQL/TestSqlStatementExecution.cpp b/Tests/LibSQL/TestSqlStatementExecution.cpp index a7f0c0fbe4..e46056e427 100644 --- a/Tests/LibSQL/TestSqlStatementExecution.cpp +++ b/Tests/LibSQL/TestSqlStatementExecution.cpp @@ -46,6 +46,17 @@ void create_table(NonnullRefPtr database) EXPECT(result->inserted() == 1); } +void create_two_tables(NonnullRefPtr database) +{ + create_schema(database); + auto result = execute(database, "CREATE TABLE TestSchema.TestTable1 ( TextColumn1 text, IntColumn integer );"); + EXPECT(result->error().code == SQL::SQLErrorCode::NoError); + EXPECT(result->inserted() == 1); + result = execute(database, "CREATE TABLE TestSchema.TestTable2 ( TextColumn2 text, IntColumn integer );"); + EXPECT(result->error().code == SQL::SQLErrorCode::NoError); + EXPECT(result->inserted() == 1); +} + TEST_CASE(create_schema) { ScopeGuard guard([]() { unlink(db_name); }); @@ -132,15 +143,15 @@ TEST_CASE(select_from_table) ScopeGuard guard([]() { unlink(db_name); }); auto database = SQL::Database::construct(db_name); create_table(database); - auto result = execute(database, "INSERT INTO TestSchema.TestTable ( TextColumn, IntColumn ) VALUES ( 'Test_1', 42 ), ( 'Test_2', 43 );"); + auto result = execute(database, + "INSERT INTO TestSchema.TestTable ( TextColumn, IntColumn ) VALUES " + "( 'Test_1', 42 ), " + "( 'Test_2', 43 ), " + "( 'Test_3', 44 ), " + "( 'Test_4', 45 ), " + "( 'Test_5', 46 );"); EXPECT(result->error().code == SQL::SQLErrorCode::NoError); - EXPECT(result->inserted() == 2); - result = execute(database, "INSERT INTO TestSchema.TestTable ( TextColumn, IntColumn ) VALUES ( 'Test_3', 44 ), ( 'Test_4', 45 );"); - EXPECT(result->error().code == SQL::SQLErrorCode::NoError); - EXPECT(result->inserted() == 2); - result = execute(database, "INSERT INTO TestSchema.TestTable ( TextColumn, IntColumn ) VALUES ( 'Test_5', 46 );"); - EXPECT(result->error().code == SQL::SQLErrorCode::NoError); - EXPECT(result->inserted() == 1); + EXPECT(result->inserted() == 5); result = execute(database, "SELECT * FROM TestSchema.TestTable;"); EXPECT(result->error().code == SQL::SQLErrorCode::NoError); EXPECT(result->has_results()); @@ -152,15 +163,15 @@ TEST_CASE(select_with_column_names) ScopeGuard guard([]() { unlink(db_name); }); auto database = SQL::Database::construct(db_name); create_table(database); - auto result = execute(database, "INSERT INTO TestSchema.TestTable ( TextColumn, IntColumn ) VALUES ( 'Test_1', 42 ), ( 'Test_2', 43 );"); + auto result = execute(database, + "INSERT INTO TestSchema.TestTable ( TextColumn, IntColumn ) VALUES " + "( 'Test_1', 42 ), " + "( 'Test_2', 43 ), " + "( 'Test_3', 44 ), " + "( 'Test_4', 45 ), " + "( 'Test_5', 46 );"); EXPECT(result->error().code == SQL::SQLErrorCode::NoError); - EXPECT(result->inserted() == 2); - result = execute(database, "INSERT INTO TestSchema.TestTable ( TextColumn, IntColumn ) VALUES ( 'Test_3', 44 ), ( 'Test_4', 45 );"); - EXPECT(result->error().code == SQL::SQLErrorCode::NoError); - EXPECT(result->inserted() == 2); - result = execute(database, "INSERT INTO TestSchema.TestTable ( TextColumn, IntColumn ) VALUES ( 'Test_5', 46 );"); - EXPECT(result->error().code == SQL::SQLErrorCode::NoError); - EXPECT(result->inserted() == 1); + EXPECT(result->inserted() == 5); result = execute(database, "SELECT TextColumn FROM TestSchema.TestTable;"); EXPECT(result->error().code == SQL::SQLErrorCode::NoError); EXPECT(result->has_results()); @@ -209,4 +220,77 @@ TEST_CASE(select_with_where) } } +TEST_CASE(select_cross_join) +{ + ScopeGuard guard([]() { unlink(db_name); }); + auto database = SQL::Database::construct(db_name); + create_two_tables(database); + auto result = execute(database, + "INSERT INTO TestSchema.TestTable1 ( TextColumn1, IntColumn ) VALUES " + "( 'Test_1', 42 ), " + "( 'Test_2', 43 ), " + "( 'Test_3', 44 ), " + "( 'Test_4', 45 ), " + "( 'Test_5', 46 );"); + EXPECT(result->error().code == SQL::SQLErrorCode::NoError); + EXPECT(result->inserted() == 5); + result = execute(database, + "INSERT INTO TestSchema.TestTable2 ( TextColumn2, IntColumn ) VALUES " + "( 'Test_10', 40 ), " + "( 'Test_11', 41 ), " + "( 'Test_12', 42 ), " + "( 'Test_13', 47 ), " + "( 'Test_14', 48 );"); + EXPECT(result->error().code == SQL::SQLErrorCode::NoError); + EXPECT(result->inserted() == 5); + result = execute(database, "SELECT * FROM TestSchema.TestTable1, TestSchema.TestTable2;"); + EXPECT(result->error().code == SQL::SQLErrorCode::NoError); + EXPECT(result->has_results()); + EXPECT_EQ(result->results().size(), 25u); + for (auto& row : result->results()) { + EXPECT(row.size() == 4); + EXPECT(row[1].to_int().value() >= 42); + EXPECT(row[1].to_int().value() <= 46); + EXPECT(row[3].to_int().value() >= 40); + EXPECT(row[3].to_int().value() <= 48); + } +} + +TEST_CASE(select_inner_join) +{ + ScopeGuard guard([]() { unlink(db_name); }); + auto database = SQL::Database::construct(db_name); + create_two_tables(database); + auto result = execute(database, + "INSERT INTO TestSchema.TestTable1 ( TextColumn1, IntColumn ) VALUES " + "( 'Test_1', 42 ), " + "( 'Test_2', 43 ), " + "( 'Test_3', 44 ), " + "( 'Test_4', 45 ), " + "( 'Test_5', 46 );"); + EXPECT(result->error().code == SQL::SQLErrorCode::NoError); + EXPECT(result->inserted() == 5); + result = execute(database, + "INSERT INTO TestSchema.TestTable2 ( TextColumn2, IntColumn ) VALUES " + "( 'Test_10', 40 ), " + "( 'Test_11', 41 ), " + "( 'Test_12', 42 ), " + "( 'Test_13', 47 ), " + "( 'Test_14', 48 );"); + EXPECT(result->error().code == SQL::SQLErrorCode::NoError); + EXPECT(result->inserted() == 5); + result = execute(database, + "SELECT TestTable1.IntColumn, TextColumn1, TextColumn2 " + "FROM TestSchema.TestTable1, TestSchema.TestTable2 " + "WHERE TestTable1.IntColumn = TestTable2.IntColumn;"); + EXPECT(result->error().code == SQL::SQLErrorCode::NoError); + EXPECT(result->has_results()); + EXPECT_EQ(result->results().size(), 1u); + auto& row = result->results()[0]; + EXPECT_EQ(row.size(), 3u); + EXPECT_EQ(row[0].to_int().value(), 42); + EXPECT_EQ(row[1].to_string(), "Test_1"); + EXPECT_EQ(row[2].to_string(), "Test_12"); +} + } diff --git a/Userland/Libraries/LibSQL/AST/Expression.cpp b/Userland/Libraries/LibSQL/AST/Expression.cpp index d3e1337be2..a7b793e602 100644 --- a/Userland/Libraries/LibSQL/AST/Expression.cpp +++ b/Userland/Libraries/LibSQL/AST/Expression.cpp @@ -170,11 +170,21 @@ Value ColumnNameExpression::evaluate(ExecutionContext& context) const { auto& descriptor = *context.current_row->descriptor(); VERIFY(context.current_row->size() == descriptor.size()); + Optional index_in_row; for (auto ix = 0u; ix < context.current_row->size(); ix++) { auto& column_descriptor = descriptor[ix]; - if (column_descriptor.name == column_name()) - return { (*context.current_row)[ix] }; + if (!table_name().is_empty() && column_descriptor.table != table_name()) + continue; + if (column_descriptor.name == column_name()) { + if (index_in_row.has_value()) { + context.result->set_error(SQLErrorCode::AmbiguousColumnName, column_name()); + return Value::null(); + } + index_in_row = ix; + } } + if (index_in_row.has_value()) + return (*context.current_row)[index_in_row.value()]; context.result->set_error(SQLErrorCode::ColumnDoesNotExist, column_name()); return Value::null(); } diff --git a/Userland/Libraries/LibSQL/AST/Select.cpp b/Userland/Libraries/LibSQL/AST/Select.cpp index 6214e67015..81ef322a57 100644 --- a/Userland/Libraries/LibSQL/AST/Select.cpp +++ b/Userland/Libraries/LibSQL/AST/Select.cpp @@ -13,13 +13,14 @@ namespace SQL::AST { RefPtr Select::execute(ExecutionContext& context) const { - if (table_or_subquery_list().size() == 1 && table_or_subquery_list()[0].is_table()) { - auto table = context.database->get_table(table_or_subquery_list()[0].schema_name(), table_or_subquery_list()[0].table_name()); + NonnullRefPtrVector columns; + for (auto& table_descriptor : table_or_subquery_list()) { + if (!table_descriptor.is_table()) + TODO(); + auto table = context.database->get_table(table_descriptor.schema_name(), table_descriptor.table_name()); if (!table) { - return SQLResult::construct(SQL::SQLCommand::Select, SQL::SQLErrorCode::TableDoesNotExist, table_or_subquery_list()[0].table_name()); + return SQLResult::construct(SQL::SQLCommand::Select, SQL::SQLErrorCode::TableDoesNotExist, table_descriptor.table_name()); } - - NonnullRefPtrVector columns; if (result_column_list().size() == 1 && result_column_list()[0].type() == ResultType::All) { for (auto& col : table->columns()) { columns.append( @@ -27,35 +28,64 @@ RefPtr Select::execute(ExecutionContext& context) const create_ast_node(table->parent()->name(), table->name(), col.name()), "")); } - } else { - for (auto& col : result_column_list()) { - columns.append(col); - } } - context.result = SQLResult::construct(); - AK::NonnullRefPtr descriptor = AK::adopt_ref(*new TupleDescriptor); - Tuple tuple(descriptor); - for (auto& row : context.database->select_all(*table)) { - context.current_row = &row; - if (where_clause()) { - auto where_result = where_clause()->evaluate(context); - if (context.result->has_error()) - return context.result; - if (!where_result) - continue; - } - tuple.clear(); - for (auto& col : columns) { - auto value = col.expression()->evaluate(context); - if (context.result->has_error()) - return context.result; - tuple.append(value); - } - context.result->append(tuple); - } - return context.result; } - return SQLResult::construct(); + + VERIFY(!result_column_list().is_empty()); + if (result_column_list().size() != 1 || result_column_list()[0].type() != ResultType::All) { + for (auto& col : result_column_list()) { + if (col.type() == ResultType::All) + // FIXME can have '*' for example in conjunction with computed columns + return SQLResult::construct(SQL::SQLCommand::Select, SQL::SQLErrorCode::SyntaxError, "*"); + columns.append(col); + } + } + + context.result = SQLResult::construct(); + AK::NonnullRefPtr descriptor = AK::adopt_ref(*new TupleDescriptor); + Tuple tuple(descriptor); + Vector rows; + descriptor->empend("__unity__"); + tuple.append(Value(SQLType::Boolean, true)); + rows.append(tuple); + + for (auto& table_descriptor : table_or_subquery_list()) { + if (!table_descriptor.is_table()) + TODO(); + auto table = context.database->get_table(table_descriptor.schema_name(), table_descriptor.table_name()); + if (table->num_columns() == 0) + continue; + auto old_descriptor_size = descriptor->size(); + descriptor->extend(table->to_tuple_descriptor()); + for (auto cartesian_row = rows.first(); cartesian_row.size() == old_descriptor_size; cartesian_row = rows.first()) { + rows.remove(0); + for (auto& table_row : context.database->select_all(*table)) { + auto new_row = cartesian_row; + new_row.extend(table_row); + rows.append(new_row); + } + } + } + + for (auto& row : rows) { + context.current_row = &row; + if (where_clause()) { + auto where_result = where_clause()->evaluate(context); + if (context.result->has_error()) + return context.result; + if (!where_result) + continue; + } + tuple.clear(); + for (auto& col : columns) { + auto value = col.expression()->evaluate(context); + if (context.result->has_error()) + return context.result; + tuple.append(value); + } + context.result->append(tuple); + } + return context.result; } } diff --git a/Userland/Libraries/LibSQL/SQLResult.h b/Userland/Libraries/LibSQL/SQLResult.h index 33eab92c2a..6f14873655 100644 --- a/Userland/Libraries/LibSQL/SQLResult.h +++ b/Userland/Libraries/LibSQL/SQLResult.h @@ -51,6 +51,7 @@ constexpr char const* command_tag(SQLCommand command) S(SchemaExists, "Schema '{}' already exist") \ S(TableDoesNotExist, "Table '{}' does not exist") \ S(ColumnDoesNotExist, "Column '{}' does not exist") \ + S(AmbiguousColumnName, "Column name '{}' is ambiguous") \ S(TableExists, "Table '{}' already exist") \ S(InvalidType, "Invalid type '{}'") \ S(InvalidDatabaseName, "Invalid database name '{}'") \