Skip to content

Commit bcc3cdd

Browse files
committed
init
0 parents  commit bcc3cdd

File tree

10 files changed

+452
-0
lines changed

10 files changed

+452
-0
lines changed

.gitignore

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
*.DS_Store
2+
*.jl.cov
3+
*.jl.*.cov
4+
*.jl.mem

.travis.yml

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
# Documentation: http://docs.travis-ci.com/user/languages/julia/
2+
language: julia
3+
os:
4+
- linux
5+
julia:
6+
- 0.7
7+
- 1.0
8+
- nightly
9+
matrix:
10+
allow_failures:
11+
- julia: nightly
12+
notifications:
13+
email: false
14+
# uncomment the following lines to override the default test script
15+
#script:
16+
# - if [[ -a .git/shallow ]]; then git fetch --unshallow; fi
17+
# - julia -e 'Pkg.clone(pwd()); Pkg.build("ChainRules"); Pkg.test("ChainRules"; coverage=true)'
18+
after_success:
19+
# push coverage results to Coveralls
20+
- julia -e 'cd(Pkg.dir("ChainRules")); Pkg.add("Coverage"); using Coverage; Coveralls.submit(Coveralls.process_folder())'
21+
- julia -e 'ps=Pkg.PackageSpec(name="Documenter", version="0.19"); Pkg.add(ps); Pkg.pin(ps)'
22+
- julia -e 'cd(Pkg.dir("ChainRules")); include(joinpath("docs", "make.jl"))'

LICENSE.md

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
The ChainRules.jl package is licensed under the MIT "Expat" License:
2+
3+
> Copyright (c) 2018: Jarrett Revels.
4+
>
5+
> Permission is hereby granted, free of charge, to any person obtaining a copy
6+
> of this software and associated documentation files (the "Software"), to deal
7+
> in the Software without restriction, including without limitation the rights
8+
> to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9+
> copies of the Software, and to permit persons to whom the Software is
10+
> furnished to do so, subject to the following conditions:
11+
>
12+
> The above copyright notice and this permission notice shall be included in all
13+
> copies or substantial portions of the Software.
14+
>
15+
> THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16+
> IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17+
> FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18+
> AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19+
> LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20+
> OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21+
> SOFTWARE.
22+
>

README.md

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
# ChainRules
2+
3+
The ChainRules package provides a variety of common utilities that can be used
4+
by downstream automatic differentiation (AD) tools to define and execute
5+
forward-, reverse-, and mixed-mode primitives.
6+
7+
This package is a WIP; the framework is essentially there, but there are only a
8+
few toy rules right now, a bunch of TODOs, virtually no tests, etc. PRs welcome!
9+
Documentation is incoming, which should help if you'd like to contribute.
10+
11+
Here are some of the basic goals for the package:
12+
13+
- First-class support for complex differentiation via Wirtinger derivatives.
14+
15+
- Mixed-mode composability without being coupled to a specific AD implementation.
16+
17+
- Propagation semantics built-in, with default implementations that allow rule
18+
authors to easily opt-in to common optimizations (fusion, increment elision, etc.).
19+
20+
- Control-inverted design: rule authors can fully specify derivatives in
21+
a concise manner while naturally allowing the caller to compute only what they
22+
need.
23+
24+
- Genericity/Overloadability: rules are well-specified independently of target
25+
function's input/output values' types, though these types can be specialized
26+
on when desired. Furthermore, properties like storage device, tensor shape,
27+
domain etc. can be specified by callers (and thus exploited by rule authors)
28+
independently of these types.
29+
30+
The ChainRules source code follows the [YASGuide](https://github.com/jrevels/YASGuide).

REQUIRE

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
julia 0.7

src/ChainRules.jl

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
module ChainRules
2+
3+
using Base.Broadcast: materialize, materialize!, broadcasted
4+
5+
include("markup.jl")
6+
include("interface.jl")
7+
include("rules.jl")
8+
9+
end # module

src/interface.jl

Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,66 @@
1+
#####
2+
##### `Thunk`
3+
#####
4+
5+
macro thunk(body)
6+
return :(Thunk(() -> $(esc(body))))
7+
end
8+
9+
struct Thunk{F}
10+
f::F
11+
end
12+
13+
@inline (thunk::Thunk{F})() where {F} = (thunk.f)()
14+
15+
#####
16+
##### `forward_chain`
17+
#####
18+
19+
forward_chain(args...) = materialize(_forward_chain(args...))
20+
21+
@inline _forward_chain(∂::Thunk, ẋ::Nothing) = false
22+
@inline _forward_chain(∂::Thunk, ẋ) = broadcasted(*, (), ẋ)
23+
_forward_chain(∂::Thunk, ẋ, args...) = broadcasted(+, _forward_chain(∂, ẋ), _forward_chain(args...))
24+
25+
#####
26+
##### `reverse_chain!`
27+
#####
28+
29+
@inline reverse_chain!(x̄::Nothing, ∂::Thunk) = false
30+
31+
@inline function reverse_chain!(x̄, ∂::Thunk)
32+
thunk = ()
33+
x̄_value = adjoint_value(x̄)
34+
casted = should_increment(x̄) ? broadcasted(+, x̄_value, thunk) : thunk
35+
if should_materialize_into(x̄)
36+
return materialize!(x̄_value, casted)
37+
else
38+
return materialize(casted)
39+
end
40+
end
41+
42+
adjoint_value(x̄) =
43+
44+
should_increment(::Any) = true
45+
46+
should_materialize_into(::Any) = false
47+
48+
#####
49+
##### miscellanous defaults
50+
#####
51+
52+
# TODO: More defaults, obviously!
53+
54+
markup(::Any) = Ignore()
55+
markup(::Real) = RealScalar()
56+
markup(::Complex) = ComplexScalar()
57+
markup(x::Tuple{Vararg{<:Real}}) = RealTensor(layout(x))
58+
markup(x::Tuple{Vararg{<:Complex}}) = ComplexTensor(layout(x))
59+
markup(x::AbstractArray{<:Real}) = RealTensor(layout(x))
60+
markup(x::AbstractArray{<:Complex}) = ComplexTensor(layout(x))
61+
markup(x::AbstractArray) = error("Cannot infer domain of array from eltype", x)
62+
63+
layout(x::Tuple) = Layout(length(x), (length(x),), CPUDevice(), true)
64+
layout(x::Array) = Layout(length(x), size(x), CPUDevice(), true)
65+
66+
should_materialize_into(::Array) = true

src/markup.jl

Lines changed: 143 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,143 @@
1+
#####
2+
##### `AbstractLayout`
3+
#####
4+
5+
abstract type AbstractLayout end
6+
7+
struct CPUDevice end
8+
9+
struct Layout{L, S, D} <: AbstractLayout
10+
length::L
11+
size::S
12+
device::D
13+
ismutable::Bool
14+
end
15+
16+
Base.length(layout::Layout) = layout.length
17+
18+
Base.size(layout::Layout) = layout.size
19+
20+
device(layout::Layout) = layout.device
21+
22+
ismutable(layout::Layout) = layout.ismutable
23+
24+
#####
25+
##### `AbstractDomain`
26+
#####
27+
28+
abstract type AbstractDomain end
29+
30+
struct RealDomain <: AbstractDomain end
31+
32+
struct ComplexDomain <: AbstractDomain end
33+
34+
#####
35+
##### `AbstractArgument`
36+
#####
37+
38+
abstract type AbstractArgument end
39+
40+
#== `Signature` ==#
41+
42+
struct Signature{I <: Tuple{Vararg{AbstractArgument}},
43+
O <: Tuple{Vararg{AbstractArgument}}}
44+
input::I
45+
output::O
46+
end
47+
48+
Signature(input, output) = Signature(markupify.(tuplify(input)),
49+
markupify.(tuplify(output)))
50+
51+
tuplify(x) = tuple(x)
52+
tuplify(x::Tuple) = x
53+
54+
markupify(x::AbstractArgument) = x
55+
markupify(x) = markup(x)
56+
57+
#== `Ignore` ==#
58+
59+
struct Ignore <: AbstractArgument end
60+
61+
Base.length(::Ignore) = 0
62+
63+
#== `AbstractVariable` ==#
64+
65+
abstract type AbstractVariable <: AbstractArgument end
66+
67+
struct Scalar{D <: AbstractDomain} <: AbstractVariable
68+
domain::D
69+
end
70+
71+
const RealScalar = Scalar{RealDomain}
72+
73+
RealScalar() = Scalar(RealDomain())
74+
75+
const ComplexScalar = Scalar{ComplexDomain}
76+
77+
ComplexScalar() = Scalar(ComplexDomain())
78+
79+
struct Tensor{D <: AbstractDomain, L <: AbstractLayout} <: AbstractVariable
80+
domain::D
81+
layout::L
82+
end
83+
84+
const RealTensor = Tensor{RealDomain}
85+
86+
RealTensor(layout) = Tensor(RealDomain(), layout)
87+
88+
const ComplexTensor = Tensor{ComplexDomain}
89+
90+
ComplexTensor(layout) = Tensor(ComplexDomain(), layout)
91+
92+
Base.length(::Scalar) = 1
93+
Base.length(t::Tensor) = length(t.layout)
94+
95+
#####
96+
##### `@sig`
97+
#####
98+
99+
macro sig(expr)
100+
signature_type_from_expr(expr)
101+
end
102+
103+
const MALFORMED_SIG_ERROR_MESSAGE = "Malformed expression given to `@sig`; see `@sig` docstring for proper format."
104+
105+
function signature_type_from_expr(expr)
106+
@assert(expr.head === :call && expr.args[1] === : && length(expr.args) === 3, MALFORMED_SIG_ERROR_MESSAGE)
107+
input_types = map(parse_into_markup_type, split_infix_args(expr.args[2], :))
108+
output_types = map(parse_into_markup_type, split_infix_args(expr.args[3], :))
109+
return :(Signature{<:Tuple{$(input_types...)}, <:Tuple{$(output_types...)}})
110+
end
111+
112+
split_infix_args(invocation::Symbol, ::Symbol) = (invocation,)
113+
114+
function split_infix_args(invocation::Expr, op::Symbol)
115+
if invocation.head === :call && invocation.args[1] === op
116+
return (split_infix_args(invocation.args[2], op)..., invocation.args[3])
117+
end
118+
return (invocation,)
119+
end
120+
121+
function parse_into_markup_type(x)
122+
if x === :R
123+
return :(Scalar{RealDomain})
124+
elseif x === :C
125+
return :(Scalar{ComplexDomain})
126+
elseif x === :_
127+
return :(Ignore)
128+
elseif isa(x, Expr) && length(x.args) === 1
129+
if x.head === :vect
130+
domain = x.args[1]
131+
if domain === :R
132+
return :(Tensor{RealDomain})
133+
elseif domain === :C
134+
return :(Tensor{ComplexDomain})
135+
end
136+
elseif x.head === :braces
137+
vararg_type = parse_into_markup_type(x.args[1])
138+
return :(Vararg{$vararg_type})
139+
end
140+
end
141+
error(string("Encountered unparseable signature element `", x, "`. ",
142+
MALFORMED_SIG_ERROR_MESSAGE))
143+
end

0 commit comments

Comments
 (0)