forked from CoinCheung/pytorch-loss
-
Notifications
You must be signed in to change notification settings - Fork 0
/
setup.py
46 lines (43 loc) · 1.36 KB
/
setup.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
from setuptools import setup, Extension, find_packages
from torch.utils import cpp_extension
'''
python setup.py install
usage: import torch first, then import this module
'''
setup(
name='pytorch_loss',
ext_modules=[
cpp_extension.CUDAExtension(
'focal_cpp',
['csrc/focal_kernel.cu', ]),
cpp_extension.CUDAExtension(
'mish_cpp',
['csrc/mish_kernel.cu']),
cpp_extension.CUDAExtension(
'swish_cpp',
['csrc/swish_kernel.cu']),
cpp_extension.CUDAExtension(
'soft_dice_cpp',
['csrc/soft_dice_kernel_v2.cu']),
cpp_extension.CUDAExtension(
'lsr_cpp',
['csrc/lsr_kernel.cu']),
cpp_extension.CUDAExtension(
'large_margin_cpp',
['csrc/large_margin_kernel.cu']),
cpp_extension.CUDAExtension(
'ohem_cpp',
['csrc/ohem_label_kernel.cu']),
cpp_extension.CUDAExtension(
'one_hot_cpp',
['csrc/one_hot_kernel.cu']),
cpp_extension.CUDAExtension(
'lovasz_softmax_cpp',
['csrc/lovasz_softmax.cu']),
cpp_extension.CUDAExtension(
'taylor_softmax_cpp',
['csrc/taylor_softmax.cu']),
],
cmdclass={'build_ext': cpp_extension.BuildExtension},
packages=find_packages()
)