Skip to content

Commit

Permalink
Restrict type broadcast rule to numbers (FluxML#1179)
Browse files Browse the repository at this point in the history
* restrict type broadcast to number

* add a test
  • Loading branch information
mcabbott committed Mar 10, 2022
1 parent 2a2095c commit 843a52d
Show file tree
Hide file tree
Showing 3 changed files with 10 additions and 2 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name = "Zygote"
uuid = "e88e6eb3-aa80-5325-afca-941959d7151f"
version = "0.6.35"
version = "0.6.36"

[deps]
AbstractFFTs = "621f4979-c628-5d54-868e-fcf4e3e8185c"
Expand Down
2 changes: 1 addition & 1 deletion src/lib/broadcast.jl
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,7 @@ end
end
end

@adjoint broadcasted(::Type{T}, x::Numeric) where T =
@adjoint broadcasted(::Type{T}, x::Numeric) where {T<:Number} =
T.(x), ȳ -> (nothing, _project(x, ȳ),)

# General Fallback
Expand Down
8 changes: 8 additions & 0 deletions test/gradcheck.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1407,9 +1407,17 @@ end
@test all(gradient((x,y) -> sum(x .* y), [1,2], [3 4 5]) .≈ ([12, 12], [3 3 3]))
@test all(gradient((x,y) -> sum(x ./ y), [1,2], 5) .≈ ([0.2, 0.2], -0.12))

# https://github.com/FluxML/Zygote.jl/pull/1171
sm = sprand(5, 5, 0.5)
@test gradient(x -> sum(abs2, Float32.(x)), sm)[1] gradient(x -> sum(abs2, x), Matrix{Float32}(sm))[1]
@test gradient(x -> real(sum(ComplexF32.(x) .+ 1 .+ im)), sm)[1] isa SparseMatrixCSC{Float64}

# https://github.com/FluxML/Zygote.jl/issues/1178
function f1179(x)
fs = Ref.(x)
getindex.(fs)
end
@test gradient(sumf1179, ones(2)) == ([2.0, 2.0],)
end

using Zygote: Buffer
Expand Down

0 comments on commit 843a52d

Please sign in to comment.