diff --git a/src/compiler/chainrules.jl b/src/compiler/chainrules.jl index 223d3867c..10e7d8abb 100644 --- a/src/compiler/chainrules.jl +++ b/src/compiler/chainrules.jl @@ -15,25 +15,25 @@ The first return value is `true` if the `rrule` exists, `false` otherwise. If it does not, then the second argument is a list of edges to attach to the CodeInfo for a generated function, such that if a suitable rule is defined later, the generated function will recompile. """ -function has_chain_rrule(T) +function has_chain_rrule(T, world) config_T, arg_Ts = Iterators.peel(T.parameters) - configured_rrule_m = meta(Tuple{typeof(rrule), config_T, arg_Ts...}) + configured_rrule_m = meta(Tuple{typeof(rrule), config_T, arg_Ts...}; world) is_ambig = configured_rrule_m === nothing # this means there was an ambiguity error, on configured_rrule if !is_ambig && _is_rrule_redispatcher(configured_rrule_m.method) # The config is not being used: # it is being redispatched without config, so we need the method it redispatches to - rrule_m = meta(Tuple{typeof(rrule), arg_Ts...}) + rrule_m = meta(Tuple{typeof(rrule), arg_Ts...}; world) # Thus any no_rrule that might apply must also not have a config because if there was a # no_rrule with a config that applied then there would also be a rrule with config that applied - no_rrule_m = meta(Tuple{typeof(ChainRulesCore.no_rrule), arg_Ts...}) + no_rrule_m = meta(Tuple{typeof(ChainRulesCore.no_rrule), arg_Ts...}; world) else # Not being redispatched: it does have a config rrule_m = configured_rrule_m # Thus any no_rrule that might apply must also have a config because if it applied # it will be identical, and if it doesn't we don't care what it is. - no_rrule_m = meta(Tuple{typeof(ChainRulesCore.no_rrule), config_T, arg_Ts...}) + no_rrule_m = meta(Tuple{typeof(ChainRulesCore.no_rrule), config_T, arg_Ts...}; world) end is_ambig |= rrule_m === nothing # this means there was an ambiguity error on unconfigured rrule diff --git a/src/compiler/emit.jl b/src/compiler/emit.jl index 1c82a44f1..ca79f11ce 100644 --- a/src/compiler/emit.jl +++ b/src/compiler/emit.jl @@ -95,8 +95,8 @@ end varargs(m::Method, n) = m.isva ? n - m.nargs + 1 : nothing -function _generate_pullback_via_decomposition(T) - (m = meta(T)) === nothing && return +function _generate_pullback_via_decomposition(T, world) + (m = meta(T; world)) === nothing && return va = varargs(m.method, length(T.parameters)) forw, back = stacks!(Adjoint(IR(m), varargs = va, normalise = false), T) m, forw, back diff --git a/src/compiler/interface2.jl b/src/compiler/interface2.jl index bf3692a30..31ae7eaf4 100644 --- a/src/compiler/interface2.jl +++ b/src/compiler/interface2.jl @@ -6,7 +6,7 @@ function edge!(m::IRTools.Meta, edge::Core.MethodInstance) return end -@generated function _pullback(ctx::AContext, f, args...) +function _generate_pullback(ctx, world, f, args...) # Try using ChainRulesCore if is_kwfunc(f, args...) # if it is_kw then `args[1]` are the keyword args, `args[2]` is actual function @@ -17,7 +17,7 @@ end chain_rrule_f = :chain_rrule end - hascr, cr_edge = has_chain_rrule(cr_T) + hascr, cr_edge = has_chain_rrule(cr_T, world) hascr && return :($chain_rrule_f(ZygoteRuleConfig(ctx), f, args...)) # No ChainRule, going to have to work it out. @@ -25,9 +25,13 @@ end ignore_sig(T) && return :(f(args...), Pullback{$T}(())) g = try - _generate_pullback_via_decomposition(T) + _generate_pullback_via_decomposition(T, world) catch e - rethrow(CompileError(T,e)) + if VERSION < v"1.8" + # work around Julia bug + rethrow(CompileError(T,e)) + end + return :(throw($(CompileError(T,e)))) end g === nothing && return :(f(args...), Pullback{$T}((f,))) meta, forw, _ = g @@ -40,12 +44,16 @@ end return update!(meta.code, forw) end -@generated function (j::Pullback{T})(Δ) where T +function _generate_callable_pullback(j::Type{<:Pullback{T}}, world, Δ) where T ignore_sig(T) && return :nothing g = try - _generate_pullback_via_decomposition(T) + _generate_pullback_via_decomposition(T, world) catch e - rethrow(CompileError(T,e)) + if VERSION < v"1.8" + # work around Julia bug + rethrow(CompileError(T,e)) + end + return :(throw($(CompileError(T,e)))) end if g === nothing Δ == Nothing && return :nothing @@ -57,3 +65,45 @@ end back = slots!(inlineable!(back)) return update!(meta.code, back) end + +if VERSION >= v"1.10.0-DEV.873" + +# on Julia 1.10, generated functions need to keep track of the world age + +function _pullback_generator(world::UInt, source, self, ctx, f, args) + ret = _generate_pullback(ctx, world, f, args...) + ret isa Core.CodeInfo && return ret + + stub = Core.GeneratedFunctionStub(identity, Core.svec(:methodinstance, :ctx, :f, :args), Core.svec()) + stub(world, source, ret) +end + +@eval function _pullback(ctx::AContext, f, args...) + $(Expr(:meta, :generated, _pullback_generator)) + $(Expr(:meta, :generated_only)) +end + +function _callable_pullback_generator(world::UInt, source, self, Δ) + ret = _generate_callable_pullback(self, world, Δ) + ret isa Core.CodeInfo && return ret + + stub = Core.GeneratedFunctionStub(identity, Core.svec(:methodinstance, :Δ), Core.svec()) + stub(world, source, ret) +end + +@eval function (j::Pullback)(Δ) + $(Expr(:meta, :generated, _callable_pullback_generator)) + $(Expr(:meta, :generated_only)) +end + +else + +@generated function _pullback(ctx::AContext, f, args...) + _generate_pullback(ctx, nothing, f, args...) +end + +@generated function (j::Pullback)(Δ) + _generate_callable_pullback(j, nothing, Δ) +end + +end diff --git a/src/lib/literal_getproperty.jl b/src/lib/literal_getproperty.jl index c13f7a89b..c50cab171 100644 --- a/src/lib/literal_getproperty.jl +++ b/src/lib/literal_getproperty.jl @@ -1,6 +1,6 @@ # Mostly copied over from Cassette in `src/overdub.jl` # Return `Reflection` for signature `sigtypes` and `world`, if possible. Otherwise, return `nothing`. -function reflect(@nospecialize(sigtypes::Tuple), world::UInt = typemax(UInt)) +function reflect(@nospecialize(sigtypes::Tuple), world::UInt) if length(sigtypes) > 2 && sigtypes[1] === typeof(invoke) @assert sigtypes[3] <: Type{<:Tuple} sigtypes = (sigtypes[2], sigtypes[3].parameters[1].parameters...) @@ -41,23 +41,32 @@ end # ugly hack to make differentiating `getproperty` infer a lot better -@generated function _pullback(cx::AContext, ::typeof(literal_getproperty), x, ::Val{f}) where f +function _generate_literal_getproperty(ctx, world, x, ::Type{Val{f}}) where f + world = something(world, typemax(UInt)) + sig(x) = Tuple{x, typeof(f)} rrule_sig(x) = Tuple{typeof(getproperty), x, typeof(f)} - pb_sig(x) = Tuple{cx, typeof(getproperty), x, typeof(f)} + pb_sig(x) = Tuple{ctx, typeof(getproperty), x, typeof(f)} + @static if VERSION >= v"1.10.0-DEV.65" + which(f, t) = Base._which(Base.signature_type(f, t); world).method + else + which(f, t) = Base.which(f, t) + end - # either `getproperty` has a custom implementation or `_pullback(cx, getproperty, x, f)` + # either `getproperty` has a custom implementation or `_pullback(ctx, getproperty, x, f)` # / `rrule(getproperty, x, f) is overloaded directly is_getfield_fallback = which(getproperty, sig(x)) == which(getproperty, sig(Any)) && which(_pullback, pb_sig(x)) == which(_pullback, pb_sig(Any)) && which(rrule, rrule_sig(x)) == which(rrule, rrule_sig(Any)) - #ccall(:jl_safe_printf, Cvoid, (Cstring,), "$is_getfield_fallback: $x\n") - if is_getfield_fallback # just copy pullback of `literal_getfield` - mi, _sig, sparams = reflect((typeof(_pullback), cx, typeof(literal_getfield), x, Val{f})) - ci = copy(Core.Compiler.retrieve_code_info(mi)) + mi, _sig, sparams = reflect((typeof(_pullback), ctx, typeof(literal_getfield), x, Val{f}), world) + ci = if VERSION >= v"1.10.0-DEV.873" + copy(Core.Compiler.retrieve_code_info(mi, world)) + else + copy(Core.Compiler.retrieve_code_info(mi)) + end # we need to change the second arg to `_pullback` from `literal_getproperty` to # `literal_getfield` @@ -69,18 +78,44 @@ end # backedge for `_pullback`, see https://docs.julialang.org/en/v1/devdocs/ast/#MethodInstance # this will cause a backedge to this particular MethodInstance to be attached to - # `_pullback(cx, getproperty, x, f)` - mi_pb_getproperty, _, _ = reflect((typeof(_pullback), pb_sig(x).parameters...)) - mi_getproperty, _, _ = reflect((typeof(getproperty), sig(x).parameters...)) - mi_rrule, _, _ = reflect((typeof(rrule), rrule_sig(x).parameters...)) + # `_pullback(ctx, getproperty, x, f)` + mi_pb_getproperty, _, _ = reflect((typeof(_pullback), pb_sig(x).parameters...), world) + mi_getproperty, _, _ = reflect((typeof(getproperty), sig(x).parameters...), world) + mi_rrule, _, _ = reflect((typeof(rrule), rrule_sig(x).parameters...), world) ci.edges = Core.MethodInstance[mi, mi_pb_getproperty, mi_getproperty, mi_rrule] + # XXX: on 1.10, we should also set metadata like min-world and max-world return ci else # nothing to optimize here, need to recurse into `getproperty` return quote Base.@_inline_meta - _pullback(cx, getproperty, x, $(QuoteNode(f))) + _pullback(ctx, getproperty, x, $(QuoteNode(f))) end end end + +if VERSION >= v"1.10.0-DEV.873" + +# on Julia 1.10, generated functions need to keep track of the world age + +function _literal_getproperty_pullback_generator(world::UInt, source, self, ctx, literal_getproperty, x, f) + ret = _generate_literal_getproperty(ctx, world, x, f) + ret isa Core.CodeInfo && return ret + + stub = Core.GeneratedFunctionStub(identity, Core.svec(:methodinstance, :ctx, :literal_getproperty, :x, :f), Core.svec()) + stub(world, source, ret) +end + +@eval function _pullback(ctx::AContext, ::typeof(literal_getproperty), x, f) + $(Expr(:meta, :generated, _literal_getproperty_pullback_generator)) + $(Expr(:meta, :generated_only)) +end + +else + +@generated function _pullback(ctx::AContext, ::typeof(literal_getproperty), x, f) + _generate_literal_getproperty(ctx, nothing, x, f) +end + +end diff --git a/test/features.jl b/test/features.jl index 112c5b937..0499987d8 100644 --- a/test/features.jl +++ b/test/features.jl @@ -401,7 +401,11 @@ global_param = 3 y, back = Zygote._pullback(cx, x -> x*global_param, 2) @test y == 6 @test back(1) == (nothing, 3) - Zygote.cache(cx)[GlobalRef(Main, :global_param)] == 2 + ref = first(keys(Zygote.cache(cx))) + @test ref isa GlobalRef + @test ref.mod == Main + @test ref.name == :global_param + @test Zygote.cache(cx)[ref] == 2 end function pow_try(x)