|
1 | | -#include <torch/script.h> |
2 | | -#include <torch/torch.h> |
| 1 | +#include <libtorchaudio/utils.h> |
| 2 | +#include <torch/csrc/stable/library.h> |
| 3 | +#include <torch/csrc/stable/ops.h> |
| 4 | +#include <torch/csrc/stable/tensor.h> |
| 5 | +#include <torch/headeronly/core/Dispatch_v2.h> |
| 6 | +#include <torch/headeronly/core/ScalarType.h> |
3 | 7 |
|
4 | 8 | #ifdef USE_CUDA |
5 | 9 | #include <libtorchaudio/iir_cuda.h> |
6 | 10 | #endif |
7 | 11 |
|
8 | 12 | namespace { |
9 | 13 |
|
| 14 | +using torch::headeronly::ScalarType; |
| 15 | +using torch::stable::Tensor; |
| 16 | + |
10 | 17 | template <typename scalar_t> |
11 | 18 | void host_lfilter_core_loop( |
12 | | - const torch::Tensor& input_signal_windows, |
13 | | - const torch::Tensor& a_coeff_flipped, |
14 | | - torch::Tensor& padded_output_waveform) { |
| 19 | + const Tensor& input_signal_windows, |
| 20 | + const Tensor& a_coeff_flipped, |
| 21 | + Tensor& padded_output_waveform) { |
15 | 22 | int64_t n_batch = input_signal_windows.size(0); |
16 | 23 | int64_t n_channel = input_signal_windows.size(1); |
17 | 24 | int64_t n_samples_input = input_signal_windows.size(2); |
18 | 25 | int64_t n_samples_output = padded_output_waveform.size(2); |
19 | 26 | int64_t n_order = a_coeff_flipped.size(1); |
20 | | - scalar_t* output_data = padded_output_waveform.data_ptr<scalar_t>(); |
21 | | - const scalar_t* input_data = input_signal_windows.data_ptr<scalar_t>(); |
22 | | - const scalar_t* a_coeff_flipped_data = a_coeff_flipped.data_ptr<scalar_t>(); |
23 | | - |
24 | | - at::parallel_for(0, n_channel * n_batch, 1, [&](int64_t begin, int64_t end) { |
25 | | - for (auto i = begin; i < end; i++) { |
26 | | - int64_t offset_input = i * n_samples_input; |
27 | | - int64_t offset_output = i * n_samples_output; |
28 | | - int64_t i_channel = i % n_channel; |
29 | | - for (int64_t i_sample = 0; i_sample < n_samples_input; i_sample++) { |
30 | | - scalar_t a0 = input_data[offset_input + i_sample]; |
31 | | - for (int64_t i_coeff = 0; i_coeff < n_order; i_coeff++) { |
32 | | - a0 -= output_data[offset_output + i_sample + i_coeff] * |
33 | | - a_coeff_flipped_data[i_coeff + i_channel * n_order]; |
| 27 | + scalar_t* output_data = |
| 28 | + reinterpret_cast<scalar_t*>(padded_output_waveform.data_ptr()); |
| 29 | + const scalar_t* input_data = |
| 30 | + reinterpret_cast<scalar_t*>(input_signal_windows.data_ptr()); |
| 31 | + const scalar_t* a_coeff_flipped_data = |
| 32 | + reinterpret_cast<scalar_t*>(a_coeff_flipped.data_ptr()); |
| 33 | + |
| 34 | + torch::stable::parallel_for( |
| 35 | + 0, n_channel * n_batch, 1, [&](int64_t begin, int64_t end) { |
| 36 | + for (auto i = begin; i < end; i++) { |
| 37 | + int64_t offset_input = i * n_samples_input; |
| 38 | + int64_t offset_output = i * n_samples_output; |
| 39 | + int64_t i_channel = i % n_channel; |
| 40 | + for (int64_t i_sample = 0; i_sample < n_samples_input; i_sample++) { |
| 41 | + scalar_t a0 = input_data[offset_input + i_sample]; |
| 42 | + for (int64_t i_coeff = 0; i_coeff < n_order; i_coeff++) { |
| 43 | + a0 -= output_data[offset_output + i_sample + i_coeff] * |
| 44 | + a_coeff_flipped_data[i_coeff + i_channel * n_order]; |
| 45 | + } |
| 46 | + output_data[offset_output + i_sample + n_order - 1] = a0; |
| 47 | + } |
34 | 48 | } |
35 | | - output_data[offset_output + i_sample + n_order - 1] = a0; |
36 | | - } |
37 | | - } |
38 | | - }); |
| 49 | + }); |
39 | 50 | } |
40 | 51 |
|
41 | | -void cpu_lfilter_core_loop( |
42 | | - const torch::Tensor& input_signal_windows, |
43 | | - const torch::Tensor& a_coeff_flipped, |
44 | | - torch::Tensor& padded_output_waveform) { |
45 | | - TORCH_CHECK( |
46 | | - input_signal_windows.device().is_cpu() && |
47 | | - a_coeff_flipped.device().is_cpu() && |
48 | | - padded_output_waveform.device().is_cpu()); |
| 52 | +Tensor cpu_lfilter_core_loop( |
| 53 | + Tensor input_signal_windows, |
| 54 | + Tensor a_coeff_flipped, |
| 55 | + Tensor padded_output_waveform) { |
| 56 | + STD_TORCH_CHECK( |
| 57 | + input_signal_windows.is_cpu() && a_coeff_flipped.is_cpu() && |
| 58 | + padded_output_waveform.is_cpu()); |
49 | 59 |
|
50 | | - TORCH_CHECK( |
| 60 | + STD_TORCH_CHECK( |
51 | 61 | input_signal_windows.is_contiguous() && a_coeff_flipped.is_contiguous() && |
52 | 62 | padded_output_waveform.is_contiguous()); |
53 | 63 |
|
54 | | - TORCH_CHECK( |
55 | | - (input_signal_windows.dtype() == torch::kFloat32 || |
56 | | - input_signal_windows.dtype() == torch::kFloat64) && |
57 | | - (a_coeff_flipped.dtype() == torch::kFloat32 || |
58 | | - a_coeff_flipped.dtype() == torch::kFloat64) && |
59 | | - (padded_output_waveform.dtype() == torch::kFloat32 || |
60 | | - padded_output_waveform.dtype() == torch::kFloat64)); |
| 64 | + STD_TORCH_CHECK( |
| 65 | + (input_signal_windows.scalar_type() == ScalarType::Float || |
| 66 | + input_signal_windows.scalar_type() == ScalarType::Double) && |
| 67 | + (a_coeff_flipped.scalar_type() == ScalarType::Float || |
| 68 | + a_coeff_flipped.scalar_type() == ScalarType::Double) && |
| 69 | + (padded_output_waveform.scalar_type() == ScalarType::Float || |
| 70 | + padded_output_waveform.scalar_type() == ScalarType::Double)); |
61 | 71 |
|
62 | | - TORCH_CHECK(input_signal_windows.size(0) == padded_output_waveform.size(0)); |
63 | | - TORCH_CHECK(input_signal_windows.size(1) == padded_output_waveform.size(1)); |
| 72 | + STD_TORCH_CHECK( |
| 73 | + input_signal_windows.size(0) == padded_output_waveform.size(0)); |
| 74 | + STD_TORCH_CHECK( |
| 75 | + input_signal_windows.size(1) == padded_output_waveform.size(1)); |
64 | 76 |
|
65 | | - TORCH_CHECK( |
| 77 | + STD_TORCH_CHECK( |
66 | 78 | input_signal_windows.size(2) + a_coeff_flipped.size(1) - 1 == |
67 | 79 | padded_output_waveform.size(2)); |
68 | 80 |
|
69 | | - AT_DISPATCH_FLOATING_TYPES( |
70 | | - input_signal_windows.scalar_type(), "lfilter_core_loop", [&] { |
| 81 | + THO_DISPATCH_V2( |
| 82 | + input_signal_windows.scalar_type(), |
| 83 | + "lfilter_core_loop", |
| 84 | + [&] { |
71 | 85 | host_lfilter_core_loop<scalar_t>( |
72 | 86 | input_signal_windows, a_coeff_flipped, padded_output_waveform); |
73 | | - }); |
| 87 | + }, |
| 88 | + AT_FLOATING_TYPES); |
| 89 | + return padded_output_waveform; |
74 | 90 | } |
75 | 91 |
|
76 | | -void lfilter_core_generic_loop( |
77 | | - const torch::Tensor& input_signal_windows, |
78 | | - const torch::Tensor& a_coeff_flipped, |
79 | | - torch::Tensor& padded_output_waveform) { |
| 92 | +Tensor lfilter_core_generic_loop( |
| 93 | + Tensor input_signal_windows, |
| 94 | + Tensor a_coeff_flipped, |
| 95 | + Tensor padded_output_waveform) { |
80 | 96 | int64_t n_samples_input = input_signal_windows.size(2); |
81 | 97 | int64_t n_order = a_coeff_flipped.size(1); |
82 | | - auto coeff = a_coeff_flipped.unsqueeze(2); |
| 98 | + auto coeff = torchaudio::stable::unsqueeze(a_coeff_flipped, 2); |
83 | 99 | for (int64_t i_sample = 0; i_sample < n_samples_input; i_sample++) { |
84 | | - auto windowed_output_signal = |
85 | | - torch::narrow(padded_output_waveform, 2, i_sample, i_sample + n_order) |
86 | | - .transpose(0, 1); |
87 | | - auto o0 = torch::select(input_signal_windows, 2, i_sample) - |
88 | | - at::matmul(windowed_output_signal, coeff).squeeze(2).transpose(0, 1); |
89 | | - padded_output_waveform.index_put_( |
90 | | - {torch::indexing::Slice(), |
91 | | - torch::indexing::Slice(), |
92 | | - i_sample + n_order - 1}, |
93 | | - o0); |
| 100 | + auto windowed_output_signal = torch::stable::transpose( |
| 101 | + torch::stable::narrow( |
| 102 | + padded_output_waveform, 2, i_sample, i_sample + n_order), |
| 103 | + 0, |
| 104 | + 1); |
| 105 | + auto o0 = torchaudio::stable::subtract( |
| 106 | + torchaudio::stable::select(input_signal_windows, 2, i_sample), |
| 107 | + torch::stable::transpose( |
| 108 | + torchaudio::stable::squeeze( |
| 109 | + torchaudio::stable::matmul(windowed_output_signal, coeff), 2), |
| 110 | + 0, |
| 111 | + 1)); |
| 112 | + auto s = torchaudio::stable::select( |
| 113 | + padded_output_waveform, 2, i_sample + n_order - 1); |
| 114 | + torch::stable::copy_(s, o0); |
94 | 115 | } |
| 116 | + return padded_output_waveform; |
95 | 117 | } |
96 | 118 |
|
97 | 119 | } // namespace |
98 | 120 |
|
99 | | -TORCH_LIBRARY(torchaudio, m) { |
| 121 | +STABLE_TORCH_LIBRARY_FRAGMENT(torchaudio, m) { |
100 | 122 | m.def( |
101 | | - "torchaudio::_lfilter_core_loop(Tensor input_signal_windows, Tensor a_coeff_flipped, Tensor(a!) padded_output_waveform) -> ()"); |
| 123 | + "_lfilter_core_loop(" |
| 124 | + "Tensor input_signal_windows," |
| 125 | + "Tensor a_coeff_flipped," |
| 126 | + "Tensor(a!) padded_output_waveform) -> Tensor(a!)"); |
102 | 127 | } |
103 | 128 |
|
104 | | -TORCH_LIBRARY_IMPL(torchaudio, CPU, m) { |
105 | | - m.impl("torchaudio::_lfilter_core_loop", &cpu_lfilter_core_loop); |
| 129 | +STABLE_TORCH_LIBRARY_IMPL(torchaudio, CPU, m) { |
| 130 | + m.impl("_lfilter_core_loop", TORCH_BOX(&cpu_lfilter_core_loop)); |
106 | 131 | } |
107 | 132 |
|
108 | 133 | #ifdef USE_CUDA |
109 | | -TORCH_LIBRARY_IMPL(torchaudio, CUDA, m) { |
110 | | - m.impl("torchaudio::_lfilter_core_loop", &cuda_lfilter_core_loop); |
| 134 | +STABLE_TORCH_LIBRARY_IMPL(torchaudio, CUDA, m) { |
| 135 | + m.impl("_lfilter_core_loop", TORCH_BOX(&cuda_lfilter_core_loop)); |
111 | 136 | } |
112 | 137 | #endif |
113 | 138 |
|
114 | | -TORCH_LIBRARY_IMPL(torchaudio, CompositeExplicitAutograd, m) { |
115 | | - m.impl("torchaudio::_lfilter_core_loop", &lfilter_core_generic_loop); |
| 139 | +STABLE_TORCH_LIBRARY_IMPL(torchaudio, CompositeExplicitAutograd, m) { |
| 140 | + m.impl("_lfilter_core_loop", TORCH_BOX(&lfilter_core_generic_loop)); |
116 | 141 | } |
0 commit comments