Skip to content

Commit 879dc69

Browse files
committed
Cover some more MPI_Win_* API surface
1 parent 1738e29 commit 879dc69

File tree

2 files changed

+149
-24
lines changed

2 files changed

+149
-24
lines changed

c++/mpi/mpi.hpp

Lines changed: 81 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -368,73 +368,131 @@ namespace mpi {
368368
window& operator=(window const&) = delete;
369369
window& operator=(window &&) = delete;
370370

371-
explicit window(communicator &c, BaseType *base, MPI_Aint size = 0) {
371+
/// Create a window over an existing local memory buffer
372+
explicit window(communicator &c, BaseType *base, MPI_Aint size = 0) noexcept {
372373
MPI_Win_create(base, size * sizeof(BaseType), alignof(BaseType), MPI_INFO_NULL, c.get(), &win);
373374
}
374375

376+
/// Create a window and allocate memory for a local memory buffer
377+
explicit window(communicator &c, MPI_Aint size = 0) noexcept {
378+
void *baseptr = nullptr;
379+
MPI_Win_allocate(size * sizeof(BaseType), alignof(BaseType), MPI_INFO_NULL, c.get(), &baseptr, &win);
380+
}
381+
375382
~window() {
376383
if (win != MPI_WIN_NULL) {
377384
MPI_Win_free(&win);
378385
}
379386
}
380387

381-
operator MPI_Win() const { return win; };
382-
operator MPI_Win*() { return &win; };
388+
explicit operator MPI_Win() const noexcept { return win; };
389+
explicit operator MPI_Win*() noexcept { return &win; };
383390

384-
void fence(int assert = 0) const {
391+
/// Synchronization routine in active target RMA. It opens and closes an access epoch.
392+
void fence(int assert = 0) const noexcept {
385393
MPI_Win_fence(assert, win);
386394
}
387395

396+
/// Complete all outstanding RMA operations at both the origin and the target
397+
void flush(int rank = -1) const noexcept {
398+
if (rank < 0) {
399+
MPI_Win_flush_all(win);
400+
} else {
401+
MPI_Win_flush(rank, win);
402+
}
403+
}
404+
405+
/// Synchronize the private and public copies of the window
406+
void sync() const noexcept {
407+
MPI_Win_sync(win);
408+
}
409+
410+
/// Starts an RMA access epoch locking access to a particular or all ranks in the window
411+
void lock(int rank = -1, int lock_type = MPI_LOCK_SHARED, int assert = 0) const noexcept {
412+
if (rank < 0) {
413+
MPI_Win_lock_all(assert, win);
414+
} else {
415+
MPI_Win_lock(lock_type, rank, assert, win);
416+
}
417+
}
418+
419+
/// Completes an RMA access epoch started by a call to lock()
420+
void unlock(int rank = -1) const noexcept {
421+
if (rank < 0) {
422+
MPI_Win_unlock_all(win);
423+
} else {
424+
MPI_Win_unlock(rank, win);
425+
}
426+
}
427+
428+
/// Load data from a remote memory window.
388429
template <typename TargetType = BaseType, typename OriginType>
389430
std::enable_if_t<has_mpi_type<OriginType> && has_mpi_type<TargetType>, void>
390-
get(OriginType *origin_addr, int origin_count, int target_rank, MPI_Aint target_disp = 0, int target_count = -1) const {
391-
MPI_Datatype origin_datatype = mpi_type<OriginType>::get();
392-
MPI_Datatype target_datatype = mpi_type<TargetType>::get();
393-
int target_count_ = target_count < 0 ? origin_count : target_count;
394-
MPI_Get(origin_addr, origin_count, origin_datatype, target_rank, target_disp, target_count_, target_datatype, win);
431+
get(OriginType *origin_addr, int origin_count, int target_rank, MPI_Aint target_disp = 0, int target_count = -1) const noexcept {
432+
MPI_Datatype origin_datatype = mpi_type<OriginType>::get();
433+
MPI_Datatype target_datatype = mpi_type<TargetType>::get();
434+
int target_count_ = target_count < 0 ? origin_count : target_count;
435+
MPI_Get(origin_addr, origin_count, origin_datatype, target_rank, target_disp, target_count_, target_datatype, win);
395436
};
396437

438+
/// Store data to a remote memory window.
397439
template <typename TargetType = BaseType, typename OriginType>
398440
std::enable_if_t<has_mpi_type<OriginType> && has_mpi_type<TargetType>, void>
399-
put(OriginType *origin_addr, int origin_count, int target_rank, MPI_Aint target_disp = 0, int target_count = -1) const {
400-
MPI_Datatype origin_datatype = mpi_type<OriginType>::get();
401-
MPI_Datatype target_datatype = mpi_type<TargetType>::get();
402-
int target_count_ = target_count < 0 ? origin_count : target_count;
403-
MPI_Put(origin_addr, origin_count, origin_datatype, target_rank, target_disp, target_count_, target_datatype, win);
441+
put(OriginType *origin_addr, int origin_count, int target_rank, MPI_Aint target_disp = 0, int target_count = -1) const noexcept {
442+
MPI_Datatype origin_datatype = mpi_type<OriginType>::get();
443+
MPI_Datatype target_datatype = mpi_type<TargetType>::get();
444+
int target_count_ = target_count < 0 ? origin_count : target_count;
445+
MPI_Put(origin_addr, origin_count, origin_datatype, target_rank, target_disp, target_count_, target_datatype, win);
404446
};
405447

406-
void* get_attr(int win_keyval) const {
448+
/// Accumulate data into target process through remote memory access.
449+
template <typename TargetType = BaseType, typename OriginType>
450+
std::enable_if_t<has_mpi_type<OriginType> && has_mpi_type<TargetType>, void>
451+
accumulate(OriginType const *origin_addr, int origin_count, int target_rank, MPI_Aint target_disp = 0, int target_count = -1, MPI_Op op = MPI_SUM) const noexcept {
452+
MPI_Datatype origin_datatype = mpi_type<OriginType>::get();
453+
MPI_Datatype target_datatype = mpi_type<TargetType>::get();
454+
int target_count_ = target_count < 0 ? origin_count : target_count;
455+
MPI_Accumulate(origin_addr, origin_count, origin_datatype, target_rank, target_disp, target_count_, target_datatype, op, win);
456+
}
457+
458+
/// Obtains the value of a window attribute.
459+
void* get_attr(int win_keyval) const noexcept {
407460
int flag;
408461
void *attribute_val;
409462
MPI_Win_get_attr(win, win_keyval, &attribute_val, &flag);
410463
assert(flag);
411464
return attribute_val;
412465
}
413-
BaseType* base() const { return static_cast<BaseType*>(get_attr(MPI_WIN_BASE)); }
414-
MPI_Aint size() const { return *static_cast<MPI_Aint*>(get_attr(MPI_WIN_SIZE)); }
415-
int disp_unit() const { return *static_cast<int*>(get_attr(MPI_WIN_DISP_UNIT)); }
466+
467+
// Expose some commonly used attributes
468+
BaseType* base() const noexcept { return static_cast<BaseType*>(get_attr(MPI_WIN_BASE)); }
469+
MPI_Aint size() const noexcept { return *static_cast<MPI_Aint*>(get_attr(MPI_WIN_SIZE)); }
470+
int disp_unit() const noexcept { return *static_cast<int*>(get_attr(MPI_WIN_DISP_UNIT)); }
416471
};
417472

418473
/// The shared_window class
419474
template <class BaseType>
420475
class shared_window : public window<BaseType> {
421476
public:
422-
shared_window(shared_communicator& c, MPI_Aint size) {
477+
/// Create a window and allocate memory for a shared memory buffer
478+
shared_window(shared_communicator& c, MPI_Aint size) noexcept {
423479
void* baseptr = nullptr;
424480
MPI_Win_allocate_shared(size * sizeof(BaseType), alignof(BaseType), MPI_INFO_NULL, c.get(), &baseptr, &(this->win));
425481
}
426482

427-
std::tuple<MPI_Aint, int, void*> query(int rank = MPI_PROC_NULL) const {
483+
/// Query a shared memory window
484+
std::tuple<MPI_Aint, int, void*> query(int rank = MPI_PROC_NULL) const noexcept {
428485
MPI_Aint size = 0;
429486
int disp_unit = 0;
430487
void *baseptr = nullptr;
431488
MPI_Win_shared_query(this->win, rank, &size, &disp_unit, &baseptr);
432489
return {size, disp_unit, baseptr};
433490
}
434491

435-
MPI_Aint size(int rank = MPI_PROC_NULL) const { return std::get<0>(query(rank)) / sizeof(BaseType); }
436-
int disp_unit(int rank = MPI_PROC_NULL) const { return std::get<1>(query(rank)); }
437-
BaseType* base(int rank = MPI_PROC_NULL) const { return static_cast<BaseType*>(std::get<2>(query(rank))); }
492+
// Override the commonly used attributes of the window base class
493+
BaseType* base(int rank = MPI_PROC_NULL) const noexcept { return static_cast<BaseType*>(std::get<2>(query(rank))); }
494+
MPI_Aint size(int rank = MPI_PROC_NULL) const noexcept { return std::get<0>(query(rank)) / sizeof(BaseType); }
495+
int disp_unit(int rank = MPI_PROC_NULL) const noexcept { return std::get<1>(query(rank)); }
438496
};
439497

440498
/* -----------------------------------------------------------

test/c++/mpi_window.cpp

Lines changed: 68 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,13 @@
1919
#include <gtest/gtest.h>
2020
#include <numeric>
2121

22+
// Test cases are adapted from slides and exercises of the HLRS course:
23+
// Introduction to the Message Passing Interface (MPI)
24+
// Authors: Joel Malard, Alan Simpson, (EPCC)
25+
// Rolf Rabenseifner, Traugott Streicher, Tobias Haas (HLRS)
26+
// https://fs.hlrs.de/projects/par/par_prog_ws/pdf/mpi_3.1_rab.pdf
27+
// https://fs.hlrs.de/projects/par/par_prog_ws/practical/MPI31single.tar.gz
28+
2229
TEST(MPI_Window, SharedCommunicator) {
2330
mpi::communicator world;
2431
[[maybe_unused]] auto shm = world.split_shared();
@@ -68,7 +75,7 @@ TEST(MPI_Window, RingOneSidedPut) {
6875
EXPECT_EQ(sum, (size * (size - 1)) / 2);
6976
}
7077

71-
TEST(MPI_Window, RingOneSidedAllowShared) {
78+
TEST(MPI_Window, RingOneSidedAllocShared) {
7279
mpi::communicator world;
7380
auto shm = world.split_shared();
7481
int const rank_shm = shm.rank();
@@ -91,6 +98,66 @@ TEST(MPI_Window, RingOneSidedAllowShared) {
9198
EXPECT_EQ(sum, (size_shm * (size_shm - 1)) / 2);
9299
}
93100

101+
TEST(MPI_Window, RingOneSidedStoreWinAllocSharedSignal) {
102+
mpi::communicator world;
103+
auto shm = world.split_shared();
104+
105+
int const rank_shm = shm.rank();
106+
int const size_shm = shm.size();
107+
int const right = (rank_shm+1) % size_shm;
108+
int const left = (rank_shm-1+size_shm) % size_shm;
109+
110+
mpi::shared_window<int> win{shm, 1};
111+
int *rcv_buf_ptr = win.base(rank_shm);
112+
win.lock();
113+
114+
int sum = 0;
115+
int snd_buf = rank_shm;
116+
117+
MPI_Request rq;
118+
MPI_Status status;
119+
int snd_dummy, rcv_dummy;
120+
121+
for(int i = 0; i < size_shm; ++i) {
122+
// ... The local Win_syncs are needed to sync the processor and real memory.
123+
// ... The following pair of syncs is needed that the read-write-rule is fulfilled.
124+
win.sync();
125+
126+
// ... tag=17: posting to left that rcv_buf is exposed to left, i.e.,
127+
// the left process is now allowed to store data into the local rcv_buf
128+
MPI_Irecv(&rcv_dummy, 0, MPI_INT, right, 17, shm.get(), &rq);
129+
MPI_Send (&snd_dummy, 0, MPI_INT, left, 17, shm.get());
130+
MPI_Wait(&rq, &status);
131+
132+
win.sync();
133+
134+
// MPI_Put(&snd_buf, 1, MPI_INT, right, (MPI_Aint) 0, 1, MPI_INT, win);
135+
// ... is substited by (with offset "right-my_rank" to store into right neigbor's rcv_buf):
136+
*(rcv_buf_ptr+(right-rank_shm)) = snd_buf;
137+
138+
139+
// ... The following pair of syncs is needed that the write-read-rule is fulfilled.
140+
win.sync();
141+
142+
// ... The following communication synchronizes the processors in the way
143+
// that the origin processor has finished the store
144+
// before the target processor starts to load the data.
145+
// ... tag=18: posting to right that rcv_buf was stored from left
146+
MPI_Irecv(&rcv_dummy, 0, MPI_INT, left, 18, shm.get(), &rq);
147+
MPI_Send (&snd_dummy, 0, MPI_INT, right, 18, shm.get());
148+
MPI_Wait(&rq, &status);
149+
150+
win.sync();
151+
152+
snd_buf = *rcv_buf_ptr;
153+
sum += *rcv_buf_ptr;
154+
}
155+
156+
EXPECT_EQ(sum, (size_shm * (size_shm - 1)) / 2);
157+
158+
win.unlock();
159+
}
160+
94161
TEST(MPI_Window, SharedArray) {
95162
mpi::communicator world;
96163
auto shm = world.split_shared();

0 commit comments

Comments
 (0)