| 1 | /* |
| 2 | * Copyright (c) Meta Platforms, Inc. and affiliates. |
| 3 | * All rights reserved. |
| 4 | * |
| 5 | * This source code is licensed under both the BSD-style license (found in the |
| 6 | * LICENSE file in the root directory of this source tree) and the GPLv2 (found |
| 7 | * in the COPYING file in the root directory of this source tree). |
| 8 | * You may select, at your option, one of the above-listed licenses. |
| 9 | */ |
| 10 | |
| 11 | #define ZSTD_STATIC_LINKING_ONLY |
| 12 | #include "zstd.h" |
| 13 | #include "zstd_errors.h" |
| 14 | |
| 15 | #include <stdio.h> |
| 16 | #include <stdlib.h> |
| 17 | #include <string.h> |
| 18 | #include <sys/types.h> |
| 19 | #include <sys/stat.h> |
| 20 | #include <unistd.h> |
| 21 | |
| 22 | typedef struct { |
| 23 | char *input; |
| 24 | size_t input_size; |
| 25 | |
| 26 | char *perturbed; /* same size as input */ |
| 27 | |
| 28 | char *output; |
| 29 | size_t output_size; |
| 30 | |
| 31 | const char *dict_file_name; |
| 32 | const char *dict_file_dir_name; |
| 33 | int32_t dict_id; |
| 34 | char *dict; |
| 35 | size_t dict_size; |
| 36 | ZSTD_DDict* ddict; |
| 37 | |
| 38 | ZSTD_DCtx* dctx; |
| 39 | |
| 40 | int success_count; |
| 41 | int error_counts[ZSTD_error_maxCode]; |
| 42 | } stuff_t; |
| 43 | |
| 44 | static void free_stuff(stuff_t* stuff) { |
| 45 | free(stuff->input); |
| 46 | free(stuff->output); |
| 47 | ZSTD_freeDDict(stuff->ddict); |
| 48 | free(stuff->dict); |
| 49 | ZSTD_freeDCtx(stuff->dctx); |
| 50 | } |
| 51 | |
| 52 | static void usage(void) { |
| 53 | fprintf(stderr, "check_flipped_bits input_filename [-d dict] [-D dict_dir]\n"); |
| 54 | fprintf(stderr, "\n"); |
| 55 | fprintf(stderr, "Arguments:\n"); |
| 56 | fprintf(stderr, " -d file: path to a dictionary file to use.\n"); |
| 57 | fprintf(stderr, " -D dir : path to a directory, with files containing dictionaries, of the\n" |
| 58 | " form DICTID.zstd-dict, e.g., 12345.zstd-dict.\n"); |
| 59 | exit(1); |
| 60 | } |
| 61 | |
| 62 | static void print_summary(stuff_t* stuff) { |
| 63 | int error_code; |
| 64 | fprintf(stderr, "%9d successful decompressions\n", stuff->success_count); |
| 65 | for (error_code = 0; error_code < ZSTD_error_maxCode; error_code++) { |
| 66 | int count = stuff->error_counts[error_code]; |
| 67 | if (count) { |
| 68 | fprintf( |
| 69 | stderr, "%9d failed decompressions with message: %s\n", |
| 70 | count, ZSTD_getErrorString(error_code)); |
| 71 | } |
| 72 | } |
| 73 | } |
| 74 | |
| 75 | static char* readFile(const char* filename, size_t* size) { |
| 76 | struct stat statbuf; |
| 77 | int ret; |
| 78 | FILE* f; |
| 79 | char *buf; |
| 80 | size_t bytes_read; |
| 81 | |
| 82 | ret = stat(filename, &statbuf); |
| 83 | if (ret != 0) { |
| 84 | fprintf(stderr, "stat failed: %m\n"); |
| 85 | return NULL; |
| 86 | } |
| 87 | if ((statbuf.st_mode & S_IFREG) != S_IFREG) { |
| 88 | fprintf(stderr, "Input must be regular file\n"); |
| 89 | return NULL; |
| 90 | } |
| 91 | |
| 92 | *size = statbuf.st_size; |
| 93 | |
| 94 | f = fopen(filename, "r"); |
| 95 | if (f == NULL) { |
| 96 | fprintf(stderr, "fopen failed: %m\n"); |
| 97 | return NULL; |
| 98 | } |
| 99 | |
| 100 | buf = malloc(*size); |
| 101 | if (buf == NULL) { |
| 102 | fprintf(stderr, "malloc failed\n"); |
| 103 | fclose(f); |
| 104 | return NULL; |
| 105 | } |
| 106 | |
| 107 | bytes_read = fread(buf, 1, *size, f); |
| 108 | if (bytes_read != *size) { |
| 109 | fprintf(stderr, "failed to read whole file\n"); |
| 110 | fclose(f); |
| 111 | free(buf); |
| 112 | return NULL; |
| 113 | } |
| 114 | |
| 115 | ret = fclose(f); |
| 116 | if (ret != 0) { |
| 117 | fprintf(stderr, "fclose failed: %m\n"); |
| 118 | free(buf); |
| 119 | return NULL; |
| 120 | } |
| 121 | |
| 122 | return buf; |
| 123 | } |
| 124 | |
| 125 | static ZSTD_DDict* readDict(const char* filename, char **buf, size_t* size, int32_t* dict_id) { |
| 126 | ZSTD_DDict* ddict; |
| 127 | *buf = readFile(filename, size); |
| 128 | if (*buf == NULL) { |
| 129 | fprintf(stderr, "Opening dictionary file '%s' failed\n", filename); |
| 130 | return NULL; |
| 131 | } |
| 132 | |
| 133 | ddict = ZSTD_createDDict_advanced(*buf, *size, ZSTD_dlm_byRef, ZSTD_dct_auto, ZSTD_defaultCMem); |
| 134 | if (ddict == NULL) { |
| 135 | fprintf(stderr, "Failed to create ddict.\n"); |
| 136 | return NULL; |
| 137 | } |
| 138 | if (dict_id != NULL) { |
| 139 | *dict_id = ZSTD_getDictID_fromDDict(ddict); |
| 140 | } |
| 141 | return ddict; |
| 142 | } |
| 143 | |
| 144 | static ZSTD_DDict* readDictByID(stuff_t *stuff, int32_t dict_id, char **buf, size_t* size) { |
| 145 | if (stuff->dict_file_dir_name == NULL) { |
| 146 | return NULL; |
| 147 | } else { |
| 148 | size_t dir_name_len = strlen(stuff->dict_file_dir_name); |
| 149 | int dir_needs_separator = 0; |
| 150 | size_t dict_file_name_alloc_size = dir_name_len + 1 /* '/' */ + 10 /* max int32_t len */ + strlen(".zstd-dict") + 1 /* '\0' */; |
| 151 | char *dict_file_name = malloc(dict_file_name_alloc_size); |
| 152 | ZSTD_DDict* ddict; |
| 153 | int32_t read_dict_id; |
| 154 | if (dict_file_name == NULL) { |
| 155 | fprintf(stderr, "malloc failed.\n"); |
| 156 | return 0; |
| 157 | } |
| 158 | |
| 159 | if (dir_name_len > 0 && stuff->dict_file_dir_name[dir_name_len - 1] != '/') { |
| 160 | dir_needs_separator = 1; |
| 161 | } |
| 162 | |
| 163 | snprintf( |
| 164 | dict_file_name, |
| 165 | dict_file_name_alloc_size, |
| 166 | "%s%s%u.zstd-dict", |
| 167 | stuff->dict_file_dir_name, |
| 168 | dir_needs_separator ? "/" : "", |
| 169 | dict_id); |
| 170 | |
| 171 | /* fprintf(stderr, "Loading dict %u from '%s'.\n", dict_id, dict_file_name); */ |
| 172 | |
| 173 | ddict = readDict(dict_file_name, buf, size, &read_dict_id); |
| 174 | if (ddict == NULL) { |
| 175 | fprintf(stderr, "Failed to create ddict from '%s'.\n", dict_file_name); |
| 176 | free(dict_file_name); |
| 177 | return 0; |
| 178 | } |
| 179 | if (read_dict_id != dict_id) { |
| 180 | fprintf(stderr, "Read dictID (%u) does not match expected (%u).\n", read_dict_id, dict_id); |
| 181 | free(dict_file_name); |
| 182 | ZSTD_freeDDict(ddict); |
| 183 | return 0; |
| 184 | } |
| 185 | |
| 186 | free(dict_file_name); |
| 187 | return ddict; |
| 188 | } |
| 189 | } |
| 190 | |
| 191 | static int init_stuff(stuff_t* stuff, int argc, char *argv[]) { |
| 192 | const char* input_filename; |
| 193 | |
| 194 | if (argc < 2) { |
| 195 | usage(); |
| 196 | } |
| 197 | |
| 198 | input_filename = argv[1]; |
| 199 | stuff->input_size = 0; |
| 200 | stuff->input = readFile(input_filename, &stuff->input_size); |
| 201 | if (stuff->input == NULL) { |
| 202 | fprintf(stderr, "Failed to read input file.\n"); |
| 203 | return 0; |
| 204 | } |
| 205 | |
| 206 | stuff->perturbed = malloc(stuff->input_size); |
| 207 | if (stuff->perturbed == NULL) { |
| 208 | fprintf(stderr, "malloc failed.\n"); |
| 209 | return 0; |
| 210 | } |
| 211 | memcpy(stuff->perturbed, stuff->input, stuff->input_size); |
| 212 | |
| 213 | stuff->output_size = ZSTD_DStreamOutSize(); |
| 214 | stuff->output = malloc(stuff->output_size); |
| 215 | if (stuff->output == NULL) { |
| 216 | fprintf(stderr, "malloc failed.\n"); |
| 217 | return 0; |
| 218 | } |
| 219 | |
| 220 | stuff->dict_file_name = NULL; |
| 221 | stuff->dict_file_dir_name = NULL; |
| 222 | stuff->dict_id = 0; |
| 223 | stuff->dict = NULL; |
| 224 | stuff->dict_size = 0; |
| 225 | stuff->ddict = NULL; |
| 226 | |
| 227 | if (argc > 2) { |
| 228 | if (!strcmp(argv[2], "-d")) { |
| 229 | if (argc > 3) { |
| 230 | stuff->dict_file_name = argv[3]; |
| 231 | } else { |
| 232 | usage(); |
| 233 | } |
| 234 | } else |
| 235 | if (!strcmp(argv[2], "-D")) { |
| 236 | if (argc > 3) { |
| 237 | stuff->dict_file_dir_name = argv[3]; |
| 238 | } else { |
| 239 | usage(); |
| 240 | } |
| 241 | } else { |
| 242 | usage(); |
| 243 | } |
| 244 | } |
| 245 | |
| 246 | if (stuff->dict_file_dir_name) { |
| 247 | int32_t dict_id = ZSTD_getDictID_fromFrame(stuff->input, stuff->input_size); |
| 248 | if (dict_id != 0) { |
| 249 | stuff->ddict = readDictByID(stuff, dict_id, &stuff->dict, &stuff->dict_size); |
| 250 | if (stuff->ddict == NULL) { |
| 251 | fprintf(stderr, "Failed to create cached ddict.\n"); |
| 252 | return 0; |
| 253 | } |
| 254 | stuff->dict_id = dict_id; |
| 255 | } |
| 256 | } else |
| 257 | if (stuff->dict_file_name) { |
| 258 | stuff->ddict = readDict(stuff->dict_file_name, &stuff->dict, &stuff->dict_size, &stuff->dict_id); |
| 259 | if (stuff->ddict == NULL) { |
| 260 | fprintf(stderr, "Failed to create ddict from '%s'.\n", stuff->dict_file_name); |
| 261 | return 0; |
| 262 | } |
| 263 | } |
| 264 | |
| 265 | stuff->dctx = ZSTD_createDCtx(); |
| 266 | if (stuff->dctx == NULL) { |
| 267 | return 0; |
| 268 | } |
| 269 | |
| 270 | stuff->success_count = 0; |
| 271 | memset(stuff->error_counts, 0, sizeof(stuff->error_counts)); |
| 272 | |
| 273 | return 1; |
| 274 | } |
| 275 | |
| 276 | static int test_decompress(stuff_t* stuff) { |
| 277 | size_t ret; |
| 278 | ZSTD_inBuffer in = {stuff->perturbed, stuff->input_size, 0}; |
| 279 | ZSTD_outBuffer out = {stuff->output, stuff->output_size, 0}; |
| 280 | ZSTD_DCtx* dctx = stuff->dctx; |
| 281 | int32_t custom_dict_id = ZSTD_getDictID_fromFrame(in.src, in.size); |
| 282 | char *custom_dict = NULL; |
| 283 | size_t custom_dict_size = 0; |
| 284 | ZSTD_DDict* custom_ddict = NULL; |
| 285 | |
| 286 | if (custom_dict_id != 0 && custom_dict_id != stuff->dict_id) { |
| 287 | /* fprintf(stderr, "Instead of dict %u, this perturbed blob wants dict %u.\n", stuff->dict_id, custom_dict_id); */ |
| 288 | custom_ddict = readDictByID(stuff, custom_dict_id, &custom_dict, &custom_dict_size); |
| 289 | } |
| 290 | |
| 291 | ZSTD_DCtx_reset(dctx, ZSTD_reset_session_only); |
| 292 | |
| 293 | if (custom_ddict != NULL) { |
| 294 | ZSTD_DCtx_refDDict(dctx, custom_ddict); |
| 295 | } else { |
| 296 | ZSTD_DCtx_refDDict(dctx, stuff->ddict); |
| 297 | } |
| 298 | |
| 299 | while (in.pos != in.size) { |
| 300 | out.pos = 0; |
| 301 | ret = ZSTD_decompressStream(dctx, &out, &in); |
| 302 | |
| 303 | if (ZSTD_isError(ret)) { |
| 304 | unsigned int code = ZSTD_getErrorCode(ret); |
| 305 | if (code >= ZSTD_error_maxCode) { |
| 306 | fprintf(stderr, "Received unexpected error code!\n"); |
| 307 | exit(1); |
| 308 | } |
| 309 | stuff->error_counts[code]++; |
| 310 | /* |
| 311 | fprintf( |
| 312 | stderr, "Decompression failed: %s\n", ZSTD_getErrorName(ret)); |
| 313 | */ |
| 314 | if (custom_ddict != NULL) { |
| 315 | ZSTD_freeDDict(custom_ddict); |
| 316 | free(custom_dict); |
| 317 | } |
| 318 | return 0; |
| 319 | } |
| 320 | } |
| 321 | |
| 322 | stuff->success_count++; |
| 323 | |
| 324 | if (custom_ddict != NULL) { |
| 325 | ZSTD_freeDDict(custom_ddict); |
| 326 | free(custom_dict); |
| 327 | } |
| 328 | return 1; |
| 329 | } |
| 330 | |
| 331 | static int perturb_bits(stuff_t* stuff) { |
| 332 | size_t pos; |
| 333 | size_t bit; |
| 334 | for (pos = 0; pos < stuff->input_size; pos++) { |
| 335 | unsigned char old_val = stuff->input[pos]; |
| 336 | if (pos % 1000 == 0) { |
| 337 | fprintf(stderr, "Perturbing byte %zu / %zu\n", pos, stuff->input_size); |
| 338 | } |
| 339 | for (bit = 0; bit < 8; bit++) { |
| 340 | unsigned char new_val = old_val ^ (1 << bit); |
| 341 | stuff->perturbed[pos] = new_val; |
| 342 | if (test_decompress(stuff)) { |
| 343 | fprintf( |
| 344 | stderr, |
| 345 | "Flipping byte %zu bit %zu (0x%02x -> 0x%02x) " |
| 346 | "produced a successful decompression!\n", |
| 347 | pos, bit, old_val, new_val); |
| 348 | } |
| 349 | } |
| 350 | stuff->perturbed[pos] = old_val; |
| 351 | } |
| 352 | return 1; |
| 353 | } |
| 354 | |
| 355 | static int perturb_bytes(stuff_t* stuff) { |
| 356 | size_t pos; |
| 357 | size_t new_val; |
| 358 | for (pos = 0; pos < stuff->input_size; pos++) { |
| 359 | unsigned char old_val = stuff->input[pos]; |
| 360 | if (pos % 1000 == 0) { |
| 361 | fprintf(stderr, "Perturbing byte %zu / %zu\n", pos, stuff->input_size); |
| 362 | } |
| 363 | for (new_val = 0; new_val < 256; new_val++) { |
| 364 | stuff->perturbed[pos] = new_val; |
| 365 | if (test_decompress(stuff)) { |
| 366 | fprintf( |
| 367 | stderr, |
| 368 | "Changing byte %zu (0x%02x -> 0x%02x) " |
| 369 | "produced a successful decompression!\n", |
| 370 | pos, old_val, (unsigned char)new_val); |
| 371 | } |
| 372 | } |
| 373 | stuff->perturbed[pos] = old_val; |
| 374 | } |
| 375 | return 1; |
| 376 | } |
| 377 | |
| 378 | int main(int argc, char* argv[]) { |
| 379 | stuff_t stuff; |
| 380 | |
| 381 | if(!init_stuff(&stuff, argc, argv)) { |
| 382 | fprintf(stderr, "Failed to init.\n"); |
| 383 | return 1; |
| 384 | } |
| 385 | |
| 386 | if (test_decompress(&stuff)) { |
| 387 | fprintf(stderr, "Blob already decompresses successfully!\n"); |
| 388 | return 1; |
| 389 | } |
| 390 | |
| 391 | perturb_bits(&stuff); |
| 392 | |
| 393 | perturb_bytes(&stuff); |
| 394 | |
| 395 | print_summary(&stuff); |
| 396 | |
| 397 | free_stuff(&stuff); |
| 398 | |
| 399 | return 0; |
| 400 | } |