Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
85 changes: 70 additions & 15 deletions unified-runtime/source/adapters/offload/kernel.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,18 +19,37 @@

#include "common.hpp"

namespace {
template <typename T> T nextPowerOf2(T Val) {
// https://graphics.stanford.edu/%7Eseander/bithacks.html#RoundUpPowerOf2
Val--;
Val |= Val >> 1;
Val |= Val >> 2;
Val |= Val >> 4;
Val |= Val >> 8;
Val |= Val >> 16;
if constexpr (sizeof(Val) > 4) {
Val |= Val >> 32;
}
return ++Val;
}
} // namespace

struct ur_kernel_handle_t_ : RefCounted {

// Simplified version of the CUDA adapter's argument implementation
struct OffloadKernelArguments {
static constexpr size_t MaxParamBytes = 4096u;
using args_t = std::array<char, MaxParamBytes>;
using final_buffer_t = std::array<char, MaxParamBytes>;
using args_t = std::vector<char>;
using args_size_t = std::vector<size_t>;
using args_ptr_t = std::vector<void *>;
args_t Storage;
size_t StorageUsed = 0;
using args_offset_t = std::vector<size_t>;
final_buffer_t RealisedBuffer;
args_t ParamStorage;
args_size_t ParamSizes;
args_ptr_t Pointers;
args_offset_t Pointers;
bool Dirty = true;
size_t RealisedSpace;

struct MemObjArg {
ur_mem_handle_t_ *Mem;
Expand All @@ -47,12 +66,12 @@ struct ur_kernel_handle_t_ : RefCounted {
ParamSizes.resize(Index + 1);
}
ParamSizes[Index] = Size;
// Calculate the insertion point in the array.
size_t InsertPos = std::accumulate(std::begin(ParamSizes),
std::begin(ParamSizes) + Index, 0);
// Update the stored value for the argument.
std::memcpy(&Storage[InsertPos], Arg, Size);
Pointers[Index] = &Storage[InsertPos];

auto Base = ParamStorage.size();
ParamStorage.resize(Base + Size);
std::memcpy(&ParamStorage[Base], Arg, Size);
Pointers[Index] = Base;
Dirty = true;
}

void addMemObjArg(int Index, ur_mem_handle_t hMem, ur_mem_flags_t Flags) {
Expand All @@ -66,14 +85,50 @@ struct ur_kernel_handle_t_ : RefCounted {
}
}
MemObjArgs.push_back(MemObjArg{hMem, Index, Flags});
Dirty = true;
}

const args_ptr_t &getPointers() const noexcept { return Pointers; }
void realise() noexcept {
if (!Dirty) {
return;
}

size_t Space = sizeof(RealisedBuffer);
void *Offset = reinterpret_cast<void *>(0);
char *Base = &RealisedBuffer[0];
for (size_t I = 0; I < Pointers.size(); I++) {
void *ValueBase = &ParamStorage[Pointers[I]];
size_t Size = ParamSizes[I];
size_t Align = nextPowerOf2(Size);

// Align the value to a multiple of the size
// TODO: This is probably not correct, but UR doesn't allow specifying
// the alignment of arguments
if (!std::align(Align, Size, Offset, Space) && Offset) {
// Ran out of space. TODO: Handle properly
// TODO: Since we start at address 0, there's no way to check whether
// the first allocation is a success or not.
abort();
}
Space -= Size;

const char *getStorage() const noexcept { return Storage.data(); }
std::memcpy(&Base[reinterpret_cast<uintptr_t>(Offset)], ValueBase,
Size);
Offset = &reinterpret_cast<char *>(Offset)[Size];
}

Dirty = false;
RealisedSpace = reinterpret_cast<uintptr_t>(Offset);
}

const char *getStorage() noexcept {
realise();
return &RealisedBuffer[0];
}

size_t getStorageSize() const noexcept {
return std::accumulate(std::begin(ParamSizes), std::end(ParamSizes), 0);
size_t getStorageSize() noexcept {
realise();
return RealisedSpace;
}
};

Expand Down
Loading