diff --git a/Tests/LibSQL/TestSqlDatabase.cpp b/Tests/LibSQL/TestSqlDatabase.cpp index 1623734a44..2e5c6b5ba7 100644 --- a/Tests/LibSQL/TestSqlDatabase.cpp +++ b/Tests/LibSQL/TestSqlDatabase.cpp @@ -172,7 +172,6 @@ TEST_CASE(get_schema_from_database) EXPECT(!db->open().is_error()); auto schema_or_error = db->get_schema("TestSchema"); EXPECT(!schema_or_error.is_error()); - EXPECT(schema_or_error.value()); } } diff --git a/Tests/LibSQL/TestSqlStatementExecution.cpp b/Tests/LibSQL/TestSqlStatementExecution.cpp index bc5cf1035d..2e8e06b5ed 100644 --- a/Tests/LibSQL/TestSqlStatementExecution.cpp +++ b/Tests/LibSQL/TestSqlStatementExecution.cpp @@ -71,7 +71,6 @@ TEST_CASE(create_schema) create_schema(database); auto schema_or_error = database->get_schema("TESTSCHEMA"); EXPECT(!schema_or_error.is_error()); - EXPECT(schema_or_error.value()); } TEST_CASE(create_table) diff --git a/Userland/Libraries/LibSQL/AST/CreateSchema.cpp b/Userland/Libraries/LibSQL/AST/CreateSchema.cpp index a41cc3a269..5f60f009a0 100644 --- a/Userland/Libraries/LibSQL/AST/CreateSchema.cpp +++ b/Userland/Libraries/LibSQL/AST/CreateSchema.cpp @@ -12,17 +12,13 @@ namespace SQL::AST { ResultOr CreateSchema::execute(ExecutionContext& context) const { - auto schema_def = TRY(context.database->get_schema(m_schema_name)); + auto schema_def = SchemaDef::construct(m_schema_name); - if (schema_def) { - if (m_is_error_if_schema_exists) - return Result { SQLCommand::Create, SQLErrorCode::SchemaExists, m_schema_name }; - return ResultSet { SQLCommand::Create }; + if (auto result = context.database->add_schema(*schema_def); result.is_error()) { + if (result.error().error() != SQLErrorCode::SchemaExists || m_is_error_if_schema_exists) + return result.release_error(); } - schema_def = SchemaDef::construct(m_schema_name); - TRY(context.database->add_schema(*schema_def)); - return ResultSet { SQLCommand::Create }; } diff --git a/Userland/Libraries/LibSQL/AST/CreateTable.cpp b/Userland/Libraries/LibSQL/AST/CreateTable.cpp index a427e59f25..2312e0f23d 100644 --- a/Userland/Libraries/LibSQL/AST/CreateTable.cpp +++ b/Userland/Libraries/LibSQL/AST/CreateTable.cpp @@ -11,13 +11,8 @@ namespace SQL::AST { ResultOr CreateTable::execute(ExecutionContext& context) const { - auto schema_name = m_schema_name.is_empty() ? String { "default"sv } : m_schema_name; - - auto schema_def = TRY(context.database->get_schema(schema_name)); - if (!schema_def) - return Result { SQLCommand::Create, SQLErrorCode::SchemaDoesNotExist, schema_name }; - - auto table_def = TRY(context.database->get_table(schema_name, m_table_name)); + auto schema_def = TRY(context.database->get_schema(m_schema_name)); + auto table_def = TRY(context.database->get_table(m_schema_name, m_table_name)); if (table_def) { if (m_is_error_if_table_exists) return Result { SQLCommand::Create, SQLErrorCode::TableExists, m_table_name }; diff --git a/Userland/Libraries/LibSQL/Database.cpp b/Userland/Libraries/LibSQL/Database.cpp index 424c0e7ba8..401d18fb21 100644 --- a/Userland/Libraries/LibSQL/Database.cpp +++ b/Userland/Libraries/LibSQL/Database.cpp @@ -24,9 +24,10 @@ Database::Database(String name) { } -ErrorOr Database::open() +ResultOr Database::open() { TRY(m_heap->open()); + m_schemas = BTree::construct(m_serializer, SchemaDef::index_def()->to_tuple_descriptor(), m_heap->schemas_root()); m_schemas->on_new_root = [&]() { m_heap->set_schemas_root(m_schemas->root()); @@ -43,17 +44,22 @@ ErrorOr Database::open() }; m_open = true; - auto default_schema = TRY(get_schema("default")); - if (!default_schema) { - default_schema = SchemaDef::construct("default"); - TRY(add_schema(*default_schema)); - } - auto master_schema = TRY(get_schema("master")); - if (!master_schema) { - master_schema = SchemaDef::construct("master"); - TRY(add_schema(*master_schema)); - } + auto ensure_schema_exists = [&](auto schema_name) -> ResultOr> { + if (auto result = get_schema(schema_name); result.is_error()) { + if (result.error().error() != SQLErrorCode::SchemaDoesNotExist) + return result.release_error(); + + auto schema_def = SchemaDef::construct(schema_name); + TRY(add_schema(*schema_def)); + return schema_def; + } else { + return result.release_value(); + } + }; + + (void)TRY(ensure_schema_exists("default"sv)); + auto master_schema = TRY(ensure_schema_exists("master"sv)); auto table_def = TRY(get_table("master", "internal_describe_table")); if (!table_def) { @@ -75,13 +81,12 @@ ErrorOr Database::commit() return {}; } -ErrorOr Database::add_schema(SchemaDef const& schema) +ResultOr Database::add_schema(SchemaDef const& schema) { VERIFY(is_open()); - if (!m_schemas->insert(schema.key())) { - warnln("Duplicate schema name {}"sv, schema.name()); - return Error::from_string_literal("Duplicate schema name"); - } + + if (!m_schemas->insert(schema.key())) + return Result { SQLCommand::Unknown, SQLErrorCode::SchemaExists, schema.name() }; return {}; } @@ -92,24 +97,25 @@ Key Database::get_schema_key(String const& schema_name) return key; } -ErrorOr> Database::get_schema(String const& schema) +ResultOr> Database::get_schema(String const& schema) { VERIFY(is_open()); + auto schema_name = schema; - if (schema.is_null() || schema.is_empty()) - schema_name = "default"; + if (schema.is_empty()) + schema_name = "default"sv; + Key key = get_schema_key(schema_name); - auto schema_def_opt = m_schema_cache.get(key.hash()); - if (schema_def_opt.has_value()) { - return RefPtr(schema_def_opt.value()); - } + if (auto it = m_schema_cache.find(key.hash()); it != m_schema_cache.end()) + return it->value; + auto schema_iterator = m_schemas->find(key); - if (schema_iterator.is_end() || (*schema_iterator != key)) { - return RefPtr(nullptr); - } - auto ret = SchemaDef::construct(*schema_iterator); - m_schema_cache.set(key.hash(), ret); - return RefPtr(ret); + if (schema_iterator.is_end() || (*schema_iterator != key)) + return Result { SQLCommand::Unknown, SQLErrorCode::SchemaDoesNotExist, schema_name }; + + auto schema_def = SchemaDef::construct(*schema_iterator); + m_schema_cache.set(key.hash(), schema_def); + return schema_def; } ErrorOr Database::add_table(TableDef& table) @@ -132,7 +138,7 @@ Key Database::get_table_key(String const& schema_name, String const& table_name) return key; } -ErrorOr> Database::get_table(String const& schema, String const& name) +ResultOr> Database::get_table(String const& schema, String const& name) { VERIFY(is_open()); auto schema_name = schema; @@ -147,10 +153,6 @@ ErrorOr> Database::get_table(String const& schema, String const return RefPtr(nullptr); } auto schema_def = TRY(get_schema(schema)); - if (!schema_def) { - warnln("Schema '{}' does not exist"sv, schema); - return Error::from_string_literal("Schema does not exist"); - } auto ret = TableDef::construct(schema_def, name); ret->set_pointer((*table_iterator).pointer()); m_table_cache.set(key.hash(), ret); diff --git a/Userland/Libraries/LibSQL/Database.h b/Userland/Libraries/LibSQL/Database.h index cafd08cd59..29c97c089c 100644 --- a/Userland/Libraries/LibSQL/Database.h +++ b/Userland/Libraries/LibSQL/Database.h @@ -13,6 +13,7 @@ #include #include #include +#include #include namespace SQL { @@ -28,17 +29,17 @@ class Database : public Core::Object { public: ~Database() override; - ErrorOr open(); + ResultOr open(); bool is_open() const { return m_open; } ErrorOr commit(); - ErrorOr add_schema(SchemaDef const&); + ResultOr add_schema(SchemaDef const&); static Key get_schema_key(String const&); - ErrorOr> get_schema(String const&); + ResultOr> get_schema(String const&); ErrorOr add_table(TableDef& table); static Key get_table_key(String const&, String const&); - ErrorOr> get_table(String const&, String const&); + ResultOr> get_table(String const&, String const&); ErrorOr> select_all(TableDef const&); ErrorOr> match(TableDef const&, Key const&); @@ -55,7 +56,7 @@ private: RefPtr m_tables; RefPtr m_table_columns; - HashMap> m_schema_cache; + HashMap> m_schema_cache; HashMap> m_table_cache; }; diff --git a/Userland/Services/SQLServer/DatabaseConnection.cpp b/Userland/Services/SQLServer/DatabaseConnection.cpp index 723af326cb..ce76b3d248 100644 --- a/Userland/Services/SQLServer/DatabaseConnection.cpp +++ b/Userland/Services/SQLServer/DatabaseConnection.cpp @@ -41,7 +41,7 @@ DatabaseConnection::DatabaseConnection(String database_name, int client_id) m_database = SQL::Database::construct(String::formatted("/home/anon/sql/{}.db", m_database_name)); auto client_connection = ConnectionFromClient::client_connection_for(m_client_id); if (auto maybe_error = m_database->open(); maybe_error.is_error()) { - client_connection->async_connection_error(m_connection_id, (int)SQL::SQLErrorCode::InternalError, maybe_error.error().string_literal()); + client_connection->async_connection_error(m_connection_id, to_underlying(maybe_error.error().error()), maybe_error.error().error_string()); return; } m_accept_statements = true;