Skip to content

Commit 3e4aa3b

Browse files
committed
upload pointops for segmentation
1 parent 8130a00 commit 3e4aa3b

30 files changed

+1112
-5
lines changed

segmentation/init.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,6 @@ pip install torch==1.8.0+cu111 torchvision==0.9.0+cu111 torchaudio==0.8.0 -f htt
1313
conda install -c anaconda h5py pyyaml -y
1414
conda install -c conda-forge sharedarray tensorboardx -y
1515

16-
cd lib/pointops
16+
cd modules/pointops
1717
python3 setup.py install
1818
cd -

segmentation/modules/pointnet2_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
import torch.nn as nn
88
import torch.nn.functional as F
99

10-
from lib.pointops.functions import pointops
10+
from modules.pointops.functions import pointops
1111

1212

1313
def sample_and_group(stride, nsample, xyz, points, offset, return_idx=False, num_sector=1):

segmentation/modules/pointops/__init__.py

Whitespace-only changes.

segmentation/modules/pointops/functions/__init__.py

Whitespace-only changes.
Lines changed: 307 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,307 @@
1+
from typing import Tuple
2+
3+
import torch
4+
from torch.autograd import Function
5+
import torch.nn as nn
6+
7+
try:
8+
import pointops_cuda
9+
except ImportError:
10+
import warnings
11+
import os
12+
from torch.utils.cpp_extension import load
13+
14+
warnings.warn("Unable to load pointops_cuda cpp extension.")
15+
pointops_cuda_src = os.path.join(os.path.dirname(__file__), "../src")
16+
pointops_cuda = load('pointops_cuda', [
17+
pointops_cuda_src + '/pointops_api.cpp',
18+
pointops_cuda_src + '/knnquery/knnquery_cuda.cpp',
19+
pointops_cuda_src + '/knnquery/knnquery_cuda_kernel.cu',
20+
pointops_cuda_src + '/interpolation/interpolation_cuda.cpp',
21+
pointops_cuda_src + '/interpolation/interpolation_cuda_kernel.cu',
22+
pointops_cuda_src + '/sampling/sampling_cuda.cpp',
23+
pointops_cuda_src + '/sampling/sampling_cuda_kernel.cu',
24+
pointops_cuda_src + '/subtraction/subtraction_cuda.cpp',
25+
pointops_cuda_src + '/subtraction/subtraction_cuda_kernel.cu',
26+
pointops_cuda_src + '/aggregation/aggregation_cuda.cpp',
27+
pointops_cuda_src + '/aggregation/aggregation_cuda_kernel.cu',
28+
], build_directory=pointops_cuda_src, verbose=False)
29+
30+
31+
class FurthestSampling(Function):
32+
@staticmethod
33+
def forward(ctx, xyz, offset, new_offset):
34+
"""
35+
input: xyz: (n, 3), offset: (b), new_offset: (b)
36+
output: idx: (m)
37+
"""
38+
assert xyz.is_contiguous()
39+
n, b, n_max = xyz.shape[0], offset.shape[0], offset[0]
40+
for i in range(1, b):
41+
n_max = max(offset[i] - offset[i - 1], n_max)
42+
idx = torch.cuda.IntTensor(new_offset[b - 1].item()).zero_()
43+
tmp = torch.cuda.FloatTensor(n).fill_(1e10)
44+
pointops_cuda.furthestsampling_cuda(b, n_max, xyz, offset, new_offset, tmp, idx)
45+
del tmp
46+
return idx
47+
48+
49+
furthestsampling = FurthestSampling.apply
50+
51+
52+
class SectorizedFurthestSampling(Function):
53+
@staticmethod
54+
def forward(ctx, xyz, offset, new_offset, num_sectors, min_points=10000):
55+
"""
56+
input: xyz: (n, 3), offset: (b), new_offset: (b)
57+
output: idx: (m)
58+
"""
59+
assert xyz.is_contiguous()
60+
61+
# cut into batches
62+
last_offset = 0
63+
sizes = []
64+
new_sizes = []
65+
indices = []
66+
for i in range(offset.shape[0]):
67+
size = offset[i] - last_offset
68+
if size < min_points:
69+
tmp_num_sectors = 1
70+
else:
71+
tmp_num_sectors = num_sectors
72+
batch_xyz = xyz[last_offset:last_offset + size]
73+
angle = torch.atan2(batch_xyz[:, 0], batch_xyz[:, 1]) # [0, 2*pi]
74+
sector_range = torch.linspace(angle.min(), angle.max() + 1e-4, tmp_num_sectors + 1)
75+
for s in range(tmp_num_sectors):
76+
indices.append(
77+
torch.where((angle >= sector_range[s]) & (angle < sector_range[s + 1]))[0] + last_offset
78+
)
79+
sizes.append(indices[-1].shape[0])
80+
if i > 0:
81+
new_size = (new_offset[i] - new_offset[i - 1]).item()
82+
else:
83+
new_size = new_offset[i].item()
84+
new_sizes_this_batch = [new_size // tmp_num_sectors for i in range(tmp_num_sectors)]
85+
new_sizes_this_batch[-1] += new_size % tmp_num_sectors
86+
new_sizes += new_sizes_this_batch
87+
last_offset = offset[i]
88+
89+
sizes = torch.tensor(sizes, dtype=torch.long).to(offset)
90+
sector_offset = sizes.cumsum(dim=0)
91+
new_sizes = torch.tensor(new_sizes, dtype=torch.long).to(offset)
92+
new_sector_offset = new_sizes.cumsum(dim=0)
93+
indices = torch.cat(indices).long().to(offset.device)
94+
sector_xyz = xyz[indices].contiguous()
95+
96+
# transform to sectors
97+
new_xyz = []
98+
n, b, n_max = sector_xyz.shape[0], sector_offset.shape[0], sector_offset[0]
99+
for i in range(1, b):
100+
n_max = max(sector_offset[i] - sector_offset[i - 1], n_max)
101+
idx = torch.cuda.IntTensor(new_sector_offset[b - 1].item()).zero_()
102+
tmp = torch.cuda.FloatTensor(n).fill_(1e10)
103+
pointops_cuda.furthestsampling_cuda(b, n_max, sector_xyz, sector_offset.int(), new_sector_offset.int(), tmp,
104+
idx)
105+
idx = indices[idx.long()]
106+
del tmp
107+
del sector_xyz
108+
return idx
109+
110+
111+
sectorized_fps = SectorizedFurthestSampling.apply
112+
113+
114+
class KNNQuery(Function):
115+
@staticmethod
116+
def forward(ctx, nsample, xyz, new_xyz, offset, new_offset):
117+
"""
118+
input: xyz: (n, 3), new_xyz: (m, 3), offset: (b), new_offset: (b)
119+
output: idx: (m, nsample), dist2: (m, nsample)
120+
"""
121+
if new_xyz is None: new_xyz = xyz
122+
assert xyz.is_contiguous() and new_xyz.is_contiguous()
123+
m = new_xyz.shape[0]
124+
idx = torch.cuda.IntTensor(m, nsample).zero_()
125+
dist2 = torch.cuda.FloatTensor(m, nsample).zero_()
126+
pointops_cuda.knnquery_cuda(m, nsample, xyz, new_xyz, offset, new_offset, idx, dist2)
127+
return idx, torch.sqrt(dist2)
128+
129+
130+
knnquery = KNNQuery.apply
131+
132+
133+
class Grouping(Function):
134+
@staticmethod
135+
def forward(ctx, input, idx):
136+
"""
137+
input: input: (n, c), idx : (m, nsample)
138+
output: (m, nsample, c)
139+
"""
140+
assert input.is_contiguous() and idx.is_contiguous()
141+
m, nsample, n, c = idx.shape[0], idx.shape[1], input.shape[0], input.shape[1]
142+
output = torch.cuda.FloatTensor(m, nsample, c)
143+
pointops_cuda.grouping_forward_cuda(m, nsample, c, input, idx, output)
144+
ctx.n = n
145+
ctx.save_for_backward(idx)
146+
return output
147+
148+
@staticmethod
149+
def backward(ctx, grad_output):
150+
"""
151+
input: grad_out: (m, c, nsample)
152+
output: (n, c), None
153+
"""
154+
n = ctx.n
155+
idx, = ctx.saved_tensors
156+
m, nsample, c = grad_output.shape
157+
grad_input = torch.cuda.FloatTensor(n, c).zero_()
158+
pointops_cuda.grouping_backward_cuda(m, nsample, c, grad_output, idx, grad_input)
159+
return grad_input, None
160+
161+
162+
grouping = Grouping.apply
163+
164+
165+
def queryandgroup(nsample, xyz, new_xyz, feat, idx, offset, new_offset, use_xyz=True):
166+
"""
167+
input: xyz: (n, 3), new_xyz: (m, 3), feat: (n, c), idx: (m, nsample), offset: (b), new_offset: (b)
168+
output: new_feat: (m, c+3, nsample), grouped_idx: (m, nsample)
169+
"""
170+
assert xyz.is_contiguous() and new_xyz.is_contiguous() and feat.is_contiguous()
171+
if new_xyz is None:
172+
new_xyz = xyz
173+
if idx is None:
174+
idx, _ = knnquery(nsample, xyz, new_xyz, offset, new_offset) # (m, nsample)
175+
176+
n, m, c = xyz.shape[0], new_xyz.shape[0], feat.shape[1]
177+
grouped_xyz = xyz[idx.view(-1).long(), :].view(m, nsample, 3) # (m, nsample, 3)
178+
# grouped_xyz = grouping(xyz, idx) # (m, nsample, 3)
179+
grouped_xyz -= new_xyz.unsqueeze(1) # (m, nsample, 3)
180+
grouped_feat = feat[idx.view(-1).long(), :].view(m, nsample, c) # (m, nsample, c)
181+
# grouped_feat = grouping(feat, idx) # (m, nsample, c)
182+
183+
if use_xyz:
184+
return torch.cat((grouped_xyz, grouped_feat), -1) # (m, nsample, 3+c)
185+
else:
186+
return grouped_feat
187+
188+
189+
class Subtraction(Function):
190+
@staticmethod
191+
def forward(ctx, input1, input2, idx):
192+
"""
193+
input: input1: (n, c), input2: (n, c), idx: (n, nsample)
194+
output: (n, nsample, c)
195+
"""
196+
assert input1.is_contiguous() and input2.is_contiguous()
197+
n, c = input1.shape;
198+
nsample = idx.shape[-1]
199+
output = torch.cuda.FloatTensor(n, nsample, c).zero_()
200+
pointops_cuda.subtraction_forward_cuda(n, nsample, c, input1, input2, idx, output)
201+
ctx.save_for_backward(idx)
202+
return output
203+
204+
@staticmethod
205+
def backward(ctx, grad_output):
206+
"""
207+
input: grad_out: (n, nsample, c)
208+
output: grad_input1: (n, c), grad_input2: (n, c)
209+
"""
210+
idx, = ctx.saved_tensors
211+
n, nsample, c = grad_output.shape
212+
grad_input1 = torch.cuda.FloatTensor(n, c).zero_()
213+
grad_input2 = torch.cuda.FloatTensor(n, c).zero_()
214+
pointops_cuda.subtraction_backward_cuda(n, nsample, c, idx, grad_output, grad_input1, grad_input2)
215+
return grad_input1, grad_input2, None
216+
217+
218+
subtraction = Subtraction.apply
219+
220+
221+
class Aggregation(Function):
222+
@staticmethod
223+
def forward(ctx, input, position, weight, idx):
224+
"""
225+
input: input: (n, c), position: (n, nsample, c), weight : (n, nsample, c'), idx: (n, nsample)
226+
output: (n, c)
227+
"""
228+
assert input.is_contiguous() and position.is_contiguous() and weight.is_contiguous()
229+
n, nsample, c = position.shape;
230+
w_c = weight.shape[-1]
231+
output = torch.cuda.FloatTensor(n, c).zero_()
232+
pointops_cuda.aggregation_forward_cuda(n, nsample, c, w_c, input, position, weight, idx, output)
233+
ctx.save_for_backward(input, position, weight, idx)
234+
return output
235+
236+
@staticmethod
237+
def backward(ctx, grad_output):
238+
"""
239+
input: grad_out: (n, c)
240+
output: grad_input: (n, c), grad_position: (n, nsample, c), grad_weight : (n, nsample, c')
241+
"""
242+
input, position, weight, idx = ctx.saved_tensors
243+
n, nsample, c = position.shape;
244+
w_c = weight.shape[-1]
245+
grad_input = torch.cuda.FloatTensor(n, c).zero_()
246+
grad_position = torch.cuda.FloatTensor(n, nsample, c).zero_()
247+
grad_weight = torch.cuda.FloatTensor(n, nsample, w_c).zero_()
248+
pointops_cuda.aggregation_backward_cuda(n, nsample, c, w_c, input, position, weight, idx, grad_output,
249+
grad_input, grad_position, grad_weight)
250+
return grad_input, grad_position, grad_weight, None
251+
252+
253+
aggregation = Aggregation.apply
254+
255+
256+
def interpolation(xyz, new_xyz, feat, offset, new_offset, k=3):
257+
"""
258+
input: xyz: (m, 3), new_xyz: (n, 3), feat: (m, c), offset: (b), new_offset: (b)
259+
output: (n, c)
260+
"""
261+
assert xyz.is_contiguous() and new_xyz.is_contiguous() and feat.is_contiguous()
262+
idx, dist = knnquery(k, xyz, new_xyz, offset, new_offset) # (n, 3), (n, 3)
263+
dist_recip = 1.0 / (dist + 1e-8) # (n, 3)
264+
norm = torch.sum(dist_recip, dim=1, keepdim=True)
265+
weight = dist_recip / norm # (n, 3)
266+
267+
new_feat = torch.cuda.FloatTensor(new_xyz.shape[0], feat.shape[1]).zero_()
268+
for i in range(k):
269+
new_feat += feat[idx[:, i].long(), :] * weight[:, i].unsqueeze(-1)
270+
return new_feat
271+
272+
273+
class Interpolation(Function):
274+
@staticmethod
275+
def forward(ctx, xyz, new_xyz, input, offset, new_offset, k=3):
276+
"""
277+
input: xyz: (m, 3), new_xyz: (n, 3), input: (m, c), offset: (b), new_offset: (b)
278+
output: (n, c)
279+
"""
280+
assert xyz.is_contiguous() and new_xyz.is_contiguous() and input.is_contiguous()
281+
idx, dist = knnquery(k, xyz, new_xyz, offset, new_offset) # (n, k), (n, k)
282+
dist_recip = 1.0 / (dist + 1e-8) # (n, k)
283+
norm = torch.sum(dist_recip, dim=1, keepdim=True)
284+
weight = dist_recip / norm # (n, k)
285+
286+
n, c, m = new_xyz.shape[0], input.shape[1], input.shape[0]
287+
output = torch.cuda.FloatTensor(n, c).zero_()
288+
pointops_cuda.interpolation_forward_cuda(n, c, k, input, idx, weight, output)
289+
ctx.m, ctx.k = m, k
290+
ctx.save_for_backward(idx, weight)
291+
return output
292+
293+
@staticmethod
294+
def backward(ctx, grad_output):
295+
"""
296+
input: xyz: (m, 3), new_xyz: (n, 3), input: (m, c), offset: (b), new_offset: (b)
297+
output: (n, c)
298+
"""
299+
m, k = ctx.m, ctx.k
300+
idx, weight = ctx.saved_tensors
301+
n, c = grad_output.shape
302+
grad_input = torch.cuda.FloatTensor(m, c).zero_()
303+
pointops_cuda.interpolation_backward_cuda(n, c, k, grad_output, idx, weight, grad_input)
304+
return None, None, grad_input, None, None, None
305+
306+
307+
interpolation2 = Interpolation.apply
Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
#python3 setup.py install
2+
from setuptools import setup
3+
from torch.utils.cpp_extension import BuildExtension, CUDAExtension
4+
import os
5+
from distutils.sysconfig import get_config_vars
6+
7+
(opt,) = get_config_vars('OPT')
8+
os.environ['OPT'] = " ".join(
9+
flag for flag in opt.split() if flag != '-Wstrict-prototypes'
10+
)
11+
12+
setup(
13+
name='pointops_cuda',
14+
author='Hengshuang Zhao',
15+
ext_modules=[
16+
CUDAExtension('pointops_cuda', [
17+
'src/pointops_api.cpp',
18+
'src/knnquery/knnquery_cuda.cpp',
19+
'src/knnquery/knnquery_cuda_kernel.cu',
20+
'src/sampling/sampling_cuda.cpp',
21+
'src/sampling/sampling_cuda_kernel.cu',
22+
'src/grouping/grouping_cuda.cpp',
23+
'src/grouping/grouping_cuda_kernel.cu',
24+
'src/interpolation/interpolation_cuda.cpp',
25+
'src/interpolation/interpolation_cuda_kernel.cu',
26+
'src/subtraction/subtraction_cuda.cpp',
27+
'src/subtraction/subtraction_cuda_kernel.cu',
28+
'src/aggregation/aggregation_cuda.cpp',
29+
'src/aggregation/aggregation_cuda_kernel.cu',
30+
],
31+
extra_compile_args={'cxx': ['-g'], 'nvcc': ['-O2']}
32+
)
33+
],
34+
cmdclass={'build_ext': BuildExtension}
35+
)

segmentation/modules/pointops/src/__init__.py

Whitespace-only changes.
Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
#include <vector>
2+
#include <THC/THC.h>
3+
#include <torch/serialize/tensor.h>
4+
#include <ATen/cuda/CUDAContext.h>
5+
#include "aggregation_cuda_kernel.h"
6+
7+
8+
void aggregation_forward_cuda(int n, int nsample, int c, int w_c, at::Tensor input_tensor, at::Tensor position_tensor, at::Tensor weight_tensor, at::Tensor idx_tensor, at::Tensor output_tensor)
9+
{
10+
const float *input = input_tensor.data_ptr<float>();
11+
const float *position = position_tensor.data_ptr<float>();
12+
const float *weight = weight_tensor.data_ptr<float>();
13+
const int *idx = idx_tensor.data_ptr<int>();
14+
float *output = output_tensor.data_ptr<float>();
15+
aggregation_forward_cuda_launcher(n, nsample, c, w_c, input, position, weight, idx, output);
16+
}
17+
18+
void aggregation_backward_cuda(int n, int nsample, int c, int w_c, at::Tensor input_tensor, at::Tensor position_tensor, at::Tensor weight_tensor, at::Tensor idx_tensor, at::Tensor grad_output_tensor, at::Tensor grad_input_tensor, at::Tensor grad_position_tensor, at::Tensor grad_weight_tensor)
19+
{
20+
const float *input = input_tensor.data_ptr<float>();
21+
const float *position = position_tensor.data_ptr<float>();
22+
const float *weight = weight_tensor.data_ptr<float>();
23+
const int *idx = idx_tensor.data_ptr<int>();
24+
const float *grad_output = grad_output_tensor.data_ptr<float>();
25+
float *grad_input = grad_input_tensor.data_ptr<float>();
26+
float *grad_position = grad_position_tensor.data_ptr<float>();
27+
float *grad_weight = grad_weight_tensor.data_ptr<float>();
28+
aggregation_backward_cuda_launcher(n, nsample, c, w_c, input, position, weight, idx, grad_output, grad_input, grad_position, grad_weight);
29+
}

0 commit comments

Comments
 (0)