Skip to content

Commit 9ca63c5

Browse files
ezyangfacebook-github-bot
authored andcommitted
Reorganize methods in Type, add CPUTypeDefault/CUDATypeDefault (pytorch#11205)
Summary: Pull Request resolved: pytorch#11205 Our short term plan for supporting out of tree complex development requires an external library to add a custom subclass of Type without access to the code generation facilities in ATen. This commit reorganizes Type so as to minimize the amount of boilerplate you have to write when making a subclass of Type. In particular, it: - Creates a new CPUTypeDefault/CUDATypeDefault class, which you are intended to inherit from, which provides default implementations of CPU/CUDA that is layout/dtype agnostic. - Adds new getCPUAllocator() and getCUDAAllocator() functions, as a more public API to get your hands on Allocator - Adds allocator() and getDeviceFromPtr(), abstracting the device specific parts of storage() methods; these methods are now implemented in base TypeDefault. - Delete the static typeString() method, which is now dead. - Move is_cuda/is_sparse/is_distributed to TypeDefault. Reviewed By: SsnL Differential Revision: D9631619 fbshipit-source-id: 40b600d99691230e36e03eb56434c351cbc2aa3a
1 parent f0d3fda commit 9ca63c5

22 files changed

+177
-148
lines changed

aten/src/ATen/CPUTypeDefault.cpp

+20
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
#include <ATen/CPUTypeDefault.h>
2+
3+
#include <ATen/Context.h>
4+
#include <ATen/CPUGenerator.h>
5+
6+
namespace at {
7+
8+
Allocator* CPUTypeDefault::allocator() const {
9+
return getCPUAllocator();
10+
}
11+
12+
Device CPUTypeDefault::getDeviceFromPtr(void * data) const {
13+
return DeviceType::CPU;
14+
}
15+
16+
std::unique_ptr<Generator> CPUTypeDefault::generator() const {
17+
return std::unique_ptr<Generator>(new CPUGenerator(&at::globalContext()));
18+
}
19+
20+
} // namespace at

aten/src/ATen/CPUTypeDefault.h

+14
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
#pragma once
2+
#include <ATen/TypeDefault.h>
3+
4+
namespace at {
5+
6+
struct AT_API CPUTypeDefault : public TypeDefault {
7+
CPUTypeDefault(TensorTypeId type_id, bool is_variable, bool is_undefined)
8+
: TypeDefault(type_id, is_variable, is_undefined) {}
9+
Allocator* allocator() const override;
10+
Device getDeviceFromPtr(void * data) const override;
11+
std::unique_ptr<Generator> generator() const override;
12+
};
13+
14+
} // namespace at

aten/src/ATen/Context.cpp

+4
Original file line numberDiff line numberDiff line change
@@ -118,4 +118,8 @@ Type& getMaybeVariableType(const TensorImpl* impl) {
118118
backend, impl->scalar_type(), impl->is_variable());
119119
}
120120

121+
Allocator* getCPUAllocator() {
122+
return getTHDefaultAllocator();
123+
}
124+
121125
}

aten/src/ATen/Context.h

+2
Original file line numberDiff line numberDiff line change
@@ -158,6 +158,8 @@ static inline Type& getNonVariableType(DeviceType p, ScalarType s) {
158158
AT_API Type& getMaybeVariableType(TensorOptions options);
159159
AT_API Type& getMaybeVariableType(const TensorImpl*);
160160

161+
AT_API Allocator* getCPUAllocator();
162+
161163
static inline Type& CPU(ScalarType s) {
162164
return getNonVariableType(Backend::CPU, s);
163165
}

aten/src/ATen/UndefinedType.cpp

+10-8
Original file line numberDiff line numberDiff line change
@@ -11,9 +11,14 @@ ScalarType UndefinedType::scalarType() const {
1111
Backend UndefinedType::backend() const {
1212
return Backend::Undefined;
1313
}
14-
bool UndefinedType::is_cuda() const { return false; }
15-
bool UndefinedType::is_sparse() const { return false; }
16-
bool UndefinedType::is_distributed() const { return false; }
14+
15+
Allocator* UndefinedType::allocator() const {
16+
AT_ERROR("allocator not defined for UndefinedType");
17+
}
18+
19+
Device UndefinedType::getDeviceFromPtr(void*) const {
20+
AT_ERROR("getDeviceFromPtr not defined for UndefinedType");
21+
}
1722

1823
Storage UndefinedType::storage(bool resizable) const {
1924
AT_ERROR("storage not defined for UndefinedType");
@@ -38,8 +43,9 @@ std::unique_ptr<Generator> UndefinedType::generator() const {
3843
}
3944

4045
const char * UndefinedType::toString() const {
41-
return UndefinedType::typeString();
46+
return "UndefinedType";
4247
}
48+
4349
TypeID UndefinedType::ID() const {
4450
return TypeID::Undefined;
4551
}
@@ -61,10 +67,6 @@ Type & UndefinedType::toScalarType(ScalarType s) const {
6167
AT_ERROR("toScalarType not implemented for UndefinedType to non-UndefinedType");
6268
}
6369

64-
const char * UndefinedType::typeString() {
65-
return "UndefinedType";
66-
}
67-
6870
Tensor & UndefinedType::s_copy_(Tensor & self, const Tensor & src, bool non_blocking) const {
6971
AT_ERROR("s_copy not defined for UndefinedType");
7072
}

aten/src/ATen/UndefinedType.h

+2-4
Original file line numberDiff line numberDiff line change
@@ -15,9 +15,8 @@ struct UndefinedType final : public TypeDefault {
1515
explicit UndefinedType();
1616
virtual ScalarType scalarType() const override;
1717
virtual Backend backend() const override;
18-
virtual bool is_cuda() const override;
19-
virtual bool is_sparse() const override;
20-
virtual bool is_distributed() const override;
18+
virtual Allocator* allocator() const override;
19+
virtual Device getDeviceFromPtr(void* data) const override;
2120
virtual Storage storage(bool resizable = false) const override;
2221
virtual Storage storage(size_t size, bool resizable = false) const override;
2322
virtual Storage storageFromBlob(void * data, int64_t size, const std::function<void(void*)> & deleter) const override;
@@ -28,7 +27,6 @@ struct UndefinedType final : public TypeDefault {
2827
virtual Type & toBackend(Backend b) const override;
2928
virtual Type & toScalarType(ScalarType s) const override;
3029
virtual TypeID ID() const override;
31-
static const char * typeString();
3230
virtual Storage unsafeStorageFromTH(void * th_pointer, bool retain) const override;
3331
virtual Tensor unsafeTensorFromTH(void * th_pointer, bool retain) const override;
3432

aten/src/ATen/cuda/CUDAContext.cpp

+6-2
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
#include "ATen/cuda/CUDAContext.h"
2-
#include "THC/THCGeneral.h"
2+
#include "THC/THCGeneral.hpp"
33

44
namespace at { namespace cuda {
55

@@ -45,6 +45,10 @@ void uncheckedSetCurrentCUDAStream(CUDAStream stream) {
4545
detail::CUDAStream_uncheckedSetStream(stream.internals());
4646
}
4747

48+
Allocator* getCUDADeviceAllocator() {
49+
return at::globalContext().getTHCState()->cudaDeviceAllocator;
50+
}
51+
4852
/* Handles */
4953
#ifndef __HIP_PLATFORM_HCC__
5054
cusparseHandle_t getCurrentCUDASparseHandle() {
@@ -54,4 +58,4 @@ void uncheckedSetCurrentCUDAStream(CUDAStream stream) {
5458

5559
} // namespace cuda
5660

57-
} // namespace at
61+
} // namespace at

aten/src/ATen/cuda/CUDAContext.h

+2
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,8 @@ AT_API CUDAStream getCurrentCUDAStream(int64_t device = -1);
5454
AT_API void setCurrentCUDAStream(CUDAStream stream);
5555
AT_API void uncheckedSetCurrentCUDAStream(CUDAStream stream);
5656

57+
AT_API Allocator* getCUDADeviceAllocator();
58+
5759
/* Handles */
5860
#ifndef __HIP_PLATFORM_HCC__
5961
AT_API cusparseHandle_t getCurrentCUDASparseHandle();

aten/src/ATen/cuda/CUDADevice.h

+16
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
#pragma once
2+
3+
#include "ATen/cuda/Exceptions.h"
4+
5+
#include "cuda.h"
6+
7+
namespace at {
8+
namespace cuda {
9+
10+
inline Device getDeviceFromPtr(void* ptr) {
11+
struct cudaPointerAttributes attr;
12+
AT_CUDA_CHECK(cudaPointerGetAttributes(&attr, ptr));
13+
return {DeviceType::CUDA, attr.device};
14+
}
15+
16+
}} // namespace at::cuda
+19
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
#include <ATen/cuda/CUDATypeDefault.h>
2+
3+
#include <ATen/cuda/CUDAContext.h>
4+
#include <ATen/cuda/CUDADevice.h>
5+
#include <ATen/CUDAGenerator.h>
6+
7+
namespace at {
8+
9+
Allocator* CUDATypeDefault::allocator() const {
10+
return cuda::getCUDADeviceAllocator();
11+
}
12+
Device CUDATypeDefault::getDeviceFromPtr(void * data) const {
13+
return cuda::getDeviceFromPtr(data);
14+
}
15+
std::unique_ptr<Generator> CUDATypeDefault::generator() const {
16+
return std::unique_ptr<Generator>(new CUDAGenerator(&at::globalContext()));
17+
}
18+
19+
} // namespace at

aten/src/ATen/cuda/CUDATypeDefault.h

+16
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
#pragma once
2+
#include <ATen/TypeDefault.h>
3+
#include <ATen/cuda/ATenCUDAGeneral.h>
4+
5+
namespace at {
6+
7+
struct AT_CUDA_API CUDATypeDefault : public TypeDefault {
8+
CUDATypeDefault(TensorTypeId type_id, bool is_variable, bool is_undefined)
9+
: TypeDefault(type_id, is_variable, is_undefined) {}
10+
11+
Allocator* allocator() const override;
12+
Device getDeviceFromPtr(void * data) const override;
13+
std::unique_ptr<Generator> generator() const override;
14+
};
15+
16+
} // namespace at

aten/src/ATen/gen.py

+2
Original file line numberDiff line numberDiff line change
@@ -256,6 +256,8 @@ def generate_storage_type_and_tensor(backend, density, scalar_type, declarations
256256
]
257257
env['extra_cuda_headers'] = ['#include <ATen/cuda/CUDAHalf.cuh>']
258258
env['extra_cuda_headers'].append('#include <ATen/DeviceGuard.h>')
259+
env['extra_cuda_headers'].append('#include <ATen/cuda/CUDADevice.h>')
260+
env['extra_cuda_headers'].append('#include <ATen/cuda/CUDATypeDefault.h>')
259261
sname = '' if scalar_name == "Float" else scalar_name
260262
env['THType'] = 'Cuda{}'.format(sname)
261263
env['THStorage'] = 'THCuda{}Storage'.format(sname)

aten/src/ATen/templates/SparseTypeDerived.cpp

+3-31
Original file line numberDiff line numberDiff line change
@@ -28,42 +28,18 @@
2828
namespace at {
2929

3030
${Type}::${Type}()
31-
: TypeDefault(${Backend}TensorId(), /*is_variable=*/false, /*is_undefined=*/false) {}
31+
: ${DenseBackend}TypeDefault(${Backend}TensorId(), /*is_variable=*/false, /*is_undefined=*/false) {}
3232
ScalarType ${Type}::scalarType() const {
3333
return ScalarType::${ScalarName};
3434
}
3535
Backend ${Type}::backend() const {
3636
return Backend::${Backend};
3737
}
38-
bool ${Type}::is_cuda() const { return backend() == Backend::CUDA || backend() == Backend::SparseCUDA; }
39-
bool ${Type}::is_sparse() const { return backend() == Backend::SparseCPU || backend() == Backend::SparseCUDA; }
40-
bool ${Type}::is_distributed() const { return false; }
41-
42-
Storage ${Type}::storage(bool resizable) const {
43-
AT_ERROR("storage not supported on sparse");
44-
}
45-
Storage ${Type}::storage(size_t size, bool resizable) const {
46-
AT_ERROR("storage not supported on sparse");
47-
}
48-
Storage ${Type}::storageFromBlob(void * data, int64_t size, const std::function<void(void*)> & deleter) const {
49-
AT_ERROR("storage not supported on sparse");
50-
}
51-
Storage ${Type}::storageWithAllocator(int64_t size, Allocator* allocator) const {
52-
AT_ERROR("storage not supported on sparse");
53-
}
54-
Tensor ${Type}::unsafeTensorFromTH(void * th_pointer, bool retain) const {
55-
AT_ERROR("unsafeTensorFromTH not supported on sparse");
56-
}
57-
Storage ${Type}::unsafeStorageFromTH(void * th_pointer, bool retain) const {
58-
AT_ERROR("unsafeTensorFromTH not supported on sparse");
59-
}
60-
std::unique_ptr<Generator> ${Type}::generator() const {
61-
return std::unique_ptr<Generator>(new ${Generator}(&at::globalContext()));
62-
}
6338

6439
const char * ${Type}::toString() const {
65-
return ${Type}::typeString();
40+
return "${Type}";
6641
}
42+
6743
TypeID ${Type}::ID() const {
6844
return ${TypeID};
6945
}
@@ -72,10 +48,6 @@ size_t ${Type}::elementSizeInBytes() const {
7248
return sizeof(${ScalarType});
7349
}
7450

75-
const char * ${Type}::typeString() {
76-
return "${Type}";
77-
}
78-
7951
${type_derived_method_definitions}
8052

8153
}

aten/src/ATen/templates/Type.h

+2
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,8 @@ struct AT_API Type {
5858
virtual bool is_distributed() const = 0;
5959
bool is_variable() const noexcept { return is_variable_; }
6060
bool is_undefined() const noexcept { return is_undefined_; }
61+
virtual Allocator * allocator() const = 0;
62+
virtual Device getDeviceFromPtr(void * data) const = 0;
6163
virtual Storage storage(bool resizable = false) const = 0;
6264
virtual Storage storage(size_t size, bool resizable = false) const = 0;
6365
virtual Storage storageFromBlob(void * data, int64_t size, const std::function<void(void*)> & deleter=noop_deleter) const = 0;

aten/src/ATen/templates/TypeDefault.cpp

+28
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,34 @@ Tensor TypeDefault::tensorWithAllocator(IntList sizes, IntList strides, Allocato
7878
auto storage = storageWithAllocator(computeStorageSize(sizes, strides), std::move(allocator));
7979
return tensor(storage, 0, sizes, strides);
8080
}
81+
82+
Storage TypeDefault::storage(bool resizable) const {
83+
return Storage(scalarType(), 0, allocator(), resizable);
84+
}
85+
Storage TypeDefault::storage(size_t size, bool resizable) const {
86+
return Storage(scalarType(), size, allocator(), resizable);
87+
}
88+
Storage TypeDefault::storageFromBlob(void * data, int64_t size, const std::function<void(void*)> & deleter) const {
89+
return Storage(
90+
scalarType(),
91+
InefficientStdFunctionContext::makeDataPtr(data, deleter, getDeviceFromPtr(data)),
92+
size,
93+
deleter);
94+
}
95+
Storage TypeDefault::storageWithAllocator(int64_t size, Allocator* allocator) const {
96+
return Storage(scalarType(), size, allocator);
97+
}
98+
Tensor TypeDefault::unsafeTensorFromTH(void * th_pointer, bool retain) const {
99+
return Tensor(static_cast<TensorImpl*>(th_pointer), retain);
100+
}
101+
Storage TypeDefault::unsafeStorageFromTH(void * th_pointer, bool retain) const {
102+
if (retain && th_pointer) {
103+
c10::raw::intrusive_ptr::incref(static_cast<StorageImpl*>(th_pointer));
104+
}
105+
return Storage(static_cast<StorageImpl*>(th_pointer));
106+
}
107+
108+
81109
Tensor TypeDefault::scalarTensor(Scalar s) const {
82110
return tensor({}).fill_(s);
83111
}

aten/src/ATen/templates/TypeDefault.h

+16-3
Original file line numberDiff line numberDiff line change
@@ -12,9 +12,15 @@ struct AT_API TypeDefault : public Type {
1212

1313
// Make sure overload resolution considers the nullary virtual method.
1414
// (A single argument overload is generated in the list.)
15-
bool is_cuda() const override = 0;
16-
bool is_sparse() const override = 0;
17-
bool is_distributed() const override = 0;
15+
bool is_cuda() const override {
16+
return backend() == Backend::CUDA || backend() == Backend::SparseCUDA;
17+
}
18+
bool is_sparse() const override {
19+
return backend() == Backend::SparseCPU || backend() == Backend::SparseCUDA;
20+
}
21+
bool is_distributed() const override {
22+
return false;
23+
}
1824

1925
Type & toBackend(Backend b) const override;
2026
Type & toScalarType(ScalarType s) const override;
@@ -28,6 +34,13 @@ struct AT_API TypeDefault : public Type {
2834
Tensor tensorWithAllocator(IntList sizes, IntList strides, Allocator* allocator) const override;
2935
Tensor scalarTensor(Scalar s) const override;
3036

37+
Storage storage(bool resizable = false) const override;
38+
Storage storage(size_t size, bool resizable = false) const override;
39+
Storage storageFromBlob(void * data, int64_t size, const std::function<void(void*)> & deleter) const override;
40+
Storage storageWithAllocator(int64_t size, Allocator* allocator) const override;
41+
Storage unsafeStorageFromTH(void * th_pointer, bool retain) const override;
42+
Tensor unsafeTensorFromTH(void * th_pointer, bool retain) const override;
43+
3144
// example
3245
// virtual Tensor * add(Tensor & a, Tensor & b) = 0;
3346
${type_method_declarations}

0 commit comments

Comments
 (0)