diff --git a/Userland/Utilities/sql.cpp b/Userland/Utilities/sql.cpp index 4ef193e917..1be7bd4dd7 100644 --- a/Userland/Utilities/sql.cpp +++ b/Userland/Utilities/sql.cpp @@ -21,6 +21,10 @@ String s_history_path = String::formatted("{}/.sql-history", Core::StandardPaths RefPtr s_editor; int s_repl_line_level = 0; bool s_keep_running = true; +String s_pending_database = ""; +String s_current_database = ""; +AK::RefPtr s_sql_client; +int s_connection_id = 0; String prompt_for_level(int level) { @@ -88,12 +92,29 @@ String read_next_piece() return piece.to_string(); } +void connect(String const& database_name) +{ + if (s_current_database.is_empty()) { + s_sql_client->connect(database_name); + } else { + s_pending_database = database_name; + s_sql_client->async_disconnect(s_connection_id); + } +} + void handle_command(StringView command) { - if (command == ".exit" || command == ".quit") + if (command == ".exit" || command == ".quit") { s_keep_running = false; - else + } else if (command.starts_with(".connect ")) { + auto parts = command.split_view(' '); + if (parts.size() == 2) + connect(parts[1]); + else + outln("\033[33;1mUsage: .connect \033[0m {}", command); + } else { outln("\033[33;1mUnrecognized command:\033[0m {}", command); + } } } @@ -158,7 +179,7 @@ int main(int argc, char** argv) }; Core::EventLoop loop; - auto sql_client = SQL::SQLClient::construct(); + s_sql_client = SQL::SQLClient::construct(); int the_connection_id; auto read_sql = [&]() { @@ -172,21 +193,23 @@ int main(int argc, char** argv) if (piece.starts_with('.')) { handle_command(piece); } else { - auto statement_id = sql_client->sql_statement(the_connection_id, piece); - sql_client->async_statement_execute(statement_id); + auto statement_id = s_sql_client->sql_statement(the_connection_id, piece); + s_sql_client->async_statement_execute(statement_id); return; } } while (s_keep_running); - sql_client->async_disconnect(the_connection_id); + s_sql_client->async_disconnect(the_connection_id); }; - sql_client->on_connected = [&](int connection_id, String const& connected_to_database) { + s_sql_client->on_connected = [&](int connection_id, String const& connected_to_database) { outln("** Connected to {} **", connected_to_database); - the_connection_id = connection_id; + s_current_database = connected_to_database; + s_pending_database = ""; + s_connection_id = connection_id; read_sql(); }; - sql_client->on_execution_success = [&](int, bool has_results, int updated, int created, int deleted) { + s_sql_client->on_execution_success = [&](int, bool has_results, int updated, int created, int deleted) { if (updated != 0 || created != 0 || deleted != 0) { outln("{} row(s) updated, {} created, {} deleted", updated, created, deleted); } @@ -195,32 +218,38 @@ int main(int argc, char** argv) } }; - sql_client->on_next_result = [&](int, Vector const& row) { + s_sql_client->on_next_result = [&](int, Vector const& row) { StringBuilder builder; builder.join(", ", row); outln("{}", builder.build()); }; - sql_client->on_results_exhausted = [&](int, int total_rows) { + s_sql_client->on_results_exhausted = [&](int, int total_rows) { outln("{} row(s)", total_rows); read_sql(); }; - sql_client->on_connection_error = [&](int, int code, String const& message) { + s_sql_client->on_connection_error = [&](int, int code, String const& message) { outln("\033[33;1mConnection error:\033[0m {}", message); loop.quit(code); }; - sql_client->on_execution_error = [&](int, int, String const& message) { + s_sql_client->on_execution_error = [&](int, int, String const& message) { outln("\033[33;1mExecution error:\033[0m {}", message); read_sql(); }; - sql_client->on_disconnected = [&](int) { - loop.quit(0); + s_sql_client->on_disconnected = [&](int) { + if (s_pending_database.is_empty()) { + loop.quit(0); + } else { + outln("** Disconnected from {} **", s_current_database); + s_current_database = ""; + s_sql_client->connect(s_pending_database); + } }; - sql_client->connect(database_name); + connect(database_name); auto rc = loop.exec(); s_editor->save_history(s_history_path);