Skip to content

Commit

Permalink
Merge pull request FluxML#1420 from maleadt/dev
Browse files Browse the repository at this point in the history
Fixes for Julia 1.10
  • Loading branch information
ToucheSir authored May 12, 2023
2 parents 2287a86 + d1bce98 commit 0e211ea
Show file tree
Hide file tree
Showing 5 changed files with 117 additions and 28 deletions.
10 changes: 5 additions & 5 deletions src/compiler/chainrules.jl
Original file line number Diff line number Diff line change
Expand Up @@ -15,25 +15,25 @@ The first return value is `true` if the `rrule` exists, `false` otherwise.
If it does not, then the second argument is a list of edges to attach to the CodeInfo for a generated function,
such that if a suitable rule is defined later, the generated function will recompile.
"""
function has_chain_rrule(T)
function has_chain_rrule(T, world)
config_T, arg_Ts = Iterators.peel(T.parameters)
configured_rrule_m = meta(Tuple{typeof(rrule), config_T, arg_Ts...})
configured_rrule_m = meta(Tuple{typeof(rrule), config_T, arg_Ts...}; world)
is_ambig = configured_rrule_m === nothing # this means there was an ambiguity error, on configured_rrule


if !is_ambig && _is_rrule_redispatcher(configured_rrule_m.method)
# The config is not being used:
# it is being redispatched without config, so we need the method it redispatches to
rrule_m = meta(Tuple{typeof(rrule), arg_Ts...})
rrule_m = meta(Tuple{typeof(rrule), arg_Ts...}; world)
# Thus any no_rrule that might apply must also not have a config because if there was a
# no_rrule with a config that applied then there would also be a rrule with config that applied
no_rrule_m = meta(Tuple{typeof(ChainRulesCore.no_rrule), arg_Ts...})
no_rrule_m = meta(Tuple{typeof(ChainRulesCore.no_rrule), arg_Ts...}; world)
else
# Not being redispatched: it does have a config
rrule_m = configured_rrule_m
# Thus any no_rrule that might apply must also have a config because if it applied
# it will be identical, and if it doesn't we don't care what it is.
no_rrule_m = meta(Tuple{typeof(ChainRulesCore.no_rrule), config_T, arg_Ts...})
no_rrule_m = meta(Tuple{typeof(ChainRulesCore.no_rrule), config_T, arg_Ts...}; world)
end

is_ambig |= rrule_m === nothing # this means there was an ambiguity error on unconfigured rrule
Expand Down
4 changes: 2 additions & 2 deletions src/compiler/emit.jl
Original file line number Diff line number Diff line change
Expand Up @@ -95,8 +95,8 @@ end

varargs(m::Method, n) = m.isva ? n - m.nargs + 1 : nothing

function _generate_pullback_via_decomposition(T)
(m = meta(T)) === nothing && return
function _generate_pullback_via_decomposition(T, world)
(m = meta(T; world)) === nothing && return
va = varargs(m.method, length(T.parameters))
forw, back = stacks!(Adjoint(IR(m), varargs = va, normalise = false), T)
m, forw, back
Expand Down
64 changes: 57 additions & 7 deletions src/compiler/interface2.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ function edge!(m::IRTools.Meta, edge::Core.MethodInstance)
return
end

@generated function _pullback(ctx::AContext, f, args...)
function _generate_pullback(ctx, world, f, args...)
# Try using ChainRulesCore
if is_kwfunc(f, args...)
# if it is_kw then `args[1]` are the keyword args, `args[2]` is actual function
Expand All @@ -17,17 +17,21 @@ end
chain_rrule_f = :chain_rrule
end

hascr, cr_edge = has_chain_rrule(cr_T)
hascr, cr_edge = has_chain_rrule(cr_T, world)
hascr && return :($chain_rrule_f(ZygoteRuleConfig(ctx), f, args...))

# No ChainRule, going to have to work it out.
T = Tuple{f,args...}
ignore_sig(T) && return :(f(args...), Pullback{$T}(()))

g = try
_generate_pullback_via_decomposition(T)
_generate_pullback_via_decomposition(T, world)
catch e
rethrow(CompileError(T,e))
if VERSION < v"1.8"
# work around Julia bug
rethrow(CompileError(T,e))
end
return :(throw($(CompileError(T,e))))
end
g === nothing && return :(f(args...), Pullback{$T}((f,)))
meta, forw, _ = g
Expand All @@ -40,12 +44,16 @@ end
return update!(meta.code, forw)
end

@generated function (j::Pullback{T})(Δ) where T
function _generate_callable_pullback(j::Type{<:Pullback{T}}, world, Δ) where T
ignore_sig(T) && return :nothing
g = try
_generate_pullback_via_decomposition(T)
_generate_pullback_via_decomposition(T, world)
catch e
rethrow(CompileError(T,e))
if VERSION < v"1.8"
# work around Julia bug
rethrow(CompileError(T,e))
end
return :(throw($(CompileError(T,e))))
end
if g === nothing
Δ == Nothing && return :nothing
Expand All @@ -57,3 +65,45 @@ end
back = slots!(inlineable!(back))
return update!(meta.code, back)
end

if VERSION >= v"1.10.0-DEV.873"

# on Julia 1.10, generated functions need to keep track of the world age

function _pullback_generator(world::UInt, source, self, ctx, f, args)
ret = _generate_pullback(ctx, world, f, args...)
ret isa Core.CodeInfo && return ret

stub = Core.GeneratedFunctionStub(identity, Core.svec(:methodinstance, :ctx, :f, :args), Core.svec())
stub(world, source, ret)
end

@eval function _pullback(ctx::AContext, f, args...)
$(Expr(:meta, :generated, _pullback_generator))
$(Expr(:meta, :generated_only))
end

function _callable_pullback_generator(world::UInt, source, self, Δ)
ret = _generate_callable_pullback(self, world, Δ)
ret isa Core.CodeInfo && return ret

stub = Core.GeneratedFunctionStub(identity, Core.svec(:methodinstance, ), Core.svec())
stub(world, source, ret)
end

@eval function (j::Pullback)(Δ)
$(Expr(:meta, :generated, _callable_pullback_generator))
$(Expr(:meta, :generated_only))
end

else

@generated function _pullback(ctx::AContext, f, args...)
_generate_pullback(ctx, nothing, f, args...)
end

@generated function (j::Pullback)(Δ)
_generate_callable_pullback(j, nothing, Δ)
end

end
61 changes: 48 additions & 13 deletions src/lib/literal_getproperty.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# Mostly copied over from Cassette in `src/overdub.jl`
# Return `Reflection` for signature `sigtypes` and `world`, if possible. Otherwise, return `nothing`.
function reflect(@nospecialize(sigtypes::Tuple), world::UInt = typemax(UInt))
function reflect(@nospecialize(sigtypes::Tuple), world::UInt)
if length(sigtypes) > 2 && sigtypes[1] === typeof(invoke)
@assert sigtypes[3] <: Type{<:Tuple}
sigtypes = (sigtypes[2], sigtypes[3].parameters[1].parameters...)
Expand Down Expand Up @@ -41,23 +41,32 @@ end


# ugly hack to make differentiating `getproperty` infer a lot better
@generated function _pullback(cx::AContext, ::typeof(literal_getproperty), x, ::Val{f}) where f
function _generate_literal_getproperty(ctx, world, x, ::Type{Val{f}}) where f
world = something(world, typemax(UInt))

sig(x) = Tuple{x, typeof(f)}
rrule_sig(x) = Tuple{typeof(getproperty), x, typeof(f)}
pb_sig(x) = Tuple{cx, typeof(getproperty), x, typeof(f)}
pb_sig(x) = Tuple{ctx, typeof(getproperty), x, typeof(f)}
@static if VERSION >= v"1.10.0-DEV.65"
which(f, t) = Base._which(Base.signature_type(f, t); world).method
else
which(f, t) = Base.which(f, t)
end

# either `getproperty` has a custom implementation or `_pullback(cx, getproperty, x, f)`
# either `getproperty` has a custom implementation or `_pullback(ctx, getproperty, x, f)`
# / `rrule(getproperty, x, f) is overloaded directly
is_getfield_fallback = which(getproperty, sig(x)) == which(getproperty, sig(Any)) &&
which(_pullback, pb_sig(x)) == which(_pullback, pb_sig(Any)) &&
which(rrule, rrule_sig(x)) == which(rrule, rrule_sig(Any))

#ccall(:jl_safe_printf, Cvoid, (Cstring,), "$is_getfield_fallback: $x\n")

if is_getfield_fallback
# just copy pullback of `literal_getfield`
mi, _sig, sparams = reflect((typeof(_pullback), cx, typeof(literal_getfield), x, Val{f}))
ci = copy(Core.Compiler.retrieve_code_info(mi))
mi, _sig, sparams = reflect((typeof(_pullback), ctx, typeof(literal_getfield), x, Val{f}), world)
ci = if VERSION >= v"1.10.0-DEV.873"
copy(Core.Compiler.retrieve_code_info(mi, world))
else
copy(Core.Compiler.retrieve_code_info(mi))
end

# we need to change the second arg to `_pullback` from `literal_getproperty` to
# `literal_getfield`
Expand All @@ -69,18 +78,44 @@ end

# backedge for `_pullback`, see https://docs.julialang.org/en/v1/devdocs/ast/#MethodInstance
# this will cause a backedge to this particular MethodInstance to be attached to
# `_pullback(cx, getproperty, x, f)`
mi_pb_getproperty, _, _ = reflect((typeof(_pullback), pb_sig(x).parameters...))
mi_getproperty, _, _ = reflect((typeof(getproperty), sig(x).parameters...))
mi_rrule, _, _ = reflect((typeof(rrule), rrule_sig(x).parameters...))
# `_pullback(ctx, getproperty, x, f)`
mi_pb_getproperty, _, _ = reflect((typeof(_pullback), pb_sig(x).parameters...), world)
mi_getproperty, _, _ = reflect((typeof(getproperty), sig(x).parameters...), world)
mi_rrule, _, _ = reflect((typeof(rrule), rrule_sig(x).parameters...), world)
ci.edges = Core.MethodInstance[mi, mi_pb_getproperty, mi_getproperty, mi_rrule]
# XXX: on 1.10, we should also set metadata like min-world and max-world

return ci
else
# nothing to optimize here, need to recurse into `getproperty`
return quote
Base.@_inline_meta
_pullback(cx, getproperty, x, $(QuoteNode(f)))
_pullback(ctx, getproperty, x, $(QuoteNode(f)))
end
end
end

if VERSION >= v"1.10.0-DEV.873"

# on Julia 1.10, generated functions need to keep track of the world age

function _literal_getproperty_pullback_generator(world::UInt, source, self, ctx, literal_getproperty, x, f)
ret = _generate_literal_getproperty(ctx, world, x, f)
ret isa Core.CodeInfo && return ret

stub = Core.GeneratedFunctionStub(identity, Core.svec(:methodinstance, :ctx, :literal_getproperty, :x, :f), Core.svec())
stub(world, source, ret)
end

@eval function _pullback(ctx::AContext, ::typeof(literal_getproperty), x, f)
$(Expr(:meta, :generated, _literal_getproperty_pullback_generator))
$(Expr(:meta, :generated_only))
end

else

@generated function _pullback(ctx::AContext, ::typeof(literal_getproperty), x, f)
_generate_literal_getproperty(ctx, nothing, x, f)
end

end
6 changes: 5 additions & 1 deletion test/features.jl
Original file line number Diff line number Diff line change
Expand Up @@ -401,7 +401,11 @@ global_param = 3
y, back = Zygote._pullback(cx, x -> x*global_param, 2)
@test y == 6
@test back(1) == (nothing, 3)
Zygote.cache(cx)[GlobalRef(Main, :global_param)] == 2
ref = first(keys(Zygote.cache(cx)))
@test ref isa GlobalRef
@test ref.mod == Main
@test ref.name == :global_param
@test Zygote.cache(cx)[ref] == 2
end

function pow_try(x)
Expand Down

0 comments on commit 0e211ea

Please sign in to comment.