-
-
Notifications
You must be signed in to change notification settings - Fork 210
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Support Mutation #75
base: master
Are you sure you want to change the base?
Support Mutation #75
Conversation
b560693
to
bfcacfa
Compare
Tried this out with julia> using TransformVariables
julia> t = as((μ = asℝ, σ = asℝ₊))
TransformVariables.TransformNamedTuple{(:μ, :σ),Tuple{TransformVariables.Identity,TransformVariables.ShiftedExp{true,Float64}}}((TransformVariables.Identity(), TransformVariables.ShiftedExp{true,Float64}(0.0)), 2)
julia> sum(inverse(t, transform(t, ones(2))))
2.0
julia> gradient(s -> sum(inverse(t, transform(t, s))), ones(2))
ERROR: Compiling Tuple{typeof(inverse!),Array{Float64,1},TransformVariables.TransformNamedTuple{(:μ, :σ),Tuple{TransformVariables.Identity,TransformVariables.ShiftedExp{true,Float64}}},NamedTuple{(:μ, :σ),Tuple{Float64,Float64}}}: MethodError: no method matching exprtype(::Core.Compiler.IRCode, ::ArgCheck.ArgCheckFlavor)
Closest candidates are:
exprtype(::Core.Compiler.IRCode, ::Expr) at /Users/andreasnoack/.julia/dev/Zygote/src/tools/ir.jl:64
exprtype(::Core.Compiler.IRCode, ::QuoteNode) at /Users/andreasnoack/.julia/dev/Zygote/src/tools/ir.jl:61
exprtype(::Core.Compiler.IRCode, ::GlobalRef) at /Users/andreasnoack/.julia/dev/Zygote/src/tools/ir.jl:60
...
Stacktrace:
[1] _broadcast_getindex_evalf at ./broadcast.jl:625 [inlined]
[2] _broadcast_getindex at ./broadcast.jl:598 [inlined]
[3] getindex at ./broadcast.jl:558 [inlined]
[4] copyto_nonleaf!(::Array{DataType,1}, ::Base.Broadcast.Broadcasted{Base.Broadcast.DefaultArrayStyle{1},Tuple{Base.OneTo{Int64}},typeof(Zygote.exprtype),Tuple{Base.RefValue{Core.Compiler.IRCode},Base.Broadcast.Extruded{Array{Any,1},Tuple{Bool},Tuple{Int64}}}}, ::Base.OneTo{Int64}, ::Int64, ::Int64) at ./broadcast.jl:982
[5] copy at ./broadcast.jl:836 [inlined] |
Oh. I'm now reading the actual error more carefully and can see that it's related to |
340bcb0
to
6b1b80f
Compare
Small update. Maintaining this separately is kind of a pain. I think I'd like to merge this and keep it under a feature flag like |
remember to add this (or something equivalent) later to make tuple vector conversion work. Zygote.@adjoint! function copyto!(xs::AbstractVector, ys::Tuple)
xs_ = copy(xs)
copyto!(xs, ys), function (dxs)
copyto!(xs_, xs)
return (nothing, Tuple(dxs))
end
end |
Is there a way to support mutation of arrays whose gradient is dropped (
So here I want to treat |
If you don't care about gradient at all you can define something like ignore(f) = f()
@nograd ignore and you can do the mutation inside a |
So what is the plan for supporting mutation? |
Specifically, mutation of arrays/values, as opposed to mutation of data structures, which we already support well.
This introduces some internal complexity and makes performance a little trickier, so it remains open whether we'll actually want to merge it in. The main goal right now is that people can play with this and test it, and in particular I'd like to get some nice benchmarks with differential equations (#37).
It is of course not ideal to maintain this separately. One option might be to make mutation optional and compile different code if it's enabled, though again this is a significant additional complexity in the system.