Skip to content

Commit

Permalink
fixed issue with pytorch deleting objects that are being used
Browse files Browse the repository at this point in the history
  • Loading branch information
benja263 committed Feb 4, 2025
1 parent 9d8f93a commit c1a0369
Show file tree
Hide file tree
Showing 13 changed files with 69 additions and 34 deletions.
6 changes: 6 additions & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ find_package(Python3 COMPONENTS Interpreter Development REQUIRED)
include(CMakeDependentOption)
option(USE_CUDA "Build with GPU acceleration" OFF)
option(COVERAGE "Run coverage report" OFF)
option(ASAN "Use address-santizier" OFF)

if(COVERAGE)
message(STATUS "Coverage build")
Expand Down Expand Up @@ -231,6 +232,11 @@ target_compile_definitions(gbrl_cpp PRIVATE MODULE_NAME="gbrl_cpp")
target_link_libraries(gbrl_cpp PRIVATE ${Python3_LIBRARIES})
target_link_libraries(gbrl_cpp PRIVATE gbrl_cpp_src)

if (CMAKE_BUILD_TYPE STREQUAL "Debug")
target_link_libraries(gbrl_cpp PRIVATE ${DEBUG_LINK_FLAGS})
endif()


if (USE_CUDA)
target_link_libraries(gbrl_cpp PRIVATE cuda_gbrl_src)
endif()
Expand Down
2 changes: 1 addition & 1 deletion gbrl/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
# https://nvlabs.github.io/gbrl/license.html
#
##############################################################################
__version__ = "1.0.10"
__version__ = "1.0.11"

import importlib.util
import os
Expand Down
23 changes: 17 additions & 6 deletions gbrl/gbrl_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,9 +76,11 @@ def reset(self) -> None:
def step(self, features: Union[np.ndarray, th.Tensor, Tuple], grads: Union[np.ndarray, th.Tensor]) -> None:
features, grads = ensure_same_type(features, grads)
if isinstance(features, th.Tensor):
if self.feature_weights is None:
self.feature_weights = th.ones(get_input_dim(features), device=features.device).float()
self.cpp_model.step(get_tensor_info(features.float()), None, get_tensor_info(grads.float()))
features = features.float()
grads = grads.float()
self._save_memory = (features, grads)
self.cpp_model.step(get_tensor_info(features), None, get_tensor_info(grads))
self._save_memory = None
else:
num_features, cat_features = preprocess_features(features)
grads = np.ascontiguousarray(grads.reshape((len(grads), self.params['output_dim']))).astype(numerical_dtype)
Expand Down Expand Up @@ -273,12 +275,16 @@ def predict(self, features: Union[np.ndarray, th.Tensor], requires_grad: bool =
if stop_idx is None:
stop_idx = 0
if isinstance(features, th.Tensor):
preds_dlpack = self.cpp_model.predict(get_tensor_info(features.float()), None, start_idx, stop_idx)
features = features.float()
# store features so that data isn't garbage collected while GBRL uses it
self._save_memory = features
preds_dlpack = self.cpp_model.predict(get_tensor_info(features), None, start_idx, stop_idx)
preds = th.from_dlpack(preds_dlpack)
if self.student_model is not None:
student_dlpack = self.student_model.predict(get_tensor_info(features.float()), None)
student_dlpack = self.student_model.predict(get_tensor_info(features), None)
preds += th.from_dlpack(student_dlpack)
preds.requires_grad_(requires_grad)
self._save_memory = None
else:
num_features, cat_features = preprocess_features(features)
preds = self.cpp_model.predict(num_features, cat_features, start_idx, stop_idx)
Expand Down Expand Up @@ -528,7 +534,12 @@ def step(self, observations: Union[np.ndarray, th.Tensor], theta_grad: np.ndarra
grads = concatenate_arrays(theta_grad, value_grad)
observations, grads = ensure_same_type(observations, grads)
if isinstance(observations, th.Tensor):
self.cpp_model.step(get_tensor_info(observations.float()), None, get_tensor_info(grads.float()))
observations = observations.float()
grads = grads.float()
# store data so that data isn't garbage collected while GBRL uses it
self._save_memory = (observations, grads)
self.cpp_model.step(get_tensor_info(observations), None, get_tensor_info(grads))
self._save_memory = None
else:
num_observations, cat_observations = preprocess_features(observations)
grads = np.ascontiguousarray(grads).astype(numerical_dtype)
Expand Down
5 changes: 5 additions & 0 deletions gbrl/src/cpp/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,11 @@ if (CMAKE_BUILD_TYPE STREQUAL "Debug")
message(STATUS "Configuring for Debug build.")
if (CMAKE_SYSTEM_NAME STREQUAL "Linux" OR CMAKE_SYSTEM_NAME STREQUAL "Darwin")
set(DEBUG_CXX_FLAGS "-g")
if (ASAN)
message(STATUS "Configuring for ASAN.")
set(DEBUG_CXX_FLAGS "-g -fsanitize=address,undefined -fno-omit-frame-pointer")
set(DEBUG_LINK_FLAGS "-fsanitize=address,undefined")
endif()
elseif (MSVC)
set(DEBUG_CXX_FLAGS "/Zi") # Debug symbols
set(CMAKE_EXE_LINKER_FLAGS "${CMAKE_EXE_LINKER_FLAGS} /DEBUG") # Generate .pdb
Expand Down
17 changes: 14 additions & 3 deletions gbrl/src/cpp/binding.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -54,16 +54,27 @@ void get_tensor_info(py::tuple tensor_info, T*& ptr, std::vector<size_t>& shape,
if (tensor_info.size() != 4) {
throw std::runtime_error("Expected a tuple of size 4: (data_ptr, shape, dtype, device)");
}
// size_t raw_ptr = tensor_info[0].cast<size_t>();
size_t raw_ptr = tensor_info[0].cast<uintptr_t>();

// Cast data_ptr directly
ptr = reinterpret_cast<T*>(tensor_info[0].cast<size_t>());
if (raw_ptr == 0 || raw_ptr == (size_t)-1) { // Check for null or invalid pointer values
std::cerr << "ERROR: Extracted an invalid pointer! Setting ptr to nullptr." << std::endl;
ptr = nullptr;
} else {
ptr = reinterpret_cast<T*>(raw_ptr);
}

if (ptr) {
if (reinterpret_cast<uintptr_t>(ptr) % alignof(T) != 0) {
std::cerr << "ERROR: Pointer is not properly aligned! Possible misaligned memory access." << std::endl;
}
}
// Extract shape
py::tuple shape_tuple = tensor_info[1].cast<py::tuple>();
shape.clear();
for (py::handle dim : shape_tuple) {
shape.push_back(dim.cast<size_t>());
}

// Extract and verify dtype
std::string dtype = tensor_info[2].cast<std::string>();
std::string expected_dtype;
Expand Down
2 changes: 1 addition & 1 deletion gbrl/src/cpp/config.h
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,6 @@

#define MAJOR_VERSION 1
#define MINOR_VERSION 0
#define PATCH_VERSION 10
#define PATCH_VERSION 11

#endif // VERSION_CONFIG_H
12 changes: 6 additions & 6 deletions gbrl/src/cpp/fitter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -299,9 +299,9 @@ int Fitter::fit_greedy_tree(dataSet *dataset, ensembleData *edata, ensembleMetaD
float score = crnt_node->getSplitScore(dataset, edata->feature_weights, metadata->split_score_func, split_candidates[j], metadata->min_data_in_leaf);
int feat_idx = (split_candidates[j].categorical_value == nullptr) ? split_candidates[j].feature_idx : split_candidates[j].feature_idx + metadata->n_num_features;
score = score * edata->feature_weights[feat_idx] - parent_score;
#ifdef DEBUG
std::cout << " cand: " << j << " score: " << score << " parent score: " << parent_score << " info: " << split_candidates[j] << std::endl;
#endif
// #ifdef DEBUG
// std::cout << " cand: " << j << " score: " << score << " parent score: " << parent_score << " info: " << split_candidates[j] << std::endl;
// #endif
if (score > local_best_score) {
local_best_score = score;
local_chosen_idx = j;
Expand Down Expand Up @@ -396,9 +396,9 @@ int Fitter::fit_oblivious_tree(dataSet *dataset, ensembleData *edata, ensembleMe
}
int feat_idx = (split_candidates[j].categorical_value == nullptr) ? split_candidates[j].feature_idx : split_candidates[j].feature_idx + metadata->n_num_features;
score = score*edata->feature_weights[feat_idx] - parent_score;
#ifdef DEBUG
std::cout << " cand: " << j << " score: " << score << " parent_score: " << parent_score << " info: " << split_candidates[j] << std::endl;
#endif
// #ifdef DEBUG
// std::cout << " cand: " << j << " score: " << score << " parent_score: " << parent_score << " info: " << split_candidates[j] << std::endl;
// #endif
if (score > local_best_score) {
local_best_score = score;
local_chosen_idx = j;
Expand Down
3 changes: 1 addition & 2 deletions gbrl/src/cpp/predictor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -129,9 +129,9 @@ void Predictor::predict_cpu(dataSet *dataset, float *preds, const ensembleData *
if (n_tree_threads > 1 && parallel_predict && n_tree_threads > n_sample_threads){
std::vector<float *> preds_buffer(n_tree_threads);
int trees_per_thread = n_trees / n_tree_threads;
omp_set_num_threads(n_tree_threads);
for (int i = 0; i < n_tree_threads; ++i)
preds_buffer[i] = init_zero_mat(n_samples*output_dim);
omp_set_num_threads(n_tree_threads);
#pragma omp parallel
{
int thread_id = omp_get_thread_num();
Expand Down Expand Up @@ -227,7 +227,6 @@ void Predictor::predict_over_trees(const float *obs, const char *categorical_obs
const int* feature_indices = edata->feature_indices;
const int* tree_indices = edata->tree_indices;
const char* categorical_values = edata->categorical_values;


while (tree_idx < stop_tree_idx)
{
Expand Down
5 changes: 4 additions & 1 deletion gbrl/src/cpp/types.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -182,6 +182,7 @@ ensembleData* ensemble_data_alloc(ensembleMetaData *metadata){
data_size += sizeof(float) * metadata->output_dim;
memset(edata->bias, 0, metadata->output_dim * sizeof(float));
edata->feature_weights = new float[metadata->input_dim];
data_size += sizeof(float) * metadata->input_dim;
memset(edata->feature_weights, 0, metadata->input_dim * sizeof(float));
int split_sizes = (metadata->grow_policy == OBLIVIOUS) ? metadata->max_trees : metadata->max_leaves;
#ifdef DEBUG
Expand Down Expand Up @@ -233,6 +234,7 @@ ensembleData* ensemble_copy_data_alloc(ensembleMetaData *metadata){
data_size += sizeof(float) * metadata->output_dim;
memset(edata->bias, 0, metadata->output_dim * sizeof(float));
edata->feature_weights = new float[metadata->input_dim];
data_size += sizeof(float) * metadata->input_dim;
memset(edata->feature_weights, 0, metadata->input_dim * sizeof(float));
int split_sizes = (metadata->grow_policy == OBLIVIOUS) ? metadata->n_trees : metadata->n_leaves;
#ifdef DEBUG
Expand Down Expand Up @@ -284,6 +286,7 @@ ensembleData* copy_ensemble_data(ensembleData *other_edata, ensembleMetaData *me
memcpy(edata->bias, other_edata->bias, metadata->output_dim * sizeof(float));
edata->feature_weights = new float[metadata->input_dim];
memcpy(edata->feature_weights, other_edata->feature_weights, metadata->input_dim * sizeof(float));
data_size += sizeof(float) * metadata->input_dim;
int split_sizes = (metadata->grow_policy == OBLIVIOUS) ? metadata->n_trees : metadata->n_leaves;
#ifdef DEBUG
edata->n_samples = new int[metadata->n_leaves]; // debugging
Expand Down Expand Up @@ -330,10 +333,10 @@ void ensemble_data_dealloc(ensembleData *edata){
#ifdef DEBUG
delete[] edata->n_samples;
#endif
delete[] edata->tree_indices;
delete[] edata->depths;
delete[] edata->values;
delete[] edata->feature_indices;
delete[] edata->tree_indices;
delete[] edata->feature_values;
delete[] edata->edge_weights;
delete[] edata->is_numerics;
Expand Down
13 changes: 8 additions & 5 deletions gbrl/src/cuda/cuda_fitter.cu
Original file line number Diff line number Diff line change
Expand Up @@ -163,7 +163,7 @@ void evaluate_oblivious_splits_cuda(dataSet *dataset, ensembleData *edata, TreeN
calc_oblivious_parallelism(candidata->n_candidates, metadata->output_dim, threads_per_block, metadata->split_score_func, depth);
for (int i = 0; i < n_nodes; ++i){
if (metadata->split_score_func == Cosine){
size_t shared_mem = sizeof(float)*2*(metadata->output_dim + 3)*threads_per_block;
size_t shared_mem = sizeof(float)*2*(metadata->output_dim + 2)*threads_per_block;
split_score_cosine_cuda<<<candidata->n_candidates, threads_per_block, shared_mem>>>(dataset->obs, dataset->categorical_obs, dataset->build_grads, edata->feature_weights, nodes[i], candidata->candidate_indices, candidata->candidate_values, candidata->candidate_categories, candidata->candidate_numeric, metadata->min_data_in_leaf, split_data->oblivious_split_scores + candidata->n_candidates*i, dataset->n_samples, metadata->n_num_features);
} else if (metadata->split_score_func == L2){
size_t shared_mem = sizeof(float)*2*(metadata->output_dim + 1)*threads_per_block;
Expand Down Expand Up @@ -260,6 +260,8 @@ __global__ void split_score_cosine_cuda(const float* __restrict__ obs, const cha
return;
}



// Accumulate per thread partial sum
for(int i=threadIdx.x; i < n_samples; i += blockDim.x) {
int sample_idx = __ldg(&node->sample_indices[i]); // Access the spec
Expand Down Expand Up @@ -299,10 +301,11 @@ __global__ void split_score_cosine_cuda(const float* __restrict__ obs, const cha
if (denominator > 0.0f) {
cosine = (l_dot_sum[0] + r_dot_sum[0]) / sqrtf(denominator);
}
else {
split_scores[cand_idx] = -INFINITY;
return;
}
// else {
// split_scores[cand_idx] = -INFINITY;
// split_scores[cand_idx] = 0.0f;
// return;
// }
int feat_idx = __ldg(&candidate_indices[cand_idx]);
if (!candidate_numeric[cand_idx])
feat_idx += n_num_features;
Expand Down
1 change: 1 addition & 0 deletions gbrl/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ def get_tensor_info(tensor: th.Tensor) -> Tuple[int, Tuple[int, ...], str, str]:
"""
if not tensor.is_contiguous():
tensor = tensor.contiguous()

data_ptr = tensor.data_ptr()
shape = tuple(tensor.size()) # Convert torch.Size to tuple
dtype = str(tensor.dtype)
Expand Down
4 changes: 2 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ build-backend = "setuptools.build_meta"

[project]
name = "gbrl"
version = "1.0.10"
version = "1.0.11"
description = "Gradient Boosted Trees for RL"
readme = { file = "README.md", content-type = "text/markdown" }
authors = [
Expand Down Expand Up @@ -45,7 +45,7 @@ requires-python = ">=3.9"

[tool.poetry]
name = "gbrl"
version = "1.0.9"
version = "1.0.11"
description = "Gradient Boosted Trees for RL"
authors = ["Benjamin Fuhrer <[email protected]>", "Chen Tessler <[email protected]>", "Gal Dalal <[email protected]>"]
readme = "README.md"
Expand Down
10 changes: 3 additions & 7 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,8 @@ def build_extension(self, ext):
]
if os.environ.get('COVERAGE', '0') == '1':
cmake_args.append('-DCOVERAGE=ON')
if os.environ.get('ASAN', '0') == '1':
cmake_args.append('-DASAN=ON')
if sysconfig.get_config_var('LIBRARY') is not None:
cmake_args.append('-DPYTHON_LIBRARY=' + sysconfig.get_config_var('LIBRARY'))
if 'CC' in os.environ:
Expand All @@ -62,12 +64,6 @@ def build_extension(self, ext):
cmake_args.append('-DUSE_CUDA=ON')
if 'CUDACXX' in os.environ:
cmake_args.append('-DCMAKE_CUDA_COMPILER=' + os.environ['CUDACXX'])
# # Set CMAKE_PREFIX_PATH for LLVM
# if platform.system() == 'Darwin': # MacOS specific logic
# brew_prefix = subprocess.check_output(['brew', '--prefix', 'llvm']).decode().strip()
# print(brew_prefix)
# cmake_args.append(f'-DCMAKE_PREFIX_PATH={brew_prefix}')
# cmake_args.append(f'-DLLVM_DIR={brew_prefix}/lib/cmake/llvm')

build_temp = self.build_temp
if not os.path.exists(build_temp):
Expand Down Expand Up @@ -105,7 +101,7 @@ def move_built_library(self, build_temp):

setup(
name="gbrl",
version = "1.0.10",
version = "1.0.11",
description = "Gradient Boosted Trees for RL",
author="Benjamin Fuhrer, Chen Tessler, Gal Dalal",
author_email="[email protected], [email protected]. [email protected]",
Expand Down

0 comments on commit c1a0369

Please sign in to comment.