1 /*
2  * Copyright (c) Yann Collet, Facebook, Inc.
3  * All rights reserved.
4  *
5  * This source code is licensed under both the BSD-style license (found in the
6  * LICENSE file in the root directory of this source tree) and the GPLv2 (found
7  * in the COPYING file in the root directory of this source tree).
8  * You may select, at your option, one of the above-listed licenses.
9  */
10 
11 
12 /* ***************************************************************
13 *  Tuning parameters
14 *****************************************************************/
15 /*!
16  * HEAPMODE :
17  * Select how default decompression function ZSTD_decompress() allocates its context,
18  * on stack (0), or into heap (1, default; requires malloc()).
19  * Note that functions with explicit context such as ZSTD_decompressDCtx() are unaffected.
20  */
21 #ifndef ZSTD_HEAPMODE
22 #  define ZSTD_HEAPMODE 1
23 #endif
24 
25 /*!
26 *  LEGACY_SUPPORT :
27 *  if set to 1+, ZSTD_decompress() can decode older formats (v0.1+)
28 */
29 
30 /*!
31  *  MAXWINDOWSIZE_DEFAULT :
32  *  maximum window size accepted by DStream __by default__.
33  *  Frames requiring more memory will be rejected.
34  *  It's possible to set a different limit using ZSTD_DCtx_setMaxWindowSize().
35  */
36 #ifndef ZSTD_MAXWINDOWSIZE_DEFAULT
37 #  define ZSTD_MAXWINDOWSIZE_DEFAULT (((U32)1 << ZSTD_WINDOWLOG_LIMIT_DEFAULT) + 1)
38 #endif
39 
40 /*!
41  *  NO_FORWARD_PROGRESS_MAX :
42  *  maximum allowed nb of calls to ZSTD_decompressStream()
43  *  without any forward progress
44  *  (defined as: no byte read from input, and no byte flushed to output)
45  *  before triggering an error.
46  */
47 #ifndef ZSTD_NO_FORWARD_PROGRESS_MAX
48 #  define ZSTD_NO_FORWARD_PROGRESS_MAX 16
49 #endif
50 
51 
52 /*-*******************************************************
53 *  Dependencies
54 *********************************************************/
55 #include "../common/zstd_deps.h"   /* ZSTD_memcpy, ZSTD_memmove, ZSTD_memset */
56 #include "../common/cpu.h"         /* bmi2 */
57 #include "../common/mem.h"         /* low level memory routines */
58 #define FSE_STATIC_LINKING_ONLY
59 #include "../common/fse.h"
60 #define HUF_STATIC_LINKING_ONLY
61 #include "../common/huf.h"
62 #include <linux/xxhash.h> /* xxh64_reset, xxh64_update, xxh64_digest, XXH64 */
63 #include "../common/zstd_internal.h"  /* blockProperties_t */
64 #include "zstd_decompress_internal.h"   /* ZSTD_DCtx */
65 #include "zstd_ddict.h"  /* ZSTD_DDictDictContent */
66 #include "zstd_decompress_block.h"   /* ZSTD_decompressBlock_internal */
67 
68 
69 
70 
71 /* ***********************************
72  * Multiple DDicts Hashset internals *
73  *************************************/
74 
75 #define DDICT_HASHSET_MAX_LOAD_FACTOR_COUNT_MULT 4
76 #define DDICT_HASHSET_MAX_LOAD_FACTOR_SIZE_MULT 3   /* These two constants represent SIZE_MULT/COUNT_MULT load factor without using a float.
77                                                      * Currently, that means a 0.75 load factor.
78                                                      * So, if count * COUNT_MULT / size * SIZE_MULT != 0, then we've exceeded
79                                                      * the load factor of the ddict hash set.
80                                                      */
81 
82 #define DDICT_HASHSET_TABLE_BASE_SIZE 64
83 #define DDICT_HASHSET_RESIZE_FACTOR 2
84 
85 /* Hash function to determine starting position of dict insertion within the table
86  * Returns an index between [0, hashSet->ddictPtrTableSize]
87  */
ZSTD_DDictHashSet_getIndex(const ZSTD_DDictHashSet * hashSet,U32 dictID)88 static size_t ZSTD_DDictHashSet_getIndex(const ZSTD_DDictHashSet* hashSet, U32 dictID) {
89     const U64 hash = xxh64(&dictID, sizeof(U32), 0);
90     /* DDict ptr table size is a multiple of 2, use size - 1 as mask to get index within [0, hashSet->ddictPtrTableSize) */
91     return hash & (hashSet->ddictPtrTableSize - 1);
92 }
93 
94 /* Adds DDict to a hashset without resizing it.
95  * If inserting a DDict with a dictID that already exists in the set, replaces the one in the set.
96  * Returns 0 if successful, or a zstd error code if something went wrong.
97  */
ZSTD_DDictHashSet_emplaceDDict(ZSTD_DDictHashSet * hashSet,const ZSTD_DDict * ddict)98 static size_t ZSTD_DDictHashSet_emplaceDDict(ZSTD_DDictHashSet* hashSet, const ZSTD_DDict* ddict) {
99     const U32 dictID = ZSTD_getDictID_fromDDict(ddict);
100     size_t idx = ZSTD_DDictHashSet_getIndex(hashSet, dictID);
101     const size_t idxRangeMask = hashSet->ddictPtrTableSize - 1;
102     RETURN_ERROR_IF(hashSet->ddictPtrCount == hashSet->ddictPtrTableSize, GENERIC, "Hash set is full!");
103     DEBUGLOG(4, "Hashed index: for dictID: %u is %zu", dictID, idx);
104     while (hashSet->ddictPtrTable[idx] != NULL) {
105         /* Replace existing ddict if inserting ddict with same dictID */
106         if (ZSTD_getDictID_fromDDict(hashSet->ddictPtrTable[idx]) == dictID) {
107             DEBUGLOG(4, "DictID already exists, replacing rather than adding");
108             hashSet->ddictPtrTable[idx] = ddict;
109             return 0;
110         }
111         idx &= idxRangeMask;
112         idx++;
113     }
114     DEBUGLOG(4, "Final idx after probing for dictID %u is: %zu", dictID, idx);
115     hashSet->ddictPtrTable[idx] = ddict;
116     hashSet->ddictPtrCount++;
117     return 0;
118 }
119 
120 /* Expands hash table by factor of DDICT_HASHSET_RESIZE_FACTOR and
121  * rehashes all values, allocates new table, frees old table.
122  * Returns 0 on success, otherwise a zstd error code.
123  */
ZSTD_DDictHashSet_expand(ZSTD_DDictHashSet * hashSet,ZSTD_customMem customMem)124 static size_t ZSTD_DDictHashSet_expand(ZSTD_DDictHashSet* hashSet, ZSTD_customMem customMem) {
125     size_t newTableSize = hashSet->ddictPtrTableSize * DDICT_HASHSET_RESIZE_FACTOR;
126     const ZSTD_DDict** newTable = (const ZSTD_DDict**)ZSTD_customCalloc(sizeof(ZSTD_DDict*) * newTableSize, customMem);
127     const ZSTD_DDict** oldTable = hashSet->ddictPtrTable;
128     size_t oldTableSize = hashSet->ddictPtrTableSize;
129     size_t i;
130 
131     DEBUGLOG(4, "Expanding DDict hash table! Old size: %zu new size: %zu", oldTableSize, newTableSize);
132     RETURN_ERROR_IF(!newTable, memory_allocation, "Expanded hashset allocation failed!");
133     hashSet->ddictPtrTable = newTable;
134     hashSet->ddictPtrTableSize = newTableSize;
135     hashSet->ddictPtrCount = 0;
136     for (i = 0; i < oldTableSize; ++i) {
137         if (oldTable[i] != NULL) {
138             FORWARD_IF_ERROR(ZSTD_DDictHashSet_emplaceDDict(hashSet, oldTable[i]), "");
139         }
140     }
141     ZSTD_customFree((void*)oldTable, customMem);
142     DEBUGLOG(4, "Finished re-hash");
143     return 0;
144 }
145 
146 /* Fetches a DDict with the given dictID
147  * Returns the ZSTD_DDict* with the requested dictID. If it doesn't exist, then returns NULL.
148  */
ZSTD_DDictHashSet_getDDict(ZSTD_DDictHashSet * hashSet,U32 dictID)149 static const ZSTD_DDict* ZSTD_DDictHashSet_getDDict(ZSTD_DDictHashSet* hashSet, U32 dictID) {
150     size_t idx = ZSTD_DDictHashSet_getIndex(hashSet, dictID);
151     const size_t idxRangeMask = hashSet->ddictPtrTableSize - 1;
152     DEBUGLOG(4, "Hashed index: for dictID: %u is %zu", dictID, idx);
153     for (;;) {
154         size_t currDictID = ZSTD_getDictID_fromDDict(hashSet->ddictPtrTable[idx]);
155         if (currDictID == dictID || currDictID == 0) {
156             /* currDictID == 0 implies a NULL ddict entry */
157             break;
158         } else {
159             idx &= idxRangeMask;    /* Goes to start of table when we reach the end */
160             idx++;
161         }
162     }
163     DEBUGLOG(4, "Final idx after probing for dictID %u is: %zu", dictID, idx);
164     return hashSet->ddictPtrTable[idx];
165 }
166 
167 /* Allocates space for and returns a ddict hash set
168  * The hash set's ZSTD_DDict* table has all values automatically set to NULL to begin with.
169  * Returns NULL if allocation failed.
170  */
ZSTD_createDDictHashSet(ZSTD_customMem customMem)171 static ZSTD_DDictHashSet* ZSTD_createDDictHashSet(ZSTD_customMem customMem) {
172     ZSTD_DDictHashSet* ret = (ZSTD_DDictHashSet*)ZSTD_customMalloc(sizeof(ZSTD_DDictHashSet), customMem);
173     DEBUGLOG(4, "Allocating new hash set");
174     if (!ret)
175         return NULL;
176     ret->ddictPtrTable = (const ZSTD_DDict**)ZSTD_customCalloc(DDICT_HASHSET_TABLE_BASE_SIZE * sizeof(ZSTD_DDict*), customMem);
177     if (!ret->ddictPtrTable) {
178         ZSTD_customFree(ret, customMem);
179         return NULL;
180     }
181     ret->ddictPtrTableSize = DDICT_HASHSET_TABLE_BASE_SIZE;
182     ret->ddictPtrCount = 0;
183     return ret;
184 }
185 
186 /* Frees the table of ZSTD_DDict* within a hashset, then frees the hashset itself.
187  * Note: The ZSTD_DDict* within the table are NOT freed.
188  */
ZSTD_freeDDictHashSet(ZSTD_DDictHashSet * hashSet,ZSTD_customMem customMem)189 static void ZSTD_freeDDictHashSet(ZSTD_DDictHashSet* hashSet, ZSTD_customMem customMem) {
190     DEBUGLOG(4, "Freeing ddict hash set");
191     if (hashSet && hashSet->ddictPtrTable) {
192         ZSTD_customFree((void*)hashSet->ddictPtrTable, customMem);
193     }
194     if (hashSet) {
195         ZSTD_customFree(hashSet, customMem);
196     }
197 }
198 
199 /* Public function: Adds a DDict into the ZSTD_DDictHashSet, possibly triggering a resize of the hash set.
200  * Returns 0 on success, or a ZSTD error.
201  */
ZSTD_DDictHashSet_addDDict(ZSTD_DDictHashSet * hashSet,const ZSTD_DDict * ddict,ZSTD_customMem customMem)202 static size_t ZSTD_DDictHashSet_addDDict(ZSTD_DDictHashSet* hashSet, const ZSTD_DDict* ddict, ZSTD_customMem customMem) {
203     DEBUGLOG(4, "Adding dict ID: %u to hashset with - Count: %zu Tablesize: %zu", ZSTD_getDictID_fromDDict(ddict), hashSet->ddictPtrCount, hashSet->ddictPtrTableSize);
204     if (hashSet->ddictPtrCount * DDICT_HASHSET_MAX_LOAD_FACTOR_COUNT_MULT / hashSet->ddictPtrTableSize * DDICT_HASHSET_MAX_LOAD_FACTOR_SIZE_MULT != 0) {
205         FORWARD_IF_ERROR(ZSTD_DDictHashSet_expand(hashSet, customMem), "");
206     }
207     FORWARD_IF_ERROR(ZSTD_DDictHashSet_emplaceDDict(hashSet, ddict), "");
208     return 0;
209 }
210 
211 /*-*************************************************************
212 *   Context management
213 ***************************************************************/
ZSTD_sizeof_DCtx(const ZSTD_DCtx * dctx)214 size_t ZSTD_sizeof_DCtx (const ZSTD_DCtx* dctx)
215 {
216     if (dctx==NULL) return 0;   /* support sizeof NULL */
217     return sizeof(*dctx)
218            + ZSTD_sizeof_DDict(dctx->ddictLocal)
219            + dctx->inBuffSize + dctx->outBuffSize;
220 }
221 
ZSTD_estimateDCtxSize(void)222 size_t ZSTD_estimateDCtxSize(void) { return sizeof(ZSTD_DCtx); }
223 
224 
ZSTD_startingInputLength(ZSTD_format_e format)225 static size_t ZSTD_startingInputLength(ZSTD_format_e format)
226 {
227     size_t const startingInputLength = ZSTD_FRAMEHEADERSIZE_PREFIX(format);
228     /* only supports formats ZSTD_f_zstd1 and ZSTD_f_zstd1_magicless */
229     assert( (format == ZSTD_f_zstd1) || (format == ZSTD_f_zstd1_magicless) );
230     return startingInputLength;
231 }
232 
ZSTD_DCtx_resetParameters(ZSTD_DCtx * dctx)233 static void ZSTD_DCtx_resetParameters(ZSTD_DCtx* dctx)
234 {
235     assert(dctx->streamStage == zdss_init);
236     dctx->format = ZSTD_f_zstd1;
237     dctx->maxWindowSize = ZSTD_MAXWINDOWSIZE_DEFAULT;
238     dctx->outBufferMode = ZSTD_bm_buffered;
239     dctx->forceIgnoreChecksum = ZSTD_d_validateChecksum;
240     dctx->refMultipleDDicts = ZSTD_rmd_refSingleDDict;
241 }
242 
ZSTD_initDCtx_internal(ZSTD_DCtx * dctx)243 static void ZSTD_initDCtx_internal(ZSTD_DCtx* dctx)
244 {
245     dctx->staticSize  = 0;
246     dctx->ddict       = NULL;
247     dctx->ddictLocal  = NULL;
248     dctx->dictEnd     = NULL;
249     dctx->ddictIsCold = 0;
250     dctx->dictUses = ZSTD_dont_use;
251     dctx->inBuff      = NULL;
252     dctx->inBuffSize  = 0;
253     dctx->outBuffSize = 0;
254     dctx->streamStage = zdss_init;
255     dctx->legacyContext = NULL;
256     dctx->previousLegacyVersion = 0;
257     dctx->noForwardProgress = 0;
258     dctx->oversizedDuration = 0;
259     dctx->bmi2 = ZSTD_cpuid_bmi2(ZSTD_cpuid());
260     dctx->ddictSet = NULL;
261     ZSTD_DCtx_resetParameters(dctx);
262 #ifdef FUZZING_BUILD_MODE_UNSAFE_FOR_PRODUCTION
263     dctx->dictContentEndForFuzzing = NULL;
264 #endif
265 }
266 
ZSTD_initStaticDCtx(void * workspace,size_t workspaceSize)267 ZSTD_DCtx* ZSTD_initStaticDCtx(void *workspace, size_t workspaceSize)
268 {
269     ZSTD_DCtx* const dctx = (ZSTD_DCtx*) workspace;
270 
271     if ((size_t)workspace & 7) return NULL;  /* 8-aligned */
272     if (workspaceSize < sizeof(ZSTD_DCtx)) return NULL;  /* minimum size */
273 
274     ZSTD_initDCtx_internal(dctx);
275     dctx->staticSize = workspaceSize;
276     dctx->inBuff = (char*)(dctx+1);
277     return dctx;
278 }
279 
ZSTD_createDCtx_advanced(ZSTD_customMem customMem)280 ZSTD_DCtx* ZSTD_createDCtx_advanced(ZSTD_customMem customMem)
281 {
282     if ((!customMem.customAlloc) ^ (!customMem.customFree)) return NULL;
283 
284     {   ZSTD_DCtx* const dctx = (ZSTD_DCtx*)ZSTD_customMalloc(sizeof(*dctx), customMem);
285         if (!dctx) return NULL;
286         dctx->customMem = customMem;
287         ZSTD_initDCtx_internal(dctx);
288         return dctx;
289     }
290 }
291 
ZSTD_createDCtx(void)292 ZSTD_DCtx* ZSTD_createDCtx(void)
293 {
294     DEBUGLOG(3, "ZSTD_createDCtx");
295     return ZSTD_createDCtx_advanced(ZSTD_defaultCMem);
296 }
297 
ZSTD_clearDict(ZSTD_DCtx * dctx)298 static void ZSTD_clearDict(ZSTD_DCtx* dctx)
299 {
300     ZSTD_freeDDict(dctx->ddictLocal);
301     dctx->ddictLocal = NULL;
302     dctx->ddict = NULL;
303     dctx->dictUses = ZSTD_dont_use;
304 }
305 
ZSTD_freeDCtx(ZSTD_DCtx * dctx)306 size_t ZSTD_freeDCtx(ZSTD_DCtx* dctx)
307 {
308     if (dctx==NULL) return 0;   /* support free on NULL */
309     RETURN_ERROR_IF(dctx->staticSize, memory_allocation, "not compatible with static DCtx");
310     {   ZSTD_customMem const cMem = dctx->customMem;
311         ZSTD_clearDict(dctx);
312         ZSTD_customFree(dctx->inBuff, cMem);
313         dctx->inBuff = NULL;
314         if (dctx->ddictSet) {
315             ZSTD_freeDDictHashSet(dctx->ddictSet, cMem);
316             dctx->ddictSet = NULL;
317         }
318         ZSTD_customFree(dctx, cMem);
319         return 0;
320     }
321 }
322 
323 /* no longer useful */
ZSTD_copyDCtx(ZSTD_DCtx * dstDCtx,const ZSTD_DCtx * srcDCtx)324 void ZSTD_copyDCtx(ZSTD_DCtx* dstDCtx, const ZSTD_DCtx* srcDCtx)
325 {
326     size_t const toCopy = (size_t)((char*)(&dstDCtx->inBuff) - (char*)dstDCtx);
327     ZSTD_memcpy(dstDCtx, srcDCtx, toCopy);  /* no need to copy workspace */
328 }
329 
330 /* Given a dctx with a digested frame params, re-selects the correct ZSTD_DDict based on
331  * the requested dict ID from the frame. If there exists a reference to the correct ZSTD_DDict, then
332  * accordingly sets the ddict to be used to decompress the frame.
333  *
334  * If no DDict is found, then no action is taken, and the ZSTD_DCtx::ddict remains as-is.
335  *
336  * ZSTD_d_refMultipleDDicts must be enabled for this function to be called.
337  */
ZSTD_DCtx_selectFrameDDict(ZSTD_DCtx * dctx)338 static void ZSTD_DCtx_selectFrameDDict(ZSTD_DCtx* dctx) {
339     assert(dctx->refMultipleDDicts && dctx->ddictSet);
340     DEBUGLOG(4, "Adjusting DDict based on requested dict ID from frame");
341     if (dctx->ddict) {
342         const ZSTD_DDict* frameDDict = ZSTD_DDictHashSet_getDDict(dctx->ddictSet, dctx->fParams.dictID);
343         if (frameDDict) {
344             DEBUGLOG(4, "DDict found!");
345             ZSTD_clearDict(dctx);
346             dctx->dictID = dctx->fParams.dictID;
347             dctx->ddict = frameDDict;
348             dctx->dictUses = ZSTD_use_indefinitely;
349         }
350     }
351 }
352 
353 
354 /*-*************************************************************
355  *   Frame header decoding
356  ***************************************************************/
357 
358 /*! ZSTD_isFrame() :
359  *  Tells if the content of `buffer` starts with a valid Frame Identifier.
360  *  Note : Frame Identifier is 4 bytes. If `size < 4`, @return will always be 0.
361  *  Note 2 : Legacy Frame Identifiers are considered valid only if Legacy Support is enabled.
362  *  Note 3 : Skippable Frame Identifiers are considered valid. */
ZSTD_isFrame(const void * buffer,size_t size)363 unsigned ZSTD_isFrame(const void* buffer, size_t size)
364 {
365     if (size < ZSTD_FRAMEIDSIZE) return 0;
366     {   U32 const magic = MEM_readLE32(buffer);
367         if (magic == ZSTD_MAGICNUMBER) return 1;
368         if ((magic & ZSTD_MAGIC_SKIPPABLE_MASK) == ZSTD_MAGIC_SKIPPABLE_START) return 1;
369     }
370     return 0;
371 }
372 
373 /* ZSTD_frameHeaderSize_internal() :
374  *  srcSize must be large enough to reach header size fields.
375  *  note : only works for formats ZSTD_f_zstd1 and ZSTD_f_zstd1_magicless.
376  * @return : size of the Frame Header
377  *           or an error code, which can be tested with ZSTD_isError() */
ZSTD_frameHeaderSize_internal(const void * src,size_t srcSize,ZSTD_format_e format)378 static size_t ZSTD_frameHeaderSize_internal(const void* src, size_t srcSize, ZSTD_format_e format)
379 {
380     size_t const minInputSize = ZSTD_startingInputLength(format);
381     RETURN_ERROR_IF(srcSize < minInputSize, srcSize_wrong, "");
382 
383     {   BYTE const fhd = ((const BYTE*)src)[minInputSize-1];
384         U32 const dictID= fhd & 3;
385         U32 const singleSegment = (fhd >> 5) & 1;
386         U32 const fcsId = fhd >> 6;
387         return minInputSize + !singleSegment
388              + ZSTD_did_fieldSize[dictID] + ZSTD_fcs_fieldSize[fcsId]
389              + (singleSegment && !fcsId);
390     }
391 }
392 
393 /* ZSTD_frameHeaderSize() :
394  *  srcSize must be >= ZSTD_frameHeaderSize_prefix.
395  * @return : size of the Frame Header,
396  *           or an error code (if srcSize is too small) */
ZSTD_frameHeaderSize(const void * src,size_t srcSize)397 size_t ZSTD_frameHeaderSize(const void* src, size_t srcSize)
398 {
399     return ZSTD_frameHeaderSize_internal(src, srcSize, ZSTD_f_zstd1);
400 }
401 
402 
403 /* ZSTD_getFrameHeader_advanced() :
404  *  decode Frame Header, or require larger `srcSize`.
405  *  note : only works for formats ZSTD_f_zstd1 and ZSTD_f_zstd1_magicless
406  * @return : 0, `zfhPtr` is correctly filled,
407  *          >0, `srcSize` is too small, value is wanted `srcSize` amount,
408  *           or an error code, which can be tested using ZSTD_isError() */
ZSTD_getFrameHeader_advanced(ZSTD_frameHeader * zfhPtr,const void * src,size_t srcSize,ZSTD_format_e format)409 size_t ZSTD_getFrameHeader_advanced(ZSTD_frameHeader* zfhPtr, const void* src, size_t srcSize, ZSTD_format_e format)
410 {
411     const BYTE* ip = (const BYTE*)src;
412     size_t const minInputSize = ZSTD_startingInputLength(format);
413 
414     ZSTD_memset(zfhPtr, 0, sizeof(*zfhPtr));   /* not strictly necessary, but static analyzer do not understand that zfhPtr is only going to be read only if return value is zero, since they are 2 different signals */
415     if (srcSize < minInputSize) return minInputSize;
416     RETURN_ERROR_IF(src==NULL, GENERIC, "invalid parameter");
417 
418     if ( (format != ZSTD_f_zstd1_magicless)
419       && (MEM_readLE32(src) != ZSTD_MAGICNUMBER) ) {
420         if ((MEM_readLE32(src) & ZSTD_MAGIC_SKIPPABLE_MASK) == ZSTD_MAGIC_SKIPPABLE_START) {
421             /* skippable frame */
422             if (srcSize < ZSTD_SKIPPABLEHEADERSIZE)
423                 return ZSTD_SKIPPABLEHEADERSIZE; /* magic number + frame length */
424             ZSTD_memset(zfhPtr, 0, sizeof(*zfhPtr));
425             zfhPtr->frameContentSize = MEM_readLE32((const char *)src + ZSTD_FRAMEIDSIZE);
426             zfhPtr->frameType = ZSTD_skippableFrame;
427             return 0;
428         }
429         RETURN_ERROR(prefix_unknown, "");
430     }
431 
432     /* ensure there is enough `srcSize` to fully read/decode frame header */
433     {   size_t const fhsize = ZSTD_frameHeaderSize_internal(src, srcSize, format);
434         if (srcSize < fhsize) return fhsize;
435         zfhPtr->headerSize = (U32)fhsize;
436     }
437 
438     {   BYTE const fhdByte = ip[minInputSize-1];
439         size_t pos = minInputSize;
440         U32 const dictIDSizeCode = fhdByte&3;
441         U32 const checksumFlag = (fhdByte>>2)&1;
442         U32 const singleSegment = (fhdByte>>5)&1;
443         U32 const fcsID = fhdByte>>6;
444         U64 windowSize = 0;
445         U32 dictID = 0;
446         U64 frameContentSize = ZSTD_CONTENTSIZE_UNKNOWN;
447         RETURN_ERROR_IF((fhdByte & 0x08) != 0, frameParameter_unsupported,
448                         "reserved bits, must be zero");
449 
450         if (!singleSegment) {
451             BYTE const wlByte = ip[pos++];
452             U32 const windowLog = (wlByte >> 3) + ZSTD_WINDOWLOG_ABSOLUTEMIN;
453             RETURN_ERROR_IF(windowLog > ZSTD_WINDOWLOG_MAX, frameParameter_windowTooLarge, "");
454             windowSize = (1ULL << windowLog);
455             windowSize += (windowSize >> 3) * (wlByte&7);
456         }
457         switch(dictIDSizeCode)
458         {
459             default:
460                 assert(0);  /* impossible */
461                 ZSTD_FALLTHROUGH;
462             case 0 : break;
463             case 1 : dictID = ip[pos]; pos++; break;
464             case 2 : dictID = MEM_readLE16(ip+pos); pos+=2; break;
465             case 3 : dictID = MEM_readLE32(ip+pos); pos+=4; break;
466         }
467         switch(fcsID)
468         {
469             default:
470                 assert(0);  /* impossible */
471                 ZSTD_FALLTHROUGH;
472             case 0 : if (singleSegment) frameContentSize = ip[pos]; break;
473             case 1 : frameContentSize = MEM_readLE16(ip+pos)+256; break;
474             case 2 : frameContentSize = MEM_readLE32(ip+pos); break;
475             case 3 : frameContentSize = MEM_readLE64(ip+pos); break;
476         }
477         if (singleSegment) windowSize = frameContentSize;
478 
479         zfhPtr->frameType = ZSTD_frame;
480         zfhPtr->frameContentSize = frameContentSize;
481         zfhPtr->windowSize = windowSize;
482         zfhPtr->blockSizeMax = (unsigned) MIN(windowSize, ZSTD_BLOCKSIZE_MAX);
483         zfhPtr->dictID = dictID;
484         zfhPtr->checksumFlag = checksumFlag;
485     }
486     return 0;
487 }
488 
489 /* ZSTD_getFrameHeader() :
490  *  decode Frame Header, or require larger `srcSize`.
491  *  note : this function does not consume input, it only reads it.
492  * @return : 0, `zfhPtr` is correctly filled,
493  *          >0, `srcSize` is too small, value is wanted `srcSize` amount,
494  *           or an error code, which can be tested using ZSTD_isError() */
ZSTD_getFrameHeader(ZSTD_frameHeader * zfhPtr,const void * src,size_t srcSize)495 size_t ZSTD_getFrameHeader(ZSTD_frameHeader* zfhPtr, const void* src, size_t srcSize)
496 {
497     return ZSTD_getFrameHeader_advanced(zfhPtr, src, srcSize, ZSTD_f_zstd1);
498 }
499 
500 
501 /* ZSTD_getFrameContentSize() :
502  *  compatible with legacy mode
503  * @return : decompressed size of the single frame pointed to be `src` if known, otherwise
504  *         - ZSTD_CONTENTSIZE_UNKNOWN if the size cannot be determined
505  *         - ZSTD_CONTENTSIZE_ERROR if an error occurred (e.g. invalid magic number, srcSize too small) */
ZSTD_getFrameContentSize(const void * src,size_t srcSize)506 unsigned long long ZSTD_getFrameContentSize(const void *src, size_t srcSize)
507 {
508     {   ZSTD_frameHeader zfh;
509         if (ZSTD_getFrameHeader(&zfh, src, srcSize) != 0)
510             return ZSTD_CONTENTSIZE_ERROR;
511         if (zfh.frameType == ZSTD_skippableFrame) {
512             return 0;
513         } else {
514             return zfh.frameContentSize;
515     }   }
516 }
517 
readSkippableFrameSize(void const * src,size_t srcSize)518 static size_t readSkippableFrameSize(void const* src, size_t srcSize)
519 {
520     size_t const skippableHeaderSize = ZSTD_SKIPPABLEHEADERSIZE;
521     U32 sizeU32;
522 
523     RETURN_ERROR_IF(srcSize < ZSTD_SKIPPABLEHEADERSIZE, srcSize_wrong, "");
524 
525     sizeU32 = MEM_readLE32((BYTE const*)src + ZSTD_FRAMEIDSIZE);
526     RETURN_ERROR_IF((U32)(sizeU32 + ZSTD_SKIPPABLEHEADERSIZE) < sizeU32,
527                     frameParameter_unsupported, "");
528     {
529         size_t const skippableSize = skippableHeaderSize + sizeU32;
530         RETURN_ERROR_IF(skippableSize > srcSize, srcSize_wrong, "");
531         return skippableSize;
532     }
533 }
534 
535 /* ZSTD_findDecompressedSize() :
536  *  compatible with legacy mode
537  *  `srcSize` must be the exact length of some number of ZSTD compressed and/or
538  *      skippable frames
539  *  @return : decompressed size of the frames contained */
ZSTD_findDecompressedSize(const void * src,size_t srcSize)540 unsigned long long ZSTD_findDecompressedSize(const void* src, size_t srcSize)
541 {
542     unsigned long long totalDstSize = 0;
543 
544     while (srcSize >= ZSTD_startingInputLength(ZSTD_f_zstd1)) {
545         U32 const magicNumber = MEM_readLE32(src);
546 
547         if ((magicNumber & ZSTD_MAGIC_SKIPPABLE_MASK) == ZSTD_MAGIC_SKIPPABLE_START) {
548             size_t const skippableSize = readSkippableFrameSize(src, srcSize);
549             if (ZSTD_isError(skippableSize)) {
550                 return ZSTD_CONTENTSIZE_ERROR;
551             }
552             assert(skippableSize <= srcSize);
553 
554             src = (const BYTE *)src + skippableSize;
555             srcSize -= skippableSize;
556             continue;
557         }
558 
559         {   unsigned long long const ret = ZSTD_getFrameContentSize(src, srcSize);
560             if (ret >= ZSTD_CONTENTSIZE_ERROR) return ret;
561 
562             /* check for overflow */
563             if (totalDstSize + ret < totalDstSize) return ZSTD_CONTENTSIZE_ERROR;
564             totalDstSize += ret;
565         }
566         {   size_t const frameSrcSize = ZSTD_findFrameCompressedSize(src, srcSize);
567             if (ZSTD_isError(frameSrcSize)) {
568                 return ZSTD_CONTENTSIZE_ERROR;
569             }
570 
571             src = (const BYTE *)src + frameSrcSize;
572             srcSize -= frameSrcSize;
573         }
574     }  /* while (srcSize >= ZSTD_frameHeaderSize_prefix) */
575 
576     if (srcSize) return ZSTD_CONTENTSIZE_ERROR;
577 
578     return totalDstSize;
579 }
580 
581 /* ZSTD_getDecompressedSize() :
582  *  compatible with legacy mode
583  * @return : decompressed size if known, 0 otherwise
584              note : 0 can mean any of the following :
585                    - frame content is empty
586                    - decompressed size field is not present in frame header
587                    - frame header unknown / not supported
588                    - frame header not complete (`srcSize` too small) */
ZSTD_getDecompressedSize(const void * src,size_t srcSize)589 unsigned long long ZSTD_getDecompressedSize(const void* src, size_t srcSize)
590 {
591     unsigned long long const ret = ZSTD_getFrameContentSize(src, srcSize);
592     ZSTD_STATIC_ASSERT(ZSTD_CONTENTSIZE_ERROR < ZSTD_CONTENTSIZE_UNKNOWN);
593     return (ret >= ZSTD_CONTENTSIZE_ERROR) ? 0 : ret;
594 }
595 
596 
597 /* ZSTD_decodeFrameHeader() :
598  * `headerSize` must be the size provided by ZSTD_frameHeaderSize().
599  * If multiple DDict references are enabled, also will choose the correct DDict to use.
600  * @return : 0 if success, or an error code, which can be tested using ZSTD_isError() */
ZSTD_decodeFrameHeader(ZSTD_DCtx * dctx,const void * src,size_t headerSize)601 static size_t ZSTD_decodeFrameHeader(ZSTD_DCtx* dctx, const void* src, size_t headerSize)
602 {
603     size_t const result = ZSTD_getFrameHeader_advanced(&(dctx->fParams), src, headerSize, dctx->format);
604     if (ZSTD_isError(result)) return result;    /* invalid header */
605     RETURN_ERROR_IF(result>0, srcSize_wrong, "headerSize too small");
606 
607     /* Reference DDict requested by frame if dctx references multiple ddicts */
608     if (dctx->refMultipleDDicts == ZSTD_rmd_refMultipleDDicts && dctx->ddictSet) {
609         ZSTD_DCtx_selectFrameDDict(dctx);
610     }
611 
612 #ifndef FUZZING_BUILD_MODE_UNSAFE_FOR_PRODUCTION
613     /* Skip the dictID check in fuzzing mode, because it makes the search
614      * harder.
615      */
616     RETURN_ERROR_IF(dctx->fParams.dictID && (dctx->dictID != dctx->fParams.dictID),
617                     dictionary_wrong, "");
618 #endif
619     dctx->validateChecksum = (dctx->fParams.checksumFlag && !dctx->forceIgnoreChecksum) ? 1 : 0;
620     if (dctx->validateChecksum) xxh64_reset(&dctx->xxhState, 0);
621     dctx->processedCSize += headerSize;
622     return 0;
623 }
624 
ZSTD_errorFrameSizeInfo(size_t ret)625 static ZSTD_frameSizeInfo ZSTD_errorFrameSizeInfo(size_t ret)
626 {
627     ZSTD_frameSizeInfo frameSizeInfo;
628     frameSizeInfo.compressedSize = ret;
629     frameSizeInfo.decompressedBound = ZSTD_CONTENTSIZE_ERROR;
630     return frameSizeInfo;
631 }
632 
ZSTD_findFrameSizeInfo(const void * src,size_t srcSize)633 static ZSTD_frameSizeInfo ZSTD_findFrameSizeInfo(const void* src, size_t srcSize)
634 {
635     ZSTD_frameSizeInfo frameSizeInfo;
636     ZSTD_memset(&frameSizeInfo, 0, sizeof(ZSTD_frameSizeInfo));
637 
638 
639     if ((srcSize >= ZSTD_SKIPPABLEHEADERSIZE)
640         && (MEM_readLE32(src) & ZSTD_MAGIC_SKIPPABLE_MASK) == ZSTD_MAGIC_SKIPPABLE_START) {
641         frameSizeInfo.compressedSize = readSkippableFrameSize(src, srcSize);
642         assert(ZSTD_isError(frameSizeInfo.compressedSize) ||
643                frameSizeInfo.compressedSize <= srcSize);
644         return frameSizeInfo;
645     } else {
646         const BYTE* ip = (const BYTE*)src;
647         const BYTE* const ipstart = ip;
648         size_t remainingSize = srcSize;
649         size_t nbBlocks = 0;
650         ZSTD_frameHeader zfh;
651 
652         /* Extract Frame Header */
653         {   size_t const ret = ZSTD_getFrameHeader(&zfh, src, srcSize);
654             if (ZSTD_isError(ret))
655                 return ZSTD_errorFrameSizeInfo(ret);
656             if (ret > 0)
657                 return ZSTD_errorFrameSizeInfo(ERROR(srcSize_wrong));
658         }
659 
660         ip += zfh.headerSize;
661         remainingSize -= zfh.headerSize;
662 
663         /* Iterate over each block */
664         while (1) {
665             blockProperties_t blockProperties;
666             size_t const cBlockSize = ZSTD_getcBlockSize(ip, remainingSize, &blockProperties);
667             if (ZSTD_isError(cBlockSize))
668                 return ZSTD_errorFrameSizeInfo(cBlockSize);
669 
670             if (ZSTD_blockHeaderSize + cBlockSize > remainingSize)
671                 return ZSTD_errorFrameSizeInfo(ERROR(srcSize_wrong));
672 
673             ip += ZSTD_blockHeaderSize + cBlockSize;
674             remainingSize -= ZSTD_blockHeaderSize + cBlockSize;
675             nbBlocks++;
676 
677             if (blockProperties.lastBlock) break;
678         }
679 
680         /* Final frame content checksum */
681         if (zfh.checksumFlag) {
682             if (remainingSize < 4)
683                 return ZSTD_errorFrameSizeInfo(ERROR(srcSize_wrong));
684             ip += 4;
685         }
686 
687         frameSizeInfo.compressedSize = (size_t)(ip - ipstart);
688         frameSizeInfo.decompressedBound = (zfh.frameContentSize != ZSTD_CONTENTSIZE_UNKNOWN)
689                                         ? zfh.frameContentSize
690                                         : nbBlocks * zfh.blockSizeMax;
691         return frameSizeInfo;
692     }
693 }
694 
695 /* ZSTD_findFrameCompressedSize() :
696  *  compatible with legacy mode
697  *  `src` must point to the start of a ZSTD frame, ZSTD legacy frame, or skippable frame
698  *  `srcSize` must be at least as large as the frame contained
699  *  @return : the compressed size of the frame starting at `src` */
ZSTD_findFrameCompressedSize(const void * src,size_t srcSize)700 size_t ZSTD_findFrameCompressedSize(const void *src, size_t srcSize)
701 {
702     ZSTD_frameSizeInfo const frameSizeInfo = ZSTD_findFrameSizeInfo(src, srcSize);
703     return frameSizeInfo.compressedSize;
704 }
705 
706 /* ZSTD_decompressBound() :
707  *  compatible with legacy mode
708  *  `src` must point to the start of a ZSTD frame or a skippeable frame
709  *  `srcSize` must be at least as large as the frame contained
710  *  @return : the maximum decompressed size of the compressed source
711  */
ZSTD_decompressBound(const void * src,size_t srcSize)712 unsigned long long ZSTD_decompressBound(const void* src, size_t srcSize)
713 {
714     unsigned long long bound = 0;
715     /* Iterate over each frame */
716     while (srcSize > 0) {
717         ZSTD_frameSizeInfo const frameSizeInfo = ZSTD_findFrameSizeInfo(src, srcSize);
718         size_t const compressedSize = frameSizeInfo.compressedSize;
719         unsigned long long const decompressedBound = frameSizeInfo.decompressedBound;
720         if (ZSTD_isError(compressedSize) || decompressedBound == ZSTD_CONTENTSIZE_ERROR)
721             return ZSTD_CONTENTSIZE_ERROR;
722         assert(srcSize >= compressedSize);
723         src = (const BYTE*)src + compressedSize;
724         srcSize -= compressedSize;
725         bound += decompressedBound;
726     }
727     return bound;
728 }
729 
730 
731 /*-*************************************************************
732  *   Frame decoding
733  ***************************************************************/
734 
735 /* ZSTD_insertBlock() :
736  *  insert `src` block into `dctx` history. Useful to track uncompressed blocks. */
ZSTD_insertBlock(ZSTD_DCtx * dctx,const void * blockStart,size_t blockSize)737 size_t ZSTD_insertBlock(ZSTD_DCtx* dctx, const void* blockStart, size_t blockSize)
738 {
739     DEBUGLOG(5, "ZSTD_insertBlock: %u bytes", (unsigned)blockSize);
740     ZSTD_checkContinuity(dctx, blockStart, blockSize);
741     dctx->previousDstEnd = (const char*)blockStart + blockSize;
742     return blockSize;
743 }
744 
745 
ZSTD_copyRawBlock(void * dst,size_t dstCapacity,const void * src,size_t srcSize)746 static size_t ZSTD_copyRawBlock(void* dst, size_t dstCapacity,
747                           const void* src, size_t srcSize)
748 {
749     DEBUGLOG(5, "ZSTD_copyRawBlock");
750     RETURN_ERROR_IF(srcSize > dstCapacity, dstSize_tooSmall, "");
751     if (dst == NULL) {
752         if (srcSize == 0) return 0;
753         RETURN_ERROR(dstBuffer_null, "");
754     }
755     ZSTD_memcpy(dst, src, srcSize);
756     return srcSize;
757 }
758 
ZSTD_setRleBlock(void * dst,size_t dstCapacity,BYTE b,size_t regenSize)759 static size_t ZSTD_setRleBlock(void* dst, size_t dstCapacity,
760                                BYTE b,
761                                size_t regenSize)
762 {
763     RETURN_ERROR_IF(regenSize > dstCapacity, dstSize_tooSmall, "");
764     if (dst == NULL) {
765         if (regenSize == 0) return 0;
766         RETURN_ERROR(dstBuffer_null, "");
767     }
768     ZSTD_memset(dst, b, regenSize);
769     return regenSize;
770 }
771 
ZSTD_DCtx_trace_end(ZSTD_DCtx const * dctx,U64 uncompressedSize,U64 compressedSize,unsigned streaming)772 static void ZSTD_DCtx_trace_end(ZSTD_DCtx const* dctx, U64 uncompressedSize, U64 compressedSize, unsigned streaming)
773 {
774     (void)dctx;
775     (void)uncompressedSize;
776     (void)compressedSize;
777     (void)streaming;
778 }
779 
780 
781 /*! ZSTD_decompressFrame() :
782  * @dctx must be properly initialized
783  *  will update *srcPtr and *srcSizePtr,
784  *  to make *srcPtr progress by one frame. */
ZSTD_decompressFrame(ZSTD_DCtx * dctx,void * dst,size_t dstCapacity,const void ** srcPtr,size_t * srcSizePtr)785 static size_t ZSTD_decompressFrame(ZSTD_DCtx* dctx,
786                                    void* dst, size_t dstCapacity,
787                              const void** srcPtr, size_t *srcSizePtr)
788 {
789     const BYTE* const istart = (const BYTE*)(*srcPtr);
790     const BYTE* ip = istart;
791     BYTE* const ostart = (BYTE*)dst;
792     BYTE* const oend = dstCapacity != 0 ? ostart + dstCapacity : ostart;
793     BYTE* op = ostart;
794     size_t remainingSrcSize = *srcSizePtr;
795 
796     DEBUGLOG(4, "ZSTD_decompressFrame (srcSize:%i)", (int)*srcSizePtr);
797 
798     /* check */
799     RETURN_ERROR_IF(
800         remainingSrcSize < ZSTD_FRAMEHEADERSIZE_MIN(dctx->format)+ZSTD_blockHeaderSize,
801         srcSize_wrong, "");
802 
803     /* Frame Header */
804     {   size_t const frameHeaderSize = ZSTD_frameHeaderSize_internal(
805                 ip, ZSTD_FRAMEHEADERSIZE_PREFIX(dctx->format), dctx->format);
806         if (ZSTD_isError(frameHeaderSize)) return frameHeaderSize;
807         RETURN_ERROR_IF(remainingSrcSize < frameHeaderSize+ZSTD_blockHeaderSize,
808                         srcSize_wrong, "");
809         FORWARD_IF_ERROR( ZSTD_decodeFrameHeader(dctx, ip, frameHeaderSize) , "");
810         ip += frameHeaderSize; remainingSrcSize -= frameHeaderSize;
811     }
812 
813     /* Loop on each block */
814     while (1) {
815         size_t decodedSize;
816         blockProperties_t blockProperties;
817         size_t const cBlockSize = ZSTD_getcBlockSize(ip, remainingSrcSize, &blockProperties);
818         if (ZSTD_isError(cBlockSize)) return cBlockSize;
819 
820         ip += ZSTD_blockHeaderSize;
821         remainingSrcSize -= ZSTD_blockHeaderSize;
822         RETURN_ERROR_IF(cBlockSize > remainingSrcSize, srcSize_wrong, "");
823 
824         switch(blockProperties.blockType)
825         {
826         case bt_compressed:
827             decodedSize = ZSTD_decompressBlock_internal(dctx, op, (size_t)(oend-op), ip, cBlockSize, /* frame */ 1);
828             break;
829         case bt_raw :
830             decodedSize = ZSTD_copyRawBlock(op, (size_t)(oend-op), ip, cBlockSize);
831             break;
832         case bt_rle :
833             decodedSize = ZSTD_setRleBlock(op, (size_t)(oend-op), *ip, blockProperties.origSize);
834             break;
835         case bt_reserved :
836         default:
837             RETURN_ERROR(corruption_detected, "invalid block type");
838         }
839 
840         if (ZSTD_isError(decodedSize)) return decodedSize;
841         if (dctx->validateChecksum)
842             xxh64_update(&dctx->xxhState, op, decodedSize);
843         if (decodedSize != 0)
844             op += decodedSize;
845         assert(ip != NULL);
846         ip += cBlockSize;
847         remainingSrcSize -= cBlockSize;
848         if (blockProperties.lastBlock) break;
849     }
850 
851     if (dctx->fParams.frameContentSize != ZSTD_CONTENTSIZE_UNKNOWN) {
852         RETURN_ERROR_IF((U64)(op-ostart) != dctx->fParams.frameContentSize,
853                         corruption_detected, "");
854     }
855     if (dctx->fParams.checksumFlag) { /* Frame content checksum verification */
856         RETURN_ERROR_IF(remainingSrcSize<4, checksum_wrong, "");
857         if (!dctx->forceIgnoreChecksum) {
858             U32 const checkCalc = (U32)xxh64_digest(&dctx->xxhState);
859             U32 checkRead;
860             checkRead = MEM_readLE32(ip);
861             RETURN_ERROR_IF(checkRead != checkCalc, checksum_wrong, "");
862         }
863         ip += 4;
864         remainingSrcSize -= 4;
865     }
866     ZSTD_DCtx_trace_end(dctx, (U64)(op-ostart), (U64)(ip-istart), /* streaming */ 0);
867     /* Allow caller to get size read */
868     *srcPtr = ip;
869     *srcSizePtr = remainingSrcSize;
870     return (size_t)(op-ostart);
871 }
872 
ZSTD_decompressMultiFrame(ZSTD_DCtx * dctx,void * dst,size_t dstCapacity,const void * src,size_t srcSize,const void * dict,size_t dictSize,const ZSTD_DDict * ddict)873 static size_t ZSTD_decompressMultiFrame(ZSTD_DCtx* dctx,
874                                         void* dst, size_t dstCapacity,
875                                   const void* src, size_t srcSize,
876                                   const void* dict, size_t dictSize,
877                                   const ZSTD_DDict* ddict)
878 {
879     void* const dststart = dst;
880     int moreThan1Frame = 0;
881 
882     DEBUGLOG(5, "ZSTD_decompressMultiFrame");
883     assert(dict==NULL || ddict==NULL);  /* either dict or ddict set, not both */
884 
885     if (ddict) {
886         dict = ZSTD_DDict_dictContent(ddict);
887         dictSize = ZSTD_DDict_dictSize(ddict);
888     }
889 
890     while (srcSize >= ZSTD_startingInputLength(dctx->format)) {
891 
892 
893         {   U32 const magicNumber = MEM_readLE32(src);
894             DEBUGLOG(4, "reading magic number %08X (expecting %08X)",
895                         (unsigned)magicNumber, ZSTD_MAGICNUMBER);
896             if ((magicNumber & ZSTD_MAGIC_SKIPPABLE_MASK) == ZSTD_MAGIC_SKIPPABLE_START) {
897                 size_t const skippableSize = readSkippableFrameSize(src, srcSize);
898                 FORWARD_IF_ERROR(skippableSize, "readSkippableFrameSize failed");
899                 assert(skippableSize <= srcSize);
900 
901                 src = (const BYTE *)src + skippableSize;
902                 srcSize -= skippableSize;
903                 continue;
904         }   }
905 
906         if (ddict) {
907             /* we were called from ZSTD_decompress_usingDDict */
908             FORWARD_IF_ERROR(ZSTD_decompressBegin_usingDDict(dctx, ddict), "");
909         } else {
910             /* this will initialize correctly with no dict if dict == NULL, so
911              * use this in all cases but ddict */
912             FORWARD_IF_ERROR(ZSTD_decompressBegin_usingDict(dctx, dict, dictSize), "");
913         }
914         ZSTD_checkContinuity(dctx, dst, dstCapacity);
915 
916         {   const size_t res = ZSTD_decompressFrame(dctx, dst, dstCapacity,
917                                                     &src, &srcSize);
918             RETURN_ERROR_IF(
919                 (ZSTD_getErrorCode(res) == ZSTD_error_prefix_unknown)
920              && (moreThan1Frame==1),
921                 srcSize_wrong,
922                 "At least one frame successfully completed, "
923                 "but following bytes are garbage: "
924                 "it's more likely to be a srcSize error, "
925                 "specifying more input bytes than size of frame(s). "
926                 "Note: one could be unlucky, it might be a corruption error instead, "
927                 "happening right at the place where we expect zstd magic bytes. "
928                 "But this is _much_ less likely than a srcSize field error.");
929             if (ZSTD_isError(res)) return res;
930             assert(res <= dstCapacity);
931             if (res != 0)
932                 dst = (BYTE*)dst + res;
933             dstCapacity -= res;
934         }
935         moreThan1Frame = 1;
936     }  /* while (srcSize >= ZSTD_frameHeaderSize_prefix) */
937 
938     RETURN_ERROR_IF(srcSize, srcSize_wrong, "input not entirely consumed");
939 
940     return (size_t)((BYTE*)dst - (BYTE*)dststart);
941 }
942 
ZSTD_decompress_usingDict(ZSTD_DCtx * dctx,void * dst,size_t dstCapacity,const void * src,size_t srcSize,const void * dict,size_t dictSize)943 size_t ZSTD_decompress_usingDict(ZSTD_DCtx* dctx,
944                                  void* dst, size_t dstCapacity,
945                            const void* src, size_t srcSize,
946                            const void* dict, size_t dictSize)
947 {
948     return ZSTD_decompressMultiFrame(dctx, dst, dstCapacity, src, srcSize, dict, dictSize, NULL);
949 }
950 
951 
ZSTD_getDDict(ZSTD_DCtx * dctx)952 static ZSTD_DDict const* ZSTD_getDDict(ZSTD_DCtx* dctx)
953 {
954     switch (dctx->dictUses) {
955     default:
956         assert(0 /* Impossible */);
957         ZSTD_FALLTHROUGH;
958     case ZSTD_dont_use:
959         ZSTD_clearDict(dctx);
960         return NULL;
961     case ZSTD_use_indefinitely:
962         return dctx->ddict;
963     case ZSTD_use_once:
964         dctx->dictUses = ZSTD_dont_use;
965         return dctx->ddict;
966     }
967 }
968 
ZSTD_decompressDCtx(ZSTD_DCtx * dctx,void * dst,size_t dstCapacity,const void * src,size_t srcSize)969 size_t ZSTD_decompressDCtx(ZSTD_DCtx* dctx, void* dst, size_t dstCapacity, const void* src, size_t srcSize)
970 {
971     return ZSTD_decompress_usingDDict(dctx, dst, dstCapacity, src, srcSize, ZSTD_getDDict(dctx));
972 }
973 
974 
ZSTD_decompress(void * dst,size_t dstCapacity,const void * src,size_t srcSize)975 size_t ZSTD_decompress(void* dst, size_t dstCapacity, const void* src, size_t srcSize)
976 {
977 #if defined(ZSTD_HEAPMODE) && (ZSTD_HEAPMODE>=1)
978     size_t regenSize;
979     ZSTD_DCtx* const dctx = ZSTD_createDCtx();
980     RETURN_ERROR_IF(dctx==NULL, memory_allocation, "NULL pointer!");
981     regenSize = ZSTD_decompressDCtx(dctx, dst, dstCapacity, src, srcSize);
982     ZSTD_freeDCtx(dctx);
983     return regenSize;
984 #else   /* stack mode */
985     ZSTD_DCtx dctx;
986     ZSTD_initDCtx_internal(&dctx);
987     return ZSTD_decompressDCtx(&dctx, dst, dstCapacity, src, srcSize);
988 #endif
989 }
990 
991 
992 /*-**************************************
993 *   Advanced Streaming Decompression API
994 *   Bufferless and synchronous
995 ****************************************/
ZSTD_nextSrcSizeToDecompress(ZSTD_DCtx * dctx)996 size_t ZSTD_nextSrcSizeToDecompress(ZSTD_DCtx* dctx) { return dctx->expected; }
997 
998 /*
999  * Similar to ZSTD_nextSrcSizeToDecompress(), but when when a block input can be streamed,
1000  * we allow taking a partial block as the input. Currently only raw uncompressed blocks can
1001  * be streamed.
1002  *
1003  * For blocks that can be streamed, this allows us to reduce the latency until we produce
1004  * output, and avoid copying the input.
1005  *
1006  * @param inputSize - The total amount of input that the caller currently has.
1007  */
ZSTD_nextSrcSizeToDecompressWithInputSize(ZSTD_DCtx * dctx,size_t inputSize)1008 static size_t ZSTD_nextSrcSizeToDecompressWithInputSize(ZSTD_DCtx* dctx, size_t inputSize) {
1009     if (!(dctx->stage == ZSTDds_decompressBlock || dctx->stage == ZSTDds_decompressLastBlock))
1010         return dctx->expected;
1011     if (dctx->bType != bt_raw)
1012         return dctx->expected;
1013     return MIN(MAX(inputSize, 1), dctx->expected);
1014 }
1015 
ZSTD_nextInputType(ZSTD_DCtx * dctx)1016 ZSTD_nextInputType_e ZSTD_nextInputType(ZSTD_DCtx* dctx) {
1017     switch(dctx->stage)
1018     {
1019     default:   /* should not happen */
1020         assert(0);
1021         ZSTD_FALLTHROUGH;
1022     case ZSTDds_getFrameHeaderSize:
1023         ZSTD_FALLTHROUGH;
1024     case ZSTDds_decodeFrameHeader:
1025         return ZSTDnit_frameHeader;
1026     case ZSTDds_decodeBlockHeader:
1027         return ZSTDnit_blockHeader;
1028     case ZSTDds_decompressBlock:
1029         return ZSTDnit_block;
1030     case ZSTDds_decompressLastBlock:
1031         return ZSTDnit_lastBlock;
1032     case ZSTDds_checkChecksum:
1033         return ZSTDnit_checksum;
1034     case ZSTDds_decodeSkippableHeader:
1035         ZSTD_FALLTHROUGH;
1036     case ZSTDds_skipFrame:
1037         return ZSTDnit_skippableFrame;
1038     }
1039 }
1040 
ZSTD_isSkipFrame(ZSTD_DCtx * dctx)1041 static int ZSTD_isSkipFrame(ZSTD_DCtx* dctx) { return dctx->stage == ZSTDds_skipFrame; }
1042 
1043 /* ZSTD_decompressContinue() :
1044  *  srcSize : must be the exact nb of bytes expected (see ZSTD_nextSrcSizeToDecompress())
1045  *  @return : nb of bytes generated into `dst` (necessarily <= `dstCapacity)
1046  *            or an error code, which can be tested using ZSTD_isError() */
ZSTD_decompressContinue(ZSTD_DCtx * dctx,void * dst,size_t dstCapacity,const void * src,size_t srcSize)1047 size_t ZSTD_decompressContinue(ZSTD_DCtx* dctx, void* dst, size_t dstCapacity, const void* src, size_t srcSize)
1048 {
1049     DEBUGLOG(5, "ZSTD_decompressContinue (srcSize:%u)", (unsigned)srcSize);
1050     /* Sanity check */
1051     RETURN_ERROR_IF(srcSize != ZSTD_nextSrcSizeToDecompressWithInputSize(dctx, srcSize), srcSize_wrong, "not allowed");
1052     ZSTD_checkContinuity(dctx, dst, dstCapacity);
1053 
1054     dctx->processedCSize += srcSize;
1055 
1056     switch (dctx->stage)
1057     {
1058     case ZSTDds_getFrameHeaderSize :
1059         assert(src != NULL);
1060         if (dctx->format == ZSTD_f_zstd1) {  /* allows header */
1061             assert(srcSize >= ZSTD_FRAMEIDSIZE);  /* to read skippable magic number */
1062             if ((MEM_readLE32(src) & ZSTD_MAGIC_SKIPPABLE_MASK) == ZSTD_MAGIC_SKIPPABLE_START) {        /* skippable frame */
1063                 ZSTD_memcpy(dctx->headerBuffer, src, srcSize);
1064                 dctx->expected = ZSTD_SKIPPABLEHEADERSIZE - srcSize;  /* remaining to load to get full skippable frame header */
1065                 dctx->stage = ZSTDds_decodeSkippableHeader;
1066                 return 0;
1067         }   }
1068         dctx->headerSize = ZSTD_frameHeaderSize_internal(src, srcSize, dctx->format);
1069         if (ZSTD_isError(dctx->headerSize)) return dctx->headerSize;
1070         ZSTD_memcpy(dctx->headerBuffer, src, srcSize);
1071         dctx->expected = dctx->headerSize - srcSize;
1072         dctx->stage = ZSTDds_decodeFrameHeader;
1073         return 0;
1074 
1075     case ZSTDds_decodeFrameHeader:
1076         assert(src != NULL);
1077         ZSTD_memcpy(dctx->headerBuffer + (dctx->headerSize - srcSize), src, srcSize);
1078         FORWARD_IF_ERROR(ZSTD_decodeFrameHeader(dctx, dctx->headerBuffer, dctx->headerSize), "");
1079         dctx->expected = ZSTD_blockHeaderSize;
1080         dctx->stage = ZSTDds_decodeBlockHeader;
1081         return 0;
1082 
1083     case ZSTDds_decodeBlockHeader:
1084         {   blockProperties_t bp;
1085             size_t const cBlockSize = ZSTD_getcBlockSize(src, ZSTD_blockHeaderSize, &bp);
1086             if (ZSTD_isError(cBlockSize)) return cBlockSize;
1087             RETURN_ERROR_IF(cBlockSize > dctx->fParams.blockSizeMax, corruption_detected, "Block Size Exceeds Maximum");
1088             dctx->expected = cBlockSize;
1089             dctx->bType = bp.blockType;
1090             dctx->rleSize = bp.origSize;
1091             if (cBlockSize) {
1092                 dctx->stage = bp.lastBlock ? ZSTDds_decompressLastBlock : ZSTDds_decompressBlock;
1093                 return 0;
1094             }
1095             /* empty block */
1096             if (bp.lastBlock) {
1097                 if (dctx->fParams.checksumFlag) {
1098                     dctx->expected = 4;
1099                     dctx->stage = ZSTDds_checkChecksum;
1100                 } else {
1101                     dctx->expected = 0; /* end of frame */
1102                     dctx->stage = ZSTDds_getFrameHeaderSize;
1103                 }
1104             } else {
1105                 dctx->expected = ZSTD_blockHeaderSize;  /* jump to next header */
1106                 dctx->stage = ZSTDds_decodeBlockHeader;
1107             }
1108             return 0;
1109         }
1110 
1111     case ZSTDds_decompressLastBlock:
1112     case ZSTDds_decompressBlock:
1113         DEBUGLOG(5, "ZSTD_decompressContinue: case ZSTDds_decompressBlock");
1114         {   size_t rSize;
1115             switch(dctx->bType)
1116             {
1117             case bt_compressed:
1118                 DEBUGLOG(5, "ZSTD_decompressContinue: case bt_compressed");
1119                 rSize = ZSTD_decompressBlock_internal(dctx, dst, dstCapacity, src, srcSize, /* frame */ 1);
1120                 dctx->expected = 0;  /* Streaming not supported */
1121                 break;
1122             case bt_raw :
1123                 assert(srcSize <= dctx->expected);
1124                 rSize = ZSTD_copyRawBlock(dst, dstCapacity, src, srcSize);
1125                 FORWARD_IF_ERROR(rSize, "ZSTD_copyRawBlock failed");
1126                 assert(rSize == srcSize);
1127                 dctx->expected -= rSize;
1128                 break;
1129             case bt_rle :
1130                 rSize = ZSTD_setRleBlock(dst, dstCapacity, *(const BYTE*)src, dctx->rleSize);
1131                 dctx->expected = 0;  /* Streaming not supported */
1132                 break;
1133             case bt_reserved :   /* should never happen */
1134             default:
1135                 RETURN_ERROR(corruption_detected, "invalid block type");
1136             }
1137             FORWARD_IF_ERROR(rSize, "");
1138             RETURN_ERROR_IF(rSize > dctx->fParams.blockSizeMax, corruption_detected, "Decompressed Block Size Exceeds Maximum");
1139             DEBUGLOG(5, "ZSTD_decompressContinue: decoded size from block : %u", (unsigned)rSize);
1140             dctx->decodedSize += rSize;
1141             if (dctx->validateChecksum) xxh64_update(&dctx->xxhState, dst, rSize);
1142             dctx->previousDstEnd = (char*)dst + rSize;
1143 
1144             /* Stay on the same stage until we are finished streaming the block. */
1145             if (dctx->expected > 0) {
1146                 return rSize;
1147             }
1148 
1149             if (dctx->stage == ZSTDds_decompressLastBlock) {   /* end of frame */
1150                 DEBUGLOG(4, "ZSTD_decompressContinue: decoded size from frame : %u", (unsigned)dctx->decodedSize);
1151                 RETURN_ERROR_IF(
1152                     dctx->fParams.frameContentSize != ZSTD_CONTENTSIZE_UNKNOWN
1153                  && dctx->decodedSize != dctx->fParams.frameContentSize,
1154                     corruption_detected, "");
1155                 if (dctx->fParams.checksumFlag) {  /* another round for frame checksum */
1156                     dctx->expected = 4;
1157                     dctx->stage = ZSTDds_checkChecksum;
1158                 } else {
1159                     ZSTD_DCtx_trace_end(dctx, dctx->decodedSize, dctx->processedCSize, /* streaming */ 1);
1160                     dctx->expected = 0;   /* ends here */
1161                     dctx->stage = ZSTDds_getFrameHeaderSize;
1162                 }
1163             } else {
1164                 dctx->stage = ZSTDds_decodeBlockHeader;
1165                 dctx->expected = ZSTD_blockHeaderSize;
1166             }
1167             return rSize;
1168         }
1169 
1170     case ZSTDds_checkChecksum:
1171         assert(srcSize == 4);  /* guaranteed by dctx->expected */
1172         {
1173             if (dctx->validateChecksum) {
1174                 U32 const h32 = (U32)xxh64_digest(&dctx->xxhState);
1175                 U32 const check32 = MEM_readLE32(src);
1176                 DEBUGLOG(4, "ZSTD_decompressContinue: checksum : calculated %08X :: %08X read", (unsigned)h32, (unsigned)check32);
1177                 RETURN_ERROR_IF(check32 != h32, checksum_wrong, "");
1178             }
1179             ZSTD_DCtx_trace_end(dctx, dctx->decodedSize, dctx->processedCSize, /* streaming */ 1);
1180             dctx->expected = 0;
1181             dctx->stage = ZSTDds_getFrameHeaderSize;
1182             return 0;
1183         }
1184 
1185     case ZSTDds_decodeSkippableHeader:
1186         assert(src != NULL);
1187         assert(srcSize <= ZSTD_SKIPPABLEHEADERSIZE);
1188         ZSTD_memcpy(dctx->headerBuffer + (ZSTD_SKIPPABLEHEADERSIZE - srcSize), src, srcSize);   /* complete skippable header */
1189         dctx->expected = MEM_readLE32(dctx->headerBuffer + ZSTD_FRAMEIDSIZE);   /* note : dctx->expected can grow seriously large, beyond local buffer size */
1190         dctx->stage = ZSTDds_skipFrame;
1191         return 0;
1192 
1193     case ZSTDds_skipFrame:
1194         dctx->expected = 0;
1195         dctx->stage = ZSTDds_getFrameHeaderSize;
1196         return 0;
1197 
1198     default:
1199         assert(0);   /* impossible */
1200         RETURN_ERROR(GENERIC, "impossible to reach");   /* some compiler require default to do something */
1201     }
1202 }
1203 
1204 
ZSTD_refDictContent(ZSTD_DCtx * dctx,const void * dict,size_t dictSize)1205 static size_t ZSTD_refDictContent(ZSTD_DCtx* dctx, const void* dict, size_t dictSize)
1206 {
1207     dctx->dictEnd = dctx->previousDstEnd;
1208     dctx->virtualStart = (const char*)dict - ((const char*)(dctx->previousDstEnd) - (const char*)(dctx->prefixStart));
1209     dctx->prefixStart = dict;
1210     dctx->previousDstEnd = (const char*)dict + dictSize;
1211 #ifdef FUZZING_BUILD_MODE_UNSAFE_FOR_PRODUCTION
1212     dctx->dictContentBeginForFuzzing = dctx->prefixStart;
1213     dctx->dictContentEndForFuzzing = dctx->previousDstEnd;
1214 #endif
1215     return 0;
1216 }
1217 
1218 /*! ZSTD_loadDEntropy() :
1219  *  dict : must point at beginning of a valid zstd dictionary.
1220  * @return : size of entropy tables read */
1221 size_t
ZSTD_loadDEntropy(ZSTD_entropyDTables_t * entropy,const void * const dict,size_t const dictSize)1222 ZSTD_loadDEntropy(ZSTD_entropyDTables_t* entropy,
1223                   const void* const dict, size_t const dictSize)
1224 {
1225     const BYTE* dictPtr = (const BYTE*)dict;
1226     const BYTE* const dictEnd = dictPtr + dictSize;
1227 
1228     RETURN_ERROR_IF(dictSize <= 8, dictionary_corrupted, "dict is too small");
1229     assert(MEM_readLE32(dict) == ZSTD_MAGIC_DICTIONARY);   /* dict must be valid */
1230     dictPtr += 8;   /* skip header = magic + dictID */
1231 
1232     ZSTD_STATIC_ASSERT(offsetof(ZSTD_entropyDTables_t, OFTable) == offsetof(ZSTD_entropyDTables_t, LLTable) + sizeof(entropy->LLTable));
1233     ZSTD_STATIC_ASSERT(offsetof(ZSTD_entropyDTables_t, MLTable) == offsetof(ZSTD_entropyDTables_t, OFTable) + sizeof(entropy->OFTable));
1234     ZSTD_STATIC_ASSERT(sizeof(entropy->LLTable) + sizeof(entropy->OFTable) + sizeof(entropy->MLTable) >= HUF_DECOMPRESS_WORKSPACE_SIZE);
1235     {   void* const workspace = &entropy->LLTable;   /* use fse tables as temporary workspace; implies fse tables are grouped together */
1236         size_t const workspaceSize = sizeof(entropy->LLTable) + sizeof(entropy->OFTable) + sizeof(entropy->MLTable);
1237 #ifdef HUF_FORCE_DECOMPRESS_X1
1238         /* in minimal huffman, we always use X1 variants */
1239         size_t const hSize = HUF_readDTableX1_wksp(entropy->hufTable,
1240                                                 dictPtr, dictEnd - dictPtr,
1241                                                 workspace, workspaceSize);
1242 #else
1243         size_t const hSize = HUF_readDTableX2_wksp(entropy->hufTable,
1244                                                 dictPtr, (size_t)(dictEnd - dictPtr),
1245                                                 workspace, workspaceSize);
1246 #endif
1247         RETURN_ERROR_IF(HUF_isError(hSize), dictionary_corrupted, "");
1248         dictPtr += hSize;
1249     }
1250 
1251     {   short offcodeNCount[MaxOff+1];
1252         unsigned offcodeMaxValue = MaxOff, offcodeLog;
1253         size_t const offcodeHeaderSize = FSE_readNCount(offcodeNCount, &offcodeMaxValue, &offcodeLog, dictPtr, (size_t)(dictEnd-dictPtr));
1254         RETURN_ERROR_IF(FSE_isError(offcodeHeaderSize), dictionary_corrupted, "");
1255         RETURN_ERROR_IF(offcodeMaxValue > MaxOff, dictionary_corrupted, "");
1256         RETURN_ERROR_IF(offcodeLog > OffFSELog, dictionary_corrupted, "");
1257         ZSTD_buildFSETable( entropy->OFTable,
1258                             offcodeNCount, offcodeMaxValue,
1259                             OF_base, OF_bits,
1260                             offcodeLog,
1261                             entropy->workspace, sizeof(entropy->workspace),
1262                             /* bmi2 */0);
1263         dictPtr += offcodeHeaderSize;
1264     }
1265 
1266     {   short matchlengthNCount[MaxML+1];
1267         unsigned matchlengthMaxValue = MaxML, matchlengthLog;
1268         size_t const matchlengthHeaderSize = FSE_readNCount(matchlengthNCount, &matchlengthMaxValue, &matchlengthLog, dictPtr, (size_t)(dictEnd-dictPtr));
1269         RETURN_ERROR_IF(FSE_isError(matchlengthHeaderSize), dictionary_corrupted, "");
1270         RETURN_ERROR_IF(matchlengthMaxValue > MaxML, dictionary_corrupted, "");
1271         RETURN_ERROR_IF(matchlengthLog > MLFSELog, dictionary_corrupted, "");
1272         ZSTD_buildFSETable( entropy->MLTable,
1273                             matchlengthNCount, matchlengthMaxValue,
1274                             ML_base, ML_bits,
1275                             matchlengthLog,
1276                             entropy->workspace, sizeof(entropy->workspace),
1277                             /* bmi2 */ 0);
1278         dictPtr += matchlengthHeaderSize;
1279     }
1280 
1281     {   short litlengthNCount[MaxLL+1];
1282         unsigned litlengthMaxValue = MaxLL, litlengthLog;
1283         size_t const litlengthHeaderSize = FSE_readNCount(litlengthNCount, &litlengthMaxValue, &litlengthLog, dictPtr, (size_t)(dictEnd-dictPtr));
1284         RETURN_ERROR_IF(FSE_isError(litlengthHeaderSize), dictionary_corrupted, "");
1285         RETURN_ERROR_IF(litlengthMaxValue > MaxLL, dictionary_corrupted, "");
1286         RETURN_ERROR_IF(litlengthLog > LLFSELog, dictionary_corrupted, "");
1287         ZSTD_buildFSETable( entropy->LLTable,
1288                             litlengthNCount, litlengthMaxValue,
1289                             LL_base, LL_bits,
1290                             litlengthLog,
1291                             entropy->workspace, sizeof(entropy->workspace),
1292                             /* bmi2 */ 0);
1293         dictPtr += litlengthHeaderSize;
1294     }
1295 
1296     RETURN_ERROR_IF(dictPtr+12 > dictEnd, dictionary_corrupted, "");
1297     {   int i;
1298         size_t const dictContentSize = (size_t)(dictEnd - (dictPtr+12));
1299         for (i=0; i<3; i++) {
1300             U32 const rep = MEM_readLE32(dictPtr); dictPtr += 4;
1301             RETURN_ERROR_IF(rep==0 || rep > dictContentSize,
1302                             dictionary_corrupted, "");
1303             entropy->rep[i] = rep;
1304     }   }
1305 
1306     return (size_t)(dictPtr - (const BYTE*)dict);
1307 }
1308 
ZSTD_decompress_insertDictionary(ZSTD_DCtx * dctx,const void * dict,size_t dictSize)1309 static size_t ZSTD_decompress_insertDictionary(ZSTD_DCtx* dctx, const void* dict, size_t dictSize)
1310 {
1311     if (dictSize < 8) return ZSTD_refDictContent(dctx, dict, dictSize);
1312     {   U32 const magic = MEM_readLE32(dict);
1313         if (magic != ZSTD_MAGIC_DICTIONARY) {
1314             return ZSTD_refDictContent(dctx, dict, dictSize);   /* pure content mode */
1315     }   }
1316     dctx->dictID = MEM_readLE32((const char*)dict + ZSTD_FRAMEIDSIZE);
1317 
1318     /* load entropy tables */
1319     {   size_t const eSize = ZSTD_loadDEntropy(&dctx->entropy, dict, dictSize);
1320         RETURN_ERROR_IF(ZSTD_isError(eSize), dictionary_corrupted, "");
1321         dict = (const char*)dict + eSize;
1322         dictSize -= eSize;
1323     }
1324     dctx->litEntropy = dctx->fseEntropy = 1;
1325 
1326     /* reference dictionary content */
1327     return ZSTD_refDictContent(dctx, dict, dictSize);
1328 }
1329 
ZSTD_decompressBegin(ZSTD_DCtx * dctx)1330 size_t ZSTD_decompressBegin(ZSTD_DCtx* dctx)
1331 {
1332     assert(dctx != NULL);
1333     dctx->expected = ZSTD_startingInputLength(dctx->format);  /* dctx->format must be properly set */
1334     dctx->stage = ZSTDds_getFrameHeaderSize;
1335     dctx->processedCSize = 0;
1336     dctx->decodedSize = 0;
1337     dctx->previousDstEnd = NULL;
1338     dctx->prefixStart = NULL;
1339     dctx->virtualStart = NULL;
1340     dctx->dictEnd = NULL;
1341     dctx->entropy.hufTable[0] = (HUF_DTable)((HufLog)*0x1000001);  /* cover both little and big endian */
1342     dctx->litEntropy = dctx->fseEntropy = 0;
1343     dctx->dictID = 0;
1344     dctx->bType = bt_reserved;
1345     ZSTD_STATIC_ASSERT(sizeof(dctx->entropy.rep) == sizeof(repStartValue));
1346     ZSTD_memcpy(dctx->entropy.rep, repStartValue, sizeof(repStartValue));  /* initial repcodes */
1347     dctx->LLTptr = dctx->entropy.LLTable;
1348     dctx->MLTptr = dctx->entropy.MLTable;
1349     dctx->OFTptr = dctx->entropy.OFTable;
1350     dctx->HUFptr = dctx->entropy.hufTable;
1351     return 0;
1352 }
1353 
ZSTD_decompressBegin_usingDict(ZSTD_DCtx * dctx,const void * dict,size_t dictSize)1354 size_t ZSTD_decompressBegin_usingDict(ZSTD_DCtx* dctx, const void* dict, size_t dictSize)
1355 {
1356     FORWARD_IF_ERROR( ZSTD_decompressBegin(dctx) , "");
1357     if (dict && dictSize)
1358         RETURN_ERROR_IF(
1359             ZSTD_isError(ZSTD_decompress_insertDictionary(dctx, dict, dictSize)),
1360             dictionary_corrupted, "");
1361     return 0;
1362 }
1363 
1364 
1365 /* ======   ZSTD_DDict   ====== */
1366 
ZSTD_decompressBegin_usingDDict(ZSTD_DCtx * dctx,const ZSTD_DDict * ddict)1367 size_t ZSTD_decompressBegin_usingDDict(ZSTD_DCtx* dctx, const ZSTD_DDict* ddict)
1368 {
1369     DEBUGLOG(4, "ZSTD_decompressBegin_usingDDict");
1370     assert(dctx != NULL);
1371     if (ddict) {
1372         const char* const dictStart = (const char*)ZSTD_DDict_dictContent(ddict);
1373         size_t const dictSize = ZSTD_DDict_dictSize(ddict);
1374         const void* const dictEnd = dictStart + dictSize;
1375         dctx->ddictIsCold = (dctx->dictEnd != dictEnd);
1376         DEBUGLOG(4, "DDict is %s",
1377                     dctx->ddictIsCold ? "~cold~" : "hot!");
1378     }
1379     FORWARD_IF_ERROR( ZSTD_decompressBegin(dctx) , "");
1380     if (ddict) {   /* NULL ddict is equivalent to no dictionary */
1381         ZSTD_copyDDictParameters(dctx, ddict);
1382     }
1383     return 0;
1384 }
1385 
1386 /*! ZSTD_getDictID_fromDict() :
1387  *  Provides the dictID stored within dictionary.
1388  *  if @return == 0, the dictionary is not conformant with Zstandard specification.
1389  *  It can still be loaded, but as a content-only dictionary. */
ZSTD_getDictID_fromDict(const void * dict,size_t dictSize)1390 unsigned ZSTD_getDictID_fromDict(const void* dict, size_t dictSize)
1391 {
1392     if (dictSize < 8) return 0;
1393     if (MEM_readLE32(dict) != ZSTD_MAGIC_DICTIONARY) return 0;
1394     return MEM_readLE32((const char*)dict + ZSTD_FRAMEIDSIZE);
1395 }
1396 
1397 /*! ZSTD_getDictID_fromFrame() :
1398  *  Provides the dictID required to decompress frame stored within `src`.
1399  *  If @return == 0, the dictID could not be decoded.
1400  *  This could for one of the following reasons :
1401  *  - The frame does not require a dictionary (most common case).
1402  *  - The frame was built with dictID intentionally removed.
1403  *    Needed dictionary is a hidden information.
1404  *    Note : this use case also happens when using a non-conformant dictionary.
1405  *  - `srcSize` is too small, and as a result, frame header could not be decoded.
1406  *    Note : possible if `srcSize < ZSTD_FRAMEHEADERSIZE_MAX`.
1407  *  - This is not a Zstandard frame.
1408  *  When identifying the exact failure cause, it's possible to use
1409  *  ZSTD_getFrameHeader(), which will provide a more precise error code. */
ZSTD_getDictID_fromFrame(const void * src,size_t srcSize)1410 unsigned ZSTD_getDictID_fromFrame(const void* src, size_t srcSize)
1411 {
1412     ZSTD_frameHeader zfp = { 0, 0, 0, ZSTD_frame, 0, 0, 0 };
1413     size_t const hError = ZSTD_getFrameHeader(&zfp, src, srcSize);
1414     if (ZSTD_isError(hError)) return 0;
1415     return zfp.dictID;
1416 }
1417 
1418 
1419 /*! ZSTD_decompress_usingDDict() :
1420 *   Decompression using a pre-digested Dictionary
1421 *   Use dictionary without significant overhead. */
ZSTD_decompress_usingDDict(ZSTD_DCtx * dctx,void * dst,size_t dstCapacity,const void * src,size_t srcSize,const ZSTD_DDict * ddict)1422 size_t ZSTD_decompress_usingDDict(ZSTD_DCtx* dctx,
1423                                   void* dst, size_t dstCapacity,
1424                             const void* src, size_t srcSize,
1425                             const ZSTD_DDict* ddict)
1426 {
1427     /* pass content and size in case legacy frames are encountered */
1428     return ZSTD_decompressMultiFrame(dctx, dst, dstCapacity, src, srcSize,
1429                                      NULL, 0,
1430                                      ddict);
1431 }
1432 
1433 
1434 /*=====================================
1435 *   Streaming decompression
1436 *====================================*/
1437 
ZSTD_createDStream(void)1438 ZSTD_DStream* ZSTD_createDStream(void)
1439 {
1440     DEBUGLOG(3, "ZSTD_createDStream");
1441     return ZSTD_createDStream_advanced(ZSTD_defaultCMem);
1442 }
1443 
ZSTD_initStaticDStream(void * workspace,size_t workspaceSize)1444 ZSTD_DStream* ZSTD_initStaticDStream(void *workspace, size_t workspaceSize)
1445 {
1446     return ZSTD_initStaticDCtx(workspace, workspaceSize);
1447 }
1448 
ZSTD_createDStream_advanced(ZSTD_customMem customMem)1449 ZSTD_DStream* ZSTD_createDStream_advanced(ZSTD_customMem customMem)
1450 {
1451     return ZSTD_createDCtx_advanced(customMem);
1452 }
1453 
ZSTD_freeDStream(ZSTD_DStream * zds)1454 size_t ZSTD_freeDStream(ZSTD_DStream* zds)
1455 {
1456     return ZSTD_freeDCtx(zds);
1457 }
1458 
1459 
1460 /* ***  Initialization  *** */
1461 
ZSTD_DStreamInSize(void)1462 size_t ZSTD_DStreamInSize(void)  { return ZSTD_BLOCKSIZE_MAX + ZSTD_blockHeaderSize; }
ZSTD_DStreamOutSize(void)1463 size_t ZSTD_DStreamOutSize(void) { return ZSTD_BLOCKSIZE_MAX; }
1464 
ZSTD_DCtx_loadDictionary_advanced(ZSTD_DCtx * dctx,const void * dict,size_t dictSize,ZSTD_dictLoadMethod_e dictLoadMethod,ZSTD_dictContentType_e dictContentType)1465 size_t ZSTD_DCtx_loadDictionary_advanced(ZSTD_DCtx* dctx,
1466                                    const void* dict, size_t dictSize,
1467                                          ZSTD_dictLoadMethod_e dictLoadMethod,
1468                                          ZSTD_dictContentType_e dictContentType)
1469 {
1470     RETURN_ERROR_IF(dctx->streamStage != zdss_init, stage_wrong, "");
1471     ZSTD_clearDict(dctx);
1472     if (dict && dictSize != 0) {
1473         dctx->ddictLocal = ZSTD_createDDict_advanced(dict, dictSize, dictLoadMethod, dictContentType, dctx->customMem);
1474         RETURN_ERROR_IF(dctx->ddictLocal == NULL, memory_allocation, "NULL pointer!");
1475         dctx->ddict = dctx->ddictLocal;
1476         dctx->dictUses = ZSTD_use_indefinitely;
1477     }
1478     return 0;
1479 }
1480 
ZSTD_DCtx_loadDictionary_byReference(ZSTD_DCtx * dctx,const void * dict,size_t dictSize)1481 size_t ZSTD_DCtx_loadDictionary_byReference(ZSTD_DCtx* dctx, const void* dict, size_t dictSize)
1482 {
1483     return ZSTD_DCtx_loadDictionary_advanced(dctx, dict, dictSize, ZSTD_dlm_byRef, ZSTD_dct_auto);
1484 }
1485 
ZSTD_DCtx_loadDictionary(ZSTD_DCtx * dctx,const void * dict,size_t dictSize)1486 size_t ZSTD_DCtx_loadDictionary(ZSTD_DCtx* dctx, const void* dict, size_t dictSize)
1487 {
1488     return ZSTD_DCtx_loadDictionary_advanced(dctx, dict, dictSize, ZSTD_dlm_byCopy, ZSTD_dct_auto);
1489 }
1490 
ZSTD_DCtx_refPrefix_advanced(ZSTD_DCtx * dctx,const void * prefix,size_t prefixSize,ZSTD_dictContentType_e dictContentType)1491 size_t ZSTD_DCtx_refPrefix_advanced(ZSTD_DCtx* dctx, const void* prefix, size_t prefixSize, ZSTD_dictContentType_e dictContentType)
1492 {
1493     FORWARD_IF_ERROR(ZSTD_DCtx_loadDictionary_advanced(dctx, prefix, prefixSize, ZSTD_dlm_byRef, dictContentType), "");
1494     dctx->dictUses = ZSTD_use_once;
1495     return 0;
1496 }
1497 
ZSTD_DCtx_refPrefix(ZSTD_DCtx * dctx,const void * prefix,size_t prefixSize)1498 size_t ZSTD_DCtx_refPrefix(ZSTD_DCtx* dctx, const void* prefix, size_t prefixSize)
1499 {
1500     return ZSTD_DCtx_refPrefix_advanced(dctx, prefix, prefixSize, ZSTD_dct_rawContent);
1501 }
1502 
1503 
1504 /* ZSTD_initDStream_usingDict() :
1505  * return : expected size, aka ZSTD_startingInputLength().
1506  * this function cannot fail */
ZSTD_initDStream_usingDict(ZSTD_DStream * zds,const void * dict,size_t dictSize)1507 size_t ZSTD_initDStream_usingDict(ZSTD_DStream* zds, const void* dict, size_t dictSize)
1508 {
1509     DEBUGLOG(4, "ZSTD_initDStream_usingDict");
1510     FORWARD_IF_ERROR( ZSTD_DCtx_reset(zds, ZSTD_reset_session_only) , "");
1511     FORWARD_IF_ERROR( ZSTD_DCtx_loadDictionary(zds, dict, dictSize) , "");
1512     return ZSTD_startingInputLength(zds->format);
1513 }
1514 
1515 /* note : this variant can't fail */
ZSTD_initDStream(ZSTD_DStream * zds)1516 size_t ZSTD_initDStream(ZSTD_DStream* zds)
1517 {
1518     DEBUGLOG(4, "ZSTD_initDStream");
1519     return ZSTD_initDStream_usingDDict(zds, NULL);
1520 }
1521 
1522 /* ZSTD_initDStream_usingDDict() :
1523  * ddict will just be referenced, and must outlive decompression session
1524  * this function cannot fail */
ZSTD_initDStream_usingDDict(ZSTD_DStream * dctx,const ZSTD_DDict * ddict)1525 size_t ZSTD_initDStream_usingDDict(ZSTD_DStream* dctx, const ZSTD_DDict* ddict)
1526 {
1527     FORWARD_IF_ERROR( ZSTD_DCtx_reset(dctx, ZSTD_reset_session_only) , "");
1528     FORWARD_IF_ERROR( ZSTD_DCtx_refDDict(dctx, ddict) , "");
1529     return ZSTD_startingInputLength(dctx->format);
1530 }
1531 
1532 /* ZSTD_resetDStream() :
1533  * return : expected size, aka ZSTD_startingInputLength().
1534  * this function cannot fail */
ZSTD_resetDStream(ZSTD_DStream * dctx)1535 size_t ZSTD_resetDStream(ZSTD_DStream* dctx)
1536 {
1537     FORWARD_IF_ERROR(ZSTD_DCtx_reset(dctx, ZSTD_reset_session_only), "");
1538     return ZSTD_startingInputLength(dctx->format);
1539 }
1540 
1541 
ZSTD_DCtx_refDDict(ZSTD_DCtx * dctx,const ZSTD_DDict * ddict)1542 size_t ZSTD_DCtx_refDDict(ZSTD_DCtx* dctx, const ZSTD_DDict* ddict)
1543 {
1544     RETURN_ERROR_IF(dctx->streamStage != zdss_init, stage_wrong, "");
1545     ZSTD_clearDict(dctx);
1546     if (ddict) {
1547         dctx->ddict = ddict;
1548         dctx->dictUses = ZSTD_use_indefinitely;
1549         if (dctx->refMultipleDDicts == ZSTD_rmd_refMultipleDDicts) {
1550             if (dctx->ddictSet == NULL) {
1551                 dctx->ddictSet = ZSTD_createDDictHashSet(dctx->customMem);
1552                 if (!dctx->ddictSet) {
1553                     RETURN_ERROR(memory_allocation, "Failed to allocate memory for hash set!");
1554                 }
1555             }
1556             assert(!dctx->staticSize);  /* Impossible: ddictSet cannot have been allocated if static dctx */
1557             FORWARD_IF_ERROR(ZSTD_DDictHashSet_addDDict(dctx->ddictSet, ddict, dctx->customMem), "");
1558         }
1559     }
1560     return 0;
1561 }
1562 
1563 /* ZSTD_DCtx_setMaxWindowSize() :
1564  * note : no direct equivalence in ZSTD_DCtx_setParameter,
1565  * since this version sets windowSize, and the other sets windowLog */
ZSTD_DCtx_setMaxWindowSize(ZSTD_DCtx * dctx,size_t maxWindowSize)1566 size_t ZSTD_DCtx_setMaxWindowSize(ZSTD_DCtx* dctx, size_t maxWindowSize)
1567 {
1568     ZSTD_bounds const bounds = ZSTD_dParam_getBounds(ZSTD_d_windowLogMax);
1569     size_t const min = (size_t)1 << bounds.lowerBound;
1570     size_t const max = (size_t)1 << bounds.upperBound;
1571     RETURN_ERROR_IF(dctx->streamStage != zdss_init, stage_wrong, "");
1572     RETURN_ERROR_IF(maxWindowSize < min, parameter_outOfBound, "");
1573     RETURN_ERROR_IF(maxWindowSize > max, parameter_outOfBound, "");
1574     dctx->maxWindowSize = maxWindowSize;
1575     return 0;
1576 }
1577 
ZSTD_DCtx_setFormat(ZSTD_DCtx * dctx,ZSTD_format_e format)1578 size_t ZSTD_DCtx_setFormat(ZSTD_DCtx* dctx, ZSTD_format_e format)
1579 {
1580     return ZSTD_DCtx_setParameter(dctx, ZSTD_d_format, (int)format);
1581 }
1582 
ZSTD_dParam_getBounds(ZSTD_dParameter dParam)1583 ZSTD_bounds ZSTD_dParam_getBounds(ZSTD_dParameter dParam)
1584 {
1585     ZSTD_bounds bounds = { 0, 0, 0 };
1586     switch(dParam) {
1587         case ZSTD_d_windowLogMax:
1588             bounds.lowerBound = ZSTD_WINDOWLOG_ABSOLUTEMIN;
1589             bounds.upperBound = ZSTD_WINDOWLOG_MAX;
1590             return bounds;
1591         case ZSTD_d_format:
1592             bounds.lowerBound = (int)ZSTD_f_zstd1;
1593             bounds.upperBound = (int)ZSTD_f_zstd1_magicless;
1594             ZSTD_STATIC_ASSERT(ZSTD_f_zstd1 < ZSTD_f_zstd1_magicless);
1595             return bounds;
1596         case ZSTD_d_stableOutBuffer:
1597             bounds.lowerBound = (int)ZSTD_bm_buffered;
1598             bounds.upperBound = (int)ZSTD_bm_stable;
1599             return bounds;
1600         case ZSTD_d_forceIgnoreChecksum:
1601             bounds.lowerBound = (int)ZSTD_d_validateChecksum;
1602             bounds.upperBound = (int)ZSTD_d_ignoreChecksum;
1603             return bounds;
1604         case ZSTD_d_refMultipleDDicts:
1605             bounds.lowerBound = (int)ZSTD_rmd_refSingleDDict;
1606             bounds.upperBound = (int)ZSTD_rmd_refMultipleDDicts;
1607             return bounds;
1608         default:;
1609     }
1610     bounds.error = ERROR(parameter_unsupported);
1611     return bounds;
1612 }
1613 
1614 /* ZSTD_dParam_withinBounds:
1615  * @return 1 if value is within dParam bounds,
1616  * 0 otherwise */
ZSTD_dParam_withinBounds(ZSTD_dParameter dParam,int value)1617 static int ZSTD_dParam_withinBounds(ZSTD_dParameter dParam, int value)
1618 {
1619     ZSTD_bounds const bounds = ZSTD_dParam_getBounds(dParam);
1620     if (ZSTD_isError(bounds.error)) return 0;
1621     if (value < bounds.lowerBound) return 0;
1622     if (value > bounds.upperBound) return 0;
1623     return 1;
1624 }
1625 
1626 #define CHECK_DBOUNDS(p,v) {                \
1627     RETURN_ERROR_IF(!ZSTD_dParam_withinBounds(p, v), parameter_outOfBound, ""); \
1628 }
1629 
ZSTD_DCtx_getParameter(ZSTD_DCtx * dctx,ZSTD_dParameter param,int * value)1630 size_t ZSTD_DCtx_getParameter(ZSTD_DCtx* dctx, ZSTD_dParameter param, int* value)
1631 {
1632     switch (param) {
1633         case ZSTD_d_windowLogMax:
1634             *value = (int)ZSTD_highbit32((U32)dctx->maxWindowSize);
1635             return 0;
1636         case ZSTD_d_format:
1637             *value = (int)dctx->format;
1638             return 0;
1639         case ZSTD_d_stableOutBuffer:
1640             *value = (int)dctx->outBufferMode;
1641             return 0;
1642         case ZSTD_d_forceIgnoreChecksum:
1643             *value = (int)dctx->forceIgnoreChecksum;
1644             return 0;
1645         case ZSTD_d_refMultipleDDicts:
1646             *value = (int)dctx->refMultipleDDicts;
1647             return 0;
1648         default:;
1649     }
1650     RETURN_ERROR(parameter_unsupported, "");
1651 }
1652 
ZSTD_DCtx_setParameter(ZSTD_DCtx * dctx,ZSTD_dParameter dParam,int value)1653 size_t ZSTD_DCtx_setParameter(ZSTD_DCtx* dctx, ZSTD_dParameter dParam, int value)
1654 {
1655     RETURN_ERROR_IF(dctx->streamStage != zdss_init, stage_wrong, "");
1656     switch(dParam) {
1657         case ZSTD_d_windowLogMax:
1658             if (value == 0) value = ZSTD_WINDOWLOG_LIMIT_DEFAULT;
1659             CHECK_DBOUNDS(ZSTD_d_windowLogMax, value);
1660             dctx->maxWindowSize = ((size_t)1) << value;
1661             return 0;
1662         case ZSTD_d_format:
1663             CHECK_DBOUNDS(ZSTD_d_format, value);
1664             dctx->format = (ZSTD_format_e)value;
1665             return 0;
1666         case ZSTD_d_stableOutBuffer:
1667             CHECK_DBOUNDS(ZSTD_d_stableOutBuffer, value);
1668             dctx->outBufferMode = (ZSTD_bufferMode_e)value;
1669             return 0;
1670         case ZSTD_d_forceIgnoreChecksum:
1671             CHECK_DBOUNDS(ZSTD_d_forceIgnoreChecksum, value);
1672             dctx->forceIgnoreChecksum = (ZSTD_forceIgnoreChecksum_e)value;
1673             return 0;
1674         case ZSTD_d_refMultipleDDicts:
1675             CHECK_DBOUNDS(ZSTD_d_refMultipleDDicts, value);
1676             if (dctx->staticSize != 0) {
1677                 RETURN_ERROR(parameter_unsupported, "Static dctx does not support multiple DDicts!");
1678             }
1679             dctx->refMultipleDDicts = (ZSTD_refMultipleDDicts_e)value;
1680             return 0;
1681         default:;
1682     }
1683     RETURN_ERROR(parameter_unsupported, "");
1684 }
1685 
ZSTD_DCtx_reset(ZSTD_DCtx * dctx,ZSTD_ResetDirective reset)1686 size_t ZSTD_DCtx_reset(ZSTD_DCtx* dctx, ZSTD_ResetDirective reset)
1687 {
1688     if ( (reset == ZSTD_reset_session_only)
1689       || (reset == ZSTD_reset_session_and_parameters) ) {
1690         dctx->streamStage = zdss_init;
1691         dctx->noForwardProgress = 0;
1692     }
1693     if ( (reset == ZSTD_reset_parameters)
1694       || (reset == ZSTD_reset_session_and_parameters) ) {
1695         RETURN_ERROR_IF(dctx->streamStage != zdss_init, stage_wrong, "");
1696         ZSTD_clearDict(dctx);
1697         ZSTD_DCtx_resetParameters(dctx);
1698     }
1699     return 0;
1700 }
1701 
1702 
ZSTD_sizeof_DStream(const ZSTD_DStream * dctx)1703 size_t ZSTD_sizeof_DStream(const ZSTD_DStream* dctx)
1704 {
1705     return ZSTD_sizeof_DCtx(dctx);
1706 }
1707 
ZSTD_decodingBufferSize_min(unsigned long long windowSize,unsigned long long frameContentSize)1708 size_t ZSTD_decodingBufferSize_min(unsigned long long windowSize, unsigned long long frameContentSize)
1709 {
1710     size_t const blockSize = (size_t) MIN(windowSize, ZSTD_BLOCKSIZE_MAX);
1711     unsigned long long const neededRBSize = windowSize + blockSize + (WILDCOPY_OVERLENGTH * 2);
1712     unsigned long long const neededSize = MIN(frameContentSize, neededRBSize);
1713     size_t const minRBSize = (size_t) neededSize;
1714     RETURN_ERROR_IF((unsigned long long)minRBSize != neededSize,
1715                     frameParameter_windowTooLarge, "");
1716     return minRBSize;
1717 }
1718 
ZSTD_estimateDStreamSize(size_t windowSize)1719 size_t ZSTD_estimateDStreamSize(size_t windowSize)
1720 {
1721     size_t const blockSize = MIN(windowSize, ZSTD_BLOCKSIZE_MAX);
1722     size_t const inBuffSize = blockSize;  /* no block can be larger */
1723     size_t const outBuffSize = ZSTD_decodingBufferSize_min(windowSize, ZSTD_CONTENTSIZE_UNKNOWN);
1724     return ZSTD_estimateDCtxSize() + inBuffSize + outBuffSize;
1725 }
1726 
ZSTD_estimateDStreamSize_fromFrame(const void * src,size_t srcSize)1727 size_t ZSTD_estimateDStreamSize_fromFrame(const void* src, size_t srcSize)
1728 {
1729     U32 const windowSizeMax = 1U << ZSTD_WINDOWLOG_MAX;   /* note : should be user-selectable, but requires an additional parameter (or a dctx) */
1730     ZSTD_frameHeader zfh;
1731     size_t const err = ZSTD_getFrameHeader(&zfh, src, srcSize);
1732     if (ZSTD_isError(err)) return err;
1733     RETURN_ERROR_IF(err>0, srcSize_wrong, "");
1734     RETURN_ERROR_IF(zfh.windowSize > windowSizeMax,
1735                     frameParameter_windowTooLarge, "");
1736     return ZSTD_estimateDStreamSize((size_t)zfh.windowSize);
1737 }
1738 
1739 
1740 /* *****   Decompression   ***** */
1741 
ZSTD_DCtx_isOverflow(ZSTD_DStream * zds,size_t const neededInBuffSize,size_t const neededOutBuffSize)1742 static int ZSTD_DCtx_isOverflow(ZSTD_DStream* zds, size_t const neededInBuffSize, size_t const neededOutBuffSize)
1743 {
1744     return (zds->inBuffSize + zds->outBuffSize) >= (neededInBuffSize + neededOutBuffSize) * ZSTD_WORKSPACETOOLARGE_FACTOR;
1745 }
1746 
ZSTD_DCtx_updateOversizedDuration(ZSTD_DStream * zds,size_t const neededInBuffSize,size_t const neededOutBuffSize)1747 static void ZSTD_DCtx_updateOversizedDuration(ZSTD_DStream* zds, size_t const neededInBuffSize, size_t const neededOutBuffSize)
1748 {
1749     if (ZSTD_DCtx_isOverflow(zds, neededInBuffSize, neededOutBuffSize))
1750         zds->oversizedDuration++;
1751     else
1752         zds->oversizedDuration = 0;
1753 }
1754 
ZSTD_DCtx_isOversizedTooLong(ZSTD_DStream * zds)1755 static int ZSTD_DCtx_isOversizedTooLong(ZSTD_DStream* zds)
1756 {
1757     return zds->oversizedDuration >= ZSTD_WORKSPACETOOLARGE_MAXDURATION;
1758 }
1759 
1760 /* Checks that the output buffer hasn't changed if ZSTD_obm_stable is used. */
ZSTD_checkOutBuffer(ZSTD_DStream const * zds,ZSTD_outBuffer const * output)1761 static size_t ZSTD_checkOutBuffer(ZSTD_DStream const* zds, ZSTD_outBuffer const* output)
1762 {
1763     ZSTD_outBuffer const expect = zds->expectedOutBuffer;
1764     /* No requirement when ZSTD_obm_stable is not enabled. */
1765     if (zds->outBufferMode != ZSTD_bm_stable)
1766         return 0;
1767     /* Any buffer is allowed in zdss_init, this must be the same for every other call until
1768      * the context is reset.
1769      */
1770     if (zds->streamStage == zdss_init)
1771         return 0;
1772     /* The buffer must match our expectation exactly. */
1773     if (expect.dst == output->dst && expect.pos == output->pos && expect.size == output->size)
1774         return 0;
1775     RETURN_ERROR(dstBuffer_wrong, "ZSTD_d_stableOutBuffer enabled but output differs!");
1776 }
1777 
1778 /* Calls ZSTD_decompressContinue() with the right parameters for ZSTD_decompressStream()
1779  * and updates the stage and the output buffer state. This call is extracted so it can be
1780  * used both when reading directly from the ZSTD_inBuffer, and in buffered input mode.
1781  * NOTE: You must break after calling this function since the streamStage is modified.
1782  */
ZSTD_decompressContinueStream(ZSTD_DStream * zds,char ** op,char * oend,void const * src,size_t srcSize)1783 static size_t ZSTD_decompressContinueStream(
1784             ZSTD_DStream* zds, char** op, char* oend,
1785             void const* src, size_t srcSize) {
1786     int const isSkipFrame = ZSTD_isSkipFrame(zds);
1787     if (zds->outBufferMode == ZSTD_bm_buffered) {
1788         size_t const dstSize = isSkipFrame ? 0 : zds->outBuffSize - zds->outStart;
1789         size_t const decodedSize = ZSTD_decompressContinue(zds,
1790                 zds->outBuff + zds->outStart, dstSize, src, srcSize);
1791         FORWARD_IF_ERROR(decodedSize, "");
1792         if (!decodedSize && !isSkipFrame) {
1793             zds->streamStage = zdss_read;
1794         } else {
1795             zds->outEnd = zds->outStart + decodedSize;
1796             zds->streamStage = zdss_flush;
1797         }
1798     } else {
1799         /* Write directly into the output buffer */
1800         size_t const dstSize = isSkipFrame ? 0 : (size_t)(oend - *op);
1801         size_t const decodedSize = ZSTD_decompressContinue(zds, *op, dstSize, src, srcSize);
1802         FORWARD_IF_ERROR(decodedSize, "");
1803         *op += decodedSize;
1804         /* Flushing is not needed. */
1805         zds->streamStage = zdss_read;
1806         assert(*op <= oend);
1807         assert(zds->outBufferMode == ZSTD_bm_stable);
1808     }
1809     return 0;
1810 }
1811 
ZSTD_decompressStream(ZSTD_DStream * zds,ZSTD_outBuffer * output,ZSTD_inBuffer * input)1812 size_t ZSTD_decompressStream(ZSTD_DStream* zds, ZSTD_outBuffer* output, ZSTD_inBuffer* input)
1813 {
1814     const char* const src = (const char*)input->src;
1815     const char* const istart = input->pos != 0 ? src + input->pos : src;
1816     const char* const iend = input->size != 0 ? src + input->size : src;
1817     const char* ip = istart;
1818     char* const dst = (char*)output->dst;
1819     char* const ostart = output->pos != 0 ? dst + output->pos : dst;
1820     char* const oend = output->size != 0 ? dst + output->size : dst;
1821     char* op = ostart;
1822     U32 someMoreWork = 1;
1823 
1824     DEBUGLOG(5, "ZSTD_decompressStream");
1825     RETURN_ERROR_IF(
1826         input->pos > input->size,
1827         srcSize_wrong,
1828         "forbidden. in: pos: %u   vs size: %u",
1829         (U32)input->pos, (U32)input->size);
1830     RETURN_ERROR_IF(
1831         output->pos > output->size,
1832         dstSize_tooSmall,
1833         "forbidden. out: pos: %u   vs size: %u",
1834         (U32)output->pos, (U32)output->size);
1835     DEBUGLOG(5, "input size : %u", (U32)(input->size - input->pos));
1836     FORWARD_IF_ERROR(ZSTD_checkOutBuffer(zds, output), "");
1837 
1838     while (someMoreWork) {
1839         switch(zds->streamStage)
1840         {
1841         case zdss_init :
1842             DEBUGLOG(5, "stage zdss_init => transparent reset ");
1843             zds->streamStage = zdss_loadHeader;
1844             zds->lhSize = zds->inPos = zds->outStart = zds->outEnd = 0;
1845             zds->legacyVersion = 0;
1846             zds->hostageByte = 0;
1847             zds->expectedOutBuffer = *output;
1848             ZSTD_FALLTHROUGH;
1849 
1850         case zdss_loadHeader :
1851             DEBUGLOG(5, "stage zdss_loadHeader (srcSize : %u)", (U32)(iend - ip));
1852             {   size_t const hSize = ZSTD_getFrameHeader_advanced(&zds->fParams, zds->headerBuffer, zds->lhSize, zds->format);
1853                 if (zds->refMultipleDDicts && zds->ddictSet) {
1854                     ZSTD_DCtx_selectFrameDDict(zds);
1855                 }
1856                 DEBUGLOG(5, "header size : %u", (U32)hSize);
1857                 if (ZSTD_isError(hSize)) {
1858                     return hSize;   /* error */
1859                 }
1860                 if (hSize != 0) {   /* need more input */
1861                     size_t const toLoad = hSize - zds->lhSize;   /* if hSize!=0, hSize > zds->lhSize */
1862                     size_t const remainingInput = (size_t)(iend-ip);
1863                     assert(iend >= ip);
1864                     if (toLoad > remainingInput) {   /* not enough input to load full header */
1865                         if (remainingInput > 0) {
1866                             ZSTD_memcpy(zds->headerBuffer + zds->lhSize, ip, remainingInput);
1867                             zds->lhSize += remainingInput;
1868                         }
1869                         input->pos = input->size;
1870                         return (MAX((size_t)ZSTD_FRAMEHEADERSIZE_MIN(zds->format), hSize) - zds->lhSize) + ZSTD_blockHeaderSize;   /* remaining header bytes + next block header */
1871                     }
1872                     assert(ip != NULL);
1873                     ZSTD_memcpy(zds->headerBuffer + zds->lhSize, ip, toLoad); zds->lhSize = hSize; ip += toLoad;
1874                     break;
1875             }   }
1876 
1877             /* check for single-pass mode opportunity */
1878             if (zds->fParams.frameContentSize != ZSTD_CONTENTSIZE_UNKNOWN
1879                 && zds->fParams.frameType != ZSTD_skippableFrame
1880                 && (U64)(size_t)(oend-op) >= zds->fParams.frameContentSize) {
1881                 size_t const cSize = ZSTD_findFrameCompressedSize(istart, (size_t)(iend-istart));
1882                 if (cSize <= (size_t)(iend-istart)) {
1883                     /* shortcut : using single-pass mode */
1884                     size_t const decompressedSize = ZSTD_decompress_usingDDict(zds, op, (size_t)(oend-op), istart, cSize, ZSTD_getDDict(zds));
1885                     if (ZSTD_isError(decompressedSize)) return decompressedSize;
1886                     DEBUGLOG(4, "shortcut to single-pass ZSTD_decompress_usingDDict()")
1887                     ip = istart + cSize;
1888                     op += decompressedSize;
1889                     zds->expected = 0;
1890                     zds->streamStage = zdss_init;
1891                     someMoreWork = 0;
1892                     break;
1893             }   }
1894 
1895             /* Check output buffer is large enough for ZSTD_odm_stable. */
1896             if (zds->outBufferMode == ZSTD_bm_stable
1897                 && zds->fParams.frameType != ZSTD_skippableFrame
1898                 && zds->fParams.frameContentSize != ZSTD_CONTENTSIZE_UNKNOWN
1899                 && (U64)(size_t)(oend-op) < zds->fParams.frameContentSize) {
1900                 RETURN_ERROR(dstSize_tooSmall, "ZSTD_obm_stable passed but ZSTD_outBuffer is too small");
1901             }
1902 
1903             /* Consume header (see ZSTDds_decodeFrameHeader) */
1904             DEBUGLOG(4, "Consume header");
1905             FORWARD_IF_ERROR(ZSTD_decompressBegin_usingDDict(zds, ZSTD_getDDict(zds)), "");
1906 
1907             if ((MEM_readLE32(zds->headerBuffer) & ZSTD_MAGIC_SKIPPABLE_MASK) == ZSTD_MAGIC_SKIPPABLE_START) {  /* skippable frame */
1908                 zds->expected = MEM_readLE32(zds->headerBuffer + ZSTD_FRAMEIDSIZE);
1909                 zds->stage = ZSTDds_skipFrame;
1910             } else {
1911                 FORWARD_IF_ERROR(ZSTD_decodeFrameHeader(zds, zds->headerBuffer, zds->lhSize), "");
1912                 zds->expected = ZSTD_blockHeaderSize;
1913                 zds->stage = ZSTDds_decodeBlockHeader;
1914             }
1915 
1916             /* control buffer memory usage */
1917             DEBUGLOG(4, "Control max memory usage (%u KB <= max %u KB)",
1918                         (U32)(zds->fParams.windowSize >>10),
1919                         (U32)(zds->maxWindowSize >> 10) );
1920             zds->fParams.windowSize = MAX(zds->fParams.windowSize, 1U << ZSTD_WINDOWLOG_ABSOLUTEMIN);
1921             RETURN_ERROR_IF(zds->fParams.windowSize > zds->maxWindowSize,
1922                             frameParameter_windowTooLarge, "");
1923 
1924             /* Adapt buffer sizes to frame header instructions */
1925             {   size_t const neededInBuffSize = MAX(zds->fParams.blockSizeMax, 4 /* frame checksum */);
1926                 size_t const neededOutBuffSize = zds->outBufferMode == ZSTD_bm_buffered
1927                         ? ZSTD_decodingBufferSize_min(zds->fParams.windowSize, zds->fParams.frameContentSize)
1928                         : 0;
1929 
1930                 ZSTD_DCtx_updateOversizedDuration(zds, neededInBuffSize, neededOutBuffSize);
1931 
1932                 {   int const tooSmall = (zds->inBuffSize < neededInBuffSize) || (zds->outBuffSize < neededOutBuffSize);
1933                     int const tooLarge = ZSTD_DCtx_isOversizedTooLong(zds);
1934 
1935                     if (tooSmall || tooLarge) {
1936                         size_t const bufferSize = neededInBuffSize + neededOutBuffSize;
1937                         DEBUGLOG(4, "inBuff  : from %u to %u",
1938                                     (U32)zds->inBuffSize, (U32)neededInBuffSize);
1939                         DEBUGLOG(4, "outBuff : from %u to %u",
1940                                     (U32)zds->outBuffSize, (U32)neededOutBuffSize);
1941                         if (zds->staticSize) {  /* static DCtx */
1942                             DEBUGLOG(4, "staticSize : %u", (U32)zds->staticSize);
1943                             assert(zds->staticSize >= sizeof(ZSTD_DCtx));  /* controlled at init */
1944                             RETURN_ERROR_IF(
1945                                 bufferSize > zds->staticSize - sizeof(ZSTD_DCtx),
1946                                 memory_allocation, "");
1947                         } else {
1948                             ZSTD_customFree(zds->inBuff, zds->customMem);
1949                             zds->inBuffSize = 0;
1950                             zds->outBuffSize = 0;
1951                             zds->inBuff = (char*)ZSTD_customMalloc(bufferSize, zds->customMem);
1952                             RETURN_ERROR_IF(zds->inBuff == NULL, memory_allocation, "");
1953                         }
1954                         zds->inBuffSize = neededInBuffSize;
1955                         zds->outBuff = zds->inBuff + zds->inBuffSize;
1956                         zds->outBuffSize = neededOutBuffSize;
1957             }   }   }
1958             zds->streamStage = zdss_read;
1959             ZSTD_FALLTHROUGH;
1960 
1961         case zdss_read:
1962             DEBUGLOG(5, "stage zdss_read");
1963             {   size_t const neededInSize = ZSTD_nextSrcSizeToDecompressWithInputSize(zds, (size_t)(iend - ip));
1964                 DEBUGLOG(5, "neededInSize = %u", (U32)neededInSize);
1965                 if (neededInSize==0) {  /* end of frame */
1966                     zds->streamStage = zdss_init;
1967                     someMoreWork = 0;
1968                     break;
1969                 }
1970                 if ((size_t)(iend-ip) >= neededInSize) {  /* decode directly from src */
1971                     FORWARD_IF_ERROR(ZSTD_decompressContinueStream(zds, &op, oend, ip, neededInSize), "");
1972                     ip += neededInSize;
1973                     /* Function modifies the stage so we must break */
1974                     break;
1975             }   }
1976             if (ip==iend) { someMoreWork = 0; break; }   /* no more input */
1977             zds->streamStage = zdss_load;
1978             ZSTD_FALLTHROUGH;
1979 
1980         case zdss_load:
1981             {   size_t const neededInSize = ZSTD_nextSrcSizeToDecompress(zds);
1982                 size_t const toLoad = neededInSize - zds->inPos;
1983                 int const isSkipFrame = ZSTD_isSkipFrame(zds);
1984                 size_t loadedSize;
1985                 /* At this point we shouldn't be decompressing a block that we can stream. */
1986                 assert(neededInSize == ZSTD_nextSrcSizeToDecompressWithInputSize(zds, iend - ip));
1987                 if (isSkipFrame) {
1988                     loadedSize = MIN(toLoad, (size_t)(iend-ip));
1989                 } else {
1990                     RETURN_ERROR_IF(toLoad > zds->inBuffSize - zds->inPos,
1991                                     corruption_detected,
1992                                     "should never happen");
1993                     loadedSize = ZSTD_limitCopy(zds->inBuff + zds->inPos, toLoad, ip, (size_t)(iend-ip));
1994                 }
1995                 ip += loadedSize;
1996                 zds->inPos += loadedSize;
1997                 if (loadedSize < toLoad) { someMoreWork = 0; break; }   /* not enough input, wait for more */
1998 
1999                 /* decode loaded input */
2000                 zds->inPos = 0;   /* input is consumed */
2001                 FORWARD_IF_ERROR(ZSTD_decompressContinueStream(zds, &op, oend, zds->inBuff, neededInSize), "");
2002                 /* Function modifies the stage so we must break */
2003                 break;
2004             }
2005         case zdss_flush:
2006             {   size_t const toFlushSize = zds->outEnd - zds->outStart;
2007                 size_t const flushedSize = ZSTD_limitCopy(op, (size_t)(oend-op), zds->outBuff + zds->outStart, toFlushSize);
2008                 op += flushedSize;
2009                 zds->outStart += flushedSize;
2010                 if (flushedSize == toFlushSize) {  /* flush completed */
2011                     zds->streamStage = zdss_read;
2012                     if ( (zds->outBuffSize < zds->fParams.frameContentSize)
2013                       && (zds->outStart + zds->fParams.blockSizeMax > zds->outBuffSize) ) {
2014                         DEBUGLOG(5, "restart filling outBuff from beginning (left:%i, needed:%u)",
2015                                 (int)(zds->outBuffSize - zds->outStart),
2016                                 (U32)zds->fParams.blockSizeMax);
2017                         zds->outStart = zds->outEnd = 0;
2018                     }
2019                     break;
2020             }   }
2021             /* cannot complete flush */
2022             someMoreWork = 0;
2023             break;
2024 
2025         default:
2026             assert(0);    /* impossible */
2027             RETURN_ERROR(GENERIC, "impossible to reach");   /* some compiler require default to do something */
2028     }   }
2029 
2030     /* result */
2031     input->pos = (size_t)(ip - (const char*)(input->src));
2032     output->pos = (size_t)(op - (char*)(output->dst));
2033 
2034     /* Update the expected output buffer for ZSTD_obm_stable. */
2035     zds->expectedOutBuffer = *output;
2036 
2037     if ((ip==istart) && (op==ostart)) {  /* no forward progress */
2038         zds->noForwardProgress ++;
2039         if (zds->noForwardProgress >= ZSTD_NO_FORWARD_PROGRESS_MAX) {
2040             RETURN_ERROR_IF(op==oend, dstSize_tooSmall, "");
2041             RETURN_ERROR_IF(ip==iend, srcSize_wrong, "");
2042             assert(0);
2043         }
2044     } else {
2045         zds->noForwardProgress = 0;
2046     }
2047     {   size_t nextSrcSizeHint = ZSTD_nextSrcSizeToDecompress(zds);
2048         if (!nextSrcSizeHint) {   /* frame fully decoded */
2049             if (zds->outEnd == zds->outStart) {  /* output fully flushed */
2050                 if (zds->hostageByte) {
2051                     if (input->pos >= input->size) {
2052                         /* can't release hostage (not present) */
2053                         zds->streamStage = zdss_read;
2054                         return 1;
2055                     }
2056                     input->pos++;  /* release hostage */
2057                 }   /* zds->hostageByte */
2058                 return 0;
2059             }  /* zds->outEnd == zds->outStart */
2060             if (!zds->hostageByte) { /* output not fully flushed; keep last byte as hostage; will be released when all output is flushed */
2061                 input->pos--;   /* note : pos > 0, otherwise, impossible to finish reading last block */
2062                 zds->hostageByte=1;
2063             }
2064             return 1;
2065         }  /* nextSrcSizeHint==0 */
2066         nextSrcSizeHint += ZSTD_blockHeaderSize * (ZSTD_nextInputType(zds) == ZSTDnit_block);   /* preload header of next block */
2067         assert(zds->inPos <= nextSrcSizeHint);
2068         nextSrcSizeHint -= zds->inPos;   /* part already loaded*/
2069         return nextSrcSizeHint;
2070     }
2071 }
2072 
ZSTD_decompressStream_simpleArgs(ZSTD_DCtx * dctx,void * dst,size_t dstCapacity,size_t * dstPos,const void * src,size_t srcSize,size_t * srcPos)2073 size_t ZSTD_decompressStream_simpleArgs (
2074                             ZSTD_DCtx* dctx,
2075                             void* dst, size_t dstCapacity, size_t* dstPos,
2076                       const void* src, size_t srcSize, size_t* srcPos)
2077 {
2078     ZSTD_outBuffer output = { dst, dstCapacity, *dstPos };
2079     ZSTD_inBuffer  input  = { src, srcSize, *srcPos };
2080     /* ZSTD_compress_generic() will check validity of dstPos and srcPos */
2081     size_t const cErr = ZSTD_decompressStream(dctx, &output, &input);
2082     *dstPos = output.pos;
2083     *srcPos = input.pos;
2084     return cErr;
2085 }
2086