diff --git a/Userland/Libraries/LibSQL/AST/AST.h b/Userland/Libraries/LibSQL/AST/AST.h index 50017b580d..8f646ba25f 100644 --- a/Userland/Libraries/LibSQL/AST/AST.h +++ b/Userland/Libraries/LibSQL/AST/AST.h @@ -306,7 +306,7 @@ struct ExecutionContext { class Expression : public ASTNode { public: - virtual Value evaluate(ExecutionContext&) const; + virtual ResultOr evaluate(ExecutionContext&) const; }; class ErrorExpression final : public Expression { @@ -320,7 +320,7 @@ public: } double value() const { return m_value; } - virtual Value evaluate(ExecutionContext&) const override; + virtual ResultOr evaluate(ExecutionContext&) const override; private: double m_value; @@ -334,7 +334,7 @@ public: } const String& value() const { return m_value; } - virtual Value evaluate(ExecutionContext&) const override; + virtual ResultOr evaluate(ExecutionContext&) const override; private: String m_value; @@ -355,13 +355,13 @@ private: class NullLiteral : public Expression { public: - virtual Value evaluate(ExecutionContext&) const override; + virtual ResultOr evaluate(ExecutionContext&) const override; }; class NestedExpression : public Expression { public: const NonnullRefPtr& expression() const { return m_expression; } - virtual Value evaluate(ExecutionContext&) const override; + virtual ResultOr evaluate(ExecutionContext&) const override; protected: explicit NestedExpression(NonnullRefPtr expression) @@ -432,7 +432,7 @@ public: const String& schema_name() const { return m_schema_name; } const String& table_name() const { return m_table_name; } const String& column_name() const { return m_column_name; } - virtual Value evaluate(ExecutionContext&) const override; + virtual ResultOr evaluate(ExecutionContext&) const override; private: String m_schema_name; @@ -475,7 +475,7 @@ public: } UnaryOperator type() const { return m_type; } - virtual Value evaluate(ExecutionContext&) const override; + virtual ResultOr evaluate(ExecutionContext&) const override; private: UnaryOperator m_type; @@ -531,7 +531,7 @@ public: } BinaryOperator type() const { return m_type; } - virtual Value evaluate(ExecutionContext&) const override; + virtual ResultOr evaluate(ExecutionContext&) const override; private: BinaryOperator m_type; @@ -545,7 +545,7 @@ public: } const NonnullRefPtrVector& expressions() const { return m_expressions; } - virtual Value evaluate(ExecutionContext&) const override; + virtual ResultOr evaluate(ExecutionContext&) const override; private: NonnullRefPtrVector m_expressions; @@ -638,7 +638,7 @@ public: MatchOperator type() const { return m_type; } const RefPtr& escape() const { return m_escape; } - virtual Value evaluate(ExecutionContext&) const override; + virtual ResultOr evaluate(ExecutionContext&) const override; private: MatchOperator m_type; diff --git a/Userland/Libraries/LibSQL/AST/Expression.cpp b/Userland/Libraries/LibSQL/AST/Expression.cpp index 6b28bbcb23..d9b3cc8696 100644 --- a/Userland/Libraries/LibSQL/AST/Expression.cpp +++ b/Userland/Libraries/LibSQL/AST/Expression.cpp @@ -12,66 +12,55 @@ namespace SQL::AST { static const String s_posix_basic_metacharacters = ".^$*[]+\\"; -Value Expression::evaluate(ExecutionContext&) const +ResultOr Expression::evaluate(ExecutionContext&) const { return Value::null(); } -Value NumericLiteral::evaluate(ExecutionContext& context) const +ResultOr NumericLiteral::evaluate(ExecutionContext&) const { - if (context.result->is_error()) - return Value::null(); Value ret(SQLType::Float); ret = value(); return ret; } -Value StringLiteral::evaluate(ExecutionContext& context) const +ResultOr StringLiteral::evaluate(ExecutionContext&) const { - if (context.result->is_error()) - return Value::null(); Value ret(SQLType::Text); ret = value(); return ret; } -Value NullLiteral::evaluate(ExecutionContext&) const +ResultOr NullLiteral::evaluate(ExecutionContext&) const { return Value::null(); } -Value NestedExpression::evaluate(ExecutionContext& context) const +ResultOr NestedExpression::evaluate(ExecutionContext& context) const { - if (context.result->is_error()) - return Value::null(); return expression()->evaluate(context); } -Value ChainedExpression::evaluate(ExecutionContext& context) const +ResultOr ChainedExpression::evaluate(ExecutionContext& context) const { - if (context.result->is_error()) - return Value::null(); Value ret(SQLType::Tuple); Vector values; - for (auto& expression : expressions()) { - values.append(expression.evaluate(context)); - } + for (auto& expression : expressions()) + values.append(TRY(expression.evaluate(context))); ret = values; return ret; } -Value BinaryOperatorExpression::evaluate(ExecutionContext& context) const +ResultOr BinaryOperatorExpression::evaluate(ExecutionContext& context) const { - if (context.result->is_error()) - return Value::null(); - Value lhs_value = lhs()->evaluate(context); - Value rhs_value = rhs()->evaluate(context); + Value lhs_value = TRY(lhs()->evaluate(context)); + Value rhs_value = TRY(rhs()->evaluate(context)); + switch (type()) { case BinaryOperator::Concatenate: { - if (lhs_value.type() != SQLType::Text) { - context.result = Result { SQLCommand::Unknown, SQLErrorCode::BooleanOperatorTypeMismatch, BinaryOperator_name(type()) }; - return Value::null(); - } + if (lhs_value.type() != SQLType::Text) + return Result { SQLCommand::Unknown, SQLErrorCode::BooleanOperatorTypeMismatch, BinaryOperator_name(type()) }; + AK::StringBuilder builder; builder.append(lhs_value.to_string()); builder.append(rhs_value.to_string()); @@ -110,19 +99,17 @@ Value BinaryOperatorExpression::evaluate(ExecutionContext& context) const case BinaryOperator::And: { auto lhs_bool_maybe = lhs_value.to_bool(); auto rhs_bool_maybe = rhs_value.to_bool(); - if (!lhs_bool_maybe.has_value() || !rhs_bool_maybe.has_value()) { - context.result = Result { SQLCommand::Unknown, SQLErrorCode::BooleanOperatorTypeMismatch, BinaryOperator_name(type()) }; - return Value::null(); - } + if (!lhs_bool_maybe.has_value() || !rhs_bool_maybe.has_value()) + return Result { SQLCommand::Unknown, SQLErrorCode::BooleanOperatorTypeMismatch, BinaryOperator_name(type()) }; + return Value(lhs_bool_maybe.release_value() && rhs_bool_maybe.release_value()); } case BinaryOperator::Or: { auto lhs_bool_maybe = lhs_value.to_bool(); auto rhs_bool_maybe = rhs_value.to_bool(); - if (!lhs_bool_maybe.has_value() || !rhs_bool_maybe.has_value()) { - context.result = Result { SQLCommand::Unknown, SQLErrorCode::BooleanOperatorTypeMismatch, BinaryOperator_name(type()) }; - return Value::null(); - } + if (!lhs_bool_maybe.has_value() || !rhs_bool_maybe.has_value()) + return Result { SQLCommand::Unknown, SQLErrorCode::BooleanOperatorTypeMismatch, BinaryOperator_name(type()) }; + return Value(lhs_bool_maybe.release_value() || rhs_bool_maybe.release_value()); } default: @@ -130,17 +117,15 @@ Value BinaryOperatorExpression::evaluate(ExecutionContext& context) const } } -Value UnaryOperatorExpression::evaluate(ExecutionContext& context) const +ResultOr UnaryOperatorExpression::evaluate(ExecutionContext& context) const { - if (context.result->is_error()) - return Value::null(); - Value expression_value = NestedExpression::evaluate(context); + Value expression_value = TRY(NestedExpression::evaluate(context)); + switch (type()) { case UnaryOperator::Plus: if (expression_value.type() == SQLType::Integer || expression_value.type() == SQLType::Float) return expression_value; - context.result = Result { SQLCommand::Unknown, SQLErrorCode::NumericOperatorTypeMismatch, UnaryOperator_name(type()) }; - return Value::null(); + return Result { SQLCommand::Unknown, SQLErrorCode::NumericOperatorTypeMismatch, UnaryOperator_name(type()) }; case UnaryOperator::Minus: if (expression_value.type() == SQLType::Integer) { expression_value = -int(expression_value); @@ -150,32 +135,29 @@ Value UnaryOperatorExpression::evaluate(ExecutionContext& context) const expression_value = -double(expression_value); return expression_value; } - context.result = Result { SQLCommand::Unknown, SQLErrorCode::NumericOperatorTypeMismatch, UnaryOperator_name(type()) }; - return Value::null(); + return Result { SQLCommand::Unknown, SQLErrorCode::NumericOperatorTypeMismatch, UnaryOperator_name(type()) }; case UnaryOperator::Not: if (expression_value.type() == SQLType::Boolean) { expression_value = !bool(expression_value); return expression_value; } - context.result = Result { SQLCommand::Unknown, SQLErrorCode::BooleanOperatorTypeMismatch, UnaryOperator_name(type()) }; - return Value::null(); + return Result { SQLCommand::Unknown, SQLErrorCode::BooleanOperatorTypeMismatch, UnaryOperator_name(type()) }; case UnaryOperator::BitwiseNot: if (expression_value.type() == SQLType::Integer) { expression_value = ~u32(expression_value); return expression_value; } - context.result = Result { SQLCommand::Unknown, SQLErrorCode::IntegerOperatorTypeMismatch, UnaryOperator_name(type()) }; - return Value::null(); + return Result { SQLCommand::Unknown, SQLErrorCode::IntegerOperatorTypeMismatch, UnaryOperator_name(type()) }; + default: + VERIFY_NOT_REACHED(); } - VERIFY_NOT_REACHED(); } -Value ColumnNameExpression::evaluate(ExecutionContext& context) const +ResultOr ColumnNameExpression::evaluate(ExecutionContext& context) const { - if (!context.current_row) { - context.result = Result { SQLCommand::Unknown, SQLErrorCode::SyntaxError, column_name() }; - return Value::null(); - } + if (!context.current_row) + return Result { SQLCommand::Unknown, SQLErrorCode::SyntaxError, column_name() }; + auto& descriptor = *context.current_row->descriptor(); VERIFY(context.current_row->size() == descriptor.size()); Optional index_in_row; @@ -184,34 +166,30 @@ Value ColumnNameExpression::evaluate(ExecutionContext& context) const if (!table_name().is_empty() && column_descriptor.table != table_name()) continue; if (column_descriptor.name == column_name()) { - if (index_in_row.has_value()) { - context.result = Result { SQLCommand::Unknown, SQLErrorCode::AmbiguousColumnName, column_name() }; - return Value::null(); - } + if (index_in_row.has_value()) + return Result { SQLCommand::Unknown, SQLErrorCode::AmbiguousColumnName, column_name() }; + index_in_row = ix; } } if (index_in_row.has_value()) return (*context.current_row)[index_in_row.value()]; - context.result = Result { SQLCommand::Unknown, SQLErrorCode::ColumnDoesNotExist, column_name() }; - return Value::null(); + + return Result { SQLCommand::Unknown, SQLErrorCode::ColumnDoesNotExist, column_name() }; } -Value MatchExpression::evaluate(ExecutionContext& context) const +ResultOr MatchExpression::evaluate(ExecutionContext& context) const { - if (context.result->is_error()) - return Value::null(); switch (type()) { case MatchOperator::Like: { - Value lhs_value = lhs()->evaluate(context); - Value rhs_value = rhs()->evaluate(context); + Value lhs_value = TRY(lhs()->evaluate(context)); + Value rhs_value = TRY(rhs()->evaluate(context)); + char escape_char = '\0'; if (escape()) { - auto escape_str = escape()->evaluate(context).to_string(); - if (escape_str.length() != 1) { - context.result = Result { SQLCommand::Unknown, SQLErrorCode::SyntaxError, "ESCAPE should be a single character" }; - return Value::null(); - } + auto escape_str = TRY(escape()->evaluate(context)).to_string(); + if (escape_str.length() != 1) + return Result { SQLCommand::Unknown, SQLErrorCode::SyntaxError, "ESCAPE should be a single character" }; escape_char = escape_str[0]; } @@ -237,14 +215,15 @@ Value MatchExpression::evaluate(ExecutionContext& context) const } } builder.append('$'); + // FIXME: We should probably cache this regex. auto regex = Regex(builder.build()); auto result = regex.match(lhs_value.to_string(), PosixFlags::Insensitive | PosixFlags::Unicode); return Value(invert_expression() ? !result.success : result.success); } case MatchOperator::Regexp: { - Value lhs_value = lhs()->evaluate(context); - Value rhs_value = rhs()->evaluate(context); + Value lhs_value = TRY(lhs()->evaluate(context)); + Value rhs_value = TRY(rhs()->evaluate(context)); auto regex = Regex(rhs_value.to_string()); auto err = regex.parser_result.error; @@ -253,8 +232,7 @@ Value MatchExpression::evaluate(ExecutionContext& context) const builder.append("Regular expression: "); builder.append(get_error_string(err)); - context.result = Result { SQLCommand::Unknown, SQLErrorCode::SyntaxError, builder.build() }; - return Value(false); + return Result { SQLCommand::Unknown, SQLErrorCode::SyntaxError, builder.build() }; } auto result = regex.match(lhs_value.to_string(), PosixFlags::Insensitive | PosixFlags::Unicode); diff --git a/Userland/Libraries/LibSQL/AST/Insert.cpp b/Userland/Libraries/LibSQL/AST/Insert.cpp index 7a608b02ca..7490a8fb6b 100644 --- a/Userland/Libraries/LibSQL/AST/Insert.cpp +++ b/Userland/Libraries/LibSQL/AST/Insert.cpp @@ -47,10 +47,7 @@ ResultOr Insert::execute(ExecutionContext& context) const row[column_def.name()] = column_def.default_value(); } - auto row_value = row_expr.evaluate(context); - if (context.result->is_error()) - return context.result.release_value(); - + auto row_value = TRY(row_expr.evaluate(context)); VERIFY(row_value.type() == SQLType::Tuple); auto values = row_value.to_vector().value(); diff --git a/Userland/Libraries/LibSQL/AST/Select.cpp b/Userland/Libraries/LibSQL/AST/Select.cpp index 6df2917cfe..91a7824853 100644 --- a/Userland/Libraries/LibSQL/AST/Select.cpp +++ b/Userland/Libraries/LibSQL/AST/Select.cpp @@ -93,9 +93,7 @@ ResultOr Select::execute(ExecutionContext& context) const context.current_row = &row; if (where_clause()) { - auto where_result = where_clause()->evaluate(context); - if (context.result->is_error()) - return context.result.release_value(); + auto where_result = TRY(where_clause()->evaluate(context)); if (!where_result) continue; } @@ -103,18 +101,14 @@ ResultOr Select::execute(ExecutionContext& context) const tuple.clear(); for (auto& col : columns) { - auto value = col.expression()->evaluate(context); - if (context.result->is_error()) - return context.result.release_value(); + auto value = TRY(col.expression()->evaluate(context)); tuple.append(value); } if (has_ordering) { sort_key.clear(); for (auto& term : m_ordering_term_list) { - auto value = term.expression()->evaluate(context); - if (context.result->is_error()) - return context.result.release_value(); + auto value = TRY(term.expression()->evaluate(context)); sort_key.append(value); } } @@ -126,7 +120,7 @@ ResultOr Select::execute(ExecutionContext& context) const size_t limit_value = NumericLimits::max(); size_t offset_value = 0; - auto limit = m_limit_clause->limit_expression()->evaluate(context); + auto limit = TRY(m_limit_clause->limit_expression()->evaluate(context)); if (!limit.is_null()) { auto limit_value_maybe = limit.to_u32(); if (!limit_value_maybe.has_value()) @@ -136,7 +130,7 @@ ResultOr Select::execute(ExecutionContext& context) const } if (m_limit_clause->offset_expression() != nullptr) { - auto offset = m_limit_clause->offset_expression()->evaluate(context); + auto offset = TRY(m_limit_clause->offset_expression()->evaluate(context)); if (!offset.is_null()) { auto offset_value_maybe = offset.to_u32(); if (!offset_value_maybe.has_value())