diff --git a/Userland/Libraries/LibSQL/SQLClient.cpp b/Userland/Libraries/LibSQL/SQLClient.cpp index 94514ccf32..2f54f84bbc 100644 --- a/Userland/Libraries/LibSQL/SQLClient.cpp +++ b/Userland/Libraries/LibSQL/SQLClient.cpp @@ -5,10 +5,143 @@ * SPDX-License-Identifier: BSD-2-Clause */ +#include #include +#if !defined(AK_OS_SERENITY) +# include +# include +# include +# include +# include +# include +#endif + namespace SQL { +#if !defined(AK_OS_SERENITY) + +// This is heavily based on how SystemServer's Service creates its socket. +static ErrorOr create_database_socket(DeprecatedString const& socket_path) +{ + if (Core::File::exists(socket_path)) + TRY(Core::System::unlink(socket_path)); + +# ifdef SOCK_NONBLOCK + auto socket_fd = TRY(Core::System::socket(AF_LOCAL, SOCK_STREAM | SOCK_NONBLOCK | SOCK_CLOEXEC, 0)); +# else + auto socket_fd = TRY(Core::System::socket(AF_LOCAL, SOCK_STREAM, 0)); + + int option = 1; + TRY(Core::System::ioctl(socket_fd, FIONBIO, &option)); + TRY(Core::System::fcntl(socket_fd, F_SETFD, FD_CLOEXEC)); +# endif + + TRY(Core::System::fchmod(socket_fd, 0600)); + + auto socket_address = Core::SocketAddress::local(socket_path); + auto socket_address_un = socket_address.to_sockaddr_un().release_value(); + + TRY(Core::System::bind(socket_fd, reinterpret_cast(&socket_address_un), sizeof(socket_address_un))); + TRY(Core::System::listen(socket_fd, 16)); + + return socket_fd; +} + +static ErrorOr launch_server(DeprecatedString const& socket_path, DeprecatedString const& pid_path, StringView server_path) +{ + auto server_fd = TRY(create_database_socket(socket_path)); + auto server_pid = TRY(Core::System::fork()); + + if (server_pid == 0) { + TRY(Core::System::setsid()); + TRY(Core::System::signal(SIGCHLD, SIG_IGN)); + server_pid = TRY(Core::System::fork()); + + if (server_pid != 0) { + auto server_pid_file = TRY(Core::Stream::File::open(pid_path, Core::Stream::OpenMode::Write)); + TRY(server_pid_file->write(DeprecatedString::number(server_pid).bytes())); + + exit(0); + } + + server_fd = TRY(Core::System::dup(server_fd)); + + auto takeover_string = DeprecatedString::formatted("{}:{}", socket_path, server_fd); + TRY(Core::System::setenv("SOCKET_TAKEOVER"sv, takeover_string, true)); + + auto arguments = Array { + server_path, + "--pid-file"sv, + pid_path, + }; + + auto result = Core::System::exec(arguments[0], arguments, Core::System::SearchInPath::Yes); + if (result.is_error()) { + warnln("Could not launch {}: {}", server_path, result.error()); + TRY(Core::System::unlink(pid_path)); + } + + VERIFY_NOT_REACHED(); + } + + TRY(Core::System::waitpid(server_pid)); + return {}; +} + +static ErrorOr should_launch_server(DeprecatedString const& pid_path) +{ + if (!Core::File::exists(pid_path)) + return true; + + Optional pid; + { + auto server_pid_file = Core::Stream::File::open(pid_path, Core::Stream::OpenMode::Read); + if (server_pid_file.is_error()) { + warnln("Could not open SQLServer PID file '{}': {}", pid_path, server_pid_file.error()); + return server_pid_file.release_error(); + } + + auto contents = server_pid_file.value()->read_all(); + if (contents.is_error()) { + warnln("Could not read SQLServer PID file '{}': {}", pid_path, contents.error()); + return contents.release_error(); + } + + pid = StringView { contents.value() }.to_int(); + } + + if (!pid.has_value()) { + warnln("SQLServer PID file '{}' exists, but with an invalid PID", pid_path); + TRY(Core::System::unlink(pid_path)); + return true; + } + if (kill(*pid, 0) < 0) { + warnln("SQLServer PID file '{}' exists with PID {}, but process cannot be found", pid_path, *pid); + TRY(Core::System::unlink(pid_path)); + return true; + } + + return false; +} + +ErrorOr> SQLClient::launch_server_and_create_client(StringView server_path) +{ + auto runtime_directory = TRY(Core::StandardPaths::runtime_directory()); + auto socket_path = DeprecatedString::formatted("{}/SQLServer.socket", runtime_directory); + auto pid_path = DeprecatedString::formatted("{}/SQLServer.pid", runtime_directory); + + if (TRY(should_launch_server(pid_path))) + TRY(launch_server(socket_path, pid_path, server_path)); + + auto socket = TRY(Core::Stream::LocalSocket::connect(move(socket_path))); + TRY(socket->set_blocking(true)); + + return adopt_nonnull_ref_or_enomem(new (nothrow) SQLClient(std::move(socket))); +} + +#endif + void SQLClient::execution_error(u64 statement_id, u64 execution_id, SQLErrorCode const& code, DeprecatedString const& message) { if (on_execution_error) diff --git a/Userland/Libraries/LibSQL/SQLClient.h b/Userland/Libraries/LibSQL/SQLClient.h index 3e1ceab937..ed8c597baa 100644 --- a/Userland/Libraries/LibSQL/SQLClient.h +++ b/Userland/Libraries/LibSQL/SQLClient.h @@ -7,6 +7,7 @@ #pragma once +#include #include #include #include @@ -20,10 +21,9 @@ class SQLClient IPC_CLIENT_CONNECTION(SQLClient, "/tmp/session/%sid/portal/sql"sv) public: - explicit SQLClient(NonnullOwnPtr socket) - : IPC::ConnectionToServer(*this, move(socket)) - { - } +#if !defined(AK_OS_SERENITY) + static ErrorOr> launch_server_and_create_client(StringView server_path); +#endif virtual ~SQLClient() = default; @@ -33,6 +33,11 @@ public: Function on_results_exhausted; private: + explicit SQLClient(NonnullOwnPtr socket) + : IPC::ConnectionToServer(*this, move(socket)) + { + } + virtual void execution_success(u64 statement_id, u64 execution_id, bool has_results, size_t created, size_t updated, size_t deleted) override; virtual void next_result(u64 statement_id, u64 execution_id, Vector const&) override; virtual void results_exhausted(u64 statement_id, u64 execution_id, size_t total_rows) override;