diff --git a/Userland/Libraries/LibVideo/VP9/Parser.cpp b/Userland/Libraries/LibVideo/VP9/Parser.cpp index 418ea12150..2ea9056dfe 100644 --- a/Userland/Libraries/LibVideo/VP9/Parser.cpp +++ b/Userland/Libraries/LibVideo/VP9/Parser.cpp @@ -1338,7 +1338,7 @@ DecoderErrorOr Parser::read_mv(u8 ref) { m_use_hp = m_allow_high_precision_mv && use_mv_hp(m_best_mv[ref]); MotionVector diff_mv; - auto mv_joint = TRY_READ(m_tree_parser->parse_tree(SyntaxElementType::MVJoint)); + auto mv_joint = TRY_READ(TreeParser::parse_motion_vector_joint(*m_bit_stream, *m_probability_tables, *m_syntax_element_counter)); if (mv_joint == MvJointHzvnz || mv_joint == MvJointHnzvnz) diff_mv.set_row(TRY(read_mv_component(0))); if (mv_joint == MvJointHnzvz || mv_joint == MvJointHnzvnz) @@ -1352,27 +1352,26 @@ DecoderErrorOr Parser::read_mv(u8 ref) DecoderErrorOr Parser::read_mv_component(u8 component) { - m_tree_parser->set_mv_component(component); - auto mv_sign = TRY_READ(m_tree_parser->parse_tree(SyntaxElementType::MVSign)); - auto mv_class = TRY_READ(m_tree_parser->parse_tree(SyntaxElementType::MVClass)); - u32 mag; + auto mv_sign = TRY_READ(TreeParser::parse_motion_vector_sign(*m_bit_stream, *m_probability_tables, *m_syntax_element_counter, component)); + auto mv_class = TRY_READ(TreeParser::parse_motion_vector_class(*m_bit_stream, *m_probability_tables, *m_syntax_element_counter, component)); + u32 magnitude; if (mv_class == MvClass0) { - u32 mv_class0_bit = TRY_READ(m_tree_parser->parse_tree(SyntaxElementType::MVClass0Bit)); - u32 mv_class0_fr = TRY_READ(m_tree_parser->parse_mv_class0_fr(mv_class0_bit)); - u32 mv_class0_hp = TRY_READ(m_tree_parser->parse_tree(SyntaxElementType::MVClass0HP)); - mag = ((mv_class0_bit << 3) | (mv_class0_fr << 1) | mv_class0_hp) + 1; + auto mv_class0_bit = TRY_READ(TreeParser::parse_motion_vector_class0_bit(*m_bit_stream, *m_probability_tables, *m_syntax_element_counter, component)); + auto mv_class0_fr = TRY_READ(TreeParser::parse_motion_vector_class0_fr(*m_bit_stream, *m_probability_tables, *m_syntax_element_counter, component, mv_class0_bit)); + auto mv_class0_hp = TRY_READ(TreeParser::parse_motion_vector_class0_hp(*m_bit_stream, *m_probability_tables, *m_syntax_element_counter, component, m_use_hp)); + magnitude = ((mv_class0_bit << 3) | (mv_class0_fr << 1) | mv_class0_hp) + 1; } else { - u32 d = 0; + u32 bits = 0; for (u8 i = 0; i < mv_class; i++) { - u32 mv_bit = TRY_READ(m_tree_parser->parse_mv_bit(i)); - d |= mv_bit << i; + auto mv_bit = TRY_READ(TreeParser::parse_motion_vector_bit(*m_bit_stream, *m_probability_tables, *m_syntax_element_counter, component, i)); + bits |= mv_bit << i; } - mag = CLASS0_SIZE << (mv_class + 2); - u32 mv_fr = TRY_READ(m_tree_parser->parse_tree(SyntaxElementType::MVFR)); - u32 mv_hp = TRY_READ(m_tree_parser->parse_tree(SyntaxElementType::MVHP)); - mag += ((d << 3) | (mv_fr << 1) | mv_hp) + 1; + magnitude = CLASS0_SIZE << (mv_class + 2); + auto mv_fr = TRY_READ(TreeParser::parse_motion_vector_fr(*m_bit_stream, *m_probability_tables, *m_syntax_element_counter, component)); + auto mv_hp = TRY_READ(TreeParser::parse_motion_vector_hp(*m_bit_stream, *m_probability_tables, *m_syntax_element_counter, component, m_use_hp)); + magnitude += ((bits << 3) | (mv_fr << 1) | mv_hp) + 1; } - return (mv_sign ? -1 : 1) * static_cast(mag); + return (mv_sign ? -1 : 1) * static_cast(magnitude); } Gfx::Point Parser::get_decoded_point_for_plane(u32 column, u32 row, u8 plane) diff --git a/Userland/Libraries/LibVideo/VP9/TreeParser.cpp b/Userland/Libraries/LibVideo/VP9/TreeParser.cpp index f23210f674..03353236af 100644 --- a/Userland/Libraries/LibVideo/VP9/TreeParser.cpp +++ b/Userland/Libraries/LibVideo/VP9/TreeParser.cpp @@ -581,29 +581,85 @@ ErrorOr TreeParser::parse_single_ref_part_2(BitStream& bit_stream, Probabi return value; } +ErrorOr TreeParser::parse_motion_vector_joint(BitStream& bit_stream, ProbabilityTables const& probability_table, SyntaxElementCounter& counter) +{ + auto value = TRY(parse_tree_new(bit_stream, { mv_joint_tree }, [&](u8 node) { return probability_table.mv_joint_probs()[node]; })); + increment_counter(counter.m_counts_mv_joint[value]); + return value; +} + +ErrorOr TreeParser::parse_motion_vector_sign(BitStream& bit_stream, ProbabilityTables const& probability_table, SyntaxElementCounter& counter, u8 component) +{ + auto value = TRY(parse_tree_new(bit_stream, { binary_tree }, [&](u8) { return probability_table.mv_sign_prob()[component]; })); + increment_counter(counter.m_counts_mv_sign[component][value]); + return value; +} + +ErrorOr TreeParser::parse_motion_vector_class(BitStream& bit_stream, ProbabilityTables const& probability_table, SyntaxElementCounter& counter, u8 component) +{ + // Spec doesn't mention node, but the probabilities table has an extra dimension + // so we will use node for that. + auto value = TRY(parse_tree_new(bit_stream, { mv_class_tree }, [&](u8 node) { return probability_table.mv_class_probs()[component][node]; })); + increment_counter(counter.m_counts_mv_class[component][value]); + return value; +} + +ErrorOr TreeParser::parse_motion_vector_class0_bit(BitStream& bit_stream, ProbabilityTables const& probability_table, SyntaxElementCounter& counter, u8 component) +{ + auto value = TRY(parse_tree_new(bit_stream, { binary_tree }, [&](u8) { return probability_table.mv_class0_bit_prob()[component]; })); + increment_counter(counter.m_counts_mv_class0_bit[component][value]); + return value; +} + +ErrorOr TreeParser::parse_motion_vector_class0_fr(BitStream& bit_stream, ProbabilityTables const& probability_table, SyntaxElementCounter& counter, u8 component, bool class_0_bit) +{ + auto value = TRY(parse_tree_new(bit_stream, { mv_fr_tree }, [&](u8 node) { return probability_table.mv_class0_fr_probs()[component][class_0_bit][node]; })); + increment_counter(counter.m_counts_mv_class0_fr[component][class_0_bit][value]); + return value; +} + +ErrorOr TreeParser::parse_motion_vector_class0_hp(BitStream& bit_stream, ProbabilityTables const& probability_table, SyntaxElementCounter& counter, u8 component, bool use_hp) +{ + TreeParser::TreeSelection tree { 1 }; + if (use_hp) + tree = { binary_tree }; + auto value = TRY(parse_tree_new(bit_stream, tree, [&](u8) { return probability_table.mv_class0_hp_prob()[component]; })); + increment_counter(counter.m_counts_mv_class0_hp[component][value]); + return value; +} + +ErrorOr TreeParser::parse_motion_vector_bit(BitStream& bit_stream, ProbabilityTables const& probability_table, SyntaxElementCounter& counter, u8 component, u8 bit_index) +{ + auto value = TRY(parse_tree_new(bit_stream, { binary_tree }, [&](u8) { return probability_table.mv_bits_prob()[component][bit_index]; })); + increment_counter(counter.m_counts_mv_bits[component][bit_index][value]); + return value; +} + +ErrorOr TreeParser::parse_motion_vector_fr(BitStream& bit_stream, ProbabilityTables const& probability_table, SyntaxElementCounter& counter, u8 component) +{ + auto value = TRY(parse_tree_new(bit_stream, { mv_fr_tree }, [&](u8 node) { return probability_table.mv_fr_probs()[component][node]; })); + increment_counter(counter.m_counts_mv_fr[component][value]); + return value; +} + +ErrorOr TreeParser::parse_motion_vector_hp(BitStream& bit_stream, ProbabilityTables const& probability_table, SyntaxElementCounter& counter, u8 component, bool use_hp) +{ + TreeParser::TreeSelection tree { 1 }; + if (use_hp) + tree = { binary_tree }; + auto value = TRY(parse_tree_new(bit_stream, tree, [&](u8) { return probability_table.mv_hp_prob()[component]; })); + increment_counter(counter.m_counts_mv_hp[component][value]); + return value; +} + /* * Select a tree value based on the type of syntax element being parsed, as well as some parser state, as specified in section 9.3.1 */ TreeParser::TreeSelection TreeParser::select_tree(SyntaxElementType type) { switch (type) { - case SyntaxElementType::MVSign: - case SyntaxElementType::MVClass0Bit: - case SyntaxElementType::MVBit: case SyntaxElementType::MoreCoefs: return { binary_tree }; - case SyntaxElementType::MVJoint: - return { mv_joint_tree }; - case SyntaxElementType::MVClass: - return { mv_class_tree }; - case SyntaxElementType::MVClass0FR: - case SyntaxElementType::MVFR: - return { mv_fr_tree }; - case SyntaxElementType::MVClass0HP: - case SyntaxElementType::MVHP: - if (m_decoder.m_use_hp) - return { binary_tree }; - return { 1 }; case SyntaxElementType::Token: return { token_tree }; default: @@ -618,28 +674,6 @@ TreeParser::TreeSelection TreeParser::select_tree(SyntaxElementType type) u8 TreeParser::select_tree_probability(SyntaxElementType type, u8 node) { switch (type) { - case SyntaxElementType::MVSign: - return m_decoder.m_probability_tables->mv_sign_prob()[m_mv_component]; - case SyntaxElementType::MVClass0Bit: - return m_decoder.m_probability_tables->mv_class0_bit_prob()[m_mv_component]; - case SyntaxElementType::MVBit: - VERIFY(m_mv_bit < MV_OFFSET_BITS); - return m_decoder.m_probability_tables->mv_bits_prob()[m_mv_component][m_mv_bit]; - case SyntaxElementType::MVJoint: - return m_decoder.m_probability_tables->mv_joint_probs()[node]; - case SyntaxElementType::MVClass: - // Spec doesn't mention node, but the probabilities table has an extra dimension - // so we will use node for that. - return m_decoder.m_probability_tables->mv_class_probs()[m_mv_component][node]; - case SyntaxElementType::MVClass0FR: - VERIFY(m_mv_class0_bit < CLASS0_SIZE); - return m_decoder.m_probability_tables->mv_class0_fr_probs()[m_mv_component][m_mv_class0_bit][node]; - case SyntaxElementType::MVClass0HP: - return m_decoder.m_probability_tables->mv_class0_hp_prob()[m_mv_component]; - case SyntaxElementType::MVFR: - return m_decoder.m_probability_tables->mv_fr_probs()[m_mv_component][node]; - case SyntaxElementType::MVHP: - return m_decoder.m_probability_tables->mv_hp_prob()[m_mv_component]; case SyntaxElementType::Token: return calculate_token_probability(node); case SyntaxElementType::MoreCoefs: @@ -738,37 +772,6 @@ void TreeParser::count_syntax_element(SyntaxElementType type, int value) increment_counter(count); }; switch (type) { - case SyntaxElementType::MVSign: - increment(m_decoder.m_syntax_element_counter->m_counts_mv_sign[m_mv_component][value]); - return; - case SyntaxElementType::MVClass0Bit: - increment(m_decoder.m_syntax_element_counter->m_counts_mv_class0_bit[m_mv_component][value]); - return; - case SyntaxElementType::MVBit: - VERIFY(m_mv_bit < MV_OFFSET_BITS); - increment(m_decoder.m_syntax_element_counter->m_counts_mv_bits[m_mv_component][m_mv_bit][value]); - m_mv_bit = 0xFF; - return; - case SyntaxElementType::MVJoint: - increment(m_decoder.m_syntax_element_counter->m_counts_mv_joint[value]); - return; - case SyntaxElementType::MVClass: - increment(m_decoder.m_syntax_element_counter->m_counts_mv_class[m_mv_component][value]); - return; - case SyntaxElementType::MVClass0FR: - VERIFY(m_mv_class0_bit < CLASS0_SIZE); - increment(m_decoder.m_syntax_element_counter->m_counts_mv_class0_fr[m_mv_component][m_mv_class0_bit][value]); - m_mv_class0_bit = 0xFF; - return; - case SyntaxElementType::MVClass0HP: - increment(m_decoder.m_syntax_element_counter->m_counts_mv_class0_hp[m_mv_component][value]); - return; - case SyntaxElementType::MVFR: - increment(m_decoder.m_syntax_element_counter->m_counts_mv_fr[m_mv_component][value]); - return; - case SyntaxElementType::MVHP: - increment(m_decoder.m_syntax_element_counter->m_counts_mv_hp[m_mv_component][value]); - return; case SyntaxElementType::Token: increment(m_decoder.m_syntax_element_counter->m_counts_token[m_tx_size][m_plane > 0][m_decoder.m_is_inter][m_band][m_ctx][min(2, value)]); return; diff --git a/Userland/Libraries/LibVideo/VP9/TreeParser.h b/Userland/Libraries/LibVideo/VP9/TreeParser.h index 7427c9a417..0227c3a7a6 100644 --- a/Userland/Libraries/LibVideo/VP9/TreeParser.h +++ b/Userland/Libraries/LibVideo/VP9/TreeParser.h @@ -81,6 +81,16 @@ public: static ErrorOr parse_single_ref_part_1(BitStream&, ProbabilityTables const&, SyntaxElementCounter&, Optional above_single, Optional left_single, Optional above_intra, Optional left_intra, Optional above_ref_frame, Optional left_ref_frame); static ErrorOr parse_single_ref_part_2(BitStream&, ProbabilityTables const&, SyntaxElementCounter&, Optional above_single, Optional left_single, Optional above_intra, Optional left_intra, Optional above_ref_frame, Optional left_ref_frame); + static ErrorOr parse_motion_vector_joint(BitStream&, ProbabilityTables const&, SyntaxElementCounter&); + static ErrorOr parse_motion_vector_sign(BitStream&, ProbabilityTables const&, SyntaxElementCounter&, u8 component); + static ErrorOr parse_motion_vector_class(BitStream&, ProbabilityTables const&, SyntaxElementCounter&, u8 component); + static ErrorOr parse_motion_vector_class0_bit(BitStream&, ProbabilityTables const&, SyntaxElementCounter&, u8 component); + static ErrorOr parse_motion_vector_class0_fr(BitStream&, ProbabilityTables const&, SyntaxElementCounter&, u8 component, bool class_0_bit); + static ErrorOr parse_motion_vector_class0_hp(BitStream&, ProbabilityTables const&, SyntaxElementCounter&, u8 component, bool use_hp); + static ErrorOr parse_motion_vector_bit(BitStream&, ProbabilityTables const&, SyntaxElementCounter&, u8 component, u8 bit_index); + static ErrorOr parse_motion_vector_fr(BitStream&, ProbabilityTables const&, SyntaxElementCounter&, u8 component); + static ErrorOr parse_motion_vector_hp(BitStream&, ProbabilityTables const&, SyntaxElementCounter&, u8 component, bool use_hp); + void set_default_intra_mode_variables(u8 idx, u8 idy) { m_idx = idx; @@ -95,23 +105,6 @@ public: m_start_y = start_y; } - void set_mv_component(u8 component) - { - m_mv_component = component; - } - - ErrorOr parse_mv_bit(u8 bit) - { - m_mv_bit = bit; - return parse_tree(SyntaxElementType::MVBit); - } - - ErrorOr parse_mv_class0_fr(bool mv_class0_bit) - { - m_mv_class0_bit = mv_class0_bit; - return parse_tree(SyntaxElementType::MVClass0FR); - } - private: u8 calculate_token_probability(u8 node); u8 calculate_more_coefs_probability();