diff --git a/Userland/Services/SQLServer/ConnectionFromClient.cpp b/Userland/Services/SQLServer/ConnectionFromClient.cpp index 3c358dd9ae..0445244aaa 100644 --- a/Userland/Services/SQLServer/ConnectionFromClient.cpp +++ b/Userland/Services/SQLServer/ConnectionFromClient.cpp @@ -4,8 +4,8 @@ * SPDX-License-Identifier: BSD-2-Clause */ -#include #include +#include #include #include #include @@ -23,8 +23,14 @@ RefPtr ConnectionFromClient::client_connection_for(int cli return nullptr; } +void ConnectionFromClient::set_database_path(DeprecatedString database_path) +{ + m_database_path = move(database_path); +} + ConnectionFromClient::ConnectionFromClient(NonnullOwnPtr socket, int client_id) : IPC::ConnectionFromClient(*this, move(socket), client_id) + , m_database_path(DeprecatedString::formatted("{}/sql", Core::StandardPaths::data_directory())) { s_connections.set(client_id, *this); } @@ -38,7 +44,7 @@ Messages::SQLServer::ConnectResponse ConnectionFromClient::connect(DeprecatedStr { dbgln_if(SQLSERVER_DEBUG, "ConnectionFromClient::connect(database_name: {})", database_name); - if (auto database_connection = DatabaseConnection::create(database_name, client_id()); !database_connection.is_error()) + if (auto database_connection = DatabaseConnection::create(m_database_path, database_name, client_id()); !database_connection.is_error()) return { database_connection.value()->connection_id() }; return { {} }; } diff --git a/Userland/Services/SQLServer/ConnectionFromClient.h b/Userland/Services/SQLServer/ConnectionFromClient.h index c880d649de..783ab55883 100644 --- a/Userland/Services/SQLServer/ConnectionFromClient.h +++ b/Userland/Services/SQLServer/ConnectionFromClient.h @@ -6,6 +6,7 @@ #pragma once +#include #include #include #include @@ -25,6 +26,8 @@ public: static RefPtr client_connection_for(int client_id); + void set_database_path(DeprecatedString); + private: explicit ConnectionFromClient(NonnullOwnPtr, int client_id); @@ -32,6 +35,8 @@ private: virtual Messages::SQLServer::PrepareStatementResponse prepare_statement(u64, DeprecatedString const&) override; virtual Messages::SQLServer::ExecuteStatementResponse execute_statement(u64, Vector const& placeholder_values) override; virtual void disconnect(u64) override; + + DeprecatedString m_database_path; }; } diff --git a/Userland/Services/SQLServer/DatabaseConnection.cpp b/Userland/Services/SQLServer/DatabaseConnection.cpp index 46f3232000..9ba9cf2b73 100644 --- a/Userland/Services/SQLServer/DatabaseConnection.cpp +++ b/Userland/Services/SQLServer/DatabaseConnection.cpp @@ -21,12 +21,14 @@ RefPtr DatabaseConnection::connection_for(u64 connection_id) return nullptr; } -ErrorOr> DatabaseConnection::create(DeprecatedString database_name, int client_id) +ErrorOr> DatabaseConnection::create(StringView database_path, DeprecatedString database_name, int client_id) { if (LexicalPath path(database_name); (path.title() != database_name) || (path.dirname() != ".")) return Error::from_string_view("Invalid database name"sv); - auto database = SQL::Database::construct(DeprecatedString::formatted("/home/anon/sql/{}.db", database_name)); + auto database_file = DeprecatedString::formatted("{}/{}.db", database_path, database_name); + auto database = SQL::Database::construct(move(database_file)); + if (auto result = database->open(); result.is_error()) { warnln("Could not open database: {}", result.error().error_string()); return Error::from_string_view("Could not open database"sv); diff --git a/Userland/Services/SQLServer/DatabaseConnection.h b/Userland/Services/SQLServer/DatabaseConnection.h index eea9acbfe2..60f852bc8e 100644 --- a/Userland/Services/SQLServer/DatabaseConnection.h +++ b/Userland/Services/SQLServer/DatabaseConnection.h @@ -18,7 +18,7 @@ class DatabaseConnection final : public Core::Object { C_OBJECT_ABSTRACT(DatabaseConnection) public: - static ErrorOr> create(DeprecatedString database_name, int client_id); + static ErrorOr> create(StringView database_path, DeprecatedString database_name, int client_id); ~DatabaseConnection() override = default; static RefPtr connection_for(u64 connection_id); diff --git a/Userland/Services/SQLServer/main.cpp b/Userland/Services/SQLServer/main.cpp index 3727fdd81d..0c8f2c754b 100644 --- a/Userland/Services/SQLServer/main.cpp +++ b/Userland/Services/SQLServer/main.cpp @@ -4,24 +4,22 @@ * SPDX-License-Identifier: BSD-2-Clause */ +#include #include +#include #include #include #include #include -#include -#include ErrorOr serenity_main(Main::Arguments) { TRY(Core::System::pledge("stdio accept unix rpath wpath cpath")); - if (mkdir("/home/anon/sql", 0700) < 0 && errno != EEXIST) { - perror("mkdir"); - return 1; - } + auto database_path = DeprecatedString::formatted("{}/sql", Core::StandardPaths::data_directory()); + TRY(Core::Directory::create(database_path, Core::Directory::CreateDirectories::Yes)); - TRY(Core::System::unveil("/home/anon/sql", "rwc")); + TRY(Core::System::unveil(database_path, "rwc"sv)); TRY(Core::System::unveil(nullptr, nullptr)); Core::EventLoop event_loop;