Skip to content
This repository was archived by the owner on Mar 12, 2021. It is now read-only.

Commit 487c07d

Browse files
committed
Adapting ConvDims to Depthwise and Groupwise
1 parent b3aca6c commit 487c07d

File tree

4 files changed

+12
-12
lines changed

4 files changed

+12
-12
lines changed

src/dnn/compat.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,11 @@
11
# Compatibility shims until users upgrade to new NNlib format
22
function conv!(y::CuArray{T}, x::CuArray{T}, w::CuArray{T}; pad=0, stride=1, flipkernel=0, dilation=1, kwargs...) where {T<:CUDNNFloat}
3-
cdims = DenseConvDims(x, w; padding=pad, stride=stride, flipkernel=flipkernel, dilation=dilation)
3+
cdims = ConvDims(x, w; padding=pad, stride=stride, flipkernel=flipkernel, dilation=dilation)
44
return conv!(y, x, w, cdims; kwargs...)
55
end
66

77
function ∇conv_filter!(dw::CuArray{T}, dy::CuArray{T}, x::CuArray{T}; pad=0, stride=1, flipkernel=0, dilation=1, kwargs...) where {T<:CUDNNFloat}
8-
cdims = DenseConvDims(x, dw; padding=pad, stride=stride, flipkernel=flipkernel, dilation=dilation)
8+
cdims = ConvDims(x, dw; padding=pad, stride=stride, flipkernel=flipkernel, dilation=dilation)
99
# NOTE!!! This compat shim re-arranges the argument order!
1010
return ∇conv_filter!(dw, x, dy, cdims; kwargs...)
1111
end

src/dnn/conv.jl

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
using NNlib: DenseConvDims
1+
using NNlib: ConvDims
22

33

44
# descriptor
@@ -44,7 +44,7 @@ function ConvDesc(T, N, padding, stride, dilation, mode, groupcount)
4444
return this
4545
end
4646

47-
function ConvDesc(T, cdims::DenseConvDims)
47+
function ConvDesc(T, cdims::ConvDims)
4848
pd = NNlib.padding(cdims)
4949
if !all(pd[1:2:end] .== pd[2:2:end])
5050
@warn("CuDNN does not support asymmetric padding; defaulting to symmetric choice")
@@ -69,7 +69,7 @@ function cudnnConvolutionBiasActivationForward(y::CuArray{T,N}, x::CuArray{T,N},
6969
end
7070

7171
function cudnnConvolutionForward(y::CuArray{T,N}, x::CuArray{T,N}, w::CuArray{T,N},
72-
cdims::DenseConvDims; algo=0, alpha=1, beta=0) where {T,N}
72+
cdims::ConvDims; algo=0, alpha=1, beta=0) where {T,N}
7373
@workspace size=@argout(
7474
cudnnGetConvolutionForwardWorkspaceSize(
7575
handle(), TensorDesc(x),
@@ -87,7 +87,7 @@ function cudnnConvolutionForward(y::CuArray{T,N}, x::CuArray{T,N}, w::CuArray{T,
8787
end
8888

8989
function cudnnConvolutionBackwardData(dx::CuArray{T,N}, w::CuArray{T,N}, dy::CuArray{T,N},
90-
cdims::DenseConvDims; algo=0, alpha=1, beta=0) where {T,N}
90+
cdims::ConvDims; algo=0, alpha=1, beta=0) where {T,N}
9191
@workspace size=@argout(
9292
cudnnGetConvolutionBackwardDataWorkspaceSize(
9393
handle(), FilterDesc(w),
@@ -106,7 +106,7 @@ function cudnnConvolutionBackwardData(dx::CuArray{T,N}, w::CuArray{T,N}, dy::CuA
106106
end
107107

108108
function cudnnConvolutionBackwardFilter(dw::CuArray{T,N}, x::CuArray{T,N}, dy::CuArray{T,N},
109-
cdims::DenseConvDims; algo=0, alpha=1, beta=0) where {T,N}
109+
cdims::ConvDims; algo=0, alpha=1, beta=0) where {T,N}
110110
@workspace size=@argout(
111111
cudnnGetConvolutionBackwardFilterWorkspaceSize(
112112
handle(), TensorDesc(x),

src/dnn/nnlib.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ end
4141

4242
# Convolution
4343

44-
function conv!(y::CuArray{T}, x::CuArray{T}, w::CuArray{T}, cdims::DenseConvDims;
44+
function conv!(y::CuArray{T}, x::CuArray{T}, w::CuArray{T}, cdims::ConvDims;
4545
alpha=1, algo=0) where T<:CUDNNFloat
4646
if version() < v"6"
4747
all(x -> x == 1, dilation(cdims)) || error("Only dilation = 1 is supported in cuDNN version < 6")
@@ -51,7 +51,7 @@ function conv!(y::CuArray{T}, x::CuArray{T}, w::CuArray{T}, cdims::DenseConvDims
5151
end
5252

5353
function ∇conv_filter!(dw::CuArray{T}, x::CuArray{T}, dy::CuArray{T},
54-
cdims::DenseConvDims; alpha=1, algo=0) where T<:CUDNNFloat
54+
cdims::ConvDims; alpha=1, algo=0) where T<:CUDNNFloat
5555
if version() < v"6"
5656
all(x -> x == 1, dilation(cdims)) || error("Only dilation = 1 is supported in cuDNN version < 6")
5757
end
@@ -60,7 +60,7 @@ function ∇conv_filter!(dw::CuArray{T}, x::CuArray{T}, dy::CuArray{T},
6060
end
6161

6262
function ∇conv_data!(dx::CuArray{T}, dy::CuArray{T}, w::CuArray{T},
63-
cdims::DenseConvDims; alpha=1, algo=0) where T<:CUDNNFloat
63+
cdims::ConvDims; alpha=1, algo=0) where T<:CUDNNFloat
6464
if version() < v"6"
6565
all(x -> x == 1, dilation(cdims)) || error("Only dilation = 1 is supported in cuDNN version < 6")
6666
end

test/dnn.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ else
1717
softmax, ∇softmax, logsoftmax, ∇logsoftmax
1818
a, b, c = rand(Float64, 10, 10, 3, 1), rand(Float64, 2, 2, 3, 4), rand(Float64, 9, 9, 4, 1)
1919
da, db, dc = CuArray(a), CuArray(b), CuArray(c)
20-
cdims = DenseConvDims(a, b)
20+
cdims = ConvDims(a, b)
2121
@test NNlib.conv(a, b, cdims) collect(NNlib.conv(da, db, cdims))
2222
@test ∇conv_data(c, b, cdims) collect(∇conv_data(dc, db, cdims))
2323
@test ∇conv_filter(a, c, cdims) collect(∇conv_filter(da, dc, cdims))
@@ -35,7 +35,7 @@ else
3535
algos = (1, 0, 1, 1,)
3636

3737
for (opts, algo) in zip(options, algos)
38-
cdims = DenseConvDims(x, w; opts...)
38+
cdims = ConvDims(x, w; opts...)
3939
y = NNlib.conv(x, w, cdims)
4040

4141
# Test that basic convolution is equivalent across GPU/CPU

0 commit comments

Comments
 (0)