Skip to content

Commit

Permalink
fixing type ambiguity of unbroadcast
Browse files Browse the repository at this point in the history
  • Loading branch information
SebastianAment committed Mar 14, 2022
1 parent 843a52d commit c45fa66
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 6 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name = "Zygote"
uuid = "e88e6eb3-aa80-5325-afca-941959d7151f"
version = "0.6.36"
version = "0.6.37"

[deps]
AbstractFFTs = "621f4979-c628-5d54-868e-fcf4e3e8185c"
Expand Down
12 changes: 7 additions & 5 deletions src/lib/broadcast.jl
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,8 @@ unbroadcast(x::Tuple{<:Any}, x̄) = (accum_sum(x̄),)
unbroadcast(x::Base.RefValue, x̄) = (x=accum_sum(x̄),)
unbroadcast(x::Tuple, x̄) = NTuple{length(x)}(length(x) == length(x̄) ?: accum_sum(x̄; dims=2:ndims(x̄))) # case length(x) > 1
unbroadcast(x::Tuple, x̄::Nothing) = nothing
# fixing issue #1184, not duplicate method, since the above allows for an empty tuple
unbroadcast(x::Tuple{<:Any}, x̄::Nothing) = nothing

unbroadcast(x::AbstractArray, x̄::Nothing) = nothing

Expand All @@ -81,7 +83,7 @@ _minus(::Nothing) = nothing
Δ -> (nothing, unbroadcast(x, Δ .* conj.(y)), unbroadcast(y, Δ .* conj.(x)))
@adjoint broadcasted(::typeof(*), x::Number, y::AbstractArray{<:Number}) =
_pullback(*, x, y) # this uses dot(y,Δ) instead of sum(Δ .* conj.(y))
@adjoint broadcasted(::typeof(*), x::AbstractArray{<:Number}, y::Number) =
@adjoint broadcasted(::typeof(*), x::AbstractArray{<:Number}, y::Number) =
_pullback(*, x, y)

@adjoint function broadcasted(::typeof(/), x::Numeric, y::Numeric)
Expand Down Expand Up @@ -181,7 +183,7 @@ _dual_safearg(x) = false
T = Broadcast.combine_eltypes(f, args)
# Avoid generic broadcasting in two easy cases:
if T == Bool
return (f.(args...), _ -> nothing)
return (f.(args...), _ -> nothing)
elseif T <: Real && isconcretetype(T) && _dual_purefun(F) && all(_dual_safearg, args) && !isderiving()
return broadcast_forward(f, args...)
end
Expand Down Expand Up @@ -260,7 +262,7 @@ end
@eval @adjoint broadcasted(::CuArrayStyle, f, args...) =
broadcast_forward(CUDA.cufunc(f), args...)

else # CUDA >= 3.0 -- don't need cufunc(f).
else # CUDA >= 3.0 -- don't need cufunc(f).
# Ordinary broadcasting calls broadcast_forward anyway when certain its' safe,
# so perhaps this can be deleted? Possible edge case here:
# https://github.com/FluxML/Zygote.jl/pull/1018#issuecomment-873629415
Expand All @@ -277,14 +279,14 @@ end
placeholder = similar(xs)
sum(xs, dims = dims), Δ -> (placeholder .= Δ,)
end

# Make sure sum(f, ::CuArray) uses broadcase through forward-mode defined above
# Not the ChainRules.rrule which will use the Zygote.Context and thus not be GPU compatible
@adjoint function sum(f, xs::CUDA.AbstractGPUArray; kws...)
@assert !haskey(kws, :init) # TODO add init support (julia 1.6)
return pullback(__context__, (f, xs) -> sum(f.(xs); kws...), f, xs)
end

@adjoint function Base.convert(::Type{T}, xs::Array) where {T<:CUDA.AbstractGPUArray}
Base.convert(T, xs), Δ -> (nothing, Base.convert(Array, Δ),)
end
Expand Down

0 comments on commit c45fa66

Please sign in to comment.