diff --git a/examples/fbp.py b/examples/fbp.py index 8fd4898..fb80d8f 100644 --- a/examples/fbp.py +++ b/examples/fbp.py @@ -3,7 +3,7 @@ import torch from utils import show_images -from torch_radon import Radon +from torch_radon import ParallelBeam device = torch.device('cuda') @@ -11,9 +11,9 @@ image_size = img.shape[0] n_angles = image_size -# Instantiate Radon transform. clip_to_circle should be True when using filtered backprojection. +# Instantiate Radon transform. angles = np.linspace(0, np.pi, n_angles, endpoint=False) -radon = Radon(image_size, angles, clip_to_circle=True) +radon = ParallelBeam(image_size, angles) with torch.no_grad(): x = torch.FloatTensor(img).to(device) diff --git a/include/floatcast.h b/include/floatcast.h new file mode 100644 index 0000000..cf2c4ac --- /dev/null +++ b/include/floatcast.h @@ -0,0 +1,27 @@ +#ifndef TORCH_RADON_FLOATCAST_H +#define TORCH_RADON_FLOATCAST_H + +#include + +template +__device__ T toType(float); + +template<> +inline __device__ float toType(float f) +{ + return f; +}; + +template<> +inline __device__ __half toType(float f) +{ + return __float2half(f); +}; + +template<> +inline __device__ unsigned short toType(float f) +{ + return static_cast(f); +}; + +#endif // TORCH_RADON_FLOATCAST_H diff --git a/include/rmath.h b/include/rmath.h index df6a25f..b0ec187 100644 --- a/include/rmath.h +++ b/include/rmath.h @@ -3,12 +3,18 @@ #include #include +#include +#ifdef __GNUC__ #define EXPECT_FALSE(x) __builtin_expect(x, false) #define EXPECT_TRUE(x) __builtin_expect(x, true) #pragma GCC push_options #pragma GCC optimize ("03", "no-fast-math") +#else +#define EXPECT_FALSE(x) x +#define EXPECT_TRUE(x) x +#endif // __GNUC__ namespace rosh { @@ -200,10 +206,14 @@ namespace rosh inline float sqrt(float x) { +#ifdef __GNUC__ __asm__("sqrtss %1, %0" - : "=x"(x) - : "x"(x)); + : "=x"(x) + : "x"(x)); return x; +#else + return std::sqrtf(x); +#endif // __GNUC__ } inline float hypot(float x, float y) @@ -388,5 +398,7 @@ namespace rosh } } +#ifdef __GNUC__ #pragma GCC pop_options +#endif // __GNUC__ #endif \ No newline at end of file diff --git a/setup.py b/setup.py index a5ba69d..fff39fd 100644 --- a/setup.py +++ b/setup.py @@ -1,14 +1,45 @@ from setuptools import setup from torch.utils.cpp_extension import BuildExtension, CUDAExtension import os -from make import build with open("README.md", "r") as fh: long_description = fh.read() -cuda_home = os.getenv("CUDA_HOME", "/usr/local/cuda") -print(f"Using CUDA_HOME={cuda_home}") -build(cuda_home=cuda_home) +def _create_cuda_extension(os_name: str): + if os_name == 'posix': + return CUDAExtension( + 'torch_radon_cuda', + [os.path.abspath('src/pytorch.cpp')], + include_dirs=[os.path.abspath('include')], + library_dirs=[os.path.abspath("objs")], + libraries=["m", "c", "gcc", "stdc++", "cufft", "radon"], + extra_compile_args=["-fopenmp"], + # extra_compile_args=["-static", "-static-libgcc", "-static-libstdc++"], + # strip debug symbols + extra_link_args=["-Wl,--strip-all"] #, "-static-libgcc", "-static-libstdc++"] + ) + + if os_name == 'nt': + return CUDAExtension( + 'torch_radon_cuda', + [ + os.path.abspath(os.path.join('src', f)) + for f in os.listdir('src') + if f.endswith('.cpp') or f.endswith('.cu') + ], + include_dirs=[os.path.abspath('include')], + libraries=["cufft"], + extra_compile_args={"cxx": ["/std:c++17"]}, + ) + + raise NotImplementedError(f"OS \"{os.name}\" not implemented.") + +if os.name == 'posix': + from torch.utils.cpp_extension import CUDA_HOME + from make import build + compiler = os.environ.get('CXX', 'g++') + print(f"Using CUDA_HOME={CUDA_HOME}, CXX={compiler}") + build(cuda_home=CUDA_HOME, cxx=compiler) setup(name='torch_radon', version="2.0.0", @@ -24,26 +55,18 @@ 'torch_radon': './torch_radon', }, ext_modules=[ - CUDAExtension('torch_radon_cuda', [os.path.abspath('src/pytorch.cpp')], - include_dirs=[os.path.abspath('include')], - library_dirs=[os.path.abspath("objs")], - libraries=["m", "c", "gcc", "stdc++", "cufft", "radon"], - # extra_compile_args=["-static", "-static-libgcc", "-static-libstdc++"], - # strip debug symbols - extra_link_args=["-Wl,--strip-all"] #, "-static-libgcc", "-static-libstdc++"] - ) + _create_cuda_extension(os.name) ], cmdclass={'build_ext': BuildExtension}, zip_safe=False, classifiers=[ "Programming Language :: Python :: 3", "Operating System :: POSIX :: Linux", + "Operating System :: Microsoft :: Windows", "License :: OSI Approved :: GNU General Public License v3 (GPLv3)" ], - install_requires=[ - "scipy", - "alpha-transform" - ], + install_requires=["scipy"] + + (["alpha-transform"] if os.name=="posix" else []), ) diff --git a/src/backprojection.cu b/src/backprojection.cu index ec2beed..9879cfe 100644 --- a/src/backprojection.cu +++ b/src/backprojection.cu @@ -3,10 +3,12 @@ #include #include +#include "floatcast.h" #include "utils.h" #include "texture.h" #include "backprojection.h" + template __global__ void radon_backward_kernel(T *__restrict__ output, cudaTextureObject_t texture, const float *__restrict__ angles, @@ -98,7 +100,7 @@ radon_backward_kernel(T *__restrict__ output, cudaTextureObject_t texture, const #pragma unroll for (int b = 0; b < channels; b++) { - output[base + b * pitch] = accumulator[b] * ids; + output[base + b * pitch] = toType(accumulator[b] * ids); } } } @@ -247,7 +249,7 @@ radon_backward_kernel_3d(T *__restrict__ output, cudaTextureObject_t texture, co #pragma unroll for (int b = 0; b < channels; b++) { - output[b * pitch + index] = accumulator[b] * ids; + output[b * pitch + index] = toType(accumulator[b] * ids); } } } diff --git a/src/forward.cu b/src/forward.cu index b963a51..d7b9d31 100644 --- a/src/forward.cu +++ b/src/forward.cu @@ -1,8 +1,8 @@ -#include #include #include #include +#include "floatcast.h" #include "utils.h" #include "texture.h" #include "parameter_classes.h" @@ -75,7 +75,7 @@ radon_forward_kernel(T *__restrict__ output, cudaTextureObject_t texture, const // if ray volume intersection is empty exit if (alpha_s > alpha_e) { #pragma unroll - for (int b = 0; b < channels; b++) output[base + b * mem_pitch] = 0.0f; + for (int b = 0; b < channels; b++) output[base + b * mem_pitch] = toType(0.0f); return; } @@ -119,7 +119,7 @@ radon_forward_kernel(T *__restrict__ output, cudaTextureObject_t texture, const } #pragma unroll - for (int b = 0; b < channels; b++) output[base + b * mem_pitch] = accumulator[b] * n; + for (int b = 0; b < channels; b++) output[base + b * mem_pitch] = toType(accumulator[b] * n); } } @@ -255,7 +255,7 @@ radon_forward_kernel_3d(T *__restrict__ output, cudaTextureObject_t texture, con if (alpha_s > alpha_e) { #pragma unroll - for (int b = 0; b < channels; b++) output[b * mem_pitch + index] = 0.0f; + for (int b = 0; b < channels; b++) output[b * mem_pitch + index] = toType(0.0f); return; } @@ -308,7 +308,7 @@ radon_forward_kernel_3d(T *__restrict__ output, cudaTextureObject_t texture, con // output #pragma unroll for (int b = 0; b < channels; b++) { - output[b * mem_pitch + index] = accumulator[b] * n; + output[b * mem_pitch + index] = toType(accumulator[b] * n); } } } diff --git a/src/symbolic.cpp b/src/symbolic.cpp index 878f32a..71cfd67 100644 --- a/src/symbolic.cpp +++ b/src/symbolic.cpp @@ -2,25 +2,23 @@ #include "log.h" #include "rmath.h" -using namespace rosh; - Gaussian::Gaussian(float _k, float _cx, float _cy, float _a, float _b) : k(_k), cx(_cx), cy(_cy), a(_a), b(_b) {} float Gaussian::line_integral(float s_x, float s_y, float e_x, float e_y) const { - float x0 = a * sq(s_x); - float x1 = b * sq(s_y); + float x0 = a * rosh::sq(s_x); + float x1 = b * rosh::sq(s_y); float x2 = 2 * a * e_x; float x3 = 2 * b * e_y; float x4 = s_x * x2; float x5 = s_y * x3; float x6 = 2 * a * cx * s_x + 2 * b * cy * s_y; float x7 = -cx * x2 - cy * x3 - 2 * x0 - 2 * x1 + x4 + x5 + x6; - float x8 = max(a * sq(e_x) + b * sq(e_y) + x0 + x1 - x4 - x5, 1e-6); - float x9 = sqrt(x8); + float x8 = rosh::max(a * rosh::sq(e_x) + b * rosh::sq(e_y) + x0 + x1 - x4 - x5, 1e-6f); + float x9 = rosh::sqrt(x8); float x10 = (1.0 / 2.0) / x9; float x11 = x10 * x7; - float lg_x12 = log(sqrt(rosh::pi) * x10) - a * sq(cx) - b * sq(cy) - x0 - x1 + x6 + (1.0 / 4.0) * sq(x7) / x8; + float lg_x12 = rosh::log(rosh::sqrt(rosh::pi) * x10) - a * rosh::sq(cx) - b * rosh::sq(cy) - x0 - x1 + x6 + (1.0 / 4.0) * rosh::sq(x7) / x8; // this is not precise if (lg_x12 >= 5) @@ -28,9 +26,9 @@ float Gaussian::line_integral(float s_x, float s_y, float e_x, float e_y) const return 0.0f; } - float len = hypot(e_x - s_x, e_y - s_y); + float len = rosh::hypot(e_x - s_x, e_y - s_y); - float y = k * len * exp(lg_x12) * (-erf(x11) + erf(x11 + x9)); + float y = k * len * rosh::exp(lg_x12) * (-rosh::erf(x11) + rosh::erf(x11 + x9)); // if(y != y){ // LOG_ERROR("len: " << len << " x11: " << x11 << " erf(x11): " << erf(x11) << " x8: " << x8); @@ -45,7 +43,7 @@ float Gaussian::evaluate(float x, float y) const float dx = x - cx; float dy = y - cy; - return k * exp(-a * dx * dx - b * dy * dy); + return k * rosh::exp(-a * dx * dx - b * dy * dy); } void Gaussian::move(float dx, float dy) @@ -80,11 +78,11 @@ float Ellipse::line_integral(float s_x, float s_y, float e_x, float e_y) const return 0.0f; // min_clip to 1 to avoid getting empty rays - const float delta_sqrt = sqrt(delta); - const float alpha_s = min(max((-b - delta_sqrt) / a, 0.0f), 1.0f); - const float alpha_e = min(max((-b + delta_sqrt) / a, 0.0f), 1.0f); + const float delta_sqrt = rosh::sqrt(delta); + const float alpha_s = rosh::min(rosh::max((-b - delta_sqrt) / a, 0.0f), 1.0f); + const float alpha_e = rosh::min(rosh::max((-b + delta_sqrt) / a, 0.0f), 1.0f); - return hypot(dx, e_y - s_y) * (alpha_e - alpha_s); + return rosh::hypot(dx, e_y - s_y) * (alpha_e - alpha_s); } float Ellipse::evaluate(float x, float y) const @@ -94,15 +92,15 @@ float Ellipse::evaluate(float x, float y) const float dy = aspect * (cy - y); constexpr float r = 1.0f / 3; - tmp += hypot(dx - r, dy - r) <= radius_x; - tmp += hypot(dx - r, dy) <= radius_x; - tmp += hypot(dx - r, dy + r) <= radius_x; - tmp += hypot(dx, dy - r) <= radius_x; - tmp += hypot(dx, dy) <= radius_x; - tmp += hypot(dx, dy + r) <= radius_x; - tmp += hypot(dx + r, dy - r) <= radius_x; - tmp += hypot(dx + r, dy) <= radius_x; - tmp += hypot(dx + r, dy + r) <= radius_x; + tmp += rosh::hypot(dx - r, dy - r) <= radius_x; + tmp += rosh::hypot(dx - r, dy) <= radius_x; + tmp += rosh::hypot(dx - r, dy + r) <= radius_x; + tmp += rosh::hypot(dx, dy - r) <= radius_x; + tmp += rosh::hypot(dx, dy) <= radius_x; + tmp += rosh::hypot(dx, dy + r) <= radius_x; + tmp += rosh::hypot(dx + r, dy - r) <= radius_x; + tmp += rosh::hypot(dx + r, dy) <= radius_x; + tmp += rosh::hypot(dx + r, dy + r) <= radius_x; return tmp / 9.0f; } @@ -170,10 +168,10 @@ void SymbolicFunction::scale(float sx, float sy) float SymbolicFunction::max_distance_from_origin() const { - float x = max(abs(min_x), abs(max_x)); - float y = max(abs(min_y), abs(max_y)); + float x = rosh::max(rosh::abs(min_x), rosh::abs(max_x)); + float y = rosh::max(rosh::abs(min_y), rosh::abs(max_y)); - return hypot(x, y); + return rosh::hypot(x, y); } void SymbolicFunction::discretize(float *data, int h, int w) const @@ -208,15 +206,15 @@ float SymbolicFunction::line_integral(float s_x, float s_y, float e_x, float e_y // clip segment to function domain float dx = e_x - s_x; float dy = e_y - s_y; - dx = dx >= 0 ? max(dx, 1e-6f) : min(dx, -1e-6f); - dy = dy >= 0 ? max(dy, 1e-6f) : min(dy, -1e-6f); + dx = dx >= 0 ? rosh::max(dx, 1e-6f) : rosh::min(dx, -1e-6f); + dy = dy >= 0 ? rosh::max(dy, 1e-6f) : rosh::min(dy, -1e-6f); const float alpha_x_m = (min_x - s_x) / dx; const float alpha_x_p = (max_x - s_x) / dx; const float alpha_y_m = (min_y - s_y) / dy; const float alpha_y_p = (max_y - s_y) / dy; - const float alpha_s = max(min(alpha_x_p, alpha_x_m), min(alpha_y_p, alpha_y_m)); - const float alpha_e = min(max(alpha_x_p, alpha_x_m), max(alpha_y_p, alpha_y_m)); + const float alpha_s = rosh::max(rosh::min(alpha_x_p, alpha_x_m), rosh::min(alpha_y_p, alpha_y_m)); + const float alpha_e = rosh::min(rosh::max(alpha_x_p, alpha_x_m), rosh::max(alpha_y_p, alpha_y_m)); if (alpha_s >= alpha_e) { @@ -262,8 +260,8 @@ void symbolic_forward(const SymbolicFunction &f, const ProjectionCfg &proj, cons // rotate ray const float angle = angles[angle_id]; - const float cs = cos(angle); - const float sn = sin(angle); + const float cs = rosh::cos(angle); + const float sn = rosh::sin(angle); float rsx = sx * cs + sy * sn; float rsy = -sx * sn + sy * cs; diff --git a/src/texture.cu b/src/texture.cu index 375861a..d37c401 100644 --- a/src/texture.cu +++ b/src/texture.cu @@ -77,7 +77,7 @@ write_half_to_surface(const __half *data, cudaSurfaceObject_t surface, const int const int offset = (z * height + y) * width + x; __half tmp[4]; - for (int i = 0; i < 4; i++) tmp[i] = __float2half(data[i * pitch + offset]); + for (int i = 0; i < 4; i++) tmp[i] = data[i * pitch + offset]; switch(texture_type){ case TEX_1D_LAYERED: diff --git a/torch_radon/filtering.py b/torch_radon/filtering.py index a0deb9d..84a4b8c 100644 --- a/torch_radon/filtering.py +++ b/torch_radon/filtering.py @@ -1,6 +1,5 @@ import numpy as np import torch -import torch.fft try: import scipy.fft