44
44
#include < json/json.hpp>
45
45
46
46
#include < pybind11_json/pybind11_json.hpp>
47
+ #include < pybind11/functional.h>
47
48
48
49
#include < tiny-cuda-nn/cpp_api.h>
49
50
53
54
#define CHECK_THROW (x ) \
54
55
do { if (!(x)) throw std::runtime_error (std::string (FILE_LINE " check failed " #x)); } while (0 )
55
56
56
- c10::ScalarType torch_type (tcnn::cpp::EPrecision precision) {
57
+ c10::ScalarType torch_type (tcnn::cpp::Precision precision) {
57
58
switch (precision) {
58
- case tcnn::cpp::EPrecision ::Fp32: return torch::kFloat32 ;
59
- case tcnn::cpp::EPrecision ::Fp16: return torch::kHalf ;
59
+ case tcnn::cpp::Precision ::Fp32: return torch::kFloat32 ;
60
+ case tcnn::cpp::Precision ::Fp16: return torch::kHalf ;
60
61
default : throw std::runtime_error{" Unknown precision tcnn->torch" };
61
62
}
62
63
}
@@ -246,41 +247,19 @@ class Module {
246
247
return output;
247
248
}
248
249
249
- uint32_t n_input_dims () const {
250
- return m_module->n_input_dims ();
251
- }
250
+ uint32_t n_input_dims () const { return m_module->n_input_dims (); }
252
251
253
- uint32_t n_params () const {
254
- return ( uint32_t ) m_module->n_params ();
255
- }
252
+ uint32_t n_params () const { return ( uint32_t )m_module-> n_params (); }
253
+ tcnn::cpp::Precision param_precision () const { return m_module->param_precision (); }
254
+ c10::ScalarType c10_param_precision () const { return torch_type ( param_precision ()); }
256
255
257
- tcnn::cpp::EPrecision param_precision () const {
258
- return m_module->param_precision ();
259
- }
256
+ uint32_t n_output_dims () const { return m_module-> n_output_dims (); }
257
+ tcnn::cpp::Precision output_precision () const { return m_module->output_precision (); }
258
+ c10::ScalarType c10_output_precision () const { return torch_type ( output_precision ()); }
260
259
261
- c10::ScalarType c10_param_precision () const {
262
- return torch_type (param_precision ());
263
- }
260
+ nlohmann::json hyperparams () const { return m_module->hyperparams (); }
261
+ std::string name () const { return m_module->name (); }
264
262
265
- uint32_t n_output_dims () const {
266
- return m_module->n_output_dims ();
267
- }
268
-
269
- tcnn::cpp::EPrecision output_precision () const {
270
- return m_module->output_precision ();
271
- }
272
-
273
- c10::ScalarType c10_output_precision () const {
274
- return torch_type (output_precision ());
275
- }
276
-
277
- nlohmann::json hyperparams () const {
278
- return m_module->hyperparams ();
279
- }
280
-
281
- std::string name () const {
282
- return m_module->name ();
283
- }
284
263
285
264
private:
286
265
std::unique_ptr<tcnn::cpp::Module> m_module;
@@ -296,22 +275,34 @@ Module create_network(uint32_t n_input_dims, uint32_t n_output_dims, const nlohm
296
275
}
297
276
#endif
298
277
299
- Module create_encoding (uint32_t n_input_dims, const nlohmann::json& encoding, tcnn::cpp::EPrecision requested_precision) {
278
+ Module create_encoding (uint32_t n_input_dims, const nlohmann::json& encoding, tcnn::cpp::Precision requested_precision) {
300
279
return Module{tcnn::cpp::create_encoding (n_input_dims, encoding, requested_precision)};
301
280
}
302
281
303
282
PYBIND11_MODULE (TORCH_EXTENSION_NAME, m) {
304
- py::enum_<tcnn::cpp::EPrecision>(m, " Precision" )
305
- .value (" Fp32" , tcnn::cpp::EPrecision::Fp32)
306
- .value (" Fp16" , tcnn::cpp::EPrecision::Fp16)
283
+ py::enum_<tcnn::cpp::LogSeverity>(m, " LogSeverity" )
284
+ .value (" Info" , tcnn::cpp::LogSeverity::Info)
285
+ .value (" Debug" , tcnn::cpp::LogSeverity::Debug)
286
+ .value (" Warning" , tcnn::cpp::LogSeverity::Warning)
287
+ .value (" Error" , tcnn::cpp::LogSeverity::Error)
288
+ .value (" Success" , tcnn::cpp::LogSeverity::Success)
289
+ .export_values ()
290
+ ;
291
+
292
+ py::enum_<tcnn::cpp::Precision>(m, " Precision" )
293
+ .value (" Fp32" , tcnn::cpp::Precision::Fp32)
294
+ .value (" Fp16" , tcnn::cpp::Precision::Fp16)
307
295
.export_values ()
308
296
;
309
297
310
298
m.def (" batch_size_granularity" , &tcnn::cpp::batch_size_granularity);
299
+ m.def (" default_loss_scale" , &tcnn::cpp::default_loss_scale);
311
300
m.def (" free_temporary_memory" , &tcnn::cpp::free_temporary_memory);
312
301
m.def (" has_networks" , &tcnn::cpp::has_networks);
313
302
m.def (" preferred_precision" , &tcnn::cpp::preferred_precision);
314
303
304
+ m.def (" set_log_callback" , &tcnn::cpp::set_log_callback);
305
+
315
306
// Encapsulates an abstract context of an operation
316
307
// (commonly the forward pass) to be passed on to other
317
308
// operations (commonly the backward pass).
0 commit comments