Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,14 @@ cmake_minimum_required(VERSION 3.15...3.27)
project(${SKBUILD_PROJECT_NAME} LANGUAGES CXX)

find_package(Python 3.10 REQUIRED COMPONENTS Interpreter Development.Module)
execute_process(
COMMAND "${Python_EXECUTABLE}"
"-c" "from jax.extend import ffi; print(ffi.include_dir())"
OUTPUT_STRIP_TRAILING_WHITESPACE OUTPUT_VARIABLE XLA_DIR)
message(STATUS "XLA include directory: ${XLA_DIR}")

find_package(nanobind CONFIG REQUIRED)
nanobind_add_module(_jaxbind NOSTRIP NB_SUPPRESS_WARNINGS NOMINSIZE src/_jaxbind.cc)
target_include_directories(_jaxbind PUBLIC ${XLA_DIR})

install(TARGETS _jaxbind LIBRARY DESTINATION .)
57 changes: 57 additions & 0 deletions src/_jaxbind.cc
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
#include <nanobind/nanobind.h>
#include <nanobind/ndarray.h>
#include <nanobind/stl/complex.h>
#include <xla/ffi/api/ffi.h>
#include <iostream>

#include <vector>
Expand All @@ -21,6 +22,7 @@
namespace detail_pymodule_jax {

namespace nb=nanobind;
namespace ffi = xla::ffi;
using namespace std;

using shape_t = vector<size_t>;
Expand Down Expand Up @@ -61,6 +63,56 @@ template <typename T>
nb::capsule EncapsulateFunction(T* fn)
{ return nb::capsule(bit_cast<void*>(fn), "xla._CUSTOM_CALL_TARGET"); }

ffi::Error pycallImpl(ffi::Dictionary attrs,
ffi::RemainingArgs args,
ffi::RemainingRets results)
{
nb::gil_scoped_acquire get_GIL;

static const map<ffi::DataType, nb::dlpack::dtype> tcdict = {
{ffi::DataType::F32 , nb::dtype<float>()},
{ffi::DataType::F64 , nb::dtype<double>()},
{ffi::DataType::U8 , nb::dtype<uint8_t>()},
{ffi::DataType::U64 , nb::dtype<uint64_t>()},
{ffi::DataType::C64 , nb::dtype<complex<float>>()},
{ffi::DataType::C128, nb::dtype<complex<double>>()}
};

size_t nargs = attrs.get<size_t>("nargs").value();
size_t n_out = attrs.get<size_t>("n_out").value();

nb::list py_in;
for (size_t i=0; i<nargs; i++)
{
// Getting type, rank, and shape of the input
auto arg = args.get<ffi::AnyBuffer>(i).value(); //FIXME
auto dtp_a = tcdict.at(arg.element_type());
auto dims = arg.dimensions();
shape_t shape_a;
for (auto x : dims) shape_a.push_back(x);
// Building "pseudo" numpy arrays on top of the provided memory regions.
// This should be completely fine, as long as the called function does not
// keep any references to them.
CNpArr py_a = make_CArr_wrapper(dtp_a, arg.untyped_data(), shape_a);
py_in.append(py_a);
}
nb::list py_out;
for (size_t i=0; i<n_out; i++) {
// Getting type, rank, and shape of the output
auto out = results.get<ffi::AnyBuffer>(i).value(); //FIXME
auto dtp_out = tcdict.at(out->element_type());
auto dims = out->dimensions();
shape_t shape_out;
for (auto x : dims) shape_out.push_back(x);
NpArr py_o = make_Arr_wrapper(dtp_out, out->untyped_data(), shape_out);
py_out.append(py_o);
}
// auto func = attrs.get<nb::object>("func").value();
// auto py_kwargs = attrs.get<nb::dict>("kwargs").value();
// func(py_out, py_in, py_kwargs);
return ffi::Error::Success();
}

void pycall(void *out_raw, void **in)
{
nb::gil_scoped_acquire get_GIL;
Expand Down Expand Up @@ -122,10 +174,15 @@ void pycall(void *out_raw, void **in)
func(py_out, py_in, py_kwargs);
}

XLA_FFI_DEFINE_HANDLER_SYMBOL(
pycallNew, pycallImpl,
ffi::Ffi::Bind().Attrs().RemainingArgs().RemainingRets());

nb::dict Registrations()
{
nb::dict dict;
dict["cpu_pycall"] = EncapsulateFunction(pycall);
// dict["cpu_pycall"] = nb::capsule(reinterpret_cast<void *>(pycall);
return dict;
}

Expand Down