2 * Copyright (c) Meta Platforms, Inc. and affiliates.
5 * This source code is licensed under both the BSD-style license (found in the
6 * LICENSE file in the root directory of this source tree) and the GPLv2 (found
7 * in the COPYING file in the root directory of this source tree).
8 * You may select, at your option, one of the above-listed licenses.
11 /// Zstandard educational decoder implementation
12 /// See https://github.com/facebook/zstd/blob/dev/doc/zstd_compression_format.md
14 #include <stdint.h> // uint8_t, etc.
15 #include <stdlib.h> // malloc, free, exit
16 #include <stdio.h> // fprintf
17 #include <string.h> // memset, memcpy
18 #include "zstd_decompress.h"
21 /******* IMPORTANT CONSTANTS *********************************************/
25 // 4 Bytes, little-endian format. Value : 0xFD2FB528"
26 #define ZSTD_MAGIC_NUMBER 0xFD2FB528U
28 // The size of `Block_Content` is limited by `Block_Maximum_Size`,
29 #define ZSTD_BLOCK_SIZE_MAX ((size_t)128 * 1024)
31 // literal blocks can't be larger than their block
32 #define MAX_LITERALS_SIZE ZSTD_BLOCK_SIZE_MAX
35 /******* UTILITY MACROS AND TYPES *********************************************/
36 #define MAX(a, b) ((a) > (b) ? (a) : (b))
37 #define MIN(a, b) ((a) < (b) ? (a) : (b))
39 #if defined(ZDEC_NO_MESSAGE)
42 #define MESSAGE(...) fprintf(stderr, "" __VA_ARGS__)
45 /// This decoder calls exit(1) when it encounters an error, however a production
46 /// library should propagate error codes
49 MESSAGE("Error: %s\n", s); \
53 ERROR("Input buffer smaller than it should be or input is " \
55 #define OUT_SIZE() ERROR("Output buffer too small for output")
56 #define CORRUPTION() ERROR("Corruption detected while decompressing")
57 #define BAD_ALLOC() ERROR("Memory allocation error")
58 #define IMPOSSIBLE() ERROR("An impossibility has occurred")
69 /******* END UTILITY MACROS AND TYPES *****************************************/
71 /******* IMPLEMENTATION PRIMITIVE PROTOTYPES **********************************/
72 /// The implementations for these functions can be found at the bottom of this
73 /// file. They implement low-level functionality needed for the higher level
74 /// decompression functions.
76 /*** IO STREAM OPERATIONS *************/
78 /// ostream_t/istream_t are used to wrap the pointers/length data passed into
79 /// ZSTD_decompress, so that all IO operations are safely bounds checked
80 /// They are written/read forward, and reads are treated as little-endian
81 /// They should be used opaquely to ensure safety
91 // Input often reads a few bits at a time, so maintain an internal offset
95 /// The following two functions are the only ones that allow the istream to be
98 /// Reads `num` bits from a bitstream, and updates the internal offset
99 static inline u64 IO_read_bits(istream_t *const in, const int num_bits);
100 /// Backs-up the stream by `num` bits so they can be read again
101 static inline void IO_rewind_bits(istream_t *const in, const int num_bits);
102 /// If the remaining bits in a byte will be unused, advance to the end of the
104 static inline void IO_align_stream(istream_t *const in);
106 /// Write the given byte into the output stream
107 static inline void IO_write_byte(ostream_t *const out, u8 symb);
109 /// Returns the number of bytes left to be read in this stream. The stream must
111 static inline size_t IO_istream_len(const istream_t *const in);
113 /// Advances the stream by `len` bytes, and returns a pointer to the chunk that
114 /// was skipped. The stream must be byte aligned.
115 static inline const u8 *IO_get_read_ptr(istream_t *const in, size_t len);
116 /// Advances the stream by `len` bytes, and returns a pointer to the chunk that
117 /// was skipped so it can be written to.
118 static inline u8 *IO_get_write_ptr(ostream_t *const out, size_t len);
120 /// Advance the inner state by `len` bytes. The stream must be byte aligned.
121 static inline void IO_advance_input(istream_t *const in, size_t len);
123 /// Returns an `ostream_t` constructed from the given pointer and length.
124 static inline ostream_t IO_make_ostream(u8 *out, size_t len);
125 /// Returns an `istream_t` constructed from the given pointer and length.
126 static inline istream_t IO_make_istream(const u8 *in, size_t len);
128 /// Returns an `istream_t` with the same base as `in`, and length `len`.
129 /// Then, advance `in` to account for the consumed bytes.
130 /// `in` must be byte aligned.
131 static inline istream_t IO_make_sub_istream(istream_t *const in, size_t len);
132 /*** END IO STREAM OPERATIONS *********/
134 /*** BITSTREAM OPERATIONS *************/
135 /// Read `num` bits (up to 64) from `src + offset`, where `offset` is in bits,
136 /// and return them interpreted as a little-endian unsigned integer.
137 static inline u64 read_bits_LE(const u8 *src, const int num_bits,
138 const size_t offset);
140 /// Read bits from the end of a HUF or FSE bitstream. `offset` is in bits, so
141 /// it updates `offset` to `offset - bits`, and then reads `bits` bits from
142 /// `src + offset`. If the offset becomes negative, the extra bits at the
143 /// bottom are filled in with `0` bits instead of reading from before `src`.
144 static inline u64 STREAM_read_bits(const u8 *src, const int bits,
146 /*** END BITSTREAM OPERATIONS *********/
148 /*** BIT COUNTING OPERATIONS **********/
149 /// Returns the index of the highest set bit in `num`, or `-1` if `num == 0`
150 static inline int highest_set_bit(const u64 num);
151 /*** END BIT COUNTING OPERATIONS ******/
153 /*** HUFFMAN PRIMITIVES ***************/
154 // Table decode method uses exponential memory, so we need to limit depth
155 #define HUF_MAX_BITS (16)
157 // Limit the maximum number of symbols to 256 so we can store a symbol in a byte
158 #define HUF_MAX_SYMBS (256)
160 /// Structure containing all tables necessary for efficient Huffman decoding
167 /// Decode a single symbol and read in enough bits to refresh the state
168 static inline u8 HUF_decode_symbol(const HUF_dtable *const dtable,
169 u16 *const state, const u8 *const src,
171 /// Read in a full state's worth of bits to initialize it
172 static inline void HUF_init_state(const HUF_dtable *const dtable,
173 u16 *const state, const u8 *const src,
176 /// Decompresses a single Huffman stream, returns the number of bytes decoded.
177 /// `src_len` must be the exact length of the Huffman-coded block.
178 static size_t HUF_decompress_1stream(const HUF_dtable *const dtable,
179 ostream_t *const out, istream_t *const in);
180 /// Same as previous but decodes 4 streams, formatted as in the Zstandard
182 /// `src_len` must be the exact length of the Huffman-coded block.
183 static size_t HUF_decompress_4stream(const HUF_dtable *const dtable,
184 ostream_t *const out, istream_t *const in);
186 /// Initialize a Huffman decoding table using the table of bit counts provided
187 static void HUF_init_dtable(HUF_dtable *const table, const u8 *const bits,
188 const int num_symbs);
189 /// Initialize a Huffman decoding table using the table of weights provided
190 /// Weights follow the definition provided in the Zstandard specification
191 static void HUF_init_dtable_usingweights(HUF_dtable *const table,
192 const u8 *const weights,
193 const int num_symbs);
195 /// Free the malloc'ed parts of a decoding table
196 static void HUF_free_dtable(HUF_dtable *const dtable);
197 /*** END HUFFMAN PRIMITIVES ***********/
199 /*** FSE PRIMITIVES *******************/
200 /// For more description of FSE see
201 /// https://github.com/Cyan4973/FiniteStateEntropy/
203 // FSE table decoding uses exponential memory, so limit the maximum accuracy
204 #define FSE_MAX_ACCURACY_LOG (15)
205 // Limit the maximum number of symbols so they can be stored in a single byte
206 #define FSE_MAX_SYMBS (256)
208 /// The tables needed to decode FSE encoded streams
216 /// Return the symbol for the current state
217 static inline u8 FSE_peek_symbol(const FSE_dtable *const dtable,
219 /// Read the number of bits necessary to update state, update, and shift offset
220 /// back to reflect the bits read
221 static inline void FSE_update_state(const FSE_dtable *const dtable,
222 u16 *const state, const u8 *const src,
225 /// Combine peek and update: decode a symbol and update the state
226 static inline u8 FSE_decode_symbol(const FSE_dtable *const dtable,
227 u16 *const state, const u8 *const src,
230 /// Read bits from the stream to initialize the state and shift offset back
231 static inline void FSE_init_state(const FSE_dtable *const dtable,
232 u16 *const state, const u8 *const src,
235 /// Decompress two interleaved bitstreams (e.g. compressed Huffman weights)
236 /// using an FSE decoding table. `src_len` must be the exact length of the
238 static size_t FSE_decompress_interleaved2(const FSE_dtable *const dtable,
239 ostream_t *const out,
240 istream_t *const in);
242 /// Initialize a decoding table using normalized frequencies.
243 static void FSE_init_dtable(FSE_dtable *const dtable,
244 const i16 *const norm_freqs, const int num_symbs,
245 const int accuracy_log);
247 /// Decode an FSE header as defined in the Zstandard format specification and
248 /// use the decoded frequencies to initialize a decoding table.
249 static void FSE_decode_header(FSE_dtable *const dtable, istream_t *const in,
250 const int max_accuracy_log);
252 /// Initialize an FSE table that will always return the same symbol and consume
253 /// 0 bits per symbol, to be used for RLE mode in sequence commands
254 static void FSE_init_dtable_rle(FSE_dtable *const dtable, const u8 symb);
256 /// Free the malloc'ed parts of a decoding table
257 static void FSE_free_dtable(FSE_dtable *const dtable);
258 /*** END FSE PRIMITIVES ***************/
260 /******* END IMPLEMENTATION PRIMITIVE PROTOTYPES ******************************/
262 /******* ZSTD HELPER STRUCTS AND PROTOTYPES ***********************************/
264 /// A small structure that can be reused in various places that need to access
265 /// frame header information
267 // The size of window that we need to be able to contiguously store for
270 // The total output size of this compressed frame
271 size_t frame_content_size;
273 // The dictionary id if this frame uses one
276 // Whether or not the content of this frame has a checksum
277 int content_checksum_flag;
278 // Whether or not the output for this frame is in a single segment
279 int single_segment_flag;
282 /// The context needed to decode blocks in a frame
284 frame_header_t header;
286 // The total amount of data available for backreferences, to determine if an
287 // offset too large to be correct
288 size_t current_total_output;
290 const u8 *dict_content;
291 size_t dict_content_len;
293 // Entropy encoding tables so they can be repeated by future blocks instead
295 HUF_dtable literals_dtable;
296 FSE_dtable ll_dtable;
297 FSE_dtable ml_dtable;
298 FSE_dtable of_dtable;
300 // The last 3 offsets for the special "repeat offsets".
301 u64 previous_offsets[3];
304 /// The decoded contents of a dictionary so that it doesn't have to be repeated
305 /// for each frame that uses it
306 struct dictionary_s {
308 HUF_dtable literals_dtable;
309 FSE_dtable ll_dtable;
310 FSE_dtable ml_dtable;
311 FSE_dtable of_dtable;
313 // Raw content for backreferences
317 // Offset history to prepopulate the frame's history
318 u64 previous_offsets[3];
323 /// A tuple containing the parts necessary to decode and execute a ZSTD sequence
329 } sequence_command_t;
331 /// The decoder works top-down, starting at the high level like Zstd frames, and
332 /// working down to lower more technical levels such as blocks, literals, and
333 /// sequences. The high-level functions roughly follow the outline of the
334 /// format specification:
335 /// https://github.com/facebook/zstd/blob/dev/doc/zstd_compression_format.md
337 /// Before the implementation of each high-level function declared here, the
338 /// prototypes for their helper functions are defined and explained
340 /// Decode a single Zstd frame, or error if the input is not a valid frame.
341 /// Accepts a dict argument, which may be NULL indicating no dictionary.
343 /// https://github.com/facebook/zstd/blob/dev/doc/zstd_compression_format.md#frame-concatenation
344 static void decode_frame(ostream_t *const out, istream_t *const in,
345 const dictionary_t *const dict);
347 // Decode data in a compressed block
348 static void decompress_block(frame_context_t *const ctx, ostream_t *const out,
349 istream_t *const in);
351 // Decode the literals section of a block
352 static size_t decode_literals(frame_context_t *const ctx, istream_t *const in,
353 u8 **const literals);
355 // Decode the sequences part of a block
356 static size_t decode_sequences(frame_context_t *const ctx, istream_t *const in,
357 sequence_command_t **const sequences);
359 // Execute the decoded sequences on the literals block
360 static void execute_sequences(frame_context_t *const ctx, ostream_t *const out,
361 const u8 *const literals,
362 const size_t literals_len,
363 const sequence_command_t *const sequences,
364 const size_t num_sequences);
366 // Copies literals and returns the total literal length that was copied
367 static u32 copy_literals(const size_t seq, istream_t *litstream,
368 ostream_t *const out);
370 // Given an offset code from a sequence command (either an actual offset value
371 // or an index for previous offset), computes the correct offset and updates
372 // the offset history
373 static size_t compute_offset(sequence_command_t seq, u64 *const offset_hist);
375 // Given an offset, match length, and total output, as well as the frame
376 // context for the dictionary, determines if the dictionary is used and
377 // executes the copy operation
378 static void execute_match_copy(frame_context_t *const ctx, size_t offset,
379 size_t match_length, size_t total_output,
380 ostream_t *const out);
382 /******* END ZSTD HELPER STRUCTS AND PROTOTYPES *******************************/
384 size_t ZSTD_decompress(void *const dst, const size_t dst_len,
385 const void *const src, const size_t src_len) {
386 dictionary_t* const uninit_dict = create_dictionary();
387 size_t const decomp_size = ZSTD_decompress_with_dict(dst, dst_len, src,
388 src_len, uninit_dict);
389 free_dictionary(uninit_dict);
393 size_t ZSTD_decompress_with_dict(void *const dst, const size_t dst_len,
394 const void *const src, const size_t src_len,
395 dictionary_t* parsed_dict) {
397 istream_t in = IO_make_istream(src, src_len);
398 ostream_t out = IO_make_ostream(dst, dst_len);
400 // "A content compressed by Zstandard is transformed into a Zstandard frame.
401 // Multiple frames can be appended into a single file or stream. A frame is
402 // totally independent, has a defined beginning and end, and a set of
403 // parameters which tells the decoder how to decompress it."
405 /* this decoder assumes decompression of a single frame */
406 decode_frame(&out, &in, parsed_dict);
408 return (size_t)(out.ptr - (u8 *)dst);
411 /******* FRAME DECODING ******************************************************/
413 static void decode_data_frame(ostream_t *const out, istream_t *const in,
414 const dictionary_t *const dict);
415 static void init_frame_context(frame_context_t *const context,
417 const dictionary_t *const dict);
418 static void free_frame_context(frame_context_t *const context);
419 static void parse_frame_header(frame_header_t *const header,
420 istream_t *const in);
421 static void frame_context_apply_dict(frame_context_t *const ctx,
422 const dictionary_t *const dict);
424 static void decompress_data(frame_context_t *const ctx, ostream_t *const out,
425 istream_t *const in);
427 static void decode_frame(ostream_t *const out, istream_t *const in,
428 const dictionary_t *const dict) {
429 const u32 magic_number = (u32)IO_read_bits(in, 32);
430 if (magic_number == ZSTD_MAGIC_NUMBER) {
432 decode_data_frame(out, in, dict);
437 // not a real frame or a skippable frame
438 ERROR("Tried to decode non-ZSTD frame");
441 /// Decode a frame that contains compressed data. Not all frames do as there
442 /// are skippable frames.
444 /// https://github.com/facebook/zstd/blob/dev/doc/zstd_compression_format.md#general-structure-of-zstandard-frame-format
445 static void decode_data_frame(ostream_t *const out, istream_t *const in,
446 const dictionary_t *const dict) {
449 // Initialize the context that needs to be carried from block to block
450 init_frame_context(&ctx, in, dict);
452 if (ctx.header.frame_content_size != 0 &&
453 ctx.header.frame_content_size > out->len) {
457 decompress_data(&ctx, out, in);
459 free_frame_context(&ctx);
462 /// Takes the information provided in the header and dictionary, and initializes
463 /// the context for this frame
464 static void init_frame_context(frame_context_t *const context,
466 const dictionary_t *const dict) {
467 // Most fields in context are correct when initialized to 0
468 memset(context, 0, sizeof(frame_context_t));
470 // Parse data from the frame header
471 parse_frame_header(&context->header, in);
473 // Set up the offset history for the repeat offset commands
474 context->previous_offsets[0] = 1;
475 context->previous_offsets[1] = 4;
476 context->previous_offsets[2] = 8;
478 // Apply details from the dict if it exists
479 frame_context_apply_dict(context, dict);
482 static void free_frame_context(frame_context_t *const context) {
483 HUF_free_dtable(&context->literals_dtable);
485 FSE_free_dtable(&context->ll_dtable);
486 FSE_free_dtable(&context->ml_dtable);
487 FSE_free_dtable(&context->of_dtable);
489 memset(context, 0, sizeof(frame_context_t));
492 static void parse_frame_header(frame_header_t *const header,
493 istream_t *const in) {
494 // "The first header's byte is called the Frame_Header_Descriptor. It tells
495 // which other fields are present. Decoding this byte is enough to tell the
496 // size of Frame_Header.
498 // Bit number Field name
499 // 7-6 Frame_Content_Size_flag
500 // 5 Single_Segment_flag
503 // 2 Content_Checksum_flag
504 // 1-0 Dictionary_ID_flag"
505 const u8 descriptor = (u8)IO_read_bits(in, 8);
507 // decode frame header descriptor into flags
508 const u8 frame_content_size_flag = descriptor >> 6;
509 const u8 single_segment_flag = (descriptor >> 5) & 1;
510 const u8 reserved_bit = (descriptor >> 3) & 1;
511 const u8 content_checksum_flag = (descriptor >> 2) & 1;
512 const u8 dictionary_id_flag = descriptor & 3;
514 if (reserved_bit != 0) {
518 header->single_segment_flag = single_segment_flag;
519 header->content_checksum_flag = content_checksum_flag;
521 // decode window size
522 if (!single_segment_flag) {
523 // "Provides guarantees on maximum back-reference distance that will be
524 // used within compressed data. This information is important for
525 // decoders to allocate enough memory.
527 // Bit numbers 7-3 2-0
528 // Field name Exponent Mantissa"
529 u8 window_descriptor = (u8)IO_read_bits(in, 8);
530 u8 exponent = window_descriptor >> 3;
531 u8 mantissa = window_descriptor & 7;
533 // Use the algorithm from the specification to compute window size
534 // https://github.com/facebook/zstd/blob/dev/doc/zstd_compression_format.md#window_descriptor
535 size_t window_base = (size_t)1 << (10 + exponent);
536 size_t window_add = (window_base / 8) * mantissa;
537 header->window_size = window_base + window_add;
540 // decode dictionary id if it exists
541 if (dictionary_id_flag) {
542 // "This is a variable size field, which contains the ID of the
543 // dictionary required to properly decode the frame. Note that this
544 // field is optional. When it's not present, it's up to the caller to
545 // make sure it uses the correct dictionary. Format is little-endian."
546 const int bytes_array[] = {0, 1, 2, 4};
547 const int bytes = bytes_array[dictionary_id_flag];
549 header->dictionary_id = (u32)IO_read_bits(in, bytes * 8);
551 header->dictionary_id = 0;
554 // decode frame content size if it exists
555 if (single_segment_flag || frame_content_size_flag) {
556 // "This is the original (uncompressed) size. This information is
557 // optional. The Field_Size is provided according to value of
558 // Frame_Content_Size_flag. The Field_Size can be equal to 0 (not
559 // present), 1, 2, 4 or 8 bytes. Format is little-endian."
561 // if frame_content_size_flag == 0 but single_segment_flag is set, we
562 // still have a 1 byte field
563 const int bytes_array[] = {1, 2, 4, 8};
564 const int bytes = bytes_array[frame_content_size_flag];
566 header->frame_content_size = IO_read_bits(in, bytes * 8);
568 // "When Field_Size is 2, the offset of 256 is added."
569 header->frame_content_size += 256;
572 header->frame_content_size = 0;
575 if (single_segment_flag) {
576 // "The Window_Descriptor byte is optional. It is absent when
577 // Single_Segment_flag is set. In this case, the maximum back-reference
578 // distance is the content size itself, which can be any value from 1 to
579 // 2^64-1 bytes (16 EB)."
580 header->window_size = header->frame_content_size;
584 /// Decompress the data from a frame block by block
585 static void decompress_data(frame_context_t *const ctx, ostream_t *const out,
586 istream_t *const in) {
587 // "A frame encapsulates one or multiple blocks. Each block can be
588 // compressed or not, and has a guaranteed maximum content size, which
589 // depends on frame parameters. Unlike frames, each block depends on
590 // previous blocks for proper decoding. However, each block can be
591 // decompressed without waiting for its successor, allowing streaming
597 // The lowest bit signals if this block is the last one. Frame ends
598 // right after this block.
600 // Block_Type and Block_Size
602 // The next 2 bits represent the Block_Type, while the remaining 21 bits
603 // represent the Block_Size. Format is little-endian."
604 last_block = (int)IO_read_bits(in, 1);
605 const int block_type = (int)IO_read_bits(in, 2);
606 const size_t block_len = IO_read_bits(in, 21);
608 switch (block_type) {
610 // "Raw_Block - this is an uncompressed block. Block_Size is the
611 // number of bytes to read and copy."
612 const u8 *const read_ptr = IO_get_read_ptr(in, block_len);
613 u8 *const write_ptr = IO_get_write_ptr(out, block_len);
615 // Copy the raw data into the output
616 memcpy(write_ptr, read_ptr, block_len);
618 ctx->current_total_output += block_len;
622 // "RLE_Block - this is a single byte, repeated N times. In which
623 // case, Block_Size is the size to regenerate, while the
624 // "compressed" block is just 1 byte (the byte to repeat)."
625 const u8 *const read_ptr = IO_get_read_ptr(in, 1);
626 u8 *const write_ptr = IO_get_write_ptr(out, block_len);
628 // Copy `block_len` copies of `read_ptr[0]` to the output
629 memset(write_ptr, read_ptr[0], block_len);
631 ctx->current_total_output += block_len;
635 // "Compressed_Block - this is a Zstandard compressed block,
636 // detailed in another section of this specification. Block_Size is
637 // the compressed size.
639 // Create a sub-stream for the block
640 istream_t block_stream = IO_make_sub_istream(in, block_len);
641 decompress_block(ctx, out, &block_stream);
645 // "Reserved - this is not a block. This value cannot be used with
646 // current version of this specification."
652 } while (!last_block);
654 if (ctx->header.content_checksum_flag) {
655 // This program does not support checking the checksum, so skip over it
657 IO_advance_input(in, 4);
660 /******* END FRAME DECODING ***************************************************/
662 /******* BLOCK DECOMPRESSION **************************************************/
663 static void decompress_block(frame_context_t *const ctx, ostream_t *const out,
664 istream_t *const in) {
665 // "A compressed block consists of 2 sections :
668 // Sequences_Section"
671 // Part 1: decode the literals block
673 const size_t literals_size = decode_literals(ctx, in, &literals);
675 // Part 2: decode the sequences block
676 sequence_command_t *sequences = NULL;
677 const size_t num_sequences =
678 decode_sequences(ctx, in, &sequences);
680 // Part 3: combine literals and sequence commands to generate output
681 execute_sequences(ctx, out, literals, literals_size, sequences,
686 /******* END BLOCK DECOMPRESSION **********************************************/
688 /******* LITERALS DECODING ****************************************************/
689 static size_t decode_literals_simple(istream_t *const in, u8 **const literals,
690 const int block_type,
691 const int size_format);
692 static size_t decode_literals_compressed(frame_context_t *const ctx,
695 const int block_type,
696 const int size_format);
697 static void decode_huf_table(HUF_dtable *const dtable, istream_t *const in);
698 static void fse_decode_hufweights(ostream_t *weights, istream_t *const in,
699 int *const num_symbs);
701 static size_t decode_literals(frame_context_t *const ctx, istream_t *const in,
702 u8 **const literals) {
703 // "Literals can be stored uncompressed or compressed using Huffman prefix
704 // codes. When compressed, an optional tree description can be present,
705 // followed by 1 or 4 streams."
707 // "Literals_Section_Header
709 // Header is in charge of describing how literals are packed. It's a
710 // byte-aligned variable-size bitfield, ranging from 1 to 5 bytes, using
711 // little-endian convention."
713 // "Literals_Block_Type
715 // This field uses 2 lowest bits of first byte, describing 4 different block
718 // size_format takes between 1 and 2 bits
719 int block_type = (int)IO_read_bits(in, 2);
720 int size_format = (int)IO_read_bits(in, 2);
722 if (block_type <= 1) {
723 // Raw or RLE literals block
724 return decode_literals_simple(in, literals, block_type,
727 // Huffman compressed literals
728 return decode_literals_compressed(ctx, in, literals, block_type,
733 /// Decodes literals blocks in raw or RLE form
734 static size_t decode_literals_simple(istream_t *const in, u8 **const literals,
735 const int block_type,
736 const int size_format) {
738 switch (size_format) {
739 // These cases are in the form ?0
740 // In this case, the ? bit is actually part of the size field
743 // "Size_Format uses 1 bit. Regenerated_Size uses 5 bits (0-31)."
744 IO_rewind_bits(in, 1);
745 size = IO_read_bits(in, 5);
748 // "Size_Format uses 2 bits. Regenerated_Size uses 12 bits (0-4095)."
749 size = IO_read_bits(in, 12);
752 // "Size_Format uses 2 bits. Regenerated_Size uses 20 bits (0-1048575)."
753 size = IO_read_bits(in, 20);
756 // Size format is in range 0-3
760 if (size > MAX_LITERALS_SIZE) {
764 *literals = malloc(size);
769 switch (block_type) {
771 // "Raw_Literals_Block - Literals are stored uncompressed."
772 const u8 *const read_ptr = IO_get_read_ptr(in, size);
773 memcpy(*literals, read_ptr, size);
777 // "RLE_Literals_Block - Literals consist of a single byte value repeated N times."
778 const u8 *const read_ptr = IO_get_read_ptr(in, 1);
779 memset(*literals, read_ptr[0], size);
789 /// Decodes Huffman compressed literals
790 static size_t decode_literals_compressed(frame_context_t *const ctx,
793 const int block_type,
794 const int size_format) {
795 size_t regenerated_size, compressed_size;
796 // Only size_format=0 has 1 stream, so default to 4
798 switch (size_format) {
800 // "A single stream. Both Compressed_Size and Regenerated_Size use 10
803 // Fall through as it has the same size format
806 // "4 streams. Both Compressed_Size and Regenerated_Size use 10 bits
808 regenerated_size = IO_read_bits(in, 10);
809 compressed_size = IO_read_bits(in, 10);
812 // "4 streams. Both Compressed_Size and Regenerated_Size use 14 bits
814 regenerated_size = IO_read_bits(in, 14);
815 compressed_size = IO_read_bits(in, 14);
818 // "4 streams. Both Compressed_Size and Regenerated_Size use 18 bits
820 regenerated_size = IO_read_bits(in, 18);
821 compressed_size = IO_read_bits(in, 18);
827 if (regenerated_size > MAX_LITERALS_SIZE) {
831 *literals = malloc(regenerated_size);
836 ostream_t lit_stream = IO_make_ostream(*literals, regenerated_size);
837 istream_t huf_stream = IO_make_sub_istream(in, compressed_size);
839 if (block_type == 2) {
840 // Decode the provided Huffman table
841 // "This section is only present when Literals_Block_Type type is
842 // Compressed_Literals_Block (2)."
844 HUF_free_dtable(&ctx->literals_dtable);
845 decode_huf_table(&ctx->literals_dtable, &huf_stream);
847 // If the previous Huffman table is being repeated, ensure it exists
848 if (!ctx->literals_dtable.symbols) {
853 size_t symbols_decoded;
854 if (num_streams == 1) {
855 symbols_decoded = HUF_decompress_1stream(&ctx->literals_dtable, &lit_stream, &huf_stream);
857 symbols_decoded = HUF_decompress_4stream(&ctx->literals_dtable, &lit_stream, &huf_stream);
860 if (symbols_decoded != regenerated_size) {
864 return regenerated_size;
867 // Decode the Huffman table description
868 static void decode_huf_table(HUF_dtable *const dtable, istream_t *const in) {
869 // "All literal values from zero (included) to last present one (excluded)
870 // are represented by Weight with values from 0 to Max_Number_of_Bits."
872 // "This is a single byte value (0-255), which describes how to decode the list of weights."
873 const u8 header = IO_read_bits(in, 8);
875 u8 weights[HUF_MAX_SYMBS];
876 memset(weights, 0, sizeof(weights));
881 // "This is a direct representation, where each Weight is written
882 // directly as a 4 bits field (0-15). The full representation occupies
883 // ((Number_of_Symbols+1)/2) bytes, meaning it uses a last full byte
884 // even if Number_of_Symbols is odd. Number_of_Symbols = headerByte -
886 num_symbs = header - 127;
887 const size_t bytes = (num_symbs + 1) / 2;
889 const u8 *const weight_src = IO_get_read_ptr(in, bytes);
891 for (int i = 0; i < num_symbs; i++) {
892 // "They are encoded forward, 2
893 // weights to a byte with the first weight taking the top four bits
894 // and the second taking the bottom four (e.g. the following
895 // operations could be used to read the weights: Weight[0] =
896 // (Byte[0] >> 4), Weight[1] = (Byte[0] & 0xf), etc.)."
898 weights[i] = weight_src[i / 2] >> 4;
900 weights[i] = weight_src[i / 2] & 0xf;
904 // The weights are FSE encoded, decode them before we can construct the
906 istream_t fse_stream = IO_make_sub_istream(in, header);
907 ostream_t weight_stream = IO_make_ostream(weights, HUF_MAX_SYMBS);
908 fse_decode_hufweights(&weight_stream, &fse_stream, &num_symbs);
911 // Construct the table using the decoded weights
912 HUF_init_dtable_usingweights(dtable, weights, num_symbs);
915 static void fse_decode_hufweights(ostream_t *weights, istream_t *const in,
916 int *const num_symbs) {
917 const int MAX_ACCURACY_LOG = 7;
921 // "An FSE bitstream starts by a header, describing probabilities
922 // distribution. It will create a Decoding Table. For a list of Huffman
923 // weights, maximum accuracy is 7 bits."
924 FSE_decode_header(&dtable, in, MAX_ACCURACY_LOG);
926 // Decode the weights
927 *num_symbs = FSE_decompress_interleaved2(&dtable, weights, in);
929 FSE_free_dtable(&dtable);
931 /******* END LITERALS DECODING ************************************************/
933 /******* SEQUENCE DECODING ****************************************************/
934 /// The combination of FSE states needed to decode sequences
945 /// Different modes to signal to decode_seq_tables what to do
947 seq_literal_length = 0,
949 seq_match_length = 2,
959 /// The predefined FSE distribution tables for `seq_predefined` mode
960 static const i16 SEQ_LITERAL_LENGTH_DEFAULT_DIST[36] = {
961 4, 3, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 1, 1, 1, 2, 2,
962 2, 2, 2, 2, 2, 2, 2, 3, 2, 1, 1, 1, 1, 1, -1, -1, -1, -1};
963 static const i16 SEQ_OFFSET_DEFAULT_DIST[29] = {
964 1, 1, 1, 1, 1, 1, 2, 2, 2, 1, 1, 1, 1, 1, 1,
965 1, 1, 1, 1, 1, 1, 1, 1, 1, -1, -1, -1, -1, -1};
966 static const i16 SEQ_MATCH_LENGTH_DEFAULT_DIST[53] = {
967 1, 4, 3, 2, 2, 2, 2, 2, 2, 1, 1, 1, 1, 1, 1, 1, 1, 1,
968 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
969 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, -1, -1, -1, -1, -1, -1, -1};
971 /// The sequence decoding baseline and number of additional bits to read/add
972 /// https://github.com/facebook/zstd/blob/dev/doc/zstd_compression_format.md#the-codes-for-literals-lengths-match-lengths-and-offsets
973 static const u32 SEQ_LITERAL_LENGTH_BASELINES[36] = {
974 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11,
975 12, 13, 14, 15, 16, 18, 20, 22, 24, 28, 32, 40,
976 48, 64, 128, 256, 512, 1024, 2048, 4096, 8192, 16384, 32768, 65536};
977 static const u8 SEQ_LITERAL_LENGTH_EXTRA_BITS[36] = {
978 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1,
979 1, 1, 2, 2, 3, 3, 4, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16};
981 static const u32 SEQ_MATCH_LENGTH_BASELINES[53] = {
982 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16,
983 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30,
984 31, 32, 33, 34, 35, 37, 39, 41, 43, 47, 51, 59, 67, 83,
985 99, 131, 259, 515, 1027, 2051, 4099, 8195, 16387, 32771, 65539};
986 static const u8 SEQ_MATCH_LENGTH_EXTRA_BITS[53] = {
987 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
988 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1,
989 2, 2, 3, 3, 4, 4, 5, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16};
991 /// Offset decoding is simpler so we just need a maximum code value
992 static const u8 SEQ_MAX_CODES[3] = {35, (u8)-1, 52};
994 static void decompress_sequences(frame_context_t *const ctx,
996 sequence_command_t *const sequences,
997 const size_t num_sequences);
998 static sequence_command_t decode_sequence(sequence_states_t *const state,
1001 static void decode_seq_table(FSE_dtable *const table, istream_t *const in,
1002 const seq_part_t type, const seq_mode_t mode);
1004 static size_t decode_sequences(frame_context_t *const ctx, istream_t *in,
1005 sequence_command_t **const sequences) {
1006 // "A compressed block is a succession of sequences . A sequence is a
1007 // literal copy command, followed by a match copy command. A literal copy
1008 // command specifies a length. It is the number of bytes to be copied (or
1009 // extracted) from the literal section. A match copy command specifies an
1010 // offset and a length. The offset gives the position to copy from, which
1011 // can be within a previous block."
1013 size_t num_sequences;
1015 // "Number_of_Sequences
1017 // This is a variable size field using between 1 and 3 bytes. Let's call its
1018 // first byte byte0."
1019 u8 header = IO_read_bits(in, 8);
1021 // "There are no sequences. The sequence section stops there.
1022 // Regenerated content is defined entirely by literals section."
1025 } else if (header < 128) {
1026 // "Number_of_Sequences = byte0 . Uses 1 byte."
1027 num_sequences = header;
1028 } else if (header < 255) {
1029 // "Number_of_Sequences = ((byte0-128) << 8) + byte1 . Uses 2 bytes."
1030 num_sequences = ((header - 128) << 8) + IO_read_bits(in, 8);
1032 // "Number_of_Sequences = byte1 + (byte2<<8) + 0x7F00 . Uses 3 bytes."
1033 num_sequences = IO_read_bits(in, 16) + 0x7F00;
1036 *sequences = malloc(num_sequences * sizeof(sequence_command_t));
1041 decompress_sequences(ctx, in, *sequences, num_sequences);
1042 return num_sequences;
1045 /// Decompress the FSE encoded sequence commands
1046 static void decompress_sequences(frame_context_t *const ctx, istream_t *in,
1047 sequence_command_t *const sequences,
1048 const size_t num_sequences) {
1049 // "The Sequences_Section regroup all symbols required to decode commands.
1050 // There are 3 symbol types : literals lengths, offsets and match lengths.
1051 // They are encoded together, interleaved, in a single bitstream."
1053 // "Symbol compression modes
1055 // This is a single byte, defining the compression mode of each symbol
1058 // Bit number : Field name
1059 // 7-6 : Literals_Lengths_Mode
1060 // 5-4 : Offsets_Mode
1061 // 3-2 : Match_Lengths_Mode
1063 u8 compression_modes = IO_read_bits(in, 8);
1065 if ((compression_modes & 3) != 0) {
1066 // Reserved bits set
1070 // "Following the header, up to 3 distribution tables can be described. When
1071 // present, they are in this order :
1076 // Update the tables we have stored in the context
1077 decode_seq_table(&ctx->ll_dtable, in, seq_literal_length,
1078 (compression_modes >> 6) & 3);
1080 decode_seq_table(&ctx->of_dtable, in, seq_offset,
1081 (compression_modes >> 4) & 3);
1083 decode_seq_table(&ctx->ml_dtable, in, seq_match_length,
1084 (compression_modes >> 2) & 3);
1087 sequence_states_t states;
1089 // Initialize the decoding tables
1091 states.ll_table = ctx->ll_dtable;
1092 states.of_table = ctx->of_dtable;
1093 states.ml_table = ctx->ml_dtable;
1096 const size_t len = IO_istream_len(in);
1097 const u8 *const src = IO_get_read_ptr(in, len);
1099 // "After writing the last bit containing information, the compressor writes
1100 // a single 1-bit and then fills the byte with 0-7 0 bits of padding."
1101 const int padding = 8 - highest_set_bit(src[len - 1]);
1102 // The offset starts at the end because FSE streams are read backwards
1103 i64 bit_offset = (i64)(len * 8 - (size_t)padding);
1105 // "The bitstream starts with initial state values, each using the required
1106 // number of bits in their respective accuracy, decoded previously from
1107 // their normalized distribution.
1109 // It starts by Literals_Length_State, followed by Offset_State, and finally
1110 // Match_Length_State."
1111 FSE_init_state(&states.ll_table, &states.ll_state, src, &bit_offset);
1112 FSE_init_state(&states.of_table, &states.of_state, src, &bit_offset);
1113 FSE_init_state(&states.ml_table, &states.ml_state, src, &bit_offset);
1115 for (size_t i = 0; i < num_sequences; i++) {
1116 // Decode sequences one by one
1117 sequences[i] = decode_sequence(&states, src, &bit_offset);
1120 if (bit_offset != 0) {
1125 // Decode a single sequence and update the state
1126 static sequence_command_t decode_sequence(sequence_states_t *const states,
1127 const u8 *const src,
1128 i64 *const offset) {
1129 // "Each symbol is a code in its own context, which specifies Baseline and
1130 // Number_of_Bits to add. Codes are FSE compressed, and interleaved with raw
1131 // additional bits in the same bitstream."
1133 // Decode symbols, but don't update states
1134 const u8 of_code = FSE_peek_symbol(&states->of_table, states->of_state);
1135 const u8 ll_code = FSE_peek_symbol(&states->ll_table, states->ll_state);
1136 const u8 ml_code = FSE_peek_symbol(&states->ml_table, states->ml_state);
1138 // Offset doesn't need a max value as it's not decoded using a table
1139 if (ll_code > SEQ_MAX_CODES[seq_literal_length] ||
1140 ml_code > SEQ_MAX_CODES[seq_match_length]) {
1144 // Read the interleaved bits
1145 sequence_command_t seq;
1146 // "Decoding starts by reading the Number_of_Bits required to decode Offset.
1147 // It then does the same for Match_Length, and then for Literals_Length."
1148 seq.offset = ((u32)1 << of_code) + STREAM_read_bits(src, of_code, offset);
1151 SEQ_MATCH_LENGTH_BASELINES[ml_code] +
1152 STREAM_read_bits(src, SEQ_MATCH_LENGTH_EXTRA_BITS[ml_code], offset);
1154 seq.literal_length =
1155 SEQ_LITERAL_LENGTH_BASELINES[ll_code] +
1156 STREAM_read_bits(src, SEQ_LITERAL_LENGTH_EXTRA_BITS[ll_code], offset);
1158 // "If it is not the last sequence in the block, the next operation is to
1159 // update states. Using the rules pre-calculated in the decoding tables,
1160 // Literals_Length_State is updated, followed by Match_Length_State, and
1161 // then Offset_State."
1162 // If the stream is complete don't read bits to update state
1164 FSE_update_state(&states->ll_table, &states->ll_state, src, offset);
1165 FSE_update_state(&states->ml_table, &states->ml_state, src, offset);
1166 FSE_update_state(&states->of_table, &states->of_state, src, offset);
1172 /// Given a sequence part and table mode, decode the FSE distribution
1173 /// Errors if the mode is `seq_repeat` without a pre-existing table in `table`
1174 static void decode_seq_table(FSE_dtable *const table, istream_t *const in,
1175 const seq_part_t type, const seq_mode_t mode) {
1176 // Constant arrays indexed by seq_part_t
1177 const i16 *const default_distributions[] = {SEQ_LITERAL_LENGTH_DEFAULT_DIST,
1178 SEQ_OFFSET_DEFAULT_DIST,
1179 SEQ_MATCH_LENGTH_DEFAULT_DIST};
1180 const size_t default_distribution_lengths[] = {36, 29, 53};
1181 const size_t default_distribution_accuracies[] = {6, 5, 6};
1183 const size_t max_accuracies[] = {9, 8, 9};
1185 if (mode != seq_repeat) {
1186 // Free old one before overwriting
1187 FSE_free_dtable(table);
1191 case seq_predefined: {
1192 // "Predefined_Mode : uses a predefined distribution table."
1193 const i16 *distribution = default_distributions[type];
1194 const size_t symbs = default_distribution_lengths[type];
1195 const size_t accuracy_log = default_distribution_accuracies[type];
1197 FSE_init_dtable(table, distribution, symbs, accuracy_log);
1201 // "RLE_Mode : it's a single code, repeated Number_of_Sequences times."
1202 const u8 symb = IO_get_read_ptr(in, 1)[0];
1203 FSE_init_dtable_rle(table, symb);
1207 // "FSE_Compressed_Mode : standard FSE compression. A distribution table
1208 // will be present "
1209 FSE_decode_header(table, in, max_accuracies[type]);
1213 // "Repeat_Mode : re-use distribution table from previous compressed
1215 // Nothing to do here, table will be unchanged
1216 if (!table->symbols) {
1217 // This mode is invalid if we don't already have a table
1222 // Impossible, as mode is from 0-3
1228 /******* END SEQUENCE DECODING ************************************************/
1230 /******* SEQUENCE EXECUTION ***************************************************/
1231 static void execute_sequences(frame_context_t *const ctx, ostream_t *const out,
1232 const u8 *const literals,
1233 const size_t literals_len,
1234 const sequence_command_t *const sequences,
1235 const size_t num_sequences) {
1236 istream_t litstream = IO_make_istream(literals, literals_len);
1238 u64 *const offset_hist = ctx->previous_offsets;
1239 size_t total_output = ctx->current_total_output;
1241 for (size_t i = 0; i < num_sequences; i++) {
1242 const sequence_command_t seq = sequences[i];
1244 const u32 literals_size = copy_literals(seq.literal_length, &litstream, out);
1245 total_output += literals_size;
1248 size_t const offset = compute_offset(seq, offset_hist);
1250 size_t const match_length = seq.match_length;
1252 execute_match_copy(ctx, offset, match_length, total_output, out);
1254 total_output += match_length;
1257 // Copy any leftover literals
1259 size_t len = IO_istream_len(&litstream);
1260 copy_literals(len, &litstream, out);
1261 total_output += len;
1264 ctx->current_total_output = total_output;
1267 static u32 copy_literals(const size_t literal_length, istream_t *litstream,
1268 ostream_t *const out) {
1269 // If the sequence asks for more literals than are left, the
1270 // sequence must be corrupted
1271 if (literal_length > IO_istream_len(litstream)) {
1275 u8 *const write_ptr = IO_get_write_ptr(out, literal_length);
1276 const u8 *const read_ptr =
1277 IO_get_read_ptr(litstream, literal_length);
1278 // Copy literals to output
1279 memcpy(write_ptr, read_ptr, literal_length);
1281 return literal_length;
1284 static size_t compute_offset(sequence_command_t seq, u64 *const offset_hist) {
1286 // Offsets are special, we need to handle the repeat offsets
1287 if (seq.offset <= 3) {
1288 // "The first 3 values define a repeated offset and we will call
1289 // them Repeated_Offset1, Repeated_Offset2, and Repeated_Offset3.
1290 // They are sorted in recency order, with Repeated_Offset1 meaning
1291 // 'most recent one'".
1293 // Use 0 indexing for the array
1294 u32 idx = seq.offset - 1;
1295 if (seq.literal_length == 0) {
1296 // "There is an exception though, when current sequence's
1297 // literals length is 0. In this case, repeated offsets are
1298 // shifted by one, so Repeated_Offset1 becomes Repeated_Offset2,
1299 // Repeated_Offset2 becomes Repeated_Offset3, and
1300 // Repeated_Offset3 becomes Repeated_Offset1 - 1_byte."
1305 offset = offset_hist[0];
1307 // If idx == 3 then literal length was 0 and the offset was 3,
1308 // as per the exception listed above
1309 offset = idx < 3 ? offset_hist[idx] : offset_hist[0] - 1;
1311 // If idx == 1 we don't need to modify offset_hist[2], since
1312 // we're using the second-most recent code
1314 offset_hist[2] = offset_hist[1];
1316 offset_hist[1] = offset_hist[0];
1317 offset_hist[0] = offset;
1320 // When it's not a repeat offset:
1321 // "if (Offset_Value > 3) offset = Offset_Value - 3;"
1322 offset = seq.offset - 3;
1324 // Shift back history
1325 offset_hist[2] = offset_hist[1];
1326 offset_hist[1] = offset_hist[0];
1327 offset_hist[0] = offset;
1332 static void execute_match_copy(frame_context_t *const ctx, size_t offset,
1333 size_t match_length, size_t total_output,
1334 ostream_t *const out) {
1335 u8 *write_ptr = IO_get_write_ptr(out, match_length);
1336 if (total_output <= ctx->header.window_size) {
1337 // In this case offset might go back into the dictionary
1338 if (offset > total_output + ctx->dict_content_len) {
1339 // The offset goes beyond even the dictionary
1343 if (offset > total_output) {
1344 // "The rest of the dictionary is its content. The content act
1345 // as a "past" in front of data to compress or decompress, so it
1346 // can be referenced in sequence commands."
1347 const size_t dict_copy =
1348 MIN(offset - total_output, match_length);
1349 const size_t dict_offset =
1350 ctx->dict_content_len - (offset - total_output);
1352 memcpy(write_ptr, ctx->dict_content + dict_offset, dict_copy);
1353 write_ptr += dict_copy;
1354 match_length -= dict_copy;
1356 } else if (offset > ctx->header.window_size) {
1360 // We must copy byte by byte because the match length might be larger
1362 // ex: if the output so far was "abc", a command with offset=3 and
1363 // match_length=6 would produce "abcabcabc" as the new output
1364 for (size_t j = 0; j < match_length; j++) {
1365 *write_ptr = *(write_ptr - offset);
1369 /******* END SEQUENCE EXECUTION ***********************************************/
1371 /******* OUTPUT SIZE COUNTING *************************************************/
1372 /// Get the decompressed size of an input stream so memory can be allocated in
1374 /// This implementation assumes `src` points to a single ZSTD-compressed frame
1375 size_t ZSTD_get_decompressed_size(const void *src, const size_t src_len) {
1376 istream_t in = IO_make_istream(src, src_len);
1378 // get decompressed size from ZSTD frame header
1380 const u32 magic_number = (u32)IO_read_bits(&in, 32);
1382 if (magic_number == ZSTD_MAGIC_NUMBER) {
1384 frame_header_t header;
1385 parse_frame_header(&header, &in);
1387 if (header.frame_content_size == 0 && !header.single_segment_flag) {
1388 // Content size not provided, we can't tell
1392 return header.frame_content_size;
1394 // not a real frame or skippable frame
1395 ERROR("ZSTD frame magic number did not match");
1399 /******* END OUTPUT SIZE COUNTING *********************************************/
1401 /******* DICTIONARY PARSING ***************************************************/
1402 dictionary_t* create_dictionary() {
1403 dictionary_t* const dict = calloc(1, sizeof(dictionary_t));
1410 /// Free an allocated dictionary
1411 void free_dictionary(dictionary_t *const dict) {
1412 HUF_free_dtable(&dict->literals_dtable);
1413 FSE_free_dtable(&dict->ll_dtable);
1414 FSE_free_dtable(&dict->of_dtable);
1415 FSE_free_dtable(&dict->ml_dtable);
1417 free(dict->content);
1419 memset(dict, 0, sizeof(dictionary_t));
1425 #if !defined(ZDEC_NO_DICTIONARY)
1426 #define DICT_SIZE_ERROR() ERROR("Dictionary size cannot be less than 8 bytes")
1427 #define NULL_SRC() ERROR("Tried to create dictionary with pointer to null src");
1429 static void init_dictionary_content(dictionary_t *const dict,
1430 istream_t *const in);
1432 void parse_dictionary(dictionary_t *const dict, const void *src,
1434 const u8 *byte_src = (const u8 *)src;
1435 memset(dict, 0, sizeof(dictionary_t));
1436 if (src == NULL) { /* cannot initialize dictionary with null src */
1443 istream_t in = IO_make_istream(byte_src, src_len);
1445 const u32 magic_number = IO_read_bits(&in, 32);
1446 if (magic_number != 0xEC30A437) {
1448 IO_rewind_bits(&in, 32);
1449 init_dictionary_content(dict, &in);
1453 dict->dictionary_id = IO_read_bits(&in, 32);
1455 // "Entropy_Tables : following the same format as the tables in compressed
1456 // blocks. They are stored in following order : Huffman tables for literals,
1457 // FSE table for offsets, FSE table for match lengths, and FSE table for
1458 // literals lengths. It's finally followed by 3 offset values, populating
1459 // recent offsets (instead of using {1,4,8}), stored in order, 4-bytes
1460 // little-endian each, for a total of 12 bytes. Each recent offset must have
1461 // a value < dictionary size."
1462 decode_huf_table(&dict->literals_dtable, &in);
1463 decode_seq_table(&dict->of_dtable, &in, seq_offset, seq_fse);
1464 decode_seq_table(&dict->ml_dtable, &in, seq_match_length, seq_fse);
1465 decode_seq_table(&dict->ll_dtable, &in, seq_literal_length, seq_fse);
1467 // Read in the previous offset history
1468 dict->previous_offsets[0] = IO_read_bits(&in, 32);
1469 dict->previous_offsets[1] = IO_read_bits(&in, 32);
1470 dict->previous_offsets[2] = IO_read_bits(&in, 32);
1472 // Ensure the provided offsets aren't too large
1473 // "Each recent offset must have a value < dictionary size."
1474 for (int i = 0; i < 3; i++) {
1475 if (dict->previous_offsets[i] > src_len) {
1476 ERROR("Dictionary corrupted");
1480 // "Content : The rest of the dictionary is its content. The content act as
1481 // a "past" in front of data to compress or decompress, so it can be
1482 // referenced in sequence commands."
1483 init_dictionary_content(dict, &in);
1486 static void init_dictionary_content(dictionary_t *const dict,
1487 istream_t *const in) {
1488 // Copy in the content
1489 dict->content_size = IO_istream_len(in);
1490 dict->content = malloc(dict->content_size);
1491 if (!dict->content) {
1495 const u8 *const content = IO_get_read_ptr(in, dict->content_size);
1497 memcpy(dict->content, content, dict->content_size);
1500 static void HUF_copy_dtable(HUF_dtable *const dst,
1501 const HUF_dtable *const src) {
1502 if (src->max_bits == 0) {
1503 memset(dst, 0, sizeof(HUF_dtable));
1507 const size_t size = (size_t)1 << src->max_bits;
1508 dst->max_bits = src->max_bits;
1510 dst->symbols = malloc(size);
1511 dst->num_bits = malloc(size);
1512 if (!dst->symbols || !dst->num_bits) {
1516 memcpy(dst->symbols, src->symbols, size);
1517 memcpy(dst->num_bits, src->num_bits, size);
1520 static void FSE_copy_dtable(FSE_dtable *const dst, const FSE_dtable *const src) {
1521 if (src->accuracy_log == 0) {
1522 memset(dst, 0, sizeof(FSE_dtable));
1526 size_t size = (size_t)1 << src->accuracy_log;
1527 dst->accuracy_log = src->accuracy_log;
1529 dst->symbols = malloc(size);
1530 dst->num_bits = malloc(size);
1531 dst->new_state_base = malloc(size * sizeof(u16));
1532 if (!dst->symbols || !dst->num_bits || !dst->new_state_base) {
1536 memcpy(dst->symbols, src->symbols, size);
1537 memcpy(dst->num_bits, src->num_bits, size);
1538 memcpy(dst->new_state_base, src->new_state_base, size * sizeof(u16));
1541 /// A dictionary acts as initializing values for the frame context before
1542 /// decompression, so we implement it by applying it's predetermined
1543 /// tables and content to the context before beginning decompression
1544 static void frame_context_apply_dict(frame_context_t *const ctx,
1545 const dictionary_t *const dict) {
1546 // If the content pointer is NULL then it must be an empty dict
1547 if (!dict || !dict->content)
1550 // If the requested dictionary_id is non-zero, the correct dictionary must
1552 if (ctx->header.dictionary_id != 0 &&
1553 ctx->header.dictionary_id != dict->dictionary_id) {
1554 ERROR("Wrong dictionary provided");
1557 // Copy the dict content to the context for references during sequence
1559 ctx->dict_content = dict->content;
1560 ctx->dict_content_len = dict->content_size;
1562 // If it's a formatted dict copy the precomputed tables in so they can
1563 // be used in the table repeat modes
1564 if (dict->dictionary_id != 0) {
1565 // Deep copy the entropy tables so they can be freed independently of
1566 // the dictionary struct
1567 HUF_copy_dtable(&ctx->literals_dtable, &dict->literals_dtable);
1568 FSE_copy_dtable(&ctx->ll_dtable, &dict->ll_dtable);
1569 FSE_copy_dtable(&ctx->of_dtable, &dict->of_dtable);
1570 FSE_copy_dtable(&ctx->ml_dtable, &dict->ml_dtable);
1572 // Copy the repeated offsets
1573 memcpy(ctx->previous_offsets, dict->previous_offsets,
1574 sizeof(ctx->previous_offsets));
1578 #else // ZDEC_NO_DICTIONARY is defined
1580 static void frame_context_apply_dict(frame_context_t *const ctx,
1581 const dictionary_t *const dict) {
1583 if (dict && dict->content) ERROR("dictionary not supported");
1587 /******* END DICTIONARY PARSING ***********************************************/
1589 /******* IO STREAM OPERATIONS *************************************************/
1591 /// Reads `num` bits from a bitstream, and updates the internal offset
1592 static inline u64 IO_read_bits(istream_t *const in, const int num_bits) {
1593 if (num_bits > 64 || num_bits <= 0) {
1594 ERROR("Attempt to read an invalid number of bits");
1597 const size_t bytes = (num_bits + in->bit_offset + 7) / 8;
1598 const size_t full_bytes = (num_bits + in->bit_offset) / 8;
1599 if (bytes > in->len) {
1603 const u64 result = read_bits_LE(in->ptr, num_bits, in->bit_offset);
1605 in->bit_offset = (num_bits + in->bit_offset) % 8;
1606 in->ptr += full_bytes;
1607 in->len -= full_bytes;
1612 /// If a non-zero number of bits have been read from the current byte, advance
1613 /// the offset to the next byte
1614 static inline void IO_rewind_bits(istream_t *const in, int num_bits) {
1616 ERROR("Attempting to rewind stream by a negative number of bits");
1619 // move the offset back by `num_bits` bits
1620 const int new_offset = in->bit_offset - num_bits;
1621 // determine the number of whole bytes we have to rewind, rounding up to an
1622 // integer number (e.g. if `new_offset == -5`, `bytes == 1`)
1623 const i64 bytes = -(new_offset - 7) / 8;
1627 // make sure the resulting `bit_offset` is positive, as mod in C does not
1628 // convert numbers from negative to positive (e.g. -22 % 8 == -6)
1629 in->bit_offset = ((new_offset % 8) + 8) % 8;
1632 /// If the remaining bits in a byte will be unused, advance to the end of the
1634 static inline void IO_align_stream(istream_t *const in) {
1635 if (in->bit_offset != 0) {
1645 /// Write the given byte into the output stream
1646 static inline void IO_write_byte(ostream_t *const out, u8 symb) {
1647 if (out->len == 0) {
1656 /// Returns the number of bytes left to be read in this stream. The stream must
1657 /// be byte aligned.
1658 static inline size_t IO_istream_len(const istream_t *const in) {
1662 /// Returns a pointer where `len` bytes can be read, and advances the internal
1663 /// state. The stream must be byte aligned.
1664 static inline const u8 *IO_get_read_ptr(istream_t *const in, size_t len) {
1665 if (len > in->len) {
1668 if (in->bit_offset != 0) {
1669 ERROR("Attempting to operate on a non-byte aligned stream");
1671 const u8 *const ptr = in->ptr;
1677 /// Returns a pointer to write `len` bytes to, and advances the internal state
1678 static inline u8 *IO_get_write_ptr(ostream_t *const out, size_t len) {
1679 if (len > out->len) {
1682 u8 *const ptr = out->ptr;
1689 /// Advance the inner state by `len` bytes
1690 static inline void IO_advance_input(istream_t *const in, size_t len) {
1691 if (len > in->len) {
1694 if (in->bit_offset != 0) {
1695 ERROR("Attempting to operate on a non-byte aligned stream");
1702 /// Returns an `ostream_t` constructed from the given pointer and length
1703 static inline ostream_t IO_make_ostream(u8 *out, size_t len) {
1704 return (ostream_t) { out, len };
1707 /// Returns an `istream_t` constructed from the given pointer and length
1708 static inline istream_t IO_make_istream(const u8 *in, size_t len) {
1709 return (istream_t) { in, len, 0 };
1712 /// Returns an `istream_t` with the same base as `in`, and length `len`
1713 /// Then, advance `in` to account for the consumed bytes
1714 /// `in` must be byte aligned
1715 static inline istream_t IO_make_sub_istream(istream_t *const in, size_t len) {
1716 // Consume `len` bytes of the parent stream
1717 const u8 *const ptr = IO_get_read_ptr(in, len);
1719 // Make a substream using the pointer to those `len` bytes
1720 return IO_make_istream(ptr, len);
1722 /******* END IO STREAM OPERATIONS *********************************************/
1724 /******* BITSTREAM OPERATIONS *************************************************/
1725 /// Read `num` bits (up to 64) from `src + offset`, where `offset` is in bits
1726 static inline u64 read_bits_LE(const u8 *src, const int num_bits,
1727 const size_t offset) {
1728 if (num_bits > 64) {
1729 ERROR("Attempt to read an invalid number of bits");
1732 // Skip over bytes that aren't in range
1734 size_t bit_offset = offset % 8;
1738 int left = num_bits;
1740 u64 mask = left >= 8 ? 0xff : (((u64)1 << left) - 1);
1741 // Read the next byte, shift it to account for the offset, and then mask
1742 // out the top part if we don't need all the bits
1743 res += (((u64)*src++ >> bit_offset) & mask) << shift;
1744 shift += 8 - bit_offset;
1745 left -= 8 - bit_offset;
1752 /// Read bits from the end of a HUF or FSE bitstream. `offset` is in bits, so
1753 /// it updates `offset` to `offset - bits`, and then reads `bits` bits from
1754 /// `src + offset`. If the offset becomes negative, the extra bits at the
1755 /// bottom are filled in with `0` bits instead of reading from before `src`.
1756 static inline u64 STREAM_read_bits(const u8 *const src, const int bits,
1757 i64 *const offset) {
1758 *offset = *offset - bits;
1759 size_t actual_off = *offset;
1760 size_t actual_bits = bits;
1761 // Don't actually read bits from before the start of src, so if `*offset <
1762 // 0` fix actual_off and actual_bits to reflect the quantity to read
1764 actual_bits += *offset;
1767 u64 res = read_bits_LE(src, actual_bits, actual_off);
1770 // Fill in the bottom "overflowed" bits with 0's
1771 res = -*offset >= 64 ? 0 : (res << -*offset);
1775 /******* END BITSTREAM OPERATIONS *********************************************/
1777 /******* BIT COUNTING OPERATIONS **********************************************/
1778 /// Returns `x`, where `2^x` is the largest power of 2 less than or equal to
1779 /// `num`, or `-1` if `num == 0`.
1780 static inline int highest_set_bit(const u64 num) {
1781 for (int i = 63; i >= 0; i--) {
1782 if (((u64)1 << i) <= num) {
1788 /******* END BIT COUNTING OPERATIONS ******************************************/
1790 /******* HUFFMAN PRIMITIVES ***************************************************/
1791 static inline u8 HUF_decode_symbol(const HUF_dtable *const dtable,
1792 u16 *const state, const u8 *const src,
1793 i64 *const offset) {
1794 // Look up the symbol and number of bits to read
1795 const u8 symb = dtable->symbols[*state];
1796 const u8 bits = dtable->num_bits[*state];
1797 const u16 rest = STREAM_read_bits(src, bits, offset);
1798 // Shift `bits` bits out of the state, keeping the low order bits that
1799 // weren't necessary to determine this symbol. Then add in the new bits
1800 // read from the stream.
1801 *state = ((*state << bits) + rest) & (((u16)1 << dtable->max_bits) - 1);
1806 static inline void HUF_init_state(const HUF_dtable *const dtable,
1807 u16 *const state, const u8 *const src,
1808 i64 *const offset) {
1809 // Read in a full `dtable->max_bits` bits to initialize the state
1810 const u8 bits = dtable->max_bits;
1811 *state = STREAM_read_bits(src, bits, offset);
1814 static size_t HUF_decompress_1stream(const HUF_dtable *const dtable,
1815 ostream_t *const out,
1816 istream_t *const in) {
1817 const size_t len = IO_istream_len(in);
1821 const u8 *const src = IO_get_read_ptr(in, len);
1823 // "Each bitstream must be read backward, that is starting from the end down
1824 // to the beginning. Therefore it's necessary to know the size of each
1827 // It's also necessary to know exactly which bit is the latest. This is
1828 // detected by a final bit flag : the highest bit of latest byte is a
1829 // final-bit-flag. Consequently, a last byte of 0 is not possible. And the
1830 // final-bit-flag itself is not part of the useful bitstream. Hence, the
1831 // last byte contains between 0 and 7 useful bits."
1832 const int padding = 8 - highest_set_bit(src[len - 1]);
1834 // Offset starts at the end because HUF streams are read backwards
1835 i64 bit_offset = len * 8 - padding;
1838 HUF_init_state(dtable, &state, src, &bit_offset);
1840 size_t symbols_written = 0;
1841 while (bit_offset > -dtable->max_bits) {
1842 // Iterate over the stream, decoding one symbol at a time
1843 IO_write_byte(out, HUF_decode_symbol(dtable, &state, src, &bit_offset));
1846 // "The process continues up to reading the required number of symbols per
1847 // stream. If a bitstream is not entirely and exactly consumed, hence
1848 // reaching exactly its beginning position with all bits consumed, the
1849 // decoding process is considered faulty."
1851 // When all symbols have been decoded, the final state value shouldn't have
1852 // any data from the stream, so it should have "read" dtable->max_bits from
1853 // before the start of `src`
1854 // Therefore `offset`, the edge to start reading new bits at, should be
1855 // dtable->max_bits before the start of the stream
1856 if (bit_offset != -dtable->max_bits) {
1860 return symbols_written;
1863 static size_t HUF_decompress_4stream(const HUF_dtable *const dtable,
1864 ostream_t *const out, istream_t *const in) {
1865 // "Compressed size is provided explicitly : in the 4-streams variant,
1866 // bitstreams are preceded by 3 unsigned little-endian 16-bits values. Each
1867 // value represents the compressed size of one stream, in order. The last
1868 // stream size is deducted from total compressed size and from previously
1869 // decoded stream sizes"
1870 const size_t csize1 = IO_read_bits(in, 16);
1871 const size_t csize2 = IO_read_bits(in, 16);
1872 const size_t csize3 = IO_read_bits(in, 16);
1874 istream_t in1 = IO_make_sub_istream(in, csize1);
1875 istream_t in2 = IO_make_sub_istream(in, csize2);
1876 istream_t in3 = IO_make_sub_istream(in, csize3);
1877 istream_t in4 = IO_make_sub_istream(in, IO_istream_len(in));
1879 size_t total_output = 0;
1880 // Decode each stream independently for simplicity
1881 // If we wanted to we could decode all 4 at the same time for speed,
1882 // utilizing more execution units
1883 total_output += HUF_decompress_1stream(dtable, out, &in1);
1884 total_output += HUF_decompress_1stream(dtable, out, &in2);
1885 total_output += HUF_decompress_1stream(dtable, out, &in3);
1886 total_output += HUF_decompress_1stream(dtable, out, &in4);
1888 return total_output;
1891 /// Initializes a Huffman table using canonical Huffman codes
1892 /// For more explanation on canonical Huffman codes see
1893 /// https://www.cs.scranton.edu/~mccloske/courses/cmps340/huff_canonical_dec2015.html
1894 /// Codes within a level are allocated in symbol order (i.e. smaller symbols get
1896 static void HUF_init_dtable(HUF_dtable *const table, const u8 *const bits,
1897 const int num_symbs) {
1898 memset(table, 0, sizeof(HUF_dtable));
1899 if (num_symbs > HUF_MAX_SYMBS) {
1900 ERROR("Too many symbols for Huffman");
1904 u16 rank_count[HUF_MAX_BITS + 1];
1905 memset(rank_count, 0, sizeof(rank_count));
1907 // Count the number of symbols for each number of bits, and determine the
1908 // depth of the tree
1909 for (int i = 0; i < num_symbs; i++) {
1910 if (bits[i] > HUF_MAX_BITS) {
1911 ERROR("Huffman table depth too large");
1913 max_bits = MAX(max_bits, bits[i]);
1914 rank_count[bits[i]]++;
1917 const size_t table_size = 1 << max_bits;
1918 table->max_bits = max_bits;
1919 table->symbols = malloc(table_size);
1920 table->num_bits = malloc(table_size);
1922 if (!table->symbols || !table->num_bits) {
1923 free(table->symbols);
1924 free(table->num_bits);
1928 // "Symbols are sorted by Weight. Within same Weight, symbols keep natural
1929 // order. Symbols with a Weight of zero are removed. Then, starting from
1930 // lowest weight, prefix codes are distributed in order."
1932 u32 rank_idx[HUF_MAX_BITS + 1];
1933 // Initialize the starting codes for each rank (number of bits)
1934 rank_idx[max_bits] = 0;
1935 for (int i = max_bits; i >= 1; i--) {
1936 rank_idx[i - 1] = rank_idx[i] + rank_count[i] * (1 << (max_bits - i));
1937 // The entire range takes the same number of bits so we can memset it
1938 memset(&table->num_bits[rank_idx[i]], i, rank_idx[i - 1] - rank_idx[i]);
1941 if (rank_idx[0] != table_size) {
1945 // Allocate codes and fill in the table
1946 for (int i = 0; i < num_symbs; i++) {
1948 // Allocate a code for this symbol and set its range in the table
1949 const u16 code = rank_idx[bits[i]];
1950 // Since the code doesn't care about the bottom `max_bits - bits[i]`
1951 // bits of state, it gets a range that spans all possible values of
1953 const u16 len = 1 << (max_bits - bits[i]);
1954 memset(&table->symbols[code], i, len);
1955 rank_idx[bits[i]] += len;
1960 static void HUF_init_dtable_usingweights(HUF_dtable *const table,
1961 const u8 *const weights,
1962 const int num_symbs) {
1963 // +1 because the last weight is not transmitted in the header
1964 if (num_symbs + 1 > HUF_MAX_SYMBS) {
1965 ERROR("Too many symbols for Huffman");
1968 u8 bits[HUF_MAX_SYMBS];
1971 for (int i = 0; i < num_symbs; i++) {
1972 // Weights are in the same range as bit count
1973 if (weights[i] > HUF_MAX_BITS) {
1976 weight_sum += weights[i] > 0 ? (u64)1 << (weights[i] - 1) : 0;
1979 // Find the first power of 2 larger than the sum
1980 const int max_bits = highest_set_bit(weight_sum) + 1;
1981 const u64 left_over = ((u64)1 << max_bits) - weight_sum;
1982 // If the left over isn't a power of 2, the weights are invalid
1983 if (left_over & (left_over - 1)) {
1987 // left_over is used to find the last weight as it's not transmitted
1988 // by inverting 2^(weight - 1) we can determine the value of last_weight
1989 const int last_weight = highest_set_bit(left_over) + 1;
1991 for (int i = 0; i < num_symbs; i++) {
1992 // "Number_of_Bits = Number_of_Bits ? Max_Number_of_Bits + 1 - Weight : 0"
1993 bits[i] = weights[i] > 0 ? (max_bits + 1 - weights[i]) : 0;
1996 max_bits + 1 - last_weight; // Last weight is always non-zero
1998 HUF_init_dtable(table, bits, num_symbs + 1);
2001 static void HUF_free_dtable(HUF_dtable *const dtable) {
2002 free(dtable->symbols);
2003 free(dtable->num_bits);
2004 memset(dtable, 0, sizeof(HUF_dtable));
2006 /******* END HUFFMAN PRIMITIVES ***********************************************/
2008 /******* FSE PRIMITIVES *******************************************************/
2009 /// For more description of FSE see
2010 /// https://github.com/Cyan4973/FiniteStateEntropy/
2012 /// Allow a symbol to be decoded without updating state
2013 static inline u8 FSE_peek_symbol(const FSE_dtable *const dtable,
2015 return dtable->symbols[state];
2018 /// Consumes bits from the input and uses the current state to determine the
2020 static inline void FSE_update_state(const FSE_dtable *const dtable,
2021 u16 *const state, const u8 *const src,
2022 i64 *const offset) {
2023 const u8 bits = dtable->num_bits[*state];
2024 const u16 rest = STREAM_read_bits(src, bits, offset);
2025 *state = dtable->new_state_base[*state] + rest;
2028 /// Decodes a single FSE symbol and updates the offset
2029 static inline u8 FSE_decode_symbol(const FSE_dtable *const dtable,
2030 u16 *const state, const u8 *const src,
2031 i64 *const offset) {
2032 const u8 symb = FSE_peek_symbol(dtable, *state);
2033 FSE_update_state(dtable, state, src, offset);
2037 static inline void FSE_init_state(const FSE_dtable *const dtable,
2038 u16 *const state, const u8 *const src,
2039 i64 *const offset) {
2040 // Read in a full `accuracy_log` bits to initialize the state
2041 const u8 bits = dtable->accuracy_log;
2042 *state = STREAM_read_bits(src, bits, offset);
2045 static size_t FSE_decompress_interleaved2(const FSE_dtable *const dtable,
2046 ostream_t *const out,
2047 istream_t *const in) {
2048 const size_t len = IO_istream_len(in);
2052 const u8 *const src = IO_get_read_ptr(in, len);
2054 // "Each bitstream must be read backward, that is starting from the end down
2055 // to the beginning. Therefore it's necessary to know the size of each
2058 // It's also necessary to know exactly which bit is the latest. This is
2059 // detected by a final bit flag : the highest bit of latest byte is a
2060 // final-bit-flag. Consequently, a last byte of 0 is not possible. And the
2061 // final-bit-flag itself is not part of the useful bitstream. Hence, the
2062 // last byte contains between 0 and 7 useful bits."
2063 const int padding = 8 - highest_set_bit(src[len - 1]);
2064 i64 offset = len * 8 - padding;
2067 // "The first state (State1) encodes the even indexed symbols, and the
2068 // second (State2) encodes the odd indexes. State1 is initialized first, and
2069 // then State2, and they take turns decoding a single symbol and updating
2071 FSE_init_state(dtable, &state1, src, &offset);
2072 FSE_init_state(dtable, &state2, src, &offset);
2074 // Decode until we overflow the stream
2075 // Since we decode in reverse order, overflowing the stream is offset going
2077 size_t symbols_written = 0;
2079 // "The number of symbols to decode is determined by tracking bitStream
2080 // overflow condition: If updating state after decoding a symbol would
2081 // require more bits than remain in the stream, it is assumed the extra
2082 // bits are 0. Then, the symbols for each of the final states are
2083 // decoded and the process is complete."
2084 IO_write_byte(out, FSE_decode_symbol(dtable, &state1, src, &offset));
2087 // There's still a symbol to decode in state2
2088 IO_write_byte(out, FSE_peek_symbol(dtable, state2));
2093 IO_write_byte(out, FSE_decode_symbol(dtable, &state2, src, &offset));
2096 // There's still a symbol to decode in state1
2097 IO_write_byte(out, FSE_peek_symbol(dtable, state1));
2103 return symbols_written;
2106 static void FSE_init_dtable(FSE_dtable *const dtable,
2107 const i16 *const norm_freqs, const int num_symbs,
2108 const int accuracy_log) {
2109 if (accuracy_log > FSE_MAX_ACCURACY_LOG) {
2110 ERROR("FSE accuracy too large");
2112 if (num_symbs > FSE_MAX_SYMBS) {
2113 ERROR("Too many symbols for FSE");
2116 dtable->accuracy_log = accuracy_log;
2118 const size_t size = (size_t)1 << accuracy_log;
2119 dtable->symbols = malloc(size * sizeof(u8));
2120 dtable->num_bits = malloc(size * sizeof(u8));
2121 dtable->new_state_base = malloc(size * sizeof(u16));
2123 if (!dtable->symbols || !dtable->num_bits || !dtable->new_state_base) {
2127 // Used to determine how many bits need to be read for each state,
2128 // and where the destination range should start
2129 // Needs to be u16 because max value is 2 * max number of symbols,
2130 // which can be larger than a byte can store
2131 u16 state_desc[FSE_MAX_SYMBS];
2133 // "Symbols are scanned in their natural order for "less than 1"
2134 // probabilities. Symbols with this probability are being attributed a
2135 // single cell, starting from the end of the table. These symbols define a
2136 // full state reset, reading Accuracy_Log bits."
2137 int high_threshold = size;
2138 for (int s = 0; s < num_symbs; s++) {
2139 // Scan for low probability symbols to put at the top
2140 if (norm_freqs[s] == -1) {
2141 dtable->symbols[--high_threshold] = s;
2146 // "All remaining symbols are sorted in their natural order. Starting from
2147 // symbol 0 and table position 0, each symbol gets attributed as many cells
2148 // as its probability. Cell allocation is spread, not linear."
2149 // Place the rest in the table
2150 const u16 step = (size >> 1) + (size >> 3) + 3;
2151 const u16 mask = size - 1;
2153 for (int s = 0; s < num_symbs; s++) {
2154 if (norm_freqs[s] <= 0) {
2158 state_desc[s] = norm_freqs[s];
2160 for (int i = 0; i < norm_freqs[s]; i++) {
2161 // Give `norm_freqs[s]` states to symbol s
2162 dtable->symbols[pos] = s;
2163 // "A position is skipped if already occupied, typically by a "less
2164 // than 1" probability symbol."
2166 pos = (pos + step) & mask;
2169 // Note: no other collision checking is necessary as `step` is
2170 // coprime to `size`, so the cycle will visit each position exactly
2178 // Now we can fill baseline and num bits
2179 for (size_t i = 0; i < size; i++) {
2180 u8 symbol = dtable->symbols[i];
2181 u16 next_state_desc = state_desc[symbol]++;
2182 // Fills in the table appropriately, next_state_desc increases by symbol
2183 // over time, decreasing number of bits
2184 dtable->num_bits[i] = (u8)(accuracy_log - highest_set_bit(next_state_desc));
2185 // Baseline increases until the bit threshold is passed, at which point
2187 dtable->new_state_base[i] =
2188 ((u16)next_state_desc << dtable->num_bits[i]) - size;
2192 /// Decode an FSE header as defined in the Zstandard format specification and
2193 /// use the decoded frequencies to initialize a decoding table.
2194 static void FSE_decode_header(FSE_dtable *const dtable, istream_t *const in,
2195 const int max_accuracy_log) {
2196 // "An FSE distribution table describes the probabilities of all symbols
2197 // from 0 to the last present one (included) on a normalized scale of 1 <<
2200 // It's a bitstream which is read forward, in little-endian fashion. It's
2201 // not necessary to know its exact size, since it will be discovered and
2202 // reported by the decoding process.
2203 if (max_accuracy_log > FSE_MAX_ACCURACY_LOG) {
2204 ERROR("FSE accuracy too large");
2207 // The bitstream starts by reporting on which scale it operates.
2208 // Accuracy_Log = low4bits + 5. Note that maximum Accuracy_Log for literal
2209 // and match lengths is 9, and for offsets is 8. Higher values are
2210 // considered errors."
2211 const int accuracy_log = 5 + IO_read_bits(in, 4);
2212 if (accuracy_log > max_accuracy_log) {
2213 ERROR("FSE accuracy too large");
2216 // "Then follows each symbol value, from 0 to last present one. The number
2217 // of bits used by each field is variable. It depends on :
2219 // Remaining probabilities + 1 : example : Presuming an Accuracy_Log of 8,
2220 // and presuming 100 probabilities points have already been distributed, the
2221 // decoder may read any value from 0 to 255 - 100 + 1 == 156 (inclusive).
2222 // Therefore, it must read log2sup(156) == 8 bits.
2224 // Value decoded : small values use 1 less bit : example : Presuming values
2225 // from 0 to 156 (inclusive) are possible, 255-156 = 99 values are remaining
2226 // in an 8-bits field. They are used this way : first 99 values (hence from
2227 // 0 to 98) use only 7 bits, values from 99 to 156 use 8 bits. "
2229 i32 remaining = 1 << accuracy_log;
2230 i16 frequencies[FSE_MAX_SYMBS];
2233 while (remaining > 0 && symb < FSE_MAX_SYMBS) {
2234 // Log of the number of possible values we could read
2235 int bits = highest_set_bit(remaining + 1) + 1;
2237 u16 val = IO_read_bits(in, bits);
2239 // Try to mask out the lower bits to see if it qualifies for the "small
2241 const u16 lower_mask = ((u16)1 << (bits - 1)) - 1;
2242 const u16 threshold = ((u16)1 << bits) - 1 - (remaining + 1);
2244 if ((val & lower_mask) < threshold) {
2245 IO_rewind_bits(in, 1);
2246 val = val & lower_mask;
2247 } else if (val > lower_mask) {
2248 val = val - threshold;
2251 // "Probability is obtained from Value decoded by following formula :
2252 // Proba = value - 1"
2253 const i16 proba = (i16)val - 1;
2255 // "It means value 0 becomes negative probability -1. -1 is a special
2256 // probability, which means "less than 1". Its effect on distribution
2257 // table is described in next paragraph. For the purpose of calculating
2258 // cumulated distribution, it counts as one."
2259 remaining -= proba < 0 ? -proba : proba;
2261 frequencies[symb] = proba;
2264 // "When a symbol has a probability of zero, it is followed by a 2-bits
2265 // repeat flag. This repeat flag tells how many probabilities of zeroes
2266 // follow the current one. It provides a number ranging from 0 to 3. If
2267 // it is a 3, another 2-bits repeat flag follows, and so on."
2269 // Read the next two bits to see how many more 0s
2270 int repeat = IO_read_bits(in, 2);
2273 for (int i = 0; i < repeat && symb < FSE_MAX_SYMBS; i++) {
2274 frequencies[symb++] = 0;
2277 repeat = IO_read_bits(in, 2);
2284 IO_align_stream(in);
2286 // "When last symbol reaches cumulated total of 1 << Accuracy_Log, decoding
2287 // is complete. If the last symbol makes cumulated total go above 1 <<
2288 // Accuracy_Log, distribution is considered corrupted."
2289 if (remaining != 0 || symb >= FSE_MAX_SYMBS) {
2293 // Initialize the decoding table using the determined weights
2294 FSE_init_dtable(dtable, frequencies, symb, accuracy_log);
2297 static void FSE_init_dtable_rle(FSE_dtable *const dtable, const u8 symb) {
2298 dtable->symbols = malloc(sizeof(u8));
2299 dtable->num_bits = malloc(sizeof(u8));
2300 dtable->new_state_base = malloc(sizeof(u16));
2302 if (!dtable->symbols || !dtable->num_bits || !dtable->new_state_base) {
2306 // This setup will always have a state of 0, always return symbol `symb`,
2307 // and never consume any bits
2308 dtable->symbols[0] = symb;
2309 dtable->num_bits[0] = 0;
2310 dtable->new_state_base[0] = 0;
2311 dtable->accuracy_log = 0;
2314 static void FSE_free_dtable(FSE_dtable *const dtable) {
2315 free(dtable->symbols);
2316 free(dtable->num_bits);
2317 free(dtable->new_state_base);
2318 memset(dtable, 0, sizeof(FSE_dtable));
2320 /******* END FSE PRIMITIVES ***************************************************/