git subrepo pull (merge) --force deps/libchdr
[pcsx_rearmed.git] / deps / libchdr / deps / zstd-1.5.6 / lib / decompress / zstd_decompress.c
 /*-*******************************************************
 *  Dependencies
 *********************************************************/
-#include "../common/allocations.h"  /* ZSTD_customMalloc, ZSTD_customCalloc, ZSTD_customFree */
 #include "../common/zstd_deps.h"   /* ZSTD_memcpy, ZSTD_memmove, ZSTD_memset */
+#include "../common/allocations.h"  /* ZSTD_customMalloc, ZSTD_customCalloc, ZSTD_customFree */
+#include "../common/error_private.h"
+#include "../common/zstd_internal.h"  /* blockProperties_t */
 #include "../common/mem.h"         /* low level memory routines */
+#include "../common/bits.h"  /* ZSTD_highbit32 */
 #define FSE_STATIC_LINKING_ONLY
 #include "../common/fse.h"
 #include "../common/huf.h"
 #include "../common/xxhash.h" /* XXH64_reset, XXH64_update, XXH64_digest, XXH64 */
-#include "../common/zstd_internal.h"  /* blockProperties_t */
 #include "zstd_decompress_internal.h"   /* ZSTD_DCtx */
 #include "zstd_ddict.h"  /* ZSTD_DDictDictContent */
 #include "zstd_decompress_block.h"   /* ZSTD_decompressBlock_internal */
-#include "../common/bits.h"  /* ZSTD_highbit32 */
 
 #if defined(ZSTD_LEGACY_SUPPORT) && (ZSTD_LEGACY_SUPPORT>=1)
 #  include "../legacy/zstd_legacy.h"
@@ -245,6 +246,7 @@ static void ZSTD_DCtx_resetParameters(ZSTD_DCtx* dctx)
     dctx->forceIgnoreChecksum = ZSTD_d_validateChecksum;
     dctx->refMultipleDDicts = ZSTD_rmd_refSingleDDict;
     dctx->disableHufAsm = 0;
+    dctx->maxBlockSizeParam = 0;
 }
 
 static void ZSTD_initDCtx_internal(ZSTD_DCtx* dctx)
@@ -265,6 +267,7 @@ static void ZSTD_initDCtx_internal(ZSTD_DCtx* dctx)
 #endif
     dctx->noForwardProgress = 0;
     dctx->oversizedDuration = 0;
+    dctx->isFrameDecompression = 1;
 #if DYNAMIC_BMI2
     dctx->bmi2 = ZSTD_cpuSupportsBmi2();
 #endif
@@ -726,17 +729,17 @@ static ZSTD_frameSizeInfo ZSTD_errorFrameSizeInfo(size_t ret)
     return frameSizeInfo;
 }
 
-static ZSTD_frameSizeInfo ZSTD_findFrameSizeInfo(const void* src, size_t srcSize)
+static ZSTD_frameSizeInfo ZSTD_findFrameSizeInfo(const void* src, size_t srcSize, ZSTD_format_e format)
 {
     ZSTD_frameSizeInfo frameSizeInfo;
     ZSTD_memset(&frameSizeInfo, 0, sizeof(ZSTD_frameSizeInfo));
 
 #if defined(ZSTD_LEGACY_SUPPORT) && (ZSTD_LEGACY_SUPPORT >= 1)
-    if (ZSTD_isLegacy(src, srcSize))
+    if (format == ZSTD_f_zstd1 && ZSTD_isLegacy(src, srcSize))
         return ZSTD_findFrameSizeInfoLegacy(src, srcSize);
 #endif
 
-    if ((srcSize >= ZSTD_SKIPPABLEHEADERSIZE)
+    if (format == ZSTD_f_zstd1 && (srcSize >= ZSTD_SKIPPABLEHEADERSIZE)
         && (MEM_readLE32(src) & ZSTD_MAGIC_SKIPPABLE_MASK) == ZSTD_MAGIC_SKIPPABLE_START) {
         frameSizeInfo.compressedSize = readSkippableFrameSize(src, srcSize);
         assert(ZSTD_isError(frameSizeInfo.compressedSize) ||
@@ -750,7 +753,7 @@ static ZSTD_frameSizeInfo ZSTD_findFrameSizeInfo(const void* src, size_t srcSize
         ZSTD_frameHeader zfh;
 
         /* Extract Frame Header */
-        {   size_t const ret = ZSTD_getFrameHeader(&zfh, src, srcSize);
+        {   size_t const ret = ZSTD_getFrameHeader_advanced(&zfh, src, srcSize, format);
             if (ZSTD_isError(ret))
                 return ZSTD_errorFrameSizeInfo(ret);
             if (ret > 0)
@@ -793,15 +796,17 @@ static ZSTD_frameSizeInfo ZSTD_findFrameSizeInfo(const void* src, size_t srcSize
     }
 }
 
+static size_t ZSTD_findFrameCompressedSize_advanced(const void *src, size_t srcSize, ZSTD_format_e format) {
+    ZSTD_frameSizeInfo const frameSizeInfo = ZSTD_findFrameSizeInfo(src, srcSize, format);
+    return frameSizeInfo.compressedSize;
+}
+
 /** ZSTD_findFrameCompressedSize() :
- *  compatible with legacy mode
- *  `src` must point to the start of a ZSTD frame, ZSTD legacy frame, or skippable frame
- *  `srcSize` must be at least as large as the frame contained
- *  @return : the compressed size of the frame starting at `src` */
+ * See docs in zstd.h
+ * Note: compatible with legacy mode */
 size_t ZSTD_findFrameCompressedSize(const void *src, size_t srcSize)
 {
-    ZSTD_frameSizeInfo const frameSizeInfo = ZSTD_findFrameSizeInfo(src, srcSize);
-    return frameSizeInfo.compressedSize;
+    return ZSTD_findFrameCompressedSize_advanced(src, srcSize, ZSTD_f_zstd1);
 }
 
 /** ZSTD_decompressBound() :
@@ -815,7 +820,7 @@ unsigned long long ZSTD_decompressBound(const void* src, size_t srcSize)
     unsigned long long bound = 0;
     /* Iterate over each frame */
     while (srcSize > 0) {
-        ZSTD_frameSizeInfo const frameSizeInfo = ZSTD_findFrameSizeInfo(src, srcSize);
+        ZSTD_frameSizeInfo const frameSizeInfo = ZSTD_findFrameSizeInfo(src, srcSize, ZSTD_f_zstd1);
         size_t const compressedSize = frameSizeInfo.compressedSize;
         unsigned long long const decompressedBound = frameSizeInfo.decompressedBound;
         if (ZSTD_isError(compressedSize) || decompressedBound == ZSTD_CONTENTSIZE_ERROR)
@@ -835,7 +840,7 @@ size_t ZSTD_decompressionMargin(void const* src, size_t srcSize)
 
     /* Iterate over each frame */
     while (srcSize > 0) {
-        ZSTD_frameSizeInfo const frameSizeInfo = ZSTD_findFrameSizeInfo(src, srcSize);
+        ZSTD_frameSizeInfo const frameSizeInfo = ZSTD_findFrameSizeInfo(src, srcSize, ZSTD_f_zstd1);
         size_t const compressedSize = frameSizeInfo.compressedSize;
         unsigned long long const decompressedBound = frameSizeInfo.decompressedBound;
         ZSTD_frameHeader zfh;
@@ -971,6 +976,10 @@ static size_t ZSTD_decompressFrame(ZSTD_DCtx* dctx,
         ip += frameHeaderSize; remainingSrcSize -= frameHeaderSize;
     }
 
+    /* Shrink the blockSizeMax if enabled */
+    if (dctx->maxBlockSizeParam != 0)
+        dctx->fParams.blockSizeMax = MIN(dctx->fParams.blockSizeMax, (unsigned)dctx->maxBlockSizeParam);
+
     /* Loop on each block */
     while (1) {
         BYTE* oBlockEnd = oend;
@@ -1003,7 +1012,8 @@ static size_t ZSTD_decompressFrame(ZSTD_DCtx* dctx,
         switch(blockProperties.blockType)
         {
         case bt_compressed:
-            decodedSize = ZSTD_decompressBlock_internal(dctx, op, (size_t)(oBlockEnd-op), ip, cBlockSize, /* frame */ 1, not_streaming);
+            assert(dctx->isFrameDecompression == 1);
+            decodedSize = ZSTD_decompressBlock_internal(dctx, op, (size_t)(oBlockEnd-op), ip, cBlockSize, not_streaming);
             break;
         case bt_raw :
             /* Use oend instead of oBlockEnd because this function is safe to overlap. It uses memmove. */
@@ -1016,12 +1026,14 @@ static size_t ZSTD_decompressFrame(ZSTD_DCtx* dctx,
         default:
             RETURN_ERROR(corruption_detected, "invalid block type");
         }
-
-        if (ZSTD_isError(decodedSize)) return decodedSize;
-        if (dctx->validateChecksum)
+        FORWARD_IF_ERROR(decodedSize, "Block decompression failure");
+        DEBUGLOG(5, "Decompressed block of dSize = %u", (unsigned)decodedSize);
+        if (dctx->validateChecksum) {
             XXH64_update(&dctx->xxhState, op, decodedSize);
-        if (decodedSize != 0)
+        }
+        if (decodedSize) /* support dst = NULL,0 */ {
             op += decodedSize;
+        }
         assert(ip != NULL);
         ip += cBlockSize;
         remainingSrcSize -= cBlockSize;
@@ -1051,7 +1063,9 @@ static size_t ZSTD_decompressFrame(ZSTD_DCtx* dctx,
     return (size_t)(op-ostart);
 }
 
-static size_t ZSTD_decompressMultiFrame(ZSTD_DCtx* dctx,
+static
+ZSTD_ALLOW_POINTER_OVERFLOW_ATTR
+size_t ZSTD_decompressMultiFrame(ZSTD_DCtx* dctx,
                                         void* dst, size_t dstCapacity,
                                   const void* src, size_t srcSize,
                                   const void* dict, size_t dictSize,
@@ -1071,7 +1085,7 @@ static size_t ZSTD_decompressMultiFrame(ZSTD_DCtx* dctx,
     while (srcSize >= ZSTD_startingInputLength(dctx->format)) {
 
 #if defined(ZSTD_LEGACY_SUPPORT) && (ZSTD_LEGACY_SUPPORT >= 1)
-        if (ZSTD_isLegacy(src, srcSize)) {
+        if (dctx->format == ZSTD_f_zstd1 && ZSTD_isLegacy(src, srcSize)) {
             size_t decodedSize;
             size_t const frameSize = ZSTD_findFrameCompressedSizeLegacy(src, srcSize);
             if (ZSTD_isError(frameSize)) return frameSize;
@@ -1081,6 +1095,15 @@ static size_t ZSTD_decompressMultiFrame(ZSTD_DCtx* dctx,
             decodedSize = ZSTD_decompressLegacy(dst, dstCapacity, src, frameSize, dict, dictSize);
             if (ZSTD_isError(decodedSize)) return decodedSize;
 
+            {
+                unsigned long long const expectedSize = ZSTD_getFrameContentSize(src, srcSize);
+                RETURN_ERROR_IF(expectedSize == ZSTD_CONTENTSIZE_ERROR, corruption_detected, "Corrupted frame header!");
+                if (expectedSize != ZSTD_CONTENTSIZE_UNKNOWN) {
+                    RETURN_ERROR_IF(expectedSize != decodedSize, corruption_detected,
+                        "Frame header size does not match decoded size!");
+                }
+            }
+
             assert(decodedSize <= dstCapacity);
             dst = (BYTE*)dst + decodedSize;
             dstCapacity -= decodedSize;
@@ -1092,7 +1115,7 @@ static size_t ZSTD_decompressMultiFrame(ZSTD_DCtx* dctx,
         }
 #endif
 
-        if (srcSize >= 4) {
+        if (dctx->format == ZSTD_f_zstd1 && srcSize >= 4) {
             U32 const magicNumber = MEM_readLE32(src);
             DEBUGLOG(5, "reading magic number %08X", (unsigned)magicNumber);
             if ((magicNumber & ZSTD_MAGIC_SKIPPABLE_MASK) == ZSTD_MAGIC_SKIPPABLE_START) {
@@ -1319,7 +1342,8 @@ size_t ZSTD_decompressContinue(ZSTD_DCtx* dctx, void* dst, size_t dstCapacity, c
             {
             case bt_compressed:
                 DEBUGLOG(5, "ZSTD_decompressContinue: case bt_compressed");
-                rSize = ZSTD_decompressBlock_internal(dctx, dst, dstCapacity, src, srcSize, /* frame */ 1, is_streaming);
+                assert(dctx->isFrameDecompression == 1);
+                rSize = ZSTD_decompressBlock_internal(dctx, dst, dstCapacity, src, srcSize, is_streaming);
                 dctx->expected = 0;  /* Streaming not supported */
                 break;
             case bt_raw :
@@ -1388,6 +1412,7 @@ size_t ZSTD_decompressContinue(ZSTD_DCtx* dctx, void* dst, size_t dstCapacity, c
     case ZSTDds_decodeSkippableHeader:
         assert(src != NULL);
         assert(srcSize <= ZSTD_SKIPPABLEHEADERSIZE);
+        assert(dctx->format != ZSTD_f_zstd1_magicless);
         ZSTD_memcpy(dctx->headerBuffer + (ZSTD_SKIPPABLEHEADERSIZE - srcSize), src, srcSize);   /* complete skippable header */
         dctx->expected = MEM_readLE32(dctx->headerBuffer + ZSTD_FRAMEIDSIZE);   /* note : dctx->expected can grow seriously large, beyond local buffer size */
         dctx->stage = ZSTDds_skipFrame;
@@ -1548,6 +1573,7 @@ size_t ZSTD_decompressBegin(ZSTD_DCtx* dctx)
     dctx->litEntropy = dctx->fseEntropy = 0;
     dctx->dictID = 0;
     dctx->bType = bt_reserved;
+    dctx->isFrameDecompression = 1;
     ZSTD_STATIC_ASSERT(sizeof(dctx->entropy.rep) == sizeof(repStartValue));
     ZSTD_memcpy(dctx->entropy.rep, repStartValue, sizeof(repStartValue));  /* initial repcodes */
     dctx->LLTptr = dctx->entropy.LLTable;
@@ -1819,6 +1845,10 @@ ZSTD_bounds ZSTD_dParam_getBounds(ZSTD_dParameter dParam)
             bounds.lowerBound = 0;
             bounds.upperBound = 1;
             return bounds;
+        case ZSTD_d_maxBlockSize:
+            bounds.lowerBound = ZSTD_BLOCKSIZE_MAX_MIN;
+            bounds.upperBound = ZSTD_BLOCKSIZE_MAX;
+            return bounds;
 
         default:;
     }
@@ -1863,6 +1893,9 @@ size_t ZSTD_DCtx_getParameter(ZSTD_DCtx* dctx, ZSTD_dParameter param, int* value
         case ZSTD_d_disableHuffmanAssembly:
             *value = (int)dctx->disableHufAsm;
             return 0;
+        case ZSTD_d_maxBlockSize:
+            *value = dctx->maxBlockSizeParam;
+            return 0;
         default:;
     }
     RETURN_ERROR(parameter_unsupported, "");
@@ -1900,6 +1933,10 @@ size_t ZSTD_DCtx_setParameter(ZSTD_DCtx* dctx, ZSTD_dParameter dParam, int value
             CHECK_DBOUNDS(ZSTD_d_disableHuffmanAssembly, value);
             dctx->disableHufAsm = value != 0;
             return 0;
+        case ZSTD_d_maxBlockSize:
+            if (value != 0) CHECK_DBOUNDS(ZSTD_d_maxBlockSize, value);
+            dctx->maxBlockSizeParam = value;
+            return 0;
         default:;
     }
     RETURN_ERROR(parameter_unsupported, "");
@@ -1911,6 +1948,7 @@ size_t ZSTD_DCtx_reset(ZSTD_DCtx* dctx, ZSTD_ResetDirective reset)
       || (reset == ZSTD_reset_session_and_parameters) ) {
         dctx->streamStage = zdss_init;
         dctx->noForwardProgress = 0;
+        dctx->isFrameDecompression = 1;
     }
     if ( (reset == ZSTD_reset_parameters)
       || (reset == ZSTD_reset_session_and_parameters) ) {
@@ -1927,11 +1965,17 @@ size_t ZSTD_sizeof_DStream(const ZSTD_DStream* dctx)
     return ZSTD_sizeof_DCtx(dctx);
 }
 
-size_t ZSTD_decodingBufferSize_min(unsigned long long windowSize, unsigned long long frameContentSize)
+static size_t ZSTD_decodingBufferSize_internal(unsigned long long windowSize, unsigned long long frameContentSize, size_t blockSizeMax)
 {
-    size_t const blockSize = (size_t) MIN(windowSize, ZSTD_BLOCKSIZE_MAX);
-    /* space is needed to store the litbuffer after the output of a given block without stomping the extDict of a previous run, as well as to cover both windows against wildcopy*/
-    unsigned long long const neededRBSize = windowSize + blockSize + ZSTD_BLOCKSIZE_MAX + (WILDCOPY_OVERLENGTH * 2);
+    size_t const blockSize = MIN((size_t)MIN(windowSize, ZSTD_BLOCKSIZE_MAX), blockSizeMax);
+    /* We need blockSize + WILDCOPY_OVERLENGTH worth of buffer so that if a block
+     * ends at windowSize + WILDCOPY_OVERLENGTH + 1 bytes, we can start writing
+     * the block at the beginning of the output buffer, and maintain a full window.
+     *
+     * We need another blockSize worth of buffer so that we can store split
+     * literals at the end of the block without overwriting the extDict window.
+     */
+    unsigned long long const neededRBSize = windowSize + (blockSize * 2) + (WILDCOPY_OVERLENGTH * 2);
     unsigned long long const neededSize = MIN(frameContentSize, neededRBSize);
     size_t const minRBSize = (size_t) neededSize;
     RETURN_ERROR_IF((unsigned long long)minRBSize != neededSize,
@@ -1939,6 +1983,11 @@ size_t ZSTD_decodingBufferSize_min(unsigned long long windowSize, unsigned long
     return minRBSize;
 }
 
+size_t ZSTD_decodingBufferSize_min(unsigned long long windowSize, unsigned long long frameContentSize)
+{
+    return ZSTD_decodingBufferSize_internal(windowSize, frameContentSize, ZSTD_BLOCKSIZE_MAX);
+}
+
 size_t ZSTD_estimateDStreamSize(size_t windowSize)
 {
     size_t const blockSize = MIN(windowSize, ZSTD_BLOCKSIZE_MAX);
@@ -2134,12 +2183,12 @@ size_t ZSTD_decompressStream(ZSTD_DStream* zds, ZSTD_outBuffer* output, ZSTD_inB
             if (zds->fParams.frameContentSize != ZSTD_CONTENTSIZE_UNKNOWN
                 && zds->fParams.frameType != ZSTD_skippableFrame
                 && (U64)(size_t)(oend-op) >= zds->fParams.frameContentSize) {
-                size_t const cSize = ZSTD_findFrameCompressedSize(istart, (size_t)(iend-istart));
+                size_t const cSize = ZSTD_findFrameCompressedSize_advanced(istart, (size_t)(iend-istart), zds->format);
                 if (cSize <= (size_t)(iend-istart)) {
                     /* shortcut : using single-pass mode */
                     size_t const decompressedSize = ZSTD_decompress_usingDDict(zds, op, (size_t)(oend-op), istart, cSize, ZSTD_getDDict(zds));
                     if (ZSTD_isError(decompressedSize)) return decompressedSize;
-                    DEBUGLOG(4, "shortcut to single-pass ZSTD_decompress_usingDDict()")
+                    DEBUGLOG(4, "shortcut to single-pass ZSTD_decompress_usingDDict()");
                     assert(istart != NULL);
                     ip = istart + cSize;
                     op = op ? op + decompressedSize : op; /* can occur if frameContentSize = 0 (empty frame) */
@@ -2161,7 +2210,8 @@ size_t ZSTD_decompressStream(ZSTD_DStream* zds, ZSTD_outBuffer* output, ZSTD_inB
             DEBUGLOG(4, "Consume header");
             FORWARD_IF_ERROR(ZSTD_decompressBegin_usingDDict(zds, ZSTD_getDDict(zds)), "");
 
-            if ((MEM_readLE32(zds->headerBuffer) & ZSTD_MAGIC_SKIPPABLE_MASK) == ZSTD_MAGIC_SKIPPABLE_START) {  /* skippable frame */
+            if (zds->format == ZSTD_f_zstd1
+                && (MEM_readLE32(zds->headerBuffer) & ZSTD_MAGIC_SKIPPABLE_MASK) == ZSTD_MAGIC_SKIPPABLE_START) {  /* skippable frame */
                 zds->expected = MEM_readLE32(zds->headerBuffer + ZSTD_FRAMEIDSIZE);
                 zds->stage = ZSTDds_skipFrame;
             } else {
@@ -2177,11 +2227,13 @@ size_t ZSTD_decompressStream(ZSTD_DStream* zds, ZSTD_outBuffer* output, ZSTD_inB
             zds->fParams.windowSize = MAX(zds->fParams.windowSize, 1U << ZSTD_WINDOWLOG_ABSOLUTEMIN);
             RETURN_ERROR_IF(zds->fParams.windowSize > zds->maxWindowSize,
                             frameParameter_windowTooLarge, "");
+            if (zds->maxBlockSizeParam != 0)
+                zds->fParams.blockSizeMax = MIN(zds->fParams.blockSizeMax, (unsigned)zds->maxBlockSizeParam);
 
             /* Adapt buffer sizes to frame header instructions */
             {   size_t const neededInBuffSize = MAX(zds->fParams.blockSizeMax, 4 /* frame checksum */);
                 size_t const neededOutBuffSize = zds->outBufferMode == ZSTD_bm_buffered
-                        ? ZSTD_decodingBufferSize_min(zds->fParams.windowSize, zds->fParams.frameContentSize)
+                        ? ZSTD_decodingBufferSize_internal(zds->fParams.windowSize, zds->fParams.frameContentSize, zds->fParams.blockSizeMax)
                         : 0;
 
                 ZSTD_DCtx_updateOversizedDuration(zds, neededInBuffSize, neededOutBuffSize);