Skip to content

Commit ef6548c

Browse files
committed
[STABLE ABI] Port lfilter
1 parent 32ce8c0 commit ef6548c

File tree

5 files changed

+203
-114
lines changed

5 files changed

+203
-114
lines changed

src/libtorchaudio/iir_cuda.cu

Lines changed: 41 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,18 @@
1+
#include <libtorchaudio/utils.h>
2+
#include <torch/headeronly/core/Dispatch_v2.h>
3+
#include <torch/headeronly/core/ScalarType.h>
14
#include <c10/cuda/CUDAException.h>
25
#include <c10/cuda/CUDAGuard.h>
3-
#include <torch/torch.h>
6+
#include <c10/core/DeviceGuard.h>
7+
8+
using torch::headeronly::ScalarType;
9+
using torch::stable::Tensor;
410

511
template <typename scalar_t>
612
__global__ void iir_cu_kernel(
7-
const torch::
8-
PackedTensorAccessor<scalar_t, 3, torch::RestrictPtrTraits, size_t> in,
9-
const torch::
10-
PackedTensorAccessor<scalar_t, 2, torch::RestrictPtrTraits, size_t>
11-
a_flipped,
12-
torch::PackedTensorAccessor<scalar_t, 3, torch::RestrictPtrTraits, size_t>
13-
out) {
13+
const torchaudio::PackedTensorAccessorSizeT<scalar_t, 3> in,
14+
const torchaudio::PackedTensorAccessorSizeT<scalar_t, 2> a_flipped,
15+
torchaudio::PackedTensorAccessorSizeT<scalar_t, 3> out) {
1416
int64_t tid = blockIdx.x * blockDim.x + threadIdx.x;
1517
int64_t n = in.size(0);
1618
int64_t c = in.size(1);
@@ -33,51 +35,49 @@ __global__ void iir_cu_kernel(
3335
}
3436
}
3537

36-
void cuda_lfilter_core_loop(
37-
const torch::Tensor& in,
38-
const torch::Tensor& a_flipped,
39-
torch::Tensor& padded_out) {
40-
TORCH_CHECK(
41-
in.device().is_cuda() && a_flipped.device().is_cuda() &&
42-
padded_out.device().is_cuda());
38+
Tensor cuda_lfilter_core_loop(
39+
Tensor in,
40+
Tensor a_flipped,
41+
Tensor padded_out) {
42+
STD_TORCH_CHECK(
43+
in.is_cuda() && a_flipped.is_cuda() &&
44+
padded_out.is_cuda());
4345

44-
TORCH_CHECK(
46+
STD_TORCH_CHECK(
47+
(in.get_device_index() == a_flipped.get_device_index()) &&
48+
(in.get_device_index() == padded_out.get_device_index()));
49+
50+
STD_TORCH_CHECK(
4551
in.is_contiguous() && a_flipped.is_contiguous() &&
4652
padded_out.is_contiguous());
4753

48-
TORCH_CHECK(
49-
(in.dtype() == torch::kFloat32 || in.dtype() == torch::kFloat64) &&
50-
(a_flipped.dtype() == torch::kFloat32 ||
51-
a_flipped.dtype() == torch::kFloat64) &&
52-
(padded_out.dtype() == torch::kFloat32 ||
53-
padded_out.dtype() == torch::kFloat64));
54+
STD_TORCH_CHECK(
55+
(in.scalar_type() == ScalarType::Float || in.scalar_type() == ScalarType::Double) &&
56+
(a_flipped.scalar_type() == ScalarType::Float ||
57+
a_flipped.scalar_type() == ScalarType::Double) &&
58+
(padded_out.scalar_type() == ScalarType::Float ||
59+
padded_out.scalar_type() == ScalarType::Double));
5460

5561
const int N = in.size(0);
5662
const int C = in.size(1);
57-
TORCH_CHECK(N == padded_out.size(0));
58-
TORCH_CHECK(C == padded_out.size(1));
63+
STD_TORCH_CHECK(N == padded_out.size(0));
64+
STD_TORCH_CHECK(C == padded_out.size(1));
5965

60-
TORCH_CHECK(in.size(2) + a_flipped.size(1) - 1 == padded_out.size(2));
66+
STD_TORCH_CHECK(in.size(2) + a_flipped.size(1) - 1 == padded_out.size(2));
6167

62-
const at::cuda::OptionalCUDAGuard device_guard(device_of(in));
68+
// TODO: enable device guard:
69+
//const at::cuda::OptionalCUDAGuard device_guard(in.device());
6370

6471
const dim3 threads(256);
6572
const dim3 blocks((N * C + threads.x - 1) / threads.x);
6673

67-
AT_DISPATCH_FLOATING_TYPES(
68-
in.scalar_type(), "iir_cu_loop", ([&] {
69-
iir_cu_kernel<scalar_t><<<blocks, threads>>>(
70-
in.packed_accessor<scalar_t, 3, torch::RestrictPtrTraits, size_t>(),
71-
a_flipped.packed_accessor<
72-
scalar_t,
73-
2,
74-
torch::RestrictPtrTraits,
75-
size_t>(),
76-
padded_out.packed_accessor<
77-
scalar_t,
78-
3,
79-
torch::RestrictPtrTraits,
80-
size_t>());
74+
THO_DISPATCH_V2(
75+
in.scalar_type(), "iir_cu_loop", AT_WRAP([&] {
76+
(iir_cu_kernel<scalar_t><<<blocks, threads>>>(
77+
torchaudio::packed_accessor_size_t<scalar_t, 3>(in),
78+
torchaudio::packed_accessor_size_t<scalar_t, 2>(a_flipped),
79+
torchaudio::packed_accessor_size_t<scalar_t, 3>(padded_out)));
8180
C10_CUDA_KERNEL_LAUNCH_CHECK();
82-
}));
81+
}), AT_FLOATING_TYPES);
82+
return padded_out;
8383
}

src/libtorchaudio/iir_cuda.h

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,7 @@
11
#pragma once
22

3-
#include <torch/types.h>
3+
#include <torch/csrc/stable/tensor.h>
44

5-
void cuda_lfilter_core_loop(
6-
const torch::Tensor& in,
7-
const torch::Tensor& a_flipped,
8-
torch::Tensor& padded_out);
5+
using torch::stable::Tensor;
6+
7+
Tensor cuda_lfilter_core_loop(Tensor in, Tensor a_flipped, Tensor padded_out);

src/libtorchaudio/lfilter.cpp

Lines changed: 93 additions & 68 deletions
Original file line numberDiff line numberDiff line change
@@ -1,116 +1,141 @@
1-
#include <torch/script.h>
2-
#include <torch/torch.h>
1+
#include <libtorchaudio/utils.h>
2+
#include <torch/csrc/stable/library.h>
3+
#include <torch/csrc/stable/ops.h>
4+
#include <torch/csrc/stable/tensor.h>
5+
#include <torch/headeronly/core/Dispatch_v2.h>
6+
#include <torch/headeronly/core/ScalarType.h>
37

48
#ifdef USE_CUDA
59
#include <libtorchaudio/iir_cuda.h>
610
#endif
711

812
namespace {
913

14+
using torch::headeronly::ScalarType;
15+
using torch::stable::Tensor;
16+
1017
template <typename scalar_t>
1118
void host_lfilter_core_loop(
12-
const torch::Tensor& input_signal_windows,
13-
const torch::Tensor& a_coeff_flipped,
14-
torch::Tensor& padded_output_waveform) {
19+
const Tensor& input_signal_windows,
20+
const Tensor& a_coeff_flipped,
21+
Tensor& padded_output_waveform) {
1522
int64_t n_batch = input_signal_windows.size(0);
1623
int64_t n_channel = input_signal_windows.size(1);
1724
int64_t n_samples_input = input_signal_windows.size(2);
1825
int64_t n_samples_output = padded_output_waveform.size(2);
1926
int64_t n_order = a_coeff_flipped.size(1);
20-
scalar_t* output_data = padded_output_waveform.data_ptr<scalar_t>();
21-
const scalar_t* input_data = input_signal_windows.data_ptr<scalar_t>();
22-
const scalar_t* a_coeff_flipped_data = a_coeff_flipped.data_ptr<scalar_t>();
23-
24-
at::parallel_for(0, n_channel * n_batch, 1, [&](int64_t begin, int64_t end) {
25-
for (auto i = begin; i < end; i++) {
26-
int64_t offset_input = i * n_samples_input;
27-
int64_t offset_output = i * n_samples_output;
28-
int64_t i_channel = i % n_channel;
29-
for (int64_t i_sample = 0; i_sample < n_samples_input; i_sample++) {
30-
scalar_t a0 = input_data[offset_input + i_sample];
31-
for (int64_t i_coeff = 0; i_coeff < n_order; i_coeff++) {
32-
a0 -= output_data[offset_output + i_sample + i_coeff] *
33-
a_coeff_flipped_data[i_coeff + i_channel * n_order];
27+
scalar_t* output_data =
28+
reinterpret_cast<scalar_t*>(padded_output_waveform.data_ptr());
29+
const scalar_t* input_data =
30+
reinterpret_cast<scalar_t*>(input_signal_windows.data_ptr());
31+
const scalar_t* a_coeff_flipped_data =
32+
reinterpret_cast<scalar_t*>(a_coeff_flipped.data_ptr());
33+
34+
torch::stable::parallel_for(
35+
0, n_channel * n_batch, 1, [&](int64_t begin, int64_t end) {
36+
for (auto i = begin; i < end; i++) {
37+
int64_t offset_input = i * n_samples_input;
38+
int64_t offset_output = i * n_samples_output;
39+
int64_t i_channel = i % n_channel;
40+
for (int64_t i_sample = 0; i_sample < n_samples_input; i_sample++) {
41+
scalar_t a0 = input_data[offset_input + i_sample];
42+
for (int64_t i_coeff = 0; i_coeff < n_order; i_coeff++) {
43+
a0 -= output_data[offset_output + i_sample + i_coeff] *
44+
a_coeff_flipped_data[i_coeff + i_channel * n_order];
45+
}
46+
output_data[offset_output + i_sample + n_order - 1] = a0;
47+
}
3448
}
35-
output_data[offset_output + i_sample + n_order - 1] = a0;
36-
}
37-
}
38-
});
49+
});
3950
}
4051

41-
void cpu_lfilter_core_loop(
42-
const torch::Tensor& input_signal_windows,
43-
const torch::Tensor& a_coeff_flipped,
44-
torch::Tensor& padded_output_waveform) {
45-
TORCH_CHECK(
46-
input_signal_windows.device().is_cpu() &&
47-
a_coeff_flipped.device().is_cpu() &&
48-
padded_output_waveform.device().is_cpu());
52+
Tensor cpu_lfilter_core_loop(
53+
Tensor input_signal_windows,
54+
Tensor a_coeff_flipped,
55+
Tensor padded_output_waveform) {
56+
STD_TORCH_CHECK(
57+
input_signal_windows.is_cpu() && a_coeff_flipped.is_cpu() &&
58+
padded_output_waveform.is_cpu());
4959

50-
TORCH_CHECK(
60+
STD_TORCH_CHECK(
5161
input_signal_windows.is_contiguous() && a_coeff_flipped.is_contiguous() &&
5262
padded_output_waveform.is_contiguous());
5363

54-
TORCH_CHECK(
55-
(input_signal_windows.dtype() == torch::kFloat32 ||
56-
input_signal_windows.dtype() == torch::kFloat64) &&
57-
(a_coeff_flipped.dtype() == torch::kFloat32 ||
58-
a_coeff_flipped.dtype() == torch::kFloat64) &&
59-
(padded_output_waveform.dtype() == torch::kFloat32 ||
60-
padded_output_waveform.dtype() == torch::kFloat64));
64+
STD_TORCH_CHECK(
65+
(input_signal_windows.scalar_type() == ScalarType::Float ||
66+
input_signal_windows.scalar_type() == ScalarType::Double) &&
67+
(a_coeff_flipped.scalar_type() == ScalarType::Float ||
68+
a_coeff_flipped.scalar_type() == ScalarType::Double) &&
69+
(padded_output_waveform.scalar_type() == ScalarType::Float ||
70+
padded_output_waveform.scalar_type() == ScalarType::Double));
6171

62-
TORCH_CHECK(input_signal_windows.size(0) == padded_output_waveform.size(0));
63-
TORCH_CHECK(input_signal_windows.size(1) == padded_output_waveform.size(1));
72+
STD_TORCH_CHECK(
73+
input_signal_windows.size(0) == padded_output_waveform.size(0));
74+
STD_TORCH_CHECK(
75+
input_signal_windows.size(1) == padded_output_waveform.size(1));
6476

65-
TORCH_CHECK(
77+
STD_TORCH_CHECK(
6678
input_signal_windows.size(2) + a_coeff_flipped.size(1) - 1 ==
6779
padded_output_waveform.size(2));
6880

69-
AT_DISPATCH_FLOATING_TYPES(
70-
input_signal_windows.scalar_type(), "lfilter_core_loop", [&] {
81+
THO_DISPATCH_V2(
82+
input_signal_windows.scalar_type(),
83+
"lfilter_core_loop",
84+
[&] {
7185
host_lfilter_core_loop<scalar_t>(
7286
input_signal_windows, a_coeff_flipped, padded_output_waveform);
73-
});
87+
},
88+
AT_FLOATING_TYPES);
89+
return padded_output_waveform;
7490
}
7591

76-
void lfilter_core_generic_loop(
77-
const torch::Tensor& input_signal_windows,
78-
const torch::Tensor& a_coeff_flipped,
79-
torch::Tensor& padded_output_waveform) {
92+
Tensor lfilter_core_generic_loop(
93+
Tensor input_signal_windows,
94+
Tensor a_coeff_flipped,
95+
Tensor padded_output_waveform) {
8096
int64_t n_samples_input = input_signal_windows.size(2);
8197
int64_t n_order = a_coeff_flipped.size(1);
82-
auto coeff = a_coeff_flipped.unsqueeze(2);
98+
auto coeff = torchaudio::stable::unsqueeze(a_coeff_flipped, 2);
8399
for (int64_t i_sample = 0; i_sample < n_samples_input; i_sample++) {
84-
auto windowed_output_signal =
85-
torch::narrow(padded_output_waveform, 2, i_sample, i_sample + n_order)
86-
.transpose(0, 1);
87-
auto o0 = torch::select(input_signal_windows, 2, i_sample) -
88-
at::matmul(windowed_output_signal, coeff).squeeze(2).transpose(0, 1);
89-
padded_output_waveform.index_put_(
90-
{torch::indexing::Slice(),
91-
torch::indexing::Slice(),
92-
i_sample + n_order - 1},
93-
o0);
100+
auto windowed_output_signal = torch::stable::transpose(
101+
torch::stable::narrow(
102+
padded_output_waveform, 2, i_sample, i_sample + n_order),
103+
0,
104+
1);
105+
auto o0 = torchaudio::stable::subtract(
106+
torchaudio::stable::select(input_signal_windows, 2, i_sample),
107+
torch::stable::transpose(
108+
torchaudio::stable::squeeze(
109+
torchaudio::stable::matmul(windowed_output_signal, coeff), 2),
110+
0,
111+
1));
112+
auto s = torchaudio::stable::select(
113+
padded_output_waveform, 2, i_sample + n_order - 1);
114+
torch::stable::copy_(s, o0);
94115
}
116+
return padded_output_waveform;
95117
}
96118

97119
} // namespace
98120

99-
TORCH_LIBRARY(torchaudio, m) {
121+
STABLE_TORCH_LIBRARY_FRAGMENT(torchaudio, m) {
100122
m.def(
101-
"torchaudio::_lfilter_core_loop(Tensor input_signal_windows, Tensor a_coeff_flipped, Tensor(a!) padded_output_waveform) -> ()");
123+
"_lfilter_core_loop("
124+
"Tensor input_signal_windows,"
125+
"Tensor a_coeff_flipped,"
126+
"Tensor(a!) padded_output_waveform) -> Tensor(a!)");
102127
}
103128

104-
TORCH_LIBRARY_IMPL(torchaudio, CPU, m) {
105-
m.impl("torchaudio::_lfilter_core_loop", &cpu_lfilter_core_loop);
129+
STABLE_TORCH_LIBRARY_IMPL(torchaudio, CPU, m) {
130+
m.impl("_lfilter_core_loop", TORCH_BOX(&cpu_lfilter_core_loop));
106131
}
107132

108133
#ifdef USE_CUDA
109-
TORCH_LIBRARY_IMPL(torchaudio, CUDA, m) {
110-
m.impl("torchaudio::_lfilter_core_loop", &cuda_lfilter_core_loop);
134+
STABLE_TORCH_LIBRARY_IMPL(torchaudio, CUDA, m) {
135+
m.impl("_lfilter_core_loop", TORCH_BOX(&cuda_lfilter_core_loop));
111136
}
112137
#endif
113138

114-
TORCH_LIBRARY_IMPL(torchaudio, CompositeExplicitAutograd, m) {
115-
m.impl("torchaudio::_lfilter_core_loop", &lfilter_core_generic_loop);
139+
STABLE_TORCH_LIBRARY_IMPL(torchaudio, CompositeExplicitAutograd, m) {
140+
m.impl("_lfilter_core_loop", TORCH_BOX(&lfilter_core_generic_loop));
116141
}

src/libtorchaudio/stable/ops.h

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -182,4 +182,51 @@ T item(const Tensor& self) {
182182
}
183183
}
184184

185+
inline Tensor unsqueeze(const Tensor& self, int64_t dim) {
186+
const auto num_args = 2;
187+
std::array<StableIValue, num_args> stack{
188+
torch::stable::detail::from(self), torch::stable::detail::from(dim)};
189+
TORCH_ERROR_CODE_CHECK(torch_call_dispatcher(
190+
"aten::unsqueeze", "", stack.data(), TORCH_ABI_VERSION));
191+
return torch::stable::detail::to<torch::stable::Tensor>(stack[0]);
192+
}
193+
194+
inline Tensor select(const Tensor& self, int64_t dim, int64_t index) {
195+
const auto num_args = 3;
196+
std::array<StableIValue, num_args> stack{
197+
torch::stable::detail::from(self),
198+
torch::stable::detail::from(dim),
199+
torch::stable::detail::from(index)};
200+
TORCH_ERROR_CODE_CHECK(torch_call_dispatcher(
201+
"aten::select", "", stack.data(), TORCH_ABI_VERSION));
202+
return torch::stable::detail::to<torch::stable::Tensor>(stack[0]);
203+
}
204+
205+
inline Tensor squeeze(const Tensor& self, int64_t dim) {
206+
const auto num_args = 2;
207+
std::array<StableIValue, num_args> stack{
208+
torch::stable::detail::from(self), torch::stable::detail::from(dim)};
209+
TORCH_ERROR_CODE_CHECK(torch_call_dispatcher(
210+
"aten::squeeze", "dim", stack.data(), TORCH_ABI_VERSION));
211+
return torch::stable::detail::to<torch::stable::Tensor>(stack[0]);
212+
}
213+
214+
inline Tensor matmul(const Tensor& self, const Tensor& other) {
215+
const auto num_args = 2;
216+
std::array<StableIValue, num_args> stack{
217+
torch::stable::detail::from(self), torch::stable::detail::from(other)};
218+
TORCH_ERROR_CODE_CHECK(torch_call_dispatcher(
219+
"aten::matmul", "", stack.data(), TORCH_ABI_VERSION));
220+
return torch::stable::detail::to<torch::stable::Tensor>(stack[0]);
221+
}
222+
223+
inline Tensor subtract(const Tensor& self, const Tensor& other) {
224+
const auto num_args = 2;
225+
std::array<StableIValue, num_args> stack{
226+
torch::stable::detail::from(self), torch::stable::detail::from(other)};
227+
TORCH_ERROR_CODE_CHECK(torch_call_dispatcher(
228+
"aten::subtract", "Tensor", stack.data(), TORCH_ABI_VERSION));
229+
return torch::stable::detail::to<torch::stable::Tensor>(stack[0]);
230+
}
231+
185232
} // namespace torchaudio::stable

0 commit comments

Comments
 (0)