diff --git a/Project.toml b/Project.toml index c9647205..02e14607 100644 --- a/Project.toml +++ b/Project.toml @@ -48,4 +48,4 @@ ScopedValues = "1.3.0" SpecialFunctions = "2" Statistics = "1" cuDNN = "1" -julia = "1.9" +julia = "1.10" diff --git a/src/dim_helpers/ConvDims.jl b/src/dim_helpers/ConvDims.jl index e8bcc08f..9e02010d 100644 --- a/src/dim_helpers/ConvDims.jl +++ b/src/dim_helpers/ConvDims.jl @@ -73,7 +73,7 @@ function im2col_dims(c::ConvDims) # Size of single dotproduct within convolution prod(kernel_size(c))*channels_in(c), # One workspace per thread - VERSION > v"1.9.0-0" ? Threads.nthreads(:default) : Threads.nthreads(), + Threads.nthreads(:default), ) end diff --git a/src/gemm.jl b/src/gemm.jl index 9a3c6cd5..e05174d1 100644 --- a/src/gemm.jl +++ b/src/gemm.jl @@ -95,7 +95,7 @@ for (gemm, elt) in gemm_datatype_mappings strC = Base.stride(C, 3) n_threads = min( - VERSION > v"1.9.0-0" ? Threads.nthreads(:default) : Threads.nthreads(), + Threads.nthreads(:default), 1 + max(length(A), length(B)) ÷ 8000) # In some tests, size (20,20,20) is worth splitting between two threads, # as is size (32,32,8). diff --git a/src/utils.jl b/src/utils.jl index baf95c8d..6d82a81e 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -144,21 +144,3 @@ function rrule(cfg::RuleConfig{>:HasReverseMode}, ::typeof(_fast_broadcast!), f: rrule_via_ad(cfg, broadcast, f, x, ys...) end -# Could get this from Compat.jl instead -# https://github.com/JuliaLang/julia/pull/39794 -if VERSION < v"1.7.0-DEV.793" - struct Returns{V} <: Function - value::V - Returns{V}(value) where {V} = new{V}(value) - Returns(value) = new{Core.Typeof(value)}(value) - end - - (obj::Returns)(args...; kw...) = obj.value - function Base.show(io::IO, obj::Returns) - show(io, typeof(obj)) - print(io, "(") - show(io, obj.value) - print(io, ")") - end -end - diff --git a/test/conv.jl b/test/conv.jl index cf323277..8e52c846 100644 --- a/test/conv.jl +++ b/test/conv.jl @@ -908,7 +908,7 @@ end gradtest((y, w) -> sum(∇depthwiseconv_data(y, w, dcdims)), y, w) end -@static if Test_Enzyme +if NNLIB_TEST_ENZYME @testset "EnzymeRules: conv! spatial_rank=$spatial_rank" for spatial_rank in (1, 2, 3) x = rand(rng, repeat([5], spatial_rank)..., 3, 2) diff --git a/test/dropout.jl b/test/dropout.jl index 0da70111..65aac8b6 100644 --- a/test/dropout.jl +++ b/test/dropout.jl @@ -16,9 +16,6 @@ using Zygote, StableRNGs, ChainRulesCore, Enzyme @test size(@inferred dropout(rng, x1, 0.1; dims=2)) == (3, 4) x2 = Diagonal(randn(Float32, 10)) # Just to check it runs on weird matrices. - if VERSION > v"1.8-" # on 1.6 this makes a sparse array. - @test dropout(x2, 0.3) isa Matrix{Float32} # does not infer, but that's OK? - end # Values @test dropout(x1, 0) == x1 @@ -76,7 +73,7 @@ using Zygote, StableRNGs, ChainRulesCore, Enzyme @test_throws ArgumentError dropout!(y1, x1, 3) end -@static if Test_Enzyme +if NNLIB_TEST_ENZYME @testset "EnzymeRules: dropout " begin rng = Random.default_rng() diff --git a/test/pooling.jl b/test/pooling.jl index f9d57ade..1b11a1ae 100644 --- a/test/pooling.jl +++ b/test/pooling.jl @@ -948,7 +948,7 @@ end gradtest(x -> sum(meanpool(x, k)), x) end -@static if Test_Enzyme +if NNLIB_TEST_ENZYME @testset "EnzymeRules: pooling! $pool spatial_rank=$spatial_rank " for spatial_rank in (1, 2), (pool, pool!) in ((maxpool, maxpool!), (meanpool, meanpool!)) diff --git a/test/runtests.jl b/test/runtests.jl index b8080b6b..6805672e 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -18,10 +18,10 @@ import ReverseDiff as RD # used in `pooling.jl` import Pkg using SpecialFunctions -const Test_Enzyme = VERSION <= v"1.10-" DocMeta.setdocmeta!(NNlib, :DocTestSetup, :(using NNlib, UnicodePlots); recursive=true) +const NNLIB_TEST_ENZYME = true # ENV["NNLIB_TEST_CUDA"] = "true" # uncomment to run CUDA tests # ENV["NNLIB_TEST_AMDGPU"] = "true" # uncomment to run AMDGPU tests # ENV["NNLIB_TEST_CPU"] = "false" # uncomment to skip CPU tests diff --git a/test/testsuite/gather.jl b/test/testsuite/gather.jl index 92e3bfb7..18953338 100644 --- a/test/testsuite/gather.jl +++ b/test/testsuite/gather.jl @@ -154,7 +154,7 @@ function gather_testsuite(Backend) gradtest_fn((s, i) -> gather(s, i), src, idx) end - @static if Test_Enzyme + if NNLIB_TEST_ENZYME @testset "EnzymeRules: gather! gradient for scalar index" begin src = device(Float64[3, 4, 5, 6, 7]) diff --git a/test/testsuite/scatter.jl b/test/testsuite/scatter.jl index aa0b1c41..ddbf8eb6 100644 --- a/test/testsuite/scatter.jl +++ b/test/testsuite/scatter.jl @@ -208,7 +208,7 @@ function scatter_testsuite(Backend) end - @static if Test_Enzyme + if NNLIB_TEST_ENZYME @testset "EnzymeRules" begin idx = device([2, 2, 3, 4, 4])