From 7d27273dc7c0b00a3407791b19538dba88232fb1 Mon Sep 17 00:00:00 2001 From: Zaggy1024 Date: Sun, 25 Sep 2022 03:18:55 -0500 Subject: [PATCH] LibVideo: Ensure that syntax element counts don't overflow Integer overflow could sometimes occur due to counts going above 255, where the values should instead be clamped at their maximum to avoid wrapping to 0. --- .../Libraries/LibVideo/VP9/TreeParser.cpp | 49 ++++++++++--------- 1 file changed, 26 insertions(+), 23 deletions(-) diff --git a/Userland/Libraries/LibVideo/VP9/TreeParser.cpp b/Userland/Libraries/LibVideo/VP9/TreeParser.cpp index 63179c3e70..274201a787 100644 --- a/Userland/Libraries/LibVideo/VP9/TreeParser.cpp +++ b/Userland/Libraries/LibVideo/VP9/TreeParser.cpp @@ -668,80 +668,83 @@ u8 TreeParser::calculate_token_probability(u8 node) void TreeParser::count_syntax_element(SyntaxElementType type, int value) { + auto increment = [](u8& count) { + count = min(static_cast(count) + 1, 255); + }; switch (type) { case SyntaxElementType::Partition: - m_decoder.m_syntax_element_counter->m_counts_partition[m_ctx][value]++; + increment(m_decoder.m_syntax_element_counter->m_counts_partition[m_ctx][value]); return; case SyntaxElementType::IntraMode: case SyntaxElementType::SubIntraMode: - m_decoder.m_syntax_element_counter->m_counts_intra_mode[m_ctx][value]++; + increment(m_decoder.m_syntax_element_counter->m_counts_intra_mode[m_ctx][value]); return; case SyntaxElementType::UVMode: - m_decoder.m_syntax_element_counter->m_counts_uv_mode[m_ctx][value]++; + increment(m_decoder.m_syntax_element_counter->m_counts_uv_mode[m_ctx][value]); return; case SyntaxElementType::Skip: - m_decoder.m_syntax_element_counter->m_counts_skip[m_ctx][value]++; + increment(m_decoder.m_syntax_element_counter->m_counts_skip[m_ctx][value]); return; case SyntaxElementType::IsInter: - m_decoder.m_syntax_element_counter->m_counts_is_inter[m_ctx][value]++; + increment(m_decoder.m_syntax_element_counter->m_counts_is_inter[m_ctx][value]); return; case SyntaxElementType::CompMode: - m_decoder.m_syntax_element_counter->m_counts_comp_mode[m_ctx][value]++; + increment(m_decoder.m_syntax_element_counter->m_counts_comp_mode[m_ctx][value]); return; case SyntaxElementType::CompRef: - m_decoder.m_syntax_element_counter->m_counts_comp_ref[m_ctx][value]++; + increment(m_decoder.m_syntax_element_counter->m_counts_comp_ref[m_ctx][value]); return; case SyntaxElementType::SingleRefP1: - m_decoder.m_syntax_element_counter->m_counts_single_ref[m_ctx][0][value]++; + increment(m_decoder.m_syntax_element_counter->m_counts_single_ref[m_ctx][0][value]); return; case SyntaxElementType::SingleRefP2: - m_decoder.m_syntax_element_counter->m_counts_single_ref[m_ctx][1][value]++; + increment(m_decoder.m_syntax_element_counter->m_counts_single_ref[m_ctx][1][value]); return; case SyntaxElementType::MVSign: - m_decoder.m_syntax_element_counter->m_counts_mv_sign[m_mv_component][value]++; + increment(m_decoder.m_syntax_element_counter->m_counts_mv_sign[m_mv_component][value]); return; case SyntaxElementType::MVClass0Bit: - m_decoder.m_syntax_element_counter->m_counts_mv_class0_bit[m_mv_component][value]++; + increment(m_decoder.m_syntax_element_counter->m_counts_mv_class0_bit[m_mv_component][value]); return; case SyntaxElementType::MVBit: VERIFY(m_mv_bit < MV_OFFSET_BITS); - m_decoder.m_syntax_element_counter->m_counts_mv_bits[m_mv_component][m_mv_bit][value]++; + increment(m_decoder.m_syntax_element_counter->m_counts_mv_bits[m_mv_component][m_mv_bit][value]); m_mv_bit = 0xFF; return; case SyntaxElementType::TXSize: - m_decoder.m_syntax_element_counter->m_counts_tx_size[m_decoder.m_max_tx_size][m_ctx][value]++; + increment(m_decoder.m_syntax_element_counter->m_counts_tx_size[m_decoder.m_max_tx_size][m_ctx][value]); return; case SyntaxElementType::InterMode: - m_decoder.m_syntax_element_counter->m_counts_inter_mode[m_ctx][value]++; + increment(m_decoder.m_syntax_element_counter->m_counts_inter_mode[m_ctx][value]); return; case SyntaxElementType::InterpFilter: - m_decoder.m_syntax_element_counter->m_counts_interp_filter[m_ctx][value]++; + increment(m_decoder.m_syntax_element_counter->m_counts_interp_filter[m_ctx][value]); return; case SyntaxElementType::MVJoint: - m_decoder.m_syntax_element_counter->m_counts_mv_joint[value]++; + increment(m_decoder.m_syntax_element_counter->m_counts_mv_joint[value]); return; case SyntaxElementType::MVClass: - m_decoder.m_syntax_element_counter->m_counts_mv_class[m_mv_component][value]++; + increment(m_decoder.m_syntax_element_counter->m_counts_mv_class[m_mv_component][value]); return; case SyntaxElementType::MVClass0FR: VERIFY(m_mv_class0_bit < CLASS0_SIZE); - m_decoder.m_syntax_element_counter->m_counts_mv_class0_fr[m_mv_component][m_mv_class0_bit][value]++; + increment(m_decoder.m_syntax_element_counter->m_counts_mv_class0_fr[m_mv_component][m_mv_class0_bit][value]); m_mv_class0_bit = 0xFF; return; case SyntaxElementType::MVClass0HP: - m_decoder.m_syntax_element_counter->m_counts_mv_class0_hp[m_mv_component][value]++; + increment(m_decoder.m_syntax_element_counter->m_counts_mv_class0_hp[m_mv_component][value]); return; case SyntaxElementType::MVFR: - m_decoder.m_syntax_element_counter->m_counts_mv_fr[m_mv_component][value]++; + increment(m_decoder.m_syntax_element_counter->m_counts_mv_fr[m_mv_component][value]); return; case SyntaxElementType::MVHP: - m_decoder.m_syntax_element_counter->m_counts_mv_hp[m_mv_component][value]++; + increment(m_decoder.m_syntax_element_counter->m_counts_mv_hp[m_mv_component][value]); return; case SyntaxElementType::Token: - m_decoder.m_syntax_element_counter->m_counts_token[m_tx_size][m_plane > 0][m_decoder.m_is_inter][m_band][m_ctx][min(2, value)]++; + increment(m_decoder.m_syntax_element_counter->m_counts_token[m_tx_size][m_plane > 0][m_decoder.m_is_inter][m_band][m_ctx][min(2, value)]); return; case SyntaxElementType::MoreCoefs: - m_decoder.m_syntax_element_counter->m_counts_more_coefs[m_tx_size][m_plane > 0][m_decoder.m_is_inter][m_band][m_ctx][value]++; + increment(m_decoder.m_syntax_element_counter->m_counts_more_coefs[m_tx_size][m_plane > 0][m_decoder.m_is_inter][m_band][m_ctx][value]); return; case SyntaxElementType::DefaultIntraMode: case SyntaxElementType::DefaultUVMode: