648db22b |
1 | /* |
2 | * Copyright (c) Meta Platforms, Inc. and affiliates. |
3 | * All rights reserved. |
4 | * |
5 | * This source code is licensed under both the BSD-style license (found in the |
6 | * LICENSE file in the root directory of this source tree) and the GPLv2 (found |
7 | * in the COPYING file in the root directory of this source tree). |
8 | */ |
9 | #pragma once |
10 | |
11 | #include "utils/Buffer.h" |
12 | |
13 | #include <atomic> |
14 | #include <cassert> |
15 | #include <cstddef> |
16 | #include <condition_variable> |
17 | #include <cstddef> |
18 | #include <functional> |
19 | #include <mutex> |
20 | #include <queue> |
21 | |
22 | namespace pzstd { |
23 | |
24 | /// Unbounded thread-safe work queue. |
25 | template <typename T> |
26 | class WorkQueue { |
27 | // Protects all member variable access |
28 | std::mutex mutex_; |
29 | std::condition_variable readerCv_; |
30 | std::condition_variable writerCv_; |
31 | std::condition_variable finishCv_; |
32 | |
33 | std::queue<T> queue_; |
34 | bool done_; |
35 | std::size_t maxSize_; |
36 | |
37 | // Must have lock to call this function |
38 | bool full() const { |
39 | if (maxSize_ == 0) { |
40 | return false; |
41 | } |
42 | return queue_.size() >= maxSize_; |
43 | } |
44 | |
45 | public: |
46 | /** |
47 | * Constructs an empty work queue with an optional max size. |
48 | * If `maxSize == 0` the queue size is unbounded. |
49 | * |
50 | * @param maxSize The maximum allowed size of the work queue. |
51 | */ |
52 | WorkQueue(std::size_t maxSize = 0) : done_(false), maxSize_(maxSize) {} |
53 | |
54 | /** |
55 | * Push an item onto the work queue. Notify a single thread that work is |
56 | * available. If `finish()` has been called, do nothing and return false. |
57 | * If `push()` returns false, then `item` has not been moved from. |
58 | * |
59 | * @param item Item to push onto the queue. |
60 | * @returns True upon success, false if `finish()` has been called. An |
61 | * item was pushed iff `push()` returns true. |
62 | */ |
63 | bool push(T&& item) { |
64 | { |
65 | std::unique_lock<std::mutex> lock(mutex_); |
66 | while (full() && !done_) { |
67 | writerCv_.wait(lock); |
68 | } |
69 | if (done_) { |
70 | return false; |
71 | } |
72 | queue_.push(std::move(item)); |
73 | } |
74 | readerCv_.notify_one(); |
75 | return true; |
76 | } |
77 | |
78 | /** |
79 | * Attempts to pop an item off the work queue. It will block until data is |
80 | * available or `finish()` has been called. |
81 | * |
82 | * @param[out] item If `pop` returns `true`, it contains the popped item. |
83 | * If `pop` returns `false`, it is unmodified. |
84 | * @returns True upon success. False if the queue is empty and |
85 | * `finish()` has been called. |
86 | */ |
87 | bool pop(T& item) { |
88 | { |
89 | std::unique_lock<std::mutex> lock(mutex_); |
90 | while (queue_.empty() && !done_) { |
91 | readerCv_.wait(lock); |
92 | } |
93 | if (queue_.empty()) { |
94 | assert(done_); |
95 | return false; |
96 | } |
97 | item = std::move(queue_.front()); |
98 | queue_.pop(); |
99 | } |
100 | writerCv_.notify_one(); |
101 | return true; |
102 | } |
103 | |
104 | /** |
105 | * Sets the maximum queue size. If `maxSize == 0` then it is unbounded. |
106 | * |
107 | * @param maxSize The new maximum queue size. |
108 | */ |
109 | void setMaxSize(std::size_t maxSize) { |
110 | { |
111 | std::lock_guard<std::mutex> lock(mutex_); |
112 | maxSize_ = maxSize; |
113 | } |
114 | writerCv_.notify_all(); |
115 | } |
116 | |
117 | /** |
118 | * Promise that `push()` won't be called again, so once the queue is empty |
119 | * there will never any more work. |
120 | */ |
121 | void finish() { |
122 | { |
123 | std::lock_guard<std::mutex> lock(mutex_); |
124 | assert(!done_); |
125 | done_ = true; |
126 | } |
127 | readerCv_.notify_all(); |
128 | writerCv_.notify_all(); |
129 | finishCv_.notify_all(); |
130 | } |
131 | |
132 | /// Blocks until `finish()` has been called (but the queue may not be empty). |
133 | void waitUntilFinished() { |
134 | std::unique_lock<std::mutex> lock(mutex_); |
135 | while (!done_) { |
136 | finishCv_.wait(lock); |
137 | } |
138 | } |
139 | }; |
140 | |
141 | /// Work queue for `Buffer`s that knows the total number of bytes in the queue. |
142 | class BufferWorkQueue { |
143 | WorkQueue<Buffer> queue_; |
144 | std::atomic<std::size_t> size_; |
145 | |
146 | public: |
147 | BufferWorkQueue(std::size_t maxSize = 0) : queue_(maxSize), size_(0) {} |
148 | |
149 | void push(Buffer buffer) { |
150 | size_.fetch_add(buffer.size()); |
151 | queue_.push(std::move(buffer)); |
152 | } |
153 | |
154 | bool pop(Buffer& buffer) { |
155 | bool result = queue_.pop(buffer); |
156 | if (result) { |
157 | size_.fetch_sub(buffer.size()); |
158 | } |
159 | return result; |
160 | } |
161 | |
162 | void setMaxSize(std::size_t maxSize) { |
163 | queue_.setMaxSize(maxSize); |
164 | } |
165 | |
166 | void finish() { |
167 | queue_.finish(); |
168 | } |
169 | |
170 | /** |
171 | * Blocks until `finish()` has been called. |
172 | * |
173 | * @returns The total number of bytes of all the `Buffer`s currently in the |
174 | * queue. |
175 | */ |
176 | std::size_t size() { |
177 | queue_.waitUntilFinished(); |
178 | return size_.load(); |
179 | } |
180 | }; |
181 | } |