Skip to content

Commit

Permalink
Merge remote-tracking branch 'origin/master' into forward
Browse files Browse the repository at this point in the history
  • Loading branch information
MikeInnes committed Jun 18, 2020
2 parents 5850d96 + ceabb33 commit e25ac87
Show file tree
Hide file tree
Showing 20 changed files with 663 additions and 212 deletions.
12 changes: 4 additions & 8 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
name = "Zygote"
uuid = "e88e6eb3-aa80-5325-afca-941959d7151f"
version = "0.4.20"
version = "0.4.21"

[deps]
AbstractFFTs = "621f4979-c628-5d54-868e-fcf4e3e8185c"
ArrayLayouts = "4c555306-a7a7-4459-81d9-ec55ddd5c99a"
DiffRules = "b552c78f-8df3-52c6-915a-8e097449b14b"
ChainRules = "082447d4-558c-5d27-93f4-14fc19e9eca2"
FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b"
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
Future = "9fa8497b-333b-5362-9e8d-4d0656e87820"
Expand All @@ -14,25 +14,21 @@ InteractiveUtils = "b77e0a4c-d291-57a0-90e8-8db25a27a240"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
MacroTools = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09"
NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd"
NaNMath = "77ba4419-2d1f-58cd-9bb1-8ffee604a2e3"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Requires = "ae029012-a4dd-5104-9daa-d747884805df"
SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
ZygoteRules = "700de1a5-db45-46bc-99cf-38207098b444"

[compat]
AbstractFFTs = "0.5"
ArrayLayouts = "0.1, 0.2, 0.3"
DiffRules = "0.0, 0.1, 1"
ChainRules = "0.6.0"
FillArrays = "0.8"
ForwardDiff = "0"
IRTools = "0.3"
IRTools = "0.4"
MacroTools = "0.5"
NNlib = "0.6.5"
NaNMath = "0"
Requires = "0.5, 1.0"
SpecialFunctions = "0"
ZygoteRules = "0.2"
julia = "1"

Expand Down
4 changes: 2 additions & 2 deletions docs/Manifest.toml
Original file line number Diff line number Diff line change
Expand Up @@ -48,9 +48,9 @@ uuid = "a63ad114-7e13-5084-954f-fe012c677804"

[[Parsers]]
deps = ["Dates", "Test"]
git-tree-sha1 = "f8f5d2d4b4b07342e5811d2b6428e45524e241df"
git-tree-sha1 = "f0abb338b4d00306500056a3fd44c221b8473ef2"
uuid = "69de0a69-1ddd-5017-9359-2bf0b02dc9f0"
version = "1.0.2"
version = "1.0.4"

[[Pkg]]
deps = ["Dates", "LibGit2", "Markdown", "Printf", "REPL", "Random", "SHA", "UUIDs"]
Expand Down
10 changes: 10 additions & 0 deletions docs/src/adjoints.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,15 @@
# Custom Adjoints

!!! note "Prefer to use ChainRules to define custom adjoints"
Zygote supports the use of [ChainRulesCore](http://www.juliadiff.org/ChainRulesCore.jl/stable/) to define custom sensitivities.
It is prefered to define the custom sensitivities using `ChainRulesCore.rrule` as they will work for many AD systems, not just Zygote.
These sensitivities can be added in your own package, or for Base functions they can be added to ChainRules.jl.

This documentation exists to descibe how Zygote works, and how adjoints can be directly defined for Zygote.
Defining adjoints this way does not make them accessible to other AD systems, but does let you do things that directly depend on how Zygote works.
It allows for specific definitions of adjoints that are only defined for Zgyote (which might work differently to more generic definitions defined for all AD).


The `@adjoint` macro is an important part of Zygote's interface; customising your backwards pass is not only possible but widely used and encouraged. While there are specific utilities available for common things like gradient clipping, understanding adjoints will give you the most flexibility. We first give a bit more background on what these pullback things are.

## Pullbacks
Expand Down
24 changes: 12 additions & 12 deletions examples/Manifest.toml
Original file line number Diff line number Diff line change
Expand Up @@ -29,9 +29,9 @@ version = "1.0.1"

[[BinaryProvider]]
deps = ["Libdl", "Logging", "SHA"]
git-tree-sha1 = "428e9106b1ff27593cbd979afac9b45b82372b8c"
git-tree-sha1 = "ecdec412a9abc8db54c0efc5548c64dfce072058"
uuid = "b99e7846-7c00-51b0-8f62-c81ae34c0232"
version = "0.5.9"
version = "0.5.10"

[[CEnum]]
git-tree-sha1 = "62847acab40e6855a9b5905ccb99c2b5cf6b3ebb"
Expand Down Expand Up @@ -93,9 +93,9 @@ version = "1.3.0"

[[DataStructures]]
deps = ["InteractiveUtils", "OrderedCollections"]
git-tree-sha1 = "6166ecfaf2b8bbf2b68d791bc1d54501f345d314"
git-tree-sha1 = "af6d9c86e191c917c2276fbede1137e8ea20157f"
uuid = "864edb3b-99cc-5e75-8d2d-829cb0a9cfe8"
version = "0.17.15"
version = "0.17.17"

[[Dates]]
deps = ["Printf"]
Expand Down Expand Up @@ -146,25 +146,25 @@ version = "2.0.1"

[[IRTools]]
deps = ["InteractiveUtils", "MacroTools", "Test"]
git-tree-sha1 = "8845400bd2d9815d37720251f1b53d27a335e1f4"
git-tree-sha1 = "90ee39f9beaaa186e4968417ea2b8ed5673c91c0"
uuid = "7869d1d1-7146-5819-86e3-90919afe41df"
version = "0.3.2"
version = "0.3.3"

[[InteractiveUtils]]
deps = ["Markdown"]
uuid = "b77e0a4c-d291-57a0-90e8-8db25a27a240"

[[Juno]]
deps = ["Base64", "Logging", "Media", "Profile"]
git-tree-sha1 = "e1ba2a612645b3e07c773c3a208f215745081fe6"
git-tree-sha1 = "a686b0cf235fa3e491b79b4783c2d2382292b436"
uuid = "e5e0dc1b-0480-54bc-9374-aad01c23163d"
version = "0.8.1"
version = "0.8.2"

[[LLVM]]
deps = ["CEnum", "Libdl", "Printf", "Unicode"]
git-tree-sha1 = "93d2e1e960fe47db1a9015e86fad1d47cf67cf59"
git-tree-sha1 = "dd3f584c3dbefe39b2a8fbafa1a3b77e31e21255"
uuid = "929cbde3-209d-540e-8aea-75f648917ca0"
version = "1.4.1"
version = "1.5.1"

[[LibGit2]]
uuid = "76f85450-5226-5b5a-8eaa-529ad045b433"
Expand Down Expand Up @@ -234,9 +234,9 @@ uuid = "9abbd945-dff8-562f-b5e8-e1ebf5ef1b79"

[[ProgressMeter]]
deps = ["Distributed", "Printf"]
git-tree-sha1 = "ea1f4fa0ff5e8b771bf130d87af5b7ef400760bd"
git-tree-sha1 = "b3cb8834eee5410c7246734cc6f4f586fe0dc50e"
uuid = "92933f4c-e287-5a05-a399-4b506db050ca"
version = "1.2.0"
version = "1.3.0"

[[REPL]]
deps = ["InteractiveUtils", "Markdown", "Sockets"]
Expand Down
2 changes: 2 additions & 0 deletions src/Zygote.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ using ArrayLayouts: MemoryLayout, AbstractColumnMajor

import ZygoteRules: @adjoint, @adjoint!, AContext, adjoint, _pullback, pullback, literal_getproperty

using ChainRules: ChainRules, rrule, unthunk
using IRTools
using MacroTools, Requires
using MacroTools: @forward
Expand All @@ -21,6 +22,7 @@ using .Forward

include("compiler/reverse.jl")
include("compiler/emit.jl")
include("compiler/chainrules.jl")
include("compiler/interface.jl")
include("compiler/show.jl")

Expand Down
104 changes: 104 additions & 0 deletions src/compiler/chainrules.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
const chainrules_fallback = which(rrule, Tuple{Any})

"""
has_chain_rrule(T)
For a type-tuple `T` e.g. `Tuple{typeof(f), Int, Float64}`, checks if there is a `rrule` defined for it.
Excluding the generic fallback.
The first return value is `true` if the `rrule` exists, `false` otherwise.
If it does not, then the second argument is a list of edges to attach to the CodeInfo for a generated function,
such that if a suitable rule is defined later, the generated function will recompile.
"""
function has_chain_rrule(T)
m = meta(Tuple{typeof(rrule),T.parameters...})
if m.method !== chainrules_fallback
# found a rrule, no need to add any edges
return true, nothing
end

return false, m.instance
end

"""
is_kwfunc(sigt...)
Determines if `sigt` is the type signature of a kwfunction.
Each element of `sigt` should be a type.
Either the first 3 types are a kwfunc type, a NamedTuple and the matching base function type,
or the first argument is the base function type and it is not a kwfunction.
the remaining types in `sigt` are the types of the argument.
"""
is_kwfunc(::Vararg{Any}) = false
is_kwfunc(k, ::Type{<:NamedTuple}, f, args...) = k===Core.kwftype(f)


"""
wrap_chainrules_output(x)
Convert `x` from the differentials types ChainRules uses to the format Zygote uses internally
(including conjugating complex gradients).
"""
@inline wrap_chainrules_output(x) = conj(unthunk(x)) # For now we are just not going to deal with thunks
@inline wrap_chainrules_output(x::Tuple) = map(wrap_chainrules_output, x)
@inline wrap_chainrules_output(x::ChainRules.AbstractZero) = nothing
for T_outer in (:Tuple, :NamedTuple)
# we create separate methods rather than using a `Union` + an `if` so that we avoid a
# branch that changes output type, because nested AD on that kinda thing makes Zygote less
# than happy.
@eval @inline function wrap_chainrules_output(x::ChainRules.Composite{P, T}) where {P, T<:$T_outer}
xp = map(wrap_chainrules_output, x)
convert($T_outer, xp)
end
end

"""
wrap_chainrules_input(x)
Convert `x` from the format Zygote uses internally (including conjugated complex gradients)
to differentials types ChainRules uses.
"""
@inline wrap_chainrules_input(x) = conj(x)
@inline wrap_chainrules_input(::Nothing) = ChainRules.Zero()
@inline function wrap_chainrules_input(xs::Union{Tuple, NamedTuple})
xp = map(wrap_chainrules_input, xs)
ChainRules.Composite{Any, typeof(xp)}(xp)
end

"""
ZBack{F}(back) <: Function
Wrapper for a ChainRules pullback `back`, that causes it to follow Zygote conventions.
(A functor here is used rather than a closure to avoid boxing issues);
"""
struct ZBack{F} <: Function
back::F
end
@inline (s::ZBack)(dy) = wrap_chainrules_output(s.back(wrap_chainrules_input(dy)))
# `nothing->nothing` can be deleted after https://github.com/FluxML/Zygote.jl/issues/603
# though it might be worth keeping as a performance optimization (benchmarking pending)
@inline (s::ZBack)(::Nothing) = nothing

"""
chain_rrule(f, args...)
Returns a the (primal) value of `f(args...)` and a pullback, by invoking `ChainRulesCore.rrule(f, args...)`.
The pullback is appropriately wrapped up to follow Zygote conventions.
"""
@inline function chain_rrule(f, args...)
y, back = rrule(f, args...)
return y, ZBack(back)
end


"""
chain_rrule_kw(kwf, kwargs, f, args...)
As per [`chain_rrule`](@ref) but with support for kwargs.
`kwf` should be the kwfunc matching to `f`, and `kwargs` are a `NamedTuple` of keyword arguments.
"""
@inline function chain_rrule_kw(kwf, kwargs, f, args...)
y, back = rrule(f, args...; kwargs...)
kw_zpullback(dy) = (nothing, nothing, ZBack(back)(dy)...) # first two nothings are for kwfunc and kwargs
return y, kw_zpullback
end
12 changes: 5 additions & 7 deletions src/compiler/interface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,6 @@ end
# interface2.jl

# Wrappers

_pullback(f, args...) = _pullback(Context(), f, args...)

tailmemaybe(::Nothing) = nothing
Expand Down Expand Up @@ -103,7 +102,7 @@ end
"""
copy!(ps::Params, x::AbstractVector)
copy!(x::AbstractVector, ps::Params)
Copies the content of array `x` into the parameters `ps` or viceversa.
The length of `x` has to be equal to the sum of the lengths
of all parameters.
Expand All @@ -122,7 +121,7 @@ function copy!(x::AbstractVector, ps::Params)
@assert length(x) == sum(length(p) for p in ps)
i = 0
for p in ps
x[i+1:i+length(p)] .= vec(p)
x[i+1:i+length(p)] .= vec(p)
i += length(p)
end
ps
Expand All @@ -147,9 +146,8 @@ end
copy!(gs::Grads, x::AbstractVector)
copy!(x::AbstractVector, gs::Grads)
Copies the content of array `x` into the gradient object `gs` or viceversa.
The length of `x` has to be equal to the sum of the lenghts
of all gradients.
Copies the content of array `x` into the gradient object `gs` or vice versa. The
length of `x` has to be equal to the sum of the lengths of all gradients.
"""
function copy!(gs::Grads, x::AbstractVector)
i = 0
Expand All @@ -163,7 +161,7 @@ end
function copy!(x::AbstractVector, gs::Grads)
i = 0
for p in gs.params
x[i+1:i+length(p)] .= vec(gs[p])
x[i+1:i+length(p)] .= vec(gs[p])
i += length(p)
end
x
Expand Down
24 changes: 21 additions & 3 deletions src/compiler/interface2.jl
Original file line number Diff line number Diff line change
@@ -1,23 +1,41 @@
using IRTools: varargs!, inlineable!, pis!, slots!
using IRTools.Inner: argnames!, update!

ignore(T) = all(T -> T <: Type, T.parameters)
ignore_sig(T) = all(T -> T <: Type, T.parameters)

function edge!(m::IRTools.Meta, edge::Core.MethodInstance)
m.code.edges == nothing && (m.code.edges = Core.MethodInstance[])
push!(m.code.edges, edge)
return
end

@generated function _pullback(ctx::AContext, f, args...)
T = Tuple{f,args...}
ignore(T) && return :(f(args...), Pullback{$T}(()))
ignore_sig(T) && return :(f(args...), Pullback{$T}(()))

iskw = is_kwfunc(f, args...)
# if it is_kw then `args[1]` are the keyword args, `args[2]` is actual function
base_T = iskw ? Tuple{args[2:end]...} : T
hascr, cr_edge = has_chain_rrule(base_T)
chain_rrule_f = iskw ? :chain_rrule_kw : :chain_rrule
hascr && return :($chain_rrule_f(f, args...))

g = try _lookup_grad(T) catch e e end
!(g isa Tuple) && return :(f(args...), Pullback{$T}((f,)))
meta, forw, _ = g
argnames!(meta, Symbol("#self#"), :ctx, :f, :args)
forw = varargs!(meta, forw, 3)
# IRTools.verify(forw)
forw = slots!(pis!(inlineable!(forw)))
@static if VERSION >= v"1.3" # no edges pre-1.3
# be ready to swap to using chainrule if one is declared
cr_edge != nothing && edge!(meta, cr_edge)
end
return update!(meta.code, forw)
end

@generated function (j::Pullback{T})(Δ) where T
ignore(T) && return :nothing
ignore_sig(T) && return :nothing
g = try _lookup_grad(T)
catch e
rethrow(CompileError(T,e))
Expand Down
Loading

0 comments on commit e25ac87

Please sign in to comment.