X-Git-Url: https://notaz.gp2x.de/cgi-bin/gitweb.cgi?a=blobdiff_plain;f=deps%2Flibchdr%2Fdeps%2Fzstd-1.5.5%2Fcontrib%2Fdiagnose_corruption%2Fcheck_flipped_bits.c;fp=deps%2Flibchdr%2Fdeps%2Fzstd-1.5.5%2Fcontrib%2Fdiagnose_corruption%2Fcheck_flipped_bits.c;h=09ddd4674768ec85c919604a609c42de4638fd21;hb=648db22b0750712da893c306efcc8e4b2d3a4e3c;hp=0000000000000000000000000000000000000000;hpb=e2fb1389dc12376acb84e4993ed3b08760257252;p=pcsx_rearmed.git diff --git a/deps/libchdr/deps/zstd-1.5.5/contrib/diagnose_corruption/check_flipped_bits.c b/deps/libchdr/deps/zstd-1.5.5/contrib/diagnose_corruption/check_flipped_bits.c new file mode 100644 index 00000000..09ddd467 --- /dev/null +++ b/deps/libchdr/deps/zstd-1.5.5/contrib/diagnose_corruption/check_flipped_bits.c @@ -0,0 +1,400 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under both the BSD-style license (found in the + * LICENSE file in the root directory of this source tree) and the GPLv2 (found + * in the COPYING file in the root directory of this source tree). + * You may select, at your option, one of the above-listed licenses. + */ + +#define ZSTD_STATIC_LINKING_ONLY +#include "zstd.h" +#include "zstd_errors.h" + +#include +#include +#include +#include +#include +#include + +typedef struct { + char *input; + size_t input_size; + + char *perturbed; /* same size as input */ + + char *output; + size_t output_size; + + const char *dict_file_name; + const char *dict_file_dir_name; + int32_t dict_id; + char *dict; + size_t dict_size; + ZSTD_DDict* ddict; + + ZSTD_DCtx* dctx; + + int success_count; + int error_counts[ZSTD_error_maxCode]; +} stuff_t; + +static void free_stuff(stuff_t* stuff) { + free(stuff->input); + free(stuff->output); + ZSTD_freeDDict(stuff->ddict); + free(stuff->dict); + ZSTD_freeDCtx(stuff->dctx); +} + +static void usage(void) { + fprintf(stderr, "check_flipped_bits input_filename [-d dict] [-D dict_dir]\n"); + fprintf(stderr, "\n"); + fprintf(stderr, "Arguments:\n"); + fprintf(stderr, " -d file: path to a dictionary file to use.\n"); + fprintf(stderr, " -D dir : path to a directory, with files containing dictionaries, of the\n" + " form DICTID.zstd-dict, e.g., 12345.zstd-dict.\n"); + exit(1); +} + +static void print_summary(stuff_t* stuff) { + int error_code; + fprintf(stderr, "%9d successful decompressions\n", stuff->success_count); + for (error_code = 0; error_code < ZSTD_error_maxCode; error_code++) { + int count = stuff->error_counts[error_code]; + if (count) { + fprintf( + stderr, "%9d failed decompressions with message: %s\n", + count, ZSTD_getErrorString(error_code)); + } + } +} + +static char* readFile(const char* filename, size_t* size) { + struct stat statbuf; + int ret; + FILE* f; + char *buf; + size_t bytes_read; + + ret = stat(filename, &statbuf); + if (ret != 0) { + fprintf(stderr, "stat failed: %m\n"); + return NULL; + } + if ((statbuf.st_mode & S_IFREG) != S_IFREG) { + fprintf(stderr, "Input must be regular file\n"); + return NULL; + } + + *size = statbuf.st_size; + + f = fopen(filename, "r"); + if (f == NULL) { + fprintf(stderr, "fopen failed: %m\n"); + return NULL; + } + + buf = malloc(*size); + if (buf == NULL) { + fprintf(stderr, "malloc failed\n"); + fclose(f); + return NULL; + } + + bytes_read = fread(buf, 1, *size, f); + if (bytes_read != *size) { + fprintf(stderr, "failed to read whole file\n"); + fclose(f); + free(buf); + return NULL; + } + + ret = fclose(f); + if (ret != 0) { + fprintf(stderr, "fclose failed: %m\n"); + free(buf); + return NULL; + } + + return buf; +} + +static ZSTD_DDict* readDict(const char* filename, char **buf, size_t* size, int32_t* dict_id) { + ZSTD_DDict* ddict; + *buf = readFile(filename, size); + if (*buf == NULL) { + fprintf(stderr, "Opening dictionary file '%s' failed\n", filename); + return NULL; + } + + ddict = ZSTD_createDDict_advanced(*buf, *size, ZSTD_dlm_byRef, ZSTD_dct_auto, ZSTD_defaultCMem); + if (ddict == NULL) { + fprintf(stderr, "Failed to create ddict.\n"); + return NULL; + } + if (dict_id != NULL) { + *dict_id = ZSTD_getDictID_fromDDict(ddict); + } + return ddict; +} + +static ZSTD_DDict* readDictByID(stuff_t *stuff, int32_t dict_id, char **buf, size_t* size) { + if (stuff->dict_file_dir_name == NULL) { + return NULL; + } else { + size_t dir_name_len = strlen(stuff->dict_file_dir_name); + int dir_needs_separator = 0; + size_t dict_file_name_alloc_size = dir_name_len + 1 /* '/' */ + 10 /* max int32_t len */ + strlen(".zstd-dict") + 1 /* '\0' */; + char *dict_file_name = malloc(dict_file_name_alloc_size); + ZSTD_DDict* ddict; + int32_t read_dict_id; + if (dict_file_name == NULL) { + fprintf(stderr, "malloc failed.\n"); + return 0; + } + + if (dir_name_len > 0 && stuff->dict_file_dir_name[dir_name_len - 1] != '/') { + dir_needs_separator = 1; + } + + snprintf( + dict_file_name, + dict_file_name_alloc_size, + "%s%s%u.zstd-dict", + stuff->dict_file_dir_name, + dir_needs_separator ? "/" : "", + dict_id); + + /* fprintf(stderr, "Loading dict %u from '%s'.\n", dict_id, dict_file_name); */ + + ddict = readDict(dict_file_name, buf, size, &read_dict_id); + if (ddict == NULL) { + fprintf(stderr, "Failed to create ddict from '%s'.\n", dict_file_name); + free(dict_file_name); + return 0; + } + if (read_dict_id != dict_id) { + fprintf(stderr, "Read dictID (%u) does not match expected (%u).\n", read_dict_id, dict_id); + free(dict_file_name); + ZSTD_freeDDict(ddict); + return 0; + } + + free(dict_file_name); + return ddict; + } +} + +static int init_stuff(stuff_t* stuff, int argc, char *argv[]) { + const char* input_filename; + + if (argc < 2) { + usage(); + } + + input_filename = argv[1]; + stuff->input_size = 0; + stuff->input = readFile(input_filename, &stuff->input_size); + if (stuff->input == NULL) { + fprintf(stderr, "Failed to read input file.\n"); + return 0; + } + + stuff->perturbed = malloc(stuff->input_size); + if (stuff->perturbed == NULL) { + fprintf(stderr, "malloc failed.\n"); + return 0; + } + memcpy(stuff->perturbed, stuff->input, stuff->input_size); + + stuff->output_size = ZSTD_DStreamOutSize(); + stuff->output = malloc(stuff->output_size); + if (stuff->output == NULL) { + fprintf(stderr, "malloc failed.\n"); + return 0; + } + + stuff->dict_file_name = NULL; + stuff->dict_file_dir_name = NULL; + stuff->dict_id = 0; + stuff->dict = NULL; + stuff->dict_size = 0; + stuff->ddict = NULL; + + if (argc > 2) { + if (!strcmp(argv[2], "-d")) { + if (argc > 3) { + stuff->dict_file_name = argv[3]; + } else { + usage(); + } + } else + if (!strcmp(argv[2], "-D")) { + if (argc > 3) { + stuff->dict_file_dir_name = argv[3]; + } else { + usage(); + } + } else { + usage(); + } + } + + if (stuff->dict_file_dir_name) { + int32_t dict_id = ZSTD_getDictID_fromFrame(stuff->input, stuff->input_size); + if (dict_id != 0) { + stuff->ddict = readDictByID(stuff, dict_id, &stuff->dict, &stuff->dict_size); + if (stuff->ddict == NULL) { + fprintf(stderr, "Failed to create cached ddict.\n"); + return 0; + } + stuff->dict_id = dict_id; + } + } else + if (stuff->dict_file_name) { + stuff->ddict = readDict(stuff->dict_file_name, &stuff->dict, &stuff->dict_size, &stuff->dict_id); + if (stuff->ddict == NULL) { + fprintf(stderr, "Failed to create ddict from '%s'.\n", stuff->dict_file_name); + return 0; + } + } + + stuff->dctx = ZSTD_createDCtx(); + if (stuff->dctx == NULL) { + return 0; + } + + stuff->success_count = 0; + memset(stuff->error_counts, 0, sizeof(stuff->error_counts)); + + return 1; +} + +static int test_decompress(stuff_t* stuff) { + size_t ret; + ZSTD_inBuffer in = {stuff->perturbed, stuff->input_size, 0}; + ZSTD_outBuffer out = {stuff->output, stuff->output_size, 0}; + ZSTD_DCtx* dctx = stuff->dctx; + int32_t custom_dict_id = ZSTD_getDictID_fromFrame(in.src, in.size); + char *custom_dict = NULL; + size_t custom_dict_size = 0; + ZSTD_DDict* custom_ddict = NULL; + + if (custom_dict_id != 0 && custom_dict_id != stuff->dict_id) { + /* fprintf(stderr, "Instead of dict %u, this perturbed blob wants dict %u.\n", stuff->dict_id, custom_dict_id); */ + custom_ddict = readDictByID(stuff, custom_dict_id, &custom_dict, &custom_dict_size); + } + + ZSTD_DCtx_reset(dctx, ZSTD_reset_session_only); + + if (custom_ddict != NULL) { + ZSTD_DCtx_refDDict(dctx, custom_ddict); + } else { + ZSTD_DCtx_refDDict(dctx, stuff->ddict); + } + + while (in.pos != in.size) { + out.pos = 0; + ret = ZSTD_decompressStream(dctx, &out, &in); + + if (ZSTD_isError(ret)) { + unsigned int code = ZSTD_getErrorCode(ret); + if (code >= ZSTD_error_maxCode) { + fprintf(stderr, "Received unexpected error code!\n"); + exit(1); + } + stuff->error_counts[code]++; + /* + fprintf( + stderr, "Decompression failed: %s\n", ZSTD_getErrorName(ret)); + */ + if (custom_ddict != NULL) { + ZSTD_freeDDict(custom_ddict); + free(custom_dict); + } + return 0; + } + } + + stuff->success_count++; + + if (custom_ddict != NULL) { + ZSTD_freeDDict(custom_ddict); + free(custom_dict); + } + return 1; +} + +static int perturb_bits(stuff_t* stuff) { + size_t pos; + size_t bit; + for (pos = 0; pos < stuff->input_size; pos++) { + unsigned char old_val = stuff->input[pos]; + if (pos % 1000 == 0) { + fprintf(stderr, "Perturbing byte %zu / %zu\n", pos, stuff->input_size); + } + for (bit = 0; bit < 8; bit++) { + unsigned char new_val = old_val ^ (1 << bit); + stuff->perturbed[pos] = new_val; + if (test_decompress(stuff)) { + fprintf( + stderr, + "Flipping byte %zu bit %zu (0x%02x -> 0x%02x) " + "produced a successful decompression!\n", + pos, bit, old_val, new_val); + } + } + stuff->perturbed[pos] = old_val; + } + return 1; +} + +static int perturb_bytes(stuff_t* stuff) { + size_t pos; + size_t new_val; + for (pos = 0; pos < stuff->input_size; pos++) { + unsigned char old_val = stuff->input[pos]; + if (pos % 1000 == 0) { + fprintf(stderr, "Perturbing byte %zu / %zu\n", pos, stuff->input_size); + } + for (new_val = 0; new_val < 256; new_val++) { + stuff->perturbed[pos] = new_val; + if (test_decompress(stuff)) { + fprintf( + stderr, + "Changing byte %zu (0x%02x -> 0x%02x) " + "produced a successful decompression!\n", + pos, old_val, (unsigned char)new_val); + } + } + stuff->perturbed[pos] = old_val; + } + return 1; +} + +int main(int argc, char* argv[]) { + stuff_t stuff; + + if(!init_stuff(&stuff, argc, argv)) { + fprintf(stderr, "Failed to init.\n"); + return 1; + } + + if (test_decompress(&stuff)) { + fprintf(stderr, "Blob already decompresses successfully!\n"); + return 1; + } + + perturb_bits(&stuff); + + perturb_bytes(&stuff); + + print_summary(&stuff); + + free_stuff(&stuff); + + return 0; +}