2 * Copyright (c) Meta Platforms, Inc. and affiliates.
5 * This source code is licensed under both the BSD-style license (found in the
6 * LICENSE file in the root directory of this source tree) and the GPLv2 (found
7 * in the COPYING file in the root directory of this source tree).
11 * A simple demo that sums up all the bytes in the file in parallel using
12 * seekable decompression and the zstd thread pool
15 #include <stdlib.h> // malloc, exit
16 #include <stdio.h> // fprintf, perror, feof
17 #include <string.h> // strerror
18 #include <errno.h> // errno
19 #define ZSTD_STATIC_LINKING_ONLY
20 #include <zstd.h> // presumes zstd library is installed
21 #include <zstd_errors.h>
22 #if defined(WIN32) || defined(_WIN32)
24 # define SLEEP(x) Sleep(x)
27 # define SLEEP(x) usleep(x * 1000)
30 #include "pool.h" // use zstd thread pool for demo
32 #include "../zstd_seekable.h"
34 #define MIN(a, b) ((a) < (b) ? (a) : (b))
36 static void* malloc_orDie(size_t size)
38 void* const buff = malloc(size);
39 if (buff) return buff;
45 static void* realloc_orDie(void* ptr, size_t size)
47 ptr = realloc(ptr, size);
54 static FILE* fopen_orDie(const char *filename, const char *instruction)
56 FILE* const inFile = fopen(filename, instruction);
57 if (inFile) return inFile;
63 static size_t fread_orDie(void* buffer, size_t sizeToRead, FILE* file)
65 size_t const readSize = fread(buffer, 1, sizeToRead, file);
66 if (readSize == sizeToRead) return readSize; /* good */
67 if (feof(file)) return readSize; /* good, reached end of file */
73 static size_t fwrite_orDie(const void* buffer, size_t sizeToWrite, FILE* file)
75 size_t const writtenSize = fwrite(buffer, 1, sizeToWrite, file);
76 if (writtenSize == sizeToWrite) return sizeToWrite; /* good */
82 static size_t fclose_orDie(FILE* file)
84 if (!fclose(file)) return 0;
90 static void fseek_orDie(FILE* file, long int offset, int origin) {
91 if (!fseek(file, offset, origin)) {
92 if (!fflush(file)) return;
101 unsigned long long sum;
106 static void sumFrame(void* opaque)
108 struct sum_job* job = (struct sum_job*)opaque;
111 FILE* const fin = fopen_orDie(job->fname, "rb");
113 ZSTD_seekable* const seekable = ZSTD_seekable_create();
114 if (seekable==NULL) { fprintf(stderr, "ZSTD_seekable_create() error \n"); exit(10); }
116 size_t const initResult = ZSTD_seekable_initFile(seekable, fin);
117 if (ZSTD_isError(initResult)) { fprintf(stderr, "ZSTD_seekable_init() error : %s \n", ZSTD_getErrorName(initResult)); exit(11); }
119 size_t const frameSize = ZSTD_seekable_getFrameDecompressedSize(seekable, job->frameNb);
120 unsigned char* data = malloc_orDie(frameSize);
122 size_t result = ZSTD_seekable_decompressFrame(seekable, data, frameSize, job->frameNb);
123 if (ZSTD_isError(result)) { fprintf(stderr, "ZSTD_seekable_decompressFrame() error : %s \n", ZSTD_getErrorName(result)); exit(12); }
125 unsigned long long sum = 0;
127 for (i = 0; i < frameSize; i++) {
134 ZSTD_seekable_free(seekable);
138 static void sumFile_orDie(const char* fname, int nbThreads)
140 POOL_ctx* pool = POOL_create(nbThreads, nbThreads);
141 if (pool == NULL) { fprintf(stderr, "POOL_create() error \n"); exit(9); }
143 FILE* const fin = fopen_orDie(fname, "rb");
145 ZSTD_seekable* const seekable = ZSTD_seekable_create();
146 if (seekable==NULL) { fprintf(stderr, "ZSTD_seekable_create() error \n"); exit(10); }
148 size_t const initResult = ZSTD_seekable_initFile(seekable, fin);
149 if (ZSTD_isError(initResult)) { fprintf(stderr, "ZSTD_seekable_init() error : %s \n", ZSTD_getErrorName(initResult)); exit(11); }
151 unsigned const numFrames = ZSTD_seekable_getNumFrames(seekable);
152 struct sum_job* jobs = (struct sum_job*)malloc(numFrames * sizeof(struct sum_job));
155 for (fnb = 0; fnb < numFrames; fnb++) {
156 jobs[fnb] = (struct sum_job){ fname, 0, fnb, 0 };
157 POOL_add(pool, sumFrame, &jobs[fnb]);
160 unsigned long long total = 0;
162 for (fnb = 0; fnb < numFrames; fnb++) {
163 while (!jobs[fnb].done) SLEEP(5); /* wake up every 5 milliseconds to check */
164 total += jobs[fnb].sum;
167 printf("Sum: %llu\n", total);
170 ZSTD_seekable_free(seekable);
176 int main(int argc, const char** argv)
178 const char* const exeName = argv[0];
181 fprintf(stderr, "wrong arguments\n");
182 fprintf(stderr, "usage:\n");
183 fprintf(stderr, "%s FILE NB_THREADS\n", exeName);
188 const char* const inFilename = argv[1];
189 int const nbThreads = atoi(argv[2]);
190 sumFile_orDie(inFilename, nbThreads);