From 0906e3c20661f357ebe808274568ed18c35d9a40 Mon Sep 17 00:00:00 2001 From: Mahmoud Mandour Date: Fri, 17 Sep 2021 17:15:55 +0200 Subject: [PATCH] LibSQL: Check data types in INSERT statement parsing Data types are now checked against the table data types. When multiple rows are inserted at once, we check all rows to be matching W.R.T data types. Only then we insert the rows. --- Userland/Libraries/LibSQL/AST/Insert.cpp | 51 +++++++++++++++++------- Userland/Libraries/LibSQL/SQLResult.h | 1 + 2 files changed, 38 insertions(+), 14 deletions(-) diff --git a/Userland/Libraries/LibSQL/AST/Insert.cpp b/Userland/Libraries/LibSQL/AST/Insert.cpp index a6effe5d82..2b1fe3d72a 100644 --- a/Userland/Libraries/LibSQL/AST/Insert.cpp +++ b/Userland/Libraries/LibSQL/AST/Insert.cpp @@ -12,6 +12,19 @@ namespace SQL::AST { +static bool does_value_data_type_match(SQLType expected, SQLType actual) +{ + if (actual == SQLType::Null) { + return false; + } + + if (expected == SQLType::Integer) { + return actual == SQLType::Integer || actual == SQLType::Float; + } + + return expected == actual; +} + RefPtr Insert::execute(ExecutionContext& context) const { auto table_def = context.database->get_table(m_schema_name, m_table_name); @@ -29,6 +42,8 @@ RefPtr Insert::execute(ExecutionContext& context) const } } + Vector inserted_rows; + inserted_rows.ensure_capacity(m_chained_expressions.size()); for (auto& row_expr : m_chained_expressions) { for (auto& column_def : table_def->columns()) { if (!m_column_names.contains_slow(column_def.name())) { @@ -39,22 +54,30 @@ RefPtr Insert::execute(ExecutionContext& context) const VERIFY(row_value.type() == SQLType::Tuple); auto values = row_value.to_vector().value(); - // FIXME: Check that the values[ix] match the data type of the column. - if (m_column_names.size() > 0) { - for (auto ix = 0u; ix < values.size(); ix++) { - auto& column_name = m_column_names[ix]; - row[column_name] = values[ix]; - } - } else { - if (values.size() != row.size()) { - return SQLResult::construct(SQLCommand::Insert, SQLErrorCode::InvalidNumberOfValues, ""); - } - for (auto ix = 0u; ix < values.size(); ix++) { - row[ix] = values[ix]; - } + if (m_column_names.size() == 0 && values.size() != row.size()) { + return SQLResult::construct(SQLCommand::Insert, SQLErrorCode::InvalidNumberOfValues, ""); } - context.database->insert(row); + + for (auto ix = 0u; ix < values.size(); ix++) { + auto input_value_type = values[ix].type(); + auto& tuple_descriptor = *row.descriptor(); + // In case of having column names, this must succeed since we checked for every column name for existence in the table. + auto element_index = (m_column_names.size() == 0) ? ix : tuple_descriptor.find_if([&](auto element) { return element.name == m_column_names[ix]; }).index(); + auto element_type = tuple_descriptor[element_index].type; + + if (!does_value_data_type_match(element_type, input_value_type)) { + return SQLResult::construct(SQLCommand::Insert, SQLErrorCode::InvalidValueType, table_def->columns()[element_index].name()); + } + + row[element_index] = values[ix]; + } + inserted_rows.append(row); } + + for (auto& inserted_row : inserted_rows) { + context.database->insert(inserted_row); + } + return SQLResult::construct(SQLCommand::Insert, 0, m_chained_expressions.size(), 0); } diff --git a/Userland/Libraries/LibSQL/SQLResult.h b/Userland/Libraries/LibSQL/SQLResult.h index a22d15e998..9702bc1eb9 100644 --- a/Userland/Libraries/LibSQL/SQLResult.h +++ b/Userland/Libraries/LibSQL/SQLResult.h @@ -54,6 +54,7 @@ constexpr char const* command_tag(SQLCommand command) S(TableExists, "Table '{}' already exist") \ S(InvalidType, "Invalid type '{}'") \ S(InvalidDatabaseName, "Invalid database name '{}'") \ + S(InvalidValueType, "Invalid type for attribute '{}'") \ S(InvalidNumberOfValues, "Number of values does not match number of columns") enum class SQLErrorCode {