Skip to content

Commit

Permalink
Merge pull request FluxML#1231 from DomCRose/nonholomorphic_test_fix
Browse files Browse the repository at this point in the history
Fix non-holomorphic tests
  • Loading branch information
CarloLucibello authored Jun 20, 2022
2 parents c4b4fa9 + 9f417db commit b29d5b2
Showing 1 changed file with 19 additions and 8 deletions.
27 changes: 19 additions & 8 deletions test/complex.jl
Original file line number Diff line number Diff line change
Expand Up @@ -58,12 +58,17 @@ fs_C_to_C_holomorphic = (cos,
@testset "C->C holomorphic" begin
for f in fs_C_to_C_holomorphic
for z in (1.0+2.0im, -2.0+pi*im)
grad_zygote = gradient(realf, z)[1]
grad_zygote_r = gradient(realf, z)[1]
grad_zygote_i = gradient(imagf, z)[1]
ε = 1e-8
grad_fd_r = (f(z+ε)-f(z))/ε
grad_fd_i = (f(z+ε*im)-f(z))/*im)
@assert abs(grad_fd_r - grad_fd_i) < sqrt(ε) # check the function is indeed holomorphic
@test abs(grad_zygote - conj(grad_fd_r)) < sqrt(ε)
grad_fd_i = (f(z + ε * im) - f(z)) /* im)
# Check the function is indeed holomorphic
@assert abs(grad_fd_r - grad_fd_i) < sqrt(ε)
# Check Zygote derivatives agree with holomorphic definition
@test grad_zygote_r -im*grad_zygote_i
# Check derivative agrees with finite differences
@test abs(grad_zygote_r - conj(grad_fd_r)) < sqrt(ε)
end
end
end
Expand All @@ -76,14 +81,20 @@ fs_C_to_C_non_holomorphic = (conj,
z->im*abs2(z),
z->z'z,
z->conj(z)*z^2,
z->imag(z)^2+real(sin(z))^3*1im,
)
@testset "C->C non-holomorphic" begin
for f in (fs_C_to_C_holomorphic...,fs_C_to_C_holomorphic...)
for f in fs_C_to_C_non_holomorphic
for z in (1.0+2.0im, -2.0+pi*im)
grad_zygote = gradient(realf, z)[1]
grad_zygote_r = gradient(realf, z)[1]
grad_zygote_i = gradient(imagf, z)[1]
ε = 1e-8
grad_fd = real(f(z+ε)-f(z))/ε + im*real(f(z+ε*im)-f(z))/ε
@test abs(grad_zygote - grad_fd) < sqrt(ε)
grad_fd_r = real(f(z+ε)-f(z))/ε + im*real(f(z+ε*im)-f(z))/ε
grad_fd_i = imag(f(z+ε)-f(z))/ε + im*imag(f(z+ε*im)-f(z))/ε
# Check derivative of both real and imaginary parts of f as these may differ
# for non-holomorphic functions
@test abs(grad_zygote_r - grad_fd_r) < sqrt(ε)
@test abs(grad_zygote_i - grad_fd_i) < sqrt(ε)
end
end
end
Expand Down

0 comments on commit b29d5b2

Please sign in to comment.