From 246ab432ffc579b4ae4374b1a699bd1e689dc506 Mon Sep 17 00:00:00 2001 From: Ali Mohammad Pur Date: Sun, 12 Sep 2021 17:30:27 +0430 Subject: [PATCH] LibRegex: Add a basic optimization pass This currently tries to convert forking loops to atomic groups, and unify the left side of alternations. --- Tests/LibRegex/Regex.cpp | 19 + Userland/Libraries/LibRegex/CMakeLists.txt | 1 + Userland/Libraries/LibRegex/RegexByteCode.cpp | 60 ++- Userland/Libraries/LibRegex/RegexByteCode.h | 39 +- .../LibRegex/RegexBytecodeStreamOptimizer.h | 18 + Userland/Libraries/LibRegex/RegexMatch.h | 3 + Userland/Libraries/LibRegex/RegexMatcher.cpp | 69 ++- Userland/Libraries/LibRegex/RegexMatcher.h | 15 + .../Libraries/LibRegex/RegexOptimizer.cpp | 477 ++++++++++++++++++ 9 files changed, 677 insertions(+), 24 deletions(-) create mode 100644 Userland/Libraries/LibRegex/RegexBytecodeStreamOptimizer.h create mode 100644 Userland/Libraries/LibRegex/RegexOptimizer.cpp diff --git a/Tests/LibRegex/Regex.cpp b/Tests/LibRegex/Regex.cpp index 5203a277a0..4aeffb8da8 100644 --- a/Tests/LibRegex/Regex.cpp +++ b/Tests/LibRegex/Regex.cpp @@ -8,6 +8,7 @@ #include // import first, to prevent warning of VERIFY* redefinition #include +#include #include #include #include @@ -887,3 +888,21 @@ BENCHMARK_CASE(fork_performance) auto result = re.match(g_lots_of_a_s); EXPECT_EQ(result.success, true); } + +TEST_CASE(optimizer_atomic_groups) +{ + Array tests { + // Fork -> ForkReplace + Tuple { "a*b"sv, "aaaaa"sv, false }, + Tuple { "a+b"sv, "aaaaa"sv, false }, + // Alternative fuse + Tuple { "(abcfoo|abcbar|abcbaz).*x"sv, "abcbarx"sv, true }, + Tuple { "(a|a)"sv, "a"sv, true }, + }; + + for (auto& test : tests) { + Regex re(test.get<0>()); + auto result = re.match(test.get<1>()); + EXPECT_EQ(result.success, test.get<2>()); + } +} diff --git a/Userland/Libraries/LibRegex/CMakeLists.txt b/Userland/Libraries/LibRegex/CMakeLists.txt index c1a25fe4b0..745ff01424 100644 --- a/Userland/Libraries/LibRegex/CMakeLists.txt +++ b/Userland/Libraries/LibRegex/CMakeLists.txt @@ -3,6 +3,7 @@ set(SOURCES RegexByteCode.cpp RegexLexer.cpp RegexMatcher.cpp + RegexOptimizer.cpp RegexParser.cpp ) diff --git a/Userland/Libraries/LibRegex/RegexByteCode.cpp b/Userland/Libraries/LibRegex/RegexByteCode.cpp index ded6ff79ea..231c076d10 100644 --- a/Userland/Libraries/LibRegex/RegexByteCode.cpp +++ b/Userland/Libraries/LibRegex/RegexByteCode.cpp @@ -245,12 +245,26 @@ ALWAYS_INLINE ExecutionResult OpCode_ForkJump::execute(MatchInput const&, MatchS return ExecutionResult::Fork_PrioHigh; } +ALWAYS_INLINE ExecutionResult OpCode_ForkReplaceJump::execute(MatchInput const& input, MatchState& state) const +{ + state.fork_at_position = state.instruction_position + size() + offset(); + input.fork_to_replace = state.instruction_position; + return ExecutionResult::Fork_PrioHigh; +} + ALWAYS_INLINE ExecutionResult OpCode_ForkStay::execute(MatchInput const&, MatchState& state) const { state.fork_at_position = state.instruction_position + size() + offset(); return ExecutionResult::Fork_PrioLow; } +ALWAYS_INLINE ExecutionResult OpCode_ForkReplaceStay::execute(MatchInput const& input, MatchState& state) const +{ + state.fork_at_position = state.instruction_position + size() + offset(); + input.fork_to_replace = state.instruction_position; + return ExecutionResult::Fork_PrioLow; +} + ALWAYS_INLINE ExecutionResult OpCode_CheckBegin::execute(MatchInput const& input, MatchState& state) const { if (0 == state.string_position && (input.regex_options & AllFlags::MatchNotBeginOfLine)) @@ -778,6 +792,40 @@ String const OpCode_Compare::arguments_string() const return String::formatted("argc={}, args={} ", arguments_count(), arguments_size()); } +Vector OpCode_Compare::flat_compares() const +{ + Vector result; + + size_t offset { state().instruction_position + 3 }; + + for (size_t i = 0; i < arguments_count(); ++i) { + auto compare_type = (CharacterCompareType)m_bytecode->at(offset++); + + if (compare_type == CharacterCompareType::Char) { + auto ch = m_bytecode->at(offset++); + result.append({ compare_type, ch }); + } else if (compare_type == CharacterCompareType::Reference) { + auto ref = m_bytecode->at(offset++); + result.append({ compare_type, ref }); + } else if (compare_type == CharacterCompareType::String) { + auto& length = m_bytecode->at(offset++); + if (length > 0) + result.append({ compare_type, m_bytecode->at(offset) }); + StringBuilder str_builder; + offset += length; + } else if (compare_type == CharacterCompareType::CharClass) { + auto character_class = m_bytecode->at(offset++); + result.append({ compare_type, character_class }); + } else if (compare_type == CharacterCompareType::CharRange) { + auto value = m_bytecode->at(offset++); + result.append({ compare_type, value }); + } else { + result.append({ compare_type, 0 }); + } + } + return result; +} + Vector const OpCode_Compare::variable_arguments_to_string(Optional input) const { Vector result; @@ -834,7 +882,7 @@ Vector const OpCode_Compare::variable_arguments_to_string(Optional view.length() ? 0 : 1).to_string())); } else if (compare_type == CharacterCompareType::CharRange) { auto value = (CharRange)m_bytecode->at(offset++); - result.empend(String::formatted("ch_range='{:c}'-'{:c}'", value.from, value.to)); + result.empend(String::formatted("ch_range={:x}-{:x}", value.from, value.to)); if (!view.is_null() && view.length() > state().string_position) result.empend(String::formatted( "compare against: '{}'", @@ -896,6 +944,16 @@ ALWAYS_INLINE ExecutionResult OpCode_JumpNonEmpty::execute(MatchInput const& inp if (form == OpCodeId::ForkStay) return ExecutionResult::Fork_PrioLow; + + if (form == OpCodeId::ForkReplaceStay) { + input.fork_to_replace = state.instruction_position; + return ExecutionResult::Fork_PrioLow; + } + + if (form == OpCodeId::ForkReplaceJump) { + input.fork_to_replace = state.instruction_position; + return ExecutionResult::Fork_PrioHigh; + } } return ExecutionResult::Continue; diff --git a/Userland/Libraries/LibRegex/RegexByteCode.h b/Userland/Libraries/LibRegex/RegexByteCode.h index 90520dd444..c0e1c97ee5 100644 --- a/Userland/Libraries/LibRegex/RegexByteCode.h +++ b/Userland/Libraries/LibRegex/RegexByteCode.h @@ -6,6 +6,7 @@ #pragma once +#include "RegexBytecodeStreamOptimizer.h" #include "RegexMatch.h" #include "RegexOptions.h" @@ -30,6 +31,8 @@ using ByteCodeValueType = u64; __ENUMERATE_OPCODE(JumpNonEmpty) \ __ENUMERATE_OPCODE(ForkJump) \ __ENUMERATE_OPCODE(ForkStay) \ + __ENUMERATE_OPCODE(ForkReplaceJump) \ + __ENUMERATE_OPCODE(ForkReplaceStay) \ __ENUMERATE_OPCODE(FailForks) \ __ENUMERATE_OPCODE(SaveLeftCaptureGroup) \ __ENUMERATE_OPCODE(SaveRightCaptureGroup) \ @@ -306,7 +309,7 @@ public: VERIFY_NOT_REACHED(); } - void insert_bytecode_alternation(ByteCode&& left, ByteCode&& right) + void insert_bytecode_alternation(ByteCode left, ByteCode right) { // FORKJUMP _ALT @@ -316,21 +319,8 @@ public: // REGEXP ALT1 // LABEL _END - ByteCode byte_code; - - empend(static_cast(OpCodeId::ForkJump)); - empend(right.size() + 2); // Jump to the _ALT label - - extend(right); - - empend(static_cast(OpCodeId::Jump)); - empend(left.size()); // Jump to the _END label - - // LABEL _ALT = bytecode.size() + 2 - - extend(left); - - // LABEL _END = alterantive_bytecode.size + // Optimisation: Eliminate extra work by unifying common pre-and-postfix exprs. + Optimizer::append_alternation(*this, left, right); } template @@ -625,7 +615,7 @@ public: } }; -class OpCode_ForkJump final : public OpCode { +class OpCode_ForkJump : public OpCode { public: ExecutionResult execute(MatchInput const& input, MatchState& state) const override; ALWAYS_INLINE OpCodeId opcode_id() const override { return OpCodeId::ForkJump; } @@ -637,7 +627,13 @@ public: } }; -class OpCode_ForkStay final : public OpCode { +class OpCode_ForkReplaceJump final : public OpCode_ForkJump { +public: + ExecutionResult execute(MatchInput const& input, MatchState& state) const override; + ALWAYS_INLINE OpCodeId opcode_id() const override { return OpCodeId::ForkReplaceJump; } +}; + +class OpCode_ForkStay : public OpCode { public: ExecutionResult execute(MatchInput const& input, MatchState& state) const override; ALWAYS_INLINE OpCodeId opcode_id() const override { return OpCodeId::ForkStay; } @@ -649,6 +645,12 @@ public: } }; +class OpCode_ForkReplaceStay final : public OpCode_ForkStay { +public: + ExecutionResult execute(MatchInput const& input, MatchState& state) const override; + ALWAYS_INLINE OpCodeId opcode_id() const override { return OpCodeId::ForkReplaceStay; } +}; + class OpCode_CheckBegin final : public OpCode { public: ExecutionResult execute(MatchInput const& input, MatchState& state) const override; @@ -725,6 +727,7 @@ public: ALWAYS_INLINE size_t arguments_size() const { return argument(1); } String const arguments_string() const override; Vector const variable_arguments_to_string(Optional input = {}) const; + Vector flat_compares() const; private: ALWAYS_INLINE static void compare_char(MatchInput const& input, MatchState& state, u32 ch1, bool inverse, bool& inverse_matched); diff --git a/Userland/Libraries/LibRegex/RegexBytecodeStreamOptimizer.h b/Userland/Libraries/LibRegex/RegexBytecodeStreamOptimizer.h new file mode 100644 index 0000000000..bc1c703402 --- /dev/null +++ b/Userland/Libraries/LibRegex/RegexBytecodeStreamOptimizer.h @@ -0,0 +1,18 @@ +/* + * Copyright (c) 2021, Ali Mohammad Pur + * + * SPDX-License-Identifier: BSD-2-Clause + */ + +#pragma once + +#include "Forward.h" + +namespace regex { + +class Optimizer { +public: + static void append_alternation(ByteCode& target, ByteCode& left, ByteCode& right); +}; + +} diff --git a/Userland/Libraries/LibRegex/RegexMatch.h b/Userland/Libraries/LibRegex/RegexMatch.h index efd5bd0e4a..60a1241eb3 100644 --- a/Userland/Libraries/LibRegex/RegexMatch.h +++ b/Userland/Libraries/LibRegex/RegexMatch.h @@ -6,6 +6,7 @@ #pragma once +#include "Forward.h" #include "RegexOptions.h" #include @@ -514,6 +515,7 @@ struct MatchInput { mutable Vector saved_positions; mutable Vector saved_code_unit_positions; mutable HashMap checkpoints; + mutable Optional fork_to_replace; }; struct MatchState { @@ -522,6 +524,7 @@ struct MatchState { size_t string_position_in_code_units { 0 }; size_t instruction_position { 0 }; size_t fork_at_position { 0 }; + Optional initiating_fork; Vector matches; Vector> capture_group_matches; Vector repetition_marks; diff --git a/Userland/Libraries/LibRegex/RegexMatcher.cpp b/Userland/Libraries/LibRegex/RegexMatcher.cpp index 12f63c016e..b0c8708bad 100644 --- a/Userland/Libraries/LibRegex/RegexMatcher.cpp +++ b/Userland/Libraries/LibRegex/RegexMatcher.cpp @@ -39,6 +39,7 @@ Regex::Regex(String pattern, typename ParserTraits::OptionsType Parser parser(lexer, regex_options); parser_result = parser.parse(); + run_optimization_passes(); if (parser_result.error == regex::Error::NoError) matcher = make>(this, regex_options); } @@ -48,6 +49,7 @@ Regex::Regex(regex::Parser::Result parse_result, String pattern, typenam : pattern_value(move(pattern)) , parser_result(move(parse_result)) { + run_optimization_passes(); if (parser_result.error == regex::Error::NoError) matcher = make>(this, regex_options); } @@ -370,6 +372,9 @@ public: return m_first == nullptr; } + auto reverse_begin() { return ReverseIterator(m_last); } + auto reverse_end() { return ReverseIterator(); } + private: struct Node { T value; @@ -377,6 +382,27 @@ private: Node* previous { nullptr }; }; + struct ReverseIterator { + ReverseIterator() = default; + explicit ReverseIterator(Node* node) + : m_node(node) + { + } + + T* operator->() { return &m_node->value; } + T& operator*() { return m_node->value; } + bool operator==(ReverseIterator const& it) const { return m_node == it.m_node; } + ReverseIterator& operator++() + { + if (m_node) + m_node = m_node->previous; + return *this; + } + + private: + Node* m_node; + }; + UniformBumpAllocator m_allocator; Node* m_first { nullptr }; Node* m_last { nullptr }; @@ -413,15 +439,48 @@ Optional Matcher::execute(MatchInput const& input, MatchState& sta state.instruction_position += opcode.size(); switch (result) { - case ExecutionResult::Fork_PrioLow: - states_to_try_next.append(state); - states_to_try_next.last().instruction_position = state.fork_at_position; + case ExecutionResult::Fork_PrioLow: { + bool found = false; + if (input.fork_to_replace.has_value()) { + for (auto it = states_to_try_next.reverse_begin(); it != states_to_try_next.reverse_end(); ++it) { + if (it->initiating_fork == input.fork_to_replace.value()) { + (*it) = state; + it->instruction_position = state.fork_at_position; + it->initiating_fork = *input.fork_to_replace; + found = true; + break; + } + } + input.fork_to_replace.clear(); + } + if (!found) { + states_to_try_next.append(state); + states_to_try_next.last().initiating_fork = state.instruction_position - opcode.size(); + states_to_try_next.last().instruction_position = state.fork_at_position; + } continue; - case ExecutionResult::Fork_PrioHigh: - states_to_try_next.append(state); + } + case ExecutionResult::Fork_PrioHigh: { + bool found = false; + if (input.fork_to_replace.has_value()) { + for (auto it = states_to_try_next.reverse_begin(); it != states_to_try_next.reverse_end(); ++it) { + if (it->initiating_fork == input.fork_to_replace.value()) { + (*it) = state; + it->initiating_fork = *input.fork_to_replace; + found = true; + break; + } + } + input.fork_to_replace.clear(); + } + if (!found) { + states_to_try_next.append(state); + states_to_try_next.last().initiating_fork = state.instruction_position - opcode.size(); + } state.instruction_position = state.fork_at_position; ++recursion_level; continue; + } case ExecutionResult::Continue: continue; case ExecutionResult::Succeeded: diff --git a/Userland/Libraries/LibRegex/RegexMatcher.h b/Userland/Libraries/LibRegex/RegexMatcher.h index e66aa8dc5b..a7d0b629fd 100644 --- a/Userland/Libraries/LibRegex/RegexMatcher.h +++ b/Userland/Libraries/LibRegex/RegexMatcher.h @@ -24,6 +24,15 @@ namespace regex { +namespace Detail { + +struct Block { + size_t start; + size_t end; +}; + +} + static constexpr const size_t c_max_recursion = 5000; static constexpr const size_t c_match_preallocation_count = 0; @@ -217,6 +226,12 @@ public: RegexResult result = matcher->match(views, AllOptions { regex_options.value_or({}) } | AllFlags::SkipSubExprResults); return result.success; } + +private: + void run_optimization_passes(); + using BasicBlockList = Vector; + BasicBlockList split_basic_blocks(); + void attempt_rewrite_loops_as_atomic_groups(BasicBlockList const&); }; // free standing functions for match, search and has_match diff --git a/Userland/Libraries/LibRegex/RegexOptimizer.cpp b/Userland/Libraries/LibRegex/RegexOptimizer.cpp new file mode 100644 index 0000000000..5c63243ee5 --- /dev/null +++ b/Userland/Libraries/LibRegex/RegexOptimizer.cpp @@ -0,0 +1,477 @@ +/* + * Copyright (c) 2021, Ali Mohammad Pur + * + * SPDX-License-Identifier: BSD-2-Clause + */ + +#include +#include +#include +#include +#include + +namespace regex { + +using Detail::Block; + +template +void Regex::run_optimization_passes() +{ + // Rewrite fork loops as atomic groups + // e.g. a*b -> (ATOMIC a*)b + attempt_rewrite_loops_as_atomic_groups(split_basic_blocks()); +} + +template +typename Regex::BasicBlockList Regex::split_basic_blocks() +{ + BasicBlockList block_boundaries; + auto& bytecode = parser_result.bytecode; + size_t end_of_last_block = 0; + + MatchState state; + state.instruction_position = 0; + auto check_jump = [&](OpCode const& opcode) { + auto& op = static_cast(opcode); + ssize_t jump_offset = op.size() + op.offset(); + if (jump_offset >= 0) { + block_boundaries.append({ end_of_last_block, state.instruction_position }); + end_of_last_block = state.instruction_position + opcode.size(); + } else { + // This op jumps back, see if that's within this "block". + if (jump_offset + state.instruction_position > end_of_last_block) { + // Split the block! + block_boundaries.append({ end_of_last_block, jump_offset + state.instruction_position }); + block_boundaries.append({ jump_offset + state.instruction_position, state.instruction_position }); + end_of_last_block = state.instruction_position + opcode.size(); + } else { + // Nope, it's just a jump to another block + block_boundaries.append({ end_of_last_block, state.instruction_position }); + end_of_last_block = state.instruction_position + opcode.size(); + } + } + }; + for (;;) { + auto& opcode = bytecode.get_opcode(state); + + switch (opcode.opcode_id()) { + case OpCodeId::Jump: + check_jump.template operator()(opcode); + break; + case OpCodeId::JumpNonEmpty: + check_jump.template operator()(opcode); + break; + case OpCodeId::ForkJump: + check_jump.template operator()(opcode); + break; + case OpCodeId::ForkStay: + check_jump.template operator()(opcode); + break; + case OpCodeId::FailForks: + block_boundaries.append({ end_of_last_block, state.instruction_position }); + end_of_last_block = state.instruction_position + opcode.size(); + break; + case OpCodeId::Repeat: { + // Repeat produces two blocks, one containing its repeated expr, and one after that. + auto repeat_start = state.instruction_position - static_cast(opcode).offset(); + if (repeat_start > end_of_last_block) + block_boundaries.append({ end_of_last_block, repeat_start }); + block_boundaries.append({ repeat_start, state.instruction_position }); + end_of_last_block = state.instruction_position + opcode.size(); + break; + } + default: + break; + } + + auto next_ip = state.instruction_position + opcode.size(); + if (next_ip < bytecode.size()) + state.instruction_position = next_ip; + else + break; + } + + if (end_of_last_block < bytecode.size()) + block_boundaries.append({ end_of_last_block, bytecode.size() }); + + quick_sort(block_boundaries, [](auto& a, auto& b) { return a.start < b.start; }); + + return block_boundaries; +} + +static bool block_satisfies_atomic_rewrite_precondition(ByteCode const& bytecode, Block const& repeated_block, Block const& following_block) +{ + Vector> repeated_values; + MatchState state; + for (state.instruction_position = repeated_block.start; state.instruction_position < repeated_block.end;) { + auto& opcode = bytecode.get_opcode(state); + switch (opcode.opcode_id()) { + case OpCodeId::Compare: { + auto compares = static_cast(opcode).flat_compares(); + if (repeated_values.is_empty() && any_of(compares, [](auto& compare) { return compare.type == CharacterCompareType::AnyChar; })) + return false; + repeated_values.append(move(compares)); + break; + } + case OpCodeId::CheckBegin: + case OpCodeId::CheckEnd: + if (repeated_values.is_empty()) + return true; + break; + case OpCodeId::CheckBoundary: + // FIXME: What should we do with these? for now, let's fail. + return false; + case OpCodeId::Restore: + case OpCodeId::GoBack: + return false; + default: + break; + } + + state.instruction_position += opcode.size(); + } + dbgln_if(REGEX_DEBUG, "Found {} entries in reference", repeated_values.size()); + + // Find the first compare in the following block, it must NOT match any of the values in `repeated_values'. + for (state.instruction_position = following_block.start; state.instruction_position < following_block.end;) { + auto& opcode = bytecode.get_opcode(state); + switch (opcode.opcode_id()) { + case OpCodeId::Compare: { + // We found a compare, let's see what it has. + auto compares = static_cast(opcode).flat_compares(); + if (compares.is_empty()) + break; + + // If either side can match _anything_, fail. + if (any_of(compares, [](auto& compare) { return compare.type == CharacterCompareType::AnyChar; })) + return false; + + for (auto& repeated_value : repeated_values) { + // FIXME: This is too naive! + if (any_of(repeated_value, [](auto& compare) { return compare.type == CharacterCompareType::AnyChar; })) + return false; + + for (auto& repeated_compare : repeated_value) { + // FIXME: This is too naive! it will miss _tons_ of cases since it doesn't check ranges! + if (any_of(compares, [&](auto& compare) { return compare.type == repeated_compare.type && compare.value == repeated_compare.value; })) + return false; + } + } + return true; + } + case OpCodeId::CheckBegin: + case OpCodeId::CheckEnd: + return true; // Nothing can match the end! + case OpCodeId::CheckBoundary: + // FIXME: What should we do with these? For now, consider them a failure. + return false; + default: + break; + } + + state.instruction_position += opcode.size(); + } + + return true; +} + +template +void Regex::attempt_rewrite_loops_as_atomic_groups(BasicBlockList const& basic_blocks) +{ + auto& bytecode = parser_result.bytecode; + if constexpr (REGEX_DEBUG) { + RegexDebug dbg; + dbg.print_bytecode(*this); + for (auto& block : basic_blocks) + dbgln("block from {} to {}", block.start, block.end); + } + + // A pattern such as: + // bb0 | RE0 + // | ForkX bb0 + // ------------------------- + // bb1 | RE1 + // can be rewritten as: + // loop.hdr | ForkStay bb1 + // ------------------------- + // bb0 | RE0 + // | ForkReplaceX bb0 + // ------------------------- + // bb1 | RE1 + // provided that first(RE1) not-in end(RE0), which is to say + // that RE1 cannot start with whatever RE0 has matched (ever). + // + // Alternatively, a second form of this pattern can also occur: + // bb0 | * + // | ForkX bb2 + // ------------------------ + // bb1 | RE0 + // | Jump bb0 + // ------------------------ + // bb2 | RE1 + // which can be transformed (with the same preconditions) to: + // bb0 | * + // | ForkReplaceX bb2 + // ------------------------ + // bb1 | RE0 + // | Jump bb0 + // ------------------------ + // bb2 | RE1 + + enum class AlternateForm { + DirectLoopWithoutHeader, // loop without proper header, a block forking to itself. i.e. the first form. + DirectLoopWithHeader, // loop with proper header, i.e. the second form. + }; + struct CandidateBlock { + Block forking_block; + Optional new_target_block; + AlternateForm form; + }; + Vector candidate_blocks; + + auto is_an_eligible_jump = [](OpCode const& opcode, size_t ip, size_t block_start, AlternateForm alternate_form) { + switch (opcode.opcode_id()) { + case OpCodeId::JumpNonEmpty: { + auto& op = static_cast(opcode); + auto form = op.form(); + if (form != OpCodeId::Jump && alternate_form == AlternateForm::DirectLoopWithHeader) + return false; + if (form != OpCodeId::ForkJump && form != OpCodeId::ForkStay && alternate_form == AlternateForm::DirectLoopWithoutHeader) + return false; + return op.offset() + ip + opcode.size() == block_start; + } + case OpCodeId::ForkJump: + if (alternate_form == AlternateForm::DirectLoopWithHeader) + return false; + return static_cast(opcode).offset() + ip + opcode.size() == block_start; + case OpCodeId::ForkStay: + if (alternate_form == AlternateForm::DirectLoopWithHeader) + return false; + return static_cast(opcode).offset() + ip + opcode.size() == block_start; + case OpCodeId::Jump: + // Infinite loop does *not* produce forks. + if (alternate_form == AlternateForm::DirectLoopWithoutHeader) + return false; + if (alternate_form == AlternateForm::DirectLoopWithHeader) + return static_cast(opcode).offset() + ip + opcode.size() == block_start; + VERIFY_NOT_REACHED(); + default: + return false; + } + }; + for (size_t i = 0; i < basic_blocks.size(); ++i) { + auto forking_block = basic_blocks[i]; + Optional fork_fallback_block; + if (i + 1 < basic_blocks.size()) + fork_fallback_block = basic_blocks[i + 1]; + MatchState state; + // Check if the last instruction in this block is a jump to the block itself: + { + state.instruction_position = forking_block.end; + auto& opcode = bytecode.get_opcode(state); + if (is_an_eligible_jump(opcode, state.instruction_position, forking_block.start, AlternateForm::DirectLoopWithoutHeader)) { + // We've found RE0 (and RE1 is just the following block, if any), let's see if the precondition applies. + // if RE1 is empty, there's no first(RE1), so this is an automatic pass. + if (!fork_fallback_block.has_value() || fork_fallback_block->end == fork_fallback_block->start) { + candidate_blocks.append({ forking_block, fork_fallback_block, AlternateForm::DirectLoopWithoutHeader }); + break; + } + + if (block_satisfies_atomic_rewrite_precondition(bytecode, forking_block, *fork_fallback_block)) { + candidate_blocks.append({ forking_block, fork_fallback_block, AlternateForm::DirectLoopWithoutHeader }); + break; + } + } + } + // Check if the last instruction in the last block is a direct jump to this block + if (fork_fallback_block.has_value()) { + state.instruction_position = fork_fallback_block->end; + auto& opcode = bytecode.get_opcode(state); + if (is_an_eligible_jump(opcode, state.instruction_position, forking_block.start, AlternateForm::DirectLoopWithHeader)) { + // We've found bb1 and bb0, let's just make sure that bb0 forks to bb2. + state.instruction_position = forking_block.end; + auto& opcode = bytecode.get_opcode(state); + if (opcode.opcode_id() == OpCodeId::ForkJump || opcode.opcode_id() == OpCodeId::ForkStay) { + Optional block_following_fork_fallback; + if (i + 2 < basic_blocks.size()) + block_following_fork_fallback = basic_blocks[i + 2]; + if (!block_following_fork_fallback.has_value() || block_satisfies_atomic_rewrite_precondition(bytecode, *fork_fallback_block, *block_following_fork_fallback)) { + candidate_blocks.append({ forking_block, {}, AlternateForm::DirectLoopWithHeader }); + break; + } + } + } + } + } + + dbgln_if(REGEX_DEBUG, "Found {} candidate blocks", candidate_blocks.size()); + if (candidate_blocks.is_empty()) { + dbgln_if(REGEX_DEBUG, "Failed to find anything for {}", pattern_value); + return; + } + + RedBlackTree needed_patches; + + // Reverse the blocks, so we can patch the bytecode without messing with the latter patches. + quick_sort(candidate_blocks, [](auto& a, auto& b) { return b.forking_block.start > a.forking_block.start; }); + for (auto& candidate : candidate_blocks) { + // Note that both forms share a ForkReplace patch in forking_block. + // Patch the ForkX in forking_block to be a ForkReplaceX instead. + auto& opcode_id = bytecode[candidate.forking_block.end]; + if (opcode_id == (ByteCodeValueType)OpCodeId::ForkStay) { + opcode_id = (ByteCodeValueType)OpCodeId::ForkReplaceStay; + } else if (opcode_id == (ByteCodeValueType)OpCodeId::ForkJump) { + opcode_id = (ByteCodeValueType)OpCodeId::ForkReplaceJump; + } else if (opcode_id == (ByteCodeValueType)OpCodeId::JumpNonEmpty) { + auto& jump_opcode_id = bytecode[candidate.forking_block.end + 3]; + if (jump_opcode_id == (ByteCodeValueType)OpCodeId::ForkStay) + jump_opcode_id = (ByteCodeValueType)OpCodeId::ForkReplaceStay; + else if (jump_opcode_id == (ByteCodeValueType)OpCodeId::ForkJump) + jump_opcode_id = (ByteCodeValueType)OpCodeId::ForkReplaceJump; + else + VERIFY_NOT_REACHED(); + } else { + VERIFY_NOT_REACHED(); + } + + if (candidate.form == AlternateForm::DirectLoopWithoutHeader) { + if (candidate.new_target_block.has_value()) { + // Insert a fork-stay targeted at the second block. + bytecode.insert(candidate.forking_block.start, (ByteCodeValueType)OpCodeId::ForkStay); + bytecode.insert(candidate.forking_block.start + 1, candidate.new_target_block->start - candidate.forking_block.start); + needed_patches.insert(candidate.forking_block.start, 2u); + } + } + } + + if (!needed_patches.is_empty()) { + MatchState state; + state.instruction_position = 0; + struct Patch { + ssize_t value; + size_t offset; + bool should_negate { false }; + }; + for (;;) { + if (state.instruction_position >= bytecode.size()) + break; + + auto& opcode = bytecode.get_opcode(state); + Stack patch_points; + + switch (opcode.opcode_id()) { + case OpCodeId::Jump: + patch_points.push({ static_cast(opcode).offset(), state.instruction_position + 1 }); + break; + case OpCodeId::JumpNonEmpty: + patch_points.push({ static_cast(opcode).offset(), state.instruction_position + 1 }); + patch_points.push({ static_cast(opcode).checkpoint(), state.instruction_position + 2 }); + break; + case OpCodeId::ForkJump: + patch_points.push({ static_cast(opcode).offset(), state.instruction_position + 1 }); + break; + case OpCodeId::ForkStay: + patch_points.push({ static_cast(opcode).offset(), state.instruction_position + 1 }); + break; + case OpCodeId::Repeat: + patch_points.push({ -(ssize_t) static_cast(opcode).offset(), state.instruction_position + 1, true }); + break; + default: + break; + } + + while (!patch_points.is_empty()) { + auto& patch_point = patch_points.top(); + auto target_offset = patch_point.value + state.instruction_position + opcode.size(); + + constexpr auto do_patch = [](auto& patch_it, auto& patch_point, auto& target_offset, auto& bytecode, auto ip) { + if (patch_it.key() == ip) + return; + + if (patch_point.value < 0 && target_offset < patch_it.key() && ip > patch_it.key()) + bytecode[patch_point.offset] += (patch_point.should_negate ? 1 : -1) * (*patch_it); + else if (patch_point.value > 0 && target_offset > patch_it.key() && ip < patch_it.key()) + bytecode[patch_point.offset] += (patch_point.should_negate ? -1 : 1) * (*patch_it); + }; + + if (auto patch_it = needed_patches.find_largest_not_above_iterator(target_offset); !patch_it.is_end()) + do_patch(patch_it, patch_point, target_offset, bytecode, state.instruction_position); + else if (auto patch_it = needed_patches.find_largest_not_above_iterator(state.instruction_position); !patch_it.is_end()) + do_patch(patch_it, patch_point, target_offset, bytecode, state.instruction_position); + + patch_points.pop(); + } + + state.instruction_position += opcode.size(); + } + } + + if constexpr (REGEX_DEBUG) { + warnln("Transformed to:"); + RegexDebug dbg; + dbg.print_bytecode(*this); + } +} + +void Optimizer::append_alternation(ByteCode& target, ByteCode& left, ByteCode& right) +{ + if (left.is_empty()) { + target.extend(right); + return; + } + + if (right.is_empty()) { + target.extend(left); + return; + } + + size_t left_skip = 0; + MatchState state; + for (state.instruction_position = 0; state.instruction_position < left.size() && state.instruction_position < right.size();) { + auto left_size = left.get_opcode(state).size(); + auto right_size = right.get_opcode(state).size(); + if (left_size != right_size) + break; + + if (left.span().slice(state.instruction_position, left_size) == right.span().slice(state.instruction_position, right_size)) + left_skip = state.instruction_position + left_size; + else + break; + + state.instruction_position += left_size; + } + + // FIXME: Implement postfix unification too. + size_t right_skip = 0; + + if (left_skip) + target.append(left.data(), left_skip); + + dbgln_if(REGEX_DEBUG, "Skipping {}/{} bytecode entries from {}/{}", left_skip, right_skip, left.size(), right.size()); + + auto left_slice = left.span().slice(left_skip, left.size() - left_skip - right_skip); + auto right_slice = right.span().slice(left_skip, right.size() - left_skip - right_skip); + + target.empend(static_cast(OpCodeId::ForkJump)); + target.empend(right_slice.size() + 2); // Jump to the _ALT label + + target.append(right_slice.data(), right_slice.size()); + + if (!left_slice.is_empty()) { + target.empend(static_cast(OpCodeId::Jump)); + target.empend(left_slice.size()); // Jump to the _END label + } + + // LABEL _ALT = bytecode.size() + 2 + + target.append(left_slice.data(), left_slice.size()); + + // LABEL _END = alterantive_bytecode.size + if (right_skip) + target.append(left.span().slice_from_end(right_skip).data(), right_skip); +} + +template void Regex::run_optimization_passes(); +template void Regex::run_optimization_passes(); +template void Regex::run_optimization_passes(); +}