648db22b |
1 | /* |
2 | * Copyright (c) Meta Platforms, Inc. and affiliates. |
3 | * All rights reserved. |
4 | * |
5 | * This source code is licensed under both the BSD-style license (found in the |
6 | * LICENSE file in the root directory of this source tree) and the GPLv2 (found |
7 | * in the COPYING file in the root directory of this source tree). |
8 | * You may select, at your option, one of the above-listed licenses. |
9 | */ |
10 | |
11 | /* |
12 | This program takes a file in input, |
13 | performs a zstd round-trip test (compression - decompress) |
14 | compares the result with original |
15 | and generates a crash (double free) on corruption detection. |
16 | */ |
17 | |
18 | /*=========================================== |
19 | * Dependencies |
20 | *==========================================*/ |
21 | #include <stddef.h> /* size_t */ |
22 | #include <stdlib.h> /* malloc, free, exit */ |
23 | #include <stdio.h> /* fprintf */ |
24 | #include <string.h> /* strcmp */ |
25 | #include <sys/types.h> /* stat */ |
26 | #include <sys/stat.h> /* stat */ |
27 | #include "xxhash.h" |
28 | |
29 | #define ZSTD_STATIC_LINKING_ONLY |
30 | #include "zstd.h" |
31 | |
32 | /*=========================================== |
33 | * Macros |
34 | *==========================================*/ |
35 | #define MIN(a,b) ( (a) < (b) ? (a) : (b) ) |
36 | |
37 | static void crash(int errorCode){ |
38 | /* abort if AFL/libfuzzer, exit otherwise */ |
39 | #ifdef FUZZING_BUILD_MODE_UNSAFE_FOR_PRODUCTION /* could also use __AFL_COMPILER */ |
40 | abort(); |
41 | #else |
42 | exit(errorCode); |
43 | #endif |
44 | } |
45 | |
46 | #define CHECK_Z(f) { \ |
47 | size_t const err = f; \ |
48 | if (ZSTD_isError(err)) { \ |
49 | fprintf(stderr, \ |
50 | "Error=> %s: %s", \ |
51 | #f, ZSTD_getErrorName(err)); \ |
52 | crash(1); \ |
53 | } } |
54 | |
55 | /** roundTripTest() : |
56 | * Compresses `srcBuff` into `compressedBuff`, |
57 | * then decompresses `compressedBuff` into `resultBuff`. |
58 | * Compression level used is derived from first content byte. |
59 | * @return : result of decompression, which should be == `srcSize` |
60 | * or an error code if either compression or decompression fails. |
61 | * Note : `compressedBuffCapacity` should be `>= ZSTD_compressBound(srcSize)` |
62 | * for compression to be guaranteed to work */ |
63 | static size_t roundTripTest(void* resultBuff, size_t resultBuffCapacity, |
64 | void* compressedBuff, size_t compressedBuffCapacity, |
65 | const void* srcBuff, size_t srcBuffSize) |
66 | { |
67 | static const int maxClevel = 19; |
68 | size_t const hashLength = MIN(128, srcBuffSize); |
69 | unsigned const h32 = XXH32(srcBuff, hashLength, 0); |
70 | int const cLevel = h32 % maxClevel; |
71 | size_t const cSize = ZSTD_compress(compressedBuff, compressedBuffCapacity, srcBuff, srcBuffSize, cLevel); |
72 | if (ZSTD_isError(cSize)) { |
73 | fprintf(stderr, "Compression error : %s \n", ZSTD_getErrorName(cSize)); |
74 | return cSize; |
75 | } |
76 | return ZSTD_decompress(resultBuff, resultBuffCapacity, compressedBuff, cSize); |
77 | } |
78 | |
79 | /** cctxParamRoundTripTest() : |
80 | * Same as roundTripTest() except allows experimenting with ZSTD_CCtx_params. */ |
81 | static size_t cctxParamRoundTripTest(void* resultBuff, size_t resultBuffCapacity, |
82 | void* compressedBuff, size_t compressedBuffCapacity, |
83 | const void* srcBuff, size_t srcBuffSize) |
84 | { |
85 | ZSTD_CCtx* const cctx = ZSTD_createCCtx(); |
86 | ZSTD_CCtx_params* const cctxParams = ZSTD_createCCtxParams(); |
87 | ZSTD_inBuffer inBuffer = { srcBuff, srcBuffSize, 0 }; |
88 | ZSTD_outBuffer outBuffer = { compressedBuff, compressedBuffCapacity, 0 }; |
89 | |
90 | static const int maxClevel = 19; |
91 | size_t const hashLength = MIN(128, srcBuffSize); |
92 | unsigned const h32 = XXH32(srcBuff, hashLength, 0); |
93 | int const cLevel = h32 % maxClevel; |
94 | |
95 | /* Set parameters */ |
96 | CHECK_Z( ZSTD_CCtxParams_setParameter(cctxParams, ZSTD_c_compressionLevel, cLevel) ); |
97 | CHECK_Z( ZSTD_CCtxParams_setParameter(cctxParams, ZSTD_c_nbWorkers, 2) ); |
98 | CHECK_Z( ZSTD_CCtxParams_setParameter(cctxParams, ZSTD_c_overlapLog, 5) ); |
99 | |
100 | |
101 | /* Apply parameters */ |
102 | CHECK_Z( ZSTD_CCtx_setParametersUsingCCtxParams(cctx, cctxParams) ); |
103 | |
104 | CHECK_Z (ZSTD_compressStream2(cctx, &outBuffer, &inBuffer, ZSTD_e_end) ); |
105 | |
106 | ZSTD_freeCCtxParams(cctxParams); |
107 | ZSTD_freeCCtx(cctx); |
108 | |
109 | return ZSTD_decompress(resultBuff, resultBuffCapacity, compressedBuff, outBuffer.pos); |
110 | } |
111 | |
112 | static size_t checkBuffers(const void* buff1, const void* buff2, size_t buffSize) |
113 | { |
114 | const char* ip1 = (const char*)buff1; |
115 | const char* ip2 = (const char*)buff2; |
116 | size_t pos; |
117 | |
118 | for (pos=0; pos<buffSize; pos++) |
119 | if (ip1[pos]!=ip2[pos]) |
120 | break; |
121 | |
122 | return pos; |
123 | } |
124 | |
125 | static void roundTripCheck(const void* srcBuff, size_t srcBuffSize, int testCCtxParams) |
126 | { |
127 | size_t const cBuffSize = ZSTD_compressBound(srcBuffSize); |
128 | void* cBuff = malloc(cBuffSize); |
129 | void* rBuff = malloc(cBuffSize); |
130 | |
131 | if (!cBuff || !rBuff) { |
132 | fprintf(stderr, "not enough memory ! \n"); |
133 | exit (1); |
134 | } |
135 | |
136 | { size_t const result = testCCtxParams ? |
137 | cctxParamRoundTripTest(rBuff, cBuffSize, cBuff, cBuffSize, srcBuff, srcBuffSize) |
138 | : roundTripTest(rBuff, cBuffSize, cBuff, cBuffSize, srcBuff, srcBuffSize); |
139 | if (ZSTD_isError(result)) { |
140 | fprintf(stderr, "roundTripTest error : %s \n", ZSTD_getErrorName(result)); |
141 | crash(1); |
142 | } |
143 | if (result != srcBuffSize) { |
144 | fprintf(stderr, "Incorrect regenerated size : %u != %u\n", (unsigned)result, (unsigned)srcBuffSize); |
145 | crash(1); |
146 | } |
147 | if (checkBuffers(srcBuff, rBuff, srcBuffSize) != srcBuffSize) { |
148 | fprintf(stderr, "Silent decoding corruption !!!"); |
149 | crash(1); |
150 | } |
151 | } |
152 | |
153 | free(cBuff); |
154 | free(rBuff); |
155 | } |
156 | |
157 | |
158 | static size_t getFileSize(const char* infilename) |
159 | { |
160 | int r; |
161 | #if defined(_MSC_VER) |
162 | struct _stat64 statbuf; |
163 | r = _stat64(infilename, &statbuf); |
164 | if (r || !(statbuf.st_mode & S_IFREG)) return 0; /* No good... */ |
165 | #else |
166 | struct stat statbuf; |
167 | r = stat(infilename, &statbuf); |
168 | if (r || !S_ISREG(statbuf.st_mode)) return 0; /* No good... */ |
169 | #endif |
170 | return (size_t)statbuf.st_size; |
171 | } |
172 | |
173 | |
174 | static int isDirectory(const char* infilename) |
175 | { |
176 | int r; |
177 | #if defined(_MSC_VER) |
178 | struct _stat64 statbuf; |
179 | r = _stat64(infilename, &statbuf); |
180 | if (!r && (statbuf.st_mode & _S_IFDIR)) return 1; |
181 | #else |
182 | struct stat statbuf; |
183 | r = stat(infilename, &statbuf); |
184 | if (!r && S_ISDIR(statbuf.st_mode)) return 1; |
185 | #endif |
186 | return 0; |
187 | } |
188 | |
189 | |
190 | /** loadFile() : |
191 | * requirement : `buffer` size >= `fileSize` */ |
192 | static void loadFile(void* buffer, const char* fileName, size_t fileSize) |
193 | { |
194 | FILE* const f = fopen(fileName, "rb"); |
195 | if (isDirectory(fileName)) { |
196 | fprintf(stderr, "Ignoring %s directory \n", fileName); |
197 | exit(2); |
198 | } |
199 | if (f==NULL) { |
200 | fprintf(stderr, "Impossible to open %s \n", fileName); |
201 | exit(3); |
202 | } |
203 | { size_t const readSize = fread(buffer, 1, fileSize, f); |
204 | if (readSize != fileSize) { |
205 | fprintf(stderr, "Error reading %s \n", fileName); |
206 | exit(5); |
207 | } } |
208 | fclose(f); |
209 | } |
210 | |
211 | |
212 | static void fileCheck(const char* fileName, int testCCtxParams) |
213 | { |
214 | size_t const fileSize = getFileSize(fileName); |
215 | void* const buffer = malloc(fileSize + !fileSize /* avoid 0 */); |
216 | if (!buffer) { |
217 | fprintf(stderr, "not enough memory \n"); |
218 | exit(4); |
219 | } |
220 | loadFile(buffer, fileName, fileSize); |
221 | roundTripCheck(buffer, fileSize, testCCtxParams); |
222 | free (buffer); |
223 | } |
224 | |
225 | int main(int argCount, const char** argv) { |
226 | int argNb = 1; |
227 | int testCCtxParams = 0; |
228 | if (argCount < 2) { |
229 | fprintf(stderr, "Error : no argument : need input file \n"); |
230 | exit(9); |
231 | } |
232 | |
233 | if (!strcmp(argv[argNb], "--cctxParams")) { |
234 | testCCtxParams = 1; |
235 | argNb++; |
236 | } |
237 | |
238 | fileCheck(argv[argNb], testCCtxParams); |
239 | fprintf(stderr, "no pb detected\n"); |
240 | return 0; |
241 | } |