Skip to content
This repository was archived by the owner on Mar 12, 2021. It is now read-only.

Ability to pass groupcount parameter for depthwise and groupwise convolutions #523

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions src/dnn/compat.jl
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
# Compatibility shims until users upgrade to new NNlib format
function conv!(y::CuArray{T}, x::CuArray{T}, w::CuArray{T}; pad=0, stride=1, flipkernel=0, dilation=1, kwargs...) where {T<:CUDNNFloat}
cdims = DenseConvDims(x, w; padding=pad, stride=stride, flipkernel=flipkernel, dilation=dilation)
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I need help with these shims.

cdims = ConvDims(x, w; padding=pad, stride=stride, flipkernel=flipkernel, dilation=dilation)
return conv!(y, x, w, cdims; kwargs...)
end

function ∇conv_filter!(dw::CuArray{T}, dy::CuArray{T}, x::CuArray{T}; pad=0, stride=1, flipkernel=0, dilation=1, kwargs...) where {T<:CUDNNFloat}
cdims = DenseConvDims(x, dw; padding=pad, stride=stride, flipkernel=flipkernel, dilation=dilation)
cdims = ConvDims(x, dw; padding=pad, stride=stride, flipkernel=flipkernel, dilation=dilation)
# NOTE!!! This compat shim re-arranges the argument order!
return ∇conv_filter!(dw, x, dy, cdims; kwargs...)
end
Expand Down
15 changes: 8 additions & 7 deletions src/dnn/conv.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
using NNlib: DenseConvDims
using NNlib: ConvDims


# descriptor
Expand Down Expand Up @@ -28,7 +28,7 @@ end

Base.cconvert(::Type{cudnnConvolutionMode_t}, x::Bool) = x ? CUDNN_CROSS_CORRELATION : CUDNN_CONVOLUTION

function ConvDesc(T, N, padding, stride, dilation, mode)
function ConvDesc(T, N, padding, stride, dilation, mode, groupcount)
cd = Ref{cudnnConvolutionDescriptor_t}()
cudnnCreateConvolutionDescriptor(cd)
if version() >= v"4"
Expand All @@ -38,18 +38,19 @@ function ConvDesc(T, N, padding, stride, dilation, mode)
else
cudnnSetConvolutionNdDescriptor(cd[],N,cdsize(padding,N),cdsize(stride,N),cdsize(dilation,N),mode)
end
cudnnSetConvolutionGroupCount(cd[], Cint(groupcount))
this = ConvDesc(cd[])
finalizer(unsafe_free!, this)
return this
end

function ConvDesc(T, cdims::DenseConvDims)
function ConvDesc(T, cdims::ConvDims)
pd = NNlib.padding(cdims)
if !all(pd[1:2:end] .== pd[2:2:end])
@warn("CuDNN does not support asymmetric padding; defaulting to symmetric choice")
end
return ConvDesc(T, NNlib.spatial_dims(cdims), pd[1:2:end], NNlib.stride(cdims),
NNlib.dilation(cdims), NNlib.flipkernel(cdims))
NNlib.dilation(cdims), NNlib.flipkernel(cdims), NNlib.group_count(cdims))
end


Expand All @@ -68,7 +69,7 @@ function cudnnConvolutionBiasActivationForward(y::CuArray{T,N}, x::CuArray{T,N},
end

function cudnnConvolutionForward(y::CuArray{T,N}, x::CuArray{T,N}, w::CuArray{T,N},
cdims::DenseConvDims; algo=0, alpha=1, beta=0) where {T,N}
cdims::ConvDims; algo=0, alpha=1, beta=0) where {T,N}
@workspace size=@argout(
cudnnGetConvolutionForwardWorkspaceSize(
handle(), TensorDesc(x),
Expand All @@ -86,7 +87,7 @@ function cudnnConvolutionForward(y::CuArray{T,N}, x::CuArray{T,N}, w::CuArray{T,
end

function cudnnConvolutionBackwardData(dx::CuArray{T,N}, w::CuArray{T,N}, dy::CuArray{T,N},
cdims::DenseConvDims; algo=0, alpha=1, beta=0) where {T,N}
cdims::ConvDims; algo=0, alpha=1, beta=0) where {T,N}
@workspace size=@argout(
cudnnGetConvolutionBackwardDataWorkspaceSize(
handle(), FilterDesc(w),
Expand All @@ -105,7 +106,7 @@ function cudnnConvolutionBackwardData(dx::CuArray{T,N}, w::CuArray{T,N}, dy::CuA
end

function cudnnConvolutionBackwardFilter(dw::CuArray{T,N}, x::CuArray{T,N}, dy::CuArray{T,N},
cdims::DenseConvDims; algo=0, alpha=1, beta=0) where {T,N}
cdims::ConvDims; algo=0, alpha=1, beta=0) where {T,N}
@workspace size=@argout(
cudnnGetConvolutionBackwardFilterWorkspaceSize(
handle(), TensorDesc(x),
Expand Down
6 changes: 3 additions & 3 deletions src/dnn/nnlib.jl
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ end

# Convolution

function conv!(y::CuArray{T}, x::CuArray{T}, w::CuArray{T}, cdims::DenseConvDims;
function conv!(y::CuArray{T}, x::CuArray{T}, w::CuArray{T}, cdims::ConvDims;
alpha=1, algo=0) where T<:CUDNNFloat
if version() < v"6"
all(x -> x == 1, dilation(cdims)) || error("Only dilation = 1 is supported in cuDNN version < 6")
Expand All @@ -51,7 +51,7 @@ function conv!(y::CuArray{T}, x::CuArray{T}, w::CuArray{T}, cdims::DenseConvDims
end

function ∇conv_filter!(dw::CuArray{T}, x::CuArray{T}, dy::CuArray{T},
cdims::DenseConvDims; alpha=1, algo=0) where T<:CUDNNFloat
cdims::ConvDims; alpha=1, algo=0) where T<:CUDNNFloat
if version() < v"6"
all(x -> x == 1, dilation(cdims)) || error("Only dilation = 1 is supported in cuDNN version < 6")
end
Expand All @@ -60,7 +60,7 @@ function ∇conv_filter!(dw::CuArray{T}, x::CuArray{T}, dy::CuArray{T},
end

function ∇conv_data!(dx::CuArray{T}, dy::CuArray{T}, w::CuArray{T},
cdims::DenseConvDims; alpha=1, algo=0) where T<:CUDNNFloat
cdims::ConvDims; alpha=1, algo=0) where T<:CUDNNFloat
if version() < v"6"
all(x -> x == 1, dilation(cdims)) || error("Only dilation = 1 is supported in cuDNN version < 6")
end
Expand Down
4 changes: 2 additions & 2 deletions test/dnn.jl
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ else
softmax, ∇softmax, logsoftmax, ∇logsoftmax
a, b, c = rand(Float64, 10, 10, 3, 1), rand(Float64, 2, 2, 3, 4), rand(Float64, 9, 9, 4, 1)
da, db, dc = CuArray(a), CuArray(b), CuArray(c)
cdims = DenseConvDims(a, b)
cdims = ConvDims(a, b)
@test NNlib.conv(a, b, cdims) ≈ collect(NNlib.conv(da, db, cdims))
@test ∇conv_data(c, b, cdims) ≈ collect(∇conv_data(dc, db, cdims))
@test ∇conv_filter(a, c, cdims) ≈ collect(∇conv_filter(da, dc, cdims))
Expand All @@ -35,7 +35,7 @@ else
algos = (1, 0, 1, 1,)

for (opts, algo) in zip(options, algos)
cdims = DenseConvDims(x, w; opts...)
cdims = ConvDims(x, w; opts...)
y = NNlib.conv(x, w, cdims)

# Test that basic convolution is equivalent across GPU/CPU
Expand Down