Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Windows compatibility #29

Open
wants to merge 4 commits into
base: v2
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions examples/fbp.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,17 +3,17 @@
import torch
from utils import show_images

from torch_radon import Radon
from torch_radon import ParallelBeam

device = torch.device('cuda')

img = np.load("phantom.npy")
image_size = img.shape[0]
n_angles = image_size

# Instantiate Radon transform. clip_to_circle should be True when using filtered backprojection.
# Instantiate Radon transform.
angles = np.linspace(0, np.pi, n_angles, endpoint=False)
radon = Radon(image_size, angles, clip_to_circle=True)
radon = ParallelBeam(image_size, angles)

with torch.no_grad():
x = torch.FloatTensor(img).to(device)
Expand Down
16 changes: 14 additions & 2 deletions include/rmath.h
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,18 @@

#include <stdint.h>
#include <limits>
#include <cstdlib>

#ifdef __GNUC__
#define EXPECT_FALSE(x) __builtin_expect(x, false)
#define EXPECT_TRUE(x) __builtin_expect(x, true)

#pragma GCC push_options
#pragma GCC optimize ("03", "no-fast-math")
#else
#define EXPECT_FALSE(x) x
#define EXPECT_TRUE(x) x
#endif // __GNUC__

namespace rosh
{
Expand Down Expand Up @@ -200,10 +206,14 @@ namespace rosh

inline float sqrt(float x)
{
#ifdef __GNUC__
__asm__("sqrtss %1, %0"
: "=x"(x)
: "x"(x));
: "=x"(x)
: "x"(x));
return x;
#else
return std::sqrtf(x);
#endif // __GNUC__
}

inline float hypot(float x, float y)
Expand Down Expand Up @@ -388,5 +398,7 @@ namespace rosh
}
}

#ifdef __GNUC__
#pragma GCC pop_options
#endif // __GNUC__
#endif
25 changes: 17 additions & 8 deletions setup.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
from setuptools import setup
from torch.utils.cpp_extension import BuildExtension, CUDAExtension
import os
from make import build
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The setup is used also in the CI/CD pipeline to create the precompiled packages so I would keep both the Linux version with the make.py file and the new Windows version ad add an if os.name=='posix' to choose between the two versions of the code

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah I see. I will add the if os.name=='posix'.
Since this will include setting cuda_home = os.getenv("CUDA_HOME", "/usr/local/cuda") again, what do you think about replacing this with from torch.utils.cpp_extension import CUDA_HOME? On Manjaro, for example, the path to cuda is CUDA_PATH="/opt/cuda" and CUDA_HOME is not set. The torch import would take care of that.

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Perfect, thanks.
Replacing the CUDA_HOME definition seems a good idea!

# from make import build

with open("README.md", "r") as fh:
long_description = fh.read()

cuda_home = os.getenv("CUDA_HOME", "/usr/local/cuda")
print(f"Using CUDA_HOME={cuda_home}")
build(cuda_home=cuda_home)
# cuda_home = os.getenv("CUDA_HOME", "/usr/local/cuda")
# print(f"Using CUDA_HOME={cuda_home}")
# build(cuda_home=cuda_home)

setup(name='torch_radon',
version="2.0.0",
Expand All @@ -24,13 +24,22 @@
'torch_radon': './torch_radon',
},
ext_modules=[
CUDAExtension('torch_radon_cuda', [os.path.abspath('src/pytorch.cpp')],
CUDAExtension('torch_radon_cuda',
[
os.path.abspath(os.path.join('src', f))
for f in os.listdir('src')
if f.endswith('.cpp') or f.endswith('.cu')
],
include_dirs=[os.path.abspath('include')],
library_dirs=[os.path.abspath("objs")],
libraries=["m", "c", "gcc", "stdc++", "cufft", "radon"],
libraries=["cufft"],
extra_compile_args={
"cxx": ["-std=c++17" if os.name=='posix' else "/std:c++17"]
},
# libraries=["m", "c", "gcc", "stdc++", "cufft", "radon"],
# extra_compile_args=["-static", "-static-libgcc", "-static-libstdc++"],
# strip debug symbols
extra_link_args=["-Wl,--strip-all"] #, "-static-libgcc", "-static-libstdc++"]
# extra_link_args=["-Wl,--strip-all"] #, "-static-libgcc", "-static-libstdc++"]
)
],
cmdclass={'build_ext': BuildExtension},
Expand All @@ -42,7 +51,7 @@
],
install_requires=[
"scipy",
"alpha-transform"
# "alpha-transform" # not available on Windows
],
)

Expand Down
28 changes: 26 additions & 2 deletions src/backprojection.cu
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,30 @@
#include "texture.h"
#include "backprojection.h"

namespace
{
template<typename T>
__device__ T toType(float);

template<>
__device__ float toType(float f)
Copy link
Owner

@matteo-ronchetti matteo-ronchetti Sep 13, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why are these required?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The cuda compiler on windows doesn't allow implicit casts from 32bit floats to 16bit floats. In radon_backward_kernel, for example, accumulator is a 32bit float array, but output might be 16bit float. toType casts 32bit floats to the target type.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I just found out that CUDAExtension defines __CUDA_NO_HALF_CONVERSIONS__ for nvcc, which deactivates these implicit conversions. Since you don't compile the cuda files with CUDAExtension, you didn't get these errors.
There are 2 possibilities now:

  1. keep the templated casts in the cuda files
  2. undefine __CUDA_NO_HALF_CONVERSIONS__ for CUDAExtension

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am ok with keeping the templated casts. Please put them in a separate file in order to avoid duplication of code

{
return f;
};

template<>
__device__ __half toType(float f)
{
return __float2half(f);
};

template<>
__device__ unsigned short toType(float f)
{
return static_cast<unsigned short>(f);
};
}

template<bool parallel_beam, int channels, typename T>
__global__ void
radon_backward_kernel(T *__restrict__ output, cudaTextureObject_t texture, const float *__restrict__ angles,
Expand Down Expand Up @@ -98,7 +122,7 @@ radon_backward_kernel(T *__restrict__ output, cudaTextureObject_t texture, const

#pragma unroll
for (int b = 0; b < channels; b++) {
output[base + b * pitch] = accumulator[b] * ids;
output[base + b * pitch] = toType<T>(accumulator[b] * ids);
}
}
}
Expand Down Expand Up @@ -247,7 +271,7 @@ radon_backward_kernel_3d(T *__restrict__ output, cudaTextureObject_t texture, co

#pragma unroll
for (int b = 0; b < channels; b++) {
output[b * pitch + index] = accumulator[b] * ids;
output[b * pitch + index] = toType<T>(accumulator[b] * ids);
}
}
}
Expand Down
34 changes: 29 additions & 5 deletions src/forward.cu
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
#include <iostream>
#include <cuda.h>
#include <cuda_runtime.h>
#include <cuda_fp16.h>
Expand All @@ -9,6 +8,31 @@
#include "log.h"


namespace
{
template<typename T>
__device__ T toType(float);

template<>
__device__ float toType(float f)
{
return f;
};

template<>
__device__ __half toType(float f)
{
return __float2half(f);
};

template<>
__device__ unsigned short toType(float f)
{
return static_cast<unsigned short>(f);
};
}


template<bool parallel_beam, int channels, typename T>
__global__ void
radon_forward_kernel(T *__restrict__ output, cudaTextureObject_t texture, const float *__restrict__ angles,
Expand Down Expand Up @@ -75,7 +99,7 @@ radon_forward_kernel(T *__restrict__ output, cudaTextureObject_t texture, const
// if ray volume intersection is empty exit
if (alpha_s > alpha_e) {
#pragma unroll
for (int b = 0; b < channels; b++) output[base + b * mem_pitch] = 0.0f;
for (int b = 0; b < channels; b++) output[base + b * mem_pitch] = toType<T>(0.0f);
return;
}

Expand Down Expand Up @@ -119,7 +143,7 @@ radon_forward_kernel(T *__restrict__ output, cudaTextureObject_t texture, const
}

#pragma unroll
for (int b = 0; b < channels; b++) output[base + b * mem_pitch] = accumulator[b] * n;
for (int b = 0; b < channels; b++) output[base + b * mem_pitch] = toType<T>(accumulator[b] * n);
}
}

Expand Down Expand Up @@ -255,7 +279,7 @@ radon_forward_kernel_3d(T *__restrict__ output, cudaTextureObject_t texture, con

if (alpha_s > alpha_e) {
#pragma unroll
for (int b = 0; b < channels; b++) output[b * mem_pitch + index] = 0.0f;
for (int b = 0; b < channels; b++) output[b * mem_pitch + index] = toType<T>(0.0f);
return;
}

Expand Down Expand Up @@ -308,7 +332,7 @@ radon_forward_kernel_3d(T *__restrict__ output, cudaTextureObject_t texture, con
// output
#pragma unroll
for (int b = 0; b < channels; b++) {
output[b * mem_pitch + index] = accumulator[b] * n;
output[b * mem_pitch + index] = toType<T>(accumulator[b] * n);
}
}
}
Expand Down
62 changes: 30 additions & 32 deletions src/symbolic.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,35 +2,33 @@
#include "log.h"
#include "rmath.h"

using namespace rosh;

Gaussian::Gaussian(float _k, float _cx, float _cy, float _a, float _b) : k(_k), cx(_cx), cy(_cy), a(_a), b(_b) {}

float Gaussian::line_integral(float s_x, float s_y, float e_x, float e_y) const
{
float x0 = a * sq(s_x);
float x1 = b * sq(s_y);
float x0 = a * rosh::sq(s_x);
float x1 = b * rosh::sq(s_y);
float x2 = 2 * a * e_x;
float x3 = 2 * b * e_y;
float x4 = s_x * x2;
float x5 = s_y * x3;
float x6 = 2 * a * cx * s_x + 2 * b * cy * s_y;
float x7 = -cx * x2 - cy * x3 - 2 * x0 - 2 * x1 + x4 + x5 + x6;
float x8 = max(a * sq(e_x) + b * sq(e_y) + x0 + x1 - x4 - x5, 1e-6);
float x9 = sqrt(x8);
float x8 = rosh::max(a * rosh::sq(e_x) + b * rosh::sq(e_y) + x0 + x1 - x4 - x5, 1e-6f);
float x9 = rosh::sqrt(x8);
float x10 = (1.0 / 2.0) / x9;
float x11 = x10 * x7;
float lg_x12 = log(sqrt(rosh::pi) * x10) - a * sq(cx) - b * sq(cy) - x0 - x1 + x6 + (1.0 / 4.0) * sq(x7) / x8;
float lg_x12 = rosh::log(rosh::sqrt(rosh::pi) * x10) - a * rosh::sq(cx) - b * rosh::sq(cy) - x0 - x1 + x6 + (1.0 / 4.0) * rosh::sq(x7) / x8;

// this is not precise
if (lg_x12 >= 5)
{
return 0.0f;
}

float len = hypot(e_x - s_x, e_y - s_y);
float len = rosh::hypot(e_x - s_x, e_y - s_y);

float y = k * len * exp(lg_x12) * (-erf(x11) + erf(x11 + x9));
float y = k * len * rosh::exp(lg_x12) * (-rosh::erf(x11) + rosh::erf(x11 + x9));

// if(y != y){
// LOG_ERROR("len: " << len << " x11: " << x11 << " erf(x11): " << erf(x11) << " x8: " << x8);
Expand All @@ -45,7 +43,7 @@ float Gaussian::evaluate(float x, float y) const
float dx = x - cx;
float dy = y - cy;

return k * exp(-a * dx * dx - b * dy * dy);
return k * rosh::exp(-a * dx * dx - b * dy * dy);
}

void Gaussian::move(float dx, float dy)
Expand Down Expand Up @@ -80,11 +78,11 @@ float Ellipse::line_integral(float s_x, float s_y, float e_x, float e_y) const
return 0.0f;

// min_clip to 1 to avoid getting empty rays
const float delta_sqrt = sqrt(delta);
const float alpha_s = min(max((-b - delta_sqrt) / a, 0.0f), 1.0f);
const float alpha_e = min(max((-b + delta_sqrt) / a, 0.0f), 1.0f);
const float delta_sqrt = rosh::sqrt(delta);
const float alpha_s = rosh::min(rosh::max((-b - delta_sqrt) / a, 0.0f), 1.0f);
const float alpha_e = rosh::min(rosh::max((-b + delta_sqrt) / a, 0.0f), 1.0f);

return hypot(dx, e_y - s_y) * (alpha_e - alpha_s);
return rosh::hypot(dx, e_y - s_y) * (alpha_e - alpha_s);
}

float Ellipse::evaluate(float x, float y) const
Expand All @@ -94,15 +92,15 @@ float Ellipse::evaluate(float x, float y) const
float dy = aspect * (cy - y);
constexpr float r = 1.0f / 3;

tmp += hypot(dx - r, dy - r) <= radius_x;
tmp += hypot(dx - r, dy) <= radius_x;
tmp += hypot(dx - r, dy + r) <= radius_x;
tmp += hypot(dx, dy - r) <= radius_x;
tmp += hypot(dx, dy) <= radius_x;
tmp += hypot(dx, dy + r) <= radius_x;
tmp += hypot(dx + r, dy - r) <= radius_x;
tmp += hypot(dx + r, dy) <= radius_x;
tmp += hypot(dx + r, dy + r) <= radius_x;
tmp += rosh::hypot(dx - r, dy - r) <= radius_x;
tmp += rosh::hypot(dx - r, dy) <= radius_x;
tmp += rosh::hypot(dx - r, dy + r) <= radius_x;
tmp += rosh::hypot(dx, dy - r) <= radius_x;
tmp += rosh::hypot(dx, dy) <= radius_x;
tmp += rosh::hypot(dx, dy + r) <= radius_x;
tmp += rosh::hypot(dx + r, dy - r) <= radius_x;
tmp += rosh::hypot(dx + r, dy) <= radius_x;
tmp += rosh::hypot(dx + r, dy + r) <= radius_x;

return tmp / 9.0f;
}
Expand Down Expand Up @@ -170,10 +168,10 @@ void SymbolicFunction::scale(float sx, float sy)

float SymbolicFunction::max_distance_from_origin() const
{
float x = max(abs(min_x), abs(max_x));
float y = max(abs(min_y), abs(max_y));
float x = rosh::max(rosh::abs(min_x), rosh::abs(max_x));
float y = rosh::max(rosh::abs(min_y), rosh::abs(max_y));

return hypot(x, y);
return rosh::hypot(x, y);
}

void SymbolicFunction::discretize(float *data, int h, int w) const
Expand Down Expand Up @@ -208,15 +206,15 @@ float SymbolicFunction::line_integral(float s_x, float s_y, float e_x, float e_y
// clip segment to function domain
float dx = e_x - s_x;
float dy = e_y - s_y;
dx = dx >= 0 ? max(dx, 1e-6f) : min(dx, -1e-6f);
dy = dy >= 0 ? max(dy, 1e-6f) : min(dy, -1e-6f);
dx = dx >= 0 ? rosh::max(dx, 1e-6f) : rosh::min(dx, -1e-6f);
dy = dy >= 0 ? rosh::max(dy, 1e-6f) : rosh::min(dy, -1e-6f);

const float alpha_x_m = (min_x - s_x) / dx;
const float alpha_x_p = (max_x - s_x) / dx;
const float alpha_y_m = (min_y - s_y) / dy;
const float alpha_y_p = (max_y - s_y) / dy;
const float alpha_s = max(min(alpha_x_p, alpha_x_m), min(alpha_y_p, alpha_y_m));
const float alpha_e = min(max(alpha_x_p, alpha_x_m), max(alpha_y_p, alpha_y_m));
const float alpha_s = rosh::max(rosh::min(alpha_x_p, alpha_x_m), rosh::min(alpha_y_p, alpha_y_m));
const float alpha_e = rosh::min(rosh::max(alpha_x_p, alpha_x_m), rosh::max(alpha_y_p, alpha_y_m));

if (alpha_s >= alpha_e)
{
Expand Down Expand Up @@ -262,8 +260,8 @@ void symbolic_forward(const SymbolicFunction &f, const ProjectionCfg &proj, cons

// rotate ray
const float angle = angles[angle_id];
const float cs = cos(angle);
const float sn = sin(angle);
const float cs = rosh::cos(angle);
const float sn = rosh::sin(angle);

float rsx = sx * cs + sy * sn;
float rsy = -sx * sn + sy * cs;
Expand Down
2 changes: 1 addition & 1 deletion src/texture.cu
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ write_half_to_surface(const __half *data, cudaSurfaceObject_t surface, const int
const int offset = (z * height + y) * width + x;

__half tmp[4];
for (int i = 0; i < 4; i++) tmp[i] = __float2half(data[i * pitch + offset]);
for (int i = 0; i < 4; i++) tmp[i] = data[i * pitch + offset];

switch(texture_type){
case TEX_1D_LAYERED:
Expand Down
1 change: 0 additions & 1 deletion torch_radon/filtering.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import numpy as np
import torch
import torch.fft

try:
import scipy.fft
Expand Down