From 3425730294e9913627e36b2ddc559f337e0f186b Mon Sep 17 00:00:00 2001 From: Jan de Visser Date: Tue, 2 Nov 2021 16:49:54 -0400 Subject: [PATCH] LibSQL: Implement table joins This patch introduces table joins. It uses a pretty dumb algorithm- starting with a singleton '__unity__' row consisting of a single boolean value, a cartesian product of all tables in the 'FROM' clause is built. This cartesian product is then filtered through the 'WHERE' clause, again without any smarts just using brute force. This patch required a bunch of busy work to allow for example the ColumnNameExpression having to deal with multiple tables potentially having columns with the same name. --- Tests/LibSQL/TestSqlStatementExecution.cpp | 116 ++++++++++++++++--- Userland/Libraries/LibSQL/AST/Expression.cpp | 14 ++- Userland/Libraries/LibSQL/AST/Select.cpp | 94 ++++++++++----- Userland/Libraries/LibSQL/SQLResult.h | 1 + 4 files changed, 175 insertions(+), 50 deletions(-) diff --git a/Tests/LibSQL/TestSqlStatementExecution.cpp b/Tests/LibSQL/TestSqlStatementExecution.cpp index a7f0c0fbe4..e46056e427 100644 --- a/Tests/LibSQL/TestSqlStatementExecution.cpp +++ b/Tests/LibSQL/TestSqlStatementExecution.cpp @@ -46,6 +46,17 @@ void create_table(NonnullRefPtr database) EXPECT(result->inserted() == 1); } +void create_two_tables(NonnullRefPtr database) +{ + create_schema(database); + auto result = execute(database, "CREATE TABLE TestSchema.TestTable1 ( TextColumn1 text, IntColumn integer );"); + EXPECT(result->error().code == SQL::SQLErrorCode::NoError); + EXPECT(result->inserted() == 1); + result = execute(database, "CREATE TABLE TestSchema.TestTable2 ( TextColumn2 text, IntColumn integer );"); + EXPECT(result->error().code == SQL::SQLErrorCode::NoError); + EXPECT(result->inserted() == 1); +} + TEST_CASE(create_schema) { ScopeGuard guard([]() { unlink(db_name); }); @@ -132,15 +143,15 @@ TEST_CASE(select_from_table) ScopeGuard guard([]() { unlink(db_name); }); auto database = SQL::Database::construct(db_name); create_table(database); - auto result = execute(database, "INSERT INTO TestSchema.TestTable ( TextColumn, IntColumn ) VALUES ( 'Test_1', 42 ), ( 'Test_2', 43 );"); + auto result = execute(database, + "INSERT INTO TestSchema.TestTable ( TextColumn, IntColumn ) VALUES " + "( 'Test_1', 42 ), " + "( 'Test_2', 43 ), " + "( 'Test_3', 44 ), " + "( 'Test_4', 45 ), " + "( 'Test_5', 46 );"); EXPECT(result->error().code == SQL::SQLErrorCode::NoError); - EXPECT(result->inserted() == 2); - result = execute(database, "INSERT INTO TestSchema.TestTable ( TextColumn, IntColumn ) VALUES ( 'Test_3', 44 ), ( 'Test_4', 45 );"); - EXPECT(result->error().code == SQL::SQLErrorCode::NoError); - EXPECT(result->inserted() == 2); - result = execute(database, "INSERT INTO TestSchema.TestTable ( TextColumn, IntColumn ) VALUES ( 'Test_5', 46 );"); - EXPECT(result->error().code == SQL::SQLErrorCode::NoError); - EXPECT(result->inserted() == 1); + EXPECT(result->inserted() == 5); result = execute(database, "SELECT * FROM TestSchema.TestTable;"); EXPECT(result->error().code == SQL::SQLErrorCode::NoError); EXPECT(result->has_results()); @@ -152,15 +163,15 @@ TEST_CASE(select_with_column_names) ScopeGuard guard([]() { unlink(db_name); }); auto database = SQL::Database::construct(db_name); create_table(database); - auto result = execute(database, "INSERT INTO TestSchema.TestTable ( TextColumn, IntColumn ) VALUES ( 'Test_1', 42 ), ( 'Test_2', 43 );"); + auto result = execute(database, + "INSERT INTO TestSchema.TestTable ( TextColumn, IntColumn ) VALUES " + "( 'Test_1', 42 ), " + "( 'Test_2', 43 ), " + "( 'Test_3', 44 ), " + "( 'Test_4', 45 ), " + "( 'Test_5', 46 );"); EXPECT(result->error().code == SQL::SQLErrorCode::NoError); - EXPECT(result->inserted() == 2); - result = execute(database, "INSERT INTO TestSchema.TestTable ( TextColumn, IntColumn ) VALUES ( 'Test_3', 44 ), ( 'Test_4', 45 );"); - EXPECT(result->error().code == SQL::SQLErrorCode::NoError); - EXPECT(result->inserted() == 2); - result = execute(database, "INSERT INTO TestSchema.TestTable ( TextColumn, IntColumn ) VALUES ( 'Test_5', 46 );"); - EXPECT(result->error().code == SQL::SQLErrorCode::NoError); - EXPECT(result->inserted() == 1); + EXPECT(result->inserted() == 5); result = execute(database, "SELECT TextColumn FROM TestSchema.TestTable;"); EXPECT(result->error().code == SQL::SQLErrorCode::NoError); EXPECT(result->has_results()); @@ -209,4 +220,77 @@ TEST_CASE(select_with_where) } } +TEST_CASE(select_cross_join) +{ + ScopeGuard guard([]() { unlink(db_name); }); + auto database = SQL::Database::construct(db_name); + create_two_tables(database); + auto result = execute(database, + "INSERT INTO TestSchema.TestTable1 ( TextColumn1, IntColumn ) VALUES " + "( 'Test_1', 42 ), " + "( 'Test_2', 43 ), " + "( 'Test_3', 44 ), " + "( 'Test_4', 45 ), " + "( 'Test_5', 46 );"); + EXPECT(result->error().code == SQL::SQLErrorCode::NoError); + EXPECT(result->inserted() == 5); + result = execute(database, + "INSERT INTO TestSchema.TestTable2 ( TextColumn2, IntColumn ) VALUES " + "( 'Test_10', 40 ), " + "( 'Test_11', 41 ), " + "( 'Test_12', 42 ), " + "( 'Test_13', 47 ), " + "( 'Test_14', 48 );"); + EXPECT(result->error().code == SQL::SQLErrorCode::NoError); + EXPECT(result->inserted() == 5); + result = execute(database, "SELECT * FROM TestSchema.TestTable1, TestSchema.TestTable2;"); + EXPECT(result->error().code == SQL::SQLErrorCode::NoError); + EXPECT(result->has_results()); + EXPECT_EQ(result->results().size(), 25u); + for (auto& row : result->results()) { + EXPECT(row.size() == 4); + EXPECT(row[1].to_int().value() >= 42); + EXPECT(row[1].to_int().value() <= 46); + EXPECT(row[3].to_int().value() >= 40); + EXPECT(row[3].to_int().value() <= 48); + } +} + +TEST_CASE(select_inner_join) +{ + ScopeGuard guard([]() { unlink(db_name); }); + auto database = SQL::Database::construct(db_name); + create_two_tables(database); + auto result = execute(database, + "INSERT INTO TestSchema.TestTable1 ( TextColumn1, IntColumn ) VALUES " + "( 'Test_1', 42 ), " + "( 'Test_2', 43 ), " + "( 'Test_3', 44 ), " + "( 'Test_4', 45 ), " + "( 'Test_5', 46 );"); + EXPECT(result->error().code == SQL::SQLErrorCode::NoError); + EXPECT(result->inserted() == 5); + result = execute(database, + "INSERT INTO TestSchema.TestTable2 ( TextColumn2, IntColumn ) VALUES " + "( 'Test_10', 40 ), " + "( 'Test_11', 41 ), " + "( 'Test_12', 42 ), " + "( 'Test_13', 47 ), " + "( 'Test_14', 48 );"); + EXPECT(result->error().code == SQL::SQLErrorCode::NoError); + EXPECT(result->inserted() == 5); + result = execute(database, + "SELECT TestTable1.IntColumn, TextColumn1, TextColumn2 " + "FROM TestSchema.TestTable1, TestSchema.TestTable2 " + "WHERE TestTable1.IntColumn = TestTable2.IntColumn;"); + EXPECT(result->error().code == SQL::SQLErrorCode::NoError); + EXPECT(result->has_results()); + EXPECT_EQ(result->results().size(), 1u); + auto& row = result->results()[0]; + EXPECT_EQ(row.size(), 3u); + EXPECT_EQ(row[0].to_int().value(), 42); + EXPECT_EQ(row[1].to_string(), "Test_1"); + EXPECT_EQ(row[2].to_string(), "Test_12"); +} + } diff --git a/Userland/Libraries/LibSQL/AST/Expression.cpp b/Userland/Libraries/LibSQL/AST/Expression.cpp index d3e1337be2..a7b793e602 100644 --- a/Userland/Libraries/LibSQL/AST/Expression.cpp +++ b/Userland/Libraries/LibSQL/AST/Expression.cpp @@ -170,11 +170,21 @@ Value ColumnNameExpression::evaluate(ExecutionContext& context) const { auto& descriptor = *context.current_row->descriptor(); VERIFY(context.current_row->size() == descriptor.size()); + Optional index_in_row; for (auto ix = 0u; ix < context.current_row->size(); ix++) { auto& column_descriptor = descriptor[ix]; - if (column_descriptor.name == column_name()) - return { (*context.current_row)[ix] }; + if (!table_name().is_empty() && column_descriptor.table != table_name()) + continue; + if (column_descriptor.name == column_name()) { + if (index_in_row.has_value()) { + context.result->set_error(SQLErrorCode::AmbiguousColumnName, column_name()); + return Value::null(); + } + index_in_row = ix; + } } + if (index_in_row.has_value()) + return (*context.current_row)[index_in_row.value()]; context.result->set_error(SQLErrorCode::ColumnDoesNotExist, column_name()); return Value::null(); } diff --git a/Userland/Libraries/LibSQL/AST/Select.cpp b/Userland/Libraries/LibSQL/AST/Select.cpp index 6214e67015..81ef322a57 100644 --- a/Userland/Libraries/LibSQL/AST/Select.cpp +++ b/Userland/Libraries/LibSQL/AST/Select.cpp @@ -13,13 +13,14 @@ namespace SQL::AST { RefPtr Select::execute(ExecutionContext& context) const { - if (table_or_subquery_list().size() == 1 && table_or_subquery_list()[0].is_table()) { - auto table = context.database->get_table(table_or_subquery_list()[0].schema_name(), table_or_subquery_list()[0].table_name()); + NonnullRefPtrVector columns; + for (auto& table_descriptor : table_or_subquery_list()) { + if (!table_descriptor.is_table()) + TODO(); + auto table = context.database->get_table(table_descriptor.schema_name(), table_descriptor.table_name()); if (!table) { - return SQLResult::construct(SQL::SQLCommand::Select, SQL::SQLErrorCode::TableDoesNotExist, table_or_subquery_list()[0].table_name()); + return SQLResult::construct(SQL::SQLCommand::Select, SQL::SQLErrorCode::TableDoesNotExist, table_descriptor.table_name()); } - - NonnullRefPtrVector columns; if (result_column_list().size() == 1 && result_column_list()[0].type() == ResultType::All) { for (auto& col : table->columns()) { columns.append( @@ -27,35 +28,64 @@ RefPtr Select::execute(ExecutionContext& context) const create_ast_node(table->parent()->name(), table->name(), col.name()), "")); } - } else { - for (auto& col : result_column_list()) { - columns.append(col); - } } - context.result = SQLResult::construct(); - AK::NonnullRefPtr descriptor = AK::adopt_ref(*new TupleDescriptor); - Tuple tuple(descriptor); - for (auto& row : context.database->select_all(*table)) { - context.current_row = &row; - if (where_clause()) { - auto where_result = where_clause()->evaluate(context); - if (context.result->has_error()) - return context.result; - if (!where_result) - continue; - } - tuple.clear(); - for (auto& col : columns) { - auto value = col.expression()->evaluate(context); - if (context.result->has_error()) - return context.result; - tuple.append(value); - } - context.result->append(tuple); - } - return context.result; } - return SQLResult::construct(); + + VERIFY(!result_column_list().is_empty()); + if (result_column_list().size() != 1 || result_column_list()[0].type() != ResultType::All) { + for (auto& col : result_column_list()) { + if (col.type() == ResultType::All) + // FIXME can have '*' for example in conjunction with computed columns + return SQLResult::construct(SQL::SQLCommand::Select, SQL::SQLErrorCode::SyntaxError, "*"); + columns.append(col); + } + } + + context.result = SQLResult::construct(); + AK::NonnullRefPtr descriptor = AK::adopt_ref(*new TupleDescriptor); + Tuple tuple(descriptor); + Vector rows; + descriptor->empend("__unity__"); + tuple.append(Value(SQLType::Boolean, true)); + rows.append(tuple); + + for (auto& table_descriptor : table_or_subquery_list()) { + if (!table_descriptor.is_table()) + TODO(); + auto table = context.database->get_table(table_descriptor.schema_name(), table_descriptor.table_name()); + if (table->num_columns() == 0) + continue; + auto old_descriptor_size = descriptor->size(); + descriptor->extend(table->to_tuple_descriptor()); + for (auto cartesian_row = rows.first(); cartesian_row.size() == old_descriptor_size; cartesian_row = rows.first()) { + rows.remove(0); + for (auto& table_row : context.database->select_all(*table)) { + auto new_row = cartesian_row; + new_row.extend(table_row); + rows.append(new_row); + } + } + } + + for (auto& row : rows) { + context.current_row = &row; + if (where_clause()) { + auto where_result = where_clause()->evaluate(context); + if (context.result->has_error()) + return context.result; + if (!where_result) + continue; + } + tuple.clear(); + for (auto& col : columns) { + auto value = col.expression()->evaluate(context); + if (context.result->has_error()) + return context.result; + tuple.append(value); + } + context.result->append(tuple); + } + return context.result; } } diff --git a/Userland/Libraries/LibSQL/SQLResult.h b/Userland/Libraries/LibSQL/SQLResult.h index 33eab92c2a..6f14873655 100644 --- a/Userland/Libraries/LibSQL/SQLResult.h +++ b/Userland/Libraries/LibSQL/SQLResult.h @@ -51,6 +51,7 @@ constexpr char const* command_tag(SQLCommand command) S(SchemaExists, "Schema '{}' already exist") \ S(TableDoesNotExist, "Table '{}' does not exist") \ S(ColumnDoesNotExist, "Column '{}' does not exist") \ + S(AmbiguousColumnName, "Column name '{}' is ambiguous") \ S(TableExists, "Table '{}' already exist") \ S(InvalidType, "Invalid type '{}'") \ S(InvalidDatabaseName, "Invalid database name '{}'") \