| 1 | /* |
| 2 | * Copyright (c) Meta Platforms, Inc. and affiliates. |
| 3 | * All rights reserved. |
| 4 | * |
| 5 | * This source code is licensed under both the BSD-style license (found in the |
| 6 | * LICENSE file in the root directory of this source tree) and the GPLv2 (found |
| 7 | * in the COPYING file in the root directory of this source tree). |
| 8 | */ |
| 9 | #include "platform.h" /* Large Files support, SET_BINARY_MODE */ |
| 10 | #include "Pzstd.h" |
| 11 | #include "SkippableFrame.h" |
| 12 | #include "utils/FileSystem.h" |
| 13 | #include "utils/Portability.h" |
| 14 | #include "utils/Range.h" |
| 15 | #include "utils/ScopeGuard.h" |
| 16 | #include "utils/ThreadPool.h" |
| 17 | #include "utils/WorkQueue.h" |
| 18 | |
| 19 | #include <algorithm> |
| 20 | #include <chrono> |
| 21 | #include <cinttypes> |
| 22 | #include <cstddef> |
| 23 | #include <cstdio> |
| 24 | #include <memory> |
| 25 | #include <string> |
| 26 | |
| 27 | |
| 28 | namespace pzstd { |
| 29 | |
| 30 | namespace { |
| 31 | #ifdef _WIN32 |
| 32 | const std::string nullOutput = "nul"; |
| 33 | #else |
| 34 | const std::string nullOutput = "/dev/null"; |
| 35 | #endif |
| 36 | } |
| 37 | |
| 38 | using std::size_t; |
| 39 | |
| 40 | static std::uintmax_t fileSizeOrZero(const std::string &file) { |
| 41 | if (file == "-") { |
| 42 | return 0; |
| 43 | } |
| 44 | std::error_code ec; |
| 45 | auto size = file_size(file, ec); |
| 46 | if (ec) { |
| 47 | size = 0; |
| 48 | } |
| 49 | return size; |
| 50 | } |
| 51 | |
| 52 | static std::uint64_t handleOneInput(const Options &options, |
| 53 | const std::string &inputFile, |
| 54 | FILE* inputFd, |
| 55 | const std::string &outputFile, |
| 56 | FILE* outputFd, |
| 57 | SharedState& state) { |
| 58 | auto inputSize = fileSizeOrZero(inputFile); |
| 59 | // WorkQueue outlives ThreadPool so in the case of error we are certain |
| 60 | // we don't accidentally try to call push() on it after it is destroyed |
| 61 | WorkQueue<std::shared_ptr<BufferWorkQueue>> outs{options.numThreads + 1}; |
| 62 | std::uint64_t bytesRead; |
| 63 | std::uint64_t bytesWritten; |
| 64 | { |
| 65 | // Initialize the (de)compression thread pool with numThreads |
| 66 | ThreadPool executor(options.numThreads); |
| 67 | // Run the reader thread on an extra thread |
| 68 | ThreadPool readExecutor(1); |
| 69 | if (!options.decompress) { |
| 70 | // Add a job that reads the input and starts all the compression jobs |
| 71 | readExecutor.add( |
| 72 | [&state, &outs, &executor, inputFd, inputSize, &options, &bytesRead] { |
| 73 | bytesRead = asyncCompressChunks( |
| 74 | state, |
| 75 | outs, |
| 76 | executor, |
| 77 | inputFd, |
| 78 | inputSize, |
| 79 | options.numThreads, |
| 80 | options.determineParameters()); |
| 81 | }); |
| 82 | // Start writing |
| 83 | bytesWritten = writeFile(state, outs, outputFd, options.decompress); |
| 84 | } else { |
| 85 | // Add a job that reads the input and starts all the decompression jobs |
| 86 | readExecutor.add([&state, &outs, &executor, inputFd, &bytesRead] { |
| 87 | bytesRead = asyncDecompressFrames(state, outs, executor, inputFd); |
| 88 | }); |
| 89 | // Start writing |
| 90 | bytesWritten = writeFile(state, outs, outputFd, options.decompress); |
| 91 | } |
| 92 | } |
| 93 | if (!state.errorHolder.hasError()) { |
| 94 | std::string inputFileName = inputFile == "-" ? "stdin" : inputFile; |
| 95 | std::string outputFileName = outputFile == "-" ? "stdout" : outputFile; |
| 96 | if (!options.decompress) { |
| 97 | double ratio = static_cast<double>(bytesWritten) / |
| 98 | static_cast<double>(bytesRead + !bytesRead); |
| 99 | state.log(kLogInfo, "%-20s :%6.2f%% (%6" PRIu64 " => %6" PRIu64 |
| 100 | " bytes, %s)\n", |
| 101 | inputFileName.c_str(), ratio * 100, bytesRead, bytesWritten, |
| 102 | outputFileName.c_str()); |
| 103 | } else { |
| 104 | state.log(kLogInfo, "%-20s: %" PRIu64 " bytes \n", |
| 105 | inputFileName.c_str(),bytesWritten); |
| 106 | } |
| 107 | } |
| 108 | return bytesWritten; |
| 109 | } |
| 110 | |
| 111 | static FILE *openInputFile(const std::string &inputFile, |
| 112 | ErrorHolder &errorHolder) { |
| 113 | if (inputFile == "-") { |
| 114 | SET_BINARY_MODE(stdin); |
| 115 | return stdin; |
| 116 | } |
| 117 | // Check if input file is a directory |
| 118 | { |
| 119 | std::error_code ec; |
| 120 | if (is_directory(inputFile, ec)) { |
| 121 | errorHolder.setError("Output file is a directory -- ignored"); |
| 122 | return nullptr; |
| 123 | } |
| 124 | } |
| 125 | auto inputFd = std::fopen(inputFile.c_str(), "rb"); |
| 126 | if (!errorHolder.check(inputFd != nullptr, "Failed to open input file")) { |
| 127 | return nullptr; |
| 128 | } |
| 129 | return inputFd; |
| 130 | } |
| 131 | |
| 132 | static FILE *openOutputFile(const Options &options, |
| 133 | const std::string &outputFile, |
| 134 | SharedState& state) { |
| 135 | if (outputFile == "-") { |
| 136 | SET_BINARY_MODE(stdout); |
| 137 | return stdout; |
| 138 | } |
| 139 | // Check if the output file exists and then open it |
| 140 | if (!options.overwrite && outputFile != nullOutput) { |
| 141 | auto outputFd = std::fopen(outputFile.c_str(), "rb"); |
| 142 | if (outputFd != nullptr) { |
| 143 | std::fclose(outputFd); |
| 144 | if (!state.log.logsAt(kLogInfo)) { |
| 145 | state.errorHolder.setError("Output file exists"); |
| 146 | return nullptr; |
| 147 | } |
| 148 | state.log( |
| 149 | kLogInfo, |
| 150 | "pzstd: %s already exists; do you wish to overwrite (y/n) ? ", |
| 151 | outputFile.c_str()); |
| 152 | int c = getchar(); |
| 153 | if (c != 'y' && c != 'Y') { |
| 154 | state.errorHolder.setError("Not overwritten"); |
| 155 | return nullptr; |
| 156 | } |
| 157 | } |
| 158 | } |
| 159 | auto outputFd = std::fopen(outputFile.c_str(), "wb"); |
| 160 | if (!state.errorHolder.check( |
| 161 | outputFd != nullptr, "Failed to open output file")) { |
| 162 | return nullptr; |
| 163 | } |
| 164 | return outputFd; |
| 165 | } |
| 166 | |
| 167 | int pzstdMain(const Options &options) { |
| 168 | int returnCode = 0; |
| 169 | SharedState state(options); |
| 170 | for (const auto& input : options.inputFiles) { |
| 171 | // Setup the shared state |
| 172 | auto printErrorGuard = makeScopeGuard([&] { |
| 173 | if (state.errorHolder.hasError()) { |
| 174 | returnCode = 1; |
| 175 | state.log(kLogError, "pzstd: %s: %s.\n", input.c_str(), |
| 176 | state.errorHolder.getError().c_str()); |
| 177 | } |
| 178 | }); |
| 179 | // Open the input file |
| 180 | auto inputFd = openInputFile(input, state.errorHolder); |
| 181 | if (inputFd == nullptr) { |
| 182 | continue; |
| 183 | } |
| 184 | auto closeInputGuard = makeScopeGuard([&] { std::fclose(inputFd); }); |
| 185 | // Open the output file |
| 186 | auto outputFile = options.getOutputFile(input); |
| 187 | if (!state.errorHolder.check(outputFile != "", |
| 188 | "Input file does not have extension .zst")) { |
| 189 | continue; |
| 190 | } |
| 191 | auto outputFd = openOutputFile(options, outputFile, state); |
| 192 | if (outputFd == nullptr) { |
| 193 | continue; |
| 194 | } |
| 195 | auto closeOutputGuard = makeScopeGuard([&] { std::fclose(outputFd); }); |
| 196 | // (de)compress the file |
| 197 | handleOneInput(options, input, inputFd, outputFile, outputFd, state); |
| 198 | if (state.errorHolder.hasError()) { |
| 199 | continue; |
| 200 | } |
| 201 | // Delete the input file if necessary |
| 202 | if (!options.keepSource) { |
| 203 | // Be sure that we are done and have written everything before we delete |
| 204 | if (!state.errorHolder.check(std::fclose(inputFd) == 0, |
| 205 | "Failed to close input file")) { |
| 206 | continue; |
| 207 | } |
| 208 | closeInputGuard.dismiss(); |
| 209 | if (!state.errorHolder.check(std::fclose(outputFd) == 0, |
| 210 | "Failed to close output file")) { |
| 211 | continue; |
| 212 | } |
| 213 | closeOutputGuard.dismiss(); |
| 214 | if (std::remove(input.c_str()) != 0) { |
| 215 | state.errorHolder.setError("Failed to remove input file"); |
| 216 | continue; |
| 217 | } |
| 218 | } |
| 219 | } |
| 220 | // Returns 1 if any of the files failed to (de)compress. |
| 221 | return returnCode; |
| 222 | } |
| 223 | |
| 224 | /// Construct a `ZSTD_inBuffer` that points to the data in `buffer`. |
| 225 | static ZSTD_inBuffer makeZstdInBuffer(const Buffer& buffer) { |
| 226 | return ZSTD_inBuffer{buffer.data(), buffer.size(), 0}; |
| 227 | } |
| 228 | |
| 229 | /** |
| 230 | * Advance `buffer` and `inBuffer` by the amount of data read, as indicated by |
| 231 | * `inBuffer.pos`. |
| 232 | */ |
| 233 | void advance(Buffer& buffer, ZSTD_inBuffer& inBuffer) { |
| 234 | auto pos = inBuffer.pos; |
| 235 | inBuffer.src = static_cast<const unsigned char*>(inBuffer.src) + pos; |
| 236 | inBuffer.size -= pos; |
| 237 | inBuffer.pos = 0; |
| 238 | return buffer.advance(pos); |
| 239 | } |
| 240 | |
| 241 | /// Construct a `ZSTD_outBuffer` that points to the data in `buffer`. |
| 242 | static ZSTD_outBuffer makeZstdOutBuffer(Buffer& buffer) { |
| 243 | return ZSTD_outBuffer{buffer.data(), buffer.size(), 0}; |
| 244 | } |
| 245 | |
| 246 | /** |
| 247 | * Split `buffer` and advance `outBuffer` by the amount of data written, as |
| 248 | * indicated by `outBuffer.pos`. |
| 249 | */ |
| 250 | Buffer split(Buffer& buffer, ZSTD_outBuffer& outBuffer) { |
| 251 | auto pos = outBuffer.pos; |
| 252 | outBuffer.dst = static_cast<unsigned char*>(outBuffer.dst) + pos; |
| 253 | outBuffer.size -= pos; |
| 254 | outBuffer.pos = 0; |
| 255 | return buffer.splitAt(pos); |
| 256 | } |
| 257 | |
| 258 | /** |
| 259 | * Stream chunks of input from `in`, compress it, and stream it out to `out`. |
| 260 | * |
| 261 | * @param state The shared state |
| 262 | * @param in Queue that we `pop()` input buffers from |
| 263 | * @param out Queue that we `push()` compressed output buffers to |
| 264 | * @param maxInputSize An upper bound on the size of the input |
| 265 | */ |
| 266 | static void compress( |
| 267 | SharedState& state, |
| 268 | std::shared_ptr<BufferWorkQueue> in, |
| 269 | std::shared_ptr<BufferWorkQueue> out, |
| 270 | size_t maxInputSize) { |
| 271 | auto& errorHolder = state.errorHolder; |
| 272 | auto guard = makeScopeGuard([&] { out->finish(); }); |
| 273 | // Initialize the CCtx |
| 274 | auto ctx = state.cStreamPool->get(); |
| 275 | if (!errorHolder.check(ctx != nullptr, "Failed to allocate ZSTD_CStream")) { |
| 276 | return; |
| 277 | } |
| 278 | { |
| 279 | auto err = ZSTD_CCtx_reset(ctx.get(), ZSTD_reset_session_only); |
| 280 | if (!errorHolder.check(!ZSTD_isError(err), ZSTD_getErrorName(err))) { |
| 281 | return; |
| 282 | } |
| 283 | } |
| 284 | |
| 285 | // Allocate space for the result |
| 286 | auto outBuffer = Buffer(ZSTD_compressBound(maxInputSize)); |
| 287 | auto zstdOutBuffer = makeZstdOutBuffer(outBuffer); |
| 288 | { |
| 289 | Buffer inBuffer; |
| 290 | // Read a buffer in from the input queue |
| 291 | while (in->pop(inBuffer) && !errorHolder.hasError()) { |
| 292 | auto zstdInBuffer = makeZstdInBuffer(inBuffer); |
| 293 | // Compress the whole buffer and send it to the output queue |
| 294 | while (!inBuffer.empty() && !errorHolder.hasError()) { |
| 295 | if (!errorHolder.check( |
| 296 | !outBuffer.empty(), "ZSTD_compressBound() was too small")) { |
| 297 | return; |
| 298 | } |
| 299 | // Compress |
| 300 | auto err = |
| 301 | ZSTD_compressStream(ctx.get(), &zstdOutBuffer, &zstdInBuffer); |
| 302 | if (!errorHolder.check(!ZSTD_isError(err), ZSTD_getErrorName(err))) { |
| 303 | return; |
| 304 | } |
| 305 | // Split the compressed data off outBuffer and pass to the output queue |
| 306 | out->push(split(outBuffer, zstdOutBuffer)); |
| 307 | // Forget about the data we already compressed |
| 308 | advance(inBuffer, zstdInBuffer); |
| 309 | } |
| 310 | } |
| 311 | } |
| 312 | // Write the epilog |
| 313 | size_t bytesLeft; |
| 314 | do { |
| 315 | if (!errorHolder.check( |
| 316 | !outBuffer.empty(), "ZSTD_compressBound() was too small")) { |
| 317 | return; |
| 318 | } |
| 319 | bytesLeft = ZSTD_endStream(ctx.get(), &zstdOutBuffer); |
| 320 | if (!errorHolder.check( |
| 321 | !ZSTD_isError(bytesLeft), ZSTD_getErrorName(bytesLeft))) { |
| 322 | return; |
| 323 | } |
| 324 | out->push(split(outBuffer, zstdOutBuffer)); |
| 325 | } while (bytesLeft != 0 && !errorHolder.hasError()); |
| 326 | } |
| 327 | |
| 328 | /** |
| 329 | * Calculates how large each independently compressed frame should be. |
| 330 | * |
| 331 | * @param size The size of the source if known, 0 otherwise |
| 332 | * @param numThreads The number of threads available to run compression jobs on |
| 333 | * @param params The zstd parameters to be used for compression |
| 334 | */ |
| 335 | static size_t calculateStep( |
| 336 | std::uintmax_t size, |
| 337 | size_t numThreads, |
| 338 | const ZSTD_parameters ¶ms) { |
| 339 | (void)size; |
| 340 | (void)numThreads; |
| 341 | // Not validated to work correctly for window logs > 23. |
| 342 | // It will definitely fail if windowLog + 2 is >= 4GB because |
| 343 | // the skippable frame can only store sizes up to 4GB. |
| 344 | assert(params.cParams.windowLog <= 23); |
| 345 | return size_t{1} << (params.cParams.windowLog + 2); |
| 346 | } |
| 347 | |
| 348 | namespace { |
| 349 | enum class FileStatus { Continue, Done, Error }; |
| 350 | /// Determines the status of the file descriptor `fd`. |
| 351 | FileStatus fileStatus(FILE* fd) { |
| 352 | if (std::feof(fd)) { |
| 353 | return FileStatus::Done; |
| 354 | } else if (std::ferror(fd)) { |
| 355 | return FileStatus::Error; |
| 356 | } |
| 357 | return FileStatus::Continue; |
| 358 | } |
| 359 | } // anonymous namespace |
| 360 | |
| 361 | /** |
| 362 | * Reads `size` data in chunks of `chunkSize` and puts it into `queue`. |
| 363 | * Will read less if an error or EOF occurs. |
| 364 | * Returns the status of the file after all of the reads have occurred. |
| 365 | */ |
| 366 | static FileStatus |
| 367 | readData(BufferWorkQueue& queue, size_t chunkSize, size_t size, FILE* fd, |
| 368 | std::uint64_t *totalBytesRead) { |
| 369 | Buffer buffer(size); |
| 370 | while (!buffer.empty()) { |
| 371 | auto bytesRead = |
| 372 | std::fread(buffer.data(), 1, std::min(chunkSize, buffer.size()), fd); |
| 373 | *totalBytesRead += bytesRead; |
| 374 | queue.push(buffer.splitAt(bytesRead)); |
| 375 | auto status = fileStatus(fd); |
| 376 | if (status != FileStatus::Continue) { |
| 377 | return status; |
| 378 | } |
| 379 | } |
| 380 | return FileStatus::Continue; |
| 381 | } |
| 382 | |
| 383 | std::uint64_t asyncCompressChunks( |
| 384 | SharedState& state, |
| 385 | WorkQueue<std::shared_ptr<BufferWorkQueue>>& chunks, |
| 386 | ThreadPool& executor, |
| 387 | FILE* fd, |
| 388 | std::uintmax_t size, |
| 389 | size_t numThreads, |
| 390 | ZSTD_parameters params) { |
| 391 | auto chunksGuard = makeScopeGuard([&] { chunks.finish(); }); |
| 392 | std::uint64_t bytesRead = 0; |
| 393 | |
| 394 | // Break the input up into chunks of size `step` and compress each chunk |
| 395 | // independently. |
| 396 | size_t step = calculateStep(size, numThreads, params); |
| 397 | state.log(kLogDebug, "Chosen frame size: %zu\n", step); |
| 398 | auto status = FileStatus::Continue; |
| 399 | while (status == FileStatus::Continue && !state.errorHolder.hasError()) { |
| 400 | // Make a new input queue that we will put the chunk's input data into. |
| 401 | auto in = std::make_shared<BufferWorkQueue>(); |
| 402 | auto inGuard = makeScopeGuard([&] { in->finish(); }); |
| 403 | // Make a new output queue that compress will put the compressed data into. |
| 404 | auto out = std::make_shared<BufferWorkQueue>(); |
| 405 | // Start compression in the thread pool |
| 406 | executor.add([&state, in, out, step] { |
| 407 | return compress( |
| 408 | state, std::move(in), std::move(out), step); |
| 409 | }); |
| 410 | // Pass the output queue to the writer thread. |
| 411 | chunks.push(std::move(out)); |
| 412 | state.log(kLogVerbose, "%s\n", "Starting a new frame"); |
| 413 | // Fill the input queue for the compression job we just started |
| 414 | status = readData(*in, ZSTD_CStreamInSize(), step, fd, &bytesRead); |
| 415 | } |
| 416 | state.errorHolder.check(status != FileStatus::Error, "Error reading input"); |
| 417 | return bytesRead; |
| 418 | } |
| 419 | |
| 420 | /** |
| 421 | * Decompress a frame, whose data is streamed into `in`, and stream the output |
| 422 | * to `out`. |
| 423 | * |
| 424 | * @param state The shared state |
| 425 | * @param in Queue that we `pop()` input buffers from. It contains |
| 426 | * exactly one compressed frame. |
| 427 | * @param out Queue that we `push()` decompressed output buffers to |
| 428 | */ |
| 429 | static void decompress( |
| 430 | SharedState& state, |
| 431 | std::shared_ptr<BufferWorkQueue> in, |
| 432 | std::shared_ptr<BufferWorkQueue> out) { |
| 433 | auto& errorHolder = state.errorHolder; |
| 434 | auto guard = makeScopeGuard([&] { out->finish(); }); |
| 435 | // Initialize the DCtx |
| 436 | auto ctx = state.dStreamPool->get(); |
| 437 | if (!errorHolder.check(ctx != nullptr, "Failed to allocate ZSTD_DStream")) { |
| 438 | return; |
| 439 | } |
| 440 | { |
| 441 | auto err = ZSTD_DCtx_reset(ctx.get(), ZSTD_reset_session_only); |
| 442 | if (!errorHolder.check(!ZSTD_isError(err), ZSTD_getErrorName(err))) { |
| 443 | return; |
| 444 | } |
| 445 | } |
| 446 | |
| 447 | const size_t outSize = ZSTD_DStreamOutSize(); |
| 448 | Buffer inBuffer; |
| 449 | size_t returnCode = 0; |
| 450 | // Read a buffer in from the input queue |
| 451 | while (in->pop(inBuffer) && !errorHolder.hasError()) { |
| 452 | auto zstdInBuffer = makeZstdInBuffer(inBuffer); |
| 453 | // Decompress the whole buffer and send it to the output queue |
| 454 | while (!inBuffer.empty() && !errorHolder.hasError()) { |
| 455 | // Allocate a buffer with at least outSize bytes. |
| 456 | Buffer outBuffer(outSize); |
| 457 | auto zstdOutBuffer = makeZstdOutBuffer(outBuffer); |
| 458 | // Decompress |
| 459 | returnCode = |
| 460 | ZSTD_decompressStream(ctx.get(), &zstdOutBuffer, &zstdInBuffer); |
| 461 | if (!errorHolder.check( |
| 462 | !ZSTD_isError(returnCode), ZSTD_getErrorName(returnCode))) { |
| 463 | return; |
| 464 | } |
| 465 | // Pass the buffer with the decompressed data to the output queue |
| 466 | out->push(split(outBuffer, zstdOutBuffer)); |
| 467 | // Advance past the input we already read |
| 468 | advance(inBuffer, zstdInBuffer); |
| 469 | if (returnCode == 0) { |
| 470 | // The frame is over, prepare to (maybe) start a new frame |
| 471 | ZSTD_initDStream(ctx.get()); |
| 472 | } |
| 473 | } |
| 474 | } |
| 475 | if (!errorHolder.check(returnCode <= 1, "Incomplete block")) { |
| 476 | return; |
| 477 | } |
| 478 | // We've given ZSTD_decompressStream all of our data, but there may still |
| 479 | // be data to read. |
| 480 | while (returnCode == 1) { |
| 481 | // Allocate a buffer with at least outSize bytes. |
| 482 | Buffer outBuffer(outSize); |
| 483 | auto zstdOutBuffer = makeZstdOutBuffer(outBuffer); |
| 484 | // Pass in no input. |
| 485 | ZSTD_inBuffer zstdInBuffer{nullptr, 0, 0}; |
| 486 | // Decompress |
| 487 | returnCode = |
| 488 | ZSTD_decompressStream(ctx.get(), &zstdOutBuffer, &zstdInBuffer); |
| 489 | if (!errorHolder.check( |
| 490 | !ZSTD_isError(returnCode), ZSTD_getErrorName(returnCode))) { |
| 491 | return; |
| 492 | } |
| 493 | // Pass the buffer with the decompressed data to the output queue |
| 494 | out->push(split(outBuffer, zstdOutBuffer)); |
| 495 | } |
| 496 | } |
| 497 | |
| 498 | std::uint64_t asyncDecompressFrames( |
| 499 | SharedState& state, |
| 500 | WorkQueue<std::shared_ptr<BufferWorkQueue>>& frames, |
| 501 | ThreadPool& executor, |
| 502 | FILE* fd) { |
| 503 | auto framesGuard = makeScopeGuard([&] { frames.finish(); }); |
| 504 | std::uint64_t totalBytesRead = 0; |
| 505 | |
| 506 | // Split the source up into its component frames. |
| 507 | // If we find our recognized skippable frame we know the next frames size |
| 508 | // which means that we can decompress each standard frame in independently. |
| 509 | // Otherwise, we will decompress using only one decompression task. |
| 510 | const size_t chunkSize = ZSTD_DStreamInSize(); |
| 511 | auto status = FileStatus::Continue; |
| 512 | while (status == FileStatus::Continue && !state.errorHolder.hasError()) { |
| 513 | // Make a new input queue that we will put the frames's bytes into. |
| 514 | auto in = std::make_shared<BufferWorkQueue>(); |
| 515 | auto inGuard = makeScopeGuard([&] { in->finish(); }); |
| 516 | // Make a output queue that decompress will put the decompressed data into |
| 517 | auto out = std::make_shared<BufferWorkQueue>(); |
| 518 | |
| 519 | size_t frameSize; |
| 520 | { |
| 521 | // Calculate the size of the next frame. |
| 522 | // frameSize is 0 if the frame info can't be decoded. |
| 523 | Buffer buffer(SkippableFrame::kSize); |
| 524 | auto bytesRead = std::fread(buffer.data(), 1, buffer.size(), fd); |
| 525 | totalBytesRead += bytesRead; |
| 526 | status = fileStatus(fd); |
| 527 | if (bytesRead == 0 && status != FileStatus::Continue) { |
| 528 | break; |
| 529 | } |
| 530 | buffer.subtract(buffer.size() - bytesRead); |
| 531 | frameSize = SkippableFrame::tryRead(buffer.range()); |
| 532 | in->push(std::move(buffer)); |
| 533 | } |
| 534 | if (frameSize == 0) { |
| 535 | // We hit a non SkippableFrame, so this will be the last job. |
| 536 | // Make sure that we don't use too much memory |
| 537 | in->setMaxSize(64); |
| 538 | out->setMaxSize(64); |
| 539 | } |
| 540 | // Start decompression in the thread pool |
| 541 | executor.add([&state, in, out] { |
| 542 | return decompress(state, std::move(in), std::move(out)); |
| 543 | }); |
| 544 | // Pass the output queue to the writer thread |
| 545 | frames.push(std::move(out)); |
| 546 | if (frameSize == 0) { |
| 547 | // We hit a non SkippableFrame ==> not compressed by pzstd or corrupted |
| 548 | // Pass the rest of the source to this decompression task |
| 549 | state.log(kLogVerbose, "%s\n", |
| 550 | "Input not in pzstd format, falling back to serial decompression"); |
| 551 | while (status == FileStatus::Continue && !state.errorHolder.hasError()) { |
| 552 | status = readData(*in, chunkSize, chunkSize, fd, &totalBytesRead); |
| 553 | } |
| 554 | break; |
| 555 | } |
| 556 | state.log(kLogVerbose, "Decompressing a frame of size %zu", frameSize); |
| 557 | // Fill the input queue for the decompression job we just started |
| 558 | status = readData(*in, chunkSize, frameSize, fd, &totalBytesRead); |
| 559 | } |
| 560 | state.errorHolder.check(status != FileStatus::Error, "Error reading input"); |
| 561 | return totalBytesRead; |
| 562 | } |
| 563 | |
| 564 | /// Write `data` to `fd`, returns true iff success. |
| 565 | static bool writeData(ByteRange data, FILE* fd) { |
| 566 | while (!data.empty()) { |
| 567 | data.advance(std::fwrite(data.begin(), 1, data.size(), fd)); |
| 568 | if (std::ferror(fd)) { |
| 569 | return false; |
| 570 | } |
| 571 | } |
| 572 | return true; |
| 573 | } |
| 574 | |
| 575 | std::uint64_t writeFile( |
| 576 | SharedState& state, |
| 577 | WorkQueue<std::shared_ptr<BufferWorkQueue>>& outs, |
| 578 | FILE* outputFd, |
| 579 | bool decompress) { |
| 580 | auto& errorHolder = state.errorHolder; |
| 581 | auto lineClearGuard = makeScopeGuard([&state] { |
| 582 | state.log.clear(kLogInfo); |
| 583 | }); |
| 584 | std::uint64_t bytesWritten = 0; |
| 585 | std::shared_ptr<BufferWorkQueue> out; |
| 586 | // Grab the output queue for each decompression job (in order). |
| 587 | while (outs.pop(out)) { |
| 588 | if (errorHolder.hasError()) { |
| 589 | continue; |
| 590 | } |
| 591 | if (!decompress) { |
| 592 | // If we are compressing and want to write skippable frames we can't |
| 593 | // start writing before compression is done because we need to know the |
| 594 | // compressed size. |
| 595 | // Wait for the compressed size to be available and write skippable frame |
| 596 | assert(uint64_t(out->size()) < uint64_t(1) << 32); |
| 597 | SkippableFrame frame(uint32_t(out->size())); |
| 598 | if (!writeData(frame.data(), outputFd)) { |
| 599 | errorHolder.setError("Failed to write output"); |
| 600 | return bytesWritten; |
| 601 | } |
| 602 | bytesWritten += frame.kSize; |
| 603 | } |
| 604 | // For each chunk of the frame: Pop it from the queue and write it |
| 605 | Buffer buffer; |
| 606 | while (out->pop(buffer) && !errorHolder.hasError()) { |
| 607 | if (!writeData(buffer.range(), outputFd)) { |
| 608 | errorHolder.setError("Failed to write output"); |
| 609 | return bytesWritten; |
| 610 | } |
| 611 | bytesWritten += buffer.size(); |
| 612 | state.log.update(kLogInfo, "Written: %u MB ", |
| 613 | static_cast<std::uint32_t>(bytesWritten >> 20)); |
| 614 | } |
| 615 | } |
| 616 | return bytesWritten; |
| 617 | } |
| 618 | } |