diff --git a/Userland/Libraries/LibSQL/AST/Insert.cpp b/Userland/Libraries/LibSQL/AST/Insert.cpp index a6effe5d82..2b1fe3d72a 100644 --- a/Userland/Libraries/LibSQL/AST/Insert.cpp +++ b/Userland/Libraries/LibSQL/AST/Insert.cpp @@ -12,6 +12,19 @@ namespace SQL::AST { +static bool does_value_data_type_match(SQLType expected, SQLType actual) +{ + if (actual == SQLType::Null) { + return false; + } + + if (expected == SQLType::Integer) { + return actual == SQLType::Integer || actual == SQLType::Float; + } + + return expected == actual; +} + RefPtr Insert::execute(ExecutionContext& context) const { auto table_def = context.database->get_table(m_schema_name, m_table_name); @@ -29,6 +42,8 @@ RefPtr Insert::execute(ExecutionContext& context) const } } + Vector inserted_rows; + inserted_rows.ensure_capacity(m_chained_expressions.size()); for (auto& row_expr : m_chained_expressions) { for (auto& column_def : table_def->columns()) { if (!m_column_names.contains_slow(column_def.name())) { @@ -39,22 +54,30 @@ RefPtr Insert::execute(ExecutionContext& context) const VERIFY(row_value.type() == SQLType::Tuple); auto values = row_value.to_vector().value(); - // FIXME: Check that the values[ix] match the data type of the column. - if (m_column_names.size() > 0) { - for (auto ix = 0u; ix < values.size(); ix++) { - auto& column_name = m_column_names[ix]; - row[column_name] = values[ix]; - } - } else { - if (values.size() != row.size()) { - return SQLResult::construct(SQLCommand::Insert, SQLErrorCode::InvalidNumberOfValues, ""); - } - for (auto ix = 0u; ix < values.size(); ix++) { - row[ix] = values[ix]; - } + if (m_column_names.size() == 0 && values.size() != row.size()) { + return SQLResult::construct(SQLCommand::Insert, SQLErrorCode::InvalidNumberOfValues, ""); } - context.database->insert(row); + + for (auto ix = 0u; ix < values.size(); ix++) { + auto input_value_type = values[ix].type(); + auto& tuple_descriptor = *row.descriptor(); + // In case of having column names, this must succeed since we checked for every column name for existence in the table. + auto element_index = (m_column_names.size() == 0) ? ix : tuple_descriptor.find_if([&](auto element) { return element.name == m_column_names[ix]; }).index(); + auto element_type = tuple_descriptor[element_index].type; + + if (!does_value_data_type_match(element_type, input_value_type)) { + return SQLResult::construct(SQLCommand::Insert, SQLErrorCode::InvalidValueType, table_def->columns()[element_index].name()); + } + + row[element_index] = values[ix]; + } + inserted_rows.append(row); } + + for (auto& inserted_row : inserted_rows) { + context.database->insert(inserted_row); + } + return SQLResult::construct(SQLCommand::Insert, 0, m_chained_expressions.size(), 0); } diff --git a/Userland/Libraries/LibSQL/SQLResult.h b/Userland/Libraries/LibSQL/SQLResult.h index a22d15e998..9702bc1eb9 100644 --- a/Userland/Libraries/LibSQL/SQLResult.h +++ b/Userland/Libraries/LibSQL/SQLResult.h @@ -54,6 +54,7 @@ constexpr char const* command_tag(SQLCommand command) S(TableExists, "Table '{}' already exist") \ S(InvalidType, "Invalid type '{}'") \ S(InvalidDatabaseName, "Invalid database name '{}'") \ + S(InvalidValueType, "Invalid type for attribute '{}'") \ S(InvalidNumberOfValues, "Number of values does not match number of columns") enum class SQLErrorCode {