From 807d6898c8317733a849b9e5a5c2e5c1b6a921a2 Mon Sep 17 00:00:00 2001 From: Paul Tiede Date: Sun, 23 Oct 2022 20:46:57 -0400 Subject: [PATCH 1/2] Added complex broadcasting support --- Project.toml | 4 +- src/lib/broadcast.jl | 104 ++++++++++++++++++++++++++++++++++++++----- 2 files changed, 96 insertions(+), 12 deletions(-) diff --git a/Project.toml b/Project.toml index 7d277f688..209da2578 100644 --- a/Project.toml +++ b/Project.toml @@ -10,7 +10,7 @@ DiffRules = "b552c78f-8df3-52c6-915a-8e097449b14b" Distributed = "8ba89e20-285c-5b6f-9357-94700520ee1b" FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b" ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" -GPUArrays = "0c68f7d7-f131-5f86-a1c3-88cf8149b2d7" # not loaded, just a version bound +GPUArrays = "0c68f7d7-f131-5f86-a1c3-88cf8149b2d7" GPUArraysCore = "46192b85-c4d5-4398-a991-12ede77f4527" IRTools = "7869d1d1-7146-5819-86e3-90919afe41df" InteractiveUtils = "b77e0a4c-d291-57a0-90e8-8db25a27a240" @@ -33,7 +33,7 @@ ChainRulesTestUtils = "1" DiffRules = "1.4" FillArrays = "0.8, 0.9, 0.10, 0.11, 0.12, 0.13" ForwardDiff = "0.10" -GPUArrays = "8.4.2" # not loaded, just a version bound +GPUArrays = "8.4.2" GPUArraysCore = "0.1.1" IRTools = "0.4.4" LogExpFunctions = "0.3.1" diff --git a/src/lib/broadcast.jl b/src/lib/broadcast.jl index 58f7ecf99..4c81a2dac 100644 --- a/src/lib/broadcast.jl +++ b/src/lib/broadcast.jl @@ -120,6 +120,9 @@ end @adjoint broadcasted(::typeof(imag), x::Numeric) = imag.(x), z̄ -> (nothing, im .* real.(z̄)) +# @adjoint broadcasted(::typeof(abs2), x::Numeric) = +# abs2.(x), z̄ -> (nothing, 2 .* real.(z̄) .* x) + @adjoint function broadcasted(::typeof(+), a::AbstractArray{<:Number}, b::Bool) y = b === false ? a : a .+ b y, Δ -> (nothing, Δ, nothing) @@ -190,7 +193,7 @@ _dual_safearg(x) = false # Avoid generic broadcasting in two easy cases: if T == Bool return (f.(args...), _ -> nothing) - elseif T <: Real && isconcretetype(T) && _dual_purefun(F) && all(_dual_safearg, args) && !isderiving() + elseif T <: Union{Real, Complex} && isconcretetype(T) && _dual_purefun(F) && all(_dual_safearg, args) && !isderiving() return broadcast_forward(f, args...) end len = inclen(args) @@ -232,23 +235,64 @@ end import ForwardDiff using ForwardDiff: Dual -dual(x, p) = x -dual(x::Real, p) = Dual(x, p) -dual(x::Bool, p) = x +dual(x, p, pc=()) = x +dual(x::Real, p, pc=()) = Dual(x, p) +dual(x::Bool, p, pc=()) = x +dual(x::Complex, p, pc) = Complex(Dual(real(x), p), Dual(imag(x), pc)) function dual_function(f::F) where F function (args::Vararg{Any,N}) where N - ds = map(args, ntuple(identity,Val(N))) do x, i - dual(x, ntuple(j -> i==j, Val(N))) + if any(a isa Complex for a in args) + ds = map(args, ntuple(identity, Val(N))) do x, i + dual(x, ntuple(j -> i==j, Val(2N)), ntuple(j -> N+i==j, Val(2N))) + end + return f(ds...) + else + ds = map(args, ntuple(identity,Val(N))) do x, i + dual(x, ntuple(j -> i==j, Val(N))) + end + return f(ds...) end - return f(ds...) end end +# function dual_function(f::F) where F +# function (args::Vararg{Any,N}) where N +# ds = map(args, ntuple(identity,Val(N))) do x, i +# dual(x, ntuple(j -> i==j, Val(N))) +# end +# return f(ds...) +# end +# end + +# @inline function broadcast_forward(f, args::Vararg{Any,N}) where N +# valN = Val(N) +# out = dual_function(f).(args...) +# eltype(out) <: Dual || return (out, _ -> nothing) +# y = broadcast(x -> x.value, out) +# function bc_fwd_back(ȳ) +# dargs = ntuple(valN) do i +# unbroadcast(args[i], broadcast((y1, o1) -> y1 * o1.partials[i], ȳ, out)) +# end +# (nothing, nothing, dargs...) # nothings for broadcasted & f +# end +# return y, bc_fwd_back +# end + + @inline function broadcast_forward(f, args::Vararg{Any,N}) where N - valN = Val(N) out = dual_function(f).(args...) - eltype(out) <: Dual || return (out, _ -> nothing) + eltype(out) <: Union{Dual, Complex} || return (out, _ -> nothing) + if any(eltype(a) <: Complex for a in args) + _broadcast_forward_complex(out, args...) + else + _broadcast_forward(out, args...) + end +end + +# Real input and real output +function _broadcast_forward(out::AbstractArray{<:Dual}, args::Vararg{Any, N}) where {N} + valN = Val(N) y = broadcast(x -> x.value, out) function bc_fwd_back(ȳ) dargs = ntuple(valN) do i @@ -259,6 +303,47 @@ end return y, bc_fwd_back end +# This handles complex output and real input +function _broadcast_forward(out::AbstractArray{<:Complex}, args::Vararg{Any, N}) where {N} + valN = Val(N) + y = broadcast(x -> Complex.(real(x).value, imag(x).value), out) + function bc_fwd_back(ȳ) + dargs = ntuple(valN) do i + unbroadcast(args[i], broadcast((y1, o1) -> (real(y1)*real(o1).partials[i] + imag(y1)*imag(o1).partials[i]), ȳ, out)) + end + (nothing, nothing, dargs...) # nothings for broadcasted & f + end + return y, bc_fwd_back + end + +# This handles complex input and real output +function _broadcast_forward_complex(out::AbstractArray{<:Dual}, args::Vararg{Any, N}) where {N} + valN = Val(N) + y = broadcast(x -> x.value, out) + function bc_fwd_back(ȳ) + dargs = ntuple(valN) do i + unbroadcast(args[i], broadcast((y1, o1) -> y1 * Complex(o1.partials[i], o1.partials[i+N]), ȳ, out)) + end + (nothing, nothing, dargs...) # nothings for broadcasted & f + end + return y, bc_fwd_back +end + +# This is for complex input and complex output +# I am a little confused what derivative we want to use here so it hasn't been implemented +function _broadcast_forward_complex(out::AbstractArray{<:Complex}, args::Vararg{Any, N}) where {N} + throw("Complex output and input not supported in Zygote broadcast_forward") + # valN = Val(N) + # y = broadcast(x -> Complex(x.re.value, x.im.value), out) + # function bc_fwd_back(ȳ) + # dargs = ntuple(valN) do i + # unbroadcast(args[i], broadcast((y1, o1) -> y1 * Complex(o1.partials[i], o1.partials[i+N-1]), ȳ, out)) + # end + # (nothing, nothing, dargs...) # nothings for broadcasted & f + # end + # return y, bc_fwd_back +end + using GPUArraysCore # replaces @require CUDA block, weird indenting to preserve git blame # Ordinary broadcasting calls broadcast_forward anyway when certain its' safe, @@ -287,4 +372,3 @@ using GPUArraysCore # replaces @require CUDA block, weird indenting to preserve end pull_block_vert(sz, Δ::AbstractGPUArray, A::Number) = @allowscalar Δ[sz] - From 2972fafacc79e11a26c5bf36be6caf1deb39ffc2 Mon Sep 17 00:00:00 2001 From: Paul Tiede Date: Mon, 24 Oct 2022 17:08:50 -0400 Subject: [PATCH 2/2] Added tests and clean up the code --- src/lib/broadcast.jl | 60 +++++++++++------------------------------ test/complex.jl | 1 - test/cuda.jl | 64 +++++++++++++++++++++++++++++--------------- 3 files changed, 58 insertions(+), 67 deletions(-) diff --git a/src/lib/broadcast.jl b/src/lib/broadcast.jl index 4c81a2dac..5e5ed38ed 100644 --- a/src/lib/broadcast.jl +++ b/src/lib/broadcast.jl @@ -120,8 +120,8 @@ end @adjoint broadcasted(::typeof(imag), x::Numeric) = imag.(x), z̄ -> (nothing, im .* real.(z̄)) -# @adjoint broadcasted(::typeof(abs2), x::Numeric) = -# abs2.(x), z̄ -> (nothing, 2 .* real.(z̄) .* x) +@adjoint broadcasted(::typeof(abs2), x::Numeric) = + abs2.(x), z̄ -> (nothing, 2 .* real.(z̄) .* x) @adjoint function broadcasted(::typeof(+), a::AbstractArray{<:Number}, b::Bool) y = b === false ? a : a .+ b @@ -235,11 +235,13 @@ end import ForwardDiff using ForwardDiff: Dual +# Updated to use proposal from 961 dual(x, p, pc=()) = x dual(x::Real, p, pc=()) = Dual(x, p) dual(x::Bool, p, pc=()) = x dual(x::Complex, p, pc) = Complex(Dual(real(x), p), Dual(imag(x), pc)) +# Updated to use proposal from 961 function dual_function(f::F) where F function (args::Vararg{Any,N}) where N if any(a isa Complex for a in args) @@ -256,30 +258,6 @@ function dual_function(f::F) where F end end -# function dual_function(f::F) where F -# function (args::Vararg{Any,N}) where N -# ds = map(args, ntuple(identity,Val(N))) do x, i -# dual(x, ntuple(j -> i==j, Val(N))) -# end -# return f(ds...) -# end -# end - -# @inline function broadcast_forward(f, args::Vararg{Any,N}) where N -# valN = Val(N) -# out = dual_function(f).(args...) -# eltype(out) <: Dual || return (out, _ -> nothing) -# y = broadcast(x -> x.value, out) -# function bc_fwd_back(ȳ) -# dargs = ntuple(valN) do i -# unbroadcast(args[i], broadcast((y1, o1) -> y1 * o1.partials[i], ȳ, out)) -# end -# (nothing, nothing, dargs...) # nothings for broadcasted & f -# end -# return y, bc_fwd_back -# end - - @inline function broadcast_forward(f, args::Vararg{Any,N}) where N out = dual_function(f).(args...) eltype(out) <: Union{Dual, Complex} || return (out, _ -> nothing) @@ -303,18 +281,19 @@ function _broadcast_forward(out::AbstractArray{<:Dual}, args::Vararg{Any, N}) wh return y, bc_fwd_back end -# This handles complex output and real input -function _broadcast_forward(out::AbstractArray{<:Complex}, args::Vararg{Any, N}) where {N} - valN = Val(N) - y = broadcast(x -> Complex.(real(x).value, imag(x).value), out) - function bc_fwd_back(ȳ) - dargs = ntuple(valN) do i - unbroadcast(args[i], broadcast((y1, o1) -> (real(y1)*real(o1).partials[i] + imag(y1)*imag(o1).partials[i]), ȳ, out)) - end - (nothing, nothing, dargs...) # nothings for broadcasted & f +# This handles complex output and real input and uses the definition from +# ChainRules.jl's section on complex numbers +function _broadcast_forward(out::AbstractArray{<:Complex{<:Dual}}, args::Vararg{Any, N}) where {N} + valN = Val(N) + y = broadcast(x -> Complex.(real(x).value, imag(x).value), out) + function bc_fwd_back(ȳ) + dargs = ntuple(valN) do i + unbroadcast(args[i], broadcast((y1, o1) -> (real(y1)*real(o1).partials[i] + imag(y1)*imag(o1).partials[i]), ȳ, out)) end - return y, bc_fwd_back + (nothing, nothing, dargs...) # nothings for broadcasted & f end + return y, bc_fwd_back +end # This handles complex input and real output function _broadcast_forward_complex(out::AbstractArray{<:Dual}, args::Vararg{Any, N}) where {N} @@ -333,15 +312,6 @@ end # I am a little confused what derivative we want to use here so it hasn't been implemented function _broadcast_forward_complex(out::AbstractArray{<:Complex}, args::Vararg{Any, N}) where {N} throw("Complex output and input not supported in Zygote broadcast_forward") - # valN = Val(N) - # y = broadcast(x -> Complex(x.re.value, x.im.value), out) - # function bc_fwd_back(ȳ) - # dargs = ntuple(valN) do i - # unbroadcast(args[i], broadcast((y1, o1) -> y1 * Complex(o1.partials[i], o1.partials[i+N-1]), ȳ, out)) - # end - # (nothing, nothing, dargs...) # nothings for broadcasted & f - # end - # return y, bc_fwd_back end using GPUArraysCore # replaces @require CUDA block, weird indenting to preserve git blame diff --git a/test/complex.jl b/test/complex.jl index efb1e06dd..e50c57486 100644 --- a/test/complex.jl +++ b/test/complex.jl @@ -120,4 +120,3 @@ end end @test Zygote.hessian(fun, collect(1:9)) ≈ [14 0 0 0 0 0 2 0 0; 0 16 0 0 0 0 0 4 0; 0 0 18 0 0 0 0 0 6; 0 0 0 14 0 0 8 0 0; 0 0 0 0 16 0 0 10 0; 0 0 0 0 0 18 0 0 12; 2 0 0 8 0 0 0 0 0; 0 4 0 0 10 0 0 0 0; 0 0 6 0 0 12 0 0 0] end - diff --git a/test/cuda.jl b/test/cuda.jl index 5cb1c8cdc..113def69e 100644 --- a/test/cuda.jl +++ b/test/cuda.jl @@ -26,7 +26,7 @@ end g_gpu = gradient(x -> v(x, 7), a_gpu)[1] @test g_gpu isa CuArray @test g_gpu |> collect ≈ g - + w(x) = sum(broadcast(log, x)) g = gradient(x -> w(x), a)[1] g_gpu = gradient(x -> w(x), a_gpu)[1] @@ -38,7 +38,7 @@ end @test gradient(x -> sum(x .> 3), a_gpu) == (nothing,) g3 = gradient(x -> sum(x .^ 3) / count(x .> 3), a)[1] # was Can't differentiate gc_preserve_end expression @test_skip cu(g3) ≈ gradient(x -> sum(x .^ 3) / sum(x .> 3), a_gpu)[1] # was KernelException -- not fixed by PR #1018 - @test cu(g3) ≈ gradient(x -> sum(x .^ 3) / count(x .> 3), a_gpu)[1] + @test cu(g3) ≈ gradient(x -> sum(x .^ 3) / count(x .> 3), a_gpu)[1] # Projection: eltype preservation: @test gradient(x -> 2.3 * sum(x.^4), a_gpu)[1] isa CuArray{Float32} @@ -90,40 +90,40 @@ end @testset "gradient algebra" begin w, b = rand(2) |> cu, rand(2) |> cu x1, x2 = rand(2) |> cu, rand(2) |> cu - - gs1 = gradient(() -> sum(w .* x1), Params([w])) - gs2 = gradient(() -> sum(w .* x2), Params([w])) + + gs1 = gradient(() -> sum(w .* x1), Params([w])) + gs2 = gradient(() -> sum(w .* x2), Params([w])) @test .- gs1 isa Grads - @test gs1 .- gs2 isa Grads + @test gs1 .- gs2 isa Grads @test .+ gs1 isa Grads - @test gs1 .+ gs2 isa Grads - @test 2 .* gs1 isa Grads + @test gs1 .+ gs2 isa Grads + @test 2 .* gs1 isa Grads @test (2 .* gs1)[w] ≈ 2 * gs1[w] - @test gs1 .* 2 isa Grads - @test gs1 ./ 2 isa Grads - @test (gs1 .+ gs2)[w] ≈ gs1[w] .+ gs2[w] + @test gs1 .* 2 isa Grads + @test gs1 ./ 2 isa Grads + @test (gs1 .+ gs2)[w] ≈ gs1[w] .+ gs2[w] gs12 = gs1 .+ gs2 gs1 .+= gs2 - @test gs12[w] ≈ gs1[w] + @test gs12[w] ≈ gs1[w] gs3 = gradient(() -> sum(w .* x1), Params([w, b])) # grad nothing with respect to b - gs4 = gradient(() -> sum(w .* x2 .+ b), Params([w, b])) + gs4 = gradient(() -> sum(w .* x2 .+ b), Params([w, b])) @test .- gs3 isa Grads - @test gs3 .- gs4 isa Grads + @test gs3 .- gs4 isa Grads @test .+ gs3 isa Grads - @test gs3 .+ gs4 isa Grads - @test 2 .* gs3 isa Grads - @test gs3 .* 2 isa Grads - @test gs3 ./ 2 isa Grads + @test gs3 .+ gs4 isa Grads + @test 2 .* gs3 isa Grads + @test gs3 .* 2 isa Grads + @test gs3 ./ 2 isa Grads @test (gs3 .+ gs4)[w] ≈ gs3[w] .+ gs4[w] - @test (gs3 .+ gs4)[b] ≈ gs4[b] - + @test (gs3 .+ gs4)[b] ≈ gs4[b] + @test gs3 .+ IdDict(w => similar(w), b => similar(b)) isa Grads gs3 .+= IdDict(p => randn!(similar(p)) for p in keys(gs3)) - @test gs3 isa Grads + @test gs3 isa Grads @test_throws ArgumentError gs1 .+ gs4 end @@ -140,3 +140,25 @@ end @test_skip gradient((x,y) -> sum(vcat(x,y)), 1f0, r, 2f0, r)[2] isa CUDA.CuArray{Float32} end + +@testset "CUDA complex broadcasting" begin + # Issue #995 test + x = rand(Float32, 50) + y = complex(rand(Float32, 50)) + + xgpu = cu(x) + ygpu = cu(y) + + f995(A) = norm(@. A*xgpu*ygpu) + g1 = Zygote.gradient(f995, 1f0) + gradcheck(f995, 1f0) + + # Issue 961 and 1121 and 1215 + g1 = Zygote.gradient(x->sum(abs2, x), ygpu) + g2 = Zygote.gradient(x->sum(abs2.(x)), ygpu) + g3 = Zygote.graient(x->sum(abs2, x), y) + @test g1 isa CUDA.CuArray{Float32} + @test g2 isa CUDA.CuArray{Float32} + @test g1 ≈ g2 + @test g1 ≈ g3 +end