From f3c6cb40d720039e87b331a9762eb74ebd42166e Mon Sep 17 00:00:00 2001 From: Timothy Flynn Date: Thu, 10 Feb 2022 07:46:36 -0500 Subject: [PATCH] LibSQL: Convert SQL expression evaluation to use ResultOr Instead of setting an error in the execution context, we can directly return that error or the successful value. This lets all callers, who were already TRY-capable, simply TRY the expression evaluation. --- Userland/Libraries/LibSQL/AST/AST.h | 20 +-- Userland/Libraries/LibSQL/AST/Expression.cpp | 122 ++++++++----------- Userland/Libraries/LibSQL/AST/Insert.cpp | 5 +- Userland/Libraries/LibSQL/AST/Select.cpp | 16 +-- 4 files changed, 66 insertions(+), 97 deletions(-) diff --git a/Userland/Libraries/LibSQL/AST/AST.h b/Userland/Libraries/LibSQL/AST/AST.h index 50017b580d..8f646ba25f 100644 --- a/Userland/Libraries/LibSQL/AST/AST.h +++ b/Userland/Libraries/LibSQL/AST/AST.h @@ -306,7 +306,7 @@ struct ExecutionContext { class Expression : public ASTNode { public: - virtual Value evaluate(ExecutionContext&) const; + virtual ResultOr evaluate(ExecutionContext&) const; }; class ErrorExpression final : public Expression { @@ -320,7 +320,7 @@ public: } double value() const { return m_value; } - virtual Value evaluate(ExecutionContext&) const override; + virtual ResultOr evaluate(ExecutionContext&) const override; private: double m_value; @@ -334,7 +334,7 @@ public: } const String& value() const { return m_value; } - virtual Value evaluate(ExecutionContext&) const override; + virtual ResultOr evaluate(ExecutionContext&) const override; private: String m_value; @@ -355,13 +355,13 @@ private: class NullLiteral : public Expression { public: - virtual Value evaluate(ExecutionContext&) const override; + virtual ResultOr evaluate(ExecutionContext&) const override; }; class NestedExpression : public Expression { public: const NonnullRefPtr& expression() const { return m_expression; } - virtual Value evaluate(ExecutionContext&) const override; + virtual ResultOr evaluate(ExecutionContext&) const override; protected: explicit NestedExpression(NonnullRefPtr expression) @@ -432,7 +432,7 @@ public: const String& schema_name() const { return m_schema_name; } const String& table_name() const { return m_table_name; } const String& column_name() const { return m_column_name; } - virtual Value evaluate(ExecutionContext&) const override; + virtual ResultOr evaluate(ExecutionContext&) const override; private: String m_schema_name; @@ -475,7 +475,7 @@ public: } UnaryOperator type() const { return m_type; } - virtual Value evaluate(ExecutionContext&) const override; + virtual ResultOr evaluate(ExecutionContext&) const override; private: UnaryOperator m_type; @@ -531,7 +531,7 @@ public: } BinaryOperator type() const { return m_type; } - virtual Value evaluate(ExecutionContext&) const override; + virtual ResultOr evaluate(ExecutionContext&) const override; private: BinaryOperator m_type; @@ -545,7 +545,7 @@ public: } const NonnullRefPtrVector& expressions() const { return m_expressions; } - virtual Value evaluate(ExecutionContext&) const override; + virtual ResultOr evaluate(ExecutionContext&) const override; private: NonnullRefPtrVector m_expressions; @@ -638,7 +638,7 @@ public: MatchOperator type() const { return m_type; } const RefPtr& escape() const { return m_escape; } - virtual Value evaluate(ExecutionContext&) const override; + virtual ResultOr evaluate(ExecutionContext&) const override; private: MatchOperator m_type; diff --git a/Userland/Libraries/LibSQL/AST/Expression.cpp b/Userland/Libraries/LibSQL/AST/Expression.cpp index 6b28bbcb23..d9b3cc8696 100644 --- a/Userland/Libraries/LibSQL/AST/Expression.cpp +++ b/Userland/Libraries/LibSQL/AST/Expression.cpp @@ -12,66 +12,55 @@ namespace SQL::AST { static const String s_posix_basic_metacharacters = ".^$*[]+\\"; -Value Expression::evaluate(ExecutionContext&) const +ResultOr Expression::evaluate(ExecutionContext&) const { return Value::null(); } -Value NumericLiteral::evaluate(ExecutionContext& context) const +ResultOr NumericLiteral::evaluate(ExecutionContext&) const { - if (context.result->is_error()) - return Value::null(); Value ret(SQLType::Float); ret = value(); return ret; } -Value StringLiteral::evaluate(ExecutionContext& context) const +ResultOr StringLiteral::evaluate(ExecutionContext&) const { - if (context.result->is_error()) - return Value::null(); Value ret(SQLType::Text); ret = value(); return ret; } -Value NullLiteral::evaluate(ExecutionContext&) const +ResultOr NullLiteral::evaluate(ExecutionContext&) const { return Value::null(); } -Value NestedExpression::evaluate(ExecutionContext& context) const +ResultOr NestedExpression::evaluate(ExecutionContext& context) const { - if (context.result->is_error()) - return Value::null(); return expression()->evaluate(context); } -Value ChainedExpression::evaluate(ExecutionContext& context) const +ResultOr ChainedExpression::evaluate(ExecutionContext& context) const { - if (context.result->is_error()) - return Value::null(); Value ret(SQLType::Tuple); Vector values; - for (auto& expression : expressions()) { - values.append(expression.evaluate(context)); - } + for (auto& expression : expressions()) + values.append(TRY(expression.evaluate(context))); ret = values; return ret; } -Value BinaryOperatorExpression::evaluate(ExecutionContext& context) const +ResultOr BinaryOperatorExpression::evaluate(ExecutionContext& context) const { - if (context.result->is_error()) - return Value::null(); - Value lhs_value = lhs()->evaluate(context); - Value rhs_value = rhs()->evaluate(context); + Value lhs_value = TRY(lhs()->evaluate(context)); + Value rhs_value = TRY(rhs()->evaluate(context)); + switch (type()) { case BinaryOperator::Concatenate: { - if (lhs_value.type() != SQLType::Text) { - context.result = Result { SQLCommand::Unknown, SQLErrorCode::BooleanOperatorTypeMismatch, BinaryOperator_name(type()) }; - return Value::null(); - } + if (lhs_value.type() != SQLType::Text) + return Result { SQLCommand::Unknown, SQLErrorCode::BooleanOperatorTypeMismatch, BinaryOperator_name(type()) }; + AK::StringBuilder builder; builder.append(lhs_value.to_string()); builder.append(rhs_value.to_string()); @@ -110,19 +99,17 @@ Value BinaryOperatorExpression::evaluate(ExecutionContext& context) const case BinaryOperator::And: { auto lhs_bool_maybe = lhs_value.to_bool(); auto rhs_bool_maybe = rhs_value.to_bool(); - if (!lhs_bool_maybe.has_value() || !rhs_bool_maybe.has_value()) { - context.result = Result { SQLCommand::Unknown, SQLErrorCode::BooleanOperatorTypeMismatch, BinaryOperator_name(type()) }; - return Value::null(); - } + if (!lhs_bool_maybe.has_value() || !rhs_bool_maybe.has_value()) + return Result { SQLCommand::Unknown, SQLErrorCode::BooleanOperatorTypeMismatch, BinaryOperator_name(type()) }; + return Value(lhs_bool_maybe.release_value() && rhs_bool_maybe.release_value()); } case BinaryOperator::Or: { auto lhs_bool_maybe = lhs_value.to_bool(); auto rhs_bool_maybe = rhs_value.to_bool(); - if (!lhs_bool_maybe.has_value() || !rhs_bool_maybe.has_value()) { - context.result = Result { SQLCommand::Unknown, SQLErrorCode::BooleanOperatorTypeMismatch, BinaryOperator_name(type()) }; - return Value::null(); - } + if (!lhs_bool_maybe.has_value() || !rhs_bool_maybe.has_value()) + return Result { SQLCommand::Unknown, SQLErrorCode::BooleanOperatorTypeMismatch, BinaryOperator_name(type()) }; + return Value(lhs_bool_maybe.release_value() || rhs_bool_maybe.release_value()); } default: @@ -130,17 +117,15 @@ Value BinaryOperatorExpression::evaluate(ExecutionContext& context) const } } -Value UnaryOperatorExpression::evaluate(ExecutionContext& context) const +ResultOr UnaryOperatorExpression::evaluate(ExecutionContext& context) const { - if (context.result->is_error()) - return Value::null(); - Value expression_value = NestedExpression::evaluate(context); + Value expression_value = TRY(NestedExpression::evaluate(context)); + switch (type()) { case UnaryOperator::Plus: if (expression_value.type() == SQLType::Integer || expression_value.type() == SQLType::Float) return expression_value; - context.result = Result { SQLCommand::Unknown, SQLErrorCode::NumericOperatorTypeMismatch, UnaryOperator_name(type()) }; - return Value::null(); + return Result { SQLCommand::Unknown, SQLErrorCode::NumericOperatorTypeMismatch, UnaryOperator_name(type()) }; case UnaryOperator::Minus: if (expression_value.type() == SQLType::Integer) { expression_value = -int(expression_value); @@ -150,32 +135,29 @@ Value UnaryOperatorExpression::evaluate(ExecutionContext& context) const expression_value = -double(expression_value); return expression_value; } - context.result = Result { SQLCommand::Unknown, SQLErrorCode::NumericOperatorTypeMismatch, UnaryOperator_name(type()) }; - return Value::null(); + return Result { SQLCommand::Unknown, SQLErrorCode::NumericOperatorTypeMismatch, UnaryOperator_name(type()) }; case UnaryOperator::Not: if (expression_value.type() == SQLType::Boolean) { expression_value = !bool(expression_value); return expression_value; } - context.result = Result { SQLCommand::Unknown, SQLErrorCode::BooleanOperatorTypeMismatch, UnaryOperator_name(type()) }; - return Value::null(); + return Result { SQLCommand::Unknown, SQLErrorCode::BooleanOperatorTypeMismatch, UnaryOperator_name(type()) }; case UnaryOperator::BitwiseNot: if (expression_value.type() == SQLType::Integer) { expression_value = ~u32(expression_value); return expression_value; } - context.result = Result { SQLCommand::Unknown, SQLErrorCode::IntegerOperatorTypeMismatch, UnaryOperator_name(type()) }; - return Value::null(); + return Result { SQLCommand::Unknown, SQLErrorCode::IntegerOperatorTypeMismatch, UnaryOperator_name(type()) }; + default: + VERIFY_NOT_REACHED(); } - VERIFY_NOT_REACHED(); } -Value ColumnNameExpression::evaluate(ExecutionContext& context) const +ResultOr ColumnNameExpression::evaluate(ExecutionContext& context) const { - if (!context.current_row) { - context.result = Result { SQLCommand::Unknown, SQLErrorCode::SyntaxError, column_name() }; - return Value::null(); - } + if (!context.current_row) + return Result { SQLCommand::Unknown, SQLErrorCode::SyntaxError, column_name() }; + auto& descriptor = *context.current_row->descriptor(); VERIFY(context.current_row->size() == descriptor.size()); Optional index_in_row; @@ -184,34 +166,30 @@ Value ColumnNameExpression::evaluate(ExecutionContext& context) const if (!table_name().is_empty() && column_descriptor.table != table_name()) continue; if (column_descriptor.name == column_name()) { - if (index_in_row.has_value()) { - context.result = Result { SQLCommand::Unknown, SQLErrorCode::AmbiguousColumnName, column_name() }; - return Value::null(); - } + if (index_in_row.has_value()) + return Result { SQLCommand::Unknown, SQLErrorCode::AmbiguousColumnName, column_name() }; + index_in_row = ix; } } if (index_in_row.has_value()) return (*context.current_row)[index_in_row.value()]; - context.result = Result { SQLCommand::Unknown, SQLErrorCode::ColumnDoesNotExist, column_name() }; - return Value::null(); + + return Result { SQLCommand::Unknown, SQLErrorCode::ColumnDoesNotExist, column_name() }; } -Value MatchExpression::evaluate(ExecutionContext& context) const +ResultOr MatchExpression::evaluate(ExecutionContext& context) const { - if (context.result->is_error()) - return Value::null(); switch (type()) { case MatchOperator::Like: { - Value lhs_value = lhs()->evaluate(context); - Value rhs_value = rhs()->evaluate(context); + Value lhs_value = TRY(lhs()->evaluate(context)); + Value rhs_value = TRY(rhs()->evaluate(context)); + char escape_char = '\0'; if (escape()) { - auto escape_str = escape()->evaluate(context).to_string(); - if (escape_str.length() != 1) { - context.result = Result { SQLCommand::Unknown, SQLErrorCode::SyntaxError, "ESCAPE should be a single character" }; - return Value::null(); - } + auto escape_str = TRY(escape()->evaluate(context)).to_string(); + if (escape_str.length() != 1) + return Result { SQLCommand::Unknown, SQLErrorCode::SyntaxError, "ESCAPE should be a single character" }; escape_char = escape_str[0]; } @@ -237,14 +215,15 @@ Value MatchExpression::evaluate(ExecutionContext& context) const } } builder.append('$'); + // FIXME: We should probably cache this regex. auto regex = Regex(builder.build()); auto result = regex.match(lhs_value.to_string(), PosixFlags::Insensitive | PosixFlags::Unicode); return Value(invert_expression() ? !result.success : result.success); } case MatchOperator::Regexp: { - Value lhs_value = lhs()->evaluate(context); - Value rhs_value = rhs()->evaluate(context); + Value lhs_value = TRY(lhs()->evaluate(context)); + Value rhs_value = TRY(rhs()->evaluate(context)); auto regex = Regex(rhs_value.to_string()); auto err = regex.parser_result.error; @@ -253,8 +232,7 @@ Value MatchExpression::evaluate(ExecutionContext& context) const builder.append("Regular expression: "); builder.append(get_error_string(err)); - context.result = Result { SQLCommand::Unknown, SQLErrorCode::SyntaxError, builder.build() }; - return Value(false); + return Result { SQLCommand::Unknown, SQLErrorCode::SyntaxError, builder.build() }; } auto result = regex.match(lhs_value.to_string(), PosixFlags::Insensitive | PosixFlags::Unicode); diff --git a/Userland/Libraries/LibSQL/AST/Insert.cpp b/Userland/Libraries/LibSQL/AST/Insert.cpp index 7a608b02ca..7490a8fb6b 100644 --- a/Userland/Libraries/LibSQL/AST/Insert.cpp +++ b/Userland/Libraries/LibSQL/AST/Insert.cpp @@ -47,10 +47,7 @@ ResultOr Insert::execute(ExecutionContext& context) const row[column_def.name()] = column_def.default_value(); } - auto row_value = row_expr.evaluate(context); - if (context.result->is_error()) - return context.result.release_value(); - + auto row_value = TRY(row_expr.evaluate(context)); VERIFY(row_value.type() == SQLType::Tuple); auto values = row_value.to_vector().value(); diff --git a/Userland/Libraries/LibSQL/AST/Select.cpp b/Userland/Libraries/LibSQL/AST/Select.cpp index 6df2917cfe..91a7824853 100644 --- a/Userland/Libraries/LibSQL/AST/Select.cpp +++ b/Userland/Libraries/LibSQL/AST/Select.cpp @@ -93,9 +93,7 @@ ResultOr Select::execute(ExecutionContext& context) const context.current_row = &row; if (where_clause()) { - auto where_result = where_clause()->evaluate(context); - if (context.result->is_error()) - return context.result.release_value(); + auto where_result = TRY(where_clause()->evaluate(context)); if (!where_result) continue; } @@ -103,18 +101,14 @@ ResultOr Select::execute(ExecutionContext& context) const tuple.clear(); for (auto& col : columns) { - auto value = col.expression()->evaluate(context); - if (context.result->is_error()) - return context.result.release_value(); + auto value = TRY(col.expression()->evaluate(context)); tuple.append(value); } if (has_ordering) { sort_key.clear(); for (auto& term : m_ordering_term_list) { - auto value = term.expression()->evaluate(context); - if (context.result->is_error()) - return context.result.release_value(); + auto value = TRY(term.expression()->evaluate(context)); sort_key.append(value); } } @@ -126,7 +120,7 @@ ResultOr Select::execute(ExecutionContext& context) const size_t limit_value = NumericLimits::max(); size_t offset_value = 0; - auto limit = m_limit_clause->limit_expression()->evaluate(context); + auto limit = TRY(m_limit_clause->limit_expression()->evaluate(context)); if (!limit.is_null()) { auto limit_value_maybe = limit.to_u32(); if (!limit_value_maybe.has_value()) @@ -136,7 +130,7 @@ ResultOr Select::execute(ExecutionContext& context) const } if (m_limit_clause->offset_expression() != nullptr) { - auto offset = m_limit_clause->offset_expression()->evaluate(context); + auto offset = TRY(m_limit_clause->offset_expression()->evaluate(context)); if (!offset.is_null()) { auto offset_value_maybe = offset.to_u32(); if (!offset_value_maybe.has_value())