diff --git a/libavcodec/jpegxl_parser.c b/libavcodec/jpegxl_parser.c index dde36b0d6e..630fc8a60b 100644 --- a/libavcodec/jpegxl_parser.c +++ b/libavcodec/jpegxl_parser.c @@ -683,7 +683,7 @@ static int read_vlc_prefix(GetBitContext *gb, JXLEntropyDecoder *dec, JXLSymbolD int repeat_count_prev = 0, repeat_count_zero = 0, prev = 8; int total_code = 0, len, hskip, num_codes = 0, ret; - VLC level1_vlc; + VLC level1_vlc = { 0 }; if (dist->alphabet_size == 1) { dist->vlc.bits = 0; @@ -709,8 +709,10 @@ static int read_vlc_prefix(GetBitContext *gb, JXLEntropyDecoder *dec, JXLSymbolD } } - if (total_code != 32 && num_codes >= 2 || num_codes < 1) - return AVERROR_INVALIDDATA; + if (total_code != 32 && num_codes >= 2 || num_codes < 1) { + ret = AVERROR_INVALIDDATA; + goto end; + } for (int i = 1; i < 19; i++) level1_codecounts[i] += level1_codecounts[i - 1]; @@ -726,7 +728,7 @@ static int read_vlc_prefix(GetBitContext *gb, JXLEntropyDecoder *dec, JXLSymbolD if (ret < 0) goto end; - buf = av_mallocz(dist->alphabet_size * (2 * sizeof(int8_t) + sizeof(int16_t) + sizeof(uint32_t)) + buf = av_mallocz(MAX_PREFIX_ALPHABET_SIZE * (2 * sizeof(int8_t) + sizeof(int16_t) + sizeof(uint32_t)) + sizeof(uint32_t)); if (!buf) { ret = AVERROR(ENOMEM); @@ -734,21 +736,22 @@ static int read_vlc_prefix(GetBitContext *gb, JXLEntropyDecoder *dec, JXLSymbolD } level2_lens = (int8_t *)buf; - level2_lens_s = (int8_t *)(buf + dist->alphabet_size * sizeof(int8_t)); - level2_syms = (int16_t *)(buf + dist->alphabet_size * (2 * sizeof(int8_t))); - level2_codecounts = (uint32_t *)(buf + dist->alphabet_size * (2 * sizeof(int8_t) + sizeof(int16_t))); + level2_lens_s = (int8_t *)(buf + MAX_PREFIX_ALPHABET_SIZE * sizeof(int8_t)); + level2_syms = (int16_t *)(buf + MAX_PREFIX_ALPHABET_SIZE * (2 * sizeof(int8_t))); + level2_codecounts = (uint32_t *)(buf + MAX_PREFIX_ALPHABET_SIZE * (2 * sizeof(int8_t) + sizeof(int16_t))); total_code = 0; for (int i = 0; i < dist->alphabet_size; i++) { len = get_vlc2(gb, level1_vlc.table, 5, 1); + if (get_bits_left(gb) < 0) { + ret = AVERROR_BUFFER_TOO_SMALL; + goto end; + } if (len == 16) { int extra = 3 + get_bits(gb, 2); if (repeat_count_prev) - extra = 4 * (repeat_count_prev - 2) - repeat_count_prev + extra; - if (i + extra > dist->alphabet_size) { - ret = AVERROR_INVALIDDATA; - goto end; - } + extra += 4 * (repeat_count_prev - 2) - repeat_count_prev; + extra = FFMIN(extra, dist->alphabet_size - i); for (int j = 0; j < extra; j++) level2_lens[i + j] = prev; total_code += (32768 >> prev) * extra; @@ -759,7 +762,8 @@ static int read_vlc_prefix(GetBitContext *gb, JXLEntropyDecoder *dec, JXLSymbolD } else if (len == 17) { int extra = 3 + get_bits(gb, 3); if (repeat_count_zero > 0) - extra = 8 * (repeat_count_zero - 2) - repeat_count_zero + extra; + extra += 8 * (repeat_count_zero - 2) - repeat_count_zero; + extra = FFMIN(extra, dist->alphabet_size - i); i += extra - 1; repeat_count_prev = 0; repeat_count_zero += extra;