@@ -38,8 +38,8 @@ OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
38
38
#include < barrier>
39
39
#include < cassert>
40
40
#include < cstdint>
41
- #include < iostream>
42
41
#include < algorithm>
42
+ #include < stdexcept>
43
43
44
44
#include " mkn/gpu.hpp"
45
45
@@ -135,7 +135,6 @@ struct StreamLauncher {
135
135
assert (step < fns.size ());
136
136
assert (i < events.size ());
137
137
138
- // if (fns[step]->mode == StreamFunctionMode::HOST_WAIT) events[i].stream.sync();
139
138
fns[step]->run (i);
140
139
if (fns[step]->mode == StreamFunctionMode::DEVICE_WAIT) events[i].record ().wait ();
141
140
}
@@ -203,6 +202,55 @@ struct StreamBarrierFunction : StreamFunction<Strat> {
203
202
std::barrier<decltype(on_completion)> sync_point;
204
203
};
205
204
205
+ template <typename Strat>
206
+ struct StreamGroupBarrierFunction : StreamFunction<Strat> {
207
+ using This = StreamGroupBarrierFunction<Strat>;
208
+ using Super = StreamFunction<Strat>;
209
+ using Super::strat;
210
+
211
+ std::string_view constexpr static MOD_GROUP_ERROR =
212
+ " mkn.gpu error: StreamGroupBarrierFunction Group size must be a divisor of datas" ;
213
+
214
+ struct GroupBarrier {
215
+ This* self;
216
+ std::uint16_t group_id;
217
+
218
+ std::function<void ()> on_completion = [this ]() {
219
+ std::size_t const offset = self->group_size * group_id;
220
+ for (std::size_t i = offset; i < offset + self->group_size ; ++i)
221
+ self->strat .status [i] = SFS::WAIT;
222
+ };
223
+
224
+ std::barrier<decltype(on_completion)> sync_point{static_cast <std::int64_t >(self->group_size ),
225
+ on_completion};
226
+
227
+ GroupBarrier (This& slf, std::uint16_t const gid) : self{&slf}, group_id{gid} {}
228
+ void arrive () { [[maybe_unused]] auto ret = sync_point.arrive (); }
229
+ };
230
+
231
+ static auto make_sync_points (This& self, Strat const & strat, std::size_t const & group_size) {
232
+ if (strat.datas .size () % group_size > 0 ) throw std::runtime_error (std::string{MOD_GROUP_ERROR});
233
+ std::vector<std::unique_ptr<GroupBarrier>> v;
234
+ std::uint16_t const groups = strat.datas .size () / group_size;
235
+ v.reserve (groups);
236
+ for (std::size_t i = 0 ; i < groups; ++i)
237
+ v.emplace_back (std::make_unique<GroupBarrier>(self, i));
238
+ return std::move (v);
239
+ }
240
+
241
+ StreamGroupBarrierFunction (std::size_t const & gs, Strat& strat)
242
+ : Super{strat, StreamFunctionMode::BARRIER},
243
+ group_size{gs},
244
+ sync_points{make_sync_points (*this , strat, group_size)} {}
245
+
246
+ void run (std::uint32_t const i) override {
247
+ sync_points[((i - (i % group_size)) / group_size)]->arrive ();
248
+ }
249
+
250
+ std::size_t const group_size;
251
+ std::vector<std::unique_ptr<GroupBarrier>> sync_points;
252
+ };
253
+
206
254
template <typename Datas>
207
255
struct ThreadedStreamLauncher : public StreamLauncher <Datas, ThreadedStreamLauncher<Datas>> {
208
256
using This = ThreadedStreamLauncher<Datas>;
@@ -211,8 +259,9 @@ struct ThreadedStreamLauncher : public StreamLauncher<Datas, ThreadedStreamLaunc
211
259
using Super::events;
212
260
using Super::fns;
213
261
214
- constexpr static std::size_t wait_ms = 1 ;
215
- constexpr static std::size_t wait_max_ms = 100 ;
262
+ constexpr static std::size_t wait_ms = _MKN_GPU_THREADED_STREAM_LAUNCHER_WAIT_MS_;
263
+ constexpr static std::size_t wait_add_ms = _MKN_GPU_THREADED_STREAM_LAUNCHER_WAIT_MS_ADD_;
264
+ constexpr static std::size_t wait_max_ms = _MKN_GPU_THREADED_STREAM_LAUNCHER_WAIT_MS_MAX_;
216
265
217
266
ThreadedStreamLauncher (Datas& datas, std::size_t const _n_threads = 1 ,
218
267
std::size_t const device = 0 )
@@ -235,6 +284,11 @@ struct ThreadedStreamLauncher : public StreamLauncher<Datas, ThreadedStreamLaunc
235
284
return *this ;
236
285
}
237
286
287
+ This& group_barrier (std::size_t const & group_size) {
288
+ fns.emplace_back (std::make_shared<StreamGroupBarrierFunction<This>>(group_size, *this ));
289
+ return *this ;
290
+ }
291
+
238
292
void operator ()() { join (); }
239
293
Super& super () { return *this ; }
240
294
void super (std::size_t const & idx) { return super ()(idx); }
@@ -264,7 +318,7 @@ struct ThreadedStreamLauncher : public StreamLauncher<Datas, ThreadedStreamLaunc
264
318
}
265
319
266
320
std::this_thread::sleep_for (std::chrono::milliseconds (waitms));
267
- waitms = waitms >= wait_max_ms ? wait_max_ms : waitms + 10 ;
321
+ waitms = waitms >= wait_max_ms ? wait_max_ms : waitms + wait_add_ms ;
268
322
}
269
323
}
270
324
0 commit comments