git subrepo pull (merge) --force deps/libchdr
[pcsx_rearmed.git] / deps / libchdr / deps / zstd-1.5.5 / contrib / diagnose_corruption / check_flipped_bits.c
1 /*
2  * Copyright (c) Meta Platforms, Inc. and affiliates.
3  * All rights reserved.
4  *
5  * This source code is licensed under both the BSD-style license (found in the
6  * LICENSE file in the root directory of this source tree) and the GPLv2 (found
7  * in the COPYING file in the root directory of this source tree).
8  * You may select, at your option, one of the above-listed licenses.
9  */
10
11 #define ZSTD_STATIC_LINKING_ONLY
12 #include "zstd.h"
13 #include "zstd_errors.h"
14
15 #include <stdio.h>
16 #include <stdlib.h>
17 #include <string.h>
18 #include <sys/types.h>
19 #include <sys/stat.h>
20 #include <unistd.h>
21
22 typedef struct {
23   char *input;
24   size_t input_size;
25
26   char *perturbed; /* same size as input */
27
28   char *output;
29   size_t output_size;
30
31   const char *dict_file_name;
32   const char *dict_file_dir_name;
33   int32_t dict_id;
34   char *dict;
35   size_t dict_size;
36   ZSTD_DDict* ddict;
37
38   ZSTD_DCtx* dctx;
39
40   int success_count;
41   int error_counts[ZSTD_error_maxCode];
42 } stuff_t;
43
44 static void free_stuff(stuff_t* stuff) {
45   free(stuff->input);
46   free(stuff->output);
47   ZSTD_freeDDict(stuff->ddict);
48   free(stuff->dict);
49   ZSTD_freeDCtx(stuff->dctx);
50 }
51
52 static void usage(void) {
53   fprintf(stderr, "check_flipped_bits input_filename [-d dict] [-D dict_dir]\n");
54   fprintf(stderr, "\n");
55   fprintf(stderr, "Arguments:\n");
56   fprintf(stderr, "    -d file: path to a dictionary file to use.\n");
57   fprintf(stderr, "    -D dir : path to a directory, with files containing dictionaries, of the\n"
58                   "             form DICTID.zstd-dict, e.g., 12345.zstd-dict.\n");
59   exit(1);
60 }
61
62 static void print_summary(stuff_t* stuff) {
63   int error_code;
64   fprintf(stderr, "%9d successful decompressions\n", stuff->success_count);
65   for (error_code = 0; error_code < ZSTD_error_maxCode; error_code++) {
66     int count = stuff->error_counts[error_code];
67     if (count) {
68       fprintf(
69           stderr, "%9d failed decompressions with message: %s\n",
70           count, ZSTD_getErrorString(error_code));
71     }
72   }
73 }
74
75 static char* readFile(const char* filename, size_t* size) {
76   struct stat statbuf;
77   int ret;
78   FILE* f;
79   char *buf;
80   size_t bytes_read;
81
82   ret = stat(filename, &statbuf);
83   if (ret != 0) {
84     fprintf(stderr, "stat failed: %m\n");
85     return NULL;
86   }
87   if ((statbuf.st_mode & S_IFREG) != S_IFREG) {
88     fprintf(stderr, "Input must be regular file\n");
89     return NULL;
90   }
91
92   *size = statbuf.st_size;
93
94   f = fopen(filename, "r");
95   if (f == NULL) {
96     fprintf(stderr, "fopen failed: %m\n");
97     return NULL;
98   }
99
100   buf = malloc(*size);
101   if (buf == NULL) {
102     fprintf(stderr, "malloc failed\n");
103     fclose(f);
104     return NULL;
105   }
106
107   bytes_read = fread(buf, 1, *size, f);
108   if (bytes_read != *size) {
109     fprintf(stderr, "failed to read whole file\n");
110     fclose(f);
111     free(buf);
112     return NULL;
113   }
114
115   ret = fclose(f);
116   if (ret != 0) {
117     fprintf(stderr, "fclose failed: %m\n");
118     free(buf);
119     return NULL;
120   }
121
122   return buf;
123 }
124
125 static ZSTD_DDict* readDict(const char* filename, char **buf, size_t* size, int32_t* dict_id) {
126   ZSTD_DDict* ddict;
127   *buf = readFile(filename, size);
128   if (*buf == NULL) {
129     fprintf(stderr, "Opening dictionary file '%s' failed\n", filename);
130     return NULL;
131   }
132
133   ddict = ZSTD_createDDict_advanced(*buf, *size, ZSTD_dlm_byRef, ZSTD_dct_auto, ZSTD_defaultCMem);
134   if (ddict == NULL) {
135     fprintf(stderr, "Failed to create ddict.\n");
136     return NULL;
137   }
138   if (dict_id != NULL) {
139     *dict_id = ZSTD_getDictID_fromDDict(ddict);
140   }
141   return ddict;
142 }
143
144 static ZSTD_DDict* readDictByID(stuff_t *stuff, int32_t dict_id, char **buf, size_t* size) {
145   if (stuff->dict_file_dir_name == NULL) {
146     return NULL;
147   } else {
148     size_t dir_name_len = strlen(stuff->dict_file_dir_name);
149     int dir_needs_separator = 0;
150     size_t dict_file_name_alloc_size = dir_name_len + 1 /* '/' */ + 10 /* max int32_t len */ + strlen(".zstd-dict") + 1 /* '\0' */;
151     char *dict_file_name = malloc(dict_file_name_alloc_size);
152     ZSTD_DDict* ddict;
153     int32_t read_dict_id;
154     if (dict_file_name == NULL) {
155       fprintf(stderr, "malloc failed.\n");
156       return 0;
157     }
158
159     if (dir_name_len > 0 && stuff->dict_file_dir_name[dir_name_len - 1] != '/') {
160       dir_needs_separator = 1;
161     }
162
163     snprintf(
164       dict_file_name,
165       dict_file_name_alloc_size,
166       "%s%s%u.zstd-dict",
167       stuff->dict_file_dir_name,
168       dir_needs_separator ? "/" : "",
169       dict_id);
170
171     /* fprintf(stderr, "Loading dict %u from '%s'.\n", dict_id, dict_file_name); */
172
173     ddict = readDict(dict_file_name, buf, size, &read_dict_id);
174     if (ddict == NULL) {
175       fprintf(stderr, "Failed to create ddict from '%s'.\n", dict_file_name);
176       free(dict_file_name);
177       return 0;
178     }
179     if (read_dict_id != dict_id) {
180       fprintf(stderr, "Read dictID (%u) does not match expected (%u).\n", read_dict_id, dict_id);
181       free(dict_file_name);
182       ZSTD_freeDDict(ddict);
183       return 0;
184     }
185
186     free(dict_file_name);
187     return ddict;
188   }
189 }
190
191 static int init_stuff(stuff_t* stuff, int argc, char *argv[]) {
192   const char* input_filename;
193
194   if (argc < 2) {
195     usage();
196   }
197
198   input_filename = argv[1];
199   stuff->input_size = 0;
200   stuff->input = readFile(input_filename, &stuff->input_size);
201   if (stuff->input == NULL) {
202     fprintf(stderr, "Failed to read input file.\n");
203     return 0;
204   }
205
206   stuff->perturbed = malloc(stuff->input_size);
207   if (stuff->perturbed == NULL) {
208     fprintf(stderr, "malloc failed.\n");
209     return 0;
210   }
211   memcpy(stuff->perturbed, stuff->input, stuff->input_size);
212
213   stuff->output_size = ZSTD_DStreamOutSize();
214   stuff->output = malloc(stuff->output_size);
215   if (stuff->output == NULL) {
216     fprintf(stderr, "malloc failed.\n");
217     return 0;
218   }
219
220   stuff->dict_file_name = NULL;
221   stuff->dict_file_dir_name = NULL;
222   stuff->dict_id = 0;
223   stuff->dict = NULL;
224   stuff->dict_size = 0;
225   stuff->ddict = NULL;
226
227   if (argc > 2) {
228     if (!strcmp(argv[2], "-d")) {
229       if (argc > 3) {
230         stuff->dict_file_name = argv[3];
231       } else {
232         usage();
233       }
234     } else
235     if (!strcmp(argv[2], "-D")) {
236       if (argc > 3) {
237         stuff->dict_file_dir_name = argv[3];
238       } else {
239         usage();
240       }
241     } else {
242       usage();
243     }
244   }
245
246   if (stuff->dict_file_dir_name) {
247     int32_t dict_id = ZSTD_getDictID_fromFrame(stuff->input, stuff->input_size);
248     if (dict_id != 0) {
249       stuff->ddict = readDictByID(stuff, dict_id, &stuff->dict, &stuff->dict_size);
250       if (stuff->ddict == NULL) {
251         fprintf(stderr, "Failed to create cached ddict.\n");
252         return 0;
253       }
254       stuff->dict_id = dict_id;
255     }
256   } else
257   if (stuff->dict_file_name) {
258     stuff->ddict = readDict(stuff->dict_file_name, &stuff->dict, &stuff->dict_size, &stuff->dict_id);
259     if (stuff->ddict == NULL) {
260       fprintf(stderr, "Failed to create ddict from '%s'.\n", stuff->dict_file_name);
261       return 0;
262     }
263   }
264
265   stuff->dctx = ZSTD_createDCtx();
266   if (stuff->dctx == NULL) {
267     return 0;
268   }
269
270   stuff->success_count = 0;
271   memset(stuff->error_counts, 0, sizeof(stuff->error_counts));
272
273   return 1;
274 }
275
276 static int test_decompress(stuff_t* stuff) {
277   size_t ret;
278   ZSTD_inBuffer in = {stuff->perturbed, stuff->input_size, 0};
279   ZSTD_outBuffer out = {stuff->output, stuff->output_size, 0};
280   ZSTD_DCtx* dctx = stuff->dctx;
281   int32_t custom_dict_id = ZSTD_getDictID_fromFrame(in.src, in.size);
282   char *custom_dict = NULL;
283   size_t custom_dict_size = 0;
284   ZSTD_DDict* custom_ddict = NULL;
285
286   if (custom_dict_id != 0 && custom_dict_id != stuff->dict_id) {
287     /* fprintf(stderr, "Instead of dict %u, this perturbed blob wants dict %u.\n", stuff->dict_id, custom_dict_id); */
288     custom_ddict = readDictByID(stuff, custom_dict_id, &custom_dict, &custom_dict_size);
289   }
290
291   ZSTD_DCtx_reset(dctx, ZSTD_reset_session_only);
292
293   if (custom_ddict != NULL) {
294     ZSTD_DCtx_refDDict(dctx, custom_ddict);
295   } else {
296     ZSTD_DCtx_refDDict(dctx, stuff->ddict);
297   }
298
299   while (in.pos != in.size) {
300     out.pos = 0;
301     ret = ZSTD_decompressStream(dctx, &out, &in);
302
303     if (ZSTD_isError(ret)) {
304       unsigned int code = ZSTD_getErrorCode(ret);
305       if (code >= ZSTD_error_maxCode) {
306         fprintf(stderr, "Received unexpected error code!\n");
307         exit(1);
308       }
309       stuff->error_counts[code]++;
310       /*
311       fprintf(
312           stderr, "Decompression failed: %s\n", ZSTD_getErrorName(ret));
313       */
314       if (custom_ddict != NULL) {
315         ZSTD_freeDDict(custom_ddict);
316         free(custom_dict);
317       }
318       return 0;
319     }
320   }
321
322   stuff->success_count++;
323
324   if (custom_ddict != NULL) {
325     ZSTD_freeDDict(custom_ddict);
326     free(custom_dict);
327   }
328   return 1;
329 }
330
331 static int perturb_bits(stuff_t* stuff) {
332   size_t pos;
333   size_t bit;
334   for (pos = 0; pos < stuff->input_size; pos++) {
335     unsigned char old_val = stuff->input[pos];
336     if (pos % 1000 == 0) {
337       fprintf(stderr, "Perturbing byte %zu / %zu\n", pos, stuff->input_size);
338     }
339     for (bit = 0; bit < 8; bit++) {
340       unsigned char new_val = old_val ^ (1 << bit);
341       stuff->perturbed[pos] = new_val;
342       if (test_decompress(stuff)) {
343         fprintf(
344             stderr,
345             "Flipping byte %zu bit %zu (0x%02x -> 0x%02x) "
346             "produced a successful decompression!\n",
347             pos, bit, old_val, new_val);
348       }
349     }
350     stuff->perturbed[pos] = old_val;
351   }
352   return 1;
353 }
354
355 static int perturb_bytes(stuff_t* stuff) {
356   size_t pos;
357   size_t new_val;
358   for (pos = 0; pos < stuff->input_size; pos++) {
359     unsigned char old_val = stuff->input[pos];
360     if (pos % 1000 == 0) {
361       fprintf(stderr, "Perturbing byte %zu / %zu\n", pos, stuff->input_size);
362     }
363     for (new_val = 0; new_val < 256; new_val++) {
364       stuff->perturbed[pos] = new_val;
365       if (test_decompress(stuff)) {
366         fprintf(
367             stderr,
368             "Changing byte %zu (0x%02x -> 0x%02x) "
369             "produced a successful decompression!\n",
370             pos, old_val, (unsigned char)new_val);
371       }
372     }
373     stuff->perturbed[pos] = old_val;
374   }
375   return 1;
376 }
377
378 int main(int argc, char* argv[]) {
379   stuff_t stuff;
380
381   if(!init_stuff(&stuff, argc, argv)) {
382     fprintf(stderr, "Failed to init.\n");
383     return 1;
384   }
385
386   if (test_decompress(&stuff)) {
387     fprintf(stderr, "Blob already decompresses successfully!\n");
388     return 1;
389   }
390
391   perturb_bits(&stuff);
392
393   perturb_bytes(&stuff);
394
395   print_summary(&stuff);
396
397   free_stuff(&stuff);
398
399   return 0;
400 }