mirror of
				https://github.com/RGBCube/serenity
				synced 2025-10-31 16:12:44 +00:00 
			
		
		
		
	
		
			
				
	
	
		
			227 lines
		
	
	
	
		
			7.3 KiB
		
	
	
	
		
			C++
		
	
	
	
	
	
			
		
		
	
	
			227 lines
		
	
	
	
		
			7.3 KiB
		
	
	
	
		
			C++
		
	
	
	
	
	
| /*
 | |
|  * Copyright (c) 2021, Jan de Visser <jan@de-visser.net>
 | |
|  * Copyright (c) 2022, the SerenityOS developers.
 | |
|  *
 | |
|  * SPDX-License-Identifier: BSD-2-Clause
 | |
|  */
 | |
| 
 | |
| #include <AK/DeprecatedString.h>
 | |
| #include <AK/String.h>
 | |
| #include <LibSQL/SQLClient.h>
 | |
| 
 | |
| #if !defined(AK_OS_SERENITY)
 | |
| #    include <LibCore/DeprecatedFile.h>
 | |
| #    include <LibCore/Directory.h>
 | |
| #    include <LibCore/SocketAddress.h>
 | |
| #    include <LibCore/StandardPaths.h>
 | |
| #    include <LibCore/System.h>
 | |
| #endif
 | |
| 
 | |
| namespace SQL {
 | |
| 
 | |
| #if !defined(AK_OS_SERENITY)
 | |
| 
 | |
| // This is heavily based on how SystemServer's Service creates its socket.
 | |
| static ErrorOr<int> create_database_socket(DeprecatedString const& socket_path)
 | |
| {
 | |
|     if (Core::DeprecatedFile::exists(socket_path))
 | |
|         TRY(Core::System::unlink(socket_path));
 | |
| 
 | |
| #    ifdef SOCK_NONBLOCK
 | |
|     auto socket_fd = TRY(Core::System::socket(AF_LOCAL, SOCK_STREAM | SOCK_NONBLOCK | SOCK_CLOEXEC, 0));
 | |
| #    else
 | |
|     auto socket_fd = TRY(Core::System::socket(AF_LOCAL, SOCK_STREAM, 0));
 | |
| 
 | |
|     int option = 1;
 | |
|     TRY(Core::System::ioctl(socket_fd, FIONBIO, &option));
 | |
|     TRY(Core::System::fcntl(socket_fd, F_SETFD, FD_CLOEXEC));
 | |
| #    endif
 | |
| 
 | |
| #    if !defined(AK_OS_BSD_GENERIC)
 | |
|     TRY(Core::System::fchmod(socket_fd, 0600));
 | |
| #    endif
 | |
| 
 | |
|     auto socket_address = Core::SocketAddress::local(socket_path);
 | |
|     auto socket_address_un = socket_address.to_sockaddr_un().release_value();
 | |
| 
 | |
|     TRY(Core::System::bind(socket_fd, reinterpret_cast<sockaddr*>(&socket_address_un), sizeof(socket_address_un)));
 | |
|     TRY(Core::System::listen(socket_fd, 16));
 | |
| 
 | |
|     return socket_fd;
 | |
| }
 | |
| 
 | |
| static ErrorOr<void> launch_server(DeprecatedString const& socket_path, DeprecatedString const& pid_path, Vector<String> candidate_server_paths)
 | |
| {
 | |
|     auto server_fd_or_error = create_database_socket(socket_path);
 | |
|     if (server_fd_or_error.is_error()) {
 | |
|         warnln("Failed to create a database socket at {}: {}", socket_path, server_fd_or_error.error());
 | |
|         return server_fd_or_error.release_error();
 | |
|     }
 | |
|     auto server_fd = server_fd_or_error.value();
 | |
|     auto server_pid = TRY(Core::System::fork());
 | |
| 
 | |
|     if (server_pid == 0) {
 | |
|         TRY(Core::System::setsid());
 | |
|         TRY(Core::System::signal(SIGCHLD, SIG_IGN));
 | |
|         server_pid = TRY(Core::System::fork());
 | |
| 
 | |
|         if (server_pid != 0) {
 | |
|             auto server_pid_file = TRY(Core::File::open(pid_path, Core::File::OpenMode::Write));
 | |
|             TRY(server_pid_file->write(DeprecatedString::number(server_pid).bytes()));
 | |
| 
 | |
|             TRY(Core::System::kill(getpid(), SIGTERM));
 | |
|         }
 | |
| 
 | |
|         server_fd = TRY(Core::System::dup(server_fd));
 | |
| 
 | |
|         auto takeover_string = DeprecatedString::formatted("SQLServer:{}", server_fd);
 | |
|         TRY(Core::System::setenv("SOCKET_TAKEOVER"sv, takeover_string, true));
 | |
| 
 | |
|         ErrorOr<void> result;
 | |
|         for (auto const& server_path : candidate_server_paths) {
 | |
|             auto arguments = Array {
 | |
|                 server_path.bytes_as_string_view(),
 | |
|                 "--pid-file"sv,
 | |
|                 pid_path,
 | |
|             };
 | |
|             result = Core::System::exec(arguments[0], arguments, Core::System::SearchInPath::Yes);
 | |
|             if (!result.is_error())
 | |
|                 break;
 | |
|         }
 | |
|         if (result.is_error()) {
 | |
|             warnln("Could not launch any of {}: {}", candidate_server_paths, result.error());
 | |
|             TRY(Core::System::unlink(pid_path));
 | |
|         }
 | |
| 
 | |
|         VERIFY_NOT_REACHED();
 | |
|     }
 | |
| 
 | |
|     TRY(Core::System::waitpid(server_pid));
 | |
|     return {};
 | |
| }
 | |
| 
 | |
| static ErrorOr<bool> should_launch_server(DeprecatedString const& pid_path)
 | |
| {
 | |
|     if (!Core::DeprecatedFile::exists(pid_path))
 | |
|         return true;
 | |
| 
 | |
|     Optional<pid_t> pid;
 | |
|     {
 | |
|         auto server_pid_file = Core::File::open(pid_path, Core::File::OpenMode::Read);
 | |
|         if (server_pid_file.is_error()) {
 | |
|             warnln("Could not open SQLServer PID file '{}': {}", pid_path, server_pid_file.error());
 | |
|             return server_pid_file.release_error();
 | |
|         }
 | |
| 
 | |
|         auto contents = server_pid_file.value()->read_until_eof();
 | |
|         if (contents.is_error()) {
 | |
|             warnln("Could not read SQLServer PID file '{}': {}", pid_path, contents.error());
 | |
|             return contents.release_error();
 | |
|         }
 | |
| 
 | |
|         pid = StringView { contents.value() }.to_int<pid_t>();
 | |
|     }
 | |
| 
 | |
|     if (!pid.has_value()) {
 | |
|         warnln("SQLServer PID file '{}' exists, but with an invalid PID", pid_path);
 | |
|         TRY(Core::System::unlink(pid_path));
 | |
|         return true;
 | |
|     }
 | |
|     if (kill(*pid, 0) < 0) {
 | |
|         warnln("SQLServer PID file '{}' exists with PID {}, but process cannot be found", pid_path, *pid);
 | |
|         TRY(Core::System::unlink(pid_path));
 | |
|         return true;
 | |
|     }
 | |
| 
 | |
|     return false;
 | |
| }
 | |
| 
 | |
| ErrorOr<NonnullRefPtr<SQLClient>> SQLClient::launch_server_and_create_client(Vector<String> candidate_server_paths)
 | |
| {
 | |
|     auto runtime_directory = TRY(Core::StandardPaths::runtime_directory());
 | |
|     auto socket_path = DeprecatedString::formatted("{}/SQLServer.socket", runtime_directory);
 | |
|     auto pid_path = DeprecatedString::formatted("{}/SQLServer.pid", runtime_directory);
 | |
| 
 | |
|     if (TRY(should_launch_server(pid_path)))
 | |
|         TRY(launch_server(socket_path, pid_path, move(candidate_server_paths)));
 | |
| 
 | |
|     auto socket = TRY(Core::LocalSocket::connect(move(socket_path)));
 | |
|     TRY(socket->set_blocking(true));
 | |
| 
 | |
|     return adopt_nonnull_ref_or_enomem(new (nothrow) SQLClient(move(socket)));
 | |
| }
 | |
| 
 | |
| #endif
 | |
| 
 | |
| void SQLClient::execution_success(u64 statement_id, u64 execution_id, Vector<DeprecatedString> const& column_names, bool has_results, size_t created, size_t updated, size_t deleted)
 | |
| {
 | |
|     if (!on_execution_success) {
 | |
|         outln("{} row(s) created, {} updated, {} deleted", created, updated, deleted);
 | |
|         return;
 | |
|     }
 | |
| 
 | |
|     ExecutionSuccess success {
 | |
|         .statement_id = statement_id,
 | |
|         .execution_id = execution_id,
 | |
|         .column_names = move(const_cast<Vector<DeprecatedString>&>(column_names)),
 | |
|         .has_results = has_results,
 | |
|         .rows_created = created,
 | |
|         .rows_updated = updated,
 | |
|         .rows_deleted = deleted,
 | |
|     };
 | |
| 
 | |
|     on_execution_success(move(success));
 | |
| }
 | |
| 
 | |
| void SQLClient::execution_error(u64 statement_id, u64 execution_id, SQLErrorCode const& code, DeprecatedString const& message)
 | |
| {
 | |
|     if (!on_execution_error) {
 | |
|         warnln("Execution error for statement_id {}: {} ({})", statement_id, message, to_underlying(code));
 | |
|         return;
 | |
|     }
 | |
| 
 | |
|     ExecutionError error {
 | |
|         .statement_id = statement_id,
 | |
|         .execution_id = execution_id,
 | |
|         .error_code = code,
 | |
|         .error_message = move(const_cast<DeprecatedString&>(message)),
 | |
|     };
 | |
| 
 | |
|     on_execution_error(move(error));
 | |
| }
 | |
| 
 | |
| void SQLClient::next_result(u64 statement_id, u64 execution_id, Vector<Value> const& row)
 | |
| {
 | |
|     if (!on_next_result) {
 | |
|         StringBuilder builder;
 | |
|         builder.join(", "sv, row, "\"{}\""sv);
 | |
|         outln("{}", builder.string_view());
 | |
|         return;
 | |
|     }
 | |
| 
 | |
|     ExecutionResult result {
 | |
|         .statement_id = statement_id,
 | |
|         .execution_id = execution_id,
 | |
|         .values = move(const_cast<Vector<Value>&>(row)),
 | |
|     };
 | |
| 
 | |
|     on_next_result(move(result));
 | |
| }
 | |
| 
 | |
| void SQLClient::results_exhausted(u64 statement_id, u64 execution_id, size_t total_rows)
 | |
| {
 | |
|     if (!on_results_exhausted) {
 | |
|         outln("{} total row(s)", total_rows);
 | |
|         return;
 | |
|     }
 | |
| 
 | |
|     ExecutionComplete success {
 | |
|         .statement_id = statement_id,
 | |
|         .execution_id = execution_id,
 | |
|         .total_rows = total_rows,
 | |
|     };
 | |
| 
 | |
|     on_results_exhausted(move(success));
 | |
| }
 | |
| 
 | |
| }
 | 
