Skip to content

Commit

Permalink
Merge pull request #97 from PumasAI/stackmemory
Browse files Browse the repository at this point in the history
Stack Memory
  • Loading branch information
chriselrod authored Sep 2, 2022
2 parents fa0982d + 8768d8f commit 13d60f0
Show file tree
Hide file tree
Showing 21 changed files with 1,034 additions and 465 deletions.
6 changes: 4 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
name = "SimpleChains"
uuid = "de6bee2f-e2f4-4ec7-b6ed-219cc6f6e9e5"
authors = ["Chris Elrod <[email protected]> and contributors"]
version = "0.2.12"
version = "0.3.0"

[deps]
ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9"
ArrayInterfaceCore = "30b0a656-2188-435a-8636-2ec0e6a096e2"
CPUSummary = "2a0fbf3d-bb9c-48f3-b0a9-814d99fd7ab9"
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
CloseOpenIntervals = "fb6a15b2-703c-40df-9091-08a04967cfa9"
Expand All @@ -27,6 +28,7 @@ VectorizedRNG = "33b4df10-0173-11e9-2a0c-851a7edac40e"

[compat]
ArrayInterface = "6"
ArrayInterfaceCore = "0.1.14"
CPUSummary = "0.1.8"
ChainRulesCore = "0.8, 0.9, 0.10, 1"
CloseOpenIntervals = "0.1.6"
Expand All @@ -43,7 +45,7 @@ Static = "0.7"
StaticArrays = "1"
StrideArraysCore = "0.3.5"
UnPack = "1"
VectorizationBase = "0.21.30"
VectorizationBase = "0.21.40"
VectorizedRNG = "0.2.13"
julia = "1.6"

Expand Down
25 changes: 17 additions & 8 deletions src/SimpleChains.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ using UnPack,
StrideArraysCore,
Static,
VectorizedRNG
using ArrayInterfaceCore: CPUPointer
using ArrayInterface:
size,
strides,
Expand All @@ -29,7 +30,7 @@ using SIMDTypes: Bit, NativeTypes
using VectorizationBase: align, relu, stridedpointer, AbstractSIMD, NativeTypesV
using HostCPUFeatures: static_sizeof, register_size, register_count, static_sizeof
using CPUSummary: cache_linesize, num_threads, num_cores
using LayoutPointers: bytestrideindex, stridedpointer, zero_offsets, val_dense_dims
using LayoutPointers: bytestrideindex, stridedpointer, zstridedpointer, zero_offsets, val_dense_dims
using Static: One, lt
using CloseOpenIntervals: CloseOpen
using StrideArraysCore: zview, @gc_preserve
Expand All @@ -39,6 +40,7 @@ import Random
import ChainRulesCore
import ForwardDiff
import LoopVectorization
import StaticArrays

using LoopVectorization: matmul_params, @turbo
# using LoopVectorization: matmul_params
Expand Down Expand Up @@ -67,6 +69,7 @@ export SimpleChain,

const Integer = Union{StaticInt,Base.Integer}

include("memory.jl")
include("simple_chain.jl")
include("utils.jl")
include("activation.jl")
Expand All @@ -81,13 +84,19 @@ include("penalty.jl")
include("chain_rules.jl")
include("optimize.jl")

if VERSION >= v"1.7.0"
if hasfield(Method, :recursion_relation)
dont_limit = Returns(true)
for f = (chain_valgrad!, _chain, output_size, _numparam)
for m in methods(f)
m.recursion_relation = dont_limit
end
if VERSION >= v"1.7.0" && hasfield(Method, :recursion_relation)
dont_limit = Returns(true)
for f in (
chain_valgrad!,
chain_valgrad_pullback!,
__chain,
output_size,
forward_output_size,
_numparam,
pullback_layer!,
)
for m in methods(f)
m.recursion_relation = dont_limit
end
end
end
Expand Down
5 changes: 3 additions & 2 deletions src/activation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,12 @@ struct Activation{F}
f::F
end
parameter_free(::Activation) = true
numparam(::Activation, id) = 0, id
numparam(::Activation, id) = static(0), id
init_params!(::Activation, p, id) = p, id
_check_input_dims(::Activation, _) = nothing

layer_output_size(::Val{T}, a::Activation, s) where {T} = align(prod(s) * (2sizeof(T))), s
forward_layer_output_size(::Val{T}, a::Activation, s) where {T} =
align(prod(s) * static_sizeof(T)), s

Base.show(io::IO, a::Activation) = print(io, "Activation layer applying: ", a.f)

Expand Down
153 changes: 96 additions & 57 deletions src/chain_rules.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@

@static if isdefined(ChainRulesCore, :NoTangent)
if isdefined(ChainRulesCore, :NoTangent)
const NoTangent = ChainRulesCore.NoTangent
else
const NoTangent = ChainRulesCore.DoesNotExist
Expand All @@ -24,33 +24,59 @@ function pullback_layer!(pbl::PullBackLayer, lgrad)
end
pullback_layer!(pbl::Ptr{UInt8}, grad) = grad, pbl

struct PullBack{PBL<:PullBackLayer,G,P,M}

#TODO: add support for not getting gradient with respect to input `x`
# struct PullBackParam{T,L,A,PBL}
# pg::Ptr{T}
# l::L
# arg::A
# p::Ptr{T}
# pu::Ptr{UInt8}
# pbl::PBL # either another `PullBackLayer`, or the last memory pointer from the forward pass (to start the reverse)
# end
# function pullback_layer!(pbl::PullBackParam, lgrad)
# grad, _ = pullback_layer!(pbl.pbl, lgrad)
# pullback_param!(pbl.pg, pbl.l, grad, pbl.arg, pbl.p, pbl.pu)
# end

# struct PullBack{PBL<:Union{PullBackLayer,PullBackParam},G,P,M}
struct PullBack{SA,PBL<:PullBackLayer,G,P,M}
pbl::PBL
grad::G
params::P
memory::M
function PullBack{SA}(pbl::PBL, grad::G, params::P, memory::M) where {SA,PBL,G,P,M}
new{SA,PBL,G,P,M}(pbl, grad, params, memory)
end
end
function (pb::PullBack)(x)
@inline function (pb::PullBack{SA})(x) where {SA}
@unpack pbl, grad, params, memory = pb
GC.@preserve grad params memory begin
lgrad, pu4 = pullback_layer!(pbl, x)
lgrad, _ = pullback_layer!(pbl, x)
end
if SA
NoTangent(),
_maybe_sarray(StrideArraysCore.StrideArray(lgrad, memory)),
_maybe_sarray(StrideArraysCore.StrideArray(grad, memory))
else
NoTangent(),
StrideArraysCore.StrideArray(lgrad, memory),
StrideArraysCore.StrideArray(grad, memory)
end
NoTangent(),
StrideArraysCore.StrideArray(lgrad, memory),
StrideArraysCore.StrideArray(grad, memory)
end


function unsafe_valgrad_pullback!(g, layers, params, memory::Vector{UInt8}, arg)
GC.@preserve g params memory begin
# @show pointer(g) pointer(params) pointer(memory)
l, pbl =
chain_valgrad_pullback!(pointer(g), arg, layers, pointer(params), pointer(memory))
@inline function (pb::PullBack)(x::StaticArrays.SArray)
@unpack pbl, grad, params, memory = pb
mx = StaticArrays.MArray(x);
GC.@preserve mx grad params memory begin
lgrad, _ = pullback_layer!(pbl, PtrArray(mx))
end
l, PullBack(pbl, g, params, memory)
NoTangent(),
_maybe_sarray(StrideArraysCore.StrideArray(lgrad, memory)),
_maybe_sarray(StrideArraysCore.StrideArray(grad, memory))
end

function chain_valgrad_pullback!(

@inline function chain_valgrad_pullback!(
pg,
arg,
layers::Tuple{X1,X2,Vararg},
Expand All @@ -60,75 +86,88 @@ function chain_valgrad_pullback!(
l = getfield(layers, 1)
pg2, larg, p2, pu2 = valgrad_layer!(pg, l, arg, p, pu)

# val, grad, pu3, pbl = chain_valgrad_pullback!(pg2, larg, Base.tail(layers), p2, pu2)
val, pbl = chain_valgrad_pullback!(pg2, larg, Base.tail(layers), p2, pu2)
pbl_ret = PullBackLayer(pg, l, arg, p, pu, pbl)
return val, pbl_ret
# lgrad, pu4 = pullback!(pg, l, grad, arg, p, pu, pu3)
# return val, lgrad, pu4
end
function chain_valgrad_pullback!(
@inline function chain_valgrad_pullback!(
pg,
arg,
layers::Tuple{X1},
p::Ptr,
pu::Ptr{UInt8},
) where {X1}
l = getfield(layers, 1)
pg2, val, p2, pu2 = valgrad_layer!(pg, l, arg, p, pu)
_, val, __, pu2 = valgrad_layer!(pg, l, arg, p, pu)

# val, grad, pu3, pbl = chain_valgrad!(pg2, larg, Base.tail(layers), p2, pu2)
# pu2 gets fed into eventual `pullback!` call
pbl_ret = PullBackLayer(pg, l, arg, p, pu, pu2)
return val, pbl_ret
# lgrad, pu4 = pullback!(pg, l, grad, arg, p, pu, pu3)
# return val, lgrad, pu4
end

# No loss: chain closures.
function _rrule(sc, arg, params, memory, ::False)
valgrad_noloss(sc, arg, params, memory)
function _rrule(sc, arg, params, ::False)
valgrad_noloss(sc, arg, params)
end
function valgrad_noloss(sc, arg::AbstractArray{S}, params::StaticArrays.SVector{T}) where {T,S}
mp = StaticArrays.MVector(params);
@gc_preserve valgrad_noloss(sc, arg, mp)
end
function valgrad_noloss(sc, arg, params::AbstractVector{T}, memory = sc.memory) where {T}
function valgrad_noloss(sc, arg::AbstractArray{S}, params::AbstractVector{T}) where {T,S}
c = getchain(sc)
@unpack layers = c
parg = maybe_static_size_arg(c.inputdim, arg)
arglen = length(parg)
barg = preserve_buffer(arg)
off = align(resize_memory!(layers, memory, parg, length(parg) * sizeof(eltype(parg))))
GC.@preserve barg memory begin
g = PtrArray(reinterpret(Ptr{T}, pointer(memory) + off), (static_length(params),))
l, pullback = unsafe_valgrad_pullback!(g, layers, params, memory, parg)
end
return l, pullback
# return StrideArraysCore.StrideArray(l, memory), pullback
end

glen = _try_static(numparam(sc), static_length(params))
goff = align(glen * static_sizeof(T))
aoff = align(arglen * static_sizeof(S))

num_bytes = required_bytes(Val{T}(), layers, size(parg), aoff + goff)
memory = get_heap_memory(sc, num_bytes)

# Loss: call `valgrad`.
function _rrule(sc, arg, params, memory, ::True)
l, g = valgrad(sc, arg, params, memory)
# assumes no grad w/ respect to arg
pullback = let g = g
-> begin
if !isone(l̄)
@turbo for i eachindex(g)
g[i] *=
end
end
NoTangent(), NoTangent(), g
GC.@preserve barg params memory begin
pm = align(pointer(memory))
parg2 = PtrArray(Ptr{S}(pm), _try_static(c.inputdim, size(parg)))
@inbounds @simd ivdep for i in eachindex(parg)
parg2[i] = parg[i]
end
pm += aoff
g = PtrArray(Ptr{T}(pm), (glen,))
pm += goff
# @show pointer(g) pointer(params) pointer(memory)
l, pbl = chain_valgrad_pullback!(pointer(g), parg2, layers, pointer(params), pm)
end
if arg isa StaticArrays.SArray
_maybe_sarray(l), PullBack{true}(pbl, g, params, memory)
else
l, PullBack{true}(pbl, g, params, memory)
end
l, pullback
end

function ChainRulesCore.rrule(
sc::AbstractPenalty, arg, params, memory = task_local_memory()
)
_rrule(sc, arg, params, memory, True())
struct ElementwisePullback{G}
g::G
end
function ChainRulesCore.rrule(
sc::SimpleChain, arg, params, memory = task_local_memory()
)
_rrule(sc, arg, params, memory, has_loss_typed(sc))
#TODO: add support for getting gradient with respect to `arg`
function (ep::ElementwisePullback)(l̄)
g = ep.g
if !isone(l̄)
@turbo for i eachindex(g)
g[i] *=
end
end
# assumes no grad w/ respect to arg
NoTangent(), NoTangent(), g
end
# Loss: call `valgrad`.
function _rrule(sc, arg, params, ::True)
l, g = valgrad(sc, arg, params)
l, ElementwisePullback(g)
end
# TODO: support penalties without returning scalars
_returns_scalar(::AbstractPenalty) = True()
_returns_scalar(sc::SimpleChain) = has_loss_typed(sc)

function ChainRulesCore.rrule(sc::Chain, arg, params)
_rrule(sc, arg, params, _returns_scalar(sc))
end
4 changes: 2 additions & 2 deletions src/conv.jl
Original file line number Diff line number Diff line change
Expand Up @@ -792,9 +792,9 @@ function getparams(c::Conv, p::Ptr{T}, inputdim::Tuple{Vararg{Integer}}) where {
(K, b), p + sizeof(T) * length(b)
end

function layer_output_size(::Val{T}, c::Conv, inputdim::Tuple) where {T}
function forward_layer_output_size(::Val{T}, c::Conv, inputdim::Tuple) where {T}
_, outputdim = numparam(c, inputdim)
2align(static_sizeof(T) * prod(outputdim)), outputdim
align(static_sizeof(T) * prod(outputdim)), outputdim
end

function init_params!(c::Conv, p, inputdim)
Expand Down
Loading

2 comments on commit 13d60f0

@chriselrod
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JuliaRegistrator
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Registration pull request created: JuliaRegistries/General/67590

After the above pull request is merged, it is recommended that a tag is created on this repository for the registered package version.

This will be done automatically if the Julia TagBot GitHub Action is installed, or can be done manually through the github interface, or via:

git tag -a v0.3.0 -m "<description of version>" 13d60f02c96628e181e2abcd3eec11dfec3f5fbf
git push origin v0.3.0

Please sign in to comment.