git subrepo pull (merge) --force deps/libchdr
[pcsx_rearmed.git] / deps / libchdr / deps / zstd-1.5.5 / contrib / diagnose_corruption / check_flipped_bits.c
CommitLineData
648db22b 1/*
2 * Copyright (c) Meta Platforms, Inc. and affiliates.
3 * All rights reserved.
4 *
5 * This source code is licensed under both the BSD-style license (found in the
6 * LICENSE file in the root directory of this source tree) and the GPLv2 (found
7 * in the COPYING file in the root directory of this source tree).
8 * You may select, at your option, one of the above-listed licenses.
9 */
10
11#define ZSTD_STATIC_LINKING_ONLY
12#include "zstd.h"
13#include "zstd_errors.h"
14
15#include <stdio.h>
16#include <stdlib.h>
17#include <string.h>
18#include <sys/types.h>
19#include <sys/stat.h>
20#include <unistd.h>
21
22typedef struct {
23 char *input;
24 size_t input_size;
25
26 char *perturbed; /* same size as input */
27
28 char *output;
29 size_t output_size;
30
31 const char *dict_file_name;
32 const char *dict_file_dir_name;
33 int32_t dict_id;
34 char *dict;
35 size_t dict_size;
36 ZSTD_DDict* ddict;
37
38 ZSTD_DCtx* dctx;
39
40 int success_count;
41 int error_counts[ZSTD_error_maxCode];
42} stuff_t;
43
44static void free_stuff(stuff_t* stuff) {
45 free(stuff->input);
46 free(stuff->output);
47 ZSTD_freeDDict(stuff->ddict);
48 free(stuff->dict);
49 ZSTD_freeDCtx(stuff->dctx);
50}
51
52static void usage(void) {
53 fprintf(stderr, "check_flipped_bits input_filename [-d dict] [-D dict_dir]\n");
54 fprintf(stderr, "\n");
55 fprintf(stderr, "Arguments:\n");
56 fprintf(stderr, " -d file: path to a dictionary file to use.\n");
57 fprintf(stderr, " -D dir : path to a directory, with files containing dictionaries, of the\n"
58 " form DICTID.zstd-dict, e.g., 12345.zstd-dict.\n");
59 exit(1);
60}
61
62static void print_summary(stuff_t* stuff) {
63 int error_code;
64 fprintf(stderr, "%9d successful decompressions\n", stuff->success_count);
65 for (error_code = 0; error_code < ZSTD_error_maxCode; error_code++) {
66 int count = stuff->error_counts[error_code];
67 if (count) {
68 fprintf(
69 stderr, "%9d failed decompressions with message: %s\n",
70 count, ZSTD_getErrorString(error_code));
71 }
72 }
73}
74
75static char* readFile(const char* filename, size_t* size) {
76 struct stat statbuf;
77 int ret;
78 FILE* f;
79 char *buf;
80 size_t bytes_read;
81
82 ret = stat(filename, &statbuf);
83 if (ret != 0) {
84 fprintf(stderr, "stat failed: %m\n");
85 return NULL;
86 }
87 if ((statbuf.st_mode & S_IFREG) != S_IFREG) {
88 fprintf(stderr, "Input must be regular file\n");
89 return NULL;
90 }
91
92 *size = statbuf.st_size;
93
94 f = fopen(filename, "r");
95 if (f == NULL) {
96 fprintf(stderr, "fopen failed: %m\n");
97 return NULL;
98 }
99
100 buf = malloc(*size);
101 if (buf == NULL) {
102 fprintf(stderr, "malloc failed\n");
103 fclose(f);
104 return NULL;
105 }
106
107 bytes_read = fread(buf, 1, *size, f);
108 if (bytes_read != *size) {
109 fprintf(stderr, "failed to read whole file\n");
110 fclose(f);
111 free(buf);
112 return NULL;
113 }
114
115 ret = fclose(f);
116 if (ret != 0) {
117 fprintf(stderr, "fclose failed: %m\n");
118 free(buf);
119 return NULL;
120 }
121
122 return buf;
123}
124
125static ZSTD_DDict* readDict(const char* filename, char **buf, size_t* size, int32_t* dict_id) {
126 ZSTD_DDict* ddict;
127 *buf = readFile(filename, size);
128 if (*buf == NULL) {
129 fprintf(stderr, "Opening dictionary file '%s' failed\n", filename);
130 return NULL;
131 }
132
133 ddict = ZSTD_createDDict_advanced(*buf, *size, ZSTD_dlm_byRef, ZSTD_dct_auto, ZSTD_defaultCMem);
134 if (ddict == NULL) {
135 fprintf(stderr, "Failed to create ddict.\n");
136 return NULL;
137 }
138 if (dict_id != NULL) {
139 *dict_id = ZSTD_getDictID_fromDDict(ddict);
140 }
141 return ddict;
142}
143
144static ZSTD_DDict* readDictByID(stuff_t *stuff, int32_t dict_id, char **buf, size_t* size) {
145 if (stuff->dict_file_dir_name == NULL) {
146 return NULL;
147 } else {
148 size_t dir_name_len = strlen(stuff->dict_file_dir_name);
149 int dir_needs_separator = 0;
150 size_t dict_file_name_alloc_size = dir_name_len + 1 /* '/' */ + 10 /* max int32_t len */ + strlen(".zstd-dict") + 1 /* '\0' */;
151 char *dict_file_name = malloc(dict_file_name_alloc_size);
152 ZSTD_DDict* ddict;
153 int32_t read_dict_id;
154 if (dict_file_name == NULL) {
155 fprintf(stderr, "malloc failed.\n");
156 return 0;
157 }
158
159 if (dir_name_len > 0 && stuff->dict_file_dir_name[dir_name_len - 1] != '/') {
160 dir_needs_separator = 1;
161 }
162
163 snprintf(
164 dict_file_name,
165 dict_file_name_alloc_size,
166 "%s%s%u.zstd-dict",
167 stuff->dict_file_dir_name,
168 dir_needs_separator ? "/" : "",
169 dict_id);
170
171 /* fprintf(stderr, "Loading dict %u from '%s'.\n", dict_id, dict_file_name); */
172
173 ddict = readDict(dict_file_name, buf, size, &read_dict_id);
174 if (ddict == NULL) {
175 fprintf(stderr, "Failed to create ddict from '%s'.\n", dict_file_name);
176 free(dict_file_name);
177 return 0;
178 }
179 if (read_dict_id != dict_id) {
180 fprintf(stderr, "Read dictID (%u) does not match expected (%u).\n", read_dict_id, dict_id);
181 free(dict_file_name);
182 ZSTD_freeDDict(ddict);
183 return 0;
184 }
185
186 free(dict_file_name);
187 return ddict;
188 }
189}
190
191static int init_stuff(stuff_t* stuff, int argc, char *argv[]) {
192 const char* input_filename;
193
194 if (argc < 2) {
195 usage();
196 }
197
198 input_filename = argv[1];
199 stuff->input_size = 0;
200 stuff->input = readFile(input_filename, &stuff->input_size);
201 if (stuff->input == NULL) {
202 fprintf(stderr, "Failed to read input file.\n");
203 return 0;
204 }
205
206 stuff->perturbed = malloc(stuff->input_size);
207 if (stuff->perturbed == NULL) {
208 fprintf(stderr, "malloc failed.\n");
209 return 0;
210 }
211 memcpy(stuff->perturbed, stuff->input, stuff->input_size);
212
213 stuff->output_size = ZSTD_DStreamOutSize();
214 stuff->output = malloc(stuff->output_size);
215 if (stuff->output == NULL) {
216 fprintf(stderr, "malloc failed.\n");
217 return 0;
218 }
219
220 stuff->dict_file_name = NULL;
221 stuff->dict_file_dir_name = NULL;
222 stuff->dict_id = 0;
223 stuff->dict = NULL;
224 stuff->dict_size = 0;
225 stuff->ddict = NULL;
226
227 if (argc > 2) {
228 if (!strcmp(argv[2], "-d")) {
229 if (argc > 3) {
230 stuff->dict_file_name = argv[3];
231 } else {
232 usage();
233 }
234 } else
235 if (!strcmp(argv[2], "-D")) {
236 if (argc > 3) {
237 stuff->dict_file_dir_name = argv[3];
238 } else {
239 usage();
240 }
241 } else {
242 usage();
243 }
244 }
245
246 if (stuff->dict_file_dir_name) {
247 int32_t dict_id = ZSTD_getDictID_fromFrame(stuff->input, stuff->input_size);
248 if (dict_id != 0) {
249 stuff->ddict = readDictByID(stuff, dict_id, &stuff->dict, &stuff->dict_size);
250 if (stuff->ddict == NULL) {
251 fprintf(stderr, "Failed to create cached ddict.\n");
252 return 0;
253 }
254 stuff->dict_id = dict_id;
255 }
256 } else
257 if (stuff->dict_file_name) {
258 stuff->ddict = readDict(stuff->dict_file_name, &stuff->dict, &stuff->dict_size, &stuff->dict_id);
259 if (stuff->ddict == NULL) {
260 fprintf(stderr, "Failed to create ddict from '%s'.\n", stuff->dict_file_name);
261 return 0;
262 }
263 }
264
265 stuff->dctx = ZSTD_createDCtx();
266 if (stuff->dctx == NULL) {
267 return 0;
268 }
269
270 stuff->success_count = 0;
271 memset(stuff->error_counts, 0, sizeof(stuff->error_counts));
272
273 return 1;
274}
275
276static int test_decompress(stuff_t* stuff) {
277 size_t ret;
278 ZSTD_inBuffer in = {stuff->perturbed, stuff->input_size, 0};
279 ZSTD_outBuffer out = {stuff->output, stuff->output_size, 0};
280 ZSTD_DCtx* dctx = stuff->dctx;
281 int32_t custom_dict_id = ZSTD_getDictID_fromFrame(in.src, in.size);
282 char *custom_dict = NULL;
283 size_t custom_dict_size = 0;
284 ZSTD_DDict* custom_ddict = NULL;
285
286 if (custom_dict_id != 0 && custom_dict_id != stuff->dict_id) {
287 /* fprintf(stderr, "Instead of dict %u, this perturbed blob wants dict %u.\n", stuff->dict_id, custom_dict_id); */
288 custom_ddict = readDictByID(stuff, custom_dict_id, &custom_dict, &custom_dict_size);
289 }
290
291 ZSTD_DCtx_reset(dctx, ZSTD_reset_session_only);
292
293 if (custom_ddict != NULL) {
294 ZSTD_DCtx_refDDict(dctx, custom_ddict);
295 } else {
296 ZSTD_DCtx_refDDict(dctx, stuff->ddict);
297 }
298
299 while (in.pos != in.size) {
300 out.pos = 0;
301 ret = ZSTD_decompressStream(dctx, &out, &in);
302
303 if (ZSTD_isError(ret)) {
304 unsigned int code = ZSTD_getErrorCode(ret);
305 if (code >= ZSTD_error_maxCode) {
306 fprintf(stderr, "Received unexpected error code!\n");
307 exit(1);
308 }
309 stuff->error_counts[code]++;
310 /*
311 fprintf(
312 stderr, "Decompression failed: %s\n", ZSTD_getErrorName(ret));
313 */
314 if (custom_ddict != NULL) {
315 ZSTD_freeDDict(custom_ddict);
316 free(custom_dict);
317 }
318 return 0;
319 }
320 }
321
322 stuff->success_count++;
323
324 if (custom_ddict != NULL) {
325 ZSTD_freeDDict(custom_ddict);
326 free(custom_dict);
327 }
328 return 1;
329}
330
331static int perturb_bits(stuff_t* stuff) {
332 size_t pos;
333 size_t bit;
334 for (pos = 0; pos < stuff->input_size; pos++) {
335 unsigned char old_val = stuff->input[pos];
336 if (pos % 1000 == 0) {
337 fprintf(stderr, "Perturbing byte %zu / %zu\n", pos, stuff->input_size);
338 }
339 for (bit = 0; bit < 8; bit++) {
340 unsigned char new_val = old_val ^ (1 << bit);
341 stuff->perturbed[pos] = new_val;
342 if (test_decompress(stuff)) {
343 fprintf(
344 stderr,
345 "Flipping byte %zu bit %zu (0x%02x -> 0x%02x) "
346 "produced a successful decompression!\n",
347 pos, bit, old_val, new_val);
348 }
349 }
350 stuff->perturbed[pos] = old_val;
351 }
352 return 1;
353}
354
355static int perturb_bytes(stuff_t* stuff) {
356 size_t pos;
357 size_t new_val;
358 for (pos = 0; pos < stuff->input_size; pos++) {
359 unsigned char old_val = stuff->input[pos];
360 if (pos % 1000 == 0) {
361 fprintf(stderr, "Perturbing byte %zu / %zu\n", pos, stuff->input_size);
362 }
363 for (new_val = 0; new_val < 256; new_val++) {
364 stuff->perturbed[pos] = new_val;
365 if (test_decompress(stuff)) {
366 fprintf(
367 stderr,
368 "Changing byte %zu (0x%02x -> 0x%02x) "
369 "produced a successful decompression!\n",
370 pos, old_val, (unsigned char)new_val);
371 }
372 }
373 stuff->perturbed[pos] = old_val;
374 }
375 return 1;
376}
377
378int main(int argc, char* argv[]) {
379 stuff_t stuff;
380
381 if(!init_stuff(&stuff, argc, argv)) {
382 fprintf(stderr, "Failed to init.\n");
383 return 1;
384 }
385
386 if (test_decompress(&stuff)) {
387 fprintf(stderr, "Blob already decompresses successfully!\n");
388 return 1;
389 }
390
391 perturb_bits(&stuff);
392
393 perturb_bytes(&stuff);
394
395 print_summary(&stuff);
396
397 free_stuff(&stuff);
398
399 return 0;
400}