Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Handle unreachable blocks in the adjoint CFG #1465

Merged
merged 3 commits into from
Oct 18, 2023

Conversation

Pangoraw
Copy link
Contributor

@Pangoraw Pangoraw commented Oct 15, 2023

The error described in #1118 and #1380, stopped happening in Julia 1.10-beta (maybe because of JuliaLang/julia#50943) but the gradient is now wrong.

julia> f(x) = @inbounds return x
f (generic function with 1 method)

julia> Zygote.gradient(f,1.)
(nothing,)

This PR adds unreachable branches at the end of blocks in the adjoint when those blocks are unreachable in the primal which fixes the issue in both 1.9 and 1.10 because it avoids implicit branches.

I also though of removing unreachable blocks altogether since some operations are invalid for these blocks in IRTools (see dominators for example) but it seemed to do a lot more work than this fix.

Fixes #1380
Fixes #1118

Note

It depends on FluxML/IRTools.jl#115.

Copy link
Member

@oxinabox oxinabox left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this seems good to me, based on the new test. But I am not that familar with this bit of the code

@Pangoraw
Copy link
Contributor Author

Pangoraw commented Oct 17, 2023

I had to add the Core.throw call because having an unreachable in a reachable path throws a compiler error. And it is possible to manually construct a pullback which reaches these adjoint blocks whose primal blocks are unreachable. This is specifically for all versions concerned by FluxML/IRTools.jl#115 (>= 1.10) since unreachable is emitted as a Core.throw call otherwise.

julia> using Zygote

julia> function f(x, cond)
           if cond
               return x
           else
               return 2x
           end
           return 3x
       end
f (generic function with 1 method)

julia> back = Zygote.Pullback{Tuple{typeof(f), Float64, Bool}, Any}(([Returns([1,2,3])], [Returns([1,2,3])], 0x03))
(f)

julia> back(1.)
ERROR: "unreachable"
Stacktrace:
 [1] f
   @ ./REPL[2]:7 [inlined]
 [2] (::Zygote.Pullback{Tuple{typeof(f), Float64, Bool}, Any})(Δ::Float64)
   @ Zygote ~/Projects/Zygote.jl/src/compiler/interface2.jl:0
 [3] top-level scope
   @ REPL[5]:1

@ToucheSir
Copy link
Member

My understanding is that this may incur an allocation for the string "unreachable" in some methods. Is this reliably optimized out in practice? If not, it might make sense to define a custom immutable error type. That type could also track more information such as file and line number to help with troubleshooting.

@Pangoraw
Copy link
Contributor Author

Pangoraw commented Oct 18, 2023

The string "unreachable" string literal is actually allocated once by the parser and permanently rooted.

julia> using Zygote

julia> f(x) = @inbounds return x;

julia> invalid_pull = Zygote.Pullback{Tuple{typeof(f), Float64}, Any}(0x02,);

julia> m = only(methods(Zygote.adjointcfg));
       root = m.roots[findfirst(==("unreachable"), m.roots)];

julia> unreachable_string = try invalid_pull(1.); catch exc; exc end
"unreachable"

julia> unreachable_string === root
true

One problem is that it may taint the effects of the pullback compared to the case without unreachable blocks (which we could have if we removed unreachable blocks from the primal):

julia> Core.Compiler.infer_effects(invalid_pull, (Float64,))
(!c,!e,!n,!t,!s,!m,+i)′

julia> g(x) = return x;

julia> _, pull = Zygote.pullback(g, 1.);
       Core.Compiler.infer_effects(pull.back, (Float64,))
(+c,+e,+n,+t,+s,+m,+i)

But any sufficiently complex pullbacks (involving Zygote.Stack) can throw too so it may not be a problem after all.

@ToucheSir
Copy link
Member

ToucheSir commented Oct 18, 2023

A quick sanity check before I merge: does this pass tests locally on nightly for you?

@Pangoraw
Copy link
Contributor Author

The added testset does, but the entire test suite fails with CUDA and a missing make_seed attribute just like in CI.

@ToucheSir
Copy link
Member

Could you try temporarily commenting out the CUDA test imports and the test block in

if has_cuda()
@testset "CUDA tests" begin
include("cuda.jl")
end
@info "CUDA tests have run"
else
@warn "CUDA not found - Skipping CUDA Tests"
end
? Either locally or on this branch to run CI would be fine.

@Pangoraw
Copy link
Contributor Author

I had to cherry-pick #1462 but there is one @optout and two grad checks tests failing. I don't know if those are related to this pr.

@ToucheSir ToucheSir merged commit b152846 into FluxML:master Oct 18, 2023
11 of 13 checks passed
@Pangoraw Pangoraw deleted the unreachable_block branch October 18, 2023 16:03
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Zygote.withgradient crashes Julia REPL return within at-inbounds yields invalid IR
3 participants