f84a15ef330dfb9c5db21c3e4343458d5236d064
[pcsx_rearmed.git] / deps / libchdr / deps / zstd-1.5.5 / tests / regression / method.c
1 /*
2  * Copyright (c) Meta Platforms, Inc. and affiliates.
3  * All rights reserved.
4  *
5  * This source code is licensed under both the BSD-style license (found in the
6  * LICENSE file in the root directory of this source tree) and the GPLv2 (found
7  * in the COPYING file in the root directory of this source tree).
8  * You may select, at your option, one of the above-listed licenses.
9  */
10
11 #include "method.h"
12
13 #include <stdio.h>
14 #include <stdlib.h>
15
16 #define ZSTD_STATIC_LINKING_ONLY
17 #include <zstd.h>
18
19 #define MIN(x, y) ((x) < (y) ? (x) : (y))
20
21 static char const* g_zstdcli = NULL;
22
23 void method_set_zstdcli(char const* zstdcli) {
24     g_zstdcli = zstdcli;
25 }
26
27 /**
28  * Macro to get a pointer of type, given ptr, which is a member variable with
29  * the given name, member.
30  *
31  *     method_state_t* base = ...;
32  *     buffer_state_t* state = container_of(base, buffer_state_t, base);
33  */
34 #define container_of(ptr, type, member) \
35     ((type*)(ptr == NULL ? NULL : (char*)(ptr)-offsetof(type, member)))
36
37 /** State to reuse the same buffers between compression calls. */
38 typedef struct {
39     method_state_t base;
40     data_buffers_t inputs; /**< The input buffer for each file. */
41     data_buffer_t dictionary; /**< The dictionary. */
42     data_buffer_t compressed; /**< The compressed data buffer. */
43     data_buffer_t decompressed; /**< The decompressed data buffer. */
44 } buffer_state_t;
45
46 static size_t buffers_max_size(data_buffers_t buffers) {
47     size_t max = 0;
48     for (size_t i = 0; i < buffers.size; ++i) {
49         if (buffers.buffers[i].size > max)
50             max = buffers.buffers[i].size;
51     }
52     return max;
53 }
54
55 static method_state_t* buffer_state_create(data_t const* data) {
56     buffer_state_t* state = (buffer_state_t*)calloc(1, sizeof(buffer_state_t));
57     if (state == NULL)
58         return NULL;
59     state->base.data = data;
60     state->inputs = data_buffers_get(data);
61     state->dictionary = data_buffer_get_dict(data);
62     size_t const max_size = buffers_max_size(state->inputs);
63     state->compressed = data_buffer_create(ZSTD_compressBound(max_size));
64     state->decompressed = data_buffer_create(max_size);
65     return &state->base;
66 }
67
68 static void buffer_state_destroy(method_state_t* base) {
69     if (base == NULL)
70         return;
71     buffer_state_t* state = container_of(base, buffer_state_t, base);
72     free(state);
73 }
74
75 static int buffer_state_bad(
76     buffer_state_t const* state,
77     config_t const* config) {
78     if (state == NULL) {
79         fprintf(stderr, "buffer_state_t is NULL\n");
80         return 1;
81     }
82     if (state->inputs.size == 0 || state->compressed.data == NULL ||
83         state->decompressed.data == NULL) {
84         fprintf(stderr, "buffer state allocation failure\n");
85         return 1;
86     }
87     if (config->use_dictionary && state->dictionary.data == NULL) {
88         fprintf(stderr, "dictionary loading failed\n");
89         return 1;
90     }
91     return 0;
92 }
93
94 static result_t simple_compress(method_state_t* base, config_t const* config) {
95     buffer_state_t* state = container_of(base, buffer_state_t, base);
96
97     if (buffer_state_bad(state, config))
98         return result_error(result_error_system_error);
99
100     /* Keep the tests short by skipping directories, since behavior shouldn't
101      * change.
102      */
103     if (base->data->type != data_type_file)
104         return result_error(result_error_skip);
105     
106     if (config->advanced_api_only)
107         return result_error(result_error_skip);
108
109     if (config->use_dictionary || config->no_pledged_src_size)
110         return result_error(result_error_skip);
111
112     /* If the config doesn't specify a level, skip. */
113     int const level = config_get_level(config);
114     if (level == CONFIG_NO_LEVEL)
115         return result_error(result_error_skip);
116
117     data_buffer_t const input = state->inputs.buffers[0];
118
119     /* Compress, decompress, and check the result. */
120     state->compressed.size = ZSTD_compress(
121         state->compressed.data,
122         state->compressed.capacity,
123         input.data,
124         input.size,
125         level);
126     if (ZSTD_isError(state->compressed.size))
127         return result_error(result_error_compression_error);
128
129     state->decompressed.size = ZSTD_decompress(
130         state->decompressed.data,
131         state->decompressed.capacity,
132         state->compressed.data,
133         state->compressed.size);
134     if (ZSTD_isError(state->decompressed.size))
135         return result_error(result_error_decompression_error);
136     if (data_buffer_compare(input, state->decompressed))
137         return result_error(result_error_round_trip_error);
138
139     result_data_t data;
140     data.total_size = state->compressed.size;
141     return result_data(data);
142 }
143
144 static result_t compress_cctx_compress(
145     method_state_t* base,
146     config_t const* config) {
147     buffer_state_t* state = container_of(base, buffer_state_t, base);
148
149     if (buffer_state_bad(state, config))
150         return result_error(result_error_system_error);
151
152     if (config->no_pledged_src_size)
153         return result_error(result_error_skip);
154
155     if (base->data->type != data_type_dir)
156         return result_error(result_error_skip);
157     
158     if (config->advanced_api_only)
159         return result_error(result_error_skip);
160
161     int const level = config_get_level(config);
162
163     ZSTD_CCtx* cctx = ZSTD_createCCtx();
164     ZSTD_DCtx* dctx = ZSTD_createDCtx();
165     if (cctx == NULL || dctx == NULL) {
166         fprintf(stderr, "context creation failed\n");
167         return result_error(result_error_system_error);
168     }
169
170     result_t result;
171     result_data_t data = {.total_size = 0};
172     for (size_t i = 0; i < state->inputs.size; ++i) {
173         data_buffer_t const input = state->inputs.buffers[i];
174         ZSTD_parameters const params =
175             config_get_zstd_params(config, input.size, state->dictionary.size);
176
177         if (level == CONFIG_NO_LEVEL)
178             state->compressed.size = ZSTD_compress_advanced(
179                 cctx,
180                 state->compressed.data,
181                 state->compressed.capacity,
182                 input.data,
183                 input.size,
184                 config->use_dictionary ? state->dictionary.data : NULL,
185                 config->use_dictionary ? state->dictionary.size : 0,
186                 params);
187         else if (config->use_dictionary)
188             state->compressed.size = ZSTD_compress_usingDict(
189                 cctx,
190                 state->compressed.data,
191                 state->compressed.capacity,
192                 input.data,
193                 input.size,
194                 state->dictionary.data,
195                 state->dictionary.size,
196                 level);
197         else
198             state->compressed.size = ZSTD_compressCCtx(
199                 cctx,
200                 state->compressed.data,
201                 state->compressed.capacity,
202                 input.data,
203                 input.size,
204                 level);
205
206         if (ZSTD_isError(state->compressed.size)) {
207             result = result_error(result_error_compression_error);
208             goto out;
209         }
210
211         if (config->use_dictionary)
212             state->decompressed.size = ZSTD_decompress_usingDict(
213                 dctx,
214                 state->decompressed.data,
215                 state->decompressed.capacity,
216                 state->compressed.data,
217                 state->compressed.size,
218                 state->dictionary.data,
219                 state->dictionary.size);
220         else
221             state->decompressed.size = ZSTD_decompressDCtx(
222                 dctx,
223                 state->decompressed.data,
224                 state->decompressed.capacity,
225                 state->compressed.data,
226                 state->compressed.size);
227         if (ZSTD_isError(state->decompressed.size)) {
228             result = result_error(result_error_decompression_error);
229             goto out;
230         }
231         if (data_buffer_compare(input, state->decompressed)) {
232             result = result_error(result_error_round_trip_error);
233             goto out;
234         }
235
236         data.total_size += state->compressed.size;
237     }
238
239     result = result_data(data);
240 out:
241     ZSTD_freeCCtx(cctx);
242     ZSTD_freeDCtx(dctx);
243     return result;
244 }
245
246 /** Generic state creation function. */
247 static method_state_t* method_state_create(data_t const* data) {
248     method_state_t* state = (method_state_t*)malloc(sizeof(method_state_t));
249     if (state == NULL)
250         return NULL;
251     state->data = data;
252     return state;
253 }
254
255 static void method_state_destroy(method_state_t* state) {
256     free(state);
257 }
258
259 static result_t cli_compress(method_state_t* state, config_t const* config) {
260     if (config->cli_args == NULL)
261         return result_error(result_error_skip);
262
263     if (config->advanced_api_only)
264         return result_error(result_error_skip);
265
266     /* We don't support no pledged source size with directories. Too slow. */
267     if (state->data->type == data_type_dir && config->no_pledged_src_size)
268         return result_error(result_error_skip);
269
270     if (g_zstdcli == NULL)
271         return result_error(result_error_system_error);
272
273     /* '<zstd>' -cqr <args> [-D '<dict>'] '<file/dir>' */
274     char cmd[1024];
275     size_t const cmd_size = snprintf(
276         cmd,
277         sizeof(cmd),
278         "'%s' -cqr %s %s%s%s %s '%s'",
279         g_zstdcli,
280         config->cli_args,
281         config->use_dictionary ? "-D '" : "",
282         config->use_dictionary ? state->data->dict.path : "",
283         config->use_dictionary ? "'" : "",
284         config->no_pledged_src_size ? "<" : "",
285         state->data->data.path);
286     if (cmd_size >= sizeof(cmd)) {
287         fprintf(stderr, "command too large: %s\n", cmd);
288         return result_error(result_error_system_error);
289     }
290     FILE* zstd = popen(cmd, "r");
291     if (zstd == NULL) {
292         fprintf(stderr, "failed to popen command: %s\n", cmd);
293         return result_error(result_error_system_error);
294     }
295
296     char out[4096];
297     size_t total_size = 0;
298     while (1) {
299         size_t const size = fread(out, 1, sizeof(out), zstd);
300         total_size += size;
301         if (size != sizeof(out))
302             break;
303     }
304     if (ferror(zstd) || pclose(zstd) != 0) {
305         fprintf(stderr, "zstd failed with command: %s\n", cmd);
306         return result_error(result_error_compression_error);
307     }
308
309     result_data_t const data = {.total_size = total_size};
310     return result_data(data);
311 }
312
313 static int advanced_config(
314     ZSTD_CCtx* cctx,
315     buffer_state_t* state,
316     config_t const* config) {
317     ZSTD_CCtx_reset(cctx, ZSTD_reset_session_and_parameters);
318     for (size_t p = 0; p < config->param_values.size; ++p) {
319         param_value_t const pv = config->param_values.data[p];
320         if (ZSTD_isError(ZSTD_CCtx_setParameter(cctx, pv.param, pv.value))) {
321             return 1;
322         }
323     }
324     if (config->use_dictionary) {
325         if (ZSTD_isError(ZSTD_CCtx_loadDictionary(
326                 cctx, state->dictionary.data, state->dictionary.size))) {
327             return 1;
328         }
329     }
330     return 0;
331 }
332
333 static result_t advanced_one_pass_compress_output_adjustment(
334     method_state_t* base,
335     config_t const* config,
336     size_t const subtract) {
337     buffer_state_t* state = container_of(base, buffer_state_t, base);
338
339     if (buffer_state_bad(state, config))
340         return result_error(result_error_system_error);
341
342     ZSTD_CCtx* cctx = ZSTD_createCCtx();
343     result_t result;
344
345     if (!cctx || advanced_config(cctx, state, config)) {
346         result = result_error(result_error_compression_error);
347         goto out;
348     }
349
350     result_data_t data = {.total_size = 0};
351     for (size_t i = 0; i < state->inputs.size; ++i) {
352         data_buffer_t const input = state->inputs.buffers[i];
353
354         if (!config->no_pledged_src_size) {
355             if (ZSTD_isError(ZSTD_CCtx_setPledgedSrcSize(cctx, input.size))) {
356                 result = result_error(result_error_compression_error);
357                 goto out;
358             }
359         }
360         size_t const size = ZSTD_compress2(
361             cctx,
362             state->compressed.data,
363             ZSTD_compressBound(input.size) - subtract,
364             input.data,
365             input.size);
366         if (ZSTD_isError(size)) {
367             result = result_error(result_error_compression_error);
368             goto out;
369         }
370         data.total_size += size;
371     }
372
373     result = result_data(data);
374 out:
375     ZSTD_freeCCtx(cctx);
376     return result;
377 }
378
379 static result_t advanced_one_pass_compress(
380     method_state_t* base,
381     config_t const* config) {
382   return advanced_one_pass_compress_output_adjustment(base, config, 0);
383 }
384
385 static result_t advanced_one_pass_compress_small_output(
386     method_state_t* base,
387     config_t const* config) {
388   return advanced_one_pass_compress_output_adjustment(base, config, 1);
389 }
390
391 static result_t advanced_streaming_compress(
392     method_state_t* base,
393     config_t const* config) {
394     buffer_state_t* state = container_of(base, buffer_state_t, base);
395
396     if (buffer_state_bad(state, config))
397         return result_error(result_error_system_error);
398
399     ZSTD_CCtx* cctx = ZSTD_createCCtx();
400     result_t result;
401
402     if (!cctx || advanced_config(cctx, state, config)) {
403         result = result_error(result_error_compression_error);
404         goto out;
405     }
406
407     result_data_t data = {.total_size = 0};
408     for (size_t i = 0; i < state->inputs.size; ++i) {
409         data_buffer_t input = state->inputs.buffers[i];
410
411         if (!config->no_pledged_src_size) {
412             if (ZSTD_isError(ZSTD_CCtx_setPledgedSrcSize(cctx, input.size))) {
413                 result = result_error(result_error_compression_error);
414                 goto out;
415             }
416         }
417
418         while (input.size > 0) {
419             ZSTD_inBuffer in = {input.data, MIN(input.size, 4096)};
420             input.data += in.size;
421             input.size -= in.size;
422             ZSTD_EndDirective const op =
423                 input.size > 0 ? ZSTD_e_continue : ZSTD_e_end;
424             size_t ret = 0;
425             while (in.pos < in.size || (op == ZSTD_e_end && ret != 0)) {
426                 ZSTD_outBuffer out = {state->compressed.data,
427                                       MIN(state->compressed.capacity, 1024)};
428                 ret = ZSTD_compressStream2(cctx, &out, &in, op);
429                 if (ZSTD_isError(ret)) {
430                     result = result_error(result_error_compression_error);
431                     goto out;
432                 }
433                 data.total_size += out.pos;
434             }
435         }
436     }
437
438     result = result_data(data);
439 out:
440     ZSTD_freeCCtx(cctx);
441     return result;
442 }
443
444 static int init_cstream(
445     buffer_state_t* state,
446     ZSTD_CStream* zcs,
447     config_t const* config,
448     int const advanced,
449     ZSTD_CDict** cdict)
450 {
451     size_t zret;
452     if (advanced) {
453         ZSTD_parameters const params = config_get_zstd_params(config, 0, 0);
454         ZSTD_CDict* dict = NULL;
455         if (cdict) {
456             if (!config->use_dictionary)
457               return 1;
458             *cdict = ZSTD_createCDict_advanced(
459                 state->dictionary.data,
460                 state->dictionary.size,
461                 ZSTD_dlm_byRef,
462                 ZSTD_dct_auto,
463                 params.cParams,
464                 ZSTD_defaultCMem);
465             if (!*cdict) {
466                 return 1;
467             }
468             zret = ZSTD_initCStream_usingCDict_advanced(
469                 zcs, *cdict, params.fParams, ZSTD_CONTENTSIZE_UNKNOWN);
470         } else {
471             zret = ZSTD_initCStream_advanced(
472                 zcs,
473                 config->use_dictionary ? state->dictionary.data : NULL,
474                 config->use_dictionary ? state->dictionary.size : 0,
475                 params,
476                 ZSTD_CONTENTSIZE_UNKNOWN);
477         }
478     } else {
479         int const level = config_get_level(config);
480         if (level == CONFIG_NO_LEVEL)
481             return 1;
482         if (cdict) {
483             if (!config->use_dictionary)
484               return 1;
485             *cdict = ZSTD_createCDict(
486                 state->dictionary.data,
487                 state->dictionary.size,
488                 level);
489             if (!*cdict) {
490                 return 1;
491             }
492             zret = ZSTD_initCStream_usingCDict(zcs, *cdict);
493         } else if (config->use_dictionary) {
494             zret = ZSTD_initCStream_usingDict(
495                 zcs,
496                 state->dictionary.data,
497                 state->dictionary.size,
498                 level);
499         } else {
500             zret = ZSTD_initCStream(zcs, level);
501         }
502     }
503     if (ZSTD_isError(zret)) {
504         return 1;
505     }
506     return 0;
507 }
508
509 static result_t old_streaming_compress_internal(
510     method_state_t* base,
511     config_t const* config,
512     int const advanced,
513     int const cdict) {
514   buffer_state_t* state = container_of(base, buffer_state_t, base);
515
516   if (buffer_state_bad(state, config))
517     return result_error(result_error_system_error);
518
519
520   ZSTD_CStream* zcs = ZSTD_createCStream();
521   ZSTD_CDict* cd = NULL;
522   result_t result;
523   if (zcs == NULL) {
524     result = result_error(result_error_compression_error);
525     goto out;
526   }
527   if (!advanced && config_get_level(config) == CONFIG_NO_LEVEL) {
528     result = result_error(result_error_skip);
529     goto out;
530   }
531   if (cdict && !config->use_dictionary) {
532     result = result_error(result_error_skip);
533     goto out;
534   }
535   if (config->advanced_api_only) {
536     result = result_error(result_error_skip);
537     goto out;
538   }
539   if (init_cstream(state, zcs, config, advanced, cdict ? &cd : NULL)) {
540     result = result_error(result_error_compression_error);
541     goto out;
542   }
543
544   result_data_t data = {.total_size = 0};
545   for (size_t i = 0; i < state->inputs.size; ++i) {
546     data_buffer_t input = state->inputs.buffers[i];
547     size_t zret = ZSTD_resetCStream(
548         zcs,
549         config->no_pledged_src_size ? ZSTD_CONTENTSIZE_UNKNOWN : input.size);
550     if (ZSTD_isError(zret)) {
551       result = result_error(result_error_compression_error);
552       goto out;
553     }
554
555     while (input.size > 0) {
556       ZSTD_inBuffer in = {input.data, MIN(input.size, 4096)};
557       input.data += in.size;
558       input.size -= in.size;
559       ZSTD_EndDirective const op =
560           input.size > 0 ? ZSTD_e_continue : ZSTD_e_end;
561       zret = 0;
562       while (in.pos < in.size || (op == ZSTD_e_end && zret != 0)) {
563         ZSTD_outBuffer out = {state->compressed.data,
564                               MIN(state->compressed.capacity, 1024)};
565         if (op == ZSTD_e_continue || in.pos < in.size)
566           zret = ZSTD_compressStream(zcs, &out, &in);
567         else
568           zret = ZSTD_endStream(zcs, &out);
569         if (ZSTD_isError(zret)) {
570           result = result_error(result_error_compression_error);
571           goto out;
572         }
573         data.total_size += out.pos;
574       }
575     }
576   }
577
578   result = result_data(data);
579 out:
580     ZSTD_freeCStream(zcs);
581     ZSTD_freeCDict(cd);
582     return result;
583 }
584
585 static result_t old_streaming_compress(
586     method_state_t* base,
587     config_t const* config)
588 {
589     return old_streaming_compress_internal(
590         base, config, /* advanced */ 0, /* cdict */ 0);
591 }
592
593 static result_t old_streaming_compress_advanced(
594     method_state_t* base,
595     config_t const* config)
596 {
597     return old_streaming_compress_internal(
598         base, config, /* advanced */ 1, /* cdict */ 0);
599 }
600
601 static result_t old_streaming_compress_cdict(
602     method_state_t* base,
603     config_t const* config)
604 {
605     return old_streaming_compress_internal(
606         base, config, /* advanced */ 0, /* cdict */ 1);
607 }
608
609 static result_t old_streaming_compress_cdict_advanced(
610     method_state_t* base,
611     config_t const* config)
612 {
613     return old_streaming_compress_internal(
614         base, config, /* advanced */ 1, /* cdict */ 1);
615 }
616
617 method_t const simple = {
618     .name = "compress simple",
619     .create = buffer_state_create,
620     .compress = simple_compress,
621     .destroy = buffer_state_destroy,
622 };
623
624 method_t const compress_cctx = {
625     .name = "compress cctx",
626     .create = buffer_state_create,
627     .compress = compress_cctx_compress,
628     .destroy = buffer_state_destroy,
629 };
630
631 method_t const advanced_one_pass = {
632     .name = "advanced one pass",
633     .create = buffer_state_create,
634     .compress = advanced_one_pass_compress,
635     .destroy = buffer_state_destroy,
636 };
637
638 method_t const advanced_one_pass_small_out = {
639     .name = "advanced one pass small out",
640     .create = buffer_state_create,
641     .compress = advanced_one_pass_compress,
642     .destroy = buffer_state_destroy,
643 };
644
645 method_t const advanced_streaming = {
646     .name = "advanced streaming",
647     .create = buffer_state_create,
648     .compress = advanced_streaming_compress,
649     .destroy = buffer_state_destroy,
650 };
651
652 method_t const old_streaming = {
653     .name = "old streaming",
654     .create = buffer_state_create,
655     .compress = old_streaming_compress,
656     .destroy = buffer_state_destroy,
657 };
658
659 method_t const old_streaming_advanced = {
660     .name = "old streaming advanced",
661     .create = buffer_state_create,
662     .compress = old_streaming_compress_advanced,
663     .destroy = buffer_state_destroy,
664 };
665
666 method_t const old_streaming_cdict = {
667     .name = "old streaming cdict",
668     .create = buffer_state_create,
669     .compress = old_streaming_compress_cdict,
670     .destroy = buffer_state_destroy,
671 };
672
673 method_t const old_streaming_advanced_cdict = {
674     .name = "old streaming advanced cdict",
675     .create = buffer_state_create,
676     .compress = old_streaming_compress_cdict_advanced,
677     .destroy = buffer_state_destroy,
678 };
679
680 method_t const cli = {
681     .name = "zstdcli",
682     .create = method_state_create,
683     .compress = cli_compress,
684     .destroy = method_state_destroy,
685 };
686
687 static method_t const* g_methods[] = {
688     &simple,
689     &compress_cctx,
690     &cli,
691     &advanced_one_pass,
692     &advanced_one_pass_small_out,
693     &advanced_streaming,
694     &old_streaming,
695     &old_streaming_advanced,
696     &old_streaming_cdict,
697     &old_streaming_advanced_cdict,
698     NULL,
699 };
700
701 method_t const* const* methods = g_methods;