From be7e88998f431fefd0c636a0b4308826e911a95d Mon Sep 17 00:00:00 2001
From: ho-oto
Date: Mon, 11 Jan 2021 02:58:45 +0900
Subject: [PATCH 001/490] fix broken link
---
README.md | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/README.md b/README.md
index d248793f6..8551bca87 100644
--- a/README.md
+++ b/README.md
@@ -3,7 +3,7 @@
-![CI Testing](https://github.com/FluxML/Zygote.jl/workflows/CI/badge.svg)
+[![CI Testing](https://github.com/FluxML/Zygote.jl/workflows/CI/badge.svg)](https://github.com/FluxML/Zygote.jl/actions)
[![Dev Docs](https://img.shields.io/badge/docs-dev-blue.svg)](https://fluxml.ai/Zygote.jl/dev)
`] add Zygote`
From 3e54503e29c6625be91f69962b826b86e84dfaa3 Mon Sep 17 00:00:00 2001
From: Dhairya Gandhi
Date: Mon, 11 Jan 2021 17:45:26 +0530
Subject: [PATCH 002/490] add scalar method for vcat with number
---
src/lib/broadcast.jl | 2 ++
test/cuda.jl | 7 +++++++
2 files changed, 9 insertions(+)
diff --git a/src/lib/broadcast.jl b/src/lib/broadcast.jl
index db43b163c..80e1f663f 100644
--- a/src/lib/broadcast.jl
+++ b/src/lib/broadcast.jl
@@ -239,4 +239,6 @@ end
Base.convert(T, xs), Δ -> (nothing, Base.convert(Array, Δ),)
end
+
+ pull_block_vert(sz, Δ::CuArray, A::Number) = CUDA.@allowscalar Δ[sz:sz]
end
diff --git a/test/cuda.jl b/test/cuda.jl
index 2820a776b..6fb83cd36 100644
--- a/test/cuda.jl
+++ b/test/cuda.jl
@@ -15,3 +15,10 @@ end
log_grada = cu(Float32[1.0, 0.5, 0.33333334, 0.25, 0.2, 0.16666667, 0.14285715, 0.125, 0.11111111])
@test gradient(x -> w(x) |> sum, a) == (log_grada,)
end
+
+@testset "vcat scalar indexing" begin
+ r = cu(rand(Float32, 3))
+ grads = (cu(ones(Float32, 3)), nothing)
+ @test gradient((x,y) -> sum(vcat(x,y)), r, 5) == grads
+end
+
From 8b9cc74f3313a58099c613bb28ee8b6a3cf6e32e Mon Sep 17 00:00:00 2001
From: Dhairya Gandhi
Date: Tue, 12 Jan 2021 21:19:55 +0530
Subject: [PATCH 003/490] Rm extra line
---
src/lib/broadcast.jl | 1 -
1 file changed, 1 deletion(-)
diff --git a/src/lib/broadcast.jl b/src/lib/broadcast.jl
index 80e1f663f..8ca5aa044 100644
--- a/src/lib/broadcast.jl
+++ b/src/lib/broadcast.jl
@@ -239,6 +239,5 @@ end
Base.convert(T, xs), Δ -> (nothing, Base.convert(Array, Δ),)
end
-
pull_block_vert(sz, Δ::CuArray, A::Number) = CUDA.@allowscalar Δ[sz:sz]
end
From bac62ef15150f1aae3d4be4cc62284fdffb97a93 Mon Sep 17 00:00:00 2001
From: Dhairya Gandhi
Date: Tue, 12 Jan 2021 21:20:06 +0530
Subject: [PATCH 004/490] typo
---
src/lib/broadcast.jl | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/src/lib/broadcast.jl b/src/lib/broadcast.jl
index 8ca5aa044..d90e03a04 100644
--- a/src/lib/broadcast.jl
+++ b/src/lib/broadcast.jl
@@ -239,5 +239,5 @@ end
Base.convert(T, xs), Δ -> (nothing, Base.convert(Array, Δ),)
end
- pull_block_vert(sz, Δ::CuArray, A::Number) = CUDA.@allowscalar Δ[sz:sz]
+ pull_block_vert(sz, Δ::CuArray, A::Number) = CUDA.@allowscalar Δ[sz]
end
From 845685847afb56d06c88efdeda86e8ba904a05d4 Mon Sep 17 00:00:00 2001
From: Dhairya Gandhi
Date: Wed, 13 Jan 2021 14:18:02 +0530
Subject: [PATCH 005/490] qualify CUDA
---
src/lib/broadcast.jl | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/src/lib/broadcast.jl b/src/lib/broadcast.jl
index d90e03a04..a3cde8094 100644
--- a/src/lib/broadcast.jl
+++ b/src/lib/broadcast.jl
@@ -239,5 +239,5 @@ end
Base.convert(T, xs), Δ -> (nothing, Base.convert(Array, Δ),)
end
- pull_block_vert(sz, Δ::CuArray, A::Number) = CUDA.@allowscalar Δ[sz]
+ pull_block_vert(sz, Δ::CUDA.CuArray, A::Number) = CUDA.@allowscalar Δ[sz]
end
From af6fa3ef846a91dd2cd7e12481fb2fcee8a870c1 Mon Sep 17 00:00:00 2001
From: David Widmann
Date: Mon, 19 Apr 2021 21:57:08 +0200
Subject: [PATCH 006/490] Define custom adjoints for LogExpFunctions instead of
StatsFuns
---
Project.toml | 4 +-
src/Zygote.jl | 2 +-
src/lib/{statsfuns.jl => logexpfunctions.jl} | 3 +-
test/gradcheck.jl | 76 ++++++++++----------
4 files changed, 42 insertions(+), 43 deletions(-)
rename src/lib/{statsfuns.jl => logexpfunctions.jl} (97%)
diff --git a/Project.toml b/Project.toml
index 872df418e..80bef388f 100644
--- a/Project.toml
+++ b/Project.toml
@@ -41,8 +41,8 @@ CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
Distances = "b4f34e82-e78d-54a5-968a-f98e89d6e8f7"
FFTW = "7a1cc6ca-52ef-59f5-83cd-3a7055c09341"
FiniteDifferences = "26cc04aa-876d-5657-8c51-4c34ba976000"
-StatsFuns = "4c63d2b9-4356-54db-8cca-17b64c39e42c"
+LogExpFunctions = "2ab3a3ac-af41-5b50-aa03-7779005ae688"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
[targets]
-test = ["CUDA", "Distances", "FFTW", "FiniteDifferences", "StatsFuns", "Test"]
+test = ["CUDA", "Distances", "FFTW", "FiniteDifferences", "LogExpFunctions", "Test"]
diff --git a/src/Zygote.jl b/src/Zygote.jl
index 614cd9a53..5c0d743fd 100644
--- a/src/Zygote.jl
+++ b/src/Zygote.jl
@@ -40,7 +40,7 @@ include("lib/forward.jl")
include("lib/utils.jl")
include("lib/range.jl")
@init @require Distances="b4f34e82-e78d-54a5-968a-f98e89d6e8f7" include("lib/distances.jl")
-@init @require StatsFuns="4c63d2b9-4356-54db-8cca-17b64c39e42c" include("lib/statsfuns.jl")
+@init @require LogExpFunctions="2ab3a3ac-af41-5b50-aa03-7779005ae688" include("lib/logexpfunctions.jl")
# we need to define this late, so that the genfuncs see lib.jl
# Move using statements out of this file to help with sysimage building
diff --git a/src/lib/statsfuns.jl b/src/lib/logexpfunctions.jl
similarity index 97%
rename from src/lib/statsfuns.jl
rename to src/lib/logexpfunctions.jl
index 85916cae8..1e5e4c0b6 100644
--- a/src/lib/statsfuns.jl
+++ b/src/lib/logexpfunctions.jl
@@ -1,5 +1,4 @@
-import .StatsFuns
-using .StatsFuns: xlogx, xlogy, logistic, logit, log1psq, log1pexp,
+using .LogExpFunctions: xlogx, xlogy, logistic, logit, log1psq, log1pexp,
logsumexp, logaddexp, logsubexp
using Base.Broadcast: broadcasted
diff --git a/test/gradcheck.jl b/test/gradcheck.jl
index b0619a194..89ff11fc4 100644
--- a/test/gradcheck.jl
+++ b/test/gradcheck.jl
@@ -1196,44 +1196,44 @@ end
@test gradcheck(x -> muladd(x[1], x[2], x[3]), [2.0, 3.0, 5.0])
end
-import StatsFuns
+import LogExpFunctions
Zygote.refresh()
@testset "xlogx" begin
- @test gradcheck(x->2.5 * StatsFuns.xlogx(x[1]), [1.0])
- @test gradcheck(x->2.5 * StatsFuns.xlogx(x[1]), [2.45])
- @test gradtest(x -> StatsFuns.xlogx.(x), (3,3))
+ @test gradcheck(x->2.5 * LogExpFunctions.xlogx(x[1]), [1.0])
+ @test gradcheck(x->2.5 * LogExpFunctions.xlogx(x[1]), [2.45])
+ @test gradtest(x -> LogExpFunctions.xlogx.(x), (3,3))
end
@testset "xlogy" begin
- @test gradcheck(x -> StatsFuns.xlogy(x[1], x[2]), [1.0, 2.0])
- @test gradcheck(x -> StatsFuns.xlogy(x[1], x[2]), [0.0, 2.0])
- @test gradtest((x,y) -> StatsFuns.xlogy.(x,y), (3,3), (3,3))
+ @test gradcheck(x -> LogExpFunctions.xlogy(x[1], x[2]), [1.0, 2.0])
+ @test gradcheck(x -> LogExpFunctions.xlogy(x[1], x[2]), [0.0, 2.0])
+ @test gradtest((x,y) -> LogExpFunctions.xlogy.(x,y), (3,3), (3,3))
end
@testset "logistic" begin
- @test gradcheck(x->3.0 * StatsFuns.logistic(x[1]), [-5.0])
- @test gradcheck(x->3.0 * StatsFuns.logistic(x[1]), [-1.0])
- @test gradcheck(x->3.0 * StatsFuns.logistic(x[1]), [-eps()])
- @test gradcheck(x->3.0 * StatsFuns.logistic(x[1]), [0.0])
- @test gradcheck(x->3.0 * StatsFuns.logistic(x[1]), [eps()])
- @test gradcheck(x->3.0 * StatsFuns.logistic(x[1]), [1.0])
- @test gradcheck(x->3.0 * StatsFuns.logistic(x[1]), [5.0])
+ @test gradcheck(x->3.0 * LogExpFunctions.logistic(x[1]), [-5.0])
+ @test gradcheck(x->3.0 * LogExpFunctions.logistic(x[1]), [-1.0])
+ @test gradcheck(x->3.0 * LogExpFunctions.logistic(x[1]), [-eps()])
+ @test gradcheck(x->3.0 * LogExpFunctions.logistic(x[1]), [0.0])
+ @test gradcheck(x->3.0 * LogExpFunctions.logistic(x[1]), [eps()])
+ @test gradcheck(x->3.0 * LogExpFunctions.logistic(x[1]), [1.0])
+ @test gradcheck(x->3.0 * LogExpFunctions.logistic(x[1]), [5.0])
end
@testset "logit" begin
- @test gradcheck(x->5.0 * StatsFuns.logit(x[1]), [0.1])
- @test gradcheck(x->5.0 * StatsFuns.logit(x[1]), [0.3])
- @test gradcheck(x->5.0 * StatsFuns.logit(x[1]), [0.5])
- @test gradcheck(x->5.0 * StatsFuns.logit(x[1]), [0.7])
- @test gradcheck(x->5.0 * StatsFuns.logit(x[1]), [0.9])
+ @test gradcheck(x->5.0 * LogExpFunctions.logit(x[1]), [0.1])
+ @test gradcheck(x->5.0 * LogExpFunctions.logit(x[1]), [0.3])
+ @test gradcheck(x->5.0 * LogExpFunctions.logit(x[1]), [0.5])
+ @test gradcheck(x->5.0 * LogExpFunctions.logit(x[1]), [0.7])
+ @test gradcheck(x->5.0 * LogExpFunctions.logit(x[1]), [0.9])
end
function test_log1pexp(T, xs)
y = T(4.3)
for x in xs
- @test gradcheck(x->y * StatsFuns.log1pexp(x[1]), [x])
+ @test gradcheck(x->y * LogExpFunctions.log1pexp(x[1]), [x])
end
end
@@ -1249,43 +1249,43 @@ end
test_log1pexp(Float64, [33.3, 33.3 + eps(), 100.0])
end
end
- @test gradcheck(x->2.5 * StatsFuns.log1pexp(x[1]), [1.0])
- @test gradcheck(x->2.5 * StatsFuns.log1pexp(x[1]), [2.45])
- @test gradtest(x -> StatsFuns.log1pexp.(x), (3,3))
+ @test gradcheck(x->2.5 * LogExpFunctions.log1pexp(x[1]), [1.0])
+ @test gradcheck(x->2.5 * LogExpFunctions.log1pexp(x[1]), [2.45])
+ @test gradtest(x -> LogExpFunctions.log1pexp.(x), (3,3))
end
@testset "log1psq" begin
rng = MersenneTwister(123456)
@testset "Float64" begin
for x in [-10.0, -5.0, -1.0, -eps(), 0.0, eps(), 1.0, 5.0, 10.0]
- @test gradcheck(x->5.1 * StatsFuns.log1psq(x[1]), [x])
+ @test gradcheck(x->5.1 * LogExpFunctions.log1psq(x[1]), [x])
end
end
end
@testset "logaddexp" begin
- @test gradcheck(x -> StatsFuns.logaddexp(x[1], x[2]), [1.0, 2.0])
- @test gradcheck(x -> StatsFuns.logaddexp(x[1], x[2]), [1.0, -1.0])
- @test gradcheck(x -> StatsFuns.logaddexp(x[1], x[2]), [-2.0, -3.0])
- @test gradcheck(x -> StatsFuns.logaddexp(x[1], x[2]), [5.0, 5.0])
- @test gradtest((x,y) -> StatsFuns.logaddexp.(x,y), (3,3), (3,3))
+ @test gradcheck(x -> LogExpFunctions.logaddexp(x[1], x[2]), [1.0, 2.0])
+ @test gradcheck(x -> LogExpFunctions.logaddexp(x[1], x[2]), [1.0, -1.0])
+ @test gradcheck(x -> LogExpFunctions.logaddexp(x[1], x[2]), [-2.0, -3.0])
+ @test gradcheck(x -> LogExpFunctions.logaddexp(x[1], x[2]), [5.0, 5.0])
+ @test gradtest((x,y) -> LogExpFunctions.logaddexp.(x,y), (3,3), (3,3))
end
@testset "logsubexp" begin
- @test gradcheck(x -> StatsFuns.logsubexp(x[1], x[2]), [1.0, 2.0])
- @test gradcheck(x -> StatsFuns.logsubexp(x[1], x[2]), [1.0, -1.0])
- @test gradcheck(x -> StatsFuns.logsubexp(x[1], x[2]), [-2.0, -3.0])
- @test gradtest((x,y) -> StatsFuns.logsubexp.(x,y), (3,3), (3,3))
+ @test gradcheck(x -> LogExpFunctions.logsubexp(x[1], x[2]), [1.0, 2.0])
+ @test gradcheck(x -> LogExpFunctions.logsubexp(x[1], x[2]), [1.0, -1.0])
+ @test gradcheck(x -> LogExpFunctions.logsubexp(x[1], x[2]), [-2.0, -3.0])
+ @test gradtest((x,y) -> LogExpFunctions.logsubexp.(x,y), (3,3), (3,3))
end
@testset "logsumexp" begin
rng = MersenneTwister(123456)
@testset "Float64" begin
- @test gradtest(StatsFuns.logsumexp, randn(rng, 1))
- @test gradtest(StatsFuns.logsumexp, randn(rng, 1, 1))
- @test gradtest(StatsFuns.logsumexp, randn(rng, 3))
- @test gradtest(StatsFuns.logsumexp, randn(rng, 3, 4, 5))
- @test gradtest(x -> sum(StatsFuns.logsumexp(x; dims=1)), randn(rng, 4, 4))
+ @test gradtest(LogExpFunctions.logsumexp, randn(rng, 1))
+ @test gradtest(LogExpFunctions.logsumexp, randn(rng, 1, 1))
+ @test gradtest(LogExpFunctions.logsumexp, randn(rng, 3))
+ @test gradtest(LogExpFunctions.logsumexp, randn(rng, 3, 4, 5))
+ @test gradtest(x -> sum(LogExpFunctions.logsumexp(x; dims=1)), randn(rng, 4, 4))
end
end
From 4f910f25f9362f7364b4fe2294f0021aab0f7e72 Mon Sep 17 00:00:00 2001
From: David Widmann
Date: Mon, 19 Apr 2021 21:57:47 +0200
Subject: [PATCH 007/490] Bump version
---
Project.toml | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/Project.toml b/Project.toml
index 80bef388f..93172abd7 100644
--- a/Project.toml
+++ b/Project.toml
@@ -1,6 +1,6 @@
name = "Zygote"
uuid = "e88e6eb3-aa80-5325-afca-941959d7151f"
-version = "0.6.9"
+version = "0.6.10"
[deps]
AbstractFFTs = "621f4979-c628-5d54-868e-fcf4e3e8185c"
From a34553eeac36969d62d8672e5b2b810ec82a32d7 Mon Sep 17 00:00:00 2001
From: Dmitri Iouchtchenko
Date: Wed, 21 Apr 2021 11:08:57 -0400
Subject: [PATCH 008/490] Handle nothing in map
---
src/lib/array.jl | 9 ++++++---
test/gradcheck.jl | 35 +++++++++++++++++++++++++++++++++++
2 files changed, 41 insertions(+), 3 deletions(-)
diff --git a/src/lib/array.jl b/src/lib/array.jl
index 3bfad52f6..344db8dcd 100644
--- a/src/lib/array.jl
+++ b/src/lib/array.jl
@@ -203,6 +203,7 @@ for (mapfunc,∇mapfunc) in [(:map,:∇map),(:pmap,:∇pmap)]
else
ys, backs = unzip(ys_and_backs)
ys, function (Δ)
+ isnothing(Δ) && return nothing
# Apply pullbacks in reverse order. Needed for correctness if `f` is stateful.
Δf_and_args_zipped = $mapfunc((f, δ) -> f(δ), _tryreverse($mapfunc, backs, Δ)...)
Δf_and_args = unzip(_tryreverse($mapfunc, Δf_and_args_zipped))
@@ -234,11 +235,13 @@ end
@nograd workers
function _pullback(cx::AContext, ::typeof(collect), g::Base.Generator)
- y, back = ∇map(cx, g.f, g.iter)
- y, function (ȳ)
- f̄, x̄ = back(ȳ)
+ y, b = ∇map(cx, g.f, g.iter)
+ back(::Nothing) = nothing
+ function back(ȳ)
+ f̄, x̄ = b(ȳ)
(nothing, (f = f̄, iter = x̄),)
end
+ y, back
end
@adjoint iterate(r::UnitRange, i...) = iterate(r, i...), _ -> nothing
diff --git a/test/gradcheck.jl b/test/gradcheck.jl
index b0619a194..6f1e5a996 100644
--- a/test/gradcheck.jl
+++ b/test/gradcheck.jl
@@ -1672,3 +1672,38 @@ end
gradient(x->norm(x*[1im, 1]), 1.23)
gradient(x->norm(x*[1im 1]), 1.23)
end
+
+# https://github.com/FluxML/Zygote.jl/issues/804
+@testset "Unused comprehension" begin
+ # Comprehension is used.
+ io = IOBuffer()
+ s = 0.0
+ gs = gradient([1.0, 2.0]) do xs
+ sum([(print(io, x); s += x; s * x) for x in xs])
+ end
+ @test String(take!(io)) == "1.02.0"
+ @test s == 3.0
+ @test gs == ([4.0, 5.0],)
+
+ # Comprehension is not used.
+ io = IOBuffer()
+ s = 0.0
+ gs = gradient([1.0, 2.0]) do xs
+ sum([(print(io, x); s += x; s * x) for x in xs])
+ 0.0
+ end
+ @test String(take!(io)) == "1.02.0"
+ @test s == 3.0
+ @test gs == (nothing,)
+
+ # Comprehension is empty and not used.
+ io = IOBuffer()
+ s = 0.0
+ gs = gradient([]) do xs
+ [(print(io, x); s += x; s * x) for x in xs]
+ 0.0
+ end
+ @test String(take!(io)) == ""
+ @test s == 0.0
+ @test gs == (nothing,)
+end
From 148ebeb6aec920ec33d1c1e1e0335a0d83dc3d7b Mon Sep 17 00:00:00 2001
From: Simeon Schaub
Date: Mon, 26 Apr 2021 21:50:00 +0200
Subject: [PATCH 009/490] fix adjoint for sum
addresses the second part of #897
---
src/lib/array.jl | 11 +++--------
test/lib/array.jl | 4 ++++
test/runtests.jl | 1 +
3 files changed, 8 insertions(+), 8 deletions(-)
create mode 100644 test/lib/array.jl
diff --git a/src/lib/array.jl b/src/lib/array.jl
index 3bfad52f6..321753860 100644
--- a/src/lib/array.jl
+++ b/src/lib/array.jl
@@ -270,14 +270,9 @@ end
sum(xs, dims = dims), Δ -> (nothing,)
end
-_normalize_kws(kws::NamedTuple) = kws
-_normalize_kws(kws) = NamedTuple()
-
-function _pullback(cx::AContext, kwtype, kws, ::typeof(sum), f, xs::AbstractArray)
- norm_kws = _normalize_kws(kws)
- @assert !haskey(norm_kws, :init) # TODO add init support (julia 1.6)
- y, back = pullback(cx, (f, xs) -> sum(f.(xs); norm_kws...), f, xs)
- y, ȳ -> (nothing, nothing, nothing, back(ȳ)...)
+@adjoint function sum(f, xs::AbstractArray; kws...)
+ @assert !haskey(kws, :init) # TODO add init support (julia 1.6)
+ return pullback(__context__, (f, xs) -> sum(f.(xs); kws...), f, xs)
end
@adjoint function sum(::typeof(abs2), X::AbstractArray; dims = :)
diff --git a/test/lib/array.jl b/test/lib/array.jl
new file mode 100644
index 000000000..380d1bb8f
--- /dev/null
+++ b/test/lib/array.jl
@@ -0,0 +1,4 @@
+using LinearAlgebra
+
+# issue 897
+@test gradient(x -> sum(sin, Diagonal(x)), ones(2)) == ([0.5403023058681398, 0.5403023058681398],)
diff --git a/test/runtests.jl b/test/runtests.jl
index b6b7aab0b..f20b59a7e 100644
--- a/test/runtests.jl
+++ b/test/runtests.jl
@@ -25,6 +25,7 @@ end
@testset "lib" begin
include("lib/number.jl")
include("lib/lib.jl")
+ include("lib/array.jl")
end
@testset "Features" begin
From 3218e50551e5c525452185de2cf8bfd2fba6fd65 Mon Sep 17 00:00:00 2001
From: Simeon Schaub
Date: Mon, 26 Apr 2021 22:02:56 +0200
Subject: [PATCH 010/490] fix differentiation of loopinfo exprs
addresses the first part of #897
---
src/compiler/reverse.jl | 2 +-
test/compiler.jl | 3 +++
test/features.jl | 2 +-
3 files changed, 5 insertions(+), 2 deletions(-)
diff --git a/src/compiler/reverse.jl b/src/compiler/reverse.jl
index e144be68d..e746684f7 100644
--- a/src/compiler/reverse.jl
+++ b/src/compiler/reverse.jl
@@ -275,7 +275,7 @@ function adjoint(pr::Primal)
end
elseif ex isa Core.PiNode
grads[ex.val] = grads[v]
- elseif isexpr(ex, GlobalRef, :call, :isdefined, :inbounds, :meta)
+ elseif isexpr(ex, GlobalRef, :call, :isdefined, :inbounds, :meta, :loopinfo)
elseif isexpr(ex)
push!(rb, stmt(xcall(Base, :error, "Can't differentiate $(ex.head) expression"),
line = b[v].line))
diff --git a/test/compiler.jl b/test/compiler.jl
index c97a50f61..af8e6ccb7 100644
--- a/test/compiler.jl
+++ b/test/compiler.jl
@@ -143,3 +143,6 @@ end
ms = MyStruct(1, 2)
@test Zygote.gradient(sumall, ms) == ((a = 2, b = 2),)
end
+
+# issue 897
+@test gradient(x -> sum(norm, collect(eachcol(x))), ones(3, 400))[1] ≈ fill(0.5773502691896258, 3, 400)
diff --git a/test/features.jl b/test/features.jl
index 5531ebd0b..48df0c87c 100644
--- a/test/features.jl
+++ b/test/features.jl
@@ -402,7 +402,7 @@ function pow_simd(x, n)
return r
end
-@test_broken gradient(pow_simd, 2, 3) == (12,nothing)
+@test gradient(pow_simd, 2, 3) == (12,nothing)
@testset "tuple getindex" begin
@test gradient(x -> size(x)[2], ones(2,2,2)) == (nothing,)
From 703e5bce12171b0e9fa4d4e29f5d30bba3775d4b Mon Sep 17 00:00:00 2001
From: Michael Abbott <32575566+mcabbott@users.noreply.github.com>
Date: Mon, 26 Apr 2021 23:19:47 -0400
Subject: [PATCH 011/490] add diagonal of Hessian function
---
docs/src/utils.md | 1 +
src/Zygote.jl | 2 +-
src/lib/forward.jl | 30 ++++++++++++++++++++++++++++++
src/lib/grad.jl | 28 ++++++++++++++++++++++++++++
test/utils.jl | 21 +++++++++++++++++++++
5 files changed, 81 insertions(+), 1 deletion(-)
diff --git a/docs/src/utils.md b/docs/src/utils.md
index c46b646ce..4d3063ee2 100644
--- a/docs/src/utils.md
+++ b/docs/src/utils.md
@@ -6,6 +6,7 @@ or a Hessian (by taking a second derivative).
```@docs
Zygote.jacobian
Zygote.hessian
+Zygote.diaghessian
```
Zygote also provides a set of helpful utilities. These are all "user-level" tools –
diff --git a/src/Zygote.jl b/src/Zygote.jl
index 614cd9a53..dc5c785ed 100644
--- a/src/Zygote.jl
+++ b/src/Zygote.jl
@@ -12,7 +12,7 @@ using MacroTools, Requires
using MacroTools: @forward
import Distributed: pmap, CachingPool, workers
-export Params, gradient, jacobian, hessian, pullback, pushforward, @code_adjoint
+export Params, gradient, jacobian, hessian, diaghessian, pullback, pushforward, @code_adjoint
const Numeric{T<:Number} = Union{T, AbstractArray{<:T}}
diff --git a/src/lib/forward.jl b/src/lib/forward.jl
index 7a6e125ff..f40ff8f3e 100644
--- a/src/lib/forward.jl
+++ b/src/lib/forward.jl
@@ -46,6 +46,36 @@ vec_scalar(x::Real) = [x]
reshape_scalar(x, y) = reshape(y, size(x))
reshape_scalar(x::Real, y) = y[]
+function extract_diag(offset, xs::AbstractArray{ForwardDiff.Dual{T,V,N}}) where {T,V,N}
+ D = similar(xs, V, N)
+ for j in 1:min(N, length(xs)-offset)
+ D[j] = xs[offset+j].partials.values[j]
+ end
+ return map(x -> x.value, xs), D
+end
+
+function forward_diag(f, x::AbstractArray, ::Val{N}) where N
+ y, _D = extract_diag(0, f(seed(x, Val(N))))
+ D = similar(_D, size(x)...)
+ D[1:N] = _D
+ offset = 0
+ while offset + N < length(x)
+ offset += N
+ _, _D = extract_diag(offset, f(seed(x, Val(N), offset)))
+ range = (1+offset):min(N+offset,length(x))
+ D[range] = @view _D[range.-offset]
+ end
+ return y, D
+end
+
+function forward_diag(f, x::AbstractArray)
+ if length(x) < ForwardDiff.DEFAULT_CHUNK_THRESHOLD
+ forward_diag(f, x, Val(length(x)))
+ else
+ forward_diag(f, x, Val(ForwardDiff.DEFAULT_CHUNK_THRESHOLD))
+ end
+end
+
"""
forwarddiff(f, x) -> f(x)
diff --git a/src/lib/grad.jl b/src/lib/grad.jl
index 7a1b0bdd8..77e0968dc 100644
--- a/src/lib/grad.jl
+++ b/src/lib/grad.jl
@@ -213,3 +213,31 @@ function jacobian(f, pars::Params)
end
Grads(out, pars)
end
+
+"""
+ diaghessian(f, args...)
+
+Diagonal part of the Hessian, literally `diaghessian(f, x)[1] == diag(hessian(f,x))`
+for one vector argument `x`. In general this returns a tuple, with an array the same shape
+as each argument, `d[i] = ∂²y/∂x[i]∂x[i]`, where `y = f(args...)` must be a real number.
+
+Like [`hessian`](@ref) it uses ForwardDiff over Zygote.
+
+!!! warning
+ For arguments of any type except `Number` & `AbstractArray`, the result is `nothing`.
+"""
+function diaghessian(f, args...)
+ ntuple(length(args)) do n
+ x = args[n]
+ if x isa AbstractArray
+ forward_diag(x -> gradient(f, _splice(x, args, Val(n))...)[n], x)[2]
+ elseif x isa Number
+ ForwardDiff.derivative(x -> gradient(f, _splice(x, args, Val(n))...)[n], x)
+ end
+ end
+end
+
+# diaghessian(f, x::AbstractArray) = (forward_diag(x -> gradient(f, x)[1], x)[2],)
+
+_splice(x, args, ::Val{n}) where {n} = ntuple(i -> i==n ? x : args[i], length(args))
+
diff --git a/test/utils.jl b/test/utils.jl
index d09fc2dc2..7aef69807 100644
--- a/test/utils.jl
+++ b/test/utils.jl
@@ -18,6 +18,27 @@ using Zygote: hessian_dual, hessian_reverse
@test_throws Exception hess(identity, randn(2))
end
+@testset "diagonal hessian" begin
+ @test diaghessian(x -> x[1]*x[2]^2, [1, pi]) == ([0, 2],)
+
+ xs, y = randn(2,3), rand()
+ f34(xs, y) = xs[1] * (sum(xs .^ (1:3)') + y^4) # non-diagonal Hessian, two arguments
+ dx, dy = diaghessian(f34, xs, y)
+ @test size(dx) == size(xs)
+ @test vec(dx) ≈ diag(hessian(x -> f34(x,y), xs))
+ @test dy ≈ hessian(y -> f34(xs,y), y)
+
+ zs = randn(7,13) # test chunk mode
+ @test length(zs) > ForwardDiff.DEFAULT_CHUNK_THRESHOLD
+ @test length(zs) % ForwardDiff.DEFAULT_CHUNK_THRESHOLD != 0
+ f713(zs) = sum(vec(zs)' .* exp.(vec(zs)))
+ @test vec(diaghessian(f713, zs)[1]) ≈ diag(hessian(f713, zs))
+
+ @test_throws Exception diaghessian(sin, im*pi)
+ @test_throws Exception diaghessian(x -> x+im, pi)
+ @test_throws Exception diaghessian(identity, randn(2))
+end
+
@testset "jacobian(f, args...)" begin
@test jacobian(identity, [1,2])[1] == [1 0; 0 1]
From 887f9b34450aefb152c7e6df97e07673bd8c29db Mon Sep 17 00:00:00 2001
From: Michael Abbott <32575566+mcabbott@users.noreply.github.com>
Date: Tue, 27 Apr 2021 00:28:09 -0400
Subject: [PATCH 012/490] forward over ForwardDiff
---
src/lib/forward.jl | 6 ++++++
test/utils.jl | 11 +++++++++++
2 files changed, 17 insertions(+)
diff --git a/src/lib/forward.jl b/src/lib/forward.jl
index f40ff8f3e..b6913dd4a 100644
--- a/src/lib/forward.jl
+++ b/src/lib/forward.jl
@@ -131,3 +131,9 @@ forwarddiff(f, x) = f(x)
y, J = forward_jacobian(f, x)
return y, ȳ -> (nothing, reshape_scalar(x, J*vec_scalar(ȳ)))
end
+
+# Second derivatives
+@adjoint ForwardDiff.derivative(f, x) = pullback(forwarddiff, x -> ForwardDiff.derivative(f, x), x)
+@adjoint ForwardDiff.gradient(f, x) = pullback(forwarddiff, x -> ForwardDiff.gradient(f, x), x)
+@adjoint ForwardDiff.jacobian(f, x) = pullback(forwarddiff, x -> ForwardDiff.jacobian(f, x), x)
+
diff --git a/test/utils.jl b/test/utils.jl
index 7aef69807..d8ebae5de 100644
--- a/test/utils.jl
+++ b/test/utils.jl
@@ -82,3 +82,14 @@ end
@test Jxy[ys] ≈ [1 0 0; 0 1 0]
@test Jxy[xs] ≈ [2 6 4 8; 2 6 4 8]
end
+
+@testset "adjoints of ForwardDiff functions" begin
+ f1(x) = ForwardDiff.gradient(x -> sum(exp.(x.+1)), x)
+ x1 = randn(3,7)
+ @test Zygote.jacobian(f1, x1)[1] ≈ ForwardDiff.jacobian(f1, x1)
+
+ f2(x) = ForwardDiff.jacobian(x -> log.(x[1:3] .+ x[2:4]), x)
+ x2 = rand(5) .+ 1
+ @test Zygote.jacobian(f2, x2)[1] ≈ ForwardDiff.jacobian(f2, x2)
+end
+
From 0e52277388d049a3b023af5c4e81ec8ed725f9d9 Mon Sep 17 00:00:00 2001
From: David Widmann
Date: Tue, 27 Apr 2021 16:38:23 +0200
Subject: [PATCH 013/490] Improve performance of `Base.Fix1` and `Base.Fix2`
---
src/lib/base.jl | 14 ++++++++++++++
test/gradcheck.jl | 15 +++++++++++++++
2 files changed, 29 insertions(+)
diff --git a/src/lib/base.jl b/src/lib/base.jl
index d90998c2c..67f8b2c5e 100644
--- a/src/lib/base.jl
+++ b/src/lib/base.jl
@@ -140,3 +140,17 @@ end
@adjoint Base.nameof(x::UnionAll) = nameof(x), _ -> (nothing,)
@nograd typeintersect
+
+# Base.Fix1 and Base.Fix2: https://github.com/FluxML/Zygote.jl/issues/957
+@adjoint function (g::Base.Fix1)(y)
+ f = g.f
+ x = g.x
+ fallback_Fix1(y) = f(x, y)
+ return _pullback(__context__, fallback_Fix1, y)
+end
+@adjoint function (g::Base.Fix2)(y)
+ f = g.f
+ x = g.x
+ fallback_Fix2(y) = f(y, x)
+ return _pullback(__context__, fallback_Fix2, y)
+end
diff --git a/test/gradcheck.jl b/test/gradcheck.jl
index b0619a194..a5be4fb98 100644
--- a/test/gradcheck.jl
+++ b/test/gradcheck.jl
@@ -1672,3 +1672,18 @@ end
gradient(x->norm(x*[1im, 1]), 1.23)
gradient(x->norm(x*[1im 1]), 1.23)
end
+
+@testset "Fix1 and Fix2" begin
+ @test gradcheck(x -> prod(Base.Fix1(+, 1), x), randn(100))
+ @test gradcheck(x -> prod(Base.Fix2(+, 1), x), randn(100))
+
+ # compile once and check the execution times compared with a closure
+ # https://github.com/FluxML/Zygote.jl/issues/957
+ x = randn(100)
+ gradient(x -> prod(y -> y + 1, x), x)
+ t = @elapsed(gradient(x -> prod(y -> y + 1, x), x))
+ gradient(x -> prod(Base.Fix1(+, 1), x), x)
+ @test @elapsed(gradient(x -> prod(Base.Fix1(+, 1), x), x)) < 2 * t
+ gradient(x -> prod(Base.Fix1(+, 1), x), x)
+ @test @elapsed(gradient(x -> prod(Base.Fix2(+, 1), x), x)) < 2 * t
+end
From 419d9da62c52dc9329b92d844b6fd6270ca6bf02 Mon Sep 17 00:00:00 2001
From: David Widmann
Date: Tue, 27 Apr 2021 16:41:18 +0200
Subject: [PATCH 014/490] Bump version
---
Project.toml | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/Project.toml b/Project.toml
index 68d7366a9..3decc2e10 100644
--- a/Project.toml
+++ b/Project.toml
@@ -1,6 +1,6 @@
name = "Zygote"
uuid = "e88e6eb3-aa80-5325-afca-941959d7151f"
-version = "0.6.10"
+version = "0.6.11"
[deps]
AbstractFFTs = "621f4979-c628-5d54-868e-fcf4e3e8185c"
From 29a624789ca29bc319e230b832ead54d932db639 Mon Sep 17 00:00:00 2001
From: David Widmann
Date: Tue, 27 Apr 2021 16:45:24 +0200
Subject: [PATCH 015/490] Fix typo
---
test/gradcheck.jl | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/test/gradcheck.jl b/test/gradcheck.jl
index a5be4fb98..ddd0ebd3b 100644
--- a/test/gradcheck.jl
+++ b/test/gradcheck.jl
@@ -1684,6 +1684,6 @@ end
t = @elapsed(gradient(x -> prod(y -> y + 1, x), x))
gradient(x -> prod(Base.Fix1(+, 1), x), x)
@test @elapsed(gradient(x -> prod(Base.Fix1(+, 1), x), x)) < 2 * t
- gradient(x -> prod(Base.Fix1(+, 1), x), x)
+ gradient(x -> prod(Base.Fix2(+, 1), x), x)
@test @elapsed(gradient(x -> prod(Base.Fix2(+, 1), x), x)) < 2 * t
end
From dad3e2067c27173c3fbe4606bf29af9224753dde Mon Sep 17 00:00:00 2001
From: Michael Abbott <32575566+mcabbott@users.noreply.github.com>
Date: Tue, 27 Apr 2021 13:13:02 -0400
Subject: [PATCH 016/490] docstring
---
src/lib/grad.jl | 24 +++++++++++++++++++-----
1 file changed, 19 insertions(+), 5 deletions(-)
diff --git a/src/lib/grad.jl b/src/lib/grad.jl
index 77e0968dc..6146e2154 100644
--- a/src/lib/grad.jl
+++ b/src/lib/grad.jl
@@ -217,14 +217,30 @@ end
"""
diaghessian(f, args...)
-Diagonal part of the Hessian, literally `diaghessian(f, x)[1] == diag(hessian(f,x))`
-for one vector argument `x`. In general this returns a tuple, with an array the same shape
-as each argument, `d[i] = ∂²y/∂x[i]∂x[i]`, where `y = f(args...)` must be a real number.
+Diagonal part of the Hessian. Returns a tuple containing
+an array `h` the same shape as each argument `x`,
+with `Hᵢᵢ = h[i] = ∂²y/∂x[i]∂x[i]`.
+The original evaluation `y = f(args...)` must give a real number `y`.
+For one vector argument `x`, this is equivalent to `(diag(hessian(f,x)),)`.
Like [`hessian`](@ref) it uses ForwardDiff over Zygote.
!!! warning
For arguments of any type except `Number` & `AbstractArray`, the result is `nothing`.
+
+# Examples
+```jldoctest; setup=:(using Zygote, LinearAlgebra)
+julia> diaghessian(x -> sum(x.^3), [1 2; 3 4])[1]
+2×2 Matrix{$Int}:
+ 6 12
+ 18 24
+
+julia> Diagonal(vec(ans)) == hessian(x -> sum(x.^3), [1 2; 3 4])
+true
+
+julia> diaghessian((x,y) -> sum(x .* y .* y'), [1 22; 333 4], [0.5, 0.666])
+([0.0 0.0; 0.0 0.0], [2.0, 8.0])
+```
"""
function diaghessian(f, args...)
ntuple(length(args)) do n
@@ -237,7 +253,5 @@ function diaghessian(f, args...)
end
end
-# diaghessian(f, x::AbstractArray) = (forward_diag(x -> gradient(f, x)[1], x)[2],)
-
_splice(x, args, ::Val{n}) where {n} = ntuple(i -> i==n ? x : args[i], length(args))
From 052d5266064074e0a9a4526725b1e444a17dc559 Mon Sep 17 00:00:00 2001
From: Michael Abbott <32575566+mcabbott@users.noreply.github.com>
Date: Tue, 27 Apr 2021 13:23:45 -0400
Subject: [PATCH 017/490] load ForwardDiff
---
test/utils.jl | 1 +
1 file changed, 1 insertion(+)
diff --git a/test/utils.jl b/test/utils.jl
index d8ebae5de..9b5171481 100644
--- a/test/utils.jl
+++ b/test/utils.jl
@@ -1,4 +1,5 @@
using LinearAlgebra
+using ForwardDiff
using Zygote: hessian_dual, hessian_reverse
@testset "hessian: $hess" for hess in [hessian_dual, hessian_reverse]
From 040e047feee36b195b869070fa414a986ea8aa41 Mon Sep 17 00:00:00 2001
From: Michael Abbott <32575566+mcabbott@users.noreply.github.com>
Date: Tue, 27 Apr 2021 13:31:46 -0400
Subject: [PATCH 018/490] examples
---
src/lib/grad.jl | 12 ++++++++++--
1 file changed, 10 insertions(+), 2 deletions(-)
diff --git a/src/lib/grad.jl b/src/lib/grad.jl
index 6146e2154..065c244ce 100644
--- a/src/lib/grad.jl
+++ b/src/lib/grad.jl
@@ -235,11 +235,19 @@ julia> diaghessian(x -> sum(x.^3), [1 2; 3 4])[1]
6 12
18 24
-julia> Diagonal(vec(ans)) == hessian(x -> sum(x.^3), [1 2; 3 4])
+julia> Diagonal(vec(ans)) == hessian(x -> sum(x.^3), [1 2; 3 4]) # full Hessian is diagonal
true
-julia> diaghessian((x,y) -> sum(x .* y .* y'), [1 22; 333 4], [0.5, 0.666])
+julia> diaghessian((x,y) -> sum(x .* y .* y'), [1 22; 333 4], [0.5, 0.666]) # two array arguments
([0.0 0.0; 0.0 0.0], [2.0, 8.0])
+
+julia> diaghessian(atan, 1, 2) # scalar arguments
+(-0.16, 0.16)
+
+julia> hessian(xy -> atan(xy[1], xy[2]), [1, 2]) # full Hessian is not diagonal
+2×2 Matrix{Float64}:
+ -0.16 -0.12
+ -0.12 0.16
```
"""
function diaghessian(f, args...)
From 4098e7048149a1af7aa69223c2f101291358d915 Mon Sep 17 00:00:00 2001
From: David Widmann
Date: Fri, 30 Apr 2021 11:49:42 +0200
Subject: [PATCH 019/490] Add lower bound for StatsFuns
---
Project.toml | 1 +
1 file changed, 1 insertion(+)
diff --git a/Project.toml b/Project.toml
index 93172abd7..1cb75371f 100644
--- a/Project.toml
+++ b/Project.toml
@@ -33,6 +33,7 @@ MacroTools = "0.5"
NaNMath = "0.3"
Requires = "1.1"
SpecialFunctions = "0.10, 1.0"
+StatsFuns = "0.9.8"
ZygoteRules = "0.2.1"
julia = "1.3"
From 7ac4d8193c8ea7ae6a782cbf46482652dd2b23af Mon Sep 17 00:00:00 2001
From: Michael Abbott <32575566+mcabbott@users.noreply.github.com>
Date: Tue, 27 Apr 2021 14:58:32 -0400
Subject: [PATCH 020/490] speed improvement
---
src/lib/grad.jl | 17 ++++++++---------
1 file changed, 8 insertions(+), 9 deletions(-)
diff --git a/src/lib/grad.jl b/src/lib/grad.jl
index 065c244ce..1375c9d63 100644
--- a/src/lib/grad.jl
+++ b/src/lib/grad.jl
@@ -217,9 +217,8 @@ end
"""
diaghessian(f, args...)
-Diagonal part of the Hessian. Returns a tuple containing
-an array `h` the same shape as each argument `x`,
-with `Hᵢᵢ = h[i] = ∂²y/∂x[i]∂x[i]`.
+Diagonal part of the Hessian. Returns a tuple containing, for each argument `x`,
+`h` of the same shape with `h[i] = Hᵢᵢ = ∂²y/∂x[i]∂x[i]`.
The original evaluation `y = f(args...)` must give a real number `y`.
For one vector argument `x`, this is equivalent to `(diag(hessian(f,x)),)`.
@@ -252,14 +251,14 @@ julia> hessian(xy -> atan(xy[1], xy[2]), [1, 2]) # full Hessian is not diagonal
"""
function diaghessian(f, args...)
ntuple(length(args)) do n
- x = args[n]
- if x isa AbstractArray
- forward_diag(x -> gradient(f, _splice(x, args, Val(n))...)[n], x)[2]
- elseif x isa Number
- ForwardDiff.derivative(x -> gradient(f, _splice(x, args, Val(n))...)[n], x)
+ let x = args[n], valn = Val(n) # let Val improves speed, sometimes
+ if x isa AbstractArray
+ forward_diag(x -> gradient(f, _splice(x, args, valn)...)[n], x)[2]
+ elseif x isa Number
+ ForwardDiff.derivative(x -> gradient(f, _splice(x, args, valn)...)[n], x)
+ end
end
end
end
_splice(x, args, ::Val{n}) where {n} = ntuple(i -> i==n ? x : args[i], length(args))
-
From 125f4146764bcef4d93d7b0eeaaae939ed6d3432 Mon Sep 17 00:00:00 2001
From: David Widmann
Date: Fri, 30 Apr 2021 13:52:05 +0200
Subject: [PATCH 021/490] Add StatsFuns to `[extras]` (otherwise Julia
complains)
---
Project.toml | 1 +
1 file changed, 1 insertion(+)
diff --git a/Project.toml b/Project.toml
index 1cb75371f..32849f933 100644
--- a/Project.toml
+++ b/Project.toml
@@ -43,6 +43,7 @@ Distances = "b4f34e82-e78d-54a5-968a-f98e89d6e8f7"
FFTW = "7a1cc6ca-52ef-59f5-83cd-3a7055c09341"
FiniteDifferences = "26cc04aa-876d-5657-8c51-4c34ba976000"
LogExpFunctions = "2ab3a3ac-af41-5b50-aa03-7779005ae688"
+StatsFuns = "4c63d2b9-4356-54db-8cca-17b64c39e42c"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
[targets]
From b097a0dff57d539408389f12ef41a8fc7dea8a8a Mon Sep 17 00:00:00 2001
From: Michael Abbott <32575566+mcabbott@users.noreply.github.com>
Date: Fri, 30 Apr 2021 07:39:36 -0400
Subject: [PATCH 022/490] reduce allocations
---
src/lib/forward.jl | 20 +++++++++++---------
src/lib/grad.jl | 4 +++-
2 files changed, 14 insertions(+), 10 deletions(-)
diff --git a/src/lib/forward.jl b/src/lib/forward.jl
index b6913dd4a..7e6b6e0fa 100644
--- a/src/lib/forward.jl
+++ b/src/lib/forward.jl
@@ -46,26 +46,28 @@ vec_scalar(x::Real) = [x]
reshape_scalar(x, y) = reshape(y, size(x))
reshape_scalar(x::Real, y) = y[]
-function extract_diag(offset, xs::AbstractArray{ForwardDiff.Dual{T,V,N}}) where {T,V,N}
- D = similar(xs, V, N)
+# very similar functions needed for diaghessian:
+
+function extract_diag!(_D, offset, xs::AbstractArray{ForwardDiff.Dual{T,V,N}}) where {T,V,N}
for j in 1:min(N, length(xs)-offset)
- D[j] = xs[offset+j].partials.values[j]
+ _D[j] = xs[offset+j].partials.values[j]
end
- return map(x -> x.value, xs), D
end
-function forward_diag(f, x::AbstractArray, ::Val{N}) where N
- y, _D = extract_diag(0, f(seed(x, Val(N))))
- D = similar(_D, size(x)...)
+function forward_diag(f, x::AbstractArray{T}, ::Val{N}) where {N,T}
+ fx = f(seed(x, Val(N)))
+ D = similar(x, ForwardDiff.valtype(eltype(fx)))
+ _D = similar(D, N)
+ extract_diag!(_D, 0, fx)
D[1:N] = _D
offset = 0
while offset + N < length(x)
offset += N
- _, _D = extract_diag(offset, f(seed(x, Val(N), offset)))
+ extract_diag!(_D, offset, f(seed(x, Val(N), offset)))
range = (1+offset):min(N+offset,length(x))
D[range] = @view _D[range.-offset]
end
- return y, D
+ return map(y -> y.value, fx), D
end
function forward_diag(f, x::AbstractArray)
diff --git a/src/lib/grad.jl b/src/lib/grad.jl
index 1375c9d63..ab09bffeb 100644
--- a/src/lib/grad.jl
+++ b/src/lib/grad.jl
@@ -50,6 +50,8 @@ is higher-dimensional.
This uses forward over reverse, ForwardDiff over Zygote, calling `hessian_dual(f, x)`.
See [`hessian_reverse`](@ref) for an all-Zygote alternative.
+See also [`diaghessian`](@ref) to compute only the diagonal part.
+
# Examples
```jldoctest; setup=:(using Zygote)
@@ -240,7 +242,7 @@ true
julia> diaghessian((x,y) -> sum(x .* y .* y'), [1 22; 333 4], [0.5, 0.666]) # two array arguments
([0.0 0.0; 0.0 0.0], [2.0, 8.0])
-julia> diaghessian(atan, 1, 2) # scalar arguments
+julia> diaghessian(atan, 1, 2) # two scalar arguments
(-0.16, 0.16)
julia> hessian(xy -> atan(xy[1], xy[2]), [1, 2]) # full Hessian is not diagonal
From c4c77a0bf24774e4003e9756b3770cff04454a0a Mon Sep 17 00:00:00 2001
From: Michael Abbott <32575566+mcabbott@users.noreply.github.com>
Date: Fri, 30 Apr 2021 22:29:24 -0400
Subject: [PATCH 023/490] speed improvements
---
src/lib/forward.jl | 25 ++++++++++++++-----------
src/lib/grad.jl | 2 +-
2 files changed, 15 insertions(+), 12 deletions(-)
diff --git a/src/lib/forward.jl b/src/lib/forward.jl
index 7e6b6e0fa..1637df165 100644
--- a/src/lib/forward.jl
+++ b/src/lib/forward.jl
@@ -8,6 +8,12 @@ function seed(x, ::Val{N}, offset = 0) where N
Dual(x, ntuple(j -> j+offset == i, Val(N)))
end
end
+function seed!(xplus, x, ::Val{N}, offset) where N
+ @assert size(x) == size(xplus)
+ map!(xplus, x, reshape(1:length(x), size(x))) do x, i
+ Dual(x, ntuple(j -> j+offset == i, Val(N)))
+ end
+end
extract(x::ForwardDiff.Dual) = x.value, [x.partials...]
@@ -48,26 +54,23 @@ reshape_scalar(x::Real, y) = y[]
# very similar functions needed for diaghessian:
-function extract_diag!(_D, offset, xs::AbstractArray{ForwardDiff.Dual{T,V,N}}) where {T,V,N}
+function extract_diag!(out, offset, xs::AbstractArray{ForwardDiff.Dual{T,V,N}}) where {T,V,N}
for j in 1:min(N, length(xs)-offset)
- _D[j] = xs[offset+j].partials.values[j]
+ out[offset+j] = xs[offset+j].partials.values[j]
end
end
function forward_diag(f, x::AbstractArray{T}, ::Val{N}) where {N,T}
- fx = f(seed(x, Val(N)))
- D = similar(x, ForwardDiff.valtype(eltype(fx)))
- _D = similar(D, N)
- extract_diag!(_D, 0, fx)
- D[1:N] = _D
+ xplus = seed(x, Val(N))
+ fx = f(xplus)
+ out = similar(x, ForwardDiff.valtype(eltype(fx)))
+ extract_diag!(out, 0, fx)
offset = 0
while offset + N < length(x)
offset += N
- extract_diag!(_D, offset, f(seed(x, Val(N), offset)))
- range = (1+offset):min(N+offset,length(x))
- D[range] = @view _D[range.-offset]
+ extract_diag!(out, offset, f(seed!(xplus, x, Val(N), offset)))
end
- return map(y -> y.value, fx), D
+ return map(y -> y.value, fx), out
end
function forward_diag(f, x::AbstractArray)
diff --git a/src/lib/grad.jl b/src/lib/grad.jl
index ab09bffeb..f88aa4a72 100644
--- a/src/lib/grad.jl
+++ b/src/lib/grad.jl
@@ -217,7 +217,7 @@ function jacobian(f, pars::Params)
end
"""
- diaghessian(f, args...)
+ diaghessian(f, args...) -> Tuple
Diagonal part of the Hessian. Returns a tuple containing, for each argument `x`,
`h` of the same shape with `h[i] = Hᵢᵢ = ∂²y/∂x[i]∂x[i]`.
From d0fa801be22fb109b196e8ace9812f2c89408469 Mon Sep 17 00:00:00 2001
From: Michael Abbott <32575566+mcabbott@users.noreply.github.com>
Date: Fri, 30 Apr 2021 22:30:07 -0400
Subject: [PATCH 024/490] another approach, not faster?
---
src/lib/forward.jl | 47 ++++++++++++++++++++++++++++++++++++++++++++++
src/lib/grad.jl | 14 ++++++++++++++
2 files changed, 61 insertions(+)
diff --git a/src/lib/forward.jl b/src/lib/forward.jl
index 1637df165..06f2c62ff 100644
--- a/src/lib/forward.jl
+++ b/src/lib/forward.jl
@@ -81,6 +81,53 @@ function forward_diag(f, x::AbstractArray)
end
end
+# and another approach: all forward, directly 2nd order
+
+seed_2nd(x::Real, ::Val) = Dual(Dual(x, true), true)
+
+function seed_2nd(xs, ::Val{N}, offset = 0) where N
+ map(xs, reshape(1:length(xs), size(xs))) do x, i
+ b = ntuple(j -> j+offset == i, Val(N))
+ Dual(Dual(x, b), b)
+ end
+end
+function seed_2nd!(xplus, xs, ::Val{N}, offset) where N
+ map!(xplus, xs, reshape(1:length(xs), size(xs))) do x, i
+ b = ntuple(j -> j+offset == i, Val(N))
+ Dual(Dual(x, b), b)
+ end
+end
+
+function extract_2nd!(out, fx::ForwardDiff.Dual{T,V,N}, offset) where {T,V,N}
+ for j in 1:min(N, length(out)-offset)
+ out[j+offset] = fx.partials.values[j].partials.values[j]
+ end
+end
+
+function forward_2nd(f, x::AbstractArray{T}, ::Val{N}) where {N,T}
+ xplus = seed_2nd(x, Val(N), 0)
+ fx = f(xplus)
+ out = similar(x, ForwardDiff.valtype(ForwardDiff.valtype(typeof(fx))))
+ extract_2nd!(out, fx, 0)
+ offset = 0
+ while offset + N < length(x)
+ offset += N
+ fx = f(seed_2nd!(xplus, x, Val(N), offset))
+ extract_2nd!(out, fx, offset)
+ end
+ return fx.value.value, out
+end
+
+function forward_2nd(f, x::AbstractArray)
+ # if length(x) < ForwardDiff.DEFAULT_CHUNK_THRESHOLD
+ # forward_2nd(f, x, Val(length(x)))
+ # else
+ # forward_2nd(f, x, Val(ForwardDiff.DEFAULT_CHUNK_THRESHOLD ÷ 2))
+ forward_2nd(f, x, Val(3))
+ # end
+end
+
+
"""
forwarddiff(f, x) -> f(x)
diff --git a/src/lib/grad.jl b/src/lib/grad.jl
index f88aa4a72..ab1815b71 100644
--- a/src/lib/grad.jl
+++ b/src/lib/grad.jl
@@ -264,3 +264,17 @@ function diaghessian(f, args...)
end
_splice(x, args, ::Val{n}) where {n} = ntuple(i -> i==n ? x : args[i], length(args))
+
+function diaghessian_2nd(f, args...)
+ ntuple(length(args)) do n
+ let x = args[n], valn = Val(n)
+ if x isa AbstractArray
+ forward_2nd(x -> f(_splice(x, args, valn)...), x)[2]
+ elseif x isa Number
+ ForwardDiff.hessian(x -> f(_splice(x[1], args, valn)...), [x])[1]
+ end
+ end
+ end
+end
+
+export diaghessian_2nd
From 7b3e9c20dcd37806e566014418a443a488a12fb9 Mon Sep 17 00:00:00 2001
From: Michael Abbott <32575566+mcabbott@users.noreply.github.com>
Date: Fri, 30 Apr 2021 22:45:41 -0400
Subject: [PATCH 025/490] Revert "another approach, not faster?"
This reverts commit d0fa801be22fb109b196e8ace9812f2c89408469.
---
src/lib/forward.jl | 47 ----------------------------------------------
src/lib/grad.jl | 14 --------------
2 files changed, 61 deletions(-)
diff --git a/src/lib/forward.jl b/src/lib/forward.jl
index 06f2c62ff..1637df165 100644
--- a/src/lib/forward.jl
+++ b/src/lib/forward.jl
@@ -81,53 +81,6 @@ function forward_diag(f, x::AbstractArray)
end
end
-# and another approach: all forward, directly 2nd order
-
-seed_2nd(x::Real, ::Val) = Dual(Dual(x, true), true)
-
-function seed_2nd(xs, ::Val{N}, offset = 0) where N
- map(xs, reshape(1:length(xs), size(xs))) do x, i
- b = ntuple(j -> j+offset == i, Val(N))
- Dual(Dual(x, b), b)
- end
-end
-function seed_2nd!(xplus, xs, ::Val{N}, offset) where N
- map!(xplus, xs, reshape(1:length(xs), size(xs))) do x, i
- b = ntuple(j -> j+offset == i, Val(N))
- Dual(Dual(x, b), b)
- end
-end
-
-function extract_2nd!(out, fx::ForwardDiff.Dual{T,V,N}, offset) where {T,V,N}
- for j in 1:min(N, length(out)-offset)
- out[j+offset] = fx.partials.values[j].partials.values[j]
- end
-end
-
-function forward_2nd(f, x::AbstractArray{T}, ::Val{N}) where {N,T}
- xplus = seed_2nd(x, Val(N), 0)
- fx = f(xplus)
- out = similar(x, ForwardDiff.valtype(ForwardDiff.valtype(typeof(fx))))
- extract_2nd!(out, fx, 0)
- offset = 0
- while offset + N < length(x)
- offset += N
- fx = f(seed_2nd!(xplus, x, Val(N), offset))
- extract_2nd!(out, fx, offset)
- end
- return fx.value.value, out
-end
-
-function forward_2nd(f, x::AbstractArray)
- # if length(x) < ForwardDiff.DEFAULT_CHUNK_THRESHOLD
- # forward_2nd(f, x, Val(length(x)))
- # else
- # forward_2nd(f, x, Val(ForwardDiff.DEFAULT_CHUNK_THRESHOLD ÷ 2))
- forward_2nd(f, x, Val(3))
- # end
-end
-
-
"""
forwarddiff(f, x) -> f(x)
diff --git a/src/lib/grad.jl b/src/lib/grad.jl
index ab1815b71..f88aa4a72 100644
--- a/src/lib/grad.jl
+++ b/src/lib/grad.jl
@@ -264,17 +264,3 @@ function diaghessian(f, args...)
end
_splice(x, args, ::Val{n}) where {n} = ntuple(i -> i==n ? x : args[i], length(args))
-
-function diaghessian_2nd(f, args...)
- ntuple(length(args)) do n
- let x = args[n], valn = Val(n)
- if x isa AbstractArray
- forward_2nd(x -> f(_splice(x, args, valn)...), x)[2]
- elseif x isa Number
- ForwardDiff.hessian(x -> f(_splice(x[1], args, valn)...), [x])[1]
- end
- end
- end
-end
-
-export diaghessian_2nd
From 30419f4c95e9e578dbb7449288b5a1e820be8fef Mon Sep 17 00:00:00 2001
From: Michael Abbott <32575566+mcabbott@users.noreply.github.com>
Date: Tue, 4 May 2021 13:14:49 -0400
Subject: [PATCH 026/490] sparse getindex gradient
---
src/lib/array.jl | 16 ++++++++++++++--
1 file changed, 14 insertions(+), 2 deletions(-)
diff --git a/src/lib/array.jl b/src/lib/array.jl
index 344db8dcd..2d9b5311c 100644
--- a/src/lib/array.jl
+++ b/src/lib/array.jl
@@ -32,8 +32,10 @@ end
@adjoint view(x::AbstractArray, inds...) = view(x, inds...), ∇getindex(x, inds)
-∇getindex(x::AbstractArray, inds) = dy -> begin
- if inds isa NTuple{<:Any, Integer}
+∇getindex(x::AbstractArray{T,N}, inds) where {T,N} = dy -> begin
+ if inds isa NTuple{N,Int} && T <: Number
+ dx = OneElement(dy, inds, axes(x))
+ elseif inds isa NTuple{<:Any, Integer}
dx = _zero(x, typeof(dy))
dx[inds...] = dy
else
@@ -44,6 +46,16 @@ end
return (dx, map(_->nothing, inds)...)
end
+struct OneElement{T,N,I,A} <: AbstractArray{T,N}
+ val::T
+ index::I
+ axes::A
+ OneElement(x::T, i::I, a::A) where {T,I<:NTuple{N,Int},A} where {N} = new{T,N,I,A}(x, i, a)
+end
+Base.size(A::OneElement) = map(length, A.axes)
+Base.axes(A::OneElement) = A.axes
+Base.getindex(A::OneElement{T,N}, i::Vararg{Int,N}) where {T,N} = ifelse(i==A.index, A.val, zero(T))
+
_zero(xs::AbstractArray{<:Number}, T::Type{Nothing}) = fill!(similar(xs), zero(eltype(xs)))
_zero(xs::AbstractArray{<:Number}, T) = fill!(similar(xs, T), false)
_zero(xs::AbstractArray, T) = fill!(similar(xs, Union{Nothing, T}), nothing)
From 8fafdf040004d1e163edb99a0edef6050c875cb2 Mon Sep 17 00:00:00 2001
From: Michael Abbott <32575566+mcabbott@users.noreply.github.com>
Date: Tue, 4 May 2021 13:15:12 -0400
Subject: [PATCH 027/490] a step towards in-place accumulation
---
src/lib/array.jl | 2 ++
src/lib/lib.jl | 1 +
2 files changed, 3 insertions(+)
diff --git a/src/lib/array.jl b/src/lib/array.jl
index 2d9b5311c..c22b02d9a 100644
--- a/src/lib/array.jl
+++ b/src/lib/array.jl
@@ -56,6 +56,8 @@ Base.size(A::OneElement) = map(length, A.axes)
Base.axes(A::OneElement) = A.axes
Base.getindex(A::OneElement{T,N}, i::Vararg{Int,N}) where {T,N} = ifelse(i==A.index, A.val, zero(T))
+accum(x::Array, y::OneElement) = (@inbounds x[y.index...] = accum(x[y.index...], y.val); x)
+
_zero(xs::AbstractArray{<:Number}, T::Type{Nothing}) = fill!(similar(xs), zero(eltype(xs)))
_zero(xs::AbstractArray{<:Number}, T) = fill!(similar(xs, T), false)
_zero(xs::AbstractArray, T) = fill!(similar(xs, Union{Nothing, T}), nothing)
diff --git a/src/lib/lib.jl b/src/lib/lib.jl
index f4d321c29..3aaf87cde 100644
--- a/src/lib/lib.jl
+++ b/src/lib/lib.jl
@@ -14,6 +14,7 @@ accum(x, y, zs...) = accum(accum(x, y), zs...)
accum(x::Tuple, y::Tuple) = accum.(x, y)
accum(x::AbstractArray, y::AbstractArray) = accum.(x, y)
+accum(x::DenseArray, y::AbstractArray) = x .= accum.(x, y)
@generated function accum(x::NamedTuple, y::NamedTuple)
# assumes that y has no keys apart from those also in x
From 431d3d46400e1f7a693c198829cfb738bf75f234 Mon Sep 17 00:00:00 2001
From: Michael Abbott <32575566+mcabbott@users.noreply.github.com>
Date: Wed, 5 May 2021 12:09:10 -0400
Subject: [PATCH 028/490] four seven seven four
---
src/lib/lib.jl | 2 ++
1 file changed, 2 insertions(+)
diff --git a/src/lib/lib.jl b/src/lib/lib.jl
index 3aaf87cde..58a551cb6 100644
--- a/src/lib/lib.jl
+++ b/src/lib/lib.jl
@@ -15,6 +15,8 @@ accum(x, y, zs...) = accum(accum(x, y), zs...)
accum(x::Tuple, y::Tuple) = accum.(x, y)
accum(x::AbstractArray, y::AbstractArray) = accum.(x, y)
accum(x::DenseArray, y::AbstractArray) = x .= accum.(x, y)
+# work around bug fixed in https://github.com/JuliaLang/julia/pull/39859
+accum(x::DenseVector, y::AbstractArray) = x .= accum.(x, vec(y))
@generated function accum(x::NamedTuple, y::NamedTuple)
# assumes that y has no keys apart from those also in x
From 549671e4341436635bf1c65b4275fc6523ba3755 Mon Sep 17 00:00:00 2001
From: Michael Abbott <32575566+mcabbott@users.noreply.github.com>
Date: Wed, 5 May 2021 17:49:31 -0400
Subject: [PATCH 029/490] change == to isapprox in some tests
---
test/gradcheck.jl | 12 ++++++------
1 file changed, 6 insertions(+), 6 deletions(-)
diff --git a/test/gradcheck.jl b/test/gradcheck.jl
index 1b509d159..7e713f950 100644
--- a/test/gradcheck.jl
+++ b/test/gradcheck.jl
@@ -555,14 +555,14 @@ end
@testset "Cholesky" begin
# Check that the forwards pass computes the correct thing.
f(X, Y) = cholesky(X * X' + I) \ Y
- @test Zygote.pullback(X -> f(X, Y), X)[1] == cholesky(X * X' + I) \ Y
+ @test Zygote.pullback(X -> f(X, Y), X)[1] ≈ cholesky(X * X' + I) \ Y
@test gradtest(X -> f(X, Y), X)
@test gradtest(Y -> f(X, Y), Y)
@test gradtest(X -> f(X, y), X)
@test gradtest(y -> f(X, y), y)
g(X) = cholesky(X * X' + I)
- @test Zygote.pullback(g, X)[2]((factors=LowerTriangular(X),)) ==
- Zygote.pullback(g, X)[2]((factors=Matrix(LowerTriangular(X)),))
+ @test Zygote.pullback(g, X)[2]((factors=LowerTriangular(X),))[1] ≈
+ Zygote.pullback(g, X)[2]((factors=Matrix(LowerTriangular(X)),))[1]
@test_throws PosDefException Zygote.pullback(X -> cholesky(X, check = false), X)[2]((factors=X,))
# https://github.com/FluxML/Zygote.jl/issues/932
@@ -689,8 +689,8 @@ end
@test gradtest(Diagonal, d)
y, back = Zygote.pullback(Diagonal, d)
D̄ = randn(rng, P, P)
- @test back(D̄) == back(Diagonal(D̄))
- @test back(D̄) == back((diag=diag(D̄),))
+ @test back(D̄)[1] ≈ back(Diagonal(D̄))[1]
+ @test back(D̄)[1] ≈ back((diag=diag(D̄),))[1]
end
@testset "dense + UniformScaling" begin
@@ -705,7 +705,7 @@ end
@testset "cholesky - dense" begin
rng, N = MersenneTwister(123456), 5
A = randn(rng, N, N)
- @test cholesky(A' * A + I) == first(Zygote.pullback(A->cholesky(A' * A + I), A))
+ @test cholesky(A' * A + I).U ≈ first(Zygote.pullback(A->cholesky(A' * A + I), A)).U
@test gradtest(A->cholesky(A' * A + I).U, A)
@test gradtest(A->logdet(cholesky(A' * A + I)), A)
@test gradtest(B->cholesky(Symmetric(B)).U, A * A' + I)
From 0811a8c1425babf8ed21bec89c601f76087f6f34 Mon Sep 17 00:00:00 2001
From: David Widmann
Date: Thu, 6 May 2021 18:31:34 +0200
Subject: [PATCH 030/490] Use BenchmarkTools
---
Project.toml | 3 ++-
test/gradcheck.jl | 14 +++++++-------
2 files changed, 9 insertions(+), 8 deletions(-)
diff --git a/Project.toml b/Project.toml
index 3decc2e10..ce17f0551 100644
--- a/Project.toml
+++ b/Project.toml
@@ -37,6 +37,7 @@ ZygoteRules = "0.2.1"
julia = "1.3"
[extras]
+BenchmarkTools = "6e4b80f9-dd63-53aa-95a3-0cdb28fa8baf"
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
Distances = "b4f34e82-e78d-54a5-968a-f98e89d6e8f7"
FFTW = "7a1cc6ca-52ef-59f5-83cd-3a7055c09341"
@@ -45,4 +46,4 @@ StatsFuns = "4c63d2b9-4356-54db-8cca-17b64c39e42c"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
[targets]
-test = ["CUDA", "Distances", "FFTW", "FiniteDifferences", "StatsFuns", "Test"]
+test = ["BenchmarkTools", "CUDA", "Distances", "FFTW", "FiniteDifferences", "StatsFuns", "Test"]
diff --git a/test/gradcheck.jl b/test/gradcheck.jl
index a5be4fb98..c01a18ca1 100644
--- a/test/gradcheck.jl
+++ b/test/gradcheck.jl
@@ -4,6 +4,7 @@ using Zygote: gradient
using Base.Broadcast: broadcast_shape
using Distributed: pmap, CachingPool, workers
import FiniteDifferences
+using BenchmarkTools
function ngradient(f, xs::AbstractArray...)
grads = zero.(xs)
@@ -1677,13 +1678,12 @@ end
@test gradcheck(x -> prod(Base.Fix1(+, 1), x), randn(100))
@test gradcheck(x -> prod(Base.Fix2(+, 1), x), randn(100))
- # compile once and check the execution times compared with a closure
+ # check the execution times compared with a closure
# https://github.com/FluxML/Zygote.jl/issues/957
x = randn(100)
- gradient(x -> prod(y -> y + 1, x), x)
- t = @elapsed(gradient(x -> prod(y -> y + 1, x), x))
- gradient(x -> prod(Base.Fix1(+, 1), x), x)
- @test @elapsed(gradient(x -> prod(Base.Fix1(+, 1), x), x)) < 2 * t
- gradient(x -> prod(Base.Fix1(+, 1), x), x)
- @test @elapsed(gradient(x -> prod(Base.Fix2(+, 1), x), x)) < 2 * t
+ tclosure = @belapsed(gradient($(x -> prod(y -> y + 1, x)), $x))
+ tfix1 = @belapsed(gradient($(x -> prod(Base.Fix1(+, 1), x)), $x))
+ tfix2 = @belapsed(gradient($(x -> prod(Base.Fix2(+, 1), x)), $x))
+ @test tfix1 < 2 * tclosure
+ @test tfix2 < 2 * tclosure
end
From 85fb41603193573032991cc3ee1b4817472bea5a Mon Sep 17 00:00:00 2001
From: Michael Abbott <32575566+mcabbott@users.noreply.github.com>
Date: Sat, 8 May 2021 09:02:01 -0400
Subject: [PATCH 031/490] rm forward forward
---
src/lib/forward.jl | 5 -----
test/utils.jl | 11 -----------
2 files changed, 16 deletions(-)
diff --git a/src/lib/forward.jl b/src/lib/forward.jl
index 1637df165..4e0bb9c82 100644
--- a/src/lib/forward.jl
+++ b/src/lib/forward.jl
@@ -137,8 +137,3 @@ forwarddiff(f, x) = f(x)
return y, ȳ -> (nothing, reshape_scalar(x, J*vec_scalar(ȳ)))
end
-# Second derivatives
-@adjoint ForwardDiff.derivative(f, x) = pullback(forwarddiff, x -> ForwardDiff.derivative(f, x), x)
-@adjoint ForwardDiff.gradient(f, x) = pullback(forwarddiff, x -> ForwardDiff.gradient(f, x), x)
-@adjoint ForwardDiff.jacobian(f, x) = pullback(forwarddiff, x -> ForwardDiff.jacobian(f, x), x)
-
diff --git a/test/utils.jl b/test/utils.jl
index 9b5171481..73a7d65c4 100644
--- a/test/utils.jl
+++ b/test/utils.jl
@@ -83,14 +83,3 @@ end
@test Jxy[ys] ≈ [1 0 0; 0 1 0]
@test Jxy[xs] ≈ [2 6 4 8; 2 6 4 8]
end
-
-@testset "adjoints of ForwardDiff functions" begin
- f1(x) = ForwardDiff.gradient(x -> sum(exp.(x.+1)), x)
- x1 = randn(3,7)
- @test Zygote.jacobian(f1, x1)[1] ≈ ForwardDiff.jacobian(f1, x1)
-
- f2(x) = ForwardDiff.jacobian(x -> log.(x[1:3] .+ x[2:4]), x)
- x2 = rand(5) .+ 1
- @test Zygote.jacobian(f2, x2)[1] ≈ ForwardDiff.jacobian(f2, x2)
-end
-
From bc7823191d8c83df56e04b5c344673f790e54549 Mon Sep 17 00:00:00 2001
From: Michael Abbott <32575566+mcabbott@users.noreply.github.com>
Date: Sat, 8 May 2021 09:17:03 -0400
Subject: [PATCH 032/490] adjoint for ForwardDiff.jacobian
---
src/lib/forward.jl | 4 ++++
test/utils.jl | 10 ++++++++++
2 files changed, 14 insertions(+)
diff --git a/src/lib/forward.jl b/src/lib/forward.jl
index 7a6e125ff..1de719c27 100644
--- a/src/lib/forward.jl
+++ b/src/lib/forward.jl
@@ -101,3 +101,7 @@ forwarddiff(f, x) = f(x)
y, J = forward_jacobian(f, x)
return y, ȳ -> (nothing, reshape_scalar(x, J*vec_scalar(ȳ)))
end
+
+# Use this to allow second derivatives
+@adjoint ForwardDiff.gradient(f, x) = pullback(forwarddiff, x -> ForwardDiff.gradient(f, x), x)
+@adjoint ForwardDiff.jacobian(f, x) = pullback(forwarddiff, x -> ForwardDiff.jacobian(f, x), x)
diff --git a/test/utils.jl b/test/utils.jl
index d09fc2dc2..e86bbf338 100644
--- a/test/utils.jl
+++ b/test/utils.jl
@@ -61,3 +61,13 @@ end
@test Jxy[ys] ≈ [1 0 0; 0 1 0]
@test Jxy[xs] ≈ [2 6 4 8; 2 6 4 8]
end
+
+@testset "adjoints of ForwardDiff functions" begin
+ f1(x) = ForwardDiff.gradient(x -> sum(exp.(x.+1)), x)
+ x1 = randn(3,7)
+ @test Zygote.jacobian(f1, x1)[1] ≈ ForwardDiff.jacobian(f1, x1)
+
+ f2(x) = ForwardDiff.jacobian(x -> log.(x[1:3] .+ x[2:4]), x)
+ x2 = rand(5) .+ 1
+ @test Zygote.jacobian(f2, x2)[1] ≈ ForwardDiff.jacobian(f2, x2)
+end
From 6bff8f8f2d5a36ce4ce894bc1e20299f1022da6d Mon Sep 17 00:00:00 2001
From: Michael Abbott <32575566+mcabbott@users.noreply.github.com>
Date: Sat, 8 May 2021 09:35:50 -0400
Subject: [PATCH 033/490] add tests from
https://github.com/FluxML/Zygote.jl/issues/769
---
test/utils.jl | 19 +++++++++++++++++++
1 file changed, 19 insertions(+)
diff --git a/test/utils.jl b/test/utils.jl
index e86bbf338..473277b5b 100644
--- a/test/utils.jl
+++ b/test/utils.jl
@@ -62,6 +62,8 @@ end
@test Jxy[xs] ≈ [2 6 4 8; 2 6 4 8]
end
+using ForwardDiff
+
@testset "adjoints of ForwardDiff functions" begin
f1(x) = ForwardDiff.gradient(x -> sum(exp.(x.+1)), x)
x1 = randn(3,7)
@@ -70,4 +72,21 @@ end
f2(x) = ForwardDiff.jacobian(x -> log.(x[1:3] .+ x[2:4]), x)
x2 = rand(5) .+ 1
@test Zygote.jacobian(f2, x2)[1] ≈ ForwardDiff.jacobian(f2, x2)
+
+ # Tests from https://github.com/FluxML/Zygote.jl/issues/769
+ f(x) = [2x[1]^2 + x[1],x[2]^2 * x[1]]
+ g1(x) = sum(ForwardDiff.jacobian(f,x))
+ out,back = Zygote.pullback(g1,[2.0,3.2])
+ stakehouse = back(1.0)[1]
+ @test typeof(stakehouse) <: Vector
+ @test size(stakehouse) == (2,)
+ @test stakehouse ≈ ForwardDiff.gradient(g1,[2.0,3.2])
+
+ g2(x) = prod(ForwardDiff.jacobian(f,x))
+ out,back = Zygote.pullback(g2,[2.0,3.2])
+ @test_skip back(1.0)[1] == ForwardDiff.gradient(g2,[2.0,3.2]) # contains NaN, @adjoint prod isn't careful
+
+ g3(x) = sum(abs2,ForwardDiff.jacobian(f,x))
+ out,back = Zygote.pullback(g3,[2.0,3.2])
+ @test back(1.0)[1] == ForwardDiff.gradient(g3,[2.0,3.2])
end
From e95ba74c34f1aaedc43522f13ce0fc3f8efd98bf Mon Sep 17 00:00:00 2001
From: Michael Abbott <32575566+mcabbott@users.noreply.github.com>
Date: Sun, 9 May 2021 18:07:28 -0400
Subject: [PATCH 034/490] avoid mutation
---
src/lib/array.jl | 1 -
src/lib/lib.jl | 7 ++-----
test/features.jl | 19 +++++++++++++++++++
3 files changed, 21 insertions(+), 6 deletions(-)
diff --git a/src/lib/array.jl b/src/lib/array.jl
index c22b02d9a..8bda68bbf 100644
--- a/src/lib/array.jl
+++ b/src/lib/array.jl
@@ -56,7 +56,6 @@ Base.size(A::OneElement) = map(length, A.axes)
Base.axes(A::OneElement) = A.axes
Base.getindex(A::OneElement{T,N}, i::Vararg{Int,N}) where {T,N} = ifelse(i==A.index, A.val, zero(T))
-accum(x::Array, y::OneElement) = (@inbounds x[y.index...] = accum(x[y.index...], y.val); x)
_zero(xs::AbstractArray{<:Number}, T::Type{Nothing}) = fill!(similar(xs), zero(eltype(xs)))
_zero(xs::AbstractArray{<:Number}, T) = fill!(similar(xs, T), false)
diff --git a/src/lib/lib.jl b/src/lib/lib.jl
index 58a551cb6..045a494de 100644
--- a/src/lib/lib.jl
+++ b/src/lib/lib.jl
@@ -12,11 +12,8 @@ accum(x, y) =
accum(x, y, zs...) = accum(accum(x, y), zs...)
-accum(x::Tuple, y::Tuple) = accum.(x, y)
-accum(x::AbstractArray, y::AbstractArray) = accum.(x, y)
-accum(x::DenseArray, y::AbstractArray) = x .= accum.(x, y)
-# work around bug fixed in https://github.com/JuliaLang/julia/pull/39859
-accum(x::DenseVector, y::AbstractArray) = x .= accum.(x, vec(y))
+accum(x::Tuple, ys::Tuple...) = accum.(x, ys...)
+accum(x::AbstractArray, ys::AbstractArray...) = accum.(x, ys...)
@generated function accum(x::NamedTuple, y::NamedTuple)
# assumes that y has no keys apart from those also in x
diff --git a/test/features.jl b/test/features.jl
index 5531ebd0b..874e3916e 100644
--- a/test/features.jl
+++ b/test/features.jl
@@ -481,3 +481,22 @@ end
Zygote.gradient(loss_adjoint,[1.0])
@test x[1] == x[2]
end
+
+@testset "accumulation" begin
+ # from https://github.com/FluxML/Zygote.jl/issues/905
+ function net(x1)
+ x2 = x1
+ x3 = x1 + x2
+ x4 = x1 + x2 + x3
+ x5 = x1 + x2 + x3 + x4
+ x6 = x1 + x2 + x3 + x4 + x5
+ x7 = x1 + x2 + x3 + x4 + x5 + x6
+ x8 = x1 + x2 + x3 + x4 + x5 + x6 + x7
+ x9 = x1 + x2 + x3 + x4 + x5 + x6 + x7 + x8
+ x10 = x1 + x2 + x3 + x4 + x5 + x6 + x7 + x8 + x9
+ end
+ loss(x) = sum(abs2, net(x))
+ @test gradient(loss, ones(10,10))[1] == fill(131072, 10, 10)
+ @test 150_000_000 > @allocated gradient(loss, ones(1000,1000))
+end
+
From 4b4cdf2f75c6549ea1fab690f0174f5bbf9f2a03 Mon Sep 17 00:00:00 2001
From: CarloLucibello
Date: Fri, 14 May 2021 17:29:14 +0200
Subject: [PATCH 035/490] fix gradient algebra on gpu
---
src/compiler/interface.jl | 2 ++
test/cuda.jl | 44 ++++++++++++++++++++++++++++++++++++++-
test/interface.jl | 2 +-
3 files changed, 46 insertions(+), 2 deletions(-)
diff --git a/src/compiler/interface.jl b/src/compiler/interface.jl
index 7d72f9537..f5336ac03 100644
--- a/src/compiler/interface.jl
+++ b/src/compiler/interface.jl
@@ -80,6 +80,8 @@ end
Base.copy(ps::Params) = union!(Params(), ps)
Base.union(ps::Params, itrs...) = union!(copy(ps), itrs...)
+Base.issetequal(ps1::Params, ps2::Params) = issetequal(ps1.params, ps2.params)
+# Base.issetequal(ps1::Params, x::AbstractSet) = issetequal(ps1.params, x)
function Base.intersect!(ps::Params, itrs...)
for itr in itrs
diff --git a/test/cuda.jl b/test/cuda.jl
index 0766ff986..f90bca6a6 100644
--- a/test/cuda.jl
+++ b/test/cuda.jl
@@ -1,4 +1,6 @@
using CUDA
+using Zygote: Grads
+using Random: randn!
# Test GPU movement inside the call to `gradient`
@testset "GPU movement" begin
@@ -21,7 +23,6 @@ end
g_gpu = gradient(x -> w(x), a_gpu)[1]
@test g_gpu isa CuArray
@test g_gpu |> collect ≈ g
-
end
@testset "jacobian" begin
@@ -37,3 +38,44 @@ end
@test j2[v1] isa CuArray
@test j2[v1] ≈ cu(res2)
end
+
+@testset "gradient algebra" begin
+ w, b = rand(2) |> cu, rand(2) |> cu
+ x1, x2 = rand(2) |> cu, rand(2) |> cu
+
+ gs1 = gradient(() -> sum(w .* x1), Params([w]))
+ gs2 = gradient(() -> sum(w .* x2), Params([w]))
+
+ @test .- gs1 isa Grads
+ @test gs1 .- gs2 isa Grads
+ @test .+ gs1 isa Grads
+ @test gs1 .+ gs2 isa Grads
+ @test 2 .* gs1 isa Grads
+ @test (2 .* gs1)[w] ≈ 2 * gs1[w]
+ @test gs1 .* 2 isa Grads
+ @test gs1 ./ 2 isa Grads
+ @test (gs1 .+ gs2)[w] ≈ gs1[w] .+ gs2[w]
+
+ gs12 = gs1 .+ gs2
+ gs1 .+= gs2
+ @test gs12[w] ≈ gs1[w]
+
+ gs3 = gradient(() -> sum(w .* x1), Params([w, b])) # grad nothing with respect to b
+ gs4 = gradient(() -> sum(w .* x2 .+ b), Params([w, b]))
+
+ @test .- gs3 isa Grads
+ @test gs3 .- gs4 isa Grads
+ @test .+ gs3 isa Grads
+ @test gs3 .+ gs4 isa Grads
+ @test 2 .* gs3 isa Grads
+ @test gs3 .* 2 isa Grads
+ @test gs3 ./ 2 isa Grads
+ @test (gs3 .+ gs4)[w] ≈ gs3[w] .+ gs4[w]
+ @test (gs3 .+ gs4)[b] ≈ gs4[b]
+
+ @test gs3 .+ Dict(w => similar(w), b => similar(b)) isa Grads
+ gs3 .+= Dict(p => randn!(similar(p)) for p in keys(gs3))
+ @test gs3 isa Grads
+
+ @test_throws ArgumentError gs1 .+ gs4
+end
\ No newline at end of file
diff --git a/test/interface.jl b/test/interface.jl
index 087da74f3..584228c84 100644
--- a/test/interface.jl
+++ b/test/interface.jl
@@ -120,7 +120,7 @@ end
gs3 .+= Dict(p => randn(size(p)) for p in keys(gs3))
@test gs3 isa Grads
- @test_throws ArgumentError gs1 .+ gs4
+ @test_throws ArgumentError gs1 .+ gs4
end
@testset "map and broadcast" begin
From 85b93e6c4920f397fa878dccf7db9ffbce74eb1b Mon Sep 17 00:00:00 2001
From: CarloLucibello
Date: Fri, 14 May 2021 17:45:45 +0200
Subject: [PATCH 036/490] mark one test as broken
---
src/compiler/interface.jl | 1 -
test/cuda.jl | 9 ++++++---
2 files changed, 6 insertions(+), 4 deletions(-)
diff --git a/src/compiler/interface.jl b/src/compiler/interface.jl
index f5336ac03..2e2725480 100644
--- a/src/compiler/interface.jl
+++ b/src/compiler/interface.jl
@@ -81,7 +81,6 @@ end
Base.copy(ps::Params) = union!(Params(), ps)
Base.union(ps::Params, itrs...) = union!(copy(ps), itrs...)
Base.issetequal(ps1::Params, ps2::Params) = issetequal(ps1.params, ps2.params)
-# Base.issetequal(ps1::Params, x::AbstractSet) = issetequal(ps1.params, x)
function Base.intersect!(ps::Params, itrs...)
for itr in itrs
diff --git a/test/cuda.jl b/test/cuda.jl
index f90bca6a6..e08f9a500 100644
--- a/test/cuda.jl
+++ b/test/cuda.jl
@@ -1,6 +1,7 @@
using CUDA
using Zygote: Grads
using Random: randn!
+CUDA.allowscalar(false)
# Test GPU movement inside the call to `gradient`
@testset "GPU movement" begin
@@ -73,9 +74,11 @@ end
@test (gs3 .+ gs4)[w] ≈ gs3[w] .+ gs4[w]
@test (gs3 .+ gs4)[b] ≈ gs4[b]
- @test gs3 .+ Dict(w => similar(w), b => similar(b)) isa Grads
- gs3 .+= Dict(p => randn!(similar(p)) for p in keys(gs3))
- @test gs3 isa Grads
+ @test_broken begin
+ gs3 .+ Dict(w => similar(w), b => similar(b)) isa Grads
+ gs3 .+= Dict(p => randn!(similar(p)) for p in keys(gs3))
+ gs3 isa Grads
+ end
@test_throws ArgumentError gs1 .+ gs4
end
\ No newline at end of file
From 478fa7d00afe8e4f6ace269c0111b4d4306c721f Mon Sep 17 00:00:00 2001
From: Michael Abbott <32575566+mcabbott@users.noreply.github.com>
Date: Fri, 14 May 2021 13:40:23 -0400
Subject: [PATCH 037/490] comment
---
src/lib/forward.jl | 3 ++-
1 file changed, 2 insertions(+), 1 deletion(-)
diff --git a/src/lib/forward.jl b/src/lib/forward.jl
index 1de719c27..113b6aa6c 100644
--- a/src/lib/forward.jl
+++ b/src/lib/forward.jl
@@ -102,6 +102,7 @@ forwarddiff(f, x) = f(x)
return y, ȳ -> (nothing, reshape_scalar(x, J*vec_scalar(ȳ)))
end
-# Use this to allow second derivatives
+# Use this to allow second derivatives -- this is forward-over-forward,
+# see https://github.com/FluxML/Zygote.jl/issues/769 for a forward-over-reverse proposal
@adjoint ForwardDiff.gradient(f, x) = pullback(forwarddiff, x -> ForwardDiff.gradient(f, x), x)
@adjoint ForwardDiff.jacobian(f, x) = pullback(forwarddiff, x -> ForwardDiff.jacobian(f, x), x)
From 734a8d563281a782775f629ab6bbd8f4abc8e7b8 Mon Sep 17 00:00:00 2001
From: Michael Abbott <32575566+mcabbott@users.noreply.github.com>
Date: Fri, 14 May 2021 17:57:25 -0400
Subject: [PATCH 038/490] treat derivative and hessian the same way
---
src/lib/forward.jl | 5 +++++
test/utils.jl | 9 +++++++++
2 files changed, 14 insertions(+)
diff --git a/src/lib/forward.jl b/src/lib/forward.jl
index 113b6aa6c..21f39e29c 100644
--- a/src/lib/forward.jl
+++ b/src/lib/forward.jl
@@ -106,3 +106,8 @@ end
# see https://github.com/FluxML/Zygote.jl/issues/769 for a forward-over-reverse proposal
@adjoint ForwardDiff.gradient(f, x) = pullback(forwarddiff, x -> ForwardDiff.gradient(f, x), x)
@adjoint ForwardDiff.jacobian(f, x) = pullback(forwarddiff, x -> ForwardDiff.jacobian(f, x), x)
+
+@adjoint ForwardDiff.derivative(f, x) = pullback(forwarddiff, x -> ForwardDiff.derivative(f, x), x)
+@adjoint ForwardDiff.hessian(f, x) = pullback(forwarddiff, x -> ForwardDiff.hessian(f, x), x)
+
+
diff --git a/test/utils.jl b/test/utils.jl
index 473277b5b..dcf877c7d 100644
--- a/test/utils.jl
+++ b/test/utils.jl
@@ -73,6 +73,15 @@ using ForwardDiff
x2 = rand(5) .+ 1
@test Zygote.jacobian(f2, x2)[1] ≈ ForwardDiff.jacobian(f2, x2)
+ f3(x) = sum(ForwardDiff.hessian(x -> sum(x .^2 .* x'), x)[1:4:end])
+ x3 = rand(3)
+ @test Zygote.gradient(f3, x3)[1] ≈ ForwardDiff.gradient(f3, x3)
+
+ @test gradient(x -> ForwardDiff.derivative(x -> x^4, x), 7) == (4 * 3 * 7^2,)
+
+ f4(x) = ForwardDiff.derivative(x -> [x,x^2,x^3], x)
+ @test Zygote.jacobian(f4, pi)[1] ≈ ForwardDiff.derivative(f4, pi)
+
# Tests from https://github.com/FluxML/Zygote.jl/issues/769
f(x) = [2x[1]^2 + x[1],x[2]^2 * x[1]]
g1(x) = sum(ForwardDiff.jacobian(f,x))
From c12cfca39974855172bdcb9721348cf345fb0b44 Mon Sep 17 00:00:00 2001
From: CarloLucibello
Date: Sat, 15 May 2021 10:20:18 +0200
Subject: [PATCH 039/490] use IdDict instead of Dict
---
docs/src/utils.md | 4 ++--
src/compiler/interface.jl | 1 +
test/cuda.jl | 10 ++++------
test/interface.jl | 4 ++--
4 files changed, 9 insertions(+), 10 deletions(-)
diff --git a/docs/src/utils.md b/docs/src/utils.md
index c46b646ce..92c8df03a 100644
--- a/docs/src/utils.md
+++ b/docs/src/utils.md
@@ -42,8 +42,8 @@ gs = gs1 .+ gs2
@test gs[w] ≈ gs1[w] + gs2[w]
@test gs[b] ≈ gs1[b] + gs2[b]
-# gradients and dictionaries interact nicely
-gs .+= Dict(p => randn(size(p)) for p in keys(gs))
+# gradients and IdDict interact nicely
+gs .+= IdDict(p => randn(size(p)) for p in keys(gs))
# clip gradients
map(x -> clamp.(x, -0.1, 0.1), gs)
diff --git a/src/compiler/interface.jl b/src/compiler/interface.jl
index 2e2725480..1dbf73be6 100644
--- a/src/compiler/interface.jl
+++ b/src/compiler/interface.jl
@@ -81,6 +81,7 @@ end
Base.copy(ps::Params) = union!(Params(), ps)
Base.union(ps::Params, itrs...) = union!(copy(ps), itrs...)
Base.issetequal(ps1::Params, ps2::Params) = issetequal(ps1.params, ps2.params)
+Base.issetequal(ps1::Params, x::Base.AbstractSet) = issetequal(ps1.params, x)
function Base.intersect!(ps::Params, itrs...)
for itr in itrs
diff --git a/test/cuda.jl b/test/cuda.jl
index e08f9a500..a54402999 100644
--- a/test/cuda.jl
+++ b/test/cuda.jl
@@ -74,11 +74,9 @@ end
@test (gs3 .+ gs4)[w] ≈ gs3[w] .+ gs4[w]
@test (gs3 .+ gs4)[b] ≈ gs4[b]
- @test_broken begin
- gs3 .+ Dict(w => similar(w), b => similar(b)) isa Grads
- gs3 .+= Dict(p => randn!(similar(p)) for p in keys(gs3))
- gs3 isa Grads
- end
+ @test gs3 .+ IdDict(w => similar(w), b => similar(b)) isa Grads
+ gs3 .+= IdDict(p => randn!(similar(p)) for p in keys(gs3))
+ @test gs3 isa Grads
@test_throws ArgumentError gs1 .+ gs4
-end
\ No newline at end of file
+end
diff --git a/test/interface.jl b/test/interface.jl
index 584228c84..0ffb933f6 100644
--- a/test/interface.jl
+++ b/test/interface.jl
@@ -116,8 +116,8 @@ end
@test (gs3 .+ gs4)[w] ≈ gs3[w] .+ gs4[w]
@test (gs3 .+ gs4)[b] ≈ gs4[b]
- @test gs3 .+ Dict(w => similar(w), b => similar(b)) isa Grads
- gs3 .+= Dict(p => randn(size(p)) for p in keys(gs3))
+ @test gs3 .+ IdDict(w => similar(w), b => similar(b)) isa Grads
+ gs3 .+= IdDict(p => randn(size(p)) for p in keys(gs3))
@test gs3 isa Grads
@test_throws ArgumentError gs1 .+ gs4
From f0708406fb21743a07b83fb15be5fa3dc90d0008 Mon Sep 17 00:00:00 2001
From: Michael Abbott <32575566+mcabbott@users.noreply.github.com>
Date: Sat, 15 May 2021 11:39:09 -0400
Subject: [PATCH 040/490] un-broadcast using mapreduce
---
src/lib/broadcast.jl | 21 ++++++++++++++++++---
test/cuda.jl | 8 ++++++++
test/gradcheck.jl | 6 ++++++
3 files changed, 32 insertions(+), 3 deletions(-)
diff --git a/src/lib/broadcast.jl b/src/lib/broadcast.jl
index 1219d7883..6f0774190 100644
--- a/src/lib/broadcast.jl
+++ b/src/lib/broadcast.jl
@@ -71,12 +71,27 @@ unbroadcast(x::AbstractArray, x̄::Nothing) = nothing
@adjoint broadcasted(::typeof(-), x::Numeric, y::Numeric) = x .- y,
Δ -> (nothing, unbroadcast(x, Δ), -unbroadcast(y, Δ))
-@adjoint broadcasted(::typeof(*), x::Numeric, y::Numeric) = x.*y,
- z̄ -> (nothing, unbroadcast(x, z̄ .* conj.(y)), unbroadcast(y, z̄ .* conj.(x)))
+@adjoint broadcasted(::typeof(*), x::Numeric, y::Numeric) = x .* y,
+ Δ -> (nothing, mul_unbroadcast(x, Δ, y), mul_unbroadcast(y, Δ, x))
+
+mul_unbroadcast(x, Δ, y) = funbroadcast(x, (δ,y₁) -> δ * conj(y₁), Δ, y)
+
+# This optimisation is only safe when all args... have same size:
+funbroadcast(::Number, f, args...) = mapreduce(f, +, args...)
+funbroadcast(x, f, args...) = unbroadcast(x, f.(args...)) # fallback
+
+@adjoint function broadcasted(::typeof(/), x::AbstractArray{<:Number}, y::Number)
+ res = x ./ y
+ res, Δ -> begin
+ Δx = funbroadcast(x, δ -> δ / conj(y), Δ)
+ Δy = funbroadcast(y, (δ,r) -> -δ * conj(r / y), Δ, res)
+ (nothing, Δx, Δy)
+ end
+end
@adjoint function broadcasted(::typeof(/), x::Numeric, y::Numeric)
res = x ./ y
- res, Δ -> (nothing, unbroadcast(x, Δ ./ conj.(y)), unbroadcast(y, -Δ .* conj.(res ./ y)))
+ res, Δ -> (nothing, unbroadcast(x, Δ ./ conj.(y)), unbroadcast(y, .-Δ .* conj.(res ./ y)))
end
@adjoint function broadcasted(::typeof(Base.literal_pow), ::typeof(^), x::Numeric, exp::Val{p}) where p
diff --git a/test/cuda.jl b/test/cuda.jl
index 0766ff986..20cd0d14e 100644
--- a/test/cuda.jl
+++ b/test/cuda.jl
@@ -24,6 +24,14 @@ end
end
+@testset "un-broadcasting *, / with mapreduce" begin
+ cu12 = cu(Float32[1,2])
+ @test gradient((x,y) -> sum(x .* y), cu12, 5) == ([5, 5], 3)
+ @test gradient((x,y) -> sum(x .* y), 5, cu12) == (3, [5, 5])
+ @test gradient((x,y) -> sum(x .* y), cu12, [3 4 5]) == ([12, 12], [3 3 3])
+ @test gradient((x,y) -> sum(x ./ y), cu12, 5) == ([0.2, 0.2], -0.12)
+end
+
@testset "jacobian" begin
v1 = cu(collect(1:3f0))
diff --git a/test/gradcheck.jl b/test/gradcheck.jl
index 7e713f950..2d5499dcd 100644
--- a/test/gradcheck.jl
+++ b/test/gradcheck.jl
@@ -1308,6 +1308,12 @@ end
x1 = rand(3, 3)
@test gradient(x -> sum(x .== 0.5), x1)[1] === nothing
@test gradient(x -> sum(x .* (x .== maximum(x, dims=1))), x1)[1] == (x1 .== maximum(x1, dims=1))
+
+ # tests for un-broadcasting *, / with mapreduce
+ @test gradient((x,y) -> sum(x .* y), [1,2], 5) == ([5, 5], 3)
+ @test gradient((x,y) -> sum(x .* y), 5, [1,2]) == (3, [5, 5])
+ @test gradient((x,y) -> sum(x .* y), [1,2], [3 4 5]) == ([12, 12], [3 3 3])
+ @test gradient((x,y) -> sum(x ./ y), [1,2], 5) == ([0.2, 0.2], -0.12)
end
using Zygote: Buffer
From f0511296cc55b6f69e0f2ead4dea3563ad48b003 Mon Sep 17 00:00:00 2001
From: Michael Abbott <32575566+mcabbott@users.noreply.github.com>
Date: Sat, 15 May 2021 12:11:19 -0400
Subject: [PATCH 041/490] cu test
---
test/cuda.jl | 5 +++--
1 file changed, 3 insertions(+), 2 deletions(-)
diff --git a/test/cuda.jl b/test/cuda.jl
index 20cd0d14e..1ca2debb9 100644
--- a/test/cuda.jl
+++ b/test/cuda.jl
@@ -28,8 +28,9 @@ end
cu12 = cu(Float32[1,2])
@test gradient((x,y) -> sum(x .* y), cu12, 5) == ([5, 5], 3)
@test gradient((x,y) -> sum(x .* y), 5, cu12) == (3, [5, 5])
- @test gradient((x,y) -> sum(x .* y), cu12, [3 4 5]) == ([12, 12], [3 3 3])
- @test gradient((x,y) -> sum(x ./ y), cu12, 5) == ([0.2, 0.2], -0.12)
+ cu345 = cu(Float32[3 4 5])
+ @test all(gradient((x,y) -> sum(x .* y), cu12, cu345) .≈ ([12, 12], [3 3 3]))
+ @test all(gradient((x,y) -> sum(x ./ y), cu12, 5) .≈ ([0.2, 0.2], -0.12))
end
@testset "jacobian" begin
From 189a4bcb1d69fcedfaaf234a300a05e24d68c84b Mon Sep 17 00:00:00 2001
From: Michael Abbott <32575566+mcabbott@users.noreply.github.com>
Date: Sat, 15 May 2021 18:16:08 -0400
Subject: [PATCH 042/490] tests
---
src/lib/grad.jl | 4 ++--
test/cuda.jl | 8 +++++---
2 files changed, 7 insertions(+), 5 deletions(-)
diff --git a/src/lib/grad.jl b/src/lib/grad.jl
index 7a1b0bdd8..0bedb0793 100644
--- a/src/lib/grad.jl
+++ b/src/lib/grad.jl
@@ -124,7 +124,7 @@ julia> jacobian((a,d) -> prod(a, dims=d), [1 2; 3 4; 5 6], 2)
!!! warning
For arguments of any type except `Number` & `AbstractArray`, the result is `nothing`.
-```jldoctest; setup=:(using Zygote)
+```
julia> jacobian((a,s) -> a.^length(s), [1,2,3], "str")
([3 0 0; 0 12 0; 0 0 27], nothing)
@@ -132,7 +132,7 @@ julia> jacobian((a,t) -> sum(a .* t[1]) + t[2], [1,2,3], (4,5))
([4 4 4], nothing)
julia> gradient((a,t) -> sum(a .* t[1]) + t[2], [1,2,3], (4,5)) # gradient undersands the tuple
-([4, 4, 4], (6, 1))
+(Fill(4, 3), (6, 1))
```
"""
function jacobian(f, args...)
diff --git a/test/cuda.jl b/test/cuda.jl
index 1ca2debb9..55eb6b7ac 100644
--- a/test/cuda.jl
+++ b/test/cuda.jl
@@ -26,11 +26,13 @@ end
@testset "un-broadcasting *, / with mapreduce" begin
cu12 = cu(Float32[1,2])
- @test gradient((x,y) -> sum(x .* y), cu12, 5) == ([5, 5], 3)
+ @test gradient((x,y) -> sum(x .* y), cu12, 5) == ([5, 5]), 3)
@test gradient((x,y) -> sum(x .* y), 5, cu12) == (3, [5, 5])
+ @test gradient((x,y) -> sum(z -> z, x .* y), cu12, 5) == ([5, 5], 3)
+ @test gradient((x,y) -> sum(z -> z, x .* y), 5, cu12) == (3, [5, 5])
cu345 = cu(Float32[3 4 5])
- @test all(gradient((x,y) -> sum(x .* y), cu12, cu345) .≈ ([12, 12], [3 3 3]))
- @test all(gradient((x,y) -> sum(x ./ y), cu12, 5) .≈ ([0.2, 0.2], -0.12))
+ @test all(gradient((x,y) -> sum(x .* y), cu12, cu345) .≈ (cu([12, 12]), cu([3 3 3])))
+ @test all(gradient((x,y) -> sum(x ./ y), cu12, 5) .≈ (cu([0.2, 0.2]), -0.12))
end
@testset "jacobian" begin
From 014c13122f5e16ed54152c11ede8f7f6a4ac19a9 Mon Sep 17 00:00:00 2001
From: Michael Abbott <32575566+mcabbott@users.noreply.github.com>
Date: Sun, 16 May 2021 07:07:29 -0400
Subject: [PATCH 043/490] tests, III
---
test/cuda.jl | 9 ++++-----
1 file changed, 4 insertions(+), 5 deletions(-)
diff --git a/test/cuda.jl b/test/cuda.jl
index 55eb6b7ac..86d902670 100644
--- a/test/cuda.jl
+++ b/test/cuda.jl
@@ -26,12 +26,11 @@ end
@testset "un-broadcasting *, / with mapreduce" begin
cu12 = cu(Float32[1,2])
- @test gradient((x,y) -> sum(x .* y), cu12, 5) == ([5, 5]), 3)
- @test gradient((x,y) -> sum(x .* y), 5, cu12) == (3, [5, 5])
- @test gradient((x,y) -> sum(z -> z, x .* y), cu12, 5) == ([5, 5], 3)
- @test gradient((x,y) -> sum(z -> z, x .* y), 5, cu12) == (3, [5, 5])
+ cu55 = cu(Float32[5,5])
+ @test gradient((x,y) -> sum(x .* y), cu12, 5) == (cu55, 3)
+ @test gradient((x,y) -> sum(x .* y), 5, cu12) == (3, cu55)
cu345 = cu(Float32[3 4 5])
- @test all(gradient((x,y) -> sum(x .* y), cu12, cu345) .≈ (cu([12, 12]), cu([3 3 3])))
+ @test gradient((x,y) -> sum(x .* y), cu12, cu345) == (cu([12, 12]), cu([3 3 3]))
@test all(gradient((x,y) -> sum(x ./ y), cu12, 5) .≈ (cu([0.2, 0.2]), -0.12))
end
From d9509013f718d800a47dbdcfc51e1beb9565824d Mon Sep 17 00:00:00 2001
From: Carlo Lucibello
Date: Sun, 16 May 2021 18:00:19 +0200
Subject: [PATCH 044/490] Update src/compiler/interface.jl
---
src/compiler/interface.jl | 1 +
1 file changed, 1 insertion(+)
diff --git a/src/compiler/interface.jl b/src/compiler/interface.jl
index 1dbf73be6..3c62d6586 100644
--- a/src/compiler/interface.jl
+++ b/src/compiler/interface.jl
@@ -82,6 +82,7 @@ Base.copy(ps::Params) = union!(Params(), ps)
Base.union(ps::Params, itrs...) = union!(copy(ps), itrs...)
Base.issetequal(ps1::Params, ps2::Params) = issetequal(ps1.params, ps2.params)
Base.issetequal(ps1::Params, x::Base.AbstractSet) = issetequal(ps1.params, x)
+Base.issetequal(x::Base.AbstractSet, ps1::Params) = issetequal(x, ps1.params)
function Base.intersect!(ps::Params, itrs...)
for itr in itrs
From 75e2fb7942f8c5d84e62e548101ffd8730f5c497 Mon Sep 17 00:00:00 2001
From: Michael Abbott <32575566+mcabbott@users.noreply.github.com>
Date: Sun, 16 May 2021 12:25:46 -0400
Subject: [PATCH 045/490] fill example
---
src/lib/grad.jl | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/src/lib/grad.jl b/src/lib/grad.jl
index 0bedb0793..7462ef4ed 100644
--- a/src/lib/grad.jl
+++ b/src/lib/grad.jl
@@ -132,7 +132,7 @@ julia> jacobian((a,t) -> sum(a .* t[1]) + t[2], [1,2,3], (4,5))
([4 4 4], nothing)
julia> gradient((a,t) -> sum(a .* t[1]) + t[2], [1,2,3], (4,5)) # gradient undersands the tuple
-(Fill(4, 3), (6, 1))
+([4 4 4], (6, 1))
```
"""
function jacobian(f, args...)
From 6dd98dd389066c0d4de92488da2a0a64873c034d Mon Sep 17 00:00:00 2001
From: Kyle Daruwalla
Date: Sun, 16 May 2021 12:34:28 -0500
Subject: [PATCH 046/490] Update docs/src/utils.md
---
docs/src/utils.md | 1 +
1 file changed, 1 insertion(+)
diff --git a/docs/src/utils.md b/docs/src/utils.md
index 92c8df03a..84596357f 100644
--- a/docs/src/utils.md
+++ b/docs/src/utils.md
@@ -43,6 +43,7 @@ gs = gs1 .+ gs2
@test gs[b] ≈ gs1[b] + gs2[b]
# gradients and IdDict interact nicely
+# note that an IdDict must be used for gradient algebra on the GPU
gs .+= IdDict(p => randn(size(p)) for p in keys(gs))
# clip gradients
From 69a4ca021517b6b3d93d68cecb4d5b518774fce8 Mon Sep 17 00:00:00 2001
From: Michael Abbott <32575566+mcabbott@users.noreply.github.com>
Date: Sun, 16 May 2021 14:17:28 -0400
Subject: [PATCH 047/490] change to use existing scalar * array rules
---
src/lib/broadcast.jl | 29 +++++++++++++----------------
test/cuda.jl | 4 +++-
2 files changed, 16 insertions(+), 17 deletions(-)
diff --git a/src/lib/broadcast.jl b/src/lib/broadcast.jl
index 6f0774190..8694b394c 100644
--- a/src/lib/broadcast.jl
+++ b/src/lib/broadcast.jl
@@ -71,28 +71,25 @@ unbroadcast(x::AbstractArray, x̄::Nothing) = nothing
@adjoint broadcasted(::typeof(-), x::Numeric, y::Numeric) = x .- y,
Δ -> (nothing, unbroadcast(x, Δ), -unbroadcast(y, Δ))
-@adjoint broadcasted(::typeof(*), x::Numeric, y::Numeric) = x .* y,
- Δ -> (nothing, mul_unbroadcast(x, Δ, y), mul_unbroadcast(y, Δ, x))
-
-mul_unbroadcast(x, Δ, y) = funbroadcast(x, (δ,y₁) -> δ * conj(y₁), Δ, y)
-
-# This optimisation is only safe when all args... have same size:
-funbroadcast(::Number, f, args...) = mapreduce(f, +, args...)
-funbroadcast(x, f, args...) = unbroadcast(x, f.(args...)) # fallback
-
-@adjoint function broadcasted(::typeof(/), x::AbstractArray{<:Number}, y::Number)
- res = x ./ y
- res, Δ -> begin
- Δx = funbroadcast(x, δ -> δ / conj(y), Δ)
- Δy = funbroadcast(y, (δ,r) -> -δ * conj(r / y), Δ, res)
- (nothing, Δx, Δy)
- end
+@adjoint broadcasted(::typeof(*), x::Numeric, y::Numeric) = x.*y,
+ Δ -> (nothing, unbroadcast(x, Δ .* conj.(y)), unbroadcast(y, Δ .* conj.(x)))
+@adjoint function broadcasted(::typeof(*), x::Number, y::AbstractArray{<:Number})
+ z, back = pullback(*, x, y) # this uses dot(y,Δ) instead of Δ .* conj.(y)
+ z, Δ -> (nothing, back(Δ)...)
+end
+@adjoint function broadcasted(::typeof(*), x::AbstractArray{<:Number}, y::Number)
+ z, back = pullback(*, x, y)
+ z, Δ -> (nothing, back(Δ)...)
end
@adjoint function broadcasted(::typeof(/), x::Numeric, y::Numeric)
res = x ./ y
res, Δ -> (nothing, unbroadcast(x, Δ ./ conj.(y)), unbroadcast(y, .-Δ .* conj.(res ./ y)))
end
+@adjoint function broadcasted(::typeof(/), x::AbstractArray{<:Number}, y::Number)
+ z, back = pullback(/, x, y)
+ z, Δ -> (nothing, back(Δ)...)
+end
@adjoint function broadcasted(::typeof(Base.literal_pow), ::typeof(^), x::Numeric, exp::Val{p}) where p
y = Base.literal_pow.(^, x, exp)
diff --git a/test/cuda.jl b/test/cuda.jl
index 86d902670..b45d066c4 100644
--- a/test/cuda.jl
+++ b/test/cuda.jl
@@ -24,11 +24,13 @@ end
end
-@testset "un-broadcasting *, / with mapreduce" begin
+@testset "un-broadcasting .*, ./ with scalars" begin
cu12 = cu(Float32[1,2])
cu55 = cu(Float32[5,5])
@test gradient((x,y) -> sum(x .* y), cu12, 5) == (cu55, 3)
@test gradient((x,y) -> sum(x .* y), 5, cu12) == (3, cu55)
+ # @test gradient((x,y) -> sum(z -> z, x .* y), cu12, 5) == (cu55, 3)
+ # @test gradient((x,y) -> sum(z -> z, x .* y), 5, cu12) == (3, cu55)
cu345 = cu(Float32[3 4 5])
@test gradient((x,y) -> sum(x .* y), cu12, cu345) == (cu([12, 12]), cu([3 3 3]))
@test all(gradient((x,y) -> sum(x ./ y), cu12, 5) .≈ (cu([0.2, 0.2]), -0.12))
From 615100fd6730e6e0fe8ec7426c625de28f77e71c Mon Sep 17 00:00:00 2001
From: Michael Abbott <32575566+mcabbott@users.noreply.github.com>
Date: Sun, 16 May 2021 14:31:37 -0400
Subject: [PATCH 048/490] tests
---
test/cuda.jl | 12 ------------
test/gradcheck.jl | 10 +++++-----
2 files changed, 5 insertions(+), 17 deletions(-)
diff --git a/test/cuda.jl b/test/cuda.jl
index b45d066c4..0766ff986 100644
--- a/test/cuda.jl
+++ b/test/cuda.jl
@@ -24,18 +24,6 @@ end
end
-@testset "un-broadcasting .*, ./ with scalars" begin
- cu12 = cu(Float32[1,2])
- cu55 = cu(Float32[5,5])
- @test gradient((x,y) -> sum(x .* y), cu12, 5) == (cu55, 3)
- @test gradient((x,y) -> sum(x .* y), 5, cu12) == (3, cu55)
- # @test gradient((x,y) -> sum(z -> z, x .* y), cu12, 5) == (cu55, 3)
- # @test gradient((x,y) -> sum(z -> z, x .* y), 5, cu12) == (3, cu55)
- cu345 = cu(Float32[3 4 5])
- @test gradient((x,y) -> sum(x .* y), cu12, cu345) == (cu([12, 12]), cu([3 3 3]))
- @test all(gradient((x,y) -> sum(x ./ y), cu12, 5) .≈ (cu([0.2, 0.2]), -0.12))
-end
-
@testset "jacobian" begin
v1 = cu(collect(1:3f0))
diff --git a/test/gradcheck.jl b/test/gradcheck.jl
index 2d5499dcd..43f482d01 100644
--- a/test/gradcheck.jl
+++ b/test/gradcheck.jl
@@ -1309,11 +1309,11 @@ end
@test gradient(x -> sum(x .== 0.5), x1)[1] === nothing
@test gradient(x -> sum(x .* (x .== maximum(x, dims=1))), x1)[1] == (x1 .== maximum(x1, dims=1))
- # tests for un-broadcasting *, / with mapreduce
- @test gradient((x,y) -> sum(x .* y), [1,2], 5) == ([5, 5], 3)
- @test gradient((x,y) -> sum(x .* y), 5, [1,2]) == (3, [5, 5])
- @test gradient((x,y) -> sum(x .* y), [1,2], [3 4 5]) == ([12, 12], [3 3 3])
- @test gradient((x,y) -> sum(x ./ y), [1,2], 5) == ([0.2, 0.2], -0.12)
+ # tests for un-broadcasting *, / via scalar rules
+ @test all(gradient((x,y) -> sum(x .* y), [1,2], 5) .≈ ([5, 5], 3))
+ @test all(gradient((x,y) -> sum(x .* y), 5, [1,2]) .≈ (3, [5, 5]))
+ @test all(gradient((x,y) -> sum(x .* y), [1,2], [3 4 5]) .≈ ([12, 12], [3 3 3]))
+ @test all(gradient((x,y) -> sum(x ./ y), [1,2], 5) .≈ ([0.2, 0.2], -0.12))
end
using Zygote: Buffer
From 6dca3a96a66029bd41d0546897321cc3266f009f Mon Sep 17 00:00:00 2001
From: Miha Zgubic
Date: Wed, 19 May 2021 09:50:55 +0100
Subject: [PATCH 049/490] rename chain rules differential types
---
src/compiler/chainrules.jl | 6 +++---
test/chainrules.jl | 4 ++--
2 files changed, 5 insertions(+), 5 deletions(-)
diff --git a/src/compiler/chainrules.jl b/src/compiler/chainrules.jl
index c573d98f7..8392a27ce 100644
--- a/src/compiler/chainrules.jl
+++ b/src/compiler/chainrules.jl
@@ -47,7 +47,7 @@ for T_outer in (:Tuple, :NamedTuple)
# we create separate methods rather than using a `Union` + an `if` so that we avoid a
# branch that changes output type, because nested AD on that kinda thing makes Zygote less
# than happy.
- @eval @inline function wrap_chainrules_output(x::ChainRules.Composite{P, T}) where {P, T<:$T_outer}
+ @eval @inline function wrap_chainrules_output(x::ChainRules.Tangent{P, T}) where {P, T<:$T_outer}
xp = map(wrap_chainrules_output, canonicalize(x))
convert($T_outer, xp)
end
@@ -59,10 +59,10 @@ end
Convert `x` from the format Zygote uses internally to differentials types ChainRules uses.
"""
@inline wrap_chainrules_input(x) = x
-@inline wrap_chainrules_input(::Nothing) = ChainRules.Zero()
+@inline wrap_chainrules_input(::Nothing) = ChainRules.ZeroTangent()
@inline function wrap_chainrules_input(xs::Union{Tuple, NamedTuple})
xp = map(wrap_chainrules_input, xs)
- ChainRules.Composite{Any, typeof(xp)}(xp)
+ ChainRules.Tangent{Any, typeof(xp)}(xp)
end
"""
diff --git a/test/chainrules.jl b/test/chainrules.jl
index 7fd8c6be5..8b7034753 100644
--- a/test/chainrules.jl
+++ b/test/chainrules.jl
@@ -131,7 +131,7 @@ using Zygote, Test, ChainRules
not_diff_eg(x, i) = [10, 20][i]
function ChainRules.rrule(::typeof(not_diff_eg), x, i)
function not_diff_eg_pullback(Δ)
- return ChainRules.NO_FIELDS, ChainRules.Zero(), ChainRules.DoesNotExist()
+ return ChainRules.NO_FIELDS, ChainRules.ZeroTangent(), ChainRules.NoTangent()
end
return not_diff_eg(x, i), not_diff_eg_pullback
end
@@ -204,7 +204,7 @@ using Zygote, Test, ChainRules
not_diff_kw_eg(x, i; kw=1.0) = [10, 20][i]
function ChainRules.rrule(::typeof(not_diff_kw_eg), x, i; kwargs...)
function not_diff_kw_eg_pullback(Δ)
- return ChainRules.NO_FIELDS, ChainRules.Zero(), ChainRules.DoesNotExist()
+ return ChainRules.NO_FIELDS, ChainRules.ZeroTangent(), ChainRules.NoTangent()
end
return not_diff_kw_eg(x, i; kwargs...), not_diff_kw_eg_pullback
end
From f583d0dc2248f52f1a8fd8ef88cbe187bc6f3634 Mon Sep 17 00:00:00 2001
From: Miha Zgubic
Date: Wed, 19 May 2021 09:51:11 +0100
Subject: [PATCH 050/490] version bump and compat bump
---
Project.toml | 4 ++--
1 file changed, 2 insertions(+), 2 deletions(-)
diff --git a/Project.toml b/Project.toml
index 32849f933..0bdee524c 100644
--- a/Project.toml
+++ b/Project.toml
@@ -1,6 +1,6 @@
name = "Zygote"
uuid = "e88e6eb3-aa80-5325-afca-941959d7151f"
-version = "0.6.10"
+version = "0.6.11"
[deps]
AbstractFFTs = "621f4979-c628-5d54-868e-fcf4e3e8185c"
@@ -24,7 +24,7 @@ ZygoteRules = "700de1a5-db45-46bc-99cf-38207098b444"
[compat]
AbstractFFTs = "0.5, 1.0"
ChainRules = "0.7.55"
-ChainRulesCore = "0.9.32"
+ChainRulesCore = "0.9.44"
DiffRules = "1.0"
FillArrays = "0.8, 0.9, 0.10, 0.11"
ForwardDiff = "0.10"
From 07f1f94dac577e411932616c9afd8bca04fa692e Mon Sep 17 00:00:00 2001
From: Michael Abbott <32575566+mcabbott@users.noreply.github.com>
Date: Wed, 19 May 2021 13:55:05 -0400
Subject: [PATCH 051/490] tuple un-broadcast
---
src/lib/broadcast.jl | 2 ++
test/features.jl | 11 +++++++++++
2 files changed, 13 insertions(+)
diff --git a/src/lib/broadcast.jl b/src/lib/broadcast.jl
index 1219d7883..ddadfcbcd 100644
--- a/src/lib/broadcast.jl
+++ b/src/lib/broadcast.jl
@@ -46,6 +46,7 @@ function Base.reducedim_init(::typeof(identity), ::typeof(accum), A::AbstractArr
end
trim(x, Δ) = reshape(Δ, ntuple(i -> size(Δ, i), Val(ndims(x))))
+trim(x::Tuple, Δ) = ntuple(k -> Δ[k], length(x))
unbroadcast(x::AbstractArray, x̄) =
size(x) == size(x̄) ? x̄ :
@@ -55,6 +56,7 @@ unbroadcast(x::AbstractArray, x̄) =
unbroadcast(x::Number, x̄) = accum_sum(x̄)
unbroadcast(x::Tuple{<:Any}, x̄) = (accum_sum(x̄),)
unbroadcast(x::Base.RefValue, x̄) = (x=accum_sum(x̄),)
+unbroadcast(x::Tuple, x̄) = trim(x, accum_sum(x̄; dims=2:ndims(x̄))) # case length(x) > 1
unbroadcast(x::AbstractArray, x̄::Nothing) = nothing
diff --git a/test/features.jl b/test/features.jl
index 48df0c87c..254034202 100644
--- a/test/features.jl
+++ b/test/features.jl
@@ -481,3 +481,14 @@ end
Zygote.gradient(loss_adjoint,[1.0])
@test x[1] == x[2]
end
+
+@testset "tuples & broadcasting" begin
+ @test gradient(x -> sum(x .+ ones(2,2)), (1,2)) == ((2,2),)
+ @test gradient(x -> sum(x .+ ones(2,2)), (1,)) == ((4,),)
+
+ # https://github.com/FluxML/Zygote.jl/issues/975
+ gt = gradient((x,p) -> prod(x .^ p), [3,4], (1,2))
+ gv = gradient((x,p) -> prod(x .^ p), [3,4], [1,2])
+ @test gt[1] == gv[1]
+ @test collect(gt[2]) ≈ gv[2]
+end
From 2dc6a51c225634b4db8bb55aa02306ed2d8a7bd5 Mon Sep 17 00:00:00 2001
From: Marius Millea
Date: Wed, 19 May 2021 13:38:24 -0700
Subject: [PATCH 052/490] minor typo fix
Minorest of typos, makes it so the user wont confuse the first clause as declarative.
---
src/compiler/interface.jl | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/src/compiler/interface.jl b/src/compiler/interface.jl
index 3c62d6586..b066b73f6 100644
--- a/src/compiler/interface.jl
+++ b/src/compiler/interface.jl
@@ -43,7 +43,7 @@ end
sensitivity(y::Number) = one(y)
sensitivity(y::Complex) = error("Output is complex, so the gradient is not defined.")
-sensitivity(y::AbstractArray) = error("output an array, so the gradient is not defined. Perhaps you wanted jacobian.")
+sensitivity(y::AbstractArray) = error("Output is an array, so the gradient is not defined. Perhaps you wanted jacobian.")
sensitivity(y) = error("Output should be scalar; gradients are not defined for output $(repr(y))")
"""
From 33f1d6de9f1fddae3a0fc166f66e921b2e219d9b Mon Sep 17 00:00:00 2001
From: Michael Abbott <32575566+mcabbott@users.noreply.github.com>
Date: Thu, 20 May 2021 08:37:01 -0400
Subject: [PATCH 053/490] skip the sum, sometimes
---
src/lib/broadcast.jl | 2 +-
test/features.jl | 1 +
2 files changed, 2 insertions(+), 1 deletion(-)
diff --git a/src/lib/broadcast.jl b/src/lib/broadcast.jl
index ddadfcbcd..451d8794c 100644
--- a/src/lib/broadcast.jl
+++ b/src/lib/broadcast.jl
@@ -56,7 +56,7 @@ unbroadcast(x::AbstractArray, x̄) =
unbroadcast(x::Number, x̄) = accum_sum(x̄)
unbroadcast(x::Tuple{<:Any}, x̄) = (accum_sum(x̄),)
unbroadcast(x::Base.RefValue, x̄) = (x=accum_sum(x̄),)
-unbroadcast(x::Tuple, x̄) = trim(x, accum_sum(x̄; dims=2:ndims(x̄))) # case length(x) > 1
+unbroadcast(x::Tuple, x̄) = trim(x, length(x) == length(x̄) ? x̄ : accum_sum(x̄; dims=2:ndims(x̄))) # case length(x) > 1
unbroadcast(x::AbstractArray, x̄::Nothing) = nothing
diff --git a/test/features.jl b/test/features.jl
index 254034202..6843acbf6 100644
--- a/test/features.jl
+++ b/test/features.jl
@@ -485,6 +485,7 @@ end
@testset "tuples & broadcasting" begin
@test gradient(x -> sum(x .+ ones(2,2)), (1,2)) == ((2,2),)
@test gradient(x -> sum(x .+ ones(2,2)), (1,)) == ((4,),)
+ @test gradient(x -> sum(x .+ ones(2,1)), (1,2)) == ((1,1),)
# https://github.com/FluxML/Zygote.jl/issues/975
gt = gradient((x,p) -> prod(x .^ p), [3,4], (1,2))
From 892022ca20512ceef1342f6705ec8209a8acf078 Mon Sep 17 00:00:00 2001
From: Miha Zgubic
Date: Thu, 20 May 2021 14:19:02 +0100
Subject: [PATCH 054/490] add dev to gitignore
---
.gitignore | 1 +
1 file changed, 1 insertion(+)
diff --git a/.gitignore b/.gitignore
index 78756acf1..aa5ffbd93 100644
--- a/.gitignore
+++ b/.gitignore
@@ -3,3 +3,4 @@
*.jl.mem
docs/build
Manifest.toml
+dev/
From 81c0aefedaf6b5a25ea092989f2a1b5fa55352f5 Mon Sep 17 00:00:00 2001
From: Dhairya Gandhi
Date: Sat, 22 May 2021 07:06:24 +0530
Subject: [PATCH 055/490] cleanup
---
src/compiler/interface.jl | 16 ++++++++--------
1 file changed, 8 insertions(+), 8 deletions(-)
diff --git a/src/compiler/interface.jl b/src/compiler/interface.jl
index b066b73f6..48d6146ac 100644
--- a/src/compiler/interface.jl
+++ b/src/compiler/interface.jl
@@ -139,8 +139,8 @@ function copy!(ps::Params, x::AbstractVector)
@assert length(x) == sum(length(p) for p in ps)
i = 0
for p in ps
- p .= reshape(x[i+1:i+length(p)], size(p))
- i += length(p)
+ p .= reshape(x[i+1:i+length(p)], size(p))
+ i += length(p)
end
ps
end
@@ -149,8 +149,8 @@ function copy!(x::AbstractVector, ps::Params)
@assert length(x) == sum(length(p) for p in ps)
i = 0
for p in ps
- x[i+1:i+length(p)] .= vec(p)
- i += length(p)
+ x[i+1:i+length(p)] .= vec(p)
+ i += length(p)
end
ps
end
@@ -196,8 +196,8 @@ length of `x` has to be equal to the sum of the lengths of all gradients.
function copy!(gs::Grads, x::AbstractVector)
i = 0
for p in gs.params
- gs[p] .= reshape(x[i+1:i+length(p)], size(p))
- i += length(p)
+ gs[p] .= reshape(x[i+1:i+length(p)], size(p))
+ i += length(p)
end
x
end
@@ -205,8 +205,8 @@ end
function copy!(x::AbstractVector, gs::Grads)
i = 0
for p in gs.params
- x[i+1:i+length(p)] .= vec(gs[p])
- i += length(p)
+ x[i+1:i+length(p)] .= vec(gs[p])
+ i += length(p)
end
x
end
From 1f492a7d6c3fc32c787cd96a95c1b90478fc615f Mon Sep 17 00:00:00 2001
From: Michael Abbott <32575566+mcabbott@users.noreply.github.com>
Date: Tue, 25 May 2021 10:34:00 -0400
Subject: [PATCH 056/490] add docstring for OneElement
---
src/lib/array.jl | 11 ++++++++---
1 file changed, 8 insertions(+), 3 deletions(-)
diff --git a/src/lib/array.jl b/src/lib/array.jl
index d7fedd4d5..9a8f94d4d 100644
--- a/src/lib/array.jl
+++ b/src/lib/array.jl
@@ -46,15 +46,20 @@ end
return (dx, map(_->nothing, inds)...)
end
+"""
+ OneElement(val, ind, axes) <: AbstractArray
+
+Extremely simple `struct` used for the gradient of scalar `getindex`.
+"""
struct OneElement{T,N,I,A} <: AbstractArray{T,N}
val::T
- index::I
+ ind::I
axes::A
- OneElement(x::T, i::I, a::A) where {T,I<:NTuple{N,Int},A} where {N} = new{T,N,I,A}(x, i, a)
+ OneElement(val::T, ind::I, axes::A) where {T<:Number, I<:NTuple{N,Int}, A} where {N} = new{T,N,I,A}(val, ind, axes)
end
Base.size(A::OneElement) = map(length, A.axes)
Base.axes(A::OneElement) = A.axes
-Base.getindex(A::OneElement{T,N}, i::Vararg{Int,N}) where {T,N} = ifelse(i==A.index, A.val, zero(T))
+Base.getindex(A::OneElement{T,N}, i::Vararg{Int,N}) where {T,N} = ifelse(i==A.ind, A.val, zero(T))
_zero(xs::AbstractArray{<:Number}, T::Type{Nothing}) = fill!(similar(xs), zero(eltype(xs)))
From 5f87aa7170c2b8157686c50247da0a86d0800a16 Mon Sep 17 00:00:00 2001
From: Lyndon White
Date: Wed, 2 Jun 2021 17:13:43 +0100
Subject: [PATCH 057/490] =Update ChainRules and ChainRulesCore
---
Project.toml | 6 +++---
1 file changed, 3 insertions(+), 3 deletions(-)
diff --git a/Project.toml b/Project.toml
index 0bdee524c..a6c376d27 100644
--- a/Project.toml
+++ b/Project.toml
@@ -1,6 +1,6 @@
name = "Zygote"
uuid = "e88e6eb3-aa80-5325-afca-941959d7151f"
-version = "0.6.11"
+version = "0.6.12"
[deps]
AbstractFFTs = "621f4979-c628-5d54-868e-fcf4e3e8185c"
@@ -23,8 +23,8 @@ ZygoteRules = "700de1a5-db45-46bc-99cf-38207098b444"
[compat]
AbstractFFTs = "0.5, 1.0"
-ChainRules = "0.7.55"
-ChainRulesCore = "0.9.44"
+ChainRules = "0.7.55, 0.8"
+ChainRulesCore = "0.9.44, 0.10"
DiffRules = "1.0"
FillArrays = "0.8, 0.9, 0.10, 0.11"
ForwardDiff = "0.10"
From 1e1f60ede2097e226ff9a42cb264df389d4f23ee Mon Sep 17 00:00:00 2001
From: Lyndon White
Date: Tue, 8 Jun 2021 15:25:55 +0100
Subject: [PATCH 058/490] delete rule for prod to use ChainRules'
---
src/lib/array.jl | 5 -----
1 file changed, 5 deletions(-)
diff --git a/src/lib/array.jl b/src/lib/array.jl
index 9a8f94d4d..f8389d4e3 100644
--- a/src/lib/array.jl
+++ b/src/lib/array.jl
@@ -300,11 +300,6 @@ end
return sum(abs2, X; dims=dims), Δ::Union{Number, AbstractArray}->(nothing, ((2Δ) .* X))
end
-@adjoint function prod(xs::AbstractArray; dims = :)
- p = prod(xs; dims = dims)
- p, Δ -> (p ./ xs .* Δ,)
-end
-
function _pullback(cx::AContext, ::typeof(prod), f, xs::AbstractArray)
y, back = pullback(cx, ((f, xs) -> prod(f.(xs))), f, xs)
y, ȳ -> (nothing, back(ȳ)...)
From 67dbc72575525845c8c9684d77f245484220b18c Mon Sep 17 00:00:00 2001
From: Lyndon White
Date: Tue, 8 Jun 2021 20:32:03 +0100
Subject: [PATCH 059/490] bump chainrules version
---
Project.toml | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/Project.toml b/Project.toml
index a6c376d27..2a89b4cd1 100644
--- a/Project.toml
+++ b/Project.toml
@@ -23,7 +23,7 @@ ZygoteRules = "700de1a5-db45-46bc-99cf-38207098b444"
[compat]
AbstractFFTs = "0.5, 1.0"
-ChainRules = "0.7.55, 0.8"
+ChainRules = "0.7.66, 0.8"
ChainRulesCore = "0.9.44, 0.10"
DiffRules = "1.0"
FillArrays = "0.8, 0.9, 0.10, 0.11"
From df762eb8236008d970c2c9fceba4df7993ec59ac Mon Sep 17 00:00:00 2001
From: Dhairya Gandhi
Date: Fri, 11 Jun 2021 18:18:23 +0530
Subject: [PATCH 060/490] fix #941
---
src/compiler/interface.jl | 126 ++++++++++----------------------------
1 file changed, 32 insertions(+), 94 deletions(-)
diff --git a/src/compiler/interface.jl b/src/compiler/interface.jl
index 48d6146ac..5883f4269 100644
--- a/src/compiler/interface.jl
+++ b/src/compiler/interface.jl
@@ -1,8 +1,12 @@
using InteractiveUtils
using InteractiveUtils: typesof
using Core: Typeof
-import Base: copy!
-import Base.Broadcast: broadcasted, materialize!
+
+@static if VERSION >= v"1.1"
+ import Base: copy!
+else
+ import Future: copy!
+end
mutable struct Context <: AContext
cache::Union{IdDict{Any,Any},Nothing}
@@ -43,17 +47,8 @@ end
sensitivity(y::Number) = one(y)
sensitivity(y::Complex) = error("Output is complex, so the gradient is not defined.")
-sensitivity(y::AbstractArray) = error("Output is an array, so the gradient is not defined. Perhaps you wanted jacobian.")
sensitivity(y) = error("Output should be scalar; gradients are not defined for output $(repr(y))")
-"""
- gradient(f, args...)
-
-Returns a tuple containing `∂f/∂x` for each argument `x`,
-the derivative (for scalar x) or the gradient.
-
-`f(args...)` must be a real number, see [`jacobian`](@ref) for array output.
-"""
function gradient(f, args...)
y, back = pullback(f, args...)
return back(sensitivity(y))
@@ -65,42 +60,35 @@ Base.adjoint(f::Function) = x -> gradient(f, x)[1]
# TODO store ids only
struct Params
- order::Buffer{Any, Vector{Any}}
+ order::Buffer # {Any, Vector{Any}}
params::IdSet{Any}
- Params() = new(Buffer([], false), IdSet())
end
+Params() = Params(Buffer([], false), IdSet())
+Params(xs) = Params(Buffer(xs, false), IdSet(xs))
+
@forward Params.order Base.iterate, Base.length, Base.getindex
-@forward Params.params Base.in
-function Base.union!(ps::Params, itrs...)
- foreach(itr -> foreach(x -> push!(ps, x), itr), itrs)
+function Base.push!(ps::Params, x)
+ if !(x in ps.params)
+ push!(ps.order, x)
+ push!(ps.params, x)
+ end
return ps
end
-Base.copy(ps::Params) = union!(Params(), ps)
-Base.union(ps::Params, itrs...) = union!(copy(ps), itrs...)
-Base.issetequal(ps1::Params, ps2::Params) = issetequal(ps1.params, ps2.params)
-Base.issetequal(ps1::Params, x::Base.AbstractSet) = issetequal(ps1.params, x)
-Base.issetequal(x::Base.AbstractSet, ps1::Params) = issetequal(x, ps1.params)
-
-function Base.intersect!(ps::Params, itrs...)
- for itr in itrs
- for x in collect(ps)
- x ∉ itr && delete!(ps, x)
- end
+@adjoint! function Base.push!(xs::IdSet, x...)
+ l = length(x)
+ push!(xs, x...), Δ -> begin
+ (Δ, ntuple(_ -> nothing, l)...)
end
- return ps
end
-Base.intersect(ps::Params, itrs...) = intersect!(copy(ps), itrs...)
-
-function Base.push!(ps::Params, x)
- if !(x in ps.params)
- push!(ps.order, x)
- push!(ps.params, x)
+@adjoint! function Base.push!(xs::Params, x::AbstractArray{T}...) where T
+ sz_x = size.(x)
+ push!(xs, x...), Δ -> begin
+ (Δ, map(x -> Ones{T}(x...), sz_x)...)
end
- return ps
end
Base.push!(ps::Params, x...) = (foreach(x -> push!(ps, x), x); ps)
@@ -139,8 +127,8 @@ function copy!(ps::Params, x::AbstractVector)
@assert length(x) == sum(length(p) for p in ps)
i = 0
for p in ps
- p .= reshape(x[i+1:i+length(p)], size(p))
- i += length(p)
+ p .= reshape(x[i+1:i+length(p)], size(p))
+ i += length(p)
end
ps
end
@@ -149,8 +137,8 @@ function copy!(x::AbstractVector, ps::Params)
@assert length(x) == sum(length(p) for p in ps)
i = 0
for p in ps
- x[i+1:i+length(p)] .= vec(p)
- i += length(p)
+ x[i+1:i+length(p)] .= vec(p)
+ i += length(p)
end
ps
end
@@ -163,23 +151,7 @@ end
Base.show(io::IO, ps::Grads) = print(io, "Grads(...)")
-@forward Grads.grads Base.setindex!
-@forward Grads.params Base.length
-
-const ADictOrGrads = Union{AbstractDict, Grads}
-
-# Dictionary interface.
-# Don't use the IdDict directly since it may contain some spurious pairs.
-Base.haskey(gs::Grads, x) = x ∈ gs.params
-Base.keys(gs::Grads) = gs.params
-Base.values(gs::Grads) = (gs.grads[p] for p in gs.params)
-
-function Base.iterate(gs::Grads, state...)
- res = iterate(gs.params, state...)
- isnothing(res) && return nothing
- p, next_state = res
- return gs[p], next_state
-end
+@forward Grads.grads Base.getindex, Base.haskey
function Base.getindex(gs::Grads, x)
isbits(x) && error("Only reference types can be differentiated with `Params`.")
@@ -196,8 +168,8 @@ length of `x` has to be equal to the sum of the lengths of all gradients.
function copy!(gs::Grads, x::AbstractVector)
i = 0
for p in gs.params
- gs[p] .= reshape(x[i+1:i+length(p)], size(p))
- i += length(p)
+ gs[p] .= reshape(x[i+1:i+length(p)], size(p))
+ i += length(p)
end
x
end
@@ -205,46 +177,12 @@ end
function copy!(x::AbstractVector, gs::Grads)
i = 0
for p in gs.params
- x[i+1:i+length(p)] .= vec(gs[p])
- i += length(p)
+ x[i+1:i+length(p)] .= vec(gs[p])
+ i += length(p)
end
x
end
-broadcasted(f, gs::Grads, gss::ADictOrGrads...) = map(f, gs, gss...)
-
-broadcasted(f, a::Numeric, gs::Grads) = map(x -> f(a, x), gs)
-broadcasted(f, gs::Grads, a::Numeric) = map(x -> f(x, a), gs)
-
-function materialize!(gs1::Grads, gs2::Grads)
- issetequal(gs1.params, gs2.params) ||
- throw(ArgumentError("Expected Grads objects with the same Params."))
- for p in gs1.params
- gs1[p] = gs2[p]
- end
- return gs1
-end
-
-
-function Base.map(f, gs1::Grads, gss::ADictOrGrads...)
- gsout = Grads(IdDict{Any,Any}(), Params(gs1.params))
- return map!(f, gsout, gs1, gss...)
-end
-
-function Base.map!(f, gsout::Grads, gss::ADictOrGrads...)
- all(issetequal(gsout.params, keys(gs)) for gs in gss) ||
- throw(ArgumentError("map! expects Grads objects with the same Params."))
- for p in gsout.params
- gsout[p] = f((_getformap(gs, p) for gs in gss)...)
- end
- return gsout
-end
-
-function _getformap(gs, p)
- g = gs[p]
- isnothing(g) ? fill!(similar(p), 0) : g
-end
-
function pullback(f, ps::Params)
cx = Context()
y, back = _pullback(cx, f)
From e1b94cc7285ecb9bd08a7ec0e7470f623d97f227 Mon Sep 17 00:00:00 2001
From: Dhairya Gandhi
Date: Fri, 11 Jun 2021 18:25:29 +0530
Subject: [PATCH 061/490] remove internal constructor
---
src/compiler/interface.jl | 110 +++++++++++++++++++++++++++++++-------
1 file changed, 91 insertions(+), 19 deletions(-)
diff --git a/src/compiler/interface.jl b/src/compiler/interface.jl
index 5883f4269..baeba003e 100644
--- a/src/compiler/interface.jl
+++ b/src/compiler/interface.jl
@@ -1,12 +1,8 @@
using InteractiveUtils
using InteractiveUtils: typesof
using Core: Typeof
-
-@static if VERSION >= v"1.1"
- import Base: copy!
-else
- import Future: copy!
-end
+import Base: copy!
+import Base.Broadcast: broadcasted, materialize!
mutable struct Context <: AContext
cache::Union{IdDict{Any,Any},Nothing}
@@ -47,8 +43,15 @@ end
sensitivity(y::Number) = one(y)
sensitivity(y::Complex) = error("Output is complex, so the gradient is not defined.")
+sensitivity(y::AbstractArray) = error("Output is an array, so the gradient is not defined. Perhaps you wanted jacobian.")
sensitivity(y) = error("Output should be scalar; gradients are not defined for output $(repr(y))")
+"""
+ gradient(f, args...)
+Returns a tuple containing `∂f/∂x` for each argument `x`,
+the derivative (for scalar x) or the gradient.
+`f(args...)` must be a real number, see [`jacobian`](@ref) for array output.
+"""
function gradient(f, args...)
y, back = pullback(f, args...)
return back(sensitivity(y))
@@ -68,6 +71,29 @@ Params() = Params(Buffer([], false), IdSet())
Params(xs) = Params(Buffer(xs, false), IdSet(xs))
@forward Params.order Base.iterate, Base.length, Base.getindex
+@forward Params.params Base.in
+
+function Base.union!(ps::Params, itrs...)
+ foreach(itr -> foreach(x -> push!(ps, x), itr), itrs)
+ return ps
+end
+
+Base.copy(ps::Params) = union!(Params(), ps)
+Base.union(ps::Params, itrs...) = union!(copy(ps), itrs...)
+Base.issetequal(ps1::Params, ps2::Params) = issetequal(ps1.params, ps2.params)
+Base.issetequal(ps1::Params, x::Base.AbstractSet) = issetequal(ps1.params, x)
+Base.issetequal(x::Base.AbstractSet, ps1::Params) = issetequal(x, ps1.params)
+
+function Base.intersect!(ps::Params, itrs...)
+ for itr in itrs
+ for x in collect(ps)
+ x ∉ itr && delete!(ps, x)
+ end
+ end
+ return ps
+end
+
+Base.intersect(ps::Params, itrs...) = intersect!(copy(ps), itrs...)
function Base.push!(ps::Params, x)
if !(x in ps.params)
@@ -102,8 +128,6 @@ function Base.delete!(ps::Params, x)
return ps
end
-Params(xs) = push!(Params(), xs...)
-
Base.Broadcast.broadcasted(f, ps::Params) = broadcasted(f, ps.order)
Base.:(==)(x::Params, y::Params) = x.order.data == y.order.data
@@ -118,7 +142,6 @@ end
"""
copy!(ps::Params, x::AbstractVector)
copy!(x::AbstractVector, ps::Params)
-
Copies the content of array `x` into the parameters `ps` or viceversa.
The length of `x` has to be equal to the sum of the lengths
of all parameters.
@@ -127,8 +150,8 @@ function copy!(ps::Params, x::AbstractVector)
@assert length(x) == sum(length(p) for p in ps)
i = 0
for p in ps
- p .= reshape(x[i+1:i+length(p)], size(p))
- i += length(p)
+ p .= reshape(x[i+1:i+length(p)], size(p))
+ i += length(p)
end
ps
end
@@ -137,8 +160,8 @@ function copy!(x::AbstractVector, ps::Params)
@assert length(x) == sum(length(p) for p in ps)
i = 0
for p in ps
- x[i+1:i+length(p)] .= vec(p)
- i += length(p)
+ x[i+1:i+length(p)] .= vec(p)
+ i += length(p)
end
ps
end
@@ -151,7 +174,23 @@ end
Base.show(io::IO, ps::Grads) = print(io, "Grads(...)")
-@forward Grads.grads Base.getindex, Base.haskey
+@forward Grads.grads Base.setindex!
+@forward Grads.params Base.length
+
+const ADictOrGrads = Union{AbstractDict, Grads}
+
+# Dictionary interface.
+# Don't use the IdDict directly since it may contain some spurious pairs.
+Base.haskey(gs::Grads, x) = x ∈ gs.params
+Base.keys(gs::Grads) = gs.params
+Base.values(gs::Grads) = (gs.grads[p] for p in gs.params)
+
+function Base.iterate(gs::Grads, state...)
+ res = iterate(gs.params, state...)
+ isnothing(res) && return nothing
+ p, next_state = res
+ return gs[p], next_state
+end
function Base.getindex(gs::Grads, x)
isbits(x) && error("Only reference types can be differentiated with `Params`.")
@@ -161,15 +200,14 @@ end
"""
copy!(gs::Grads, x::AbstractVector)
copy!(x::AbstractVector, gs::Grads)
-
Copies the content of array `x` into the gradient object `gs` or vice versa. The
length of `x` has to be equal to the sum of the lengths of all gradients.
"""
function copy!(gs::Grads, x::AbstractVector)
i = 0
for p in gs.params
- gs[p] .= reshape(x[i+1:i+length(p)], size(p))
- i += length(p)
+ gs[p] .= reshape(x[i+1:i+length(p)], size(p))
+ i += length(p)
end
x
end
@@ -177,12 +215,46 @@ end
function copy!(x::AbstractVector, gs::Grads)
i = 0
for p in gs.params
- x[i+1:i+length(p)] .= vec(gs[p])
- i += length(p)
+ x[i+1:i+length(p)] .= vec(gs[p])
+ i += length(p)
end
x
end
+broadcasted(f, gs::Grads, gss::ADictOrGrads...) = map(f, gs, gss...)
+
+broadcasted(f, a::Numeric, gs::Grads) = map(x -> f(a, x), gs)
+broadcasted(f, gs::Grads, a::Numeric) = map(x -> f(x, a), gs)
+
+function materialize!(gs1::Grads, gs2::Grads)
+ issetequal(gs1.params, gs2.params) ||
+ throw(ArgumentError("Expected Grads objects with the same Params."))
+ for p in gs1.params
+ gs1[p] = gs2[p]
+ end
+ return gs1
+end
+
+
+function Base.map(f, gs1::Grads, gss::ADictOrGrads...)
+ gsout = Grads(IdDict{Any,Any}(), Params(gs1.params))
+ return map!(f, gsout, gs1, gss...)
+end
+
+function Base.map!(f, gsout::Grads, gss::ADictOrGrads...)
+ all(issetequal(gsout.params, keys(gs)) for gs in gss) ||
+ throw(ArgumentError("map! expects Grads objects with the same Params."))
+ for p in gsout.params
+ gsout[p] = f((_getformap(gs, p) for gs in gss)...)
+ end
+ return gsout
+end
+
+function _getformap(gs, p)
+ g = gs[p]
+ isnothing(g) ? fill!(similar(p), 0) : g
+end
+
function pullback(f, ps::Params)
cx = Context()
y, back = _pullback(cx, f)
From 56c7c084ddd303999ca5504ba2f5e88a587946b9 Mon Sep 17 00:00:00 2001
From: Dhairya Gandhi
Date: Mon, 14 Jun 2021 15:07:42 +0530
Subject: [PATCH 062/490] add tests
---
test/interface.jl | 40 ++++++++++++++++++++++++++++++++++++++++
1 file changed, 40 insertions(+)
diff --git a/test/interface.jl b/test/interface.jl
index 0ffb933f6..2c1af3c3d 100644
--- a/test/interface.jl
+++ b/test/interface.jl
@@ -163,4 +163,44 @@ end
@test all(abs.(gs[w]) .<= 1e-5)
@test all(abs.(gs[b]) .<= 1e-5)
end
+
+ @testset "Params nesting" begin
+ struct Dense{F,T,S}
+ W::T
+ b::S
+ σ::F
+ end
+
+ (d::Dense)(x) = d.σ.(d.W * x .+ d.b)
+ d = Dense(ones(Float32, 3,3), zeros(Float32, 3), identity)
+ ps = Zygote.Params([d.W, d.b])
+ r = ones(Float32, 3,3)
+
+ gs = gradient(ps) do
+ p, pb = pullback(ps) do
+ sum(d(r))
+ end
+ g = pb(p)
+ sum(g[d.W]) # + sum(g[d.b])
+ end
+
+ @test gs[d.W] ≈ fill(81f0, (3,3))
+
+ # Test L2
+ l2g = gradient(ps) do
+ sum(sum(x .^ 2) for x in ps)
+ end
+ @test l2g[d.W] ≈ fill(2.f0, size(d.W))
+ @test l2g[d.b] ≈ fill(0.f0, size(d.b))
+
+ # Can be safely removed - creating Params within
+ # gradient calls may break between releases.
+ sgs = gradient(ps) do
+ sum(sum(x) for x in Zygote.Params([d.W, d.b, b]))
+ end
+ @test sgs[d.W] ≈ fill(1.f0, size(d.W))
+ @test sgs[d.b] ≈ fill(1.f0, size(d.b))
+ end
+
+
end
From 0424158150bd727a147b6bcfb4663db0f83a2be4 Mon Sep 17 00:00:00 2001
From: Dhairya Gandhi
Date: Mon, 14 Jun 2021 15:11:50 +0530
Subject: [PATCH 063/490] add Params(::Params)
---
src/compiler/interface.jl | 3 ++-
1 file changed, 2 insertions(+), 1 deletion(-)
diff --git a/src/compiler/interface.jl b/src/compiler/interface.jl
index baeba003e..75850fdbb 100644
--- a/src/compiler/interface.jl
+++ b/src/compiler/interface.jl
@@ -68,7 +68,8 @@ struct Params
end
Params() = Params(Buffer([], false), IdSet())
-Params(xs) = Params(Buffer(xs, false), IdSet(xs))
+Params(xs::Vector) = Params(Buffer(xs, false), IdSet(xs))
+Params(ps::Params) = ps
@forward Params.order Base.iterate, Base.length, Base.getindex
@forward Params.params Base.in
From 7e6c2c9e0aa224a4451f535e333c358c1302e27d Mon Sep 17 00:00:00 2001
From: Dhairya Gandhi
Date: Mon, 14 Jun 2021 15:27:03 +0530
Subject: [PATCH 064/490] fixupmissing symbol
---
test/interface.jl | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/test/interface.jl b/test/interface.jl
index 2c1af3c3d..159f4bce1 100644
--- a/test/interface.jl
+++ b/test/interface.jl
@@ -196,7 +196,7 @@ end
# Can be safely removed - creating Params within
# gradient calls may break between releases.
sgs = gradient(ps) do
- sum(sum(x) for x in Zygote.Params([d.W, d.b, b]))
+ sum(sum(x) for x in Zygote.Params([d.W, d.b]))
end
@test sgs[d.W] ≈ fill(1.f0, size(d.W))
@test sgs[d.b] ≈ fill(1.f0, size(d.b))
From d2785a6c5c3facf07605d89e72ed532140f55afd Mon Sep 17 00:00:00 2001
From: David Widmann
Date: Mon, 14 Jun 2021 12:18:21 +0200
Subject: [PATCH 065/490] Disable regression tests
---
test/gradcheck.jl | 3 ++-
1 file changed, 2 insertions(+), 1 deletion(-)
diff --git a/test/gradcheck.jl b/test/gradcheck.jl
index c0aa8a665..53c29a047 100644
--- a/test/gradcheck.jl
+++ b/test/gradcheck.jl
@@ -4,7 +4,6 @@ using Zygote: gradient
using Base.Broadcast: broadcast_shape
using Distributed: pmap, CachingPool, workers
import FiniteDifferences
-using BenchmarkTools
function ngradient(f, xs::AbstractArray...)
grads = zero.(xs)
@@ -1719,6 +1718,7 @@ end
@test gradcheck(x -> prod(Base.Fix1(+, 1), x), randn(100))
@test gradcheck(x -> prod(Base.Fix2(+, 1), x), randn(100))
+#= regression tests are not included to reduce CI times
# check the execution times compared with a closure
# https://github.com/FluxML/Zygote.jl/issues/957
x = randn(100)
@@ -1727,4 +1727,5 @@ end
tfix2 = @belapsed(gradient($(x -> prod(Base.Fix2(+, 1), x)), $x))
@test tfix1 < 2 * tclosure
@test tfix2 < 2 * tclosure
+=#
end
From 5af6148496390344c96dc29662814883ae8f32db Mon Sep 17 00:00:00 2001
From: David Widmann
Date: Mon, 14 Jun 2021 12:18:32 +0200
Subject: [PATCH 066/490] Clean test dependencies
---
Project.toml | 4 +---
1 file changed, 1 insertion(+), 3 deletions(-)
diff --git a/Project.toml b/Project.toml
index b586ded24..fc02a9a3f 100644
--- a/Project.toml
+++ b/Project.toml
@@ -38,14 +38,12 @@ ZygoteRules = "0.2.1"
julia = "1.3"
[extras]
-BenchmarkTools = "6e4b80f9-dd63-53aa-95a3-0cdb28fa8baf"
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
Distances = "b4f34e82-e78d-54a5-968a-f98e89d6e8f7"
FFTW = "7a1cc6ca-52ef-59f5-83cd-3a7055c09341"
FiniteDifferences = "26cc04aa-876d-5657-8c51-4c34ba976000"
LogExpFunctions = "2ab3a3ac-af41-5b50-aa03-7779005ae688"
-StatsFuns = "4c63d2b9-4356-54db-8cca-17b64c39e42c"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
[targets]
-test = ["BenchmarkTools", "CUDA", "Distances", "FFTW", "FiniteDifferences", "LogExpFunctions", "Test"]
+test = ["CUDA", "Distances", "FFTW", "FiniteDifferences", "LogExpFunctions", "Test"]
From d7308036ab964c81c7a5fb797ee1b1dcbd30e9c9 Mon Sep 17 00:00:00 2001
From: David Widmann
Date: Mon, 14 Jun 2021 12:21:16 +0200
Subject: [PATCH 067/490] Re-add StatsFuns dependency
---
Project.toml | 1 +
1 file changed, 1 insertion(+)
diff --git a/Project.toml b/Project.toml
index fc02a9a3f..56d72758a 100644
--- a/Project.toml
+++ b/Project.toml
@@ -43,6 +43,7 @@ Distances = "b4f34e82-e78d-54a5-968a-f98e89d6e8f7"
FFTW = "7a1cc6ca-52ef-59f5-83cd-3a7055c09341"
FiniteDifferences = "26cc04aa-876d-5657-8c51-4c34ba976000"
LogExpFunctions = "2ab3a3ac-af41-5b50-aa03-7779005ae688"
+StatsFuns = "4c63d2b9-4356-54db-8cca-17b64c39e42c" # otherwise we can't add a compat bound
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
[targets]
From 4ba326f7ddb4533d0443ef5ccb299eba0c8b3f8e Mon Sep 17 00:00:00 2001
From: Dhairya Gandhi
Date: Mon, 14 Jun 2021 20:59:47 +0530
Subject: [PATCH 068/490] add tuple constructor
---
src/compiler/interface.jl | 3 ++-
1 file changed, 2 insertions(+), 1 deletion(-)
diff --git a/src/compiler/interface.jl b/src/compiler/interface.jl
index 75850fdbb..4c8eb0d53 100644
--- a/src/compiler/interface.jl
+++ b/src/compiler/interface.jl
@@ -68,8 +68,9 @@ struct Params
end
Params() = Params(Buffer([], false), IdSet())
-Params(xs::Vector) = Params(Buffer(xs, false), IdSet(xs))
+Params(xs) = Params(Buffer(xs, false), IdSet(xs))
Params(ps::Params) = ps
+Params(xs::Tuple) = Params(collect(xs))
@forward Params.order Base.iterate, Base.length, Base.getindex
@forward Params.params Base.in
From ce8eb91a592cdd21a1e6610e03fd9e5245e40ae3 Mon Sep 17 00:00:00 2001
From: Dhairya Gandhi
Date: Tue, 15 Jun 2021 11:11:46 +0530
Subject: [PATCH 069/490] whitespace
---
src/compiler/interface.jl | 4 ++++
1 file changed, 4 insertions(+)
diff --git a/src/compiler/interface.jl b/src/compiler/interface.jl
index 4c8eb0d53..86e847dc4 100644
--- a/src/compiler/interface.jl
+++ b/src/compiler/interface.jl
@@ -48,8 +48,10 @@ sensitivity(y) = error("Output should be scalar; gradients are not defined for o
"""
gradient(f, args...)
+
Returns a tuple containing `∂f/∂x` for each argument `x`,
the derivative (for scalar x) or the gradient.
+
`f(args...)` must be a real number, see [`jacobian`](@ref) for array output.
"""
function gradient(f, args...)
@@ -144,6 +146,7 @@ end
"""
copy!(ps::Params, x::AbstractVector)
copy!(x::AbstractVector, ps::Params)
+
Copies the content of array `x` into the parameters `ps` or viceversa.
The length of `x` has to be equal to the sum of the lengths
of all parameters.
@@ -202,6 +205,7 @@ end
"""
copy!(gs::Grads, x::AbstractVector)
copy!(x::AbstractVector, gs::Grads)
+
Copies the content of array `x` into the gradient object `gs` or vice versa. The
length of `x` has to be equal to the sum of the lengths of all gradients.
"""
From c4373ddcd0e5ae248ac4f2e3cefe627b4d8b112a Mon Sep 17 00:00:00 2001
From: Simeon Schaub
Date: Wed, 16 Jun 2021 05:26:16 -0400
Subject: [PATCH 070/490] fix #996
---
src/lib/broadcast.jl | 1 +
test/gradcheck.jl | 4 ++++
2 files changed, 5 insertions(+)
diff --git a/src/lib/broadcast.jl b/src/lib/broadcast.jl
index df9fdc9b5..879f62c1a 100644
--- a/src/lib/broadcast.jl
+++ b/src/lib/broadcast.jl
@@ -186,6 +186,7 @@ end
y, ∂b = _broadcast((x...) -> _pullback(__context__, f, x...), args...)
y, function (ȳ)
dxs = ∂b(ȳ)
+ dxs === nothing && return nothing
(nothing, dxs...)
end
end
diff --git a/test/gradcheck.jl b/test/gradcheck.jl
index 43f482d01..aeea366cf 100644
--- a/test/gradcheck.jl
+++ b/test/gradcheck.jl
@@ -1713,3 +1713,7 @@ end
@test s == 0.0
@test gs == (nothing,)
end
+
+# https://github.com/FluxML/Zygote.jl/issues/996
+a = rand(3)
+@test Zygote.gradient(x->sum(x .+ rand.()), a) == (ones(3),)
From b250d925b8d4db8d5f76ee506c0ba6a4e35fe689 Mon Sep 17 00:00:00 2001
From: Lyndon White
Date: Fri, 18 Jun 2021 18:00:59 +0100
Subject: [PATCH 071/490] Use ChainRules RuleConfig (#990)
* draft test_gradient function
* draft
* update draft
* wrap input
* polish zygote_ad_rrule
* clean up chainrulestest utils
* remove multizeros
* rename zygote_ad_rrule to rrule_via_ad
* rename export
* add a real test example
* take nothing seriously
* skip all chainrules tests
* refresh often
* remove chainrules_fallback method
* Revert "refresh often"
This reverts commit 388bd0f060de9b36d89059ac534490a52123909e.
* remove one level of nesting
* add a test
* use ChainRules RuleConfigs
* Fix it so can use RuleConfig in Zygote
* make tests easier to use
* wip
* Mark testing of rrule_via_ad on round as broken (others work)
* debugging
* Don't take nothing seriously
* remove scratch file
* stop taking nothing seriously again
* fix typo
* Fix use of test_rrule to test Zygote
* renable ChainRules tests
* fix ChainRulesTest to use new ChainRulesCore
* bring back adjoint for sum on arrays of bools
* clash with names less
* Use old rule for sum on Arrays of Arrays
* Apply suggestions from code review
Co-authored-by: Miha Zgubic
* import ZygoteRuleConfig into the tests
* import ChainRules testing tools etc
* import ChainRules testing tools etc
Co-authored-by: Miha Zgubic
Co-authored-by: Michael Abbott <32575566+mcabbott@users.noreply.github.com>
Co-authored-by: Miha Zgubic
---
Project.toml | 6 +-
src/Zygote.jl | 2 +
src/compiler/chainrules.jl | 76 ++++++++++++++++++++-----
src/compiler/interface2.jl | 18 ++++--
src/lib/array.jl | 11 ++--
test/chainrules.jl | 113 ++++++++++++++++++++++++++++---------
test/features.jl | 6 +-
test/gradcheck.jl | 6 +-
test/lib/array.jl | 6 +-
test/runtests.jl | 2 +-
10 files changed, 183 insertions(+), 63 deletions(-)
diff --git a/Project.toml b/Project.toml
index 2a89b4cd1..3fe2ceb58 100644
--- a/Project.toml
+++ b/Project.toml
@@ -23,8 +23,9 @@ ZygoteRules = "700de1a5-db45-46bc-99cf-38207098b444"
[compat]
AbstractFFTs = "0.5, 1.0"
-ChainRules = "0.7.66, 0.8"
+ChainRules = "0.8.12"
ChainRulesCore = "0.9.44, 0.10"
+ChainRulesTestUtils = "0.7.1"
DiffRules = "1.0"
FillArrays = "0.8, 0.9, 0.10, 0.11"
ForwardDiff = "0.10"
@@ -39,6 +40,7 @@ julia = "1.3"
[extras]
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
+ChainRulesTestUtils = "cdddcdb0-9152-4a09-a978-84456f9df70a"
Distances = "b4f34e82-e78d-54a5-968a-f98e89d6e8f7"
FFTW = "7a1cc6ca-52ef-59f5-83cd-3a7055c09341"
FiniteDifferences = "26cc04aa-876d-5657-8c51-4c34ba976000"
@@ -47,4 +49,4 @@ StatsFuns = "4c63d2b9-4356-54db-8cca-17b64c39e42c"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
[targets]
-test = ["CUDA", "Distances", "FFTW", "FiniteDifferences", "LogExpFunctions", "Test"]
+test = ["ChainRulesTestUtils", "CUDA", "Distances", "FFTW", "FiniteDifferences", "LogExpFunctions", "Test"]
diff --git a/src/Zygote.jl b/src/Zygote.jl
index 873707fce..895e65f3c 100644
--- a/src/Zygote.jl
+++ b/src/Zygote.jl
@@ -6,6 +6,7 @@ using LinearAlgebra: copytri!, AbstractTriangular
import ZygoteRules: @adjoint, @adjoint!, AContext, adjoint, _pullback, pullback,
literal_getproperty, literal_getfield
+using ChainRulesCore
using ChainRules: ChainRules, rrule, unthunk, canonicalize
using IRTools
using MacroTools, Requires
@@ -13,6 +14,7 @@ using MacroTools: @forward
import Distributed: pmap, CachingPool, workers
export Params, gradient, jacobian, hessian, diaghessian, pullback, pushforward, @code_adjoint
+export rrule_via_ad
const Numeric{T<:Number} = Union{T, AbstractArray{<:T}}
diff --git a/src/compiler/chainrules.jl b/src/compiler/chainrules.jl
index 8392a27ce..c4e72f07e 100644
--- a/src/compiler/chainrules.jl
+++ b/src/compiler/chainrules.jl
@@ -1,4 +1,11 @@
-const chainrules_fallback = which(rrule, Tuple{Any})
+struct ZygoteRuleConfig{CTX<:AContext} <: RuleConfig{Union{HasReverseMode,NoForwardsMode}}
+ context::CTX
+end
+ZygoteRuleConfig() = ZygoteRuleConfig(Context())
+
+
+const rrule_fallback_method = Base.which(rrule, Tuple{Any, Vararg{Any}})
+const rrule_redispatcher_method = Base.which(rrule, Tuple{RuleConfig, Any, Vararg{Any}})
"""
has_chain_rrule(T)
@@ -10,13 +17,20 @@ If it does not, then the second argument is a list of edges to attach to the Cod
such that if a suitable rule is defined later, the generated function will recompile.
"""
function has_chain_rrule(T)
- m = meta(Tuple{typeof(rrule),T.parameters...})
- if m.method !== chainrules_fallback
- # found a rrule, no need to add any edges
- return true, nothing
+ config_T, arg_Ts = Iterators.peel(T.parameters)
+ m_with_config = meta(Tuple{typeof(rrule), config_T, arg_Ts...})
+ if m_with_config.method === rrule_redispatcher_method
+ # it is being redispatched without config, so check it that hits the fallback
+ m_without_config = meta(Tuple{typeof(rrule), arg_Ts...})
+ if m_without_config.method === rrule_fallback_method
+ # no rrule exists, return instance for m_with_config as that will be invalidated
+ # directly if configured rule added, or indirectly if unconfigured rule added
+ return false, m_with_config.instance
+ end
end
-
- return false, m.instance
+ # otherwise found a rrule, no need to add any edges, as it will generate code with
+ # natural edges.
+ return true, nothing
end
"""
@@ -80,25 +94,25 @@ end
@inline (s::ZBack)(::Nothing) = nothing
"""
- chain_rrule(f, args...)
+ chain_rrule(config, f, args...)
Returns a the (primal) value of `f(args...)` and a pullback, by invoking `ChainRulesCore.rrule(f, args...)`.
The pullback is appropriately wrapped up to follow Zygote conventions.
"""
-@inline function chain_rrule(f, args...)
- y, back = rrule(f, args...)
+@inline function chain_rrule(config, f, args...)
+ y, back = rrule(config, f, args...)
return y, ZBack(back)
end
"""
- chain_rrule_kw(kwf, kwargs, f, args...)
+ chain_rrule_kw(config, kwf, kwargs, f, args...)
As per [`chain_rrule`](@ref) but with support for kwargs.
`kwf` should be the kwfunc matching to `f`, and `kwargs` are a `NamedTuple` of keyword arguments.
"""
-@inline function chain_rrule_kw(kwf, kwargs, f, args...)
- y, back = rrule(f, args...; kwargs...)
+@inline function chain_rrule_kw(config, kwf, kwargs, f, args...)
+ y, back = rrule(config, f, args...; kwargs...)
function kw_zpullback(dy)
dxs = ZBack(back)(dy)
if dxs === nothing # if dxs is nothing, then all partiaols are nothing
@@ -110,3 +124,39 @@ As per [`chain_rrule`](@ref) but with support for kwargs.
end
return y, kw_zpullback
end
+
+
+function ChainRulesCore.rrule_via_ad(config::ZygoteRuleConfig, f, args...)
+ y, pb = _pullback(config.context, f, args...)
+ ad_pullback(Δ) = zygote2differential(pb(wrap_chainrules_output(Δ)), (f, args...))
+ return y, ad_pullback
+end
+
+"""
+ zygote2differential(x)
+
+Convert input `x` from the Zygote format to the ChainRules differential types.
+"""
+zygote2differential(x, primal) = z2d(x, primal)
+zygote2differential(::Nothing, ::Any) = NoTangent()
+zygote2differential(t::Tuple, primal::Tuple) = map(z2d, t, primal)
+zygote2differential(t::Tuple, primal) = (@warn "primal should be a tuple, not $primal"; return t)
+z2d(x, ::Any) = x
+z2d(::Nothing, ::Any) = NoTangent()
+z2d(a::AbstractArray{<:Number}, primal::AbstractArray{T}) where T = a
+z2d(a::AbstractArray, primal::AbstractArray{T}) where T = z2d.(a, primal)
+z2d(x::Union{AbstractZero, Tangent}, ::Any) = (difftype_warn(x); return x)
+function z2d(t::Tuple, primal::Tuple)
+ tp::Tuple = map(z2d, t, primal)
+ primal_type = typeof(primal)
+ return canonicalize(Tangent{primal_type, typeof(tp)}(tp))
+end
+
+function z2d(t::NamedTuple, primal)
+ primal_type = typeof(primal)
+ fnames = fieldnames(primal_type)
+ complete_t = NamedTuple{fnames}(fn in keys(t) ? t[fn] : nothing for fn in fnames)
+ primals = NamedTuple{fnames}(getfield(primal, fn) for fn in fnames)
+ tp::NamedTuple = map(z2d, complete_t, primals)
+ return canonicalize(Tangent{primal_type, typeof(tp)}(tp))
+end
diff --git a/src/compiler/interface2.jl b/src/compiler/interface2.jl
index 350599bba..b77c7e3a6 100644
--- a/src/compiler/interface2.jl
+++ b/src/compiler/interface2.jl
@@ -10,12 +10,18 @@ end
T = Tuple{f,args...}
ignore_sig(T) && return :(f(args...), Pullback{$T}(()))
- iskw = is_kwfunc(f, args...)
- # if it is_kw then `args[1]` are the keyword args, `args[2]` is actual function
- base_T = iskw ? Tuple{args[2:end]...} : T
- hascr, cr_edge = has_chain_rrule(base_T)
- chain_rrule_f = iskw ? :chain_rrule_kw : :chain_rrule
- hascr && return :($chain_rrule_f(f, args...))
+ if is_kwfunc(f, args...)
+ # if it is_kw then `args[1]` are the keyword args, `args[2]` is actual function
+ cr_T = Tuple{ZygoteRuleConfig{ctx}, args[2:end]...}
+ chain_rrule_f = :chain_rrule_kw
+ else
+ cr_T = Tuple{ZygoteRuleConfig{ctx}, f, args...}
+ chain_rrule_f = :chain_rrule
+ end
+
+ hascr, cr_edge = has_chain_rrule(cr_T)
+
+ hascr && return :($chain_rrule_f(ZygoteRuleConfig(ctx), f, args...))
g = try _lookup_grad(T) catch e e end
!(g isa Tuple) && return :(f(args...), Pullback{$T}((f,)))
diff --git a/src/lib/array.jl b/src/lib/array.jl
index f8389d4e3..f1e217fe2 100644
--- a/src/lib/array.jl
+++ b/src/lib/array.jl
@@ -287,19 +287,16 @@ end
end
end
-@adjoint function sum(xs::AbstractArray{Bool}; dims = :)
- sum(xs, dims = dims), Δ -> (nothing,)
-end
-
-@adjoint function sum(f, xs::AbstractArray; kws...)
+@adjoint function sum(f, xs::AbstractArray{<:AbstractArray}; kws...)
@assert !haskey(kws, :init) # TODO add init support (julia 1.6)
return pullback(__context__, (f, xs) -> sum(f.(xs); kws...), f, xs)
end
-@adjoint function sum(::typeof(abs2), X::AbstractArray; dims = :)
- return sum(abs2, X; dims=dims), Δ::Union{Number, AbstractArray}->(nothing, ((2Δ) .* X))
+@adjoint function sum(xs::AbstractArray{Bool}; dims = :)
+ sum(xs, dims = dims), Δ -> (nothing,)
end
+
function _pullback(cx::AContext, ::typeof(prod), f, xs::AbstractArray)
y, back = pullback(cx, ((f, xs) -> prod(f.(xs))), f, xs)
y, ȳ -> (nothing, back(ȳ)...)
diff --git a/test/chainrules.jl b/test/chainrules.jl
index 8b7034753..519c10ad6 100644
--- a/test/chainrules.jl
+++ b/test/chainrules.jl
@@ -1,16 +1,14 @@
-using Zygote, Test, ChainRules
-
-
-@testset "ChainRules Integration" begin
- @testset "basic" begin
+using ChainRulesCore, ChainRulesTestUtils, Zygote
+@testset "ChainRules integration" begin
+ @testset "ChainRules basics" begin
cr_inner_demo_rrule_hitcount = Ref(0)
cr_inner_demo_pullback_hitcount = Ref(0)
cr_inner_demo(x) = 5x
- function ChainRules.rrule(::typeof(cr_inner_demo), x)
+ function ChainRulesCore.rrule(::typeof(cr_inner_demo), x)
cr_inner_demo_rrule_hitcount[] += 1
function cr_inner_demo_pullback(Δx)
cr_inner_demo_pullback_hitcount[] += 1
- return ChainRules.NO_FIELDS, 5.0*Δx
+ return NoTangent(), 5.0*Δx
end
return cr_inner_demo(x), cr_inner_demo_pullback
end
@@ -19,6 +17,7 @@ using Zygote, Test, ChainRules
2 + 10cr_inner_demo(x)
end
+ #
@testset "gradient inner" begin
cr_inner_demo_rrule_hitcount[] = 0
@@ -55,19 +54,19 @@ using Zygote, Test, ChainRules
simo_rrule_hitcount = Ref(0)
simo_pullback_hitcount = Ref(0)
simo(x) = (5x, 7x)
- function ChainRules.rrule(::typeof(simo), x)
+ function ChainRulesCore.rrule(::typeof(simo), x)
simo_rrule_hitcount[] += 1
function simo_pullback((Δa, Δb))
simo_pullback_hitcount[] += 1
- return ChainRules.NO_FIELDS, 5*Δa + 7*Δb
+ return NoTangent(), 5*Δa + 7*Δb
end
return simo(x), simo_pullback
end
-
+
simo_outer(x) = sum(simo(x))
- @assert simo_rrule_hitcount[] == 0
- @assert simo_pullback_hitcount[] == 0
+ simo_rrule_hitcount[] = 0
+ simo_pullback_hitcount[] = 0
@test (12,) == Zygote.gradient(simo_outer, π)
@test simo_rrule_hitcount[] == 1
@test simo_pullback_hitcount[] == 1
@@ -77,19 +76,20 @@ using Zygote, Test, ChainRules
miso_rrule_hitcount = Ref(0)
miso_pullback_hitcount = Ref(0)
miso(a, b) = 5a + 7b
- function ChainRules.rrule(::typeof(miso), a, b)
+ function ChainRulesCore.rrule(::typeof(miso), a, b)
miso_rrule_hitcount[] += 1
function miso_pullback(Δy)
miso_pullback_hitcount[] += 1
- return ChainRules.NO_FIELDS, 5Δy , 7Δy
+ return NoTangent(), 5Δy , 7Δy
end
return miso(a, b), miso_pullback
end
+
miso_outer(x) = miso(100x, 10x)
- @assert miso_rrule_hitcount[] == 0
- @assert miso_pullback_hitcount[] == 0
+ miso_rrule_hitcount[] = 0
+ miso_pullback_hitcount[] = 0
@test (570,) == Zygote.gradient(miso_outer, π)
@test miso_rrule_hitcount[] == 1
@test miso_pullback_hitcount[] == 1
@@ -99,17 +99,17 @@ using Zygote, Test, ChainRules
mimo_rrule_hitcount = Ref(0)
mimo_pullback_hitcount = Ref(0)
mimo(a, b) = (5a + 7b, 100a, 10b)
- function ChainRules.rrule(::typeof(mimo), a, b)
+ function ChainRulesCore.rrule(::typeof(mimo), a, b)
mimo_rrule_hitcount[] += 1
function mimo_pullback((Δx, Δy, Δz))
mimo_pullback_hitcount[] += 1
- return ChainRules.NO_FIELDS, 5Δx + 100Δy , 7Δx + 10Δz
+ return NoTangent(), 5Δx + 100Δy , 7Δx + 10Δz
end
return mimo(a, b), mimo_pullback
end
- @assert mimo_rrule_hitcount[] == 0
- @assert mimo_pullback_hitcount[] == 0
+ mimo_rrule_hitcount[] = 0
+ mimo_pullback_hitcount[] = 0
_, pb = Zygote.pullback(mimo, π, 2π)
@test (105, 17) == pb((1, 1, 1))
@test mimo_rrule_hitcount[] == 1
@@ -129,13 +129,14 @@ using Zygote, Test, ChainRules
# to a single `nothing` if they are all zero-like.
not_diff_eg(x, i) = [10, 20][i]
- function ChainRules.rrule(::typeof(not_diff_eg), x, i)
+ function ChainRulesCore.rrule(::typeof(not_diff_eg), x, i)
function not_diff_eg_pullback(Δ)
- return ChainRules.NO_FIELDS, ChainRules.ZeroTangent(), ChainRules.NoTangent()
+ return NoTangent(), ZeroTangent(), NoTangent()
end
return not_diff_eg(x, i), not_diff_eg_pullback
end
+
_, pb = Zygote.pullback(not_diff_eg, 10.4, 2)
@test pb(1.2) === nothing
end
@@ -175,14 +176,15 @@ using Zygote, Test, ChainRules
kwfoo_rrule_hitcount = Ref(0)
kwfoo_pullback_hitcount = Ref(0)
kwfoo(x; k=10) = x + k
- function ChainRules.rrule(::typeof(kwfoo), x; k=10)
+ function ChainRulesCore.rrule(::typeof(kwfoo), x; k=10)
kwfoo_rrule_hitcount[] += 1
function kwfoo_pullback(Δy)
kwfoo_pullback_hitcount[] += 1
- return ChainRules.NO_FIELDS, Δy
+ return NoTangent(), Δy
end
return kwfoo(x; k=k), kwfoo_pullback
end
+
kwfoo_outer_unused(x) = kwfoo(x)
kwfoo_outer_used(x) = kwfoo(x; k=-15)
@@ -196,24 +198,81 @@ using Zygote, Test, ChainRules
end
end
-
@testset "kwarg, with all AbstractZero partials" begin
# while ChainRules always has a partial for every input, Zygote combined them all
# to a single `nothing` if they are all zero-like.
not_diff_kw_eg(x, i; kw=1.0) = [10, 20][i]
- function ChainRules.rrule(::typeof(not_diff_kw_eg), x, i; kwargs...)
+ function ChainRulesCore.rrule(::typeof(not_diff_kw_eg), x, i; kwargs...)
function not_diff_kw_eg_pullback(Δ)
- return ChainRules.NO_FIELDS, ChainRules.ZeroTangent(), ChainRules.NoTangent()
+ return NoTangent(), ZeroTangent(), NoTangent()
end
return not_diff_kw_eg(x, i; kwargs...), not_diff_kw_eg_pullback
end
+
@test (nothing,) == Zygote.gradient(x->not_diff_kw_eg(x, 2), 10.4)
@test (nothing,) == Zygote.gradient(x->not_diff_kw_eg(x, 2; kw=2.0), 10.4)
end
end
+@testset "ChainRulesCore.rrule_via_ad" begin
+ @testset "basic" begin
+ # broken because Zygoye compresses `(NoTangent(), NoTangent())` into just NoTangent()
+ # which ChainRulesTestUtils does not think is valid:
+ @test_broken(rrule_via_ad(ZygoteRuleConfig(), round, 2.2) isa Tuple{NoTangent,NoTangent})
+ # uncomment below when/if above is fixed
+ # test_rrule(ZygoteRuleConfig(), round, 2.2; rrule_f=rrule_via_ad)
+
+ test_rrule(ZygoteRuleConfig(), vcat, rand(3), rand(4); rrule_f=rrule_via_ad, check_inferred=false)
+ test_rrule(ZygoteRuleConfig(), getindex, rand(5), 3; rrule_f=rrule_via_ad)
+ end
+
+ @testset "struct" begin
+ struct Foo
+ x
+ y
+ end
+ makefoo(a, b) = Foo(a, b)
+ sumfoo(foo) = foo.x + foo.y
+
+
+ test_rrule(
+ ZygoteRuleConfig(), sumfoo, Foo(1.2, 2.3); rrule_f=rrule_via_ad, check_inferred=false
+ )
+ test_rrule(
+ ZygoteRuleConfig(), makefoo, 1.0, 2.0;
+ rrule_f=rrule_via_ad, check_inferred=false
+ )
+ end
+
+ @testset "tuples/namedtuples" begin
+ my_tuple(a, b, c) = (a+b, b+c)
+ my_namedtuple(a, b, c) = (a=a, b=b, c=0.0)
+
+ test_rrule(
+ ZygoteRuleConfig(), my_tuple, 1., 2., 3.; rrule_f=rrule_via_ad
+ )
+ test_rrule(
+ ZygoteRuleConfig(), my_namedtuple, 1., 2., 3.; rrule_f=rrule_via_ad
+ )
+ test_rrule(
+ ZygoteRuleConfig(), my_namedtuple, 1., (2.0, "str"), 3.; rrule_f=rrule_via_ad
+ )
+ test_rrule(ZygoteRuleConfig(), sum, (1.0, 2.0, 3.0); rrule_f=rrule_via_ad)
+ test_rrule(
+ ZygoteRuleConfig(), sum, (a=1.0, b=2.0); rrule_f=rrule_via_ad, check_inferred=false
+ )
+ end
+
+ @testset "arrays" begin
+ nada(x, y) = 1.0
+ test_rrule(ZygoteRuleConfig(), nada, rand(3), rand(2,3); rrule_f=rrule_via_ad)
+ test_rrule(ZygoteRuleConfig(), +, rand(3), rand(3); rrule_f=rrule_via_ad)
+ test_rrule(ZygoteRuleConfig(), *, rand(1, 3), rand(3); rrule_f=rrule_via_ad)
+ end
+end
+
@testset "FastMath support" begin
@test gradient(2.0) do x
@fastmath x^2.0
diff --git a/test/features.jl b/test/features.jl
index 4c63aca74..d10d60c19 100644
--- a/test/features.jl
+++ b/test/features.jl
@@ -179,13 +179,13 @@ end
@test gradient(x -> x.re*x.im, 2+3im) == ((re = 3, im = 2),)
-struct Foo{T}
+struct Bar{T}
a::T
b::T
end
function mul_struct(a, b)
- c = Foo(a, b)
+ c = Bar(a, b)
c.a * c.b
end
@@ -358,7 +358,7 @@ end
pop!(stk)
end == (1,)
-@test gradient(x -> [x][1].a, Foo(1, 1)) == ((a=1, b=nothing),)
+@test gradient(x -> [x][1].a, Bar(1, 1)) == ((a=1, b=nothing),)
@test gradient((a, b) -> Zygote.hook(-, a)*b, 2, 3) == (-3, 2)
diff --git a/test/gradcheck.jl b/test/gradcheck.jl
index aeea366cf..5d95f3a5c 100644
--- a/test/gradcheck.jl
+++ b/test/gradcheck.jl
@@ -72,13 +72,13 @@ end
@testset "power" begin
@test gradient(x -> x^2, -2) == (-4,)
@test gradient(x -> x^10, -1.0) == (-10,) # literal_pow
- pow = 10
- @test gradient(x -> x^pow, -1.0) == (-pow,)
+ _pow = 10
+ @test gradient(x -> x^_pow, -1.0) == (-_pow,)
@test gradient(p -> real(2^p), 2)[1] ≈ 4*log(2)
@test gradient(xs ->sum(xs .^ 2), [2, -1]) == ([4, -2],)
@test gradient(xs ->sum(xs .^ 10), [3, -1]) == ([10*3^9, -10],)
- @test gradient(xs ->sum(xs .^ pow), [4, -1]) == ([pow*4^9, -10],)
+ @test gradient(xs ->sum(xs .^ _pow), [4, -1]) == ([_pow*4^9, -10],)
@test gradient(x -> real((1+3im) * x^2), 5+7im) == (-32 - 44im,)
@test gradient(p -> real((1+3im) * (5+7im)^p), 2)[1] ≈ (-234 + 2im)*log(5 - 7im)
diff --git a/test/lib/array.jl b/test/lib/array.jl
index 380d1bb8f..6f72a4a2f 100644
--- a/test/lib/array.jl
+++ b/test/lib/array.jl
@@ -1,4 +1,8 @@
+using ChainRulesTestUtils
using LinearAlgebra
+using Zygote: ZygoteRuleConfig
# issue 897
-@test gradient(x -> sum(sin, Diagonal(x)), ones(2)) == ([0.5403023058681398, 0.5403023058681398],)
+
+test_rrule(ZygoteRuleConfig(), x->sum(sin, Diagonal(x)), ones(2); rrule_f=rrule_via_ad, check_inferred=false)
+test_rrule(ZygoteRuleConfig(), x->sum(sin, Diagonal(x)), rand(3); rrule_f=rrule_via_ad, check_inferred=false)
diff --git a/test/runtests.jl b/test/runtests.jl
index f20b59a7e..67893a7a5 100644
--- a/test/runtests.jl
+++ b/test/runtests.jl
@@ -1,5 +1,5 @@
using Zygote, Test
-using Zygote: gradient
+using Zygote: gradient, ZygoteRuleConfig
using CUDA: has_cuda
if has_cuda()
From 13647cd618c45d9c7e691d403e666a64d5d5309a Mon Sep 17 00:00:00 2001
From: Lyndon White
Date: Fri, 18 Jun 2021 18:02:23 +0100
Subject: [PATCH 072/490] bump version
---
Project.toml | 4 ++--
1 file changed, 2 insertions(+), 2 deletions(-)
diff --git a/Project.toml b/Project.toml
index 3fe2ceb58..60476fca1 100644
--- a/Project.toml
+++ b/Project.toml
@@ -1,6 +1,6 @@
name = "Zygote"
uuid = "e88e6eb3-aa80-5325-afca-941959d7151f"
-version = "0.6.12"
+version = "0.6.13"
[deps]
AbstractFFTs = "621f4979-c628-5d54-868e-fcf4e3e8185c"
@@ -24,7 +24,7 @@ ZygoteRules = "700de1a5-db45-46bc-99cf-38207098b444"
[compat]
AbstractFFTs = "0.5, 1.0"
ChainRules = "0.8.12"
-ChainRulesCore = "0.9.44, 0.10"
+ChainRulesCore = "0.10.4"
ChainRulesTestUtils = "0.7.1"
DiffRules = "1.0"
FillArrays = "0.8, 0.9, 0.10, 0.11"
From 61d4eebdc77aa4599e88991e32bf29a6ba8af47b Mon Sep 17 00:00:00 2001
From: Michael Abbott <32575566+mcabbott@users.noreply.github.com>
Date: Fri, 18 Jun 2021 14:13:37 -0400
Subject: [PATCH 073/490] faster generic broadcasting
---
src/lib/broadcast.jl | 16 +++++++++-------
1 file changed, 9 insertions(+), 7 deletions(-)
diff --git a/src/lib/broadcast.jl b/src/lib/broadcast.jl
index 879f62c1a..219a9fbdb 100644
--- a/src/lib/broadcast.jl
+++ b/src/lib/broadcast.jl
@@ -172,23 +172,25 @@ collapse_nothings(xs) = xs
@adjoint function broadcasted(::AbstractArrayStyle, f, args...)
len = inclen(args)
y∂b = _broadcast((x...) -> _pullback(__context__, f, x...), args...)
- y = map(x -> x[1], y∂b)
- ∂b = map(x -> x[2], y∂b)
- y, function (ȳ)
- dxs_zip = map((∂b, ȳ) -> ∂b(ȳ), ∂b, ȳ)
- dxs = collapse_nothings.(ntuple(i -> map(x -> _get(x, i), dxs_zip), len))
+ y = map(first, y∂b)
+ function ∇broadcasted(ȳ)
+ dxs_zip = map((pair, ȳ₁) -> last(pair)(ȳ₁), y∂b, ȳ)
+ dxs = ntuple(len) do i
+ collapse_nothings(map(StaticGetter{i}(), dxs_zip))
+ end
(nothing, accum_sum(dxs[1]), map(unbroadcast, args, Base.tail(dxs))...)
end
+ y, ∇broadcasted
end
@adjoint function broadcasted(::AbstractArrayStyle{0}, f, args...)
- len = inclen(args)
y, ∂b = _broadcast((x...) -> _pullback(__context__, f, x...), args...)
- y, function (ȳ)
+ function ∇broadcasted0(ȳ)
dxs = ∂b(ȳ)
dxs === nothing && return nothing
(nothing, dxs...)
end
+ y, ∇broadcasted0
end
# Use the `map` adjoint in this special case, which is the same but applies
From d949838b2996cfe9db49e840ee974c5a0e0dc24a Mon Sep 17 00:00:00 2001
From: Lyndon White
Date: Fri, 18 Jun 2021 19:18:18 +0100
Subject: [PATCH 074/490] for clarity check of _lookup grad is nothing rather
than is a tuple, thus matching the rest of the code referencing g
---
src/compiler/interface2.jl | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/src/compiler/interface2.jl b/src/compiler/interface2.jl
index b77c7e3a6..141d90a77 100644
--- a/src/compiler/interface2.jl
+++ b/src/compiler/interface2.jl
@@ -24,7 +24,7 @@ end
hascr && return :($chain_rrule_f(ZygoteRuleConfig(ctx), f, args...))
g = try _lookup_grad(T) catch e e end
- !(g isa Tuple) && return :(f(args...), Pullback{$T}((f,)))
+ g === nothing && return :(f(args...), Pullback{$T}((f,)))
meta, forw, _ = g
argnames!(meta, Symbol("#self#"), :ctx, :f, :args)
forw = varargs!(meta, forw, 3)
From 8d9fac7b19fa0cfd511d430fab1a61541c1b7626 Mon Sep 17 00:00:00 2001
From: Lyndon White
Date: Fri, 18 Jun 2021 19:23:19 +0100
Subject: [PATCH 075/490] rename _lookup_grad
---
src/compiler/emit.jl | 2 +-
src/compiler/interface2.jl | 5 +++--
2 files changed, 4 insertions(+), 3 deletions(-)
diff --git a/src/compiler/emit.jl b/src/compiler/emit.jl
index d277621ab..1c82a44f1 100644
--- a/src/compiler/emit.jl
+++ b/src/compiler/emit.jl
@@ -95,7 +95,7 @@ end
varargs(m::Method, n) = m.isva ? n - m.nargs + 1 : nothing
-function _lookup_grad(T)
+function _generate_pullback_via_decomposition(T)
(m = meta(T)) === nothing && return
va = varargs(m.method, length(T.parameters))
forw, back = stacks!(Adjoint(IR(m), varargs = va, normalise = false), T)
diff --git a/src/compiler/interface2.jl b/src/compiler/interface2.jl
index 141d90a77..ac4a5a76a 100644
--- a/src/compiler/interface2.jl
+++ b/src/compiler/interface2.jl
@@ -23,7 +23,7 @@ end
hascr && return :($chain_rrule_f(ZygoteRuleConfig(ctx), f, args...))
- g = try _lookup_grad(T) catch e e end
+ g = try _generate_pullback_via_decomposition(T) catch e e end
g === nothing && return :(f(args...), Pullback{$T}((f,)))
meta, forw, _ = g
argnames!(meta, Symbol("#self#"), :ctx, :f, :args)
@@ -37,7 +37,8 @@ end
@generated function (j::Pullback{T})(Δ) where T
ignore_sig(T) && return :nothing
- g = try _lookup_grad(T)
+ g = try
+ _generate_pullback_via_decomposition(T)
catch e
rethrow(CompileError(T,e))
end
From 7f52f560d9f2bece25771e9a2774c2da1edc3b01 Mon Sep 17 00:00:00 2001
From: Michael Abbott <32575566+mcabbott@users.noreply.github.com>
Date: Fri, 18 Jun 2021 14:38:36 -0400
Subject: [PATCH 076/490] (StaticGetter)(::Nothing)
---
src/lib/array.jl | 1 +
1 file changed, 1 insertion(+)
diff --git a/src/lib/array.jl b/src/lib/array.jl
index f1e217fe2..aa939b2f7 100644
--- a/src/lib/array.jl
+++ b/src/lib/array.jl
@@ -194,6 +194,7 @@ end
struct StaticGetter{i} end
(::StaticGetter{i})(v) where {i} = v[i]
+(::StaticGetter{i})(::Nothing) where {i} = nothing
@generated function _unzip(tuples, ::Val{N}) where {N}
Expr(:tuple, (:(map($(StaticGetter{i}()), tuples)) for i ∈ 1:N)...)
end
From 079c287e36220f070017b8de6f8e259dbf6f2d5e Mon Sep 17 00:00:00 2001
From: Michael Abbott <32575566+mcabbott@users.noreply.github.com>
Date: Fri, 18 Jun 2021 14:39:12 -0400
Subject: [PATCH 077/490] style change
---
src/lib/broadcast.jl | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/src/lib/broadcast.jl b/src/lib/broadcast.jl
index 219a9fbdb..da62f0af1 100644
--- a/src/lib/broadcast.jl
+++ b/src/lib/broadcast.jl
@@ -174,7 +174,7 @@ collapse_nothings(xs) = xs
y∂b = _broadcast((x...) -> _pullback(__context__, f, x...), args...)
y = map(first, y∂b)
function ∇broadcasted(ȳ)
- dxs_zip = map((pair, ȳ₁) -> last(pair)(ȳ₁), y∂b, ȳ)
+ dxs_zip = map(((_, pb), ȳ₁) -> pb(ȳ₁), y∂b, ȳ)
dxs = ntuple(len) do i
collapse_nothings(map(StaticGetter{i}(), dxs_zip))
end
From 458413654a5e53e7002c69c28013e14fab5957e3 Mon Sep 17 00:00:00 2001
From: Michael Abbott <32575566+mcabbott@users.noreply.github.com>
Date: Fri, 18 Jun 2021 14:47:48 -0400
Subject: [PATCH 078/490] use StaticGetter for CuArrays too
---
src/lib/broadcast.jl | 4 ++--
1 file changed, 2 insertions(+), 2 deletions(-)
diff --git a/src/lib/broadcast.jl b/src/lib/broadcast.jl
index da62f0af1..16e1cdfa9 100644
--- a/src/lib/broadcast.jl
+++ b/src/lib/broadcast.jl
@@ -229,8 +229,8 @@ end
out = dual_function(f).(args...)
eltype(out) <: Dual || return (out, _ -> nothing)
y = map(x -> x.value, out)
- _back(ȳ, i) = unbroadcast(args[i], ((a, b) -> a*b.partials[i]).(ȳ, out))
- back(ȳ) = ntuple(i -> _back(ȳ, i), N)
+ _back(ȳ, geti) = unbroadcast(geti(args), ((a, b) -> a * geti(b.partials)).(ȳ, out))
+ back(ȳ) = ntuple(i -> _back(ȳ, StaticGetter{i}()), N)
return y, back
end
From b81c2e30e6a9b91363146cfd053db64cdd0b2eb0 Mon Sep 17 00:00:00 2001
From: Michael Abbott <32575566+mcabbott@users.noreply.github.com>
Date: Fri, 18 Jun 2021 14:48:45 -0400
Subject: [PATCH 079/490] rm _get
---
src/lib/broadcast.jl | 8 +++-----
1 file changed, 3 insertions(+), 5 deletions(-)
diff --git a/src/lib/broadcast.jl b/src/lib/broadcast.jl
index 16e1cdfa9..1f3e1d076 100644
--- a/src/lib/broadcast.jl
+++ b/src/lib/broadcast.jl
@@ -164,23 +164,21 @@ end
# Avoid hitting special cases for `Adjoint` etc.
_broadcast(f::F, x...) where F = materialize(broadcasted(f, x...))
-_get(x::Tuple, i) = x[i]
-_get(::Nothing, i) = nothing
collapse_nothings(xs::AbstractArray{Nothing}) = nothing
collapse_nothings(xs) = xs
-@adjoint function broadcasted(::AbstractArrayStyle, f, args...)
+@adjoint function broadcasted(::AbstractArrayStyle, f::F, args...) where {F}
len = inclen(args)
y∂b = _broadcast((x...) -> _pullback(__context__, f, x...), args...)
y = map(first, y∂b)
- function ∇broadcasted(ȳ)
+ function ∇broadcasted(ȳ,y∂b::G) where {G}
dxs_zip = map(((_, pb), ȳ₁) -> pb(ȳ₁), y∂b, ȳ)
dxs = ntuple(len) do i
collapse_nothings(map(StaticGetter{i}(), dxs_zip))
end
(nothing, accum_sum(dxs[1]), map(unbroadcast, args, Base.tail(dxs))...)
end
- y, ∇broadcasted
+ y, Base.Fix2(∇broadcasted,y∂b)
end
@adjoint function broadcasted(::AbstractArrayStyle{0}, f, args...)
From 072035832e0890d6f529d42818c2b88cf2e84fbf Mon Sep 17 00:00:00 2001
From: Michael Abbott <32575566+mcabbott@users.noreply.github.com>
Date: Fri, 18 Jun 2021 20:27:07 -0400
Subject: [PATCH 080/490] Revert "use StaticGetter for CuArrays too", ++
This reverts commit 458413654a5e53e7002c69c28013e14fab5957e3.
---
src/lib/broadcast.jl | 10 +++++-----
1 file changed, 5 insertions(+), 5 deletions(-)
diff --git a/src/lib/broadcast.jl b/src/lib/broadcast.jl
index 1f3e1d076..2098001b7 100644
--- a/src/lib/broadcast.jl
+++ b/src/lib/broadcast.jl
@@ -167,18 +167,18 @@ _broadcast(f::F, x...) where F = materialize(broadcasted(f, x...))
collapse_nothings(xs::AbstractArray{Nothing}) = nothing
collapse_nothings(xs) = xs
-@adjoint function broadcasted(::AbstractArrayStyle, f::F, args...) where {F}
+@adjoint function broadcasted(::AbstractArrayStyle, f, args...)
len = inclen(args)
y∂b = _broadcast((x...) -> _pullback(__context__, f, x...), args...)
y = map(first, y∂b)
- function ∇broadcasted(ȳ,y∂b::G) where {G}
+ function ∇broadcasted(ȳ)
dxs_zip = map(((_, pb), ȳ₁) -> pb(ȳ₁), y∂b, ȳ)
dxs = ntuple(len) do i
collapse_nothings(map(StaticGetter{i}(), dxs_zip))
end
(nothing, accum_sum(dxs[1]), map(unbroadcast, args, Base.tail(dxs))...)
end
- y, Base.Fix2(∇broadcasted,y∂b)
+ y, ∇broadcasted
end
@adjoint function broadcasted(::AbstractArrayStyle{0}, f, args...)
@@ -227,8 +227,8 @@ end
out = dual_function(f).(args...)
eltype(out) <: Dual || return (out, _ -> nothing)
y = map(x -> x.value, out)
- _back(ȳ, geti) = unbroadcast(geti(args), ((a, b) -> a * geti(b.partials)).(ȳ, out))
- back(ȳ) = ntuple(i -> _back(ȳ, StaticGetter{i}()), N)
+ _back(ȳ, i) = unbroadcast(args[i], ((a, b) -> a*b.partials[i]).(ȳ, out))
+ back(ȳ) = ntuple(i -> _back(ȳ, i), N)
return y, back
end
From 922716865534a678631acf56bf680ab892f8eba5 Mon Sep 17 00:00:00 2001
From: Michael Abbott <32575566+mcabbott@users.noreply.github.com>
Date: Fri, 18 Jun 2021 20:41:48 -0400
Subject: [PATCH 081/490] use forward mode sometimes
---
src/lib/broadcast.jl | 8 +++++++-
1 file changed, 7 insertions(+), 1 deletion(-)
diff --git a/src/lib/broadcast.jl b/src/lib/broadcast.jl
index 2098001b7..418003d9f 100644
--- a/src/lib/broadcast.jl
+++ b/src/lib/broadcast.jl
@@ -167,7 +167,13 @@ _broadcast(f::F, x...) where F = materialize(broadcasted(f, x...))
collapse_nothings(xs::AbstractArray{Nothing}) = nothing
collapse_nothings(xs) = xs
-@adjoint function broadcasted(::AbstractArrayStyle, f, args...)
+@adjoint function broadcasted(::AbstractArrayStyle, f::F, args...) where {F}
+ # When safe, avoid generic broadcast & use ForwardDiff instead, often 100x faster
+ if all(a -> a isa Numeric{<:Real}, args) && Broadcast.combine_eltypes(f, args) <: Real
+ y, back = broadcast_forward(f, args...)
+ return y, ȳ -> (nothing, nothing, back(ȳ)...)
+ end
+
len = inclen(args)
y∂b = _broadcast((x...) -> _pullback(__context__, f, x...), args...)
y = map(first, y∂b)
From 9578ae7e28e738fa2e3a73298dccb31d8bc96b23 Mon Sep 17 00:00:00 2001
From: Michael Abbott <32575566+mcabbott@users.noreply.github.com>
Date: Fri, 18 Jun 2021 21:28:11 -0400
Subject: [PATCH 082/490] safer version
---
src/lib/broadcast.jl | 12 +++++++++---
1 file changed, 9 insertions(+), 3 deletions(-)
diff --git a/src/lib/broadcast.jl b/src/lib/broadcast.jl
index 418003d9f..ddb57c5ce 100644
--- a/src/lib/broadcast.jl
+++ b/src/lib/broadcast.jl
@@ -167,13 +167,19 @@ _broadcast(f::F, x...) where F = materialize(broadcasted(f, x...))
collapse_nothings(xs::AbstractArray{Nothing}) = nothing
collapse_nothings(xs) = xs
+_purefun(::Type{F}) where {F<:Function} = isempty(fieldnames(F))
+_purefun(::Type{ComposedFunction{F,G}}) where {F,G} = _purefun(F) && _purefun(G)
+_purefun(::Type) = false
+
@adjoint function broadcasted(::AbstractArrayStyle, f::F, args...) where {F}
- # When safe, avoid generic broadcast & use ForwardDiff instead, often 100x faster
- if all(a -> a isa Numeric{<:Real}, args) && Broadcast.combine_eltypes(f, args) <: Real
+ T = Broadcast.combine_eltypes(f, args)
+ # Avoid generic broadcasting in two easy cases:
+ if T == Bool
+ return f.(args...), _->nothing
+ elseif T <: Real && _purefun(F) && all(a -> a isa Numeric{<:Real}, args)
y, back = broadcast_forward(f, args...)
return y, ȳ -> (nothing, nothing, back(ȳ)...)
end
-
len = inclen(args)
y∂b = _broadcast((x...) -> _pullback(__context__, f, x...), args...)
y = map(first, y∂b)
From bdb1d1f3796a8403eedc8b5e99169e13f7b1d940 Mon Sep 17 00:00:00 2001
From: Michael Abbott <32575566+mcabbott@users.noreply.github.com>
Date: Fri, 18 Jun 2021 21:44:02 -0400
Subject: [PATCH 083/490] power
---
src/lib/broadcast.jl | 2 ++
1 file changed, 2 insertions(+)
diff --git a/src/lib/broadcast.jl b/src/lib/broadcast.jl
index ddb57c5ce..b16572be6 100644
--- a/src/lib/broadcast.jl
+++ b/src/lib/broadcast.jl
@@ -171,6 +171,8 @@ _purefun(::Type{F}) where {F<:Function} = isempty(fieldnames(F))
_purefun(::Type{ComposedFunction{F,G}}) where {F,G} = _purefun(F) && _purefun(G)
_purefun(::Type) = false
+_purefun(::Type{typeof(^)}) = false # fix @testset "power" & @testset "diagonal hessian"
+
@adjoint function broadcasted(::AbstractArrayStyle, f::F, args...) where {F}
T = Broadcast.combine_eltypes(f, args)
# Avoid generic broadcasting in two easy cases:
From f7e96b19bb618f662dac4138c8f9f022d3bb1ac2 Mon Sep 17 00:00:00 2001
From: Michael Abbott <32575566+mcabbott@users.noreply.github.com>
Date: Fri, 18 Jun 2021 21:58:14 -0400
Subject: [PATCH 084/490] ComposedFunction is new
---
src/lib/broadcast.jl | 5 +++--
1 file changed, 3 insertions(+), 2 deletions(-)
diff --git a/src/lib/broadcast.jl b/src/lib/broadcast.jl
index b16572be6..f189c8246 100644
--- a/src/lib/broadcast.jl
+++ b/src/lib/broadcast.jl
@@ -168,9 +168,10 @@ collapse_nothings(xs::AbstractArray{Nothing}) = nothing
collapse_nothings(xs) = xs
_purefun(::Type{F}) where {F<:Function} = isempty(fieldnames(F))
-_purefun(::Type{ComposedFunction{F,G}}) where {F,G} = _purefun(F) && _purefun(G)
_purefun(::Type) = false
-
+if VERSION >= v"1.6"
+ _purefun(::Type{ComposedFunction{F,G}}) where {F,G} = _purefun(F) && _purefun(G)
+end
_purefun(::Type{typeof(^)}) = false # fix @testset "power" & @testset "diagonal hessian"
@adjoint function broadcasted(::AbstractArrayStyle, f::F, args...) where {F}
From c34b2b98f1c26304e0ba5242b2b37cfe67432e1f Mon Sep 17 00:00:00 2001
From: Michael Abbott <32575566+mcabbott@users.noreply.github.com>
Date: Fri, 18 Jun 2021 22:44:17 -0400
Subject: [PATCH 085/490] add & fix some tests
---
src/lib/broadcast.jl | 2 +-
test/features.jl | 32 ++++++++++++++++++++++----------
test/gradcheck.jl | 3 ++-
3 files changed, 25 insertions(+), 12 deletions(-)
diff --git a/src/lib/broadcast.jl b/src/lib/broadcast.jl
index f189c8246..738508e53 100644
--- a/src/lib/broadcast.jl
+++ b/src/lib/broadcast.jl
@@ -179,7 +179,7 @@ _purefun(::Type{typeof(^)}) = false # fix @testset "power" & @testset "diagonal
# Avoid generic broadcasting in two easy cases:
if T == Bool
return f.(args...), _->nothing
- elseif T <: Real && _purefun(F) && all(a -> a isa Numeric{<:Real}, args)
+ elseif isconcretetype(T) && T <: Real && _purefun(F) && all(a -> a isa Numeric{<:Real}, args)
y, back = broadcast_forward(f, args...)
return y, ȳ -> (nothing, nothing, back(ȳ)...)
end
diff --git a/test/features.jl b/test/features.jl
index d10d60c19..57f743958 100644
--- a/test/features.jl
+++ b/test/features.jl
@@ -500,14 +500,26 @@ end
@test 150_000_000 > @allocated gradient(loss, ones(1000,1000))
end
-@testset "tuples & broadcasting" begin
- @test gradient(x -> sum(x .+ ones(2,2)), (1,2)) == ((2,2),)
- @test gradient(x -> sum(x .+ ones(2,2)), (1,)) == ((4,),)
- @test gradient(x -> sum(x .+ ones(2,1)), (1,2)) == ((1,1),)
-
- # https://github.com/FluxML/Zygote.jl/issues/975
- gt = gradient((x,p) -> prod(x .^ p), [3,4], (1,2))
- gv = gradient((x,p) -> prod(x .^ p), [3,4], [1,2])
- @test gt[1] == gv[1]
- @test collect(gt[2]) ≈ gv[2]
+@testset "tricky broadcasting" begin
+ @test gradient(x -> sum(x .+ ones(2,2)), (1,2)) == ((2,2),)
+ @test gradient(x -> sum(x .+ ones(2,2)), (1,)) == ((4,),)
+ @test gradient(x -> sum(x .+ ones(2,1)), (1,2)) == ((1,1),)
+
+ # https://github.com/FluxML/Zygote.jl/issues/975
+ gt = gradient((x,p) -> prod(x .^ p), [3,4], (1,2))
+ gv = gradient((x,p) -> prod(x .^ p), [3,4], [1,2])
+ @test gt[1] == gv[1]
+ @test collect(gt[2]) ≈ gv[2]
+
+ # closure captures y -- can't use ForwardDiff
+ @test gradient((x,y) -> sum((z->z^2+y[1]).(x)), [1,2,3], [4,5]) == ([2, 4, 6], [3, 0])
+ @test gradient((x,y) -> sum((z->z^2+y[1]), x), [1,2,3], [4,5]) == ([2, 4, 6], [3, 0])
+ @test gradient((x,y) -> sum(map((z->z^2+y[1]), x)), [1,2,3], [4,5]) == ([2, 4, 6], [3, 0])
+ @test gradient((x,y) -> mapreduce((z->z^2+y[1]), +, x), [1,2,3], [4,5]) == ([2, 4, 6], [3, 0])
+
+ # type unstable
+ @test gradient(xs -> sum((x -> x<2 ? false : x^2).(xs)), [1,2,3])[1][2:3] == [4, 6]
+ @test gradient(xs -> sum((x -> x<2 ? false : x^2), xs), [1,2,3])[1][2:3] == [4, 6]
+ @test gradient(xs -> sum(map((x -> x<2 ? false : x^2), xs)), [1,2,3])[1][2:3] == [4, 6]
+ @test gradient(xs -> mapreduce((x -> x<2 ? false : x^2), +, xs), [1,2,3])[1][2:3] == [4, 6]
end
diff --git a/test/gradcheck.jl b/test/gradcheck.jl
index 5d95f3a5c..0baa0bb53 100644
--- a/test/gradcheck.jl
+++ b/test/gradcheck.jl
@@ -1295,7 +1295,8 @@ end
end
@testset "broadcast" begin
- @test gradient(x -> sum(sin.(x)), Diagonal(randn(3)))[1][2] == 1
+ # Before https://github.com/FluxML/Zygote.jl/pull/1001 this gave [1 1 1; 1 0 1; 1 1 -1]
+ @test gradient(x -> sum(sin.(x)), Diagonal([0,pi/2,pi]))[1] ≈ [1 0 0; 0 0 0; 0 0 -1]
a = rand(3)
b = rand(2,2)
From 2c10638b894e86ebd104269a56593ed36188c955 Mon Sep 17 00:00:00 2001
From: Michael Abbott <32575566+mcabbott@users.noreply.github.com>
Date: Fri, 18 Jun 2021 23:57:27 -0400
Subject: [PATCH 086/490] widen types
---
src/lib/broadcast.jl | 9 ++++++++-
test/features.jl | 8 ++++++++
2 files changed, 16 insertions(+), 1 deletion(-)
diff --git a/src/lib/broadcast.jl b/src/lib/broadcast.jl
index 738508e53..8e3a0a2bb 100644
--- a/src/lib/broadcast.jl
+++ b/src/lib/broadcast.jl
@@ -174,12 +174,19 @@ if VERSION >= v"1.6"
end
_purefun(::Type{typeof(^)}) = false # fix @testset "power" & @testset "diagonal hessian"
+_dualsafe(x::Numeric{<:Real}) = true
+_dualsafe(x::Ref{<:Numeric{<:Real}}) = true
+_dualsafe(x::Val) = true
+_dualsafe(x::Type) = true
+_dualsafe(x::Symbol) = true
+_dualsafe(x) = false
+
@adjoint function broadcasted(::AbstractArrayStyle, f::F, args...) where {F}
T = Broadcast.combine_eltypes(f, args)
# Avoid generic broadcasting in two easy cases:
if T == Bool
return f.(args...), _->nothing
- elseif isconcretetype(T) && T <: Real && _purefun(F) && all(a -> a isa Numeric{<:Real}, args)
+ elseif T <: Real && isconcretetype(T) && _purefun(F) && all(_dualsafe, args)
y, back = broadcast_forward(f, args...)
return y, ȳ -> (nothing, nothing, back(ȳ)...)
end
diff --git a/test/features.jl b/test/features.jl
index 57f743958..2bd2f828b 100644
--- a/test/features.jl
+++ b/test/features.jl
@@ -522,4 +522,12 @@ end
@test gradient(xs -> sum((x -> x<2 ? false : x^2), xs), [1,2,3])[1][2:3] == [4, 6]
@test gradient(xs -> sum(map((x -> x<2 ? false : x^2), xs)), [1,2,3])[1][2:3] == [4, 6]
@test gradient(xs -> mapreduce((x -> x<2 ? false : x^2), +, xs), [1,2,3])[1][2:3] == [4, 6]
+
+ # with Ref, Val, Symbol
+ @test gradient(x -> sum(x .+ Ref(x[1])), [1,2,3]) == ([4,1,1],)
+ @test gradient(x -> sum(x .+ (x[1],)), [1,2,3]) == ([4,1,1],)
+ @test gradient(x -> sum((first∘tuple).(x, :ignore)), [1,2,3]) == ([1,1,1],)
+ @test gradient(x -> sum((first∘tuple).(x, Symbol)), [1,2,3]) == ([1,1,1],)
+ _f(x,::Val{y}) where {y} = x/y
+ @test gradient(x -> sum(_f.(x, Val(2))), [1,2,3]) == ([0.5, 0.5, 0.5],)
end
From 86c3bb49314ea73ec19a2a4a6b70d6c1a44dc82b Mon Sep 17 00:00:00 2001
From: Michael Abbott <32575566+mcabbott@users.noreply.github.com>
Date: Sat, 19 Jun 2021 10:25:42 -0400
Subject: [PATCH 087/490] simpler purity check
---
src/lib/broadcast.jl | 5 +----
test/features.jl | 7 ++++++-
2 files changed, 7 insertions(+), 5 deletions(-)
diff --git a/src/lib/broadcast.jl b/src/lib/broadcast.jl
index 8e3a0a2bb..49efbc279 100644
--- a/src/lib/broadcast.jl
+++ b/src/lib/broadcast.jl
@@ -167,11 +167,8 @@ _broadcast(f::F, x...) where F = materialize(broadcasted(f, x...))
collapse_nothings(xs::AbstractArray{Nothing}) = nothing
collapse_nothings(xs) = xs
-_purefun(::Type{F}) where {F<:Function} = isempty(fieldnames(F))
+_purefun(::Type{F}) where {F<:Function} = Base.issingletontype(F)
_purefun(::Type) = false
-if VERSION >= v"1.6"
- _purefun(::Type{ComposedFunction{F,G}}) where {F,G} = _purefun(F) && _purefun(G)
-end
_purefun(::Type{typeof(^)}) = false # fix @testset "power" & @testset "diagonal hessian"
_dualsafe(x::Numeric{<:Real}) = true
diff --git a/test/features.jl b/test/features.jl
index 2bd2f828b..dd39a75e8 100644
--- a/test/features.jl
+++ b/test/features.jl
@@ -528,6 +528,11 @@ end
@test gradient(x -> sum(x .+ (x[1],)), [1,2,3]) == ([4,1,1],)
@test gradient(x -> sum((first∘tuple).(x, :ignore)), [1,2,3]) == ([1,1,1],)
@test gradient(x -> sum((first∘tuple).(x, Symbol)), [1,2,3]) == ([1,1,1],)
- _f(x,::Val{y}) where {y} = x/y
+ _f(x,::Val{y}=Val(2)) where {y} = x/y
@test gradient(x -> sum(_f.(x, Val(2))), [1,2,3]) == ([0.5, 0.5, 0.5],)
+ @test gradient(x -> sum(_f.(x)), [1,2,3]) == ([0.5, 0.5, 0.5],)
+ @test gradient(x -> sum(map(_f, x)), [1,2,3]) == ([0.5, 0.5, 0.5],)
+
+ @test gradient(x -> sum(x ./ [1,2,4]), [1,2,pi]) == ([1.0, 0.5, 0.25],)
+ @test gradient(x -> sum(map(/, x, [1,2,4])), [1,2,pi]) == ([1.0, 0.5, 0.25],)
end
From 570db214af521bd20d64b1c19b39e0f0443cec7c Mon Sep 17 00:00:00 2001
From: Michael Abbott <32575566+mcabbott@users.noreply.github.com>
Date: Sat, 19 Jun 2021 10:27:30 -0400
Subject: [PATCH 088/490] use purity check in map, too?
---
src/lib/array.jl | 22 +++++++++++++++-------
1 file changed, 15 insertions(+), 7 deletions(-)
diff --git a/src/lib/array.jl b/src/lib/array.jl
index aa939b2f7..7b597f32c 100644
--- a/src/lib/array.jl
+++ b/src/lib/array.jl
@@ -215,19 +215,27 @@ _tryreverse(m, x) = x
_tryreverse(m::typeof(map), x::Union{AbstractVector, Tuple}) = reverse(x)
for (mapfunc,∇mapfunc) in [(:map,:∇map),(:pmap,:∇pmap)]
- @eval function $∇mapfunc(cx, f, args...)
+ @eval function $∇mapfunc(cx, f::F, args...) where {F}
ys_and_backs = $mapfunc((args...) -> _pullback(cx, f, args...), args...)
if isempty(ys_and_backs)
ys_and_backs, _ -> nothing
else
- ys, backs = unzip(ys_and_backs)
+ ys = map(first, ys_and_backs)
ys, function (Δ)
isnothing(Δ) && return nothing
- # Apply pullbacks in reverse order. Needed for correctness if `f` is stateful.
- Δf_and_args_zipped = $mapfunc((f, δ) -> f(δ), _tryreverse($mapfunc, backs, Δ)...)
- Δf_and_args = unzip(_tryreverse($mapfunc, Δf_and_args_zipped))
- Δf = reduce(accum, Δf_and_args[1])
- (Δf, Δf_and_args[2:end]...)
+ if _purefun(F) && length(args) == 1
+ Δarg = $mapfunc(((_,pb), δ) -> last(pb(δ)), ys_and_backs, Δ) # No unzip needed
+ (nothing, Δarg)
+ elseif _purefun(F)
+ Δargs = unzip($mapfunc(((_,pb), δ) -> Base.tail(pb(δ)), ys_and_backs, Δ))
+ (nothing, Δargs...)
+ else
+ # Apply pullbacks in reverse order. Needed for correctness if `f` is stateful.
+ Δf_and_args_zipped = $mapfunc(((_,pb), δ) -> pb(δ), _tryreverse($mapfunc, ys_and_backs, Δ)...)
+ Δf_and_args = unzip(_tryreverse($mapfunc, Δf_and_args_zipped))
+ Δf = reduce(accum, Δf_and_args[1])
+ (Δf, Δf_and_args[2:end]...)
+ end
end
end
end
From 72eb6810c5712aa5eb63ed4e2573f9e77ce593a5 Mon Sep 17 00:00:00 2001
From: Michael Abbott <32575566+mcabbott@users.noreply.github.com>
Date: Sat, 19 Jun 2021 12:30:12 -0400
Subject: [PATCH 089/490] simplify
---
src/lib/array.jl | 4 ++--
1 file changed, 2 insertions(+), 2 deletions(-)
diff --git a/src/lib/array.jl b/src/lib/array.jl
index 7b597f32c..aaf122cb4 100644
--- a/src/lib/array.jl
+++ b/src/lib/array.jl
@@ -223,10 +223,10 @@ for (mapfunc,∇mapfunc) in [(:map,:∇map),(:pmap,:∇pmap)]
ys = map(first, ys_and_backs)
ys, function (Δ)
isnothing(Δ) && return nothing
- if _purefun(F) && length(args) == 1
+ if Base.issingletontype(F) && length(args) == 1
Δarg = $mapfunc(((_,pb), δ) -> last(pb(δ)), ys_and_backs, Δ) # No unzip needed
(nothing, Δarg)
- elseif _purefun(F)
+ elseif Base.issingletontype(F) # Ensures `f` is pure: nothing captured & no state
Δargs = unzip($mapfunc(((_,pb), δ) -> Base.tail(pb(δ)), ys_and_backs, Δ))
(nothing, Δargs...)
else
From fbed3cec876e4c81614f6353dcf91604b1c441fc Mon Sep 17 00:00:00 2001
From: Michael Abbott <32575566+mcabbott@users.noreply.github.com>
Date: Sat, 19 Jun 2021 12:32:04 -0400
Subject: [PATCH 090/490] use Base.ismutabletype
---
src/lib/lib.jl | 10 +++++-----
1 file changed, 5 insertions(+), 5 deletions(-)
diff --git a/src/lib/lib.jl b/src/lib/lib.jl
index 045a494de..2e57850f0 100644
--- a/src/lib/lib.jl
+++ b/src/lib/lib.jl
@@ -1,4 +1,4 @@
-using Base: RefValue
+using Base: RefValue, ismutabletype
# Interfaces
@@ -278,19 +278,19 @@ Jnew{T}(g) where T = Jnew{T,typeof(g)}(g)
@adjoint! function __new__(T, args...)
x = __new__(T, args...)
- g = !T.mutable || fieldcount(T) == 0 ? nothing : grad_mut(__context__, x)
+ g = !ismutabletype(T) || fieldcount(T) == 0 ? nothing : grad_mut(__context__, x)
x, Jnew{T,typeof(g),false}(g)
end
@adjoint! function __splatnew__(T, args)
x = __splatnew__(T, args)
- g = !T.mutable || fieldcount(T) == 0 ? nothing : grad_mut(__context__, x)
+ g = !ismutabletype(T) || fieldcount(T) == 0 ? nothing : grad_mut(__context__, x)
x, Jnew{T,typeof(g),true}(g)
end
# TODO captured mutables + multiple calls to `back`
@generated function (back::Jnew{T,G,false})(Δ::Union{NamedTuple,Nothing,RefValue}) where {T,G}
- !T.mutable && Δ == Nothing && return :nothing
+ !ismutabletype(T) && Δ == Nothing && return :nothing
Δ = G == Nothing ? :Δ :
Δ <: RefValue ? :(back.g[]) :
:(accum(back.g[], Δ))
@@ -302,7 +302,7 @@ end
end
@generated function (back::Jnew{T,G,true})(Δ::Union{NamedTuple,Nothing,RefValue}) where {T,G}
- !T.mutable && Δ == Nothing && return :nothing
+ !ismutabletype(T) && Δ == Nothing && return :nothing
Δ = G == Nothing ? :Δ : :(back.g)
quote
x̄ = $Δ
From 58a13685428fa96b5cc1b055dde2c48885926bd4 Mon Sep 17 00:00:00 2001
From: Michael Abbott <32575566+mcabbott@users.noreply.github.com>
Date: Sat, 19 Jun 2021 12:42:03 -0400
Subject: [PATCH 091/490] add version check
---
src/lib/lib.jl | 8 +++++++-
1 file changed, 7 insertions(+), 1 deletion(-)
diff --git a/src/lib/lib.jl b/src/lib/lib.jl
index 2e57850f0..762939ae3 100644
--- a/src/lib/lib.jl
+++ b/src/lib/lib.jl
@@ -1,4 +1,10 @@
-using Base: RefValue, ismutabletype
+using Base: RefValue
+
+if VERSION > v"1.7.0-DEV.204"
+ using Base: ismutabletype
+else
+ ismutabletype(::Type{T}) where T = T.mutable
+end
# Interfaces
From ec1a41bf86d59ac6017bb76783d34c9a6eeccf93 Mon Sep 17 00:00:00 2001
From: Michael Abbott <32575566+mcabbott@users.noreply.github.com>
Date: Sat, 19 Jun 2021 12:57:47 -0400
Subject: [PATCH 092/490] make ismutabletype look like Base's
Co-authored-by: Johnny Chen
---
src/lib/lib.jl | 5 ++++-
1 file changed, 4 insertions(+), 1 deletion(-)
diff --git a/src/lib/lib.jl b/src/lib/lib.jl
index 762939ae3..96422d78c 100644
--- a/src/lib/lib.jl
+++ b/src/lib/lib.jl
@@ -3,7 +3,10 @@ using Base: RefValue
if VERSION > v"1.7.0-DEV.204"
using Base: ismutabletype
else
- ismutabletype(::Type{T}) where T = T.mutable
+ function ismutabletype(@nospecialize(t::Type))
+ t = Base.unwrap_unionall(t)
+ return isa(t, DataType) && t.mutable
+ end
end
# Interfaces
From 8f4354f752525d9e9dc1a9bfbebf5c9f379681eb Mon Sep 17 00:00:00 2001
From: Lyndon White
Date: Sun, 20 Jun 2021 20:44:25 +0100
Subject: [PATCH 093/490] add tests to see if sum(f,x) is broken on GPU
---
test/cuda.jl | 11 +++++++++++
1 file changed, 11 insertions(+)
diff --git a/test/cuda.jl b/test/cuda.jl
index a54402999..59a7dda56 100644
--- a/test/cuda.jl
+++ b/test/cuda.jl
@@ -26,6 +26,17 @@ end
@test g_gpu |> collect ≈ g
end
+@testset "sum(f, x)" begin
+ a = Float32.(-4:4)
+ a_gpu = a |> cu
+
+ f(x) = sum(abs, x)
+ g = gradient(f, a)[1]
+ g_gpu = gradient(f, a_gpu)[1]
+ @test g_gpu isa CuArray
+ @test g_gpu |> collect ≈ g
+end
+
@testset "jacobian" begin
v1 = cu(collect(1:3f0))
From 075f530aa6efd1c1976216c297716f7e5775dcdb Mon Sep 17 00:00:00 2001
From: Lyndon White
Date: Sun, 20 Jun 2021 20:55:53 +0100
Subject: [PATCH 094/490] add back old way of doing sum(f, xs) for CuArrays
only
---
src/lib/broadcast.jl | 9 ++++++++-
1 file changed, 8 insertions(+), 1 deletion(-)
diff --git a/src/lib/broadcast.jl b/src/lib/broadcast.jl
index 879f62c1a..9e89d43b9 100644
--- a/src/lib/broadcast.jl
+++ b/src/lib/broadcast.jl
@@ -254,7 +254,14 @@ end
placeholder = similar(xs)
sum(xs, dims = dims), Δ -> (placeholder .= Δ,)
end
-
+
+ # Make sure sum(f, ::CuArray) uses broadcase through forward-mode defined above
+ # Not the ChainRules.rrule which will use the Zygote.Context and thus not be GPU compatible
+ @adjoint function sum(f, xs::CuArray; kws...)
+ @assert !haskey(kws, :init) # TODO add init support (julia 1.6)
+ return pullback(__context__, (f, xs) -> sum(f.(xs); kws...), f, xs)
+ end
+
@adjoint function Base.convert(::Type{T}, xs::Array) where {T<:CUDA.CuArray}
Base.convert(T, xs), Δ -> (nothing, Base.convert(Array, Δ),)
end
From 22602b3633a7857e74720a583eb1a6c68dea9dcc Mon Sep 17 00:00:00 2001
From: Lyndon White
Date: Sun, 20 Jun 2021 21:33:21 +0100
Subject: [PATCH 095/490] Qualified names
---
src/lib/broadcast.jl | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/src/lib/broadcast.jl b/src/lib/broadcast.jl
index 9e89d43b9..0689c1627 100644
--- a/src/lib/broadcast.jl
+++ b/src/lib/broadcast.jl
@@ -257,7 +257,7 @@ end
# Make sure sum(f, ::CuArray) uses broadcase through forward-mode defined above
# Not the ChainRules.rrule which will use the Zygote.Context and thus not be GPU compatible
- @adjoint function sum(f, xs::CuArray; kws...)
+ @adjoint function sum(f, xs::CUDA.CuArray; kws...)
@assert !haskey(kws, :init) # TODO add init support (julia 1.6)
return pullback(__context__, (f, xs) -> sum(f.(xs); kws...), f, xs)
end
From 7776fdd20cb1e0f4f8eae93208b24c7985b8f552 Mon Sep 17 00:00:00 2001
From: Lyndon White
Date: Sun, 20 Jun 2021 21:43:29 +0100
Subject: [PATCH 096/490] don't test gradient of abs at 0
---
test/cuda.jl | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/test/cuda.jl b/test/cuda.jl
index 59a7dda56..7f34aa99f 100644
--- a/test/cuda.jl
+++ b/test/cuda.jl
@@ -27,7 +27,7 @@ end
end
@testset "sum(f, x)" begin
- a = Float32.(-4:4)
+ a = Float32.([-1.5, -9.0, 2.4, -1.3, 0.01])
a_gpu = a |> cu
f(x) = sum(abs, x)
From bd8c5fb2e9a7659b1ba56d4dd6c4aa6a2822b1a8 Mon Sep 17 00:00:00 2001
From: Lyndon White
Date: Sun, 20 Jun 2021 21:43:55 +0100
Subject: [PATCH 097/490] Bump version
---
Project.toml | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/Project.toml b/Project.toml
index 60476fca1..e7109b349 100644
--- a/Project.toml
+++ b/Project.toml
@@ -1,6 +1,6 @@
name = "Zygote"
uuid = "e88e6eb3-aa80-5325-afca-941959d7151f"
-version = "0.6.13"
+version = "0.6.14"
[deps]
AbstractFFTs = "621f4979-c628-5d54-868e-fcf4e3e8185c"
From 7ebb75b01ab26faf7db0efc7ca53239e371cc974 Mon Sep 17 00:00:00 2001
From: Michael Abbott <32575566+mcabbott@users.noreply.github.com>
Date: Mon, 21 Jun 2021 21:46:39 -0400
Subject: [PATCH 098/490] rename, tidy, improve
---
src/lib/broadcast.jl | 34 ++++++++++++++++++++--------------
test/features.jl | 6 ++++++
2 files changed, 26 insertions(+), 14 deletions(-)
diff --git a/src/lib/broadcast.jl b/src/lib/broadcast.jl
index 49efbc279..46882acad 100644
--- a/src/lib/broadcast.jl
+++ b/src/lib/broadcast.jl
@@ -167,25 +167,31 @@ _broadcast(f::F, x...) where F = materialize(broadcasted(f, x...))
collapse_nothings(xs::AbstractArray{Nothing}) = nothing
collapse_nothings(xs) = xs
-_purefun(::Type{F}) where {F<:Function} = Base.issingletontype(F)
-_purefun(::Type) = false
-_purefun(::Type{typeof(^)}) = false # fix @testset "power" & @testset "diagonal hessian"
+_dual_purefun(::Type{F}) where {F<:Function} = Base.issingletontype(F)
+_dual_purefun(::Type) = false
+_dual_purefun(::Type{typeof(^)}) = false # avoid DomainError from negative powers
-_dualsafe(x::Numeric{<:Real}) = true
-_dualsafe(x::Ref{<:Numeric{<:Real}}) = true
-_dualsafe(x::Val) = true
-_dualsafe(x::Type) = true
-_dualsafe(x::Symbol) = true
-_dualsafe(x) = false
+_dual_safearg(x::Numeric{<:Real}) = true
+_dual_safearg(x::Ref{<:Numeric{<:Real}}) = true
+_dual_safearg(x::Union{Type,Val,Symbol}) = true # non-differentiable types
+_dual_safearg(x) = false
+
+# This is Broadcast.combine_eltypes but with dual eltypes:
+_combine_dual_eltypes(f, args::Tuple) =
+ Broadcast.promote_typejoin_union(Base._return_type(f, map(_dual_eltype, args)))
+_dual_eltype(x::Numeric{T}) where {T<:Real} = Dual{Nothing, T, 1} # typeof(Dual(one(T),true))
+_dual_eltype(x) = eltype(x)
@adjoint function broadcasted(::AbstractArrayStyle, f::F, args...) where {F}
- T = Broadcast.combine_eltypes(f, args)
+ TD = _combine_dual_eltypes(f, args)
# Avoid generic broadcasting in two easy cases:
- if T == Bool
+ if TD <: Dual && isconcretetype(TD)
+ if _dual_purefun(F) && all(_dual_safearg, args)
+ y, back = broadcast_forward(f, args...)
+ return y, ȳ -> (nothing, nothing, back(ȳ)...)
+ end
+ elseif TD <: Real && isconcretetype(TD)
return f.(args...), _->nothing
- elseif T <: Real && isconcretetype(T) && _purefun(F) && all(_dualsafe, args)
- y, back = broadcast_forward(f, args...)
- return y, ȳ -> (nothing, nothing, back(ȳ)...)
end
len = inclen(args)
y∂b = _broadcast((x...) -> _pullback(__context__, f, x...), args...)
diff --git a/test/features.jl b/test/features.jl
index dd39a75e8..f471a29e7 100644
--- a/test/features.jl
+++ b/test/features.jl
@@ -535,4 +535,10 @@ end
@test gradient(x -> sum(x ./ [1,2,4]), [1,2,pi]) == ([1.0, 0.5, 0.25],)
@test gradient(x -> sum(map(/, x, [1,2,4])), [1,2,pi]) == ([1.0, 0.5, 0.25],)
+
+ # negative powers
+ @test gradient((x,p) -> sum(x .^ p), [1.0,2.0,4.0], [1,-1,2])[1] ≈ [1.0, -0.25, 8.0]
+ @test gradient((x,p) -> sum(x .^ p), [1.0,2.0,4.0], -1)[1] ≈ [-1.0, -0.25, -0.0625]
+ @test gradient((x,p) -> sum(z -> z^p, x), [1.0,2.0,4.0], -1)[1] ≈ [-1.0, -0.25, -0.0625]
+ @test gradient((x,p) -> mapreduce(z -> z^p, +, x), [1.0,2.0,4.0], -1)[1] ≈ [-1.0, -0.25, -0.0625]
end
From 94712ccbc41909a58a10c48ee7e14be26fdd79b0 Mon Sep 17 00:00:00 2001
From: Michael Abbott <32575566+mcabbott@users.noreply.github.com>
Date: Mon, 21 Jun 2021 21:50:13 -0400
Subject: [PATCH 099/490] revert some of that due to 20% slowdown
---
src/lib/broadcast.jl | 20 ++++++--------------
1 file changed, 6 insertions(+), 14 deletions(-)
diff --git a/src/lib/broadcast.jl b/src/lib/broadcast.jl
index 46882acad..32e307e26 100644
--- a/src/lib/broadcast.jl
+++ b/src/lib/broadcast.jl
@@ -176,22 +176,14 @@ _dual_safearg(x::Ref{<:Numeric{<:Real}}) = true
_dual_safearg(x::Union{Type,Val,Symbol}) = true # non-differentiable types
_dual_safearg(x) = false
-# This is Broadcast.combine_eltypes but with dual eltypes:
-_combine_dual_eltypes(f, args::Tuple) =
- Broadcast.promote_typejoin_union(Base._return_type(f, map(_dual_eltype, args)))
-_dual_eltype(x::Numeric{T}) where {T<:Real} = Dual{Nothing, T, 1} # typeof(Dual(one(T),true))
-_dual_eltype(x) = eltype(x)
-
@adjoint function broadcasted(::AbstractArrayStyle, f::F, args...) where {F}
- TD = _combine_dual_eltypes(f, args)
+ T = Broadcast.combine_eltypes(f, args)
# Avoid generic broadcasting in two easy cases:
- if TD <: Dual && isconcretetype(TD)
- if _dual_purefun(F) && all(_dual_safearg, args)
- y, back = broadcast_forward(f, args...)
- return y, ȳ -> (nothing, nothing, back(ȳ)...)
- end
- elseif TD <: Real && isconcretetype(TD)
- return f.(args...), _->nothing
+ if T == Bool
+ return f.(args...), _->nothing
+ elseif T <: Real && isconcretetype(T) && _dual_purefun(F) && all(_dual_safearg, args)
+ y, back = broadcast_forward(f, args...)
+ return y, ȳ -> (nothing, nothing, back(ȳ)...)
end
len = inclen(args)
y∂b = _broadcast((x...) -> _pullback(__context__, f, x...), args...)
From 8424c3e6c808196afbc665b7208993cdb356c0f8 Mon Sep 17 00:00:00 2001
From: Michael Abbott <32575566+mcabbott@users.noreply.github.com>
Date: Mon, 21 Jun 2021 22:19:21 -0400
Subject: [PATCH 100/490] delete an unused definition
---
src/lib/broadcast.jl | 5 +----
1 file changed, 1 insertion(+), 4 deletions(-)
diff --git a/src/lib/broadcast.jl b/src/lib/broadcast.jl
index 32e307e26..92ecd1d0f 100644
--- a/src/lib/broadcast.jl
+++ b/src/lib/broadcast.jl
@@ -219,7 +219,7 @@ end
@adjoint! (b::typeof(broadcast))(f, args...) = _pullback(__context__, broadcasted, f, args...)
-# Forward Mode (mainly necessary for CUDA)
+# Forward Mode -- necessary for CUDA, also used as a fast path above
import ForwardDiff
using ForwardDiff: Dual
@@ -227,9 +227,6 @@ using ForwardDiff: Dual
dual(x, p) = x
dual(x::Real, p) = Dual(x, p)
-dualtype(::Type{Dual{G,T,P}}) where {G,T,P} = T
-dualtype(T) = T
-
function dual_function(f::F) where F
function (args::Vararg{Any,N}) where N
ds = map(args, ntuple(identity,Val(N))) do x, i
From 163e1731d0a3bbc71464772d61501572be931208 Mon Sep 17 00:00:00 2001
From: Lyndon White
Date: Tue, 22 Jun 2021 18:45:47 +0100
Subject: [PATCH 101/490] use rrules even when all the arguments are types
---
src/compiler/interface2.jl | 10 ++++++----
test/chainrules.jl | 18 ++++++++++++++++++
2 files changed, 24 insertions(+), 4 deletions(-)
diff --git a/src/compiler/interface2.jl b/src/compiler/interface2.jl
index ac4a5a76a..f0c4fa690 100644
--- a/src/compiler/interface2.jl
+++ b/src/compiler/interface2.jl
@@ -7,22 +7,24 @@ function edge!(m::IRTools.Meta, edge::Core.MethodInstance)
end
@generated function _pullback(ctx::AContext, f, args...)
- T = Tuple{f,args...}
- ignore_sig(T) && return :(f(args...), Pullback{$T}(()))
-
+ # Try using ChainRulesCore
if is_kwfunc(f, args...)
# if it is_kw then `args[1]` are the keyword args, `args[2]` is actual function
cr_T = Tuple{ZygoteRuleConfig{ctx}, args[2:end]...}
chain_rrule_f = :chain_rrule_kw
else
cr_T = Tuple{ZygoteRuleConfig{ctx}, f, args...}
+ Core.println("cr_T=", cr_T)
chain_rrule_f = :chain_rrule
end
hascr, cr_edge = has_chain_rrule(cr_T)
-
hascr && return :($chain_rrule_f(ZygoteRuleConfig(ctx), f, args...))
+ # No ChainRule, going to have to work it out.
+ T = Tuple{f,args...}
+ ignore_sig(T) && return :(f(args...), Pullback{$T}(()))
+
g = try _generate_pullback_via_decomposition(T) catch e e end
g === nothing && return :(f(args...), Pullback{$T}((f,)))
meta, forw, _ = g
diff --git a/test/chainrules.jl b/test/chainrules.jl
index 519c10ad6..66058c93d 100644
--- a/test/chainrules.jl
+++ b/test/chainrules.jl
@@ -214,6 +214,24 @@ using ChainRulesCore, ChainRulesTestUtils, Zygote
@test (nothing,) == Zygote.gradient(x->not_diff_kw_eg(x, 2), 10.4)
@test (nothing,) == Zygote.gradient(x->not_diff_kw_eg(x, 2; kw=2.0), 10.4)
end
+
+ @testset "Type only rrule" begin
+ struct StructForTestingTypeOnlyRRules{T}
+ x::T
+ end
+ StructForTestingTypeOnlyRRules() = StructForTestingTypeOnlyRRules(1.0)
+
+ function ChainRulesCore.rrule(P::Type{<:StructForTestingTypeOnlyRRules})
+ # notice here we mess with the primal doing 2.0 rather than 1.0, this is for testing purposes
+ # and also because apparently people actually want to do this. Weird, but 🤷
+ # https://github.com/SciML/SciMLBase.jl/issues/69#issuecomment-865639754
+ P(2.0), _->NoTangent()
+ end
+
+ @assert StructForTestingTypeOnlyRRules().x == 1.0
+ aug_primal_val, _ = Zygote.pullback(x->StructForTestingTypeOnlyRRules(), 1.2)
+ @test aug_primal_val.x == 2.0
+ end
end
@testset "ChainRulesCore.rrule_via_ad" begin
From 2dab48fdfaddd8a908c341eb011ef82817fba0f9 Mon Sep 17 00:00:00 2001
From: Lyndon White
Date: Wed, 23 Jun 2021 17:50:32 +0100
Subject: [PATCH 102/490] Remove leftove debugging statements
---
src/compiler/interface2.jl | 1 -
1 file changed, 1 deletion(-)
diff --git a/src/compiler/interface2.jl b/src/compiler/interface2.jl
index f0c4fa690..0f7da4b32 100644
--- a/src/compiler/interface2.jl
+++ b/src/compiler/interface2.jl
@@ -14,7 +14,6 @@ end
chain_rrule_f = :chain_rrule_kw
else
cr_T = Tuple{ZygoteRuleConfig{ctx}, f, args...}
- Core.println("cr_T=", cr_T)
chain_rrule_f = :chain_rrule
end
From b9f186f8f044ad94b469773ae8fe722fea457983 Mon Sep 17 00:00:00 2001
From: Lyndon White
Date: Thu, 24 Jun 2021 08:48:27 +0100
Subject: [PATCH 103/490] Update test/chainrules.jl
Co-authored-by: Dhairya Gandhi
---
test/chainrules.jl | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/test/chainrules.jl b/test/chainrules.jl
index 66058c93d..32bdd3799 100644
--- a/test/chainrules.jl
+++ b/test/chainrules.jl
@@ -225,7 +225,7 @@ using ChainRulesCore, ChainRulesTestUtils, Zygote
# notice here we mess with the primal doing 2.0 rather than 1.0, this is for testing purposes
# and also because apparently people actually want to do this. Weird, but 🤷
# https://github.com/SciML/SciMLBase.jl/issues/69#issuecomment-865639754
- P(2.0), _->NoTangent()
+ P(2.0), _ -> (NoTangent(),)
end
@assert StructForTestingTypeOnlyRRules().x == 1.0
From 46ef05ef2f839b7d9d95453654f223cf07c89a6e Mon Sep 17 00:00:00 2001
From: Akash Garg
Date: Fri, 25 Jun 2021 12:48:08 -0700
Subject: [PATCH 104/490] Use abstract GPU types for broadcast.
---
src/lib/broadcast.jl | 12 ++++++------
1 file changed, 6 insertions(+), 6 deletions(-)
diff --git a/src/lib/broadcast.jl b/src/lib/broadcast.jl
index 0689c1627..e4b8a431a 100644
--- a/src/lib/broadcast.jl
+++ b/src/lib/broadcast.jl
@@ -233,7 +233,7 @@ end
end
@init @require CUDA="052768ef-5323-5732-b1bb-66c8b64840ba" begin
- const CuArrayStyle = CUDA.CuArrayStyle
+ const CuArrayStyle = CUDA.AbstractGPUArrayStyle
if isdefined(CUDA, :cufunc)
@eval @adjoint function broadcasted(::CuArrayStyle, f, args...)
@@ -247,22 +247,22 @@ end
end
end
- @adjoint CUDA.CuArray{N,T}(xs::Array) where {N,T} =
- CUDA.CuArray{N,T}(xs), Δ -> (convert(Array, Δ), )
+ @adjoint CUDA.DenseArray{N,T}(xs::Array) where {N,T} =
+ CUDA.DenseArray{N,T}(xs), Δ -> (convert(Array, Δ), )
- @adjoint function sum(xs::CUDA.CuArray; dims = :)
+ @adjoint function sum(xs::CUDA.DenseArray; dims = :)
placeholder = similar(xs)
sum(xs, dims = dims), Δ -> (placeholder .= Δ,)
end
# Make sure sum(f, ::CuArray) uses broadcase through forward-mode defined above
# Not the ChainRules.rrule which will use the Zygote.Context and thus not be GPU compatible
- @adjoint function sum(f, xs::CUDA.CuArray; kws...)
+ @adjoint function sum(f, xs::CUDA.DenseArray; kws...)
@assert !haskey(kws, :init) # TODO add init support (julia 1.6)
return pullback(__context__, (f, xs) -> sum(f.(xs); kws...), f, xs)
end
- @adjoint function Base.convert(::Type{T}, xs::Array) where {T<:CUDA.CuArray}
+ @adjoint function Base.convert(::Type{T}, xs::Array) where {T<:CUDA.DenseArray}
Base.convert(T, xs), Δ -> (nothing, Base.convert(Array, Δ),)
end
From c49222c7f144670cdb20c680d1a69ae97fadb4c6 Mon Sep 17 00:00:00 2001
From: Akash Garg
Date: Fri, 25 Jun 2021 14:43:22 -0700
Subject: [PATCH 105/490] Fix typo DenseArray -> DenseCuArray
---
src/lib/broadcast.jl | 10 +++++-----
1 file changed, 5 insertions(+), 5 deletions(-)
diff --git a/src/lib/broadcast.jl b/src/lib/broadcast.jl
index e4b8a431a..72929a592 100644
--- a/src/lib/broadcast.jl
+++ b/src/lib/broadcast.jl
@@ -247,22 +247,22 @@ end
end
end
- @adjoint CUDA.DenseArray{N,T}(xs::Array) where {N,T} =
- CUDA.DenseArray{N,T}(xs), Δ -> (convert(Array, Δ), )
+ @adjoint CUDA.DenseCuArray{N,T}(xs::Array) where {N,T} =
+ CUDA.DenseCuArray{N,T}(xs), Δ -> (convert(Array, Δ), )
- @adjoint function sum(xs::CUDA.DenseArray; dims = :)
+ @adjoint function sum(xs::CUDA.DenseCuArray; dims = :)
placeholder = similar(xs)
sum(xs, dims = dims), Δ -> (placeholder .= Δ,)
end
# Make sure sum(f, ::CuArray) uses broadcase through forward-mode defined above
# Not the ChainRules.rrule which will use the Zygote.Context and thus not be GPU compatible
- @adjoint function sum(f, xs::CUDA.DenseArray; kws...)
+ @adjoint function sum(f, xs::CUDA.DenseCuArray; kws...)
@assert !haskey(kws, :init) # TODO add init support (julia 1.6)
return pullback(__context__, (f, xs) -> sum(f.(xs); kws...), f, xs)
end
- @adjoint function Base.convert(::Type{T}, xs::Array) where {T<:CUDA.DenseArray}
+ @adjoint function Base.convert(::Type{T}, xs::Array) where {T<:CUDA.DenseCuArray}
Base.convert(T, xs), Δ -> (nothing, Base.convert(Array, Δ),)
end
From 41b6862219a00dcd432030db0620a3df129232b2 Mon Sep 17 00:00:00 2001
From: Akash Garg
Date: Fri, 25 Jun 2021 14:48:38 -0700
Subject: [PATCH 106/490] Adjoint for CuArray
---
src/lib/broadcast.jl | 4 ++--
1 file changed, 2 insertions(+), 2 deletions(-)
diff --git a/src/lib/broadcast.jl b/src/lib/broadcast.jl
index 72929a592..0192507ef 100644
--- a/src/lib/broadcast.jl
+++ b/src/lib/broadcast.jl
@@ -247,8 +247,8 @@ end
end
end
- @adjoint CUDA.DenseCuArray{N,T}(xs::Array) where {N,T} =
- CUDA.DenseCuArray{N,T}(xs), Δ -> (convert(Array, Δ), )
+ @adjoint CUDA.CuArray{N,T}(xs::Array) where {N,T} =
+ CUDA.CuArray{N,T}(xs), Δ -> (convert(Array, Δ), )
@adjoint function sum(xs::CUDA.DenseCuArray; dims = :)
placeholder = similar(xs)
From d73dabe42ddc49dee5e4da9315615f541444018d Mon Sep 17 00:00:00 2001
From: Akash Garg
Date: Fri, 25 Jun 2021 15:13:24 -0700
Subject: [PATCH 107/490] updating broadcast to use AbstractGPUArray
---
src/lib/broadcast.jl | 6 +++---
1 file changed, 3 insertions(+), 3 deletions(-)
diff --git a/src/lib/broadcast.jl b/src/lib/broadcast.jl
index 0192507ef..a5fa1adae 100644
--- a/src/lib/broadcast.jl
+++ b/src/lib/broadcast.jl
@@ -250,19 +250,19 @@ end
@adjoint CUDA.CuArray{N,T}(xs::Array) where {N,T} =
CUDA.CuArray{N,T}(xs), Δ -> (convert(Array, Δ), )
- @adjoint function sum(xs::CUDA.DenseCuArray; dims = :)
+ @adjoint function sum(xs::CUDA.AbstractGPUArray; dims = :)
placeholder = similar(xs)
sum(xs, dims = dims), Δ -> (placeholder .= Δ,)
end
# Make sure sum(f, ::CuArray) uses broadcase through forward-mode defined above
# Not the ChainRules.rrule which will use the Zygote.Context and thus not be GPU compatible
- @adjoint function sum(f, xs::CUDA.DenseCuArray; kws...)
+ @adjoint function sum(f, xs::CUDA.AbstractGPUArray; kws...)
@assert !haskey(kws, :init) # TODO add init support (julia 1.6)
return pullback(__context__, (f, xs) -> sum(f.(xs); kws...), f, xs)
end
- @adjoint function Base.convert(::Type{T}, xs::Array) where {T<:CUDA.DenseCuArray}
+ @adjoint function Base.convert(::Type{T}, xs::Array) where {T<:CUDA.AbstractGPUArray}
Base.convert(T, xs), Δ -> (nothing, Base.convert(Array, Δ),)
end
From 8e0ed5d29cdbcedcaa6ac62d661eb96bf21ee8ab Mon Sep 17 00:00:00 2001
From: Michael Abbott <32575566+mcabbott@users.noreply.github.com>
Date: Sat, 26 Jun 2021 13:35:16 -0400
Subject: [PATCH 108/490] add withgradient function
---
docs/src/utils.md | 2 ++
src/Zygote.jl | 2 +-
src/compiler/interface.jl | 22 ++++++++++++++++++++++
src/lib/grad.jl | 24 +++++++++++++++++++-----
test/features.jl | 4 ++++
test/utils.jl | 1 +
6 files changed, 49 insertions(+), 6 deletions(-)
diff --git a/docs/src/utils.md b/docs/src/utils.md
index a03ab1c25..309e43786 100644
--- a/docs/src/utils.md
+++ b/docs/src/utils.md
@@ -14,6 +14,8 @@ in other words you could have written them easily yourself, but they live in
Zygote for convenience.
```@docs
+Zygote.withgradient
+Zygote.withjacobian
Zygote.@showgrad
Zygote.hook
Zygote.dropgrad
diff --git a/src/Zygote.jl b/src/Zygote.jl
index 895e65f3c..ae023213c 100644
--- a/src/Zygote.jl
+++ b/src/Zygote.jl
@@ -13,7 +13,7 @@ using MacroTools, Requires
using MacroTools: @forward
import Distributed: pmap, CachingPool, workers
-export Params, gradient, jacobian, hessian, diaghessian, pullback, pushforward, @code_adjoint
+export Params, withgradient, gradient, withjacobian, jacobian, hessian, diaghessian, pullback, pushforward, @code_adjoint
export rrule_via_ad
const Numeric{T<:Number} = Union{T, AbstractArray{<:T}}
diff --git a/src/compiler/interface.jl b/src/compiler/interface.jl
index 86e847dc4..b631763b0 100644
--- a/src/compiler/interface.jl
+++ b/src/compiler/interface.jl
@@ -53,6 +53,9 @@ Returns a tuple containing `∂f/∂x` for each argument `x`,
the derivative (for scalar x) or the gradient.
`f(args...)` must be a real number, see [`jacobian`](@ref) for array output.
+
+See also [`withgradient`](@ref) to keep the value `f(args...)`,
+and `pullback`](@ref) for value and back-propagator.
"""
function gradient(f, args...)
y, back = pullback(f, args...)
@@ -61,6 +64,25 @@ end
Base.adjoint(f::Function) = x -> gradient(f, x)[1]
+"""
+ withgradient(f, args...)
+
+Returns both the value `f(args...)` and the [`gradient`](@ref),
+`∂f/∂x` for each argument `x`, as a named tuple.
+
+```jldoctest
+julia> y, ∇ = withgradient(/, 1, 2)
+(val = 0.5, grad = (0.5, -0.25))
+
+julia> ∇ == gradient(/, 1, 2)
+true
+```
+"""
+function withgradient(f, args...)
+ y, back = pullback(f, args...)
+ (val=y, grad=back(sensitivity(y)))
+end
+
# Param-style wrappers
# TODO store ids only
diff --git a/src/lib/grad.jl b/src/lib/grad.jl
index 827edbfc2..3a6eefa2c 100644
--- a/src/lib/grad.jl
+++ b/src/lib/grad.jl
@@ -105,7 +105,7 @@ This reverse-mode Jacobian needs to evaluate the pullback once for each element
Doing so is usually only efficient when `length(y)` is small compared to `length(a)`,
otherwise forward mode is likely to be better.
-See also [`hessian`](@ref), [`hessian_reverse`](@ref).
+See also [`withjacobian`](@ref), `hessian`](@ref), [`hessian_reverse`](@ref).
# Examples
@@ -137,7 +137,19 @@ julia> gradient((a,t) -> sum(a .* t[1]) + t[2], [1,2,3], (4,5)) # gradient unde
([4 4 4], (6, 1))
```
"""
-function jacobian(f, args...)
+jacobian(f, args...) = withjacobian(f, args...).grad
+
+"""
+ withjacobian(f, args...)
+
+Returns both the value `f(args...)` and the [`jacobian`](@ref) as a named tuple.
+
+```jldoctest
+julia> withjacobian(cumsum, [1,2,3])
+(val = [1, 3, 6], grad = ([1 0 0; 1 1 0; 1 1 1],))
+```
+"""
+function withjacobian(f, args...)
y, back = pullback(_jvec∘f, args...)
out = map(args) do x
T = promote_type(eltype(x), eltype(y))
@@ -153,7 +165,7 @@ function jacobian(f, args...)
_gradcopy!(view(dx,k,:), grad)
end
end
- out
+ (val=y, grad=out)
end
_jvec(x::AbstractArray) = vec(x)
@@ -197,7 +209,9 @@ julia> Jxy[xs]
2 6 4 8
```
"""
-function jacobian(f, pars::Params)
+jacobian(f, pars::Params) = withjacobian(f, pars::Params).grad
+
+function withjacobian(f, pars::Params)
y, back = pullback(_jvec∘f, pars)
out = IdDict()
for p in pars
@@ -213,7 +227,7 @@ function jacobian(f, pars::Params)
_gradcopy!(view(out[p],k,:), grads[p])
end
end
- Grads(out, pars)
+ (val=y, grad=Grads(out, pars))
end
"""
diff --git a/test/features.jl b/test/features.jl
index f471a29e7..ee981c8be 100644
--- a/test/features.jl
+++ b/test/features.jl
@@ -111,6 +111,7 @@ dx = back(4)
@test dx == (12, 8)
@test gradient(mul, 2, 3) == (3, 2)
+@test withgradient(mul, 2, 3) == (val = 6, grad = (3, 2))
bool = true
b(x) = bool ? 2x : x
@@ -287,6 +288,9 @@ y, back = pullback(() -> layer(x), Params([W]))
@test back([1, 1])[W] == [1 2; 1 2]
@test gradient(() -> sum(W * x), Params([W]))[W] == [1 2; 1 2]
+y, grad = withgradient(() -> sum(W * x), Params([W]))
+@test y == 3
+@test grad[W] == [1 2; 1 2]
let
p = [1]
diff --git a/test/utils.jl b/test/utils.jl
index 1e5366fbc..b5845a7ba 100644
--- a/test/utils.jl
+++ b/test/utils.jl
@@ -42,6 +42,7 @@ end
@testset "jacobian(f, args...)" begin
@test jacobian(identity, [1,2])[1] == [1 0; 0 1]
+ @test withjacobian(identity, [1,2]) == (val = [1,2], grad = ([1 0; 0 1],))
j1 = jacobian((a,x) -> a.^2 .* x, [1,2,3], 1)
@test j1[1] ≈ Diagonal([2,4,6])
From 84693def6161438409cdb111a2ee87e76a010ad9 Mon Sep 17 00:00:00 2001
From: Michael Abbott <32575566+mcabbott@users.noreply.github.com>
Date: Sat, 26 Jun 2021 13:35:34 -0400
Subject: [PATCH 109/490] minimal docstrings re implicit gradients
---
src/compiler/interface.jl | 26 ++++++++++++++++++++++++--
1 file changed, 24 insertions(+), 2 deletions(-)
diff --git a/src/compiler/interface.jl b/src/compiler/interface.jl
index b631763b0..eaa97010d 100644
--- a/src/compiler/interface.jl
+++ b/src/compiler/interface.jl
@@ -66,9 +66,11 @@ Base.adjoint(f::Function) = x -> gradient(f, x)[1]
"""
withgradient(f, args...)
+ withgradient(f, ::Params)
Returns both the value `f(args...)` and the [`gradient`](@ref),
`∂f/∂x` for each argument `x`, as a named tuple.
+With imiplicit parameters, the value is `f()`.
```jldoctest
julia> y, ∇ = withgradient(/, 1, 2)
@@ -85,10 +87,24 @@ end
# Param-style wrappers
-# TODO store ids only
+"""
+ gradient(() -> loss(), ::Params) -> Grads
+
+Gradient with implicit parameters. Returns a container, from which
+`grads[W]` extracts the gradient with respect to some array `W`,
+if this is among those being tracked, for example via `Params([W, A, B])`.
+"""
+gradient
+
+"""
+ Params([A, B, C...])
+
+Container for implicit parameters, differentiating a zero-argument
+funtion `() -> loss()` with respect to `A, B, C`.
+"""
struct Params
order::Buffer # {Any, Vector{Any}}
- params::IdSet{Any}
+ params::IdSet{Any} # TODO store ids only
end
Params() = Params(Buffer([], false), IdSet())
@@ -193,7 +209,13 @@ function copy!(x::AbstractVector, ps::Params)
ps
end
+"""
+ Grads(...)
+Dictionary-like container returned when taking gradients with
+respect to implicit parameters. For an array `W`, appearing
+within `Params([W, A, B...])`, the gradient is `g[W]`.
+"""
struct Grads
grads::IdDict{Any,Any}
params::Params
From 8f811ca9844c33488f0f76a415e4db1f270c9de5 Mon Sep 17 00:00:00 2001
From: Michael Abbott <32575566+mcabbott@users.noreply.github.com>
Date: Sat, 26 Jun 2021 14:31:24 -0400
Subject: [PATCH 110/490] name clash
---
test/features.jl | 4 ++--
1 file changed, 2 insertions(+), 2 deletions(-)
diff --git a/test/features.jl b/test/features.jl
index ee981c8be..b1434aac7 100644
--- a/test/features.jl
+++ b/test/features.jl
@@ -288,9 +288,9 @@ y, back = pullback(() -> layer(x), Params([W]))
@test back([1, 1])[W] == [1 2; 1 2]
@test gradient(() -> sum(W * x), Params([W]))[W] == [1 2; 1 2]
-y, grad = withgradient(() -> sum(W * x), Params([W]))
+y, gr = withgradient(() -> sum(W * x), Params([W]))
@test y == 3
-@test grad[W] == [1 2; 1 2]
+@test gr[W] == [1 2; 1 2]
let
p = [1]
From b361b18e60f3e21f58925836eff960692b90dca6 Mon Sep 17 00:00:00 2001
From: Michael Abbott <32575566+mcabbott@users.noreply.github.com>
Date: Sat, 26 Jun 2021 14:33:41 -0400
Subject: [PATCH 111/490] fix jldoctest
---
src/compiler/interface.jl | 2 +-
src/lib/grad.jl | 2 +-
2 files changed, 2 insertions(+), 2 deletions(-)
diff --git a/src/compiler/interface.jl b/src/compiler/interface.jl
index eaa97010d..148b9a238 100644
--- a/src/compiler/interface.jl
+++ b/src/compiler/interface.jl
@@ -72,7 +72,7 @@ Returns both the value `f(args...)` and the [`gradient`](@ref),
`∂f/∂x` for each argument `x`, as a named tuple.
With imiplicit parameters, the value is `f()`.
-```jldoctest
+```jldoctest; setup=:(using Zygote)
julia> y, ∇ = withgradient(/, 1, 2)
(val = 0.5, grad = (0.5, -0.25))
diff --git a/src/lib/grad.jl b/src/lib/grad.jl
index 3a6eefa2c..4ac7708b6 100644
--- a/src/lib/grad.jl
+++ b/src/lib/grad.jl
@@ -144,7 +144,7 @@ jacobian(f, args...) = withjacobian(f, args...).grad
Returns both the value `f(args...)` and the [`jacobian`](@ref) as a named tuple.
-```jldoctest
+```jldoctest; setup=:(using Zygote)
julia> withjacobian(cumsum, [1,2,3])
(val = [1, 3, 6], grad = ([1 0 0; 1 1 0; 1 1 1],))
```
From c4b7306087dfcab1fcd9264a120b5a42a752a8d9 Mon Sep 17 00:00:00 2001
From: Michael Abbott <32575566+mcabbott@users.noreply.github.com>
Date: Sat, 26 Jun 2021 19:35:32 -0400
Subject: [PATCH 112/490] tweak docstrings, one more test
---
src/compiler/interface.jl | 44 +++++++++++++++++++++++++++++----------
test/utils.jl | 4 ++++
2 files changed, 37 insertions(+), 11 deletions(-)
diff --git a/src/compiler/interface.jl b/src/compiler/interface.jl
index 148b9a238..0fdd9bbeb 100644
--- a/src/compiler/interface.jl
+++ b/src/compiler/interface.jl
@@ -50,12 +50,22 @@ sensitivity(y) = error("Output should be scalar; gradients are not defined for o
gradient(f, args...)
Returns a tuple containing `∂f/∂x` for each argument `x`,
-the derivative (for scalar x) or the gradient.
+the derivative (for scalar `x`) or the gradient.
`f(args...)` must be a real number, see [`jacobian`](@ref) for array output.
See also [`withgradient`](@ref) to keep the value `f(args...)`,
and `pullback`](@ref) for value and back-propagator.
+
+```jldoctest; setup=:(using Zygote)
+julia> gradient(*, 2, 3, 5)
+(15, 10, 6)
+
+julia> gradient([7,11,13]) do x
+ sum(abs2, x)
+ end
+([14, 22, 26],)
+```
"""
function gradient(f, args...)
y, back = pullback(f, args...)
@@ -68,9 +78,8 @@ Base.adjoint(f::Function) = x -> gradient(f, x)[1]
withgradient(f, args...)
withgradient(f, ::Params)
-Returns both the value `f(args...)` and the [`gradient`](@ref),
-`∂f/∂x` for each argument `x`, as a named tuple.
-With imiplicit parameters, the value is `f()`.
+Returns both the value `f(args...)` and the [`gradient`](@ref)
+as a named tuple. With implicit parameters, the value is `f()`.
```jldoctest; setup=:(using Zygote)
julia> y, ∇ = withgradient(/, 1, 2)
@@ -88,19 +97,32 @@ end
# Param-style wrappers
"""
- gradient(() -> loss(), ::Params) -> Grads
+ gradient(() -> loss(), ps::Params) -> Grads
+
+Gradient with implicit parameters. Takes a zero-argument function,
+and returns a dictionary-like container, whose keys are arrays `x in ps`.
-Gradient with implicit parameters. Returns a container, from which
-`grads[W]` extracts the gradient with respect to some array `W`,
-if this is among those being tracked, for example via `Params([W, A, B])`.
+```jldoctest; setup=:(using Zygote)
+julia> x = [1 2; 3 4]; y = [5, 6];
+
+julia> g = gradient(Params([x, y])) do
+ sum(x .* y .* y')
+ end
+Grads(...)
+
+julia> g[x]
+2×2 Matrix{Int64}:
+ 25 30
+ 30 36
+```
"""
gradient
"""
- Params([A, B, C...])
+ Params([A, B, C])
-Container for implicit parameters, differentiating a zero-argument
-funtion `() -> loss()` with respect to `A, B, C`.
+Container for implicit parameters, used when differentiating
+a zero-argument funtion `() -> loss()` with respect to `A, B, C`.
"""
struct Params
order::Buffer # {Any, Vector{Any}}
diff --git a/test/utils.jl b/test/utils.jl
index b5845a7ba..9a3d83ea5 100644
--- a/test/utils.jl
+++ b/test/utils.jl
@@ -83,6 +83,10 @@ end
Jxy = jacobian(() -> ys[1:2] .+ sum(xs.^2), Params([xs, ys]))
@test Jxy[ys] ≈ [1 0 0; 0 1 0]
@test Jxy[xs] ≈ [2 6 4 8; 2 6 4 8]
+
+ z, grad = withjacobian(() -> ys[1:2] .+ sum(xs.^2), Params([xs, ys]))
+ @test z == [35, 37]
+ @test grad[ys] ≈ [1 0 0; 0 1 0]
end
using ForwardDiff
From 821a0d93ae1d366c9ff1442a104b71f53779db05 Mon Sep 17 00:00:00 2001
From: Michael Abbott <32575566+mcabbott@users.noreply.github.com>
Date: Sat, 26 Jun 2021 19:50:50 -0400
Subject: [PATCH 113/490] further tweaks
---
src/compiler/interface.jl | 21 ++++++++++++---------
1 file changed, 12 insertions(+), 9 deletions(-)
diff --git a/src/compiler/interface.jl b/src/compiler/interface.jl
index 0fdd9bbeb..c494438f1 100644
--- a/src/compiler/interface.jl
+++ b/src/compiler/interface.jl
@@ -78,8 +78,8 @@ Base.adjoint(f::Function) = x -> gradient(f, x)[1]
withgradient(f, args...)
withgradient(f, ::Params)
-Returns both the value `f(args...)` and the [`gradient`](@ref)
-as a named tuple. With implicit parameters, the value is `f()`.
+Returns both the value of the function and the [`gradient`](@ref),
+as a named tuple.
```jldoctest; setup=:(using Zygote)
julia> y, ∇ = withgradient(/, 1, 2)
@@ -103,26 +103,29 @@ Gradient with implicit parameters. Takes a zero-argument function,
and returns a dictionary-like container, whose keys are arrays `x in ps`.
```jldoctest; setup=:(using Zygote)
-julia> x = [1 2; 3 4]; y = [5, 6];
+julia> x = [1 2 3; 4 5 6]; y = [7, 8]; z = [1, 10, 100];
julia> g = gradient(Params([x, y])) do
- sum(x .* y .* y')
+ sum(x .* y .* z')
end
Grads(...)
julia> g[x]
-2×2 Matrix{Int64}:
- 25 30
- 30 36
+2×3 Matrix{Int64}:
+ 7 70 700
+ 8 80 800
+
+julia> haskey(g, z) # only x and y are parameters
+false
```
"""
gradient
"""
- Params([A, B, C])
+ Params([A, B])
Container for implicit parameters, used when differentiating
-a zero-argument funtion `() -> loss()` with respect to `A, B, C`.
+a zero-argument funtion `() -> loss(A, B)` with respect to `A, B`.
"""
struct Params
order::Buffer # {Any, Vector{Any}}
From 0416a253381cb8e5a456c9a5fe57d01465cead1b Mon Sep 17 00:00:00 2001
From: Michael Abbott <32575566+mcabbott@users.noreply.github.com>
Date: Sun, 27 Jun 2021 15:10:01 -0400
Subject: [PATCH 114/490] Update src/compiler/interface.jl
Co-authored-by: Carlo Lucibello
---
src/compiler/interface.jl | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/src/compiler/interface.jl b/src/compiler/interface.jl
index c494438f1..605feb8f6 100644
--- a/src/compiler/interface.jl
+++ b/src/compiler/interface.jl
@@ -55,7 +55,7 @@ the derivative (for scalar `x`) or the gradient.
`f(args...)` must be a real number, see [`jacobian`](@ref) for array output.
See also [`withgradient`](@ref) to keep the value `f(args...)`,
-and `pullback`](@ref) for value and back-propagator.
+and [`pullback`](@ref) for value and back-propagator.
```jldoctest; setup=:(using Zygote)
julia> gradient(*, 2, 3, 5)
From 1121cc46fab1b6f9ca765feee9257a93c67bc5bb Mon Sep 17 00:00:00 2001
From: Michael Abbott <32575566+mcabbott@users.noreply.github.com>
Date: Sun, 27 Jun 2021 16:27:42 -0400
Subject: [PATCH 115/490] tweak anon func
---
src/compiler/interface.jl | 10 +++++++---
1 file changed, 7 insertions(+), 3 deletions(-)
diff --git a/src/compiler/interface.jl b/src/compiler/interface.jl
index 605feb8f6..19b73b732 100644
--- a/src/compiler/interface.jl
+++ b/src/compiler/interface.jl
@@ -61,10 +61,14 @@ and [`pullback`](@ref) for value and back-propagator.
julia> gradient(*, 2, 3, 5)
(15, 10, 6)
-julia> gradient([7,11,13]) do x
- sum(abs2, x)
- end
+julia> gradient(x -> sum(abs2,x), [7, 11, 13])
([14, 22, 26],)
+
+julia> gradient([7, 11], 0, 1) do x, y, d
+ p = size(x, d)
+ sum(x.^p .+ y)
+ end
+([14.0, 22.0], 2, nothing)
```
"""
function gradient(f, args...)
From 0d777cd6bf99ecf9098ef5e24872ea0727f7befe Mon Sep 17 00:00:00 2001
From: Dhairya Gandhi
Date: Tue, 29 Jun 2021 16:24:42 +0530
Subject: [PATCH 116/490] in requires
---
src/lib/broadcast.jl | 3 ++-
test/runtests.jl | 1 +
2 files changed, 3 insertions(+), 1 deletion(-)
diff --git a/src/lib/broadcast.jl b/src/lib/broadcast.jl
index f20a36663..756e18ac3 100644
--- a/src/lib/broadcast.jl
+++ b/src/lib/broadcast.jl
@@ -247,6 +247,7 @@ end
end
@init @require CUDA="052768ef-5323-5732-b1bb-66c8b64840ba" begin
+ using CUDA
const CuArrayStyle = CUDA.CuArrayStyle
if isdefined(CUDA, :cufunc)
@@ -280,5 +281,5 @@ end
Base.convert(T, xs), Δ -> (nothing, Base.convert(Array, Δ),)
end
- pull_block_vert(sz, Δ::CUDA.CuArray, A::Number) = CUDA.@allowscalar Δ[sz]
+ @eval pull_block_vert(sz, Δ::CUDA.CuArray, A::Number) = CUDA.@allowscalar Δ[sz]
end
diff --git a/test/runtests.jl b/test/runtests.jl
index 67893a7a5..022727fbe 100644
--- a/test/runtests.jl
+++ b/test/runtests.jl
@@ -1,5 +1,6 @@
using Zygote, Test
using Zygote: gradient, ZygoteRuleConfig
+using CUDA
using CUDA: has_cuda
if has_cuda()
From 42cd28e52d9571d8e36c79593431e2ba1c9b25e5 Mon Sep 17 00:00:00 2001
From: Dhairya Gandhi
Date: Tue, 29 Jun 2021 16:53:24 +0530
Subject: [PATCH 117/490] test the gpu case the same as cpu
---
test/cuda.jl | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/test/cuda.jl b/test/cuda.jl
index 2a4ce40e5..95bcdf373 100644
--- a/test/cuda.jl
+++ b/test/cuda.jl
@@ -94,7 +94,7 @@ end
@testset "vcat scalar indexing" begin
r = cu(rand(Float32, 3))
- grads = (cu(ones(Float32, 3)), nothing)
+ grads = (cu(ones(Float32, 3)), 1.f0)
@test gradient((x,y) -> sum(vcat(x,y)), r, 5) == grads
end
From bcc41b8afc144f5b26bde73afd8207ec861fa735 Mon Sep 17 00:00:00 2001
From: Jordi Bolibar
Date: Fri, 2 Jul 2021 16:14:50 +0200
Subject: [PATCH 118/490] Documentation edit on implicit parameters
After struggling to handle implicit parameters with a Flux model, and following a discourse discussion (https://discourse.julialang.org/t/unrecognized-gradient-using-zygote-for-ad-with-universal-differential-equations/63791/2) , I have decided to add some extra details on how to access and when to use implicit parameters. I hope this helps new users like me to avoid wasting time looking for this.
---
docs/src/index.md | 18 +++++++++++++++++-
1 file changed, 17 insertions(+), 1 deletion(-)
diff --git a/docs/src/index.md b/docs/src/index.md
index be300581d..cac589ae0 100644
--- a/docs/src/index.md
+++ b/docs/src/index.md
@@ -131,6 +131,8 @@ julia> gradient(colordiff, RGB(1, 0, 0), RGB(0, 1, 0))
## Gradients of ML models
+### Explicit parameters
+
It's easy to work with even very large and complex models, and there are few ways to do this. Autograd-style models pass around a collection of weights.
```julia
@@ -170,7 +172,9 @@ julia> dmodel = gradient(model -> sum(model(x)), model)[1]
(W = [0.652543 … 0.683588], b = [1.0, 1.0])
```
-Zygote also support one more way to take gradients, via *implicit parameters* – this is a lot like autograd-style gradients, except we don't have to thread the parameter collection through all our code.
+### Implicit parameters
+
+Zygote also support one more way to take gradients, via *implicit parameters* – this is a lot like autograd-style gradients, except we don't have to thread the parameter collection through all our code. When working with Flux models, this is the recommended way of passing the gradients, as it ensures compatibility with Flux's built-in optimizers.
```julia
julia> W = rand(2, 5); b = rand(2);
@@ -181,8 +185,20 @@ linear (generic function with 2 methods)
julia> grads = gradient(() -> sum(linear(x)), Params([W, b]))
Grads(...)
+# Apply gradients to model parameters
julia> grads[W], grads[b]
([0.652543 … 0.683588], [1.0, 1.0])
```
+Unlike with explicit gradients, in order to see implicit gradients one needs to do:
+
+```julia
+julia> grads.grads
+IdDict{Any, Any} with 5 entries:
+ [0.467471 0.597815 … 0.678126 … => [0.579671 0.215381 … 0.635058 0.623832; 0.579671 0.215381 … …
+ :(Main.x) => [1.3377, 0.930234, 0.499161, 1.33827, 1.37791]
+ :(Main.W) => [0.579671 0.215381 … 0.635058 0.623832; 0.579671 0.215381 … …
+ [0.106308, 0.705531] => 2-element Fill{Float64}: entries equal to 1.0
+ :(Main.b) => 2-element Fill{Float64}: entries equal to 1.0
+```
However, implicit parameters exist mainly for compatibility with Flux's current AD; it's recommended to use the other approaches unless you need this.
From a970421a674d077b126ae8ffc4a56e172fbe880a Mon Sep 17 00:00:00 2001
From: Jordi Bolibar
Date: Fri, 2 Jul 2021 18:24:32 +0200
Subject: [PATCH 119/490] Update docs/src/index.md
Co-authored-by: Kyle Daruwalla
---
docs/src/index.md | 11 ++---------
1 file changed, 2 insertions(+), 9 deletions(-)
diff --git a/docs/src/index.md b/docs/src/index.md
index cac589ae0..657560022 100644
--- a/docs/src/index.md
+++ b/docs/src/index.md
@@ -189,16 +189,9 @@ Grads(...)
julia> grads[W], grads[b]
([0.652543 … 0.683588], [1.0, 1.0])
```
-Unlike with explicit gradients, in order to see implicit gradients one needs to do:
+To inspect the `Grads(...)` object returned for implicit parameters, you can index it using the parameters passed to `Params`:
```julia
-julia> grads.grads
-IdDict{Any, Any} with 5 entries:
- [0.467471 0.597815 … 0.678126 … => [0.579671 0.215381 … 0.635058 0.623832; 0.579671 0.215381 … …
- :(Main.x) => [1.3377, 0.930234, 0.499161, 1.33827, 1.37791]
- :(Main.W) => [0.579671 0.215381 … 0.635058 0.623832; 0.579671 0.215381 … …
- [0.106308, 0.705531] => 2-element Fill{Float64}: entries equal to 1.0
- :(Main.b) => 2-element Fill{Float64}: entries equal to 1.0
-```
+julia> [grads[p] for p in [W, b]]
However, implicit parameters exist mainly for compatibility with Flux's current AD; it's recommended to use the other approaches unless you need this.
From e860ac3dc3765a317b342e7219fe13b4d4c29968 Mon Sep 17 00:00:00 2001
From: Jordi Bolibar
Date: Fri, 2 Jul 2021 19:14:05 +0200
Subject: [PATCH 120/490] Update on implicit/explicit parameters docs
An update following some suggestions.
---
docs/src/index.md | 18 +++++++++---------
1 file changed, 9 insertions(+), 9 deletions(-)
diff --git a/docs/src/index.md b/docs/src/index.md
index 657560022..d55cdc014 100644
--- a/docs/src/index.md
+++ b/docs/src/index.md
@@ -129,11 +129,9 @@ julia> gradient(colordiff, RGB(1, 0, 0), RGB(0, 1, 0))
((r = 0.4590887719632896, g = -9.598786801605689, b = 14.181383399012862), (r = -1.7697549557037275, g = 28.88472330558805, b = -0.044793892637761346))
```
-## Gradients of ML models
+## Explicit and implicit parameters of ML models
-### Explicit parameters
-
-It's easy to work with even very large and complex models, and there are few ways to do this. Autograd-style models pass around a collection of weights.
+It's easy to work with even very large and complex models, and there are few ways to do this. Autograd-style models pass around a collection of weights. There are two ways of passing *explicit* parameters:
```julia
julia> linear(θ, x) = θ[:W] * x .+ θ[:b]
@@ -172,8 +170,6 @@ julia> dmodel = gradient(model -> sum(model(x)), model)[1]
(W = [0.652543 … 0.683588], b = [1.0, 1.0])
```
-### Implicit parameters
-
Zygote also support one more way to take gradients, via *implicit parameters* – this is a lot like autograd-style gradients, except we don't have to thread the parameter collection through all our code. When working with Flux models, this is the recommended way of passing the gradients, as it ensures compatibility with Flux's built-in optimizers.
```julia
@@ -184,14 +180,18 @@ linear (generic function with 2 methods)
julia> grads = gradient(() -> sum(linear(x)), Params([W, b]))
Grads(...)
+```
+To inspect the `Grads(...)` object returned for implicit parameters, you can access it using the parameters passed to `Params`:
+```julia
# Apply gradients to model parameters
julia> grads[W], grads[b]
([0.652543 … 0.683588], [1.0, 1.0])
```
-To inspect the `Grads(...)` object returned for implicit parameters, you can index it using the parameters passed to `Params`:
-```julia
-julia> [grads[p] for p in [W, b]]
+Here `grads` is a dictionary-like object, whose keys are the same parameters we
+indicated in `Params`. (In fact it contains a dictionary using `objectid(W)`, which
+does not change if the values in `W` are mutated.) These parameters `W, b` are global
+variables, but gradients with respect to other global variables are not stored.
However, implicit parameters exist mainly for compatibility with Flux's current AD; it's recommended to use the other approaches unless you need this.
From 51bda084777a642a3641901f833968a0104468e9 Mon Sep 17 00:00:00 2001
From: Jordi Bolibar
Date: Fri, 2 Jul 2021 20:18:24 +0200
Subject: [PATCH 121/490] Update docs/src/index.md
Co-authored-by: Michael Abbott <32575566+mcabbott@users.noreply.github.com>
---
docs/src/index.md | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/docs/src/index.md b/docs/src/index.md
index d55cdc014..d9e5ea7a0 100644
--- a/docs/src/index.md
+++ b/docs/src/index.md
@@ -129,7 +129,7 @@ julia> gradient(colordiff, RGB(1, 0, 0), RGB(0, 1, 0))
((r = 0.4590887719632896, g = -9.598786801605689, b = 14.181383399012862), (r = -1.7697549557037275, g = 28.88472330558805, b = -0.044793892637761346))
```
-## Explicit and implicit parameters of ML models
+## Explicit and Implicit Parameters
It's easy to work with even very large and complex models, and there are few ways to do this. Autograd-style models pass around a collection of weights. There are two ways of passing *explicit* parameters:
From 9033b71fff20a46b8e2a3bb80fd37f670e47f211 Mon Sep 17 00:00:00 2001
From: Jordi Bolibar
Date: Fri, 2 Jul 2021 20:19:25 +0200
Subject: [PATCH 122/490] Update docs/src/index.md
Co-authored-by: Kyle Daruwalla
---
docs/src/index.md | 11 +++--------
1 file changed, 3 insertions(+), 8 deletions(-)
diff --git a/docs/src/index.md b/docs/src/index.md
index d9e5ea7a0..60963608e 100644
--- a/docs/src/index.md
+++ b/docs/src/index.md
@@ -180,18 +180,13 @@ linear (generic function with 2 methods)
julia> grads = gradient(() -> sum(linear(x)), Params([W, b]))
Grads(...)
-```
-To inspect the `Grads(...)` object returned for implicit parameters, you can access it using the parameters passed to `Params`:
-```julia
-# Apply gradients to model parameters
-julia> grads[W], grads[b]
+julia> grads[W], grads[b] # access gradients using arrays as keys
([0.652543 … 0.683588], [1.0, 1.0])
```
Here `grads` is a dictionary-like object, whose keys are the same parameters we
-indicated in `Params`. (In fact it contains a dictionary using `objectid(W)`, which
-does not change if the values in `W` are mutated.) These parameters `W, b` are global
-variables, but gradients with respect to other global variables are not stored.
+indicated in `Params`. (In fact it wraps a dictionary using `objectid(W)` as keys, which
+does not change if the values in `W` are mutated).
However, implicit parameters exist mainly for compatibility with Flux's current AD; it's recommended to use the other approaches unless you need this.
From 49d50645f829794d7a499bb4b5efefe4563d7134 Mon Sep 17 00:00:00 2001
From: Jordi Bolibar
Date: Fri, 2 Jul 2021 20:23:00 +0200
Subject: [PATCH 123/490] Formatting and new last paragraph in docs
Adding the last paragraph suggested by Michael.
---
docs/src/index.md | 6 ++----
1 file changed, 2 insertions(+), 4 deletions(-)
diff --git a/docs/src/index.md b/docs/src/index.md
index 60963608e..b468cc27a 100644
--- a/docs/src/index.md
+++ b/docs/src/index.md
@@ -185,8 +185,6 @@ julia> grads[W], grads[b] # access gradients using arrays as keys
([0.652543 … 0.683588], [1.0, 1.0])
```
-Here `grads` is a dictionary-like object, whose keys are the same parameters we
-indicated in `Params`. (In fact it wraps a dictionary using `objectid(W)` as keys, which
-does not change if the values in `W` are mutated).
+Here `grads` is a dictionary-like object, whose keys are the same parameters we indicated in `Params`. (In fact it wraps a dictionary using `objectid(W)` as keys, which does not change if the values in `W` are mutated).
-However, implicit parameters exist mainly for compatibility with Flux's current AD; it's recommended to use the other approaches unless you need this.
+This implicit style is the one presently used by [Flux.jl](https://github.com/FluxML/Flux.jl), a closely related machine learning library. It uses structs like `Linear` above to define layers, and the function `Flux.params(model)` returns a `Params` object containing all the parameters of all layers. See [its documentation](https://fluxml.ai/Flux.jl/stable/models/basics/) for more details. When using Zygote for most other purposes, however, the explicit style is usually preferred.
From cfe501a942caa6eccb376136c55b0e6cfe8d5d28 Mon Sep 17 00:00:00 2001
From: Jordi Bolibar
Date: Fri, 2 Jul 2021 21:25:32 +0200
Subject: [PATCH 124/490] Update docs/src/index.md
Co-authored-by: Kyle Daruwalla
---
docs/src/index.md | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/docs/src/index.md b/docs/src/index.md
index b468cc27a..1b10e4293 100644
--- a/docs/src/index.md
+++ b/docs/src/index.md
@@ -131,7 +131,7 @@ julia> gradient(colordiff, RGB(1, 0, 0), RGB(0, 1, 0))
## Explicit and Implicit Parameters
-It's easy to work with even very large and complex models, and there are few ways to do this. Autograd-style models pass around a collection of weights. There are two ways of passing *explicit* parameters:
+It's easy to work with even very large and complex models, and there are few ways to do this. Autograd-style models pass around a collection of weights. Depending on how you write your model, there are multiple ways to *explicity* take gradients with respect to parameters. For example, the function `linear` accepts the parameters as an argument to the model. So, we directly pass in the parameters, `θ`, as an argument to the function being differentiated.
```julia
julia> linear(θ, x) = θ[:W] * x .+ θ[:b]
From b0db7287b08a869d4294808b0f7f8e0b24dbc41a Mon Sep 17 00:00:00 2001
From: Jordi Bolibar
Date: Fri, 2 Jul 2021 21:29:17 +0200
Subject: [PATCH 125/490] Update index.md
---
docs/src/index.md | 3 +--
1 file changed, 1 insertion(+), 2 deletions(-)
diff --git a/docs/src/index.md b/docs/src/index.md
index 1b10e4293..b70ca9eb7 100644
--- a/docs/src/index.md
+++ b/docs/src/index.md
@@ -170,7 +170,7 @@ julia> dmodel = gradient(model -> sum(model(x)), model)[1]
(W = [0.652543 … 0.683588], b = [1.0, 1.0])
```
-Zygote also support one more way to take gradients, via *implicit parameters* – this is a lot like autograd-style gradients, except we don't have to thread the parameter collection through all our code. When working with Flux models, this is the recommended way of passing the gradients, as it ensures compatibility with Flux's built-in optimizers.
+On the other hand, the *implicit* style is the one presently used by [Flux.jl](https://github.com/FluxML/Flux.jl), a closely related machine learning library. It uses structs like `Linear` above to define layers, and the function `Flux.params(model)` returns a `Params` object containing all the parameters of all layers. See [its documentation](https://fluxml.ai/Flux.jl/stable/models/basics/) for more details. When using Zygote for most other purposes, however, the explicit style is usually preferred.
```julia
julia> W = rand(2, 5); b = rand(2);
@@ -187,4 +187,3 @@ julia> grads[W], grads[b] # access gradients using arrays as keys
Here `grads` is a dictionary-like object, whose keys are the same parameters we indicated in `Params`. (In fact it wraps a dictionary using `objectid(W)` as keys, which does not change if the values in `W` are mutated).
-This implicit style is the one presently used by [Flux.jl](https://github.com/FluxML/Flux.jl), a closely related machine learning library. It uses structs like `Linear` above to define layers, and the function `Flux.params(model)` returns a `Params` object containing all the parameters of all layers. See [its documentation](https://fluxml.ai/Flux.jl/stable/models/basics/) for more details. When using Zygote for most other purposes, however, the explicit style is usually preferred.
From b63276052e5997e63d3feaa66993f738f7c885b9 Mon Sep 17 00:00:00 2001
From: Dhairya Gandhi
Date: Sun, 4 Jul 2021 20:28:42 +0530
Subject: [PATCH 126/490] whitespace
---
src/compiler/interface.jl | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/src/compiler/interface.jl b/src/compiler/interface.jl
index 19b73b732..35a02e272 100644
--- a/src/compiler/interface.jl
+++ b/src/compiler/interface.jl
@@ -95,7 +95,7 @@ true
"""
function withgradient(f, args...)
y, back = pullback(f, args...)
- (val=y, grad=back(sensitivity(y)))
+ (val = y, grad = back(sensitivity(y)))
end
# Param-style wrappers
From d70f10544df4d2fa244ca90d85a799697c3fbd9c Mon Sep 17 00:00:00 2001
From: "github-actions[bot]"
<41898282+github-actions[bot]@users.noreply.github.com>
Date: Sun, 4 Jul 2021 16:04:00 +0000
Subject: [PATCH 127/490] CompatHelper: bump compat for "FillArrays" to "0.12"
---
Project.toml | 4 +--
docs/Manifest.toml | 21 ++++---------
examples/Manifest.toml | 68 ++++++++++++++++++++++++------------------
3 files changed, 46 insertions(+), 47 deletions(-)
diff --git a/Project.toml b/Project.toml
index 3f9baa17c..6cd2fccdc 100644
--- a/Project.toml
+++ b/Project.toml
@@ -27,7 +27,7 @@ ChainRules = "0.8.12"
ChainRulesCore = "0.10.4"
ChainRulesTestUtils = "0.7.1"
DiffRules = "1.0"
-FillArrays = "0.8, 0.9, 0.10, 0.11"
+FillArrays = "0.8, 0.9, 0.10, 0.11, 0.12"
ForwardDiff = "0.10"
IRTools = "0.4"
MacroTools = "0.5"
@@ -45,7 +45,7 @@ Distances = "b4f34e82-e78d-54a5-968a-f98e89d6e8f7"
FFTW = "7a1cc6ca-52ef-59f5-83cd-3a7055c09341"
FiniteDifferences = "26cc04aa-876d-5657-8c51-4c34ba976000"
LogExpFunctions = "2ab3a3ac-af41-5b50-aa03-7779005ae688"
-StatsFuns = "4c63d2b9-4356-54db-8cca-17b64c39e42c" # otherwise we can't add a compat bound
+StatsFuns = "4c63d2b9-4356-54db-8cca-17b64c39e42c"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
[targets]
diff --git a/docs/Manifest.toml b/docs/Manifest.toml
index 1d00bfcff..1628ace07 100644
--- a/docs/Manifest.toml
+++ b/docs/Manifest.toml
@@ -12,10 +12,10 @@ deps = ["Random", "Serialization", "Sockets"]
uuid = "8ba89e20-285c-5b6f-9357-94700520ee1b"
[[DocStringExtensions]]
-deps = ["LibGit2", "Markdown", "Pkg", "Test"]
-git-tree-sha1 = "50ddf44c53698f5e784bbebb3f4b21c5807401b1"
+deps = ["LibGit2"]
+git-tree-sha1 = "a32185f5428d3986f47c2ab78b1f216d5e6cc96f"
uuid = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae"
-version = "0.8.3"
+version = "0.8.5"
[[Documenter]]
deps = ["Base64", "Dates", "DocStringExtensions", "InteractiveUtils", "JSON", "LibGit2", "Logging", "Markdown", "REPL", "Test", "Unicode"]
@@ -48,13 +48,9 @@ uuid = "a63ad114-7e13-5084-954f-fe012c677804"
[[Parsers]]
deps = ["Dates"]
-git-tree-sha1 = "50c9a9ed8c714945e01cd53a21007ed3865ed714"
+git-tree-sha1 = "c8abc88faa3f7a3950832ac5d6e690881590d6dc"
uuid = "69de0a69-1ddd-5017-9359-2bf0b02dc9f0"
-version = "1.0.15"
-
-[[Pkg]]
-deps = ["Dates", "LibGit2", "Markdown", "Printf", "REPL", "Random", "SHA", "UUIDs"]
-uuid = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"
+version = "1.1.0"
[[Printf]]
deps = ["Unicode"]
@@ -68,9 +64,6 @@ uuid = "3fa0cd96-eef1-5676-8a61-b3b8758bbffb"
deps = ["Serialization"]
uuid = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
-[[SHA]]
-uuid = "ea8e919c-243c-51af-8825-aaa63cd721ce"
-
[[Serialization]]
uuid = "9e88b42a-f829-5b0c-bbe9-9e923198166b"
@@ -81,9 +74,5 @@ uuid = "6462fe0b-24de-5631-8697-dd941f90decc"
deps = ["Distributed", "InteractiveUtils", "Logging", "Random"]
uuid = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
-[[UUIDs]]
-deps = ["Random", "SHA"]
-uuid = "cf7118a7-6976-5b1a-9a39-7adc72f591a4"
-
[[Unicode]]
uuid = "4ec0a83e-493e-50e2-b9ac-8f72acf5a8f5"
diff --git a/examples/Manifest.toml b/examples/Manifest.toml
index 2ecfa67d1..4ad93ad0d 100644
--- a/examples/Manifest.toml
+++ b/examples/Manifest.toml
@@ -82,9 +82,9 @@ version = "0.3.0"
[[Conda]]
deps = ["JSON", "VersionParsing"]
-git-tree-sha1 = "c0647249d785f1d5139c0cc96db8f6b32f7ec416"
+git-tree-sha1 = "299304989a5e6473d985212c28928899c74e9421"
uuid = "8f4d0f93-b110-5947-807f-2305c1781a2d"
-version = "1.5.0"
+version = "1.5.2"
[[CuArrays]]
deps = ["AbstractFFTs", "Adapt", "CUDAapi", "CUDAdrv", "CUDAnative", "GPUArrays", "LinearAlgebra", "MacroTools", "NNlib", "Printf", "Random", "Requires", "SparseArrays", "TimerOutputs"]
@@ -93,9 +93,9 @@ uuid = "3a865a2d-5b23-5a0f-bc46-62713ec82fae"
version = "1.2.1"
[[DataAPI]]
-git-tree-sha1 = "ad84f52c0b8f05aa20839484dbaf01690b41ff84"
+git-tree-sha1 = "ee400abb2298bd13bfc3df1c412ed228061a2385"
uuid = "9a962f9c-6df0-11e9-0e5d-c546b8b5ee8a"
-version = "1.4.0"
+version = "1.7.0"
[[DataStructures]]
deps = ["InteractiveUtils", "OrderedCollections"]
@@ -127,6 +127,11 @@ version = "1.0.2"
deps = ["Random", "Serialization", "Sockets"]
uuid = "8ba89e20-285c-5b6f-9357-94700520ee1b"
+[[ExprTools]]
+git-tree-sha1 = "555eab1f7c501166ba87eeb5d561e9f5e7d167d3"
+uuid = "e2ba6199-217a-4e67-a87a-7c52f15ade04"
+version = "0.1.4"
+
[[FFTW]]
deps = ["AbstractFFTs", "BinaryProvider", "Conda", "Libdl", "LinearAlgebra", "Reexport", "Test"]
git-tree-sha1 = "6c5b420da0b8c12098048561b8d58f81adea506f"
@@ -151,10 +156,10 @@ uuid = "587475ba-b771-5e3f-ad9e-33799f191a9c"
version = "0.9.0"
[[ForwardDiff]]
-deps = ["CommonSubexpressions", "DiffResults", "DiffRules", "NaNMath", "Random", "SpecialFunctions", "StaticArrays"]
-git-tree-sha1 = "8de2519a83c6c1c2442c2f481dd9a8364855daf4"
+deps = ["CommonSubexpressions", "DiffResults", "DiffRules", "LinearAlgebra", "NaNMath", "Printf", "Random", "SpecialFunctions", "StaticArrays"]
+git-tree-sha1 = "e2af66012e08966366a43251e1fd421522908be6"
uuid = "f6369f11-7733-5829-9624-2563aa707210"
-version = "0.10.14"
+version = "0.10.18"
[[GPUArrays]]
deps = ["Adapt", "FFTW", "FillArrays", "LinearAlgebra", "Printf", "Random", "Serialization", "Test"]
@@ -164,9 +169,9 @@ version = "1.0.4"
[[IRTools]]
deps = ["InteractiveUtils", "MacroTools", "Test"]
-git-tree-sha1 = "c67e7515a11f726f44083e74f218d134396d6510"
+git-tree-sha1 = "95215cd0076a150ef46ff7928892bc341864c73c"
uuid = "7869d1d1-7146-5819-86e3-90919afe41df"
-version = "0.4.2"
+version = "0.4.3"
[[InteractiveUtils]]
deps = ["Markdown"]
@@ -221,9 +226,9 @@ version = "0.5.0"
[[Missings]]
deps = ["DataAPI"]
-git-tree-sha1 = "ed61674a0864832495ffe0a7e889c0da76b0f4c8"
+git-tree-sha1 = "4ea90bd5d3985ae1f9a908bd4500ae88921c5ce7"
uuid = "e1d29d7a-bbdc-5cf2-9ac0-f12de2c33e28"
-version = "0.4.4"
+version = "1.0.0"
[[Mmap]]
uuid = "a63ad114-7e13-5084-954f-fe012c677804"
@@ -240,15 +245,15 @@ uuid = "77ba4419-2d1f-58cd-9bb1-8ffee604a2e3"
version = "0.3.5"
[[OrderedCollections]]
-git-tree-sha1 = "cf59cfed2e2c12e8a2ff0a4f1e9b2cd8650da6db"
+git-tree-sha1 = "85f8e6578bf1f9ee0d11e7bb1b1456435479d47c"
uuid = "bac558e1-5e72-5ebc-8fee-abe8a469f55d"
-version = "1.3.2"
+version = "1.4.1"
[[Parsers]]
deps = ["Dates"]
-git-tree-sha1 = "50c9a9ed8c714945e01cd53a21007ed3865ed714"
+git-tree-sha1 = "c8abc88faa3f7a3950832ac5d6e690881590d6dc"
uuid = "69de0a69-1ddd-5017-9359-2bf0b02dc9f0"
-version = "1.0.15"
+version = "1.1.0"
[[Pkg]]
deps = ["Dates", "LibGit2", "Markdown", "Printf", "REPL", "Random", "SHA", "UUIDs"]
@@ -264,9 +269,9 @@ uuid = "9abbd945-dff8-562f-b5e8-e1ebf5ef1b79"
[[ProgressMeter]]
deps = ["Distributed", "Printf"]
-git-tree-sha1 = "45640774ee2efa24e52686dbdf895e88102e68fc"
+git-tree-sha1 = "afadeba63d90ff223a6a48d2009434ecee2ec9e8"
uuid = "92933f4c-e287-5a05-a399-4b506db050ca"
-version = "1.4.1"
+version = "1.7.1"
[[REPL]]
deps = ["InteractiveUtils", "Markdown", "Sockets"]
@@ -298,10 +303,10 @@ uuid = "9e88b42a-f829-5b0c-bbe9-9e923198166b"
uuid = "6462fe0b-24de-5631-8697-dd941f90decc"
[[SortingAlgorithms]]
-deps = ["DataStructures", "Random", "Test"]
-git-tree-sha1 = "03f5898c9959f8115e30bc7226ada7d0df554ddd"
+deps = ["DataStructures"]
+git-tree-sha1 = "2ec1962eba973f383239da22e75218565c390a96"
uuid = "a2af1166-a08f-5f64-846c-94a0d3cef48c"
-version = "0.3.1"
+version = "1.0.0"
[[SparseArrays]]
deps = ["LinearAlgebra", "Random"]
@@ -315,29 +320,34 @@ version = "0.8.0"
[[StaticArrays]]
deps = ["LinearAlgebra", "Random", "Statistics"]
-git-tree-sha1 = "9da72ed50e94dbff92036da395275ed114e04d49"
+git-tree-sha1 = "896d55218776ab8f23fb7b222a5a4a946d4aafc2"
uuid = "90137ffa-7385-5640-81b9-e52037218182"
-version = "1.0.1"
+version = "1.2.5"
[[Statistics]]
deps = ["LinearAlgebra", "SparseArrays"]
uuid = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
+[[StatsAPI]]
+git-tree-sha1 = "1958272568dc176a1d881acb797beb909c785510"
+uuid = "82ae8749-77ed-4fe6-ae5f-f523153014b0"
+version = "1.0.0"
+
[[StatsBase]]
-deps = ["DataAPI", "DataStructures", "LinearAlgebra", "Missings", "Printf", "Random", "SortingAlgorithms", "SparseArrays", "Statistics"]
-git-tree-sha1 = "7bab7d4eb46b225b35179632852b595a3162cb61"
+deps = ["DataAPI", "DataStructures", "LinearAlgebra", "Missings", "Printf", "Random", "SortingAlgorithms", "SparseArrays", "Statistics", "StatsAPI"]
+git-tree-sha1 = "2f6792d523d7448bbe2fec99eca9218f06cc746d"
uuid = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"
-version = "0.33.2"
+version = "0.33.8"
[[Test]]
deps = ["Distributed", "InteractiveUtils", "Logging", "Random"]
uuid = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
[[TimerOutputs]]
-deps = ["Printf"]
-git-tree-sha1 = "3318281dd4121ecf9713ce1383b9ace7d7476fdd"
+deps = ["ExprTools", "Printf"]
+git-tree-sha1 = "209a8326c4f955e2442c07b56029e88bb48299c7"
uuid = "a759f4b9-e2f1-59dc-863e-4aeb61b1ea8f"
-version = "0.5.7"
+version = "0.5.12"
[[Tracker]]
deps = ["Adapt", "DiffRules", "ForwardDiff", "LinearAlgebra", "MacroTools", "NNlib", "NaNMath", "Printf", "Random", "Requires", "SpecialFunctions", "Statistics", "Test"]
@@ -381,4 +391,4 @@ git-tree-sha1 = "d3c2ae55d116b5360a73b1e88d1a974b446d933a"
repo-rev = "ffc50480ff8f7662110bfb82b0b6d4f9cef6e59d"
repo-url = "https://github.com/FluxML/Zygote.jl.git"
uuid = "e88e6eb3-aa80-5325-afca-941959d7151f"
-version = "0.6.0+"
+version = "0.6.14+"
From 468d7a4d987907c7572b0b28e153eca76cd5c848 Mon Sep 17 00:00:00 2001
From: Jordi Bolibar
Date: Mon, 5 Jul 2021 10:43:35 +0200
Subject: [PATCH 128/490] Update docs/src/index.md
Co-authored-by: Michael Abbott <32575566+mcabbott@users.noreply.github.com>
---
docs/src/index.md | 3 +--
1 file changed, 1 insertion(+), 2 deletions(-)
diff --git a/docs/src/index.md b/docs/src/index.md
index b70ca9eb7..0cfd80ae3 100644
--- a/docs/src/index.md
+++ b/docs/src/index.md
@@ -170,7 +170,7 @@ julia> dmodel = gradient(model -> sum(model(x)), model)[1]
(W = [0.652543 … 0.683588], b = [1.0, 1.0])
```
-On the other hand, the *implicit* style is the one presently used by [Flux.jl](https://github.com/FluxML/Flux.jl), a closely related machine learning library. It uses structs like `Linear` above to define layers, and the function `Flux.params(model)` returns a `Params` object containing all the parameters of all layers. See [its documentation](https://fluxml.ai/Flux.jl/stable/models/basics/) for more details. When using Zygote for most other purposes, however, the explicit style is usually preferred.
+Zygote also support another way to take gradients, via *implicit parameters*. Here the loss function takes zero arguments, but the variables of interest are indicated by a special `Params` object. The function `linear` which depends on `W` and `b` is executed when the loss function `() -> sum(linear(x))` is called, and hence this dependence is visible to Zygote:
```julia
julia> W = rand(2, 5); b = rand(2);
@@ -186,4 +186,3 @@ julia> grads[W], grads[b] # access gradients using arrays as keys
```
Here `grads` is a dictionary-like object, whose keys are the same parameters we indicated in `Params`. (In fact it wraps a dictionary using `objectid(W)` as keys, which does not change if the values in `W` are mutated).
-
From 4f1f7e64d0b4afa7845df694c3bbd03cc68c0441 Mon Sep 17 00:00:00 2001
From: Jordi Bolibar
Date: Tue, 6 Jul 2021 11:24:31 +0200
Subject: [PATCH 129/490] Update docs/src/index.md
Co-authored-by: Michael Abbott <32575566+mcabbott@users.noreply.github.com>
---
docs/src/index.md | 2 ++
1 file changed, 2 insertions(+)
diff --git a/docs/src/index.md b/docs/src/index.md
index 0cfd80ae3..36bf0ae8f 100644
--- a/docs/src/index.md
+++ b/docs/src/index.md
@@ -186,3 +186,5 @@ julia> grads[W], grads[b] # access gradients using arrays as keys
```
Here `grads` is a dictionary-like object, whose keys are the same parameters we indicated in `Params`. (In fact it wraps a dictionary using `objectid(W)` as keys, which does not change if the values in `W` are mutated).
+
+This implicit style is the one presently used by [Flux.jl](https://github.com/FluxML/Flux.jl), a closely related machine learning library. It uses structs like `Linear` above to define layers, and the function `Flux.params(model)` returns a `Params` object containing all the parameters of all layers. See [its documentation](https://fluxml.ai/Flux.jl/stable/models/basics/) for more details. When using Zygote for most other purposes, however, the explicit style is usually preferred.
From f02b1259bf56b8f26c6348247e8c21edc2048fb0 Mon Sep 17 00:00:00 2001
From: Shuhei Kadowaki
Date: Fri, 9 Jul 2021 14:50:11 +0900
Subject: [PATCH 130/490] documentation updates
---
docs/Manifest.toml | 15 +++++++++------
docs/src/adjoints.md | 4 ++--
docs/src/internals.md | 4 ++--
docs/src/utils.md | 10 +++++-----
4 files changed, 18 insertions(+), 15 deletions(-)
diff --git a/docs/Manifest.toml b/docs/Manifest.toml
index 1628ace07..eac8a8401 100644
--- a/docs/Manifest.toml
+++ b/docs/Manifest.toml
@@ -7,10 +7,6 @@ uuid = "2a0f44e3-6c83-55bd-87e4-b1978d98bd5f"
deps = ["Printf"]
uuid = "ade2ca70-3891-5945-98fb-dc099432e06a"
-[[Distributed]]
-deps = ["Random", "Serialization", "Sockets"]
-uuid = "8ba89e20-285c-5b6f-9357-94700520ee1b"
-
[[DocStringExtensions]]
deps = ["LibGit2"]
git-tree-sha1 = "a32185f5428d3986f47c2ab78b1f216d5e6cc96f"
@@ -34,6 +30,7 @@ uuid = "682c06a0-de6a-54ab-a142-c8b1cf79cde6"
version = "0.21.1"
[[LibGit2]]
+deps = ["Base64", "NetworkOptions", "Printf", "SHA"]
uuid = "76f85450-5226-5b5a-8eaa-529ad045b433"
[[Logging]]
@@ -46,6 +43,9 @@ uuid = "d6f4376e-aef5-505a-96c1-9c027394607a"
[[Mmap]]
uuid = "a63ad114-7e13-5084-954f-fe012c677804"
+[[NetworkOptions]]
+uuid = "ca575930-c2e3-43a9-ace4-1e988b2c1908"
+
[[Parsers]]
deps = ["Dates"]
git-tree-sha1 = "c8abc88faa3f7a3950832ac5d6e690881590d6dc"
@@ -57,13 +57,16 @@ deps = ["Unicode"]
uuid = "de0858da-6303-5e67-8744-51eddeeeb8d7"
[[REPL]]
-deps = ["InteractiveUtils", "Markdown", "Sockets"]
+deps = ["InteractiveUtils", "Markdown", "Sockets", "Unicode"]
uuid = "3fa0cd96-eef1-5676-8a61-b3b8758bbffb"
[[Random]]
deps = ["Serialization"]
uuid = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
+[[SHA]]
+uuid = "ea8e919c-243c-51af-8825-aaa63cd721ce"
+
[[Serialization]]
uuid = "9e88b42a-f829-5b0c-bbe9-9e923198166b"
@@ -71,7 +74,7 @@ uuid = "9e88b42a-f829-5b0c-bbe9-9e923198166b"
uuid = "6462fe0b-24de-5631-8697-dd941f90decc"
[[Test]]
-deps = ["Distributed", "InteractiveUtils", "Logging", "Random"]
+deps = ["InteractiveUtils", "Logging", "Random", "Serialization"]
uuid = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
[[Unicode]]
diff --git a/docs/src/adjoints.md b/docs/src/adjoints.md
index 89a76a2bd..9cf72c943 100644
--- a/docs/src/adjoints.md
+++ b/docs/src/adjoints.md
@@ -4,9 +4,9 @@
Zygote supports the use of [ChainRulesCore](http://www.juliadiff.org/ChainRulesCore.jl/stable/) to define custom sensitivities.
It is prefered to define the custom sensitivities using `ChainRulesCore.rrule` as they will work for many AD systems, not just Zygote.
These sensitivities can be added in your own package, or for Base/StdLib functions they can be added to [ChainRules.jl](https://github.com/JuliaDiff/ChainRules.jl/).
- To define custom sensitivities using ChainRulesCore, define a `ChainRulesCore.rrule(f, args...; kwargs...)` [ChainRules project's documentation for more information](https://www.juliadiff.org/ChainRulesCore.jl/stable/).
+ To define custom sensitivities using ChainRulesCore, define a `ChainRulesCore.rrule(f, args...; kwargs...)`. Head for [ChainRules project's documentation](https://www.juliadiff.org/ChainRulesCore.jl/stable/) for more information.
**If you are defining your custom adjoints using ChainRulesCore then you do not need to read this page**, and can consider it as documenting a legacy feature.
-
+
This page exists to descibe how Zygote works, and how adjoints can be directly defined for Zygote.
Defining adjoints this way does not make them accessible to other AD systems, but does let you do things that directly depend on how Zygote works.
It allows for specific definitions of adjoints that are only defined for Zgyote (which might work differently to more generic definitions defined for all AD).
diff --git a/docs/src/internals.md b/docs/src/internals.md
index 3c808ad79..70651a697 100644
--- a/docs/src/internals.md
+++ b/docs/src/internals.md
@@ -137,7 +137,7 @@ We convert the code to SSA form using Julia's built-in IR data structure, after
julia> Zygote.@code_ir foo(1)
1 1 ─ %1 = (Main.bar)(_2)::Any
│ %2 = (Main.baz)(%1)::Any
- └── return %2
+ └── return %2
```
(There isn't much difference unless there's some control flow.)
@@ -202,7 +202,7 @@ function J(::typeof(foo), x)
return b, Pullback{typeof(foo)}((da, db))
end
-function(p::Pullback{typeof(foo)})(b̄)
+function (p::Pullback{typeof(foo)})(b̄)
da, db = p.data[1], p.data[2]
ā = db(b̄)
x̄ = da(ā)
diff --git a/docs/src/utils.md b/docs/src/utils.md
index 309e43786..b7e779185 100644
--- a/docs/src/utils.md
+++ b/docs/src/utils.md
@@ -37,17 +37,17 @@ using Zygote, Test
w, x1, x2, b = rand(2), rand(2), rand(2), rand(2)
-gs1 = gradient(() -> sum(tanh.(w .* x1 .+ b)), Params([w, b]))
-gs2 = gradient(() -> sum(tanh.(w .* x2 .+ b)), Params([w, b]))
+gs1 = gradient(() -> sum(tanh.(w .* x1 .+ b)), Params([w, b]))
+gs2 = gradient(() -> sum(tanh.(w .* x2 .+ b)), Params([w, b]))
# accumulate gradients
gs = gs1 .+ gs2
-@test gs[w] ≈ gs1[w] + gs2[w]
-@test gs[b] ≈ gs1[b] + gs2[b]
+@test gs[w] ≈ gs1[w] + gs2[w]
+@test gs[b] ≈ gs1[b] + gs2[b]
# gradients and IdDict interact nicely
# note that an IdDict must be used for gradient algebra on the GPU
-gs .+= IdDict(p => randn(size(p)) for p in keys(gs))
+gs .+= IdDict(p => randn(size(p)) for p in keys(gs))
# clip gradients
map(x -> clamp.(x, -0.1, 0.1), gs)
From d087bb6c63043c7b2e2fde54b15f3d19a76bba84 Mon Sep 17 00:00:00 2001
From: Shuhei Kadowaki <40514306+aviatesk@users.noreply.github.com>
Date: Fri, 9 Jul 2021 20:19:54 +0900
Subject: [PATCH 131/490] Update docs/src/adjoints.md
Co-authored-by: Dhairya Gandhi
---
docs/src/adjoints.md | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/docs/src/adjoints.md b/docs/src/adjoints.md
index 9cf72c943..1d6ddc527 100644
--- a/docs/src/adjoints.md
+++ b/docs/src/adjoints.md
@@ -4,7 +4,7 @@
Zygote supports the use of [ChainRulesCore](http://www.juliadiff.org/ChainRulesCore.jl/stable/) to define custom sensitivities.
It is prefered to define the custom sensitivities using `ChainRulesCore.rrule` as they will work for many AD systems, not just Zygote.
These sensitivities can be added in your own package, or for Base/StdLib functions they can be added to [ChainRules.jl](https://github.com/JuliaDiff/ChainRules.jl/).
- To define custom sensitivities using ChainRulesCore, define a `ChainRulesCore.rrule(f, args...; kwargs...)`. Head for [ChainRules project's documentation](https://www.juliadiff.org/ChainRulesCore.jl/stable/) for more information.
+ To define custom sensitivities using ChainRulesCore, define a `ChainRulesCore.rrule(f, args...; kwargs...)`. Head to [ChainRules project's documentation](https://www.juliadiff.org/ChainRulesCore.jl/stable/) for more information.
**If you are defining your custom adjoints using ChainRulesCore then you do not need to read this page**, and can consider it as documenting a legacy feature.
This page exists to descibe how Zygote works, and how adjoints can be directly defined for Zygote.
From 2f91749b06f27d894ffb314db0ebc4eea697faeb Mon Sep 17 00:00:00 2001
From: Michael Abbott <32575566+mcabbott@users.noreply.github.com>
Date: Sun, 11 Jul 2021 22:34:37 -0400
Subject: [PATCH 132/490] add error for broadcasting over Params
---
src/compiler/interface.jl | 5 +++++
test/interface.jl | 3 +++
2 files changed, 8 insertions(+)
diff --git a/src/compiler/interface.jl b/src/compiler/interface.jl
index 35a02e272..3ba61ad0c 100644
--- a/src/compiler/interface.jl
+++ b/src/compiler/interface.jl
@@ -200,6 +200,11 @@ function Base.delete!(ps::Params, x)
end
Base.Broadcast.broadcasted(f, ps::Params) = broadcasted(f, ps.order)
+# Broadcast.broadcastable(ps::Params) = ps.order
+
+@adjoint function Broadcast.broadcasted(f::Function, ps::Params)
+ f.(ps), _ -> throw(ArgumentError("Zygote.Params does not support broadcasting within gradients, try iteration `for p in ps`"))
+end
Base.:(==)(x::Params, y::Params) = x.order.data == y.order.data
diff --git a/test/interface.jl b/test/interface.jl
index 159f4bce1..18b0a9875 100644
--- a/test/interface.jl
+++ b/test/interface.jl
@@ -38,7 +38,10 @@ using Zygote: Grads
x, y = [1,2], [1]
ps = Params([x, y])
@test length.(ps) == length.([x, y]) # 617
+ @test size.(ps, 1) == [2, 1]
@test all(Params([[1,1]]) .== Params([[1,1]]))
+
+ @test_throws ArgumentError gradient(() -> sum(sum.(ps)), ps)
end
@testset "indexing" begin
From 6bdc92a88baad913365697acb1350f49ff9a55c8 Mon Sep 17 00:00:00 2001
From: Michael Abbott <32575566+mcabbott@users.noreply.github.com>
Date: Sun, 11 Jul 2021 22:35:45 -0400
Subject: [PATCH 133/490] extend Mutating arrays is not supported messages
---
src/lib/array.jl | 6 +++---
1 file changed, 3 insertions(+), 3 deletions(-)
diff --git a/src/lib/array.jl b/src/lib/array.jl
index aaf122cb4..48092191f 100644
--- a/src/lib/array.jl
+++ b/src/lib/array.jl
@@ -73,14 +73,14 @@ _droplike(dy::Union{LinearAlgebra.Adjoint, LinearAlgebra.Transpose}, dxv::Abstra
@adjoint getindex(::Type{T}, xs...) where {T} = T[xs...], dy -> (nothing, dy...)
@adjoint! setindex!(xs::AbstractArray, x...) = setindex!(xs, x...),
- _ -> error("Mutating arrays is not supported")
+ _ -> error("Mutating arrays is not supported -- called setindex!(::$(typeof(xs)), ...)")
@adjoint! copyto!(args...) = copyto!(args...),
- _ -> error("Mutating arrays is not supported")
+ _ -> error("Mutating arrays is not supported -- called copyto!(::$(typeof(xs)), ...)")
for f in [push!, pop!, pushfirst!, popfirst!]
@eval @adjoint! $f(xs, x...) =
- push!(xs, x...), _ -> error("Mutating arrays is not supported")
+ push!(xs, x...), _ -> error("Mutating arrays is not supported -- called $f(::$(typeof(xs)), ...)")
end
# This is kind of bad, but at least we don't materialize the whole
From 6ffb0d1fd4ba19a12825e79446be1f0cb4f1ffa0 Mon Sep 17 00:00:00 2001
From: Michael Abbott <32575566+mcabbott@users.noreply.github.com>
Date: Sun, 11 Jul 2021 22:47:19 -0400
Subject: [PATCH 134/490] remove "Flux-style models" comment
---
docs/src/index.md | 6 ++++--
1 file changed, 4 insertions(+), 2 deletions(-)
diff --git a/docs/src/index.md b/docs/src/index.md
index 36bf0ae8f..1eec9768f 100644
--- a/docs/src/index.md
+++ b/docs/src/index.md
@@ -153,7 +153,9 @@ Dict{Any,Any} with 2 entries:
:W => [0.628998 … 0.433006]
```
-An extension of this is the Flux-style model in which we use call overloading to combine the weight object with the pullback pass (equivalent to a closure).
+We can combine the role of the dictionary and the function here by making a callable struct which
+contains the parameters, equivalent to a closure. Passed explicitly to `gradient`, we get a named tuple
+with the same field names:
```julia
julia> struct Linear
@@ -170,7 +172,7 @@ julia> dmodel = gradient(model -> sum(model(x)), model)[1]
(W = [0.652543 … 0.683588], b = [1.0, 1.0])
```
-Zygote also support another way to take gradients, via *implicit parameters*. Here the loss function takes zero arguments, but the variables of interest are indicated by a special `Params` object. The function `linear` which depends on `W` and `b` is executed when the loss function `() -> sum(linear(x))` is called, and hence this dependence is visible to Zygote:
+Zygote also supports another way to take gradients, via *implicit parameters*. Here the loss function takes zero arguments, but the variables of interest are indicated by a special `Params` object. The function `linear` which depends on `W` and `b` is executed when the loss function `() -> sum(linear(x))` is called, and hence this dependence is visible to Zygote:
```julia
julia> W = rand(2, 5); b = rand(2);
From 8f659c8a26e50472c5f67a2c2c686f84d555e3eb Mon Sep 17 00:00:00 2001
From: Michael Abbott <32575566+mcabbott@users.noreply.github.com>
Date: Sun, 11 Jul 2021 23:06:01 -0400
Subject: [PATCH 135/490] double-quote
---
src/lib/array.jl | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/src/lib/array.jl b/src/lib/array.jl
index 48092191f..246f12dff 100644
--- a/src/lib/array.jl
+++ b/src/lib/array.jl
@@ -80,7 +80,7 @@ _droplike(dy::Union{LinearAlgebra.Adjoint, LinearAlgebra.Transpose}, dxv::Abstra
for f in [push!, pop!, pushfirst!, popfirst!]
@eval @adjoint! $f(xs, x...) =
- push!(xs, x...), _ -> error("Mutating arrays is not supported -- called $f(::$(typeof(xs)), ...)")
+ push!(xs, x...), _ -> error("Mutating arrays is not supported -- called $($f)(::$(typeof(xs)), ...)")
end
# This is kind of bad, but at least we don't materialize the whole
From 98f4590e3c9dd7db2335ae039c4548ca80bbbb5f Mon Sep 17 00:00:00 2001
From: Shuhei Kadowaki
Date: Mon, 12 Jul 2021 16:36:01 +0900
Subject: [PATCH 136/490] make sure to throw explicit `CompileError`
The previous error handling is user-unfriendly IMHO.
> before
```julia
julia> gradient(pow_try, 1)
ERROR: MethodError: no method matching iterate(::ErrorException)
Closest candidates are:
iterate(::Union{LinRange, StepRangeLen}) at range.jl:806
iterate(::Union{LinRange, StepRangeLen}, ::Int64) at range.jl:806
iterate(::T) where T<:Union{Base.KeySet{<:Any, <:Dict}, Base.ValueIterator{<:Dict}} at dict.jl:695
...
Stacktrace:
[1] indexed_iterate(I::ErrorException, i::Int64)
@ Base ./tuple.jl:91
[2] #s3061#1245
@ ~/julia/packages/Zygote/src/compiler/interface2.jl:34 [inlined]
...
```
> after
```julia
julia> gradient(pow_try, 1)
ERROR: Compiling Tuple{typeof(pow_try), Int64}: try/catch is not supported.
Stacktrace:
[1] error(s::String)
@ Base ./error.jl:33
[2] instrument(ir::IRTools.Inner.IR)
@ Zygote ~/julia/packages/Zygote/src/compiler/reverse.jl:121
...
```
---
src/compiler/interface2.jl | 8 ++++++--
test/features.jl | 3 ++-
2 files changed, 8 insertions(+), 3 deletions(-)
diff --git a/src/compiler/interface2.jl b/src/compiler/interface2.jl
index 0f7da4b32..bf3692a30 100644
--- a/src/compiler/interface2.jl
+++ b/src/compiler/interface2.jl
@@ -24,7 +24,11 @@ end
T = Tuple{f,args...}
ignore_sig(T) && return :(f(args...), Pullback{$T}(()))
- g = try _generate_pullback_via_decomposition(T) catch e e end
+ g = try
+ _generate_pullback_via_decomposition(T)
+ catch e
+ rethrow(CompileError(T,e))
+ end
g === nothing && return :(f(args...), Pullback{$T}((f,)))
meta, forw, _ = g
argnames!(meta, Symbol("#self#"), :ctx, :f, :args)
@@ -38,7 +42,7 @@ end
@generated function (j::Pullback{T})(Δ) where T
ignore_sig(T) && return :nothing
- g = try
+ g = try
_generate_pullback_via_decomposition(T)
catch e
rethrow(CompileError(T,e))
diff --git a/test/features.jl b/test/features.jl
index b1434aac7..766673350 100644
--- a/test/features.jl
+++ b/test/features.jl
@@ -397,6 +397,7 @@ function pow_try(x)
end
@test_broken gradient(pow_try, 1) == (2,)
+@test_throws Zygote.CompileError gradient(pow_try, 1)
function pow_simd(x, n)
r = 1
@@ -508,7 +509,7 @@ end
@test gradient(x -> sum(x .+ ones(2,2)), (1,2)) == ((2,2),)
@test gradient(x -> sum(x .+ ones(2,2)), (1,)) == ((4,),)
@test gradient(x -> sum(x .+ ones(2,1)), (1,2)) == ((1,1),)
-
+
# https://github.com/FluxML/Zygote.jl/issues/975
gt = gradient((x,p) -> prod(x .^ p), [3,4], (1,2))
gv = gradient((x,p) -> prod(x .^ p), [3,4], [1,2])
From aa44b5f5f9245d75d0273257ffbf100466726002 Mon Sep 17 00:00:00 2001
From: Michael Abbott <32575566+mcabbott@users.noreply.github.com>
Date: Mon, 12 Jul 2021 20:53:16 -0400
Subject: [PATCH 137/490] tweak
---
src/compiler/interface.jl | 2 +-
src/lib/array.jl | 8 ++++----
2 files changed, 5 insertions(+), 5 deletions(-)
diff --git a/src/compiler/interface.jl b/src/compiler/interface.jl
index 3ba61ad0c..2d9cdce11 100644
--- a/src/compiler/interface.jl
+++ b/src/compiler/interface.jl
@@ -200,7 +200,7 @@ function Base.delete!(ps::Params, x)
end
Base.Broadcast.broadcasted(f, ps::Params) = broadcasted(f, ps.order)
-# Broadcast.broadcastable(ps::Params) = ps.order
+Base.Broadcast.broadcastable(ps::Params) = ps.order
@adjoint function Broadcast.broadcasted(f::Function, ps::Params)
f.(ps), _ -> throw(ArgumentError("Zygote.Params does not support broadcasting within gradients, try iteration `for p in ps`"))
diff --git a/src/lib/array.jl b/src/lib/array.jl
index 246f12dff..d5f9557a4 100644
--- a/src/lib/array.jl
+++ b/src/lib/array.jl
@@ -73,14 +73,14 @@ _droplike(dy::Union{LinearAlgebra.Adjoint, LinearAlgebra.Transpose}, dxv::Abstra
@adjoint getindex(::Type{T}, xs...) where {T} = T[xs...], dy -> (nothing, dy...)
@adjoint! setindex!(xs::AbstractArray, x...) = setindex!(xs, x...),
- _ -> error("Mutating arrays is not supported -- called setindex!(::$(typeof(xs)), ...)")
+ _ -> error("Mutating arrays is not supported -- called setindex!(::$(typeof(xs)), _...)")
@adjoint! copyto!(args...) = copyto!(args...),
- _ -> error("Mutating arrays is not supported -- called copyto!(::$(typeof(xs)), ...)")
+ _ -> error("Mutating arrays is not supported -- called copyto!(::$(typeof(xs)), _...)")
for f in [push!, pop!, pushfirst!, popfirst!]
- @eval @adjoint! $f(xs, x...) =
- push!(xs, x...), _ -> error("Mutating arrays is not supported -- called $($f)(::$(typeof(xs)), ...)")
+ @eval @adjoint! $f(xs, x...) = $f(xs, x...),
+ _ -> error("Mutating arrays is not supported -- called $($f)(::$(typeof(xs)), _...)")
end
# This is kind of bad, but at least we don't materialize the whole
From d7cd2ec8be5784bae32ef6276acb5fe6628d2397 Mon Sep 17 00:00:00 2001
From: Michael Abbott <32575566+mcabbott@users.noreply.github.com>
Date: Mon, 12 Jul 2021 21:02:02 -0400
Subject: [PATCH 138/490] tweak
---
src/compiler/interface.jl | 1 -
1 file changed, 1 deletion(-)
diff --git a/src/compiler/interface.jl b/src/compiler/interface.jl
index 2d9cdce11..6ca8257d5 100644
--- a/src/compiler/interface.jl
+++ b/src/compiler/interface.jl
@@ -200,7 +200,6 @@ function Base.delete!(ps::Params, x)
end
Base.Broadcast.broadcasted(f, ps::Params) = broadcasted(f, ps.order)
-Base.Broadcast.broadcastable(ps::Params) = ps.order
@adjoint function Broadcast.broadcasted(f::Function, ps::Params)
f.(ps), _ -> throw(ArgumentError("Zygote.Params does not support broadcasting within gradients, try iteration `for p in ps`"))
From 7d40e94a20a31b8df3dc7a4d2614905f84659614 Mon Sep 17 00:00:00 2001
From: Michael Abbott <32575566+mcabbott@users.noreply.github.com>
Date: Tue, 13 Jul 2021 20:21:07 -0400
Subject: [PATCH 139/490] v0.6.15
---
Project.toml | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/Project.toml b/Project.toml
index 6cd2fccdc..8c7df235d 100644
--- a/Project.toml
+++ b/Project.toml
@@ -1,6 +1,6 @@
name = "Zygote"
uuid = "e88e6eb3-aa80-5325-afca-941959d7151f"
-version = "0.6.14"
+version = "0.6.15"
[deps]
AbstractFFTs = "621f4979-c628-5d54-868e-fcf4e3e8185c"
From b33920efbe37acbc2507846a642b97628c4a1b76 Mon Sep 17 00:00:00 2001
From: Dhairya Gandhi
Date: Wed, 14 Jul 2021 18:14:13 +0530
Subject: [PATCH 140/490] fix resolve message
---
src/lib/broadcast.jl | 1 -
1 file changed, 1 deletion(-)
diff --git a/src/lib/broadcast.jl b/src/lib/broadcast.jl
index 57bb7b38f..cdcd21547 100644
--- a/src/lib/broadcast.jl
+++ b/src/lib/broadcast.jl
@@ -247,7 +247,6 @@ end
end
@init @require CUDA="052768ef-5323-5732-b1bb-66c8b64840ba" begin
- using CUDA
const CuArrayStyle = CUDA.AbstractGPUArrayStyle
if isdefined(CUDA, :cufunc)
From 4d22edf4a27db0302badb12e1faec6d5ee31fd9d Mon Sep 17 00:00:00 2001
From: Carlo Lucibello
Date: Wed, 14 Jul 2021 18:35:34 +0200
Subject: [PATCH 141/490] Update Project.toml
---
Project.toml | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/Project.toml b/Project.toml
index 8c7df235d..986d0c10a 100644
--- a/Project.toml
+++ b/Project.toml
@@ -1,6 +1,6 @@
name = "Zygote"
uuid = "e88e6eb3-aa80-5325-afca-941959d7151f"
-version = "0.6.15"
+version = "0.6.16"
[deps]
AbstractFFTs = "621f4979-c628-5d54-868e-fcf4e3e8185c"
From 9203ea8ac035a257f4f8f3c7769581d8e6242c64 Mon Sep 17 00:00:00 2001
From: Carlo Lucibello
Date: Mon, 19 Jul 2021 10:42:07 +0200
Subject: [PATCH 142/490] construct Params with empty tuple
---
Project.toml | 2 +-
src/tools/idset.jl | 16 ++++++++++------
test/interface.jl | 5 +++++
3 files changed, 16 insertions(+), 7 deletions(-)
diff --git a/Project.toml b/Project.toml
index 986d0c10a..e209009ad 100644
--- a/Project.toml
+++ b/Project.toml
@@ -1,6 +1,6 @@
name = "Zygote"
uuid = "e88e6eb3-aa80-5325-afca-941959d7151f"
-version = "0.6.16"
+version = "0.6.17"
[deps]
AbstractFFTs = "621f4979-c628-5d54-868e-fcf4e3e8185c"
diff --git a/src/tools/idset.jl b/src/tools/idset.jl
index d9f0ceb04..9f3566699 100644
--- a/src/tools/idset.jl
+++ b/src/tools/idset.jl
@@ -3,18 +3,22 @@ struct IdSet{T} <: AbstractSet{T}
IdSet{T}() where T = new(IdDict{T,Nothing}())
end
-Base.eltype(::IdSet{T}) where T = T
+IdSet(xs) = IdSet{eltype(xs)}(xs)
IdSet() = IdSet{Any}()
+function IdSet{T}(xs) where T
+ s = IdSet{T}()
+ for x in xs
+ push!(s, x)
+ end
+ return s
+end
+
Base.push!(s::IdSet{T}, x::T) where T = (s.dict[x] = nothing; s)
Base.delete!(s::IdSet{T}, x::T) where T = (delete!(s.dict, x); s)
Base.in(x, s::IdSet) = haskey(s.dict, x)
-
-IdSet{T}(xs) where T = push!(IdSet{T}(), xs...)
-
-IdSet(xs) = IdSet{eltype(xs)}(xs)
-
+Base.eltype(::IdSet{T}) where T = T
Base.collect(s::IdSet) = Base.collect(keys(s.dict))
Base.similar(s::IdSet, T::Type) = IdSet{T}()
diff --git a/test/interface.jl b/test/interface.jl
index 18b0a9875..ff45e7ebd 100644
--- a/test/interface.jl
+++ b/test/interface.jl
@@ -82,6 +82,11 @@ using Zygote: Grads
@test ps isa Params
@test issetequal(ps, Set([y]))
end
+
+ @testset "constructor with empty args" begin
+ @test length(Params()) == 0
+ @test length(Params(())) == 0
+ end
end
@testset "Grads" begin
From 3df54fbc7d6d2d4fc68a1c190ef047fcc933d832 Mon Sep 17 00:00:00 2001
From: Carlo Lucibello
Date: Mon, 19 Jul 2021 11:00:21 +0200
Subject: [PATCH 143/490] use splatting
---
src/tools/idset.jl | 8 +-------
1 file changed, 1 insertion(+), 7 deletions(-)
diff --git a/src/tools/idset.jl b/src/tools/idset.jl
index 9f3566699..d8a600992 100644
--- a/src/tools/idset.jl
+++ b/src/tools/idset.jl
@@ -7,13 +7,7 @@ IdSet(xs) = IdSet{eltype(xs)}(xs)
IdSet() = IdSet{Any}()
-function IdSet{T}(xs) where T
- s = IdSet{T}()
- for x in xs
- push!(s, x)
- end
- return s
-end
+IdSet{T}(xs) = isempty(xs) ? IdSet{T}() : push!(IdSet{T}(), xs...)
Base.push!(s::IdSet{T}, x::T) where T = (s.dict[x] = nothing; s)
Base.delete!(s::IdSet{T}, x::T) where T = (delete!(s.dict, x); s)
From 8467723a14aef3b23fa0f6dc417b6c991ed66810 Mon Sep 17 00:00:00 2001
From: Carlo Lucibello
Date: Mon, 19 Jul 2021 11:02:04 +0200
Subject: [PATCH 144/490] use splatting
---
src/tools/idset.jl | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/src/tools/idset.jl b/src/tools/idset.jl
index d8a600992..a0aa93df0 100644
--- a/src/tools/idset.jl
+++ b/src/tools/idset.jl
@@ -7,7 +7,7 @@ IdSet(xs) = IdSet{eltype(xs)}(xs)
IdSet() = IdSet{Any}()
-IdSet{T}(xs) = isempty(xs) ? IdSet{T}() : push!(IdSet{T}(), xs...)
+IdSet{T}(xs) where T = isempty(xs) ? IdSet{T}() : push!(IdSet{T}(), xs...)
Base.push!(s::IdSet{T}, x::T) where T = (s.dict[x] = nothing; s)
Base.delete!(s::IdSet{T}, x::T) where T = (delete!(s.dict, x); s)
From a2a0393959239859b7673a5fc2177c2aecaaeabb Mon Sep 17 00:00:00 2001
From: Carlo Lucibello
Date: Mon, 19 Jul 2021 11:02:53 +0200
Subject: [PATCH 145/490] more tests
---
test/interface.jl | 1 +
1 file changed, 1 insertion(+)
diff --git a/test/interface.jl b/test/interface.jl
index ff45e7ebd..0bee98321 100644
--- a/test/interface.jl
+++ b/test/interface.jl
@@ -86,6 +86,7 @@ using Zygote: Grads
@testset "constructor with empty args" begin
@test length(Params()) == 0
@test length(Params(())) == 0
+ @test length(Params([])) == 0
end
end
From 8ce8b32e52976340d81b70b40dc429a7f9a258d8 Mon Sep 17 00:00:00 2001
From: Christopher Rackauckas
Date: Mon, 19 Jul 2021 08:13:19 -0400
Subject: [PATCH 146/490] Add some downstream testing
Adds downstream testing to CI, with a few application packages that should flex the API a bit.
---
.github/workflows/Downstream.yml | 55 ++++++++++++++++++++++++++++++++
1 file changed, 55 insertions(+)
create mode 100644 .github/workflows/Downstream.yml
diff --git a/.github/workflows/Downstream.yml b/.github/workflows/Downstream.yml
new file mode 100644
index 000000000..6565e82d5
--- /dev/null
+++ b/.github/workflows/Downstream.yml
@@ -0,0 +1,55 @@
+
+name: IntegrationTest
+on:
+ push:
+ branches: [master]
+ tags: [v*]
+ pull_request:
+
+jobs:
+ test:
+ name: ${{ matrix.package.repo }}/${{ matrix.package.group }}
+ runs-on: ${{ matrix.os }}
+ env:
+ GROUP: ${{ matrix.package.group }}
+ strategy:
+ fail-fast: false
+ matrix:
+ julia-version: [1]
+ os: [ubuntu-latest]
+ package:
+ - {user: FluxML, repo: Flux.jl, group: All}
+ - {user: FluxML, repo: NNlib.jl, group: All}
+ - {user: FluxML, repo: FastAI.jl, group: All}
+ - {user: FluxML, repo: GeometricFlux.jl, group: All}
+ - {user: SciML, repo: DiffEqFlux.jl, group: Layers}
+ - {user: SciML, repo: NeuralPDE.jl, group: NNPDE}
+ steps:
+ - uses: actions/checkout@v2
+ - uses: julia-actions/setup-julia@v1
+ with:
+ version: ${{ matrix.julia-version }}
+ arch: x64
+ - uses: julia-actions/julia-buildpkg@latest
+ - name: Clone Downstream
+ uses: actions/checkout@v2
+ with:
+ repository: ${{ matrix.package.user }}/${{ matrix.package.repo }}
+ path: downstream
+ - name: Load this and run the downstream tests
+ shell: julia --color=yes --project=downstream {0}
+ run: |
+ using Pkg
+ try
+ # force it to use this PR's version of the package
+ Pkg.develop(PackageSpec(path=".")) # resolver may fail with main deps
+ Pkg.update()
+ Pkg.test() # resolver may fail with test time deps
+ catch err
+ err isa Pkg.Resolve.ResolverError || rethrow()
+ # If we can't resolve that means this is incompatible by SemVer and this is fine
+ # It means we marked this as a breaking change, so we don't need to worry about
+ # Mistakenly introducing a breaking change, as we have intentionally made one
+ @info "Not compatible with this release. No problem." exception=err
+ exit(0) # Exit immediately, as a success
+ end
From 05cebdc8badf4a145bb97ef386f806cefb21f572 Mon Sep 17 00:00:00 2001
From: Lyndon White
Date: Tue, 20 Jul 2021 13:25:25 +0100
Subject: [PATCH 147/490] Add support for ChainRules optout
---
src/compiler/chainrules.jl | 53 ++++++++++++++++++++++++++++----------
test/chainrules.jl | 30 +++++++++++++++++++++
2 files changed, 69 insertions(+), 14 deletions(-)
diff --git a/src/compiler/chainrules.jl b/src/compiler/chainrules.jl
index c4e72f07e..3698f21fb 100644
--- a/src/compiler/chainrules.jl
+++ b/src/compiler/chainrules.jl
@@ -4,8 +4,7 @@ end
ZygoteRuleConfig() = ZygoteRuleConfig(Context())
-const rrule_fallback_method = Base.which(rrule, Tuple{Any, Vararg{Any}})
-const rrule_redispatcher_method = Base.which(rrule, Tuple{RuleConfig, Any, Vararg{Any}})
+_is_rrule_redispatcher(m::Method) = m.sig == Tuple{typeof(rrule), RuleConfig, Vararg}
"""
has_chain_rrule(T)
@@ -18,19 +17,45 @@ such that if a suitable rule is defined later, the generated function will recom
"""
function has_chain_rrule(T)
config_T, arg_Ts = Iterators.peel(T.parameters)
- m_with_config = meta(Tuple{typeof(rrule), config_T, arg_Ts...})
- if m_with_config.method === rrule_redispatcher_method
- # it is being redispatched without config, so check it that hits the fallback
- m_without_config = meta(Tuple{typeof(rrule), arg_Ts...})
- if m_without_config.method === rrule_fallback_method
- # no rrule exists, return instance for m_with_config as that will be invalidated
- # directly if configured rule added, or indirectly if unconfigured rule added
- return false, m_with_config.instance
- end
+ configured_rrule_m = meta(Tuple{typeof(rrule), config_T, arg_Ts...})
+ if _is_rrule_redispatcher(configured_rrule_m.method)
+ # it is being redispatched without config, so get the method it redispatches to
+ rrule_m = meta(Tuple{typeof(rrule), arg_Ts...})
+ no_rrule_m = meta(Tuple{typeof(ChainRulesCore.no_rrule), arg_Ts...})
+ else
+ # Not being redispatched
+ rrule_m = configured_rrule_m
+ no_rrule_m = meta(Tuple{typeof(ChainRulesCore.no_rrule), config_T, arg_Ts...})
+ end
+
+ do_not_use_rrule = matching_cr_sig(no_rrule_m, rrule_m)
+ if do_not_use_rrule
+ # return instance for configured_rrule_m as that will be invalidated
+ # directly if configured rule added, or indirectly if unconfigured rule added
+ # Do not need an edge for `no_rrule` as no addition of methods to that can cause this
+ # decision to need to be revisited (only changes to `rrule`), since we are already not
+ # using the rrule, so not using more rules wouldn't change anything
+ return false, configured_rrule_m.instance
+ else
+ # otherwise found a rrule, no need to add any edges for `rrule`, as it will generate
+ # code with natural edges if a new method is defined there.
+ # We also do not need an edge to `no_rrule`, as any time a method is added to `no_rrule`
+ # a corresponding method is added to `rrule` (to return `nothing`), thus we will already
+ # be revisiting this decision when a new opt-out is added
+ return true, nothing
end
- # otherwise found a rrule, no need to add any edges, as it will generate code with
- # natural edges.
- return true, nothing
+end
+
+matching_cr_sig(t, s) = matching_cr_sig(t.method.sig, s.method.sig)
+matching_cr_sig(::DataType, ::UnionAll) = false
+matching_cr_sig(::UnionAll, ::DataType) = false
+matching_cr_sig(t::Type, s::Type) = type_tuple_tail(t) == type_tuple_tail(s)
+
+type_tuple_tail(d::DataType) = Tuple{d.parameters[2:end]...}
+function type_tuple_tail(d::UnionAll)
+ body = Base.unwrap_unionall(d)
+ body_tt = type_tuple_tail(body)
+ return Base.rewrap_unionall(body_tt, d)
end
"""
diff --git a/test/chainrules.jl b/test/chainrules.jl
index 32bdd3799..3fb323e69 100644
--- a/test/chainrules.jl
+++ b/test/chainrules.jl
@@ -232,6 +232,36 @@ using ChainRulesCore, ChainRulesTestUtils, Zygote
aug_primal_val, _ = Zygote.pullback(x->StructForTestingTypeOnlyRRules(), 1.2)
@test aug_primal_val.x == 2.0
end
+
+ @testset "@opt_out" begin
+ oa_id(x) = x
+ oa_id_rrule_hitcount = Ref(0)
+ function ChainRulesCore.rrule(::typeof(oa_id), x::Any)
+ oa_id_rrule_hitcount[] += 1
+ oa_id_pullback(ȳ) = (NoTangent(), ȳ)
+ return oa_id(x), oa_id_pullback
+ end
+
+ @opt_out ChainRulesCore.rrule(::typeof(oa_id), x::AbstractArray)
+
+ # Hit one we haven't opted out
+ oa_id_rrule_hitcount[] = 0
+ oa_id_outer(x) = sum(oa_id(x))
+ @test (1.0,) == Zygote.gradient(oa_id_outer, π)
+ @test oa_id_rrule_hitcount[] == 1
+
+ # make sure don't hit the one we have opted out
+ oa_id_rrule_hitcount[] = 0
+ @test ([1.0],) == Zygote.gradient(oa_id_outer, [π])
+ @test oa_id_rrule_hitcount[] == 0
+
+ # Now try opting out After we have already used it
+ @opt_out ChainRulesCore.rrule(::typeof(oa_id), x::Real)
+ oa_id_rrule_hitcount[] = 0
+ oa_id_outer(x) = sum(oa_id(x))
+ @test (1.0,) == Zygote.gradient(oa_id_outer, π)
+ @test oa_id_rrule_hitcount[] == 0
+ end
end
@testset "ChainRulesCore.rrule_via_ad" begin
From 06aaae2dc8964154ade4fc2f6565d83f4bef77a9 Mon Sep 17 00:00:00 2001
From: Lyndon White
Date: Wed, 21 Jul 2021 14:36:06 +0100
Subject: [PATCH 148/490] More explain how has_chain_rrule works
---
src/compiler/chainrules.jl | 35 +++++++++++++++++++++++++++++------
1 file changed, 29 insertions(+), 6 deletions(-)
diff --git a/src/compiler/chainrules.jl b/src/compiler/chainrules.jl
index 3698f21fb..6d43946bc 100644
--- a/src/compiler/chainrules.jl
+++ b/src/compiler/chainrules.jl
@@ -19,29 +19,52 @@ function has_chain_rrule(T)
config_T, arg_Ts = Iterators.peel(T.parameters)
configured_rrule_m = meta(Tuple{typeof(rrule), config_T, arg_Ts...})
if _is_rrule_redispatcher(configured_rrule_m.method)
- # it is being redispatched without config, so get the method it redispatches to
+ # The config is not being used:
+ # it is being redispatched without config, so we need the method it redispatches to
rrule_m = meta(Tuple{typeof(rrule), arg_Ts...})
+ # Thus any no_rrule that might apply must also not have a config because if there was a
+ # no_rrule with a config that applied then there would also be a rrule with config that applied
no_rrule_m = meta(Tuple{typeof(ChainRulesCore.no_rrule), arg_Ts...})
else
- # Not being redispatched
+ # Not being redispatched: it does have a config
rrule_m = configured_rrule_m
+ # Thus any no_rrule that might apply must also have a config because if it applied
+ # it will be identical, and if it doesn't we don't care what it is.
no_rrule_m = meta(Tuple{typeof(ChainRulesCore.no_rrule), config_T, arg_Ts...})
end
+ # To understand why we only need to check if the sigs match between no_rrule_m and rrule_m
+ # in order to decide if to use, one must consider the following facts:
+ # - for every method in `no_rrule` there is a identical one in `rrule` that returns nothing
+ # - this includes the general fallback `rrule(::Any...)=nothing`.
+ # - a configured rrule/no_rrule is always more specific than a otherwise equivalent unconfigured rrule/no_rrule
+ #
+ # Consider the following truth table, for what can occur:
+ # rrule: fallback, no_rrule: fallback => matches => do not use rrule.
+ # rrule: specific, no_rrule: fallback => !matches => do use rrule, as haven't opted out.
+ # rrule: fallback, no_rrule: specific => IMPOSSIBLE, every no_rule us identical to some rrule
+ # rrule: specific, no_rrule: specific => matches => do not use rrule as opted out
+ # rrule: specific, no_rrule: general => !matches => do use rrule as a more specific rrule takes preciedent over more general opted out
+ # rrule: general , no_rrule: specific => IMPOSSIBLE, every no_rule us identical to some rrule so can't have a more general rrule being hit, as the specific one would hit first
+ #
+ # Note that the fallback cases are the same outcome as the general cases as fallback is just most general.
+ # It can be seen that checking if it matches is the correct way to decide if we should ue the rrule or not.
+
+
do_not_use_rrule = matching_cr_sig(no_rrule_m, rrule_m)
if do_not_use_rrule
- # return instance for configured_rrule_m as that will be invalidated
+ # Return instance for configured_rrule_m as that will be invalidated
# directly if configured rule added, or indirectly if unconfigured rule added
# Do not need an edge for `no_rrule` as no addition of methods to that can cause this
# decision to need to be revisited (only changes to `rrule`), since we are already not
- # using the rrule, so not using more rules wouldn't change anything
+ # using the rrule, so not using more rules wouldn't change anything.
return false, configured_rrule_m.instance
else
- # otherwise found a rrule, no need to add any edges for `rrule`, as it will generate
+ # Otherwise found a rrule, no need to add any edges for `rrule`, as it will generate
# code with natural edges if a new method is defined there.
# We also do not need an edge to `no_rrule`, as any time a method is added to `no_rrule`
# a corresponding method is added to `rrule` (to return `nothing`), thus we will already
- # be revisiting this decision when a new opt-out is added
+ # be revisiting this decision when a new opt-out is added.
return true, nothing
end
end
From 9e18f12ad89fe676e6a2e9913290f31c4162f7da Mon Sep 17 00:00:00 2001
From: Sheehan Olver
Date: Wed, 21 Jul 2021 14:42:16 +0100
Subject: [PATCH 149/490] Fix UndefVarError
---
src/lib/array.jl | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/src/lib/array.jl b/src/lib/array.jl
index d5f9557a4..1ad7eb6a7 100644
--- a/src/lib/array.jl
+++ b/src/lib/array.jl
@@ -75,7 +75,7 @@ _droplike(dy::Union{LinearAlgebra.Adjoint, LinearAlgebra.Transpose}, dxv::Abstra
@adjoint! setindex!(xs::AbstractArray, x...) = setindex!(xs, x...),
_ -> error("Mutating arrays is not supported -- called setindex!(::$(typeof(xs)), _...)")
-@adjoint! copyto!(args...) = copyto!(args...),
+@adjoint! copyto!(xs, args...) = copyto!(xs, args...),
_ -> error("Mutating arrays is not supported -- called copyto!(::$(typeof(xs)), _...)")
for f in [push!, pop!, pushfirst!, popfirst!]
From f878cf7cc1c96444d27db45c4e2b4277fde9848c Mon Sep 17 00:00:00 2001
From: Carlo Lucibello
Date: Wed, 21 Jul 2021 17:59:53 +0200
Subject: [PATCH 150/490] fix copyto! error message
---
src/lib/array.jl | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/src/lib/array.jl b/src/lib/array.jl
index d5f9557a4..a9a096b8e 100644
--- a/src/lib/array.jl
+++ b/src/lib/array.jl
@@ -76,7 +76,7 @@ _droplike(dy::Union{LinearAlgebra.Adjoint, LinearAlgebra.Transpose}, dxv::Abstra
_ -> error("Mutating arrays is not supported -- called setindex!(::$(typeof(xs)), _...)")
@adjoint! copyto!(args...) = copyto!(args...),
- _ -> error("Mutating arrays is not supported -- called copyto!(::$(typeof(xs)), _...)")
+ _ -> error("Mutating arrays is not supported -- called copyto!(::$(typeof(args))..., _...)")
for f in [push!, pop!, pushfirst!, popfirst!]
@eval @adjoint! $f(xs, x...) = $f(xs, x...),
From dbde9a541ca5579518ed6220282c1e8242a4953b Mon Sep 17 00:00:00 2001
From: Miha Zgubic
Date: Fri, 23 Jul 2021 11:34:30 +0200
Subject: [PATCH 151/490] fix convert
---
src/compiler/chainrules.jl | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/src/compiler/chainrules.jl b/src/compiler/chainrules.jl
index 6d43946bc..3e1745c10 100644
--- a/src/compiler/chainrules.jl
+++ b/src/compiler/chainrules.jl
@@ -111,7 +111,7 @@ for T_outer in (:Tuple, :NamedTuple)
# than happy.
@eval @inline function wrap_chainrules_output(x::ChainRules.Tangent{P, T}) where {P, T<:$T_outer}
xp = map(wrap_chainrules_output, canonicalize(x))
- convert($T_outer, xp)
+ ChainRulesCore.backing(xp)
end
end
From aecec9be88003620e2e66675e6c155be9a191a90 Mon Sep 17 00:00:00 2001
From: Miha Zgubic
Date: Fri, 23 Jul 2021 11:35:44 +0200
Subject: [PATCH 152/490] fix tests which were wrong to accommodate no
projection
---
test/complex.jl | 2 +-
test/features.jl | 4 ++--
test/gradcheck.jl | 4 ++--
3 files changed, 5 insertions(+), 5 deletions(-)
diff --git a/test/complex.jl b/test/complex.jl
index 54f99fd0f..6a0445b85 100644
--- a/test/complex.jl
+++ b/test/complex.jl
@@ -18,7 +18,7 @@ using Zygote, Test, LinearAlgebra
@test gradient(x -> real(logabsdet(x)[1]), [1 2im; 3im 4])[1] ≈ [4 3im; 2im 1]/10
# https://github.com/FluxML/Zygote.jl/issues/705
-@test gradient(x -> imag(sum(exp, x)), [1,2,3])[1] ≈ im .* exp.(1:3)
+@test gradient(x -> imag(sum(exp, x)), [1,2,3])[1] ≈ real(im .* exp.(1:3))
@test gradient(x -> imag(sum(exp, x)), [1+0im,2,3])[1] ≈ im .* exp.(1:3)
fs_C_to_R = (real,
diff --git a/test/features.jl b/test/features.jl
index 766673350..b17f55b41 100644
--- a/test/features.jl
+++ b/test/features.jl
@@ -449,12 +449,12 @@ end
@test pullback(type_test)[1] == Complex{<:Real}
@testset "Pairs" begin
- @test (x->10*pairs((a=x, b=2))[1])'(100) === 10
+ @test (x->10*pairs((a=x, b=2))[1])'(100) === 10.0
@test (x->10*pairs((a=x, b=2))[2])'(100) === 0
foo(;kw...) = 1
@test gradient(() -> foo(a=1,b=2.0)) === ()
- @test (x->10*(x => 2)[1])'(100) === 10
+ @test (x->10*(x => 2)[1])'(100) === 10.0
@test (x->10*(x => 2)[2])'(100) === 0
end
diff --git a/test/gradcheck.jl b/test/gradcheck.jl
index 08dfd45db..9ffd04260 100644
--- a/test/gradcheck.jl
+++ b/test/gradcheck.jl
@@ -81,7 +81,7 @@ end
@test gradient(xs ->sum(xs .^ _pow), [4, -1]) == ([_pow*4^9, -10],)
@test gradient(x -> real((1+3im) * x^2), 5+7im) == (-32 - 44im,)
- @test gradient(p -> real((1+3im) * (5+7im)^p), 2)[1] ≈ (-234 + 2im)*log(5 - 7im)
+ @test gradient(p -> real((1+3im) * (5+7im)^p), 2)[1] ≈ real((-234 + 2im)*log(5 - 7im))
# D[(1+3I)x^p, p] /. {x->5+7I, p->2} // Conjugate
end
@@ -160,7 +160,7 @@ end
# https://github.com/FluxML/Zygote.jl/issues/376
_, back = Zygote._pullback(x->x[1]*im, randn(2))
- @test back(1.0)[2] == [-im, 0]
+ @test back(1.0)[2] == real([-im, 0])
# _droplike
@test gradient(x -> sum(inv, x[1, :]'), ones(2, 2)) == ([-1 -1; 0 0],)
From 47d09d8ee6df2b3cfe17e9a6153e0464d81948e8 Mon Sep 17 00:00:00 2001
From: Miha Zgubic
Date: Fri, 23 Jul 2021 12:49:42 +0200
Subject: [PATCH 153/490] add dependency on the PR that fixes the Dual error
---
Project.toml | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/Project.toml b/Project.toml
index e209009ad..9ac8ab690 100644
--- a/Project.toml
+++ b/Project.toml
@@ -28,7 +28,7 @@ ChainRulesCore = "0.10.4"
ChainRulesTestUtils = "0.7.1"
DiffRules = "1.0"
FillArrays = "0.8, 0.9, 0.10, 0.11, 0.12"
-ForwardDiff = "0.10"
+ForwardDiff = "0.10.20"
IRTools = "0.4"
MacroTools = "0.5"
NaNMath = "0.3"
From 6ae2ad5de0372c6d40ff2528c44ae7f1224d6e3a Mon Sep 17 00:00:00 2001
From: Miha Zgubic
Date: Fri, 23 Jul 2021 12:50:22 +0200
Subject: [PATCH 154/490] compat to CRC 1.0
---
Project.toml | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/Project.toml b/Project.toml
index 9ac8ab690..c8d96bd7e 100644
--- a/Project.toml
+++ b/Project.toml
@@ -24,7 +24,7 @@ ZygoteRules = "700de1a5-db45-46bc-99cf-38207098b444"
[compat]
AbstractFFTs = "0.5, 1.0"
ChainRules = "0.8.12"
-ChainRulesCore = "0.10.4"
+ChainRulesCore = "1"
ChainRulesTestUtils = "0.7.1"
DiffRules = "1.0"
FillArrays = "0.8, 0.9, 0.10, 0.11, 0.12"
From 26286fc64c48149bb7a6c0b6e1108147ad1e3a5b Mon Sep 17 00:00:00 2001
From: Miha Zgubic
Date: Fri, 23 Jul 2021 13:31:51 +0200
Subject: [PATCH 155/490] mark ForwardDiff test broken
---
Project.toml | 2 +-
test/utils.jl | 13 +++++++++----
2 files changed, 10 insertions(+), 5 deletions(-)
diff --git a/Project.toml b/Project.toml
index c8d96bd7e..edcbfc5ed 100644
--- a/Project.toml
+++ b/Project.toml
@@ -28,7 +28,7 @@ ChainRulesCore = "1"
ChainRulesTestUtils = "0.7.1"
DiffRules = "1.0"
FillArrays = "0.8, 0.9, 0.10, 0.11, 0.12"
-ForwardDiff = "0.10.20"
+ForwardDiff = "0.10"
IRTools = "0.4"
MacroTools = "0.5"
NaNMath = "0.3"
diff --git a/test/utils.jl b/test/utils.jl
index 9a3d83ea5..3b461f82b 100644
--- a/test/utils.jl
+++ b/test/utils.jl
@@ -24,10 +24,15 @@ end
xs, y = randn(2,3), rand()
f34(xs, y) = xs[1] * (sum(xs .^ (1:3)') + y^4) # non-diagonal Hessian, two arguments
- dx, dy = diaghessian(f34, xs, y)
- @test size(dx) == size(xs)
- @test vec(dx) ≈ diag(hessian(x -> f34(x,y), xs))
- @test dy ≈ hessian(y -> f34(xs,y), y)
+
+ function broken()
+ dx, dy = diaghessian(f34, xs, y) # This fails becase ProjectTo can't project a Dual onto a Float
+ c1 = size(dx) == size(xs)
+ c2 = vec(dx) ≈ diag(hessian(x -> f34(x,y), xs))
+ c3 = dy ≈ hessian(y -> f34(xs,y), y)
+ return all([c1, c2, c3])
+ end
+ @test_broken broken()
zs = randn(7,13) # test chunk mode
@test length(zs) > ForwardDiff.DEFAULT_CHUNK_THRESHOLD
From b89989a6a99cdbcf34fa777b4e148257e418808d Mon Sep 17 00:00:00 2001
From: Lyndon White
Date: Fri, 23 Jul 2021 13:49:55 +0100
Subject: [PATCH 156/490] Apply suggestions from code review
---
src/compiler/chainrules.jl | 2 +-
test/utils.jl | 10 ++++++++++
2 files changed, 11 insertions(+), 1 deletion(-)
diff --git a/src/compiler/chainrules.jl b/src/compiler/chainrules.jl
index 3e1745c10..99d907299 100644
--- a/src/compiler/chainrules.jl
+++ b/src/compiler/chainrules.jl
@@ -111,7 +111,7 @@ for T_outer in (:Tuple, :NamedTuple)
# than happy.
@eval @inline function wrap_chainrules_output(x::ChainRules.Tangent{P, T}) where {P, T<:$T_outer}
xp = map(wrap_chainrules_output, canonicalize(x))
- ChainRulesCore.backing(xp)
+ ChainRulesCore.backing(xp) # this is accessing ChainRulesCore internals, but it is prob safe enough, and it is fastest
end
end
diff --git a/test/utils.jl b/test/utils.jl
index 3b461f82b..ee01fa7ed 100644
--- a/test/utils.jl
+++ b/test/utils.jl
@@ -25,6 +25,16 @@ end
xs, y = randn(2,3), rand()
f34(xs, y) = xs[1] * (sum(xs .^ (1:3)') + y^4) # non-diagonal Hessian, two arguments
+ # Follow is should work ones we workout what ForwardDiff should do when `Float64` is called on a `Dual`
+ # https://github.com/JuliaDiff/ForwardDiff.jl/pull/538
+ # else might need a custom overload of `(;;ChainRulesCore.ProjectTo)(::Dual)`
+ # When fixed uncomment the below and delete the broken function
+ #==
+ dx, dy = diaghessian(f34, xs, y)
+ @test size(dx) == size(xs)
+ @test vec(dx) ≈ diag(hessian(x -> f34(x,y), xs))
+ @test dy ≈ hessian(y -> f34(xs,y), y)
+ ==#
function broken()
dx, dy = diaghessian(f34, xs, y) # This fails becase ProjectTo can't project a Dual onto a Float
c1 = size(dx) == size(xs)
From 91536890929d2398539ef382db4c100c4d45175f Mon Sep 17 00:00:00 2001
From: Lyndon White
Date: Fri, 23 Jul 2021 13:50:50 +0100
Subject: [PATCH 157/490] Apply suggestions from code review
---
test/gradcheck.jl | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/test/gradcheck.jl b/test/gradcheck.jl
index 9ffd04260..eab959ddd 100644
--- a/test/gradcheck.jl
+++ b/test/gradcheck.jl
@@ -160,7 +160,7 @@ end
# https://github.com/FluxML/Zygote.jl/issues/376
_, back = Zygote._pullback(x->x[1]*im, randn(2))
- @test back(1.0)[2] == real([-im, 0])
+ @test back(1.0)[2] == real([-im, 0]) == [0, 0]
# _droplike
@test gradient(x -> sum(inv, x[1, :]'), ones(2, 2)) == ([-1 -1; 0 0],)
From e30fcf6196681ac410fce00db3300925d7f821fd Mon Sep 17 00:00:00 2001
From: Lyndon White
Date: Fri, 23 Jul 2021 19:58:08 +0100
Subject: [PATCH 158/490] rename test out out function
---
test/chainrules.jl | 37 ++++++++++++++++++-------------------
1 file changed, 18 insertions(+), 19 deletions(-)
diff --git a/test/chainrules.jl b/test/chainrules.jl
index 3fb323e69..1ab2edc5a 100644
--- a/test/chainrules.jl
+++ b/test/chainrules.jl
@@ -234,33 +234,32 @@ using ChainRulesCore, ChainRulesTestUtils, Zygote
end
@testset "@opt_out" begin
- oa_id(x) = x
- oa_id_rrule_hitcount = Ref(0)
- function ChainRulesCore.rrule(::typeof(oa_id), x::Any)
- oa_id_rrule_hitcount[] += 1
- oa_id_pullback(ȳ) = (NoTangent(), ȳ)
- return oa_id(x), oa_id_pullback
+ oout_id(x) = x
+ oout_id_rrule_hitcount = Ref(0)
+ function ChainRulesCore.rrule(::typeof(oout_id), x::Any)
+ oout_id_rrule_hitcount[] += 1
+ oout_id_pullback(ȳ) = (NoTangent(), ȳ)
+ return oout_id(x), oout_id_pullback
end
- @opt_out ChainRulesCore.rrule(::typeof(oa_id), x::AbstractArray)
+ @opt_out ChainRulesCore.rrule(::typeof(oout_id), x::AbstractArray)
# Hit one we haven't opted out
- oa_id_rrule_hitcount[] = 0
- oa_id_outer(x) = sum(oa_id(x))
- @test (1.0,) == Zygote.gradient(oa_id_outer, π)
- @test oa_id_rrule_hitcount[] == 1
+ oout_id_rrule_hitcount[] = 0
+ oout_id_outer(x) = sum(oout_id(x))
+ @test (1.0,) == Zygote.gradient(oout_id_outer, π)
+ @test oout_id_rrule_hitcount[] == 1
# make sure don't hit the one we have opted out
- oa_id_rrule_hitcount[] = 0
- @test ([1.0],) == Zygote.gradient(oa_id_outer, [π])
- @test oa_id_rrule_hitcount[] == 0
+ oout_id_rrule_hitcount[] = 0
+ @test ([1.0],) == Zygote.gradient(oout_id_outer, [π])
+ @test oout_id_rrule_hitcount[] == 0
# Now try opting out After we have already used it
- @opt_out ChainRulesCore.rrule(::typeof(oa_id), x::Real)
- oa_id_rrule_hitcount[] = 0
- oa_id_outer(x) = sum(oa_id(x))
- @test (1.0,) == Zygote.gradient(oa_id_outer, π)
- @test oa_id_rrule_hitcount[] == 0
+ @opt_out ChainRulesCore.rrule(::typeof(oout_id), x::Real)
+ oout_id_rrule_hitcount[] = 0
+ @test (1.0,) == Zygote.gradient(oout_id_outer, π)
+ @test oout_id_rrule_hitcount[] == 0
end
end
From e1d481927d89004e2ff3a2fdaa11f8c6efc89cdc Mon Sep 17 00:00:00 2001
From: Lyndon White
Date: Fri, 23 Jul 2021 19:58:54 +0100
Subject: [PATCH 159/490] CR v1
---
Project.toml | 4 ++--
1 file changed, 2 insertions(+), 2 deletions(-)
diff --git a/Project.toml b/Project.toml
index edcbfc5ed..2b595b150 100644
--- a/Project.toml
+++ b/Project.toml
@@ -23,9 +23,9 @@ ZygoteRules = "700de1a5-db45-46bc-99cf-38207098b444"
[compat]
AbstractFFTs = "0.5, 1.0"
-ChainRules = "0.8.12"
+ChainRules = "1"
ChainRulesCore = "1"
-ChainRulesTestUtils = "0.7.1"
+ChainRulesTestUtils = "1"
DiffRules = "1.0"
FillArrays = "0.8, 0.9, 0.10, 0.11, 0.12"
ForwardDiff = "0.10"
From 01b75a733930d3ab506419d05cca54a6f58dc492 Mon Sep 17 00:00:00 2001
From: Dhairya Gandhi
Date: Sun, 25 Jul 2021 14:27:07 +0530
Subject: [PATCH 160/490] update CompatHelper script
---
.github/workflows/CompatHelper.yml | 39 +++++++++++++++---------------
1 file changed, 20 insertions(+), 19 deletions(-)
diff --git a/.github/workflows/CompatHelper.yml b/.github/workflows/CompatHelper.yml
index 0243c7062..1696bd76e 100644
--- a/.github/workflows/CompatHelper.yml
+++ b/.github/workflows/CompatHelper.yml
@@ -1,26 +1,27 @@
name: CompatHelper
-
on:
schedule:
- - cron: '00 * * * *'
- issues:
- types: [opened, reopened]
-
+ - cron: 0 0 * * *
+ workflow_dispatch:
jobs:
- build:
- runs-on: ${{ matrix.os }}
- strategy:
- matrix:
- julia-version: [1.2.0]
- julia-arch: [x86]
- os: [ubuntu-latest]
+ CompatHelper:
+ runs-on: ubuntu-latest
steps:
- - uses: julia-actions/setup-julia@latest
- with:
- version: ${{ matrix.julia-version }}
- - name: Pkg.add("CompatHelper")
- run: julia -e 'using Pkg; Pkg.add("CompatHelper")'
- - name: CompatHelper.main()
+ - name: "Install CompatHelper"
+ run: |
+ import Pkg
+ name = "CompatHelper"
+ uuid = "aa819f21-2bde-4658-8897-bab36330d9b7"
+ version = "2"
+ Pkg.add(; name, uuid, version)
+ shell: julia --color=yes {0}
+ - name: "Run CompatHelper"
+ run: |
+ import CompatHelper
+ CompatHelper.main()
+ shell: julia --color=yes {0}
env:
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
- run: julia -e 'using CompatHelper; CompatHelper.main()'
+ COMPATHELPER_PRIV: ${{ secrets.DOCUMENTER_KEY }}
+ # COMPATHELPER_PRIV: ${{ secrets.COMPATHELPER_PRIV }}
+
From 13e277af1e123fbe1e21f567e7937a304ce4c86c Mon Sep 17 00:00:00 2001
From: Lyndon White
Date: Mon, 26 Jul 2021 18:47:26 +0100
Subject: [PATCH 161/490] Fix rrule_via_ad troubles
---
src/compiler/chainrules.jl | 3 ++-
test/chainrules.jl | 2 +-
2 files changed, 3 insertions(+), 2 deletions(-)
diff --git a/src/compiler/chainrules.jl b/src/compiler/chainrules.jl
index 99d907299..15653c47b 100644
--- a/src/compiler/chainrules.jl
+++ b/src/compiler/chainrules.jl
@@ -100,7 +100,8 @@ is_kwfunc(k, ::Type{<:NamedTuple}, f, args...) = k===Core.kwftype(f)
Convert `x` from the differentials types ChainRules uses to the format Zygote uses internally.
"""
-@inline wrap_chainrules_output(x) = unthunk(x) # For now we are just not going to deal with thunks
+@inline wrap_chainrules_output(x) = x
+@inline wrap_chainrules_output(x::AbstractThunk) = wrap_chainrules_output(unthunk(x)) # For now we are just not going to deal with thunks
@inline wrap_chainrules_output(x::Tuple) = map(wrap_chainrules_output, x)
# Zygote convention: even if many AbstractZero partials (i.e. multi-input function), make just 1 nothing.
@inline wrap_chainrules_output(x::Tuple{Vararg{ChainRules.AbstractZero}}) = nothing
diff --git a/test/chainrules.jl b/test/chainrules.jl
index 1ab2edc5a..30b758d04 100644
--- a/test/chainrules.jl
+++ b/test/chainrules.jl
@@ -304,7 +304,7 @@ end
ZygoteRuleConfig(), my_namedtuple, 1., 2., 3.; rrule_f=rrule_via_ad
)
test_rrule(
- ZygoteRuleConfig(), my_namedtuple, 1., (2.0, "str"), 3.; rrule_f=rrule_via_ad
+ ZygoteRuleConfig(), my_namedtuple, 1., (2.0, 2.4), 3.; rrule_f=rrule_via_ad
)
test_rrule(ZygoteRuleConfig(), sum, (1.0, 2.0, 3.0); rrule_f=rrule_via_ad)
test_rrule(
From 073d86942c8b7aafcbe3bc6d13632a9e3ddc0515 Mon Sep 17 00:00:00 2001
From: Lyndon White
Date: Mon, 26 Jul 2021 19:08:42 +0100
Subject: [PATCH 162/490] bump versions
---
Project.toml | 5 +++--
1 file changed, 3 insertions(+), 2 deletions(-)
diff --git a/Project.toml b/Project.toml
index 2b595b150..9d623d263 100644
--- a/Project.toml
+++ b/Project.toml
@@ -6,6 +6,7 @@ version = "0.6.17"
AbstractFFTs = "621f4979-c628-5d54-868e-fcf4e3e8185c"
ChainRules = "082447d4-558c-5d27-93f4-14fc19e9eca2"
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
+ChainRulesTestUtils = "cdddcdb0-9152-4a09-a978-84456f9df70a"
DiffRules = "b552c78f-8df3-52c6-915a-8e097449b14b"
Distributed = "8ba89e20-285c-5b6f-9357-94700520ee1b"
FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b"
@@ -24,7 +25,7 @@ ZygoteRules = "700de1a5-db45-46bc-99cf-38207098b444"
[compat]
AbstractFFTs = "0.5, 1.0"
ChainRules = "1"
-ChainRulesCore = "1"
+ChainRulesCore = "1.0.1"
ChainRulesTestUtils = "1"
DiffRules = "1.0"
FillArrays = "0.8, 0.9, 0.10, 0.11, 0.12"
@@ -33,7 +34,7 @@ IRTools = "0.4"
MacroTools = "0.5"
NaNMath = "0.3"
Requires = "1.1"
-SpecialFunctions = "0.10, 1.0"
+SpecialFunctions = "1.6"
StatsFuns = "0.9.8"
ZygoteRules = "0.2.1"
julia = "1.3"
From 9ef6332829752cc6571badabf80b401f28e72f3c Mon Sep 17 00:00:00 2001
From: Lyndon White
Date: Mon, 26 Jul 2021 19:23:42 +0100
Subject: [PATCH 163/490] Remove direct dependency on CRTU
---
Project.toml | 1 -
1 file changed, 1 deletion(-)
diff --git a/Project.toml b/Project.toml
index 9d623d263..4a71b47cb 100644
--- a/Project.toml
+++ b/Project.toml
@@ -6,7 +6,6 @@ version = "0.6.17"
AbstractFFTs = "621f4979-c628-5d54-868e-fcf4e3e8185c"
ChainRules = "082447d4-558c-5d27-93f4-14fc19e9eca2"
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
-ChainRulesTestUtils = "cdddcdb0-9152-4a09-a978-84456f9df70a"
DiffRules = "b552c78f-8df3-52c6-915a-8e097449b14b"
Distributed = "8ba89e20-285c-5b6f-9357-94700520ee1b"
FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b"
From 1cc024e8840ed9585acd3e3f75cbc40ea64534ef Mon Sep 17 00:00:00 2001
From: Lyndon White
Date: Tue, 27 Jul 2021 09:51:04 +0100
Subject: [PATCH 164/490] Update test/utils.jl
Co-authored-by: Dhairya Gandhi
---
test/utils.jl | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/test/utils.jl b/test/utils.jl
index ee01fa7ed..ecbfc2e14 100644
--- a/test/utils.jl
+++ b/test/utils.jl
@@ -25,7 +25,7 @@ end
xs, y = randn(2,3), rand()
f34(xs, y) = xs[1] * (sum(xs .^ (1:3)') + y^4) # non-diagonal Hessian, two arguments
- # Follow is should work ones we workout what ForwardDiff should do when `Float64` is called on a `Dual`
+ # Following should work once we workout what ForwardDiff should do when `Float64` is called on a `Dual`
# https://github.com/JuliaDiff/ForwardDiff.jl/pull/538
# else might need a custom overload of `(;;ChainRulesCore.ProjectTo)(::Dual)`
# When fixed uncomment the below and delete the broken function
From 8e44dc7875cdba1df0289830f35f141288a69a0e Mon Sep 17 00:00:00 2001
From: Lyndon White
Date: Tue, 27 Jul 2021 12:35:07 +0100
Subject: [PATCH 165/490] renable diaghessian test
---
test/utils.jl | 14 --------------
1 file changed, 14 deletions(-)
diff --git a/test/utils.jl b/test/utils.jl
index ecbfc2e14..70a8ebd63 100644
--- a/test/utils.jl
+++ b/test/utils.jl
@@ -25,24 +25,10 @@ end
xs, y = randn(2,3), rand()
f34(xs, y) = xs[1] * (sum(xs .^ (1:3)') + y^4) # non-diagonal Hessian, two arguments
- # Following should work once we workout what ForwardDiff should do when `Float64` is called on a `Dual`
- # https://github.com/JuliaDiff/ForwardDiff.jl/pull/538
- # else might need a custom overload of `(;;ChainRulesCore.ProjectTo)(::Dual)`
- # When fixed uncomment the below and delete the broken function
- #==
dx, dy = diaghessian(f34, xs, y)
@test size(dx) == size(xs)
@test vec(dx) ≈ diag(hessian(x -> f34(x,y), xs))
@test dy ≈ hessian(y -> f34(xs,y), y)
- ==#
- function broken()
- dx, dy = diaghessian(f34, xs, y) # This fails becase ProjectTo can't project a Dual onto a Float
- c1 = size(dx) == size(xs)
- c2 = vec(dx) ≈ diag(hessian(x -> f34(x,y), xs))
- c3 = dy ≈ hessian(y -> f34(xs,y), y)
- return all([c1, c2, c3])
- end
- @test_broken broken()
zs = randn(7,13) # test chunk mode
@test length(zs) > ForwardDiff.DEFAULT_CHUNK_THRESHOLD
From f417fcb5487c2d171c279bce701067984e90da8e Mon Sep 17 00:00:00 2001
From: Lyndon White
Date: Tue, 27 Jul 2021 17:31:00 +0100
Subject: [PATCH 166/490] Fix Manifest for Docs
---
docs/Manifest.toml | 213 +++++++++++++++++++++++++++++++++++++++++++++
docs/Project.toml | 1 +
2 files changed, 214 insertions(+)
diff --git a/docs/Manifest.toml b/docs/Manifest.toml
index eac8a8401..dfeb0b184 100644
--- a/docs/Manifest.toml
+++ b/docs/Manifest.toml
@@ -1,12 +1,72 @@
# This file is machine-generated - editing it directly is not advised
+[[AbstractFFTs]]
+deps = ["LinearAlgebra"]
+git-tree-sha1 = "485ee0867925449198280d4af84bdb46a2a404d0"
+uuid = "621f4979-c628-5d54-868e-fcf4e3e8185c"
+version = "1.0.1"
+
+[[ArgTools]]
+uuid = "0dad84c5-d112-42e6-8d28-ef12dabb789f"
+
+[[Artifacts]]
+uuid = "56f22d72-fd6d-98f1-02f0-08ddc0907c33"
+
[[Base64]]
uuid = "2a0f44e3-6c83-55bd-87e4-b1978d98bd5f"
+[[ChainRules]]
+deps = ["ChainRulesCore", "Compat", "LinearAlgebra", "Random", "Statistics"]
+git-tree-sha1 = "346588c81effb94da6a30c1617e56af6a878e4d6"
+uuid = "082447d4-558c-5d27-93f4-14fc19e9eca2"
+version = "1.0.1"
+
+[[ChainRulesCore]]
+deps = ["Compat", "LinearAlgebra", "SparseArrays"]
+git-tree-sha1 = "ad613c934ec3a3aa0ff19b91f15a16d56ed404b5"
+uuid = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
+version = "1.0.2"
+
+[[CommonSubexpressions]]
+deps = ["MacroTools", "Test"]
+git-tree-sha1 = "7b8a93dba8af7e3b42fecabf646260105ac373f7"
+uuid = "bbf7d656-a473-5ed7-a52c-81e309532950"
+version = "0.3.0"
+
+[[Compat]]
+deps = ["Base64", "Dates", "DelimitedFiles", "Distributed", "InteractiveUtils", "LibGit2", "Libdl", "LinearAlgebra", "Markdown", "Mmap", "Pkg", "Printf", "REPL", "Random", "SHA", "Serialization", "SharedArrays", "Sockets", "SparseArrays", "Statistics", "Test", "UUIDs", "Unicode"]
+git-tree-sha1 = "dc7dedc2c2aa9faf59a55c622760a25cbefbe941"
+uuid = "34da2185-b29b-5c13-b0c7-acf172513d20"
+version = "3.31.0"
+
+[[CompilerSupportLibraries_jll]]
+deps = ["Artifacts", "Libdl"]
+uuid = "e66e0078-7015-5450-92f7-15fbd957f2ae"
+
[[Dates]]
deps = ["Printf"]
uuid = "ade2ca70-3891-5945-98fb-dc099432e06a"
+[[DelimitedFiles]]
+deps = ["Mmap"]
+uuid = "8bb1440f-4735-579b-a4ab-409b98df4dab"
+
+[[DiffResults]]
+deps = ["StaticArrays"]
+git-tree-sha1 = "c18e98cba888c6c25d1c3b048e4b3380ca956805"
+uuid = "163ba53b-c6d8-5494-b064-1a9d43ac40c5"
+version = "1.0.3"
+
+[[DiffRules]]
+deps = ["NaNMath", "Random", "SpecialFunctions"]
+git-tree-sha1 = "214c3fcac57755cfda163d91c58893a8723f93e9"
+uuid = "b552c78f-8df3-52c6-915a-8e097449b14b"
+version = "1.0.2"
+
+[[Distributed]]
+deps = ["Random", "Serialization", "Sockets"]
+uuid = "8ba89e20-285c-5b6f-9357-94700520ee1b"
+
[[DocStringExtensions]]
deps = ["LibGit2"]
git-tree-sha1 = "a32185f5428d3986f47c2ab78b1f216d5e6cc96f"
@@ -19,39 +79,126 @@ git-tree-sha1 = "395fa1554c69735802bba37d9e7d9586fd44326c"
uuid = "e30172f5-a6a5-5a46-863b-614d45cd2de4"
version = "0.24.11"
+[[Downloads]]
+deps = ["ArgTools", "LibCURL", "NetworkOptions"]
+uuid = "f43a241f-c20a-4ad4-852c-f6b1247861c6"
+
+[[FillArrays]]
+deps = ["LinearAlgebra", "Random", "SparseArrays", "Statistics"]
+git-tree-sha1 = "8c8eac2af06ce35973c3eadb4ab3243076a408e7"
+uuid = "1a297f60-69ca-5386-bcde-b61e274b549b"
+version = "0.12.1"
+
+[[ForwardDiff]]
+deps = ["CommonSubexpressions", "DiffResults", "DiffRules", "LinearAlgebra", "NaNMath", "Printf", "Random", "SpecialFunctions", "StaticArrays"]
+git-tree-sha1 = "e2af66012e08966366a43251e1fd421522908be6"
+uuid = "f6369f11-7733-5829-9624-2563aa707210"
+version = "0.10.18"
+
+[[IRTools]]
+deps = ["InteractiveUtils", "MacroTools", "Test"]
+git-tree-sha1 = "95215cd0076a150ef46ff7928892bc341864c73c"
+uuid = "7869d1d1-7146-5819-86e3-90919afe41df"
+version = "0.4.3"
+
[[InteractiveUtils]]
deps = ["Markdown"]
uuid = "b77e0a4c-d291-57a0-90e8-8db25a27a240"
+[[JLLWrappers]]
+deps = ["Preferences"]
+git-tree-sha1 = "642a199af8b68253517b80bd3bfd17eb4e84df6e"
+uuid = "692b3bcd-3c85-4b1f-b108-f13ce0eb3210"
+version = "1.3.0"
+
[[JSON]]
deps = ["Dates", "Mmap", "Parsers", "Unicode"]
git-tree-sha1 = "81690084b6198a2e1da36fcfda16eeca9f9f24e4"
uuid = "682c06a0-de6a-54ab-a142-c8b1cf79cde6"
version = "0.21.1"
+[[LibCURL]]
+deps = ["LibCURL_jll", "MozillaCACerts_jll"]
+uuid = "b27032c2-a3e7-50c8-80cd-2d36dbcbfd21"
+
+[[LibCURL_jll]]
+deps = ["Artifacts", "LibSSH2_jll", "Libdl", "MbedTLS_jll", "Zlib_jll", "nghttp2_jll"]
+uuid = "deac9b47-8bc7-5906-a0fe-35ac56dc84c0"
+
[[LibGit2]]
deps = ["Base64", "NetworkOptions", "Printf", "SHA"]
uuid = "76f85450-5226-5b5a-8eaa-529ad045b433"
+[[LibSSH2_jll]]
+deps = ["Artifacts", "Libdl", "MbedTLS_jll"]
+uuid = "29816b5a-b9ab-546f-933c-edad1886dfa8"
+
+[[Libdl]]
+uuid = "8f399da3-3557-5675-b5ff-fb832c97cbdb"
+
+[[LinearAlgebra]]
+deps = ["Libdl"]
+uuid = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
+
+[[LogExpFunctions]]
+deps = ["DocStringExtensions", "LinearAlgebra"]
+git-tree-sha1 = "7bd5f6565d80b6bf753738d2bc40a5dfea072070"
+uuid = "2ab3a3ac-af41-5b50-aa03-7779005ae688"
+version = "0.2.5"
+
[[Logging]]
uuid = "56ddb016-857b-54e1-b83d-db4d58db5568"
+[[MacroTools]]
+deps = ["Markdown", "Random"]
+git-tree-sha1 = "6a8a2a625ab0dea913aba95c11370589e0239ff0"
+uuid = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09"
+version = "0.5.6"
+
[[Markdown]]
deps = ["Base64"]
uuid = "d6f4376e-aef5-505a-96c1-9c027394607a"
+[[MbedTLS_jll]]
+deps = ["Artifacts", "Libdl"]
+uuid = "c8ffd9c3-330d-5841-b78e-0817d7145fa1"
+
[[Mmap]]
uuid = "a63ad114-7e13-5084-954f-fe012c677804"
+[[MozillaCACerts_jll]]
+uuid = "14a3606d-f60d-562e-9121-12d972cd8159"
+
+[[NaNMath]]
+git-tree-sha1 = "bfe47e760d60b82b66b61d2d44128b62e3a369fb"
+uuid = "77ba4419-2d1f-58cd-9bb1-8ffee604a2e3"
+version = "0.3.5"
+
[[NetworkOptions]]
uuid = "ca575930-c2e3-43a9-ace4-1e988b2c1908"
+[[OpenSpecFun_jll]]
+deps = ["Artifacts", "CompilerSupportLibraries_jll", "JLLWrappers", "Libdl", "Pkg"]
+git-tree-sha1 = "13652491f6856acfd2db29360e1bbcd4565d04f1"
+uuid = "efe28fd5-8261-553b-a9e1-b2916fc3738e"
+version = "0.5.5+0"
+
[[Parsers]]
deps = ["Dates"]
git-tree-sha1 = "c8abc88faa3f7a3950832ac5d6e690881590d6dc"
uuid = "69de0a69-1ddd-5017-9359-2bf0b02dc9f0"
version = "1.1.0"
+[[Pkg]]
+deps = ["Artifacts", "Dates", "Downloads", "LibGit2", "Libdl", "Logging", "Markdown", "Printf", "REPL", "Random", "SHA", "Serialization", "TOML", "Tar", "UUIDs", "p7zip_jll"]
+uuid = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"
+
+[[Preferences]]
+deps = ["TOML"]
+git-tree-sha1 = "00cfd92944ca9c760982747e9a1d0d5d86ab1e5a"
+uuid = "21216c6a-2e73-6563-6e65-726566657250"
+version = "1.2.2"
+
[[Printf]]
deps = ["Unicode"]
uuid = "de0858da-6303-5e67-8744-51eddeeeb8d7"
@@ -64,18 +211,84 @@ uuid = "3fa0cd96-eef1-5676-8a61-b3b8758bbffb"
deps = ["Serialization"]
uuid = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
+[[Requires]]
+deps = ["UUIDs"]
+git-tree-sha1 = "4036a3bd08ac7e968e27c203d45f5fff15020621"
+uuid = "ae029012-a4dd-5104-9daa-d747884805df"
+version = "1.1.3"
+
[[SHA]]
uuid = "ea8e919c-243c-51af-8825-aaa63cd721ce"
[[Serialization]]
uuid = "9e88b42a-f829-5b0c-bbe9-9e923198166b"
+[[SharedArrays]]
+deps = ["Distributed", "Mmap", "Random", "Serialization"]
+uuid = "1a1011a3-84de-559e-8e89-a11a2f7dc383"
+
[[Sockets]]
uuid = "6462fe0b-24de-5631-8697-dd941f90decc"
+[[SparseArrays]]
+deps = ["LinearAlgebra", "Random"]
+uuid = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
+
+[[SpecialFunctions]]
+deps = ["ChainRulesCore", "LogExpFunctions", "OpenSpecFun_jll"]
+git-tree-sha1 = "508822dca004bf62e210609148511ad03ce8f1d8"
+uuid = "276daf66-3868-5448-9aa4-cd146d93841b"
+version = "1.6.0"
+
+[[StaticArrays]]
+deps = ["LinearAlgebra", "Random", "Statistics"]
+git-tree-sha1 = "5b2f81eeb66bcfe379947c500aae773c85c31033"
+uuid = "90137ffa-7385-5640-81b9-e52037218182"
+version = "1.2.8"
+
+[[Statistics]]
+deps = ["LinearAlgebra", "SparseArrays"]
+uuid = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
+
+[[TOML]]
+deps = ["Dates"]
+uuid = "fa267f1f-6049-4f14-aa54-33bafae1ed76"
+
+[[Tar]]
+deps = ["ArgTools", "SHA"]
+uuid = "a4e569a6-e804-4fa4-b0f3-eef7a1d5b13e"
+
[[Test]]
deps = ["InteractiveUtils", "Logging", "Random", "Serialization"]
uuid = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
+[[UUIDs]]
+deps = ["Random", "SHA"]
+uuid = "cf7118a7-6976-5b1a-9a39-7adc72f591a4"
+
[[Unicode]]
uuid = "4ec0a83e-493e-50e2-b9ac-8f72acf5a8f5"
+
+[[Zlib_jll]]
+deps = ["Libdl"]
+uuid = "83775a58-1f1d-513f-b197-d71354ab007a"
+
+[[Zygote]]
+deps = ["AbstractFFTs", "ChainRules", "ChainRulesCore", "DiffRules", "Distributed", "FillArrays", "ForwardDiff", "IRTools", "InteractiveUtils", "LinearAlgebra", "MacroTools", "NaNMath", "Random", "Requires", "SpecialFunctions", "Statistics", "ZygoteRules"]
+path = ".."
+uuid = "e88e6eb3-aa80-5325-afca-941959d7151f"
+version = "0.6.17"
+
+[[ZygoteRules]]
+deps = ["MacroTools"]
+git-tree-sha1 = "9e7a1e8ca60b742e508a315c17eef5211e7fbfd7"
+uuid = "700de1a5-db45-46bc-99cf-38207098b444"
+version = "0.2.1"
+
+[[nghttp2_jll]]
+deps = ["Artifacts", "Libdl"]
+uuid = "8e850ede-7688-5339-a07c-302acd2aaf8d"
+
+[[p7zip_jll]]
+deps = ["Artifacts", "Libdl"]
+uuid = "3f19e933-33d8-53b3-aaab-bd5110c3b7a0"
diff --git a/docs/Project.toml b/docs/Project.toml
index 1b9ab1f81..2a4c85433 100644
--- a/docs/Project.toml
+++ b/docs/Project.toml
@@ -1,5 +1,6 @@
[deps]
Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4"
+Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
[compat]
Documenter = "0.24"
From 882a939986871d834b98b3cbbf2fab4bba6183ff Mon Sep 17 00:00:00 2001
From: Lyndon White
Date: Tue, 27 Jul 2021 17:52:58 +0100
Subject: [PATCH 167/490] Fix doctests
---
src/compiler/interface.jl | 8 ++++----
1 file changed, 4 insertions(+), 4 deletions(-)
diff --git a/src/compiler/interface.jl b/src/compiler/interface.jl
index 6ca8257d5..e4db33471 100644
--- a/src/compiler/interface.jl
+++ b/src/compiler/interface.jl
@@ -58,11 +58,11 @@ See also [`withgradient`](@ref) to keep the value `f(args...)`,
and [`pullback`](@ref) for value and back-propagator.
```jldoctest; setup=:(using Zygote)
-julia> gradient(*, 2, 3, 5)
-(15, 10, 6)
+julia> gradient(*, 2.0, 3.0, 5.0)
+(15.0, 10.0, 6.0)
-julia> gradient(x -> sum(abs2,x), [7, 11, 13])
-([14, 22, 26],)
+julia> gradient(x -> sum(abs2,x), [7.0, 11.0, 13.0])
+([14.0, 22.0, 26.0],)
julia> gradient([7, 11], 0, 1) do x, y, d
p = size(x, d)
From 31150fa019030cb28deddfb0c5ffdff36fc7da0a Mon Sep 17 00:00:00 2001
From: Dhairya Gandhi
Date: Tue, 27 Jul 2021 22:30:02 +0530
Subject: [PATCH 168/490] whitespace
---
src/compiler/chainrules.jl | 4 ++--
1 file changed, 2 insertions(+), 2 deletions(-)
diff --git a/src/compiler/chainrules.jl b/src/compiler/chainrules.jl
index 15653c47b..d6a1894c2 100644
--- a/src/compiler/chainrules.jl
+++ b/src/compiler/chainrules.jl
@@ -28,8 +28,8 @@ function has_chain_rrule(T)
else
# Not being redispatched: it does have a config
rrule_m = configured_rrule_m
- # Thus any no_rrule that might apply must also have a config because if it applied
- # it will be identical, and if it doesn't we don't care what it is.
+ # Thus any no_rrule that might apply must also have a config because if it applied
+ # it will be identical, and if it doesn't we don't care what it is.
no_rrule_m = meta(Tuple{typeof(ChainRulesCore.no_rrule), config_T, arg_Ts...})
end
From 5fb12bcc06b9ad002fb707e1c575da8f296b81be Mon Sep 17 00:00:00 2001
From: Dhairya Gandhi
Date: Tue, 27 Jul 2021 22:44:37 +0530
Subject: [PATCH 169/490] typos
---
src/compiler/chainrules.jl | 6 +++---
1 file changed, 3 insertions(+), 3 deletions(-)
diff --git a/src/compiler/chainrules.jl b/src/compiler/chainrules.jl
index d6a1894c2..6fcdcdf40 100644
--- a/src/compiler/chainrules.jl
+++ b/src/compiler/chainrules.jl
@@ -42,13 +42,13 @@ function has_chain_rrule(T)
# Consider the following truth table, for what can occur:
# rrule: fallback, no_rrule: fallback => matches => do not use rrule.
# rrule: specific, no_rrule: fallback => !matches => do use rrule, as haven't opted out.
- # rrule: fallback, no_rrule: specific => IMPOSSIBLE, every no_rule us identical to some rrule
+ # rrule: fallback, no_rrule: specific => IMPOSSIBLE, every no_rule is identical to some rrule
# rrule: specific, no_rrule: specific => matches => do not use rrule as opted out
# rrule: specific, no_rrule: general => !matches => do use rrule as a more specific rrule takes preciedent over more general opted out
# rrule: general , no_rrule: specific => IMPOSSIBLE, every no_rule us identical to some rrule so can't have a more general rrule being hit, as the specific one would hit first
#
- # Note that the fallback cases are the same outcome as the general cases as fallback is just most general.
- # It can be seen that checking if it matches is the correct way to decide if we should ue the rrule or not.
+ # Note that the fallback cases are the same outcome as the general cases as fallback is just most general.
+ # It can be seen that checking if it matches is the correct way to decide if we should use the rrule or not.
do_not_use_rrule = matching_cr_sig(no_rrule_m, rrule_m)
From 002c937db45ce25855eb2427781c19cd1d351bd0 Mon Sep 17 00:00:00 2001
From: Lyndon White
Date: Tue, 27 Jul 2021 19:12:17 +0100
Subject: [PATCH 170/490] Update Project.toml
---
Project.toml | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/Project.toml b/Project.toml
index 4a71b47cb..5d04edb68 100644
--- a/Project.toml
+++ b/Project.toml
@@ -1,6 +1,6 @@
name = "Zygote"
uuid = "e88e6eb3-aa80-5325-afca-941959d7151f"
-version = "0.6.17"
+version = "0.6.18"
[deps]
AbstractFFTs = "621f4979-c628-5d54-868e-fcf4e3e8185c"
From 5b803d12f2d235755ba6f23a9017fc29fefd9985 Mon Sep 17 00:00:00 2001
From: Lyndon White
Date: Tue, 27 Jul 2021 19:15:46 +0100
Subject: [PATCH 171/490] let CR1 rest as 0.6.18-DEV for a bit
---
Project.toml | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/Project.toml b/Project.toml
index 5d04edb68..a1b9e5d0c 100644
--- a/Project.toml
+++ b/Project.toml
@@ -1,6 +1,6 @@
name = "Zygote"
uuid = "e88e6eb3-aa80-5325-afca-941959d7151f"
-version = "0.6.18"
+version = "0.6.18-DEV"
[deps]
AbstractFFTs = "621f4979-c628-5d54-868e-fcf4e3e8185c"
From 7cde196d9fca10b829d4da3792f4c0944bb9899a Mon Sep 17 00:00:00 2001
From: Miha Zgubic
Date: Thu, 29 Jul 2021 11:53:43 +0200
Subject: [PATCH 172/490] remove artifact warning
---
src/compiler/chainrules.jl | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/src/compiler/chainrules.jl b/src/compiler/chainrules.jl
index 6fcdcdf40..fab01e01c 100644
--- a/src/compiler/chainrules.jl
+++ b/src/compiler/chainrules.jl
@@ -194,7 +194,7 @@ z2d(x, ::Any) = x
z2d(::Nothing, ::Any) = NoTangent()
z2d(a::AbstractArray{<:Number}, primal::AbstractArray{T}) where T = a
z2d(a::AbstractArray, primal::AbstractArray{T}) where T = z2d.(a, primal)
-z2d(x::Union{AbstractZero, Tangent}, ::Any) = (difftype_warn(x); return x)
+z2d(x::Union{AbstractZero, Tangent}, ::Any) = return x
function z2d(t::Tuple, primal::Tuple)
tp::Tuple = map(z2d, t, primal)
primal_type = typeof(primal)
From 149bae5492e1f1e7e71903908d155a9107dad3fd Mon Sep 17 00:00:00 2001
From: Lyndon White
Date: Thu, 29 Jul 2021 12:57:40 +0100
Subject: [PATCH 173/490] add comment
---
src/compiler/chainrules.jl | 3 +++
1 file changed, 3 insertions(+)
diff --git a/src/compiler/chainrules.jl b/src/compiler/chainrules.jl
index fab01e01c..fc02d68a8 100644
--- a/src/compiler/chainrules.jl
+++ b/src/compiler/chainrules.jl
@@ -194,6 +194,9 @@ z2d(x, ::Any) = x
z2d(::Nothing, ::Any) = NoTangent()
z2d(a::AbstractArray{<:Number}, primal::AbstractArray{T}) where T = a
z2d(a::AbstractArray, primal::AbstractArray{T}) where T = z2d.(a, primal)
+# Note: this should never be hit if we are converting things right, but it seems to be
+# happening in the wild for sufficiently weird functions/types.
+# This fixes most (all?) cases, but it would be good to find what we miss.
z2d(x::Union{AbstractZero, Tangent}, ::Any) = return x
function z2d(t::Tuple, primal::Tuple)
tp::Tuple = map(z2d, t, primal)
From 005448bf57913a4ed270fb9b63ba248219973890 Mon Sep 17 00:00:00 2001
From: Michael Abbott <32575566+mcabbott@users.noreply.github.com>
Date: Sat, 31 Jul 2021 21:51:51 -0400
Subject: [PATCH 174/490] rm examples
---
examples/Manifest.toml | 394 ----------------------------------
examples/Project.toml | 5 -
examples/linear_regression.jl | 99 ---------
examples/mnist_mlp.jl | 107 ---------
examples/profiler.jl | 38 ----
5 files changed, 643 deletions(-)
delete mode 100644 examples/Manifest.toml
delete mode 100644 examples/Project.toml
delete mode 100644 examples/linear_regression.jl
delete mode 100644 examples/mnist_mlp.jl
delete mode 100644 examples/profiler.jl
diff --git a/examples/Manifest.toml b/examples/Manifest.toml
deleted file mode 100644
index 4ad93ad0d..000000000
--- a/examples/Manifest.toml
+++ /dev/null
@@ -1,394 +0,0 @@
-# This file is machine-generated - editing it directly is not advised
-
-[[AbstractFFTs]]
-deps = ["LinearAlgebra"]
-git-tree-sha1 = "380e36c66edfa099cd90116b24c1ce8cafccac40"
-uuid = "621f4979-c628-5d54-868e-fcf4e3e8185c"
-version = "0.4.1"
-
-[[AbstractTrees]]
-deps = ["Markdown", "Test"]
-git-tree-sha1 = "6621d9645702c1c4e6970cc6a3eae440c768000b"
-uuid = "1520ce14-60c1-5f80-bbc7-55ef81b5835c"
-version = "0.2.1"
-
-[[Adapt]]
-deps = ["LinearAlgebra"]
-git-tree-sha1 = "fd04049c7dd78cfef0b06cdc1f0f181467655712"
-uuid = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
-version = "1.1.0"
-
-[[Base64]]
-uuid = "2a0f44e3-6c83-55bd-87e4-b1978d98bd5f"
-
-[[BinDeps]]
-deps = ["Libdl", "Pkg", "SHA", "URIParser", "Unicode"]
-git-tree-sha1 = "1289b57e8cf019aede076edab0587eb9644175bd"
-uuid = "9e28174c-4ba2-5203-b857-d8d62c4213ee"
-version = "1.0.2"
-
-[[BinaryProvider]]
-deps = ["Libdl", "Logging", "SHA"]
-git-tree-sha1 = "ecdec412a9abc8db54c0efc5548c64dfce072058"
-uuid = "b99e7846-7c00-51b0-8f62-c81ae34c0232"
-version = "0.5.10"
-
-[[CEnum]]
-git-tree-sha1 = "215a9aa4a1f23fbd05b92769fdd62559488d70e9"
-uuid = "fa961155-64e5-5f13-b03f-caf6b980ea82"
-version = "0.4.1"
-
-[[CUDAapi]]
-deps = ["Libdl", "Logging"]
-git-tree-sha1 = "e063efb91cfefd7e6afd92c435d01398107a500b"
-uuid = "3895d2a7-ec45-59b8-82bb-cfc6a382f9b3"
-version = "1.2.0"
-
-[[CUDAdrv]]
-deps = ["CUDAapi", "Libdl", "Printf"]
-git-tree-sha1 = "9ce99b5732c70e06ed97c042187baed876fb1698"
-uuid = "c5f51814-7f29-56b8-a69c-e4d8f6be1fde"
-version = "3.1.0"
-
-[[CUDAnative]]
-deps = ["Adapt", "CUDAapi", "CUDAdrv", "DataStructures", "InteractiveUtils", "LLVM", "Libdl", "Logging", "Printf", "TimerOutputs"]
-git-tree-sha1 = "3d6427f28430730c0e4107d8f26c4943a9a142dc"
-uuid = "be33ccc6-a3ff-5ff2-a52e-74243cff1e17"
-version = "2.4.0"
-
-[[CodecZlib]]
-deps = ["BinaryProvider", "Libdl", "TranscodingStreams"]
-git-tree-sha1 = "05916673a2627dd91b4969ff8ba6941bc85a960e"
-uuid = "944b1d66-785c-5afd-91f1-9de20f533193"
-version = "0.6.0"
-
-[[ColorTypes]]
-deps = ["FixedPointNumbers", "Random"]
-git-tree-sha1 = "7b62b728a5f3dd6ee3b23910303ccf27e82fad5e"
-uuid = "3da002f7-5984-5a60-b8a6-cbb66c0b333f"
-version = "0.8.1"
-
-[[Colors]]
-deps = ["ColorTypes", "FixedPointNumbers", "InteractiveUtils", "Printf", "Reexport"]
-git-tree-sha1 = "c9c1845d6bf22e34738bee65c357a69f416ed5d1"
-uuid = "5ae59095-9a9b-59fe-a467-6f913c188581"
-version = "0.9.6"
-
-[[CommonSubexpressions]]
-deps = ["MacroTools", "Test"]
-git-tree-sha1 = "7b8a93dba8af7e3b42fecabf646260105ac373f7"
-uuid = "bbf7d656-a473-5ed7-a52c-81e309532950"
-version = "0.3.0"
-
-[[Conda]]
-deps = ["JSON", "VersionParsing"]
-git-tree-sha1 = "299304989a5e6473d985212c28928899c74e9421"
-uuid = "8f4d0f93-b110-5947-807f-2305c1781a2d"
-version = "1.5.2"
-
-[[CuArrays]]
-deps = ["AbstractFFTs", "Adapt", "CUDAapi", "CUDAdrv", "CUDAnative", "GPUArrays", "LinearAlgebra", "MacroTools", "NNlib", "Printf", "Random", "Requires", "SparseArrays", "TimerOutputs"]
-git-tree-sha1 = "46b48742a84bb839e74215b7e468a4a1c6ba30f9"
-uuid = "3a865a2d-5b23-5a0f-bc46-62713ec82fae"
-version = "1.2.1"
-
-[[DataAPI]]
-git-tree-sha1 = "ee400abb2298bd13bfc3df1c412ed228061a2385"
-uuid = "9a962f9c-6df0-11e9-0e5d-c546b8b5ee8a"
-version = "1.7.0"
-
-[[DataStructures]]
-deps = ["InteractiveUtils", "OrderedCollections"]
-git-tree-sha1 = "88d48e133e6d3dd68183309877eac74393daa7eb"
-uuid = "864edb3b-99cc-5e75-8d2d-829cb0a9cfe8"
-version = "0.17.20"
-
-[[Dates]]
-deps = ["Printf"]
-uuid = "ade2ca70-3891-5945-98fb-dc099432e06a"
-
-[[DelimitedFiles]]
-deps = ["Mmap"]
-uuid = "8bb1440f-4735-579b-a4ab-409b98df4dab"
-
-[[DiffResults]]
-deps = ["StaticArrays"]
-git-tree-sha1 = "c18e98cba888c6c25d1c3b048e4b3380ca956805"
-uuid = "163ba53b-c6d8-5494-b064-1a9d43ac40c5"
-version = "1.0.3"
-
-[[DiffRules]]
-deps = ["NaNMath", "Random", "SpecialFunctions"]
-git-tree-sha1 = "214c3fcac57755cfda163d91c58893a8723f93e9"
-uuid = "b552c78f-8df3-52c6-915a-8e097449b14b"
-version = "1.0.2"
-
-[[Distributed]]
-deps = ["Random", "Serialization", "Sockets"]
-uuid = "8ba89e20-285c-5b6f-9357-94700520ee1b"
-
-[[ExprTools]]
-git-tree-sha1 = "555eab1f7c501166ba87eeb5d561e9f5e7d167d3"
-uuid = "e2ba6199-217a-4e67-a87a-7c52f15ade04"
-version = "0.1.4"
-
-[[FFTW]]
-deps = ["AbstractFFTs", "BinaryProvider", "Conda", "Libdl", "LinearAlgebra", "Reexport", "Test"]
-git-tree-sha1 = "6c5b420da0b8c12098048561b8d58f81adea506f"
-uuid = "7a1cc6ca-52ef-59f5-83cd-3a7055c09341"
-version = "1.0.1"
-
-[[FillArrays]]
-deps = ["LinearAlgebra", "Random", "SparseArrays"]
-git-tree-sha1 = "de38b0253ade98340fabaf220f368f6144541938"
-uuid = "1a297f60-69ca-5386-bcde-b61e274b549b"
-version = "0.7.4"
-
-[[FixedPointNumbers]]
-git-tree-sha1 = "d14a6fa5890ea3a7e5dcab6811114f132fec2b4b"
-uuid = "53c48c17-4a7d-5ca2-90c5-79b7896eea93"
-version = "0.6.1"
-
-[[Flux]]
-deps = ["AbstractTrees", "Adapt", "CUDAapi", "CodecZlib", "Colors", "CuArrays", "DelimitedFiles", "Juno", "LinearAlgebra", "MacroTools", "NNlib", "Pkg", "Printf", "Random", "Reexport", "SHA", "Statistics", "StatsBase", "Tracker", "ZipFile"]
-git-tree-sha1 = "b5ebbd896dcd8ff19c6cb7297c4d323155b26bcf"
-uuid = "587475ba-b771-5e3f-ad9e-33799f191a9c"
-version = "0.9.0"
-
-[[ForwardDiff]]
-deps = ["CommonSubexpressions", "DiffResults", "DiffRules", "LinearAlgebra", "NaNMath", "Printf", "Random", "SpecialFunctions", "StaticArrays"]
-git-tree-sha1 = "e2af66012e08966366a43251e1fd421522908be6"
-uuid = "f6369f11-7733-5829-9624-2563aa707210"
-version = "0.10.18"
-
-[[GPUArrays]]
-deps = ["Adapt", "FFTW", "FillArrays", "LinearAlgebra", "Printf", "Random", "Serialization", "Test"]
-git-tree-sha1 = "8d74ced24448c52b539a23d107bd2424ee139c0f"
-uuid = "0c68f7d7-f131-5f86-a1c3-88cf8149b2d7"
-version = "1.0.4"
-
-[[IRTools]]
-deps = ["InteractiveUtils", "MacroTools", "Test"]
-git-tree-sha1 = "95215cd0076a150ef46ff7928892bc341864c73c"
-uuid = "7869d1d1-7146-5819-86e3-90919afe41df"
-version = "0.4.3"
-
-[[InteractiveUtils]]
-deps = ["Markdown"]
-uuid = "b77e0a4c-d291-57a0-90e8-8db25a27a240"
-
-[[JSON]]
-deps = ["Dates", "Mmap", "Parsers", "Unicode"]
-git-tree-sha1 = "81690084b6198a2e1da36fcfda16eeca9f9f24e4"
-uuid = "682c06a0-de6a-54ab-a142-c8b1cf79cde6"
-version = "0.21.1"
-
-[[Juno]]
-deps = ["Base64", "Logging", "Media", "Profile", "Test"]
-git-tree-sha1 = "30d94657a422d09cb97b6f86f04f750fa9c50df8"
-uuid = "e5e0dc1b-0480-54bc-9374-aad01c23163d"
-version = "0.7.2"
-
-[[LLVM]]
-deps = ["CEnum", "Libdl", "Printf", "Unicode"]
-git-tree-sha1 = "d9c6e1efcaa6c2fcd043da812a62b3e489a109a3"
-uuid = "929cbde3-209d-540e-8aea-75f648917ca0"
-version = "1.7.0"
-
-[[LibGit2]]
-uuid = "76f85450-5226-5b5a-8eaa-529ad045b433"
-
-[[Libdl]]
-uuid = "8f399da3-3557-5675-b5ff-fb832c97cbdb"
-
-[[LinearAlgebra]]
-deps = ["Libdl"]
-uuid = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
-
-[[Logging]]
-uuid = "56ddb016-857b-54e1-b83d-db4d58db5568"
-
-[[MacroTools]]
-deps = ["Markdown", "Random"]
-git-tree-sha1 = "6a8a2a625ab0dea913aba95c11370589e0239ff0"
-uuid = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09"
-version = "0.5.6"
-
-[[Markdown]]
-deps = ["Base64"]
-uuid = "d6f4376e-aef5-505a-96c1-9c027394607a"
-
-[[Media]]
-deps = ["MacroTools", "Test"]
-git-tree-sha1 = "75a54abd10709c01f1b86b84ec225d26e840ed58"
-uuid = "e89f7d12-3494-54d1-8411-f7d8b9ae1f27"
-version = "0.5.0"
-
-[[Missings]]
-deps = ["DataAPI"]
-git-tree-sha1 = "4ea90bd5d3985ae1f9a908bd4500ae88921c5ce7"
-uuid = "e1d29d7a-bbdc-5cf2-9ac0-f12de2c33e28"
-version = "1.0.0"
-
-[[Mmap]]
-uuid = "a63ad114-7e13-5084-954f-fe012c677804"
-
-[[NNlib]]
-deps = ["BinaryProvider", "Libdl", "LinearAlgebra", "Requires", "Statistics"]
-git-tree-sha1 = "d9f196d911f55aeaff11b11f681b135980783824"
-uuid = "872c559c-99b0-510c-b3b7-b6c96a88d5cd"
-version = "0.6.6"
-
-[[NaNMath]]
-git-tree-sha1 = "bfe47e760d60b82b66b61d2d44128b62e3a369fb"
-uuid = "77ba4419-2d1f-58cd-9bb1-8ffee604a2e3"
-version = "0.3.5"
-
-[[OrderedCollections]]
-git-tree-sha1 = "85f8e6578bf1f9ee0d11e7bb1b1456435479d47c"
-uuid = "bac558e1-5e72-5ebc-8fee-abe8a469f55d"
-version = "1.4.1"
-
-[[Parsers]]
-deps = ["Dates"]
-git-tree-sha1 = "c8abc88faa3f7a3950832ac5d6e690881590d6dc"
-uuid = "69de0a69-1ddd-5017-9359-2bf0b02dc9f0"
-version = "1.1.0"
-
-[[Pkg]]
-deps = ["Dates", "LibGit2", "Markdown", "Printf", "REPL", "Random", "SHA", "UUIDs"]
-uuid = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"
-
-[[Printf]]
-deps = ["Unicode"]
-uuid = "de0858da-6303-5e67-8744-51eddeeeb8d7"
-
-[[Profile]]
-deps = ["Printf"]
-uuid = "9abbd945-dff8-562f-b5e8-e1ebf5ef1b79"
-
-[[ProgressMeter]]
-deps = ["Distributed", "Printf"]
-git-tree-sha1 = "afadeba63d90ff223a6a48d2009434ecee2ec9e8"
-uuid = "92933f4c-e287-5a05-a399-4b506db050ca"
-version = "1.7.1"
-
-[[REPL]]
-deps = ["InteractiveUtils", "Markdown", "Sockets"]
-uuid = "3fa0cd96-eef1-5676-8a61-b3b8758bbffb"
-
-[[Random]]
-deps = ["Serialization"]
-uuid = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
-
-[[Reexport]]
-deps = ["Pkg"]
-git-tree-sha1 = "7b1d07f411bc8ddb7977ec7f377b97b158514fe0"
-uuid = "189a3867-3050-52da-a836-e630ba90ab69"
-version = "0.2.0"
-
-[[Requires]]
-deps = ["Test"]
-git-tree-sha1 = "f6fbf4ba64d295e146e49e021207993b6b48c7d1"
-uuid = "ae029012-a4dd-5104-9daa-d747884805df"
-version = "0.5.2"
-
-[[SHA]]
-uuid = "ea8e919c-243c-51af-8825-aaa63cd721ce"
-
-[[Serialization]]
-uuid = "9e88b42a-f829-5b0c-bbe9-9e923198166b"
-
-[[Sockets]]
-uuid = "6462fe0b-24de-5631-8697-dd941f90decc"
-
-[[SortingAlgorithms]]
-deps = ["DataStructures"]
-git-tree-sha1 = "2ec1962eba973f383239da22e75218565c390a96"
-uuid = "a2af1166-a08f-5f64-846c-94a0d3cef48c"
-version = "1.0.0"
-
-[[SparseArrays]]
-deps = ["LinearAlgebra", "Random"]
-uuid = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
-
-[[SpecialFunctions]]
-deps = ["BinDeps", "BinaryProvider", "Libdl"]
-git-tree-sha1 = "3bdd374b6fd78faf0119b8c5d538788dbf910c6e"
-uuid = "276daf66-3868-5448-9aa4-cd146d93841b"
-version = "0.8.0"
-
-[[StaticArrays]]
-deps = ["LinearAlgebra", "Random", "Statistics"]
-git-tree-sha1 = "896d55218776ab8f23fb7b222a5a4a946d4aafc2"
-uuid = "90137ffa-7385-5640-81b9-e52037218182"
-version = "1.2.5"
-
-[[Statistics]]
-deps = ["LinearAlgebra", "SparseArrays"]
-uuid = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
-
-[[StatsAPI]]
-git-tree-sha1 = "1958272568dc176a1d881acb797beb909c785510"
-uuid = "82ae8749-77ed-4fe6-ae5f-f523153014b0"
-version = "1.0.0"
-
-[[StatsBase]]
-deps = ["DataAPI", "DataStructures", "LinearAlgebra", "Missings", "Printf", "Random", "SortingAlgorithms", "SparseArrays", "Statistics", "StatsAPI"]
-git-tree-sha1 = "2f6792d523d7448bbe2fec99eca9218f06cc746d"
-uuid = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"
-version = "0.33.8"
-
-[[Test]]
-deps = ["Distributed", "InteractiveUtils", "Logging", "Random"]
-uuid = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
-
-[[TimerOutputs]]
-deps = ["ExprTools", "Printf"]
-git-tree-sha1 = "209a8326c4f955e2442c07b56029e88bb48299c7"
-uuid = "a759f4b9-e2f1-59dc-863e-4aeb61b1ea8f"
-version = "0.5.12"
-
-[[Tracker]]
-deps = ["Adapt", "DiffRules", "ForwardDiff", "LinearAlgebra", "MacroTools", "NNlib", "NaNMath", "Printf", "Random", "Requires", "SpecialFunctions", "Statistics", "Test"]
-git-tree-sha1 = "86929a5811dca5ce76c65a1d3fecda92d90c2e49"
-uuid = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c"
-version = "0.2.6"
-
-[[TranscodingStreams]]
-deps = ["Random", "Test"]
-git-tree-sha1 = "7c53c35547de1c5b9d46a4797cf6d8253807108c"
-uuid = "3bb67fe8-82b1-5028-8e26-92a6c54297fa"
-version = "0.9.5"
-
-[[URIParser]]
-deps = ["Unicode"]
-git-tree-sha1 = "53a9f49546b8d2dd2e688d216421d050c9a31d0d"
-uuid = "30578b45-9adc-5946-b283-645ec420af67"
-version = "0.4.1"
-
-[[UUIDs]]
-deps = ["Random", "SHA"]
-uuid = "cf7118a7-6976-5b1a-9a39-7adc72f591a4"
-
-[[Unicode]]
-uuid = "4ec0a83e-493e-50e2-b9ac-8f72acf5a8f5"
-
-[[VersionParsing]]
-git-tree-sha1 = "80229be1f670524750d905f8fc8148e5a8c4537f"
-uuid = "81def892-9a0e-5fdd-b105-ffc91e053289"
-version = "1.2.0"
-
-[[ZipFile]]
-deps = ["BinaryProvider", "Libdl", "Printf"]
-git-tree-sha1 = "7fbfbc51c186f0ccdbe091f32d3dff8608973f8e"
-uuid = "a5390f91-8eb1-5f08-bee0-b1d1ffed6cea"
-version = "0.8.4"
-
-[[Zygote]]
-deps = ["DiffRules", "ForwardDiff", "IRTools", "InteractiveUtils", "LinearAlgebra", "MacroTools", "NNlib", "NaNMath", "Random", "Requires", "SpecialFunctions", "Statistics"]
-git-tree-sha1 = "d3c2ae55d116b5360a73b1e88d1a974b446d933a"
-repo-rev = "ffc50480ff8f7662110bfb82b0b6d4f9cef6e59d"
-repo-url = "https://github.com/FluxML/Zygote.jl.git"
-uuid = "e88e6eb3-aa80-5325-afca-941959d7151f"
-version = "0.6.14+"
diff --git a/examples/Project.toml b/examples/Project.toml
deleted file mode 100644
index 541d5a4f5..000000000
--- a/examples/Project.toml
+++ /dev/null
@@ -1,5 +0,0 @@
-[deps]
-Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c"
-Juno = "e5e0dc1b-0480-54bc-9374-aad01c23163d"
-ProgressMeter = "92933f4c-e287-5a05-a399-4b506db050ca"
-Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
diff --git a/examples/linear_regression.jl b/examples/linear_regression.jl
deleted file mode 100644
index 8b1e2cfe8..000000000
--- a/examples/linear_regression.jl
+++ /dev/null
@@ -1,99 +0,0 @@
-# Initialize environment in current directory
-@info("Ensuring example environment instantiated...")
-import Pkg
-Pkg.activate(@__DIR__)
-Pkg.instantiate()
-
-@info("Loading Zygote...")
-using Zygote, LinearAlgebra
-
-# This example will showcase how we do a simple linear fit with Zygote, making
-# use of complex datastructures, a home-grown stochastic gradient descent
-# optimizer, and some good old-fashioned math. We start with the problem
-# statement: We wish to learn the mapping `f(X) -> Y`, where `X` is a matrix
-# of vector observations, `f()` is a linear mapping function and `Y` is a
-# vector of scalar observations.
-
-# Because we like complex objects, we will define our linear regression as the
-# following object:
-mutable struct LinearRegression
- # These values will be implicitly learned
- weights::Matrix
- bias::Float64
-
- # These values will not be learned
- name::String
-end
-LinearRegression(nparams, name) = LinearRegression(randn(1, nparams), 0.0, name)
-
-# Our linear prediction looks very familiar; w*X + b
-function predict(model::LinearRegression, X)
- return model.weights * X .+ model.bias
-end
-
-# Our "loss" that must be minimized is the l2 norm between our current
-# prediction and our ground-truth Y
-function loss(model::LinearRegression, X, Y)
- return norm(predict(model, X) .- Y, 2)
-end
-
-
-# Our "ground truth" values (that we will learn, to prove that this works)
-weights_gt = [1.0, 2.7, 0.3, 1.2]'
-bias_gt = 0.4
-
-# Generate a dataset of many observations
-X = randn(length(weights_gt), 10000)
-Y = weights_gt * X .+ bias_gt
-
-# Add a little bit of noise to `X` so that we do not have an exact solution,
-# but must instead do a least-squares fit:
-X .+= 0.001.*randn(size(X))
-
-
-# Now we begin our "training loop", where we take examples from `X`,
-# calculate loss with respect to the corresponding entry in `Y`, find the
-# gradient upon our model, update the model, and continue. Before we jump
-# in, let's look at what `Zygote.gradient()` gives us:
-@info("Building model...")
-model = LinearRegression(size(X, 1), "Example")
-
-# Calculate gradient upon `model` for the first example in our training set
-@info("Calculating gradient (the first time can take a while to compile...)")
-grads = Zygote.gradient(model) do m
- return loss(m, X[:,1], Y[1])
-end
-
-# The `grads` object is a Tuple containing one element per argument to
-# `gradient()`, so we take the first one to get the gradient upon `model`:
-grads = grads[1]
-
-# Because our LinearRegression object is mutable, the gradient holds a
-# reference to it, which we peel via `grads[]`:
-grads = grads[]
-
-# We now get a `NamedTuple` so we can now do things like `grads.weight`. Let's
-# print it out, just to see what it looks like. Note that while `weights` and
-# `bias` have gradients, `name` just naturally has a gradient of `nothing`,
-# because it was not involved in the calculation of the output loss.
-@info grads
-
-# Let's define an update rule that will allow us to modify the weights
-# of our model a tad bit according to the gradients
-function sgd_update!(model::LinearRegression, grads, η = 0.001)
- model.weights .-= η .* grads.weights
- model.bias -= η * grads.bias
-end
-
-# Now let's do that for each example in our training set:
-@info("Running train loop for $(size(X,2)) iterations")
-for idx in 1:size(X, 2)
- grads = Zygote.gradient(m -> loss(m, X[:, idx], Y[idx]), model)[1][]
- sgd_update!(model, grads)
-end
-
-# Now let's look at how well we've approximated the ground truth weights/bias:
-@info("Ground truth weights: $(weights_gt)")
-@info("Learned weights: $(round.(model.weights; digits=3))")
-@info("Ground truth bias: $(bias_gt)")
-@info("Learned bias: $(round(model.bias; digits=3))")
diff --git a/examples/mnist_mlp.jl b/examples/mnist_mlp.jl
deleted file mode 100644
index 60c3e7e02..000000000
--- a/examples/mnist_mlp.jl
+++ /dev/null
@@ -1,107 +0,0 @@
-# Initialize environment in current directory
-@info("Ensuring example environment instantiated...")
-import Pkg
-Pkg.activate(@__DIR__)
-Pkg.instantiate()
-
-@info("Loading Zygote and Flux...")
-using Zygote, Flux, Random, Statistics
-using Flux.Data.MNIST
-
-# We're going to showcase how to use Zygote with Flux; we'll create a simple
-# Multi-Layer Perceptron network to do digit classification upon the MNIST
-# dataset. We start with some setup that is ripped straight from the Flux
-# model zoo:
-
-# First, we load the MNIST images and flatten them into a giant matrix:
-@info("Loading dataset...")
-X = hcat(float.(reshape.(MNIST.images(), :))...)
-
-# Load labels as well, one-hot encoding them
-Y = float.(Flux.onehotbatch(MNIST.labels(), 0:9))
-
-# Do the same for the test data/labels:
-X_test = hcat(float.(reshape.(MNIST.images(:test), :))...)
-Y_test = float.(Flux.onehotbatch(MNIST.labels(:test), 0:9))
-
-@info("Constructing MLP model...")
-model = Chain(
- Dense(28^2, 32, relu),
- Dense(32, 10),
- softmax,
-)
-
-# Until Flux drops Tracker as its default Automatic Differentiation library,
-# strip it out with this line:
-model = Flux.mapleaves(Flux.data, model)
-
-# Our loss is the classical multiclass crossentropy loss
-loss(model, X, Y) = Flux.crossentropy(model(X), Y)
-
-# Helper function to calculate accuracy of our model
-accuracy(model, X, Y) = mean(Flux.onecold(model(X)) .== Flux.onecold(Y))
-
-
-# Recursive zygote update method, this is the general recursion case:
-function zyg_update!(opt, model, updates)
- # If this `model` node has no fields, then just return it
- if nfields(model) == 0
- return model
- end
-
- # If it does have fields, recurse into them:
- for field_idx in 1:nfields(model)
- zyg_update!(opt, getfield(model, field_idx), getfield(updates, field_idx))
- end
-
- # In the end, return the `model`
- return model
-end
-# If the `updates` is set to `Nothing`, then just return `model`; this means
-# that there were no changes to be applied to this piece of the model.
-zyg_update!(opt, model, updates::Nothing) = model
-
-# If `model` is an `AbstractArray` and `updates` is too, then apply our Flux
-# optimizer to the incoming gradients and apply them to the model!
-function zyg_update!(opt, model::AbstractArray, updates::AbstractArray)
- # Sub off to Flux's ADAM optimizer
- Flux.Optimise.apply!(opt, model, updates)
- return model .-= updates
-end
-
-
-# We will train for a number of epochs, with minibatches, using the `ADAM`
-# optimizer to nudge our weights toward perfection.
-opt = ADAM(0.001)
-num_epochs = 10
-@info("Training for $(num_epochs) epochs...")
-for epoch_idx in 1:num_epochs
- # "global" here to dodgescoping issues with for loops at top-level
- global X, Y, model
-
- # Shuffle the data each epoch:
- perm = shuffle(1:size(X,2))
- X = X[:, perm]
- Y = Y[:, perm]
-
- # Iterate over batches
- batch_size = 512
- batch_idxs = 1:batch_size:(size(X,2) - batch_size)
- for bidx in batch_idxs
- # Calculate gradients upon the model for this batch
- grads = Zygote.gradient(model) do model
- return loss(model, X[:, bidx:bidx+batch_size],
- Y[:, bidx:bidx+batch_size])
- end
-
- # Peel outer Tuple to access gradient of first parameter
- grads = grads[1]
-
- # Apply recursive update to our model:
- zyg_update!(opt, model, grads)
- end
-
- # After each epoch, report our accuracy on the test set:
- acc = accuracy(model, X_test, Y_test)
- @info("[$(epoch_idx)] Accuracy: $(round(100*acc; digits=1))%")
-end
diff --git a/examples/profiler.jl b/examples/profiler.jl
deleted file mode 100644
index 513fd6933..000000000
--- a/examples/profiler.jl
+++ /dev/null
@@ -1,38 +0,0 @@
-# Initialize environment in current directory
-@info("Ensuring example environment instantiated...")
-import Pkg
-Pkg.activate(@__DIR__)
-Pkg.instantiate()
-
-@info("Loading Zygote...")
-using Zygote
-
-function f(x)
- for i = 1:5
- x = sin(cos(x))
- end
- return x
-end
-
-function loop(x, n)
- r = x/x
- for i = 1:n
- r *= f(x)
- end
- return sin(cos(r))
-end
-
-gradient(loop, 2, 3)
-
-Zygote.@profile loop(2, 3)
-
-function logsumexp(x::Array{Float64,1})
- A = maximum(x);
- ema = exp.(x .- A);
- sema = sum(ema);
- log(sema) + A;
-end
-
-gradient(logsumexp, rand(100))
-
-Zygote.@profile logsumexp(rand(100))
From d89dc341d82bcd66def879eb7cece70e4f549495 Mon Sep 17 00:00:00 2001
From: Michael Abbott <32575566+mcabbott@users.noreply.github.com>
Date: Sun, 4 Jul 2021 09:46:37 -0400
Subject: [PATCH 175/490] use _pullback to avoid some dy -> (nothing,
back(dy)...)
---
src/lib/broadcast.jl | 18 ++++++------------
1 file changed, 6 insertions(+), 12 deletions(-)
diff --git a/src/lib/broadcast.jl b/src/lib/broadcast.jl
index cdcd21547..32b4085df 100644
--- a/src/lib/broadcast.jl
+++ b/src/lib/broadcast.jl
@@ -75,23 +75,17 @@ unbroadcast(x::AbstractArray, x̄::Nothing) = nothing
@adjoint broadcasted(::typeof(*), x::Numeric, y::Numeric) = x.*y,
Δ -> (nothing, unbroadcast(x, Δ .* conj.(y)), unbroadcast(y, Δ .* conj.(x)))
-@adjoint function broadcasted(::typeof(*), x::Number, y::AbstractArray{<:Number})
- z, back = pullback(*, x, y) # this uses dot(y,Δ) instead of Δ .* conj.(y)
- z, Δ -> (nothing, back(Δ)...)
-end
-@adjoint function broadcasted(::typeof(*), x::AbstractArray{<:Number}, y::Number)
- z, back = pullback(*, x, y)
- z, Δ -> (nothing, back(Δ)...)
-end
+@adjoint broadcasted(::typeof(*), x::Number, y::AbstractArray{<:Number}) =
+ _pullback(*, x, y) # this uses dot(y,Δ) instead of sum(Δ .* conj.(y))
+@adjoint broadcasted(::typeof(*), x::AbstractArray{<:Number}, y::Number) =
+ _pullback(*, x, y)
@adjoint function broadcasted(::typeof(/), x::Numeric, y::Numeric)
res = x ./ y
res, Δ -> (nothing, unbroadcast(x, Δ ./ conj.(y)), unbroadcast(y, .-Δ .* conj.(res ./ y)))
end
-@adjoint function broadcasted(::typeof(/), x::AbstractArray{<:Number}, y::Number)
- z, back = pullback(/, x, y)
- z, Δ -> (nothing, back(Δ)...)
-end
+@adjoint broadcasted(::typeof(/), x::AbstractArray{<:Number}, y::Number) =
+ _pullback(/, x, y)
@adjoint function broadcasted(::typeof(Base.literal_pow), ::typeof(^), x::Numeric, exp::Val{p}) where p
y = Base.literal_pow.(^, x, exp)
From 29755a5bf7ed71e978da714ad3a175e7b7c9f8c0 Mon Sep 17 00:00:00 2001
From: Michael Abbott <32575566+mcabbott@users.noreply.github.com>
Date: Sun, 4 Jul 2021 09:47:01 -0400
Subject: [PATCH 176/490] tidy up broadcast_forward
---
src/lib/broadcast.jl | 37 ++++++++++++++++++++-----------------
1 file changed, 20 insertions(+), 17 deletions(-)
diff --git a/src/lib/broadcast.jl b/src/lib/broadcast.jl
index 32b4085df..6afafbefa 100644
--- a/src/lib/broadcast.jl
+++ b/src/lib/broadcast.jl
@@ -174,10 +174,9 @@ _dual_safearg(x) = false
T = Broadcast.combine_eltypes(f, args)
# Avoid generic broadcasting in two easy cases:
if T == Bool
- return f.(args...), _->nothing
+ return (f.(args...), _ -> nothing)
elseif T <: Real && isconcretetype(T) && _dual_purefun(F) && all(_dual_safearg, args)
- y, back = broadcast_forward(f, args...)
- return y, ȳ -> (nothing, nothing, back(ȳ)...)
+ return broadcast_forward(f, args...)
end
len = inclen(args)
y∂b = _broadcast((x...) -> _pullback(__context__, f, x...), args...)
@@ -189,7 +188,7 @@ _dual_safearg(x) = false
end
(nothing, accum_sum(dxs[1]), map(unbroadcast, args, Base.tail(dxs))...)
end
- y, ∇broadcasted
+ return y, ∇broadcasted
end
@adjoint function broadcasted(::AbstractArrayStyle{0}, f, args...)
@@ -231,28 +230,32 @@ function dual_function(f::F) where F
end
@inline function broadcast_forward(f, args::Vararg{Any,N}) where N
- T = Broadcast.combine_eltypes(f, args)
+ valN = Val(N)
out = dual_function(f).(args...)
eltype(out) <: Dual || return (out, _ -> nothing)
y = map(x -> x.value, out)
- _back(ȳ, i) = unbroadcast(args[i], ((a, b) -> a*b.partials[i]).(ȳ, out))
- back(ȳ) = ntuple(i -> _back(ȳ, i), N)
- return y, back
+ function bc_fwd_back(ȳ)
+ dargs = ntuple(valN) do i
+ unbroadcast(args[i], broadcast((y1, o1) -> y1 * o1.partials[i], ȳ, out))
+ end
+ (nothing, nothing, dargs...) # nothings for broadcasted & f
+ end
+ return y, bc_fwd_back
end
@init @require CUDA="052768ef-5323-5732-b1bb-66c8b64840ba" begin
const CuArrayStyle = CUDA.AbstractGPUArrayStyle
- if isdefined(CUDA, :cufunc)
- @eval @adjoint function broadcasted(::CuArrayStyle, f, args...)
- y, back = broadcast_forward(CUDA.cufunc(f), args...)
- y, ȳ -> (nothing, nothing, back(ȳ)...)
- end
+ if isdefined(CUDA, :cufunc) # CUDA < 3.0
+
+ @eval @adjoint broadcasted(::CuArrayStyle, f, args...) =
+ broadcast_forward(CUDA.cufunc(f), args...)
+
else # CUDA >= 3.0
- @eval @adjoint function broadcasted(::CuArrayStyle, f, args...)
- y, back = broadcast_forward(f, args...)
- y, ȳ -> (nothing, nothing, back(ȳ)...)
- end
+
+ @eval @adjoint function broadcasted(::CuArrayStyle, f, args...) =
+ broadcast_forward(f, args...)
+
end
@adjoint CUDA.CuArray{N,T}(xs::Array) where {N,T} =
From c43e48140cea59a56263d5aa02ddd71e13eb7996 Mon Sep 17 00:00:00 2001
From: Michael Abbott <32575566+mcabbott@users.noreply.github.com>
Date: Sun, 4 Jul 2021 09:48:38 -0400
Subject: [PATCH 177/490] rm explicit case for CuArray
---
src/lib/broadcast.jl | 8 ++++----
1 file changed, 4 insertions(+), 4 deletions(-)
diff --git a/src/lib/broadcast.jl b/src/lib/broadcast.jl
index 6afafbefa..e189d9dea 100644
--- a/src/lib/broadcast.jl
+++ b/src/lib/broadcast.jl
@@ -251,10 +251,10 @@ end
@eval @adjoint broadcasted(::CuArrayStyle, f, args...) =
broadcast_forward(CUDA.cufunc(f), args...)
- else # CUDA >= 3.0
-
- @eval @adjoint function broadcasted(::CuArrayStyle, f, args...) =
- broadcast_forward(f, args...)
+ # else CUDA >= 3.0 -- don't need cufunc(f), and ordinary broadcasting calls broadcast_forward when safe
+
+ # @eval @adjoint function broadcasted(::CuArrayStyle, f, args...) =
+ # broadcast_forward(f, args...)
end
From 5d6516adc36765ee6082c126aeb3944fddf7b070 Mon Sep 17 00:00:00 2001
From: Michael Abbott <32575566+mcabbott@users.noreply.github.com>
Date: Sun, 4 Jul 2021 10:45:05 -0400
Subject: [PATCH 178/490] solve two warnings re CUDA
ArgumentError: Package Zygote does not have CUDA in its dependencies
WARNING: using CUDA.trim in module Zygote conflicts with an existing identifier
---
src/lib/broadcast.jl | 11 ++++++-----
1 file changed, 6 insertions(+), 5 deletions(-)
diff --git a/src/lib/broadcast.jl b/src/lib/broadcast.jl
index e189d9dea..e9ae579a8 100644
--- a/src/lib/broadcast.jl
+++ b/src/lib/broadcast.jl
@@ -45,18 +45,18 @@ function Base.reducedim_init(::typeof(identity), ::typeof(accum), A::AbstractArr
Base.reducedim_initarray(A, region, nothing, Union{Nothing,eltype(A)})
end
-trim(x, Δ) = reshape(Δ, ntuple(i -> size(Δ, i), Val(ndims(x))))
-trim(x::Tuple, Δ) = ntuple(k -> Δ[k], length(x))
+_trim(x, Δ) = reshape(Δ, ntuple(i -> size(Δ, i), Val(ndims(x))))
+_trim(x::Tuple, Δ) = NTuple{length(x)}(Δ)
unbroadcast(x::AbstractArray, x̄) =
size(x) == size(x̄) ? x̄ :
- length(x) == length(x̄) ? trim(x, x̄) :
- trim(x, accum_sum(x̄, dims = ntuple(i -> size(x, i) == 1 ? i : ndims(x̄)+1, Val(ndims(x̄)))))
+ length(x) == length(x̄) ? _trim(x, x̄) :
+ _trim(x, accum_sum(x̄, dims = ntuple(i -> size(x, i) == 1 ? i : ndims(x̄)+1, Val(ndims(x̄)))))
unbroadcast(x::Number, x̄) = accum_sum(x̄)
unbroadcast(x::Tuple{<:Any}, x̄) = (accum_sum(x̄),)
unbroadcast(x::Base.RefValue, x̄) = (x=accum_sum(x̄),)
-unbroadcast(x::Tuple, x̄) = trim(x, length(x) == length(x̄) ? x̄ : accum_sum(x̄; dims=2:ndims(x̄))) # case length(x) > 1
+unbroadcast(x::Tuple, x̄) = _trim(x, length(x) == length(x̄) ? x̄ : accum_sum(x̄; dims=2:ndims(x̄))) # case length(x) > 1
unbroadcast(x::AbstractArray, x̄::Nothing) = nothing
@@ -244,6 +244,7 @@ end
end
@init @require CUDA="052768ef-5323-5732-b1bb-66c8b64840ba" begin
+
const CuArrayStyle = CUDA.AbstractGPUArrayStyle
if isdefined(CUDA, :cufunc) # CUDA < 3.0
From ed3ec9bb19a46cb4d64bd452a9c73229217bf5f3 Mon Sep 17 00:00:00 2001
From: Michael Abbott <32575566+mcabbott@users.noreply.github.com>
Date: Sun, 11 Jul 2021 21:24:09 -0400
Subject: [PATCH 179/490] use real.(x) never real(x)
---
src/lib/broadcast.jl | 4 ++--
1 file changed, 2 insertions(+), 2 deletions(-)
diff --git a/src/lib/broadcast.jl b/src/lib/broadcast.jl
index e9ae579a8..df3c3dd16 100644
--- a/src/lib/broadcast.jl
+++ b/src/lib/broadcast.jl
@@ -100,10 +100,10 @@ end
end
@adjoint broadcasted(::typeof(conj), x::Numeric) =
- conj.(x), z̄ -> (nothing, conj.(z̄))
+ conj(x), z̄ -> (nothing, conj(z̄))
@adjoint broadcasted(::typeof(real), x::Numeric) =
- real.(x), z̄ -> (nothing, real.(z̄))
+ real(x), z̄ -> (nothing, real(z̄))
@adjoint broadcasted(::typeof(imag), x::Numeric) =
imag.(x), z̄ -> (nothing, im .* real.(z̄))
From 7b7c6902017293a1d1e9257f0679ea4f4d882a3d Mon Sep 17 00:00:00 2001
From: Michael Abbott <32575566+mcabbott@users.noreply.github.com>
Date: Mon, 12 Jul 2021 23:04:34 -0400
Subject: [PATCH 180/490] add examples from 1027
---
src/lib/array.jl | 2 +-
test/cuda.jl | 7 ++++++-
2 files changed, 7 insertions(+), 2 deletions(-)
diff --git a/src/lib/array.jl b/src/lib/array.jl
index 1ad7eb6a7..51e2ba6d3 100644
--- a/src/lib/array.jl
+++ b/src/lib/array.jl
@@ -8,7 +8,7 @@ using Distributed: pmap, AbstractWorkerPool
@adjoint Array(xs::AbstractArray) = Array(xs), ȳ -> (ȳ,)
@adjoint Array(xs::Array) = Array(xs), ȳ -> (ȳ,)
-@nograd ones, zeros, Base.OneTo, Colon(), one, zero, sizehint!
+@nograd ones, zeros, Base.OneTo, Colon(), one, zero, sizehint!, count
@adjoint Base.vect(xs...) = Base.vect(xs...), Δ -> (Δ...,)
diff --git a/test/cuda.jl b/test/cuda.jl
index 95bcdf373..1d07d39fd 100644
--- a/test/cuda.jl
+++ b/test/cuda.jl
@@ -9,7 +9,7 @@ CUDA.allowscalar(false)
@test gradient(x -> sum(cu(x)), r)[1] isa Array{Float32, 2}
end
-@testset "basic bcasting" begin
+@testset "broadcasting" begin
a = Float32.(1:9)
a_gpu = a |> cu
@@ -24,6 +24,11 @@ end
g_gpu = gradient(x -> w(x), a_gpu)[1]
@test g_gpu isa CuArray
@test g_gpu |> collect ≈ g
+
+ # https://github.com/FluxML/Zygote.jl/issues/1027
+ @test gradient(x -> sum(x .!= 0), a_gpu) == (nothing,)
+ g3 = gradient(x -> sum(x .^ 3) ./ count(x .> 3), a)[1]
+ @test cu(g3) ≈ gradient(x -> sum(x .^ 3) ./ sum(x .> 3), a_gpu)[1]
end
@testset "sum(f, x)" begin
From 08b1c1ea45321eddfe26f1f13258126cdec8177a Mon Sep 17 00:00:00 2001
From: Michael Abbott <32575566+mcabbott@users.noreply.github.com>
Date: Tue, 27 Jul 2021 15:46:32 -0400
Subject: [PATCH 181/490] bools aren't diff
---
src/lib/broadcast.jl | 1 +
1 file changed, 1 insertion(+)
diff --git a/src/lib/broadcast.jl b/src/lib/broadcast.jl
index df3c3dd16..00f536366 100644
--- a/src/lib/broadcast.jl
+++ b/src/lib/broadcast.jl
@@ -219,6 +219,7 @@ using ForwardDiff: Dual
dual(x, p) = x
dual(x::Real, p) = Dual(x, p)
+dual(x::Bool, p) = x
function dual_function(f::F) where F
function (args::Vararg{Any,N}) where N
From f83f20a59b87cb5dc35d5c6e6fb967e12953b9fe Mon Sep 17 00:00:00 2001
From: Michael Abbott <32575566+mcabbott@users.noreply.github.com>
Date: Tue, 27 Jul 2021 16:52:54 -0400
Subject: [PATCH 182/490] un-comment CUDA broadcasting
---
src/lib/broadcast.jl | 6 +++---
1 file changed, 3 insertions(+), 3 deletions(-)
diff --git a/src/lib/broadcast.jl b/src/lib/broadcast.jl
index 00f536366..b4e451aa9 100644
--- a/src/lib/broadcast.jl
+++ b/src/lib/broadcast.jl
@@ -253,10 +253,10 @@ end
@eval @adjoint broadcasted(::CuArrayStyle, f, args...) =
broadcast_forward(CUDA.cufunc(f), args...)
- # else CUDA >= 3.0 -- don't need cufunc(f), and ordinary broadcasting calls broadcast_forward when safe
+ else CUDA >= 3.0 -- don't need cufunc(f), and ordinary broadcasting calls broadcast_forward when safe
- # @eval @adjoint function broadcasted(::CuArrayStyle, f, args...) =
- # broadcast_forward(f, args...)
+ @eval @adjoint function broadcasted(::CuArrayStyle, f, args...) =
+ broadcast_forward(f, args...)
end
From 6642fa7cc9aa2c7e727ee0f314c9d0c51965a55c Mon Sep 17 00:00:00 2001
From: Michael Abbott <32575566+mcabbott@users.noreply.github.com>
Date: Tue, 27 Jul 2021 16:59:14 -0400
Subject: [PATCH 183/490] typo
---
src/lib/broadcast.jl | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/src/lib/broadcast.jl b/src/lib/broadcast.jl
index b4e451aa9..4e3e3e83e 100644
--- a/src/lib/broadcast.jl
+++ b/src/lib/broadcast.jl
@@ -253,7 +253,7 @@ end
@eval @adjoint broadcasted(::CuArrayStyle, f, args...) =
broadcast_forward(CUDA.cufunc(f), args...)
- else CUDA >= 3.0 -- don't need cufunc(f), and ordinary broadcasting calls broadcast_forward when safe
+ else # CUDA >= 3.0 -- don't need cufunc(f), and ordinary broadcasting calls broadcast_forward when safe
@eval @adjoint function broadcasted(::CuArrayStyle, f, args...) =
broadcast_forward(f, args...)
From 7e7ca7565a242b3b8188a6065df2a2e90c8fdea8 Mon Sep 17 00:00:00 2001
From: Michael Abbott <32575566+mcabbott@users.noreply.github.com>
Date: Tue, 27 Jul 2021 17:01:22 -0400
Subject: [PATCH 184/490] another typo
---
src/lib/broadcast.jl | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/src/lib/broadcast.jl b/src/lib/broadcast.jl
index 4e3e3e83e..bc317ae89 100644
--- a/src/lib/broadcast.jl
+++ b/src/lib/broadcast.jl
@@ -255,7 +255,7 @@ end
else # CUDA >= 3.0 -- don't need cufunc(f), and ordinary broadcasting calls broadcast_forward when safe
- @eval @adjoint function broadcasted(::CuArrayStyle, f, args...) =
+ @eval @adjoint broadcasted(::CuArrayStyle, f, args...) =
broadcast_forward(f, args...)
end
From 45ab282e88aa2b23ebb0f38cf7d69e3546826b04 Mon Sep 17 00:00:00 2001
From: Michael Abbott <32575566+mcabbott@users.noreply.github.com>
Date: Tue, 27 Jul 2021 17:43:30 -0400
Subject: [PATCH 185/490] rm two dots
---
test/cuda.jl | 4 ++--
1 file changed, 2 insertions(+), 2 deletions(-)
diff --git a/test/cuda.jl b/test/cuda.jl
index 1d07d39fd..8bb59629e 100644
--- a/test/cuda.jl
+++ b/test/cuda.jl
@@ -27,8 +27,8 @@ end
# https://github.com/FluxML/Zygote.jl/issues/1027
@test gradient(x -> sum(x .!= 0), a_gpu) == (nothing,)
- g3 = gradient(x -> sum(x .^ 3) ./ count(x .> 3), a)[1]
- @test cu(g3) ≈ gradient(x -> sum(x .^ 3) ./ sum(x .> 3), a_gpu)[1]
+ g3 = gradient(x -> sum(x .^ 3) / count(x .> 3), a)[1]
+ @test cu(g3) ≈ gradient(x -> sum(x .^ 3) / sum(x .> 3), a_gpu)[1]
end
@testset "sum(f, x)" begin
From 2cca8fc8cae4247e8a5b0950909c44cbb43b00bc Mon Sep 17 00:00:00 2001
From: Michael Abbott <32575566+mcabbott@users.noreply.github.com>
Date: Sun, 1 Aug 2021 08:31:54 -0400
Subject: [PATCH 186/490] revert change to trim's name
---
src/lib/broadcast.jl | 10 +++++-----
1 file changed, 5 insertions(+), 5 deletions(-)
diff --git a/src/lib/broadcast.jl b/src/lib/broadcast.jl
index bc317ae89..3c4f7f215 100644
--- a/src/lib/broadcast.jl
+++ b/src/lib/broadcast.jl
@@ -45,18 +45,18 @@ function Base.reducedim_init(::typeof(identity), ::typeof(accum), A::AbstractArr
Base.reducedim_initarray(A, region, nothing, Union{Nothing,eltype(A)})
end
-_trim(x, Δ) = reshape(Δ, ntuple(i -> size(Δ, i), Val(ndims(x))))
-_trim(x::Tuple, Δ) = NTuple{length(x)}(Δ)
+trim(x, Δ) = reshape(Δ, ntuple(i -> size(Δ, i), Val(ndims(x))))
+trim(x::Tuple, Δ) = NTuple{length(x)}(Δ)
unbroadcast(x::AbstractArray, x̄) =
size(x) == size(x̄) ? x̄ :
- length(x) == length(x̄) ? _trim(x, x̄) :
- _trim(x, accum_sum(x̄, dims = ntuple(i -> size(x, i) == 1 ? i : ndims(x̄)+1, Val(ndims(x̄)))))
+ length(x) == length(x̄) ? trim(x, x̄) :
+ trim(x, accum_sum(x̄, dims = ntuple(i -> size(x, i) == 1 ? i : ndims(x̄)+1, Val(ndims(x̄)))))
unbroadcast(x::Number, x̄) = accum_sum(x̄)
unbroadcast(x::Tuple{<:Any}, x̄) = (accum_sum(x̄),)
unbroadcast(x::Base.RefValue, x̄) = (x=accum_sum(x̄),)
-unbroadcast(x::Tuple, x̄) = _trim(x, length(x) == length(x̄) ? x̄ : accum_sum(x̄; dims=2:ndims(x̄))) # case length(x) > 1
+unbroadcast(x::Tuple, x̄) = trim(x, length(x) == length(x̄) ? x̄ : accum_sum(x̄; dims=2:ndims(x̄))) # case length(x) > 1
unbroadcast(x::AbstractArray, x̄::Nothing) = nothing
From 19862d1ff517baa5a69933a40eb7e522e3e6d92e Mon Sep 17 00:00:00 2001
From: Michael Abbott <32575566+mcabbott@users.noreply.github.com>
Date: Sun, 1 Aug 2021 09:46:52 -0400
Subject: [PATCH 187/490] comments
---
src/lib/broadcast.jl | 5 ++++-
test/cuda.jl | 6 +++---
2 files changed, 7 insertions(+), 4 deletions(-)
diff --git a/src/lib/broadcast.jl b/src/lib/broadcast.jl
index 3c4f7f215..bcd0cbc35 100644
--- a/src/lib/broadcast.jl
+++ b/src/lib/broadcast.jl
@@ -253,7 +253,10 @@ end
@eval @adjoint broadcasted(::CuArrayStyle, f, args...) =
broadcast_forward(CUDA.cufunc(f), args...)
- else # CUDA >= 3.0 -- don't need cufunc(f), and ordinary broadcasting calls broadcast_forward when safe
+ else # CUDA >= 3.0 -- don't need cufunc(f).
+ # Ordinary broadcasting calls broadcast_forward anyway when certain its' safe,
+ # so perhaps this can be deleted? Possible edge case here:
+ # https://github.com/FluxML/Zygote.jl/pull/1018#issuecomment-873629415
@eval @adjoint broadcasted(::CuArrayStyle, f, args...) =
broadcast_forward(f, args...)
diff --git a/test/cuda.jl b/test/cuda.jl
index 8bb59629e..9eebafc14 100644
--- a/test/cuda.jl
+++ b/test/cuda.jl
@@ -26,9 +26,9 @@ end
@test g_gpu |> collect ≈ g
# https://github.com/FluxML/Zygote.jl/issues/1027
- @test gradient(x -> sum(x .!= 0), a_gpu) == (nothing,)
- g3 = gradient(x -> sum(x .^ 3) / count(x .> 3), a)[1]
- @test cu(g3) ≈ gradient(x -> sum(x .^ 3) / sum(x .> 3), a_gpu)[1]
+ @test gradient(x -> sum(x .!= 0), a_gpu) == (nothing,) # was MethodError: no method matching iterate(::Nothing)
+ g3 = gradient(x -> sum(x .^ 3) / count(x .> 3), a)[1] # was Can't differentiate gc_preserve_end expression
+ @test cu(g3) ≈ gradient(x -> sum(x .^ 3) / sum(x .> 3), a_gpu)[1] # was KernelException -- Zygote v0.6.14, CUDA v3.3.0
end
@testset "sum(f, x)" begin
From 687adbce9b6cdf73d70756b4cafffb71c9c16d6c Mon Sep 17 00:00:00 2001
From: Michael Abbott <32575566+mcabbott@users.noreply.github.com>
Date: Sun, 1 Aug 2021 09:55:01 -0400
Subject: [PATCH 188/490] mark one broken, not a regression
---
test/cuda.jl | 6 ++++--
1 file changed, 4 insertions(+), 2 deletions(-)
diff --git a/test/cuda.jl b/test/cuda.jl
index 9eebafc14..3999ace59 100644
--- a/test/cuda.jl
+++ b/test/cuda.jl
@@ -25,10 +25,12 @@ end
@test g_gpu isa CuArray
@test g_gpu |> collect ≈ g
- # https://github.com/FluxML/Zygote.jl/issues/1027
+ # https://github.com/FluxML/Zygote.jl/issues/1027 # status on Zygote v0.6.14, CUDA v3.3.0 in comments:
@test gradient(x -> sum(x .!= 0), a_gpu) == (nothing,) # was MethodError: no method matching iterate(::Nothing)
+ @test gradient(x -> sum(x .> 3), a_gpu) == (nothing,)
g3 = gradient(x -> sum(x .^ 3) / count(x .> 3), a)[1] # was Can't differentiate gc_preserve_end expression
- @test cu(g3) ≈ gradient(x -> sum(x .^ 3) / sum(x .> 3), a_gpu)[1] # was KernelException -- Zygote v0.6.14, CUDA v3.3.0
+ @test_skip cu(g3) ≈ gradient(x -> sum(x .^ 3) / sum(x .> 3), a_gpu)[1] # was KernelException -- not fixed by PR #1018
+ @test cu(g3) ≈ gradient(x -> sum(x .^ 3) / count(x .> 3), a_gpu)[1]
end
@testset "sum(f, x)" begin
From 88a0182f96d54ec5bee4cd333610082a5f3ee931 Mon Sep 17 00:00:00 2001
From: WT
Date: Mon, 2 Aug 2021 17:27:41 +0100
Subject: [PATCH 189/490] Bump CR dep to 1.5
---
Project.toml | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/Project.toml b/Project.toml
index a1b9e5d0c..22b32dd6a 100644
--- a/Project.toml
+++ b/Project.toml
@@ -23,7 +23,7 @@ ZygoteRules = "700de1a5-db45-46bc-99cf-38207098b444"
[compat]
AbstractFFTs = "0.5, 1.0"
-ChainRules = "1"
+ChainRules = "1.5"
ChainRulesCore = "1.0.1"
ChainRulesTestUtils = "1"
DiffRules = "1.0"
From b7d7cc2e47df158a72dcd0be6af234fe60d0eea5 Mon Sep 17 00:00:00 2001
From: WT
Date: Mon, 2 Aug 2021 17:27:51 +0100
Subject: [PATCH 190/490] Remove redundant rules
---
src/lib/array.jl | 5 -----
1 file changed, 5 deletions(-)
diff --git a/src/lib/array.jl b/src/lib/array.jl
index 51e2ba6d3..387b955cb 100644
--- a/src/lib/array.jl
+++ b/src/lib/array.jl
@@ -3,22 +3,17 @@ using FillArrays: AbstractFill, getindex_value
using Base.Broadcast: broadcasted, broadcast_shape
using Distributed: pmap, AbstractWorkerPool
-@adjoint (::Type{T})(::UndefInitializer, args...) where T<:Array = T(undef, args...), Δ -> nothing
-
@adjoint Array(xs::AbstractArray) = Array(xs), ȳ -> (ȳ,)
@adjoint Array(xs::Array) = Array(xs), ȳ -> (ȳ,)
@nograd ones, zeros, Base.OneTo, Colon(), one, zero, sizehint!, count
-@adjoint Base.vect(xs...) = Base.vect(xs...), Δ -> (Δ...,)
-
@adjoint copy(x::AbstractArray) = copy(x), ȳ -> (ȳ,)
@adjoint collect(x::Tuple) = collect(x), dy -> (Tuple(dy),)
@adjoint collect(x::AbstractArray) = collect(x), dy -> (dy,)
# Array Constructors
-@adjoint (::Type{T})(x::T) where T<:Array = T(x), ȳ -> (ȳ,)
@adjoint function (::Type{T})(x::Number, sz) where {T <: Fill}
back(Δ::AbstractArray) = (sum(Δ), nothing)
back(Δ::NamedTuple) = (Δ.value, nothing)
From ff3d28251e59bddca4cb3840eeec76affc3884c4 Mon Sep 17 00:00:00 2001
From: WT
Date: Mon, 2 Aug 2021 17:30:01 +0100
Subject: [PATCH 191/490] Remove more redundant rules
---
src/lib/array.jl | 11 -----------
1 file changed, 11 deletions(-)
diff --git a/src/lib/array.jl b/src/lib/array.jl
index 387b955cb..15b994564 100644
--- a/src/lib/array.jl
+++ b/src/lib/array.jl
@@ -101,17 +101,6 @@ end
@adjoint collect(x::Array) = collect(x), Δ -> (Δ,)
-@adjoint fill(x::Real, dims...) = fill(x, dims...), Δ->(sum(Δ), map(_->nothing, dims)...)
-
-@adjoint function circshift(A, shifts)
- circshift(A, shifts), Δ -> (circshift(Δ, map(-, shifts)), nothing)
-end
-
-@adjoint function reverse(x::AbstractArray, args...; kwargs...)
- _reverse(t) = reverse(t, args...; kwargs...)
- _reverse(x), Δ->(_reverse(Δ), map(_->nothing, args)...)
-end
-
@adjoint permutedims(xs) = permutedims(xs), Δ -> (permutedims(Δ),)
@adjoint permutedims(xs::AbstractVector) = permutedims(xs), Δ -> (vec(permutedims(Δ)),)
From b952f4976ce710323d470e0599722b234daa12ce Mon Sep 17 00:00:00 2001
From: willtebbutt
Date: Tue, 3 Aug 2021 11:59:03 +0100
Subject: [PATCH 192/490] Bump patch (#1052)
---
Project.toml | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/Project.toml b/Project.toml
index a1b9e5d0c..5d04edb68 100644
--- a/Project.toml
+++ b/Project.toml
@@ -1,6 +1,6 @@
name = "Zygote"
uuid = "e88e6eb3-aa80-5325-afca-941959d7151f"
-version = "0.6.18-DEV"
+version = "0.6.18"
[deps]
AbstractFFTs = "621f4979-c628-5d54-868e-fcf4e3e8185c"
From 4e2ec8fb42d215fd2c66714f11e2a02b5b50b9ae Mon Sep 17 00:00:00 2001
From: WT
Date: Tue, 3 Aug 2021 13:37:31 +0100
Subject: [PATCH 193/490] Bump patch
---
Project.toml | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/Project.toml b/Project.toml
index dc81fbc0a..7f777adb4 100644
--- a/Project.toml
+++ b/Project.toml
@@ -1,6 +1,6 @@
name = "Zygote"
uuid = "e88e6eb3-aa80-5325-afca-941959d7151f"
-version = "0.6.18"
+version = "0.6.19"
[deps]
AbstractFFTs = "621f4979-c628-5d54-868e-fcf4e3e8185c"
From ae822e09bf892d0fef6eadb5ffb1c8b6df5985df Mon Sep 17 00:00:00 2001
From: willtebbutt
Date: Tue, 3 Aug 2021 14:09:26 +0100
Subject: [PATCH 194/490] Update Project.toml
Co-authored-by: Michael Abbott <32575566+mcabbott@users.noreply.github.com>
---
Project.toml | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/Project.toml b/Project.toml
index 7f777adb4..8f0a00973 100644
--- a/Project.toml
+++ b/Project.toml
@@ -24,7 +24,7 @@ ZygoteRules = "700de1a5-db45-46bc-99cf-38207098b444"
[compat]
AbstractFFTs = "0.5, 1.0"
ChainRules = "1.5"
-ChainRulesCore = "1.0.1"
+ChainRulesCore = "1.1"
ChainRulesTestUtils = "1"
DiffRules = "1.0"
FillArrays = "0.8, 0.9, 0.10, 0.11, 0.12"
From ae9e1c37b4b9286378cd126d40495775448b3407 Mon Sep 17 00:00:00 2001
From: st--
Date: Thu, 5 Aug 2021 15:38:28 +0300
Subject: [PATCH 195/490] Fix docstring cross-reference
---
src/lib/grad.jl | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/src/lib/grad.jl b/src/lib/grad.jl
index 4ac7708b6..a522d685a 100644
--- a/src/lib/grad.jl
+++ b/src/lib/grad.jl
@@ -105,7 +105,7 @@ This reverse-mode Jacobian needs to evaluate the pullback once for each element
Doing so is usually only efficient when `length(y)` is small compared to `length(a)`,
otherwise forward mode is likely to be better.
-See also [`withjacobian`](@ref), `hessian`](@ref), [`hessian_reverse`](@ref).
+See also [`withjacobian`](@ref), [`hessian`](@ref), [`hessian_reverse`](@ref).
# Examples
From 05d0c2ae04f334a2ec61e42decfe1172d0f2e6e8 Mon Sep 17 00:00:00 2001
From: Miha Zgubic
Date: Fri, 20 Aug 2021 19:20:34 +0200
Subject: [PATCH 196/490] Support kwargs in `rrule_via_ad` (#1055)
* allow kwargs
* add test for kwarg support
* bump patch
* unwrap closure
* first solution
* second solution kwf()
* Revert "second solution kwf()"
This reverts commit b53a381855963eebde2ed4274fb2059f9263e6bf.
* avoid creating closure unnecesasrily
* short function
* first instead of only
---
Project.toml | 2 +-
src/compiler/chainrules.jl | 14 +++++++++++---
test/chainrules.jl | 7 +++++++
3 files changed, 19 insertions(+), 4 deletions(-)
diff --git a/Project.toml b/Project.toml
index 8f0a00973..7e19b5a9a 100644
--- a/Project.toml
+++ b/Project.toml
@@ -1,6 +1,6 @@
name = "Zygote"
uuid = "e88e6eb3-aa80-5325-afca-941959d7151f"
-version = "0.6.19"
+version = "0.6.20"
[deps]
AbstractFFTs = "621f4979-c628-5d54-868e-fcf4e3e8185c"
diff --git a/src/compiler/chainrules.jl b/src/compiler/chainrules.jl
index fc02d68a8..aaec3951f 100644
--- a/src/compiler/chainrules.jl
+++ b/src/compiler/chainrules.jl
@@ -174,10 +174,18 @@ As per [`chain_rrule`](@ref) but with support for kwargs.
return y, kw_zpullback
end
+function ChainRulesCore.rrule_via_ad(config::ZygoteRuleConfig, f_args...; kwargs...)
+ # create a closure to work around _pullback not accepting kwargs
+ # but avoid creating a closure unnecessarily (pullbacks of closures do not infer)
+ y, pb = if !isempty(kwargs)
+ kwf() = first(f_args)(Base.tail(f_args)...; kwargs...)
+ _y, _pb = _pullback(config.context, kwf)
+ _y, Δ -> first(_pb(Δ)).f_args # `first` should be `only`
+ else
+ _pullback(config.context, f_args...)
+ end
-function ChainRulesCore.rrule_via_ad(config::ZygoteRuleConfig, f, args...)
- y, pb = _pullback(config.context, f, args...)
- ad_pullback(Δ) = zygote2differential(pb(wrap_chainrules_output(Δ)), (f, args...))
+ ad_pullback(Δ) = zygote2differential(pb(wrap_chainrules_output(Δ)), f_args)
return y, ad_pullback
end
diff --git a/test/chainrules.jl b/test/chainrules.jl
index 30b758d04..2a76b081e 100644
--- a/test/chainrules.jl
+++ b/test/chainrules.jl
@@ -275,6 +275,13 @@ end
test_rrule(ZygoteRuleConfig(), getindex, rand(5), 3; rrule_f=rrule_via_ad)
end
+ @testset "kwargs" begin
+ test_rrule(
+ ZygoteRuleConfig(), sum, [1.0 2; 3 4];
+ rrule_f=rrule_via_ad, check_inferred=false, fkwargs=(;dims=1)
+ )
+ end
+
@testset "struct" begin
struct Foo
x
From 9b977a97c67cf07f37a076e4b8b2ee4a05e0cda8 Mon Sep 17 00:00:00 2001
From: Dhairya Gandhi
Date: Mon, 6 Sep 2021 14:43:57 +0530
Subject: [PATCH 197/490] add buffer typevar
---
src/lib/broadcast.jl | 4 ++--
1 file changed, 2 insertions(+), 2 deletions(-)
diff --git a/src/lib/broadcast.jl b/src/lib/broadcast.jl
index bcd0cbc35..8510ae2dd 100644
--- a/src/lib/broadcast.jl
+++ b/src/lib/broadcast.jl
@@ -263,8 +263,8 @@ end
end
- @adjoint CUDA.CuArray{N,T}(xs::Array) where {N,T} =
- CUDA.CuArray{N,T}(xs), Δ -> (convert(Array, Δ), )
+ @adjoint CUDA.CuArray{N,T,B}(xs::Array) where {N,T,B} =
+ CUDA.CuArray{N,T,B}(xs), Δ -> (convert(Array, Δ), )
@adjoint function sum(xs::CUDA.AbstractGPUArray; dims = :)
placeholder = similar(xs)
From 1ef96e7477da29661c0d5c7b4c9d9e4ec2a079d7 Mon Sep 17 00:00:00 2001
From: Dhairya Gandhi
Date: Mon, 6 Sep 2021 19:56:24 +0530
Subject: [PATCH 198/490] use type dispatch
---
src/lib/broadcast.jl | 4 ++--
1 file changed, 2 insertions(+), 2 deletions(-)
diff --git a/src/lib/broadcast.jl b/src/lib/broadcast.jl
index 8510ae2dd..b9d11aa79 100644
--- a/src/lib/broadcast.jl
+++ b/src/lib/broadcast.jl
@@ -263,8 +263,8 @@ end
end
- @adjoint CUDA.CuArray{N,T,B}(xs::Array) where {N,T,B} =
- CUDA.CuArray{N,T,B}(xs), Δ -> (convert(Array, Δ), )
+ @adjoint (::Type{CUDA.CuArray{T,N}})(xs::Array) where {T,N} =
+ CUDA.CuArray{T,N}(xs), Δ -> (convert(Array, Δ), )
@adjoint function sum(xs::CUDA.AbstractGPUArray; dims = :)
placeholder = similar(xs)
From 807e5e5b3e99e8f5bb60e721b77d9a95261e62a9 Mon Sep 17 00:00:00 2001
From: Dhairya Gandhi
Date: Mon, 6 Sep 2021 23:12:34 +0530
Subject: [PATCH 199/490] fix rule
---
src/lib/broadcast.jl | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/src/lib/broadcast.jl b/src/lib/broadcast.jl
index b9d11aa79..1277f70c5 100644
--- a/src/lib/broadcast.jl
+++ b/src/lib/broadcast.jl
@@ -263,7 +263,7 @@ end
end
- @adjoint (::Type{CUDA.CuArray{T,N}})(xs::Array) where {T,N} =
+ @adjoint (::Type{<:CUDA.CuArray{T,N}})(xs::Array) where {T,N} =
CUDA.CuArray{T,N}(xs), Δ -> (convert(Array, Δ), )
@adjoint function sum(xs::CUDA.AbstractGPUArray; dims = :)
From 0c080e786e05d907d92200d3f03e11f1374831dd Mon Sep 17 00:00:00 2001
From: Dhairya Gandhi
Date: Tue, 7 Sep 2021 21:48:01 +0530
Subject: [PATCH 200/490] Update src/lib/broadcast.jl
Co-authored-by: Michael Abbott <32575566+mcabbott@users.noreply.github.com>
---
src/lib/broadcast.jl | 4 ++--
1 file changed, 2 insertions(+), 2 deletions(-)
diff --git a/src/lib/broadcast.jl b/src/lib/broadcast.jl
index 1277f70c5..2b2154e75 100644
--- a/src/lib/broadcast.jl
+++ b/src/lib/broadcast.jl
@@ -263,8 +263,8 @@ end
end
- @adjoint (::Type{<:CUDA.CuArray{T,N}})(xs::Array) where {T,N} =
- CUDA.CuArray{T,N}(xs), Δ -> (convert(Array, Δ), )
+ @adjoint (::Type{T})(xs::Array) where {T<:CUDA.CuArray} =
+ T(xs), Δ -> (convert(Array, Δ), )
@adjoint function sum(xs::CUDA.AbstractGPUArray; dims = :)
placeholder = similar(xs)
From 649a6ac73fac3e3195342e10725319a4f5d924bc Mon Sep 17 00:00:00 2001
From: Dhairya Gandhi
Date: Tue, 7 Sep 2021 21:48:56 +0530
Subject: [PATCH 201/490] whitespaces
---
src/lib/broadcast.jl | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/src/lib/broadcast.jl b/src/lib/broadcast.jl
index 2b2154e75..de69e7a85 100644
--- a/src/lib/broadcast.jl
+++ b/src/lib/broadcast.jl
@@ -263,7 +263,7 @@ end
end
- @adjoint (::Type{T})(xs::Array) where {T<:CUDA.CuArray} =
+ @adjoint (::Type{T})(xs::Array) where {T <: CUDA.CuArray} =
T(xs), Δ -> (convert(Array, Δ), )
@adjoint function sum(xs::CUDA.AbstractGPUArray; dims = :)
From f4536a87467d22432a401e260a617d6c6821fe91 Mon Sep 17 00:00:00 2001
From: Carlo Lucibello
Date: Wed, 8 Sep 2021 07:02:31 +0200
Subject: [PATCH 202/490] fix pair getfield adjoint
---
src/lib/base.jl | 25 ++++++++++++++-----------
test/features.jl | 29 +++++++++++++++++------------
2 files changed, 31 insertions(+), 23 deletions(-)
diff --git a/src/lib/base.jl b/src/lib/base.jl
index 67f8b2c5e..d9e748f9c 100644
--- a/src/lib/base.jl
+++ b/src/lib/base.jl
@@ -118,23 +118,26 @@ end
# named tuple
@adjoint function pairs(t::NamedTuple{N}) where N
- pairs_namedtuple(dx::NamedTuple) = (dx.data,)
- function pairs_namedtuple(Δ::Dict)
- t0 = map(zero, t)
- for (idx, v) in Δ
- t0 = NamedTuple{N}(Base.setindex((t0...,), v, idx))
- end
- return (t0,)
+
+ pairs_namedtuple_pullback(dx::NamedTuple) = (dx.data,)
+
+ function pairs_namedtuple_pullback(Δ::Dict)
+ t0 = map(zero, t)
+ for (idx, v) in Δ
+ t0 = NamedTuple{N}(Base.setindex((t0...,), v, idx))
end
- return pairs(t), pairs_namedtuple
+ return (t0,)
+ end
+
+ return pairs(t), pairs_namedtuple_pullback
end
@adjoint function Base.getfield(p::Pair, i::Int)
- function pair_getfield(Δ)
- f, s = i == 1 ? (Δ, zero(p[2])) : (zero(p[1]), Δ)
+ function pair_getfield_pullback(Δ)
+ f, s = i == 1 ? (Δ, nothing) : (nothing, Δ)
return (first=f, second=s), nothing
end
- return getfield(p, i), pair_getfield
+ return getfield(p, i), pair_getfield_pullback
end
@adjoint Base.nameof(x::UnionAll) = nameof(x), _ -> (nothing,)
diff --git a/test/features.jl b/test/features.jl
index b17f55b41..7f9a1f70c 100644
--- a/test/features.jl
+++ b/test/features.jl
@@ -424,17 +424,17 @@ end
end
mutable struct MyMutable
- value::Float64
+ value::Float64
end
function foo!(m::MyMutable, x)
- m.value = x
+ m.value = x
end
function baz(args)
- m = MyMutable(0.)
- foo!(m, args...)
- m.value
+ m = MyMutable(0.)
+ foo!(m, args...)
+ m.value
end
let
@@ -449,13 +449,18 @@ end
@test pullback(type_test)[1] == Complex{<:Real}
@testset "Pairs" begin
- @test (x->10*pairs((a=x, b=2))[1])'(100) === 10.0
- @test (x->10*pairs((a=x, b=2))[2])'(100) === 0
- foo(;kw...) = 1
- @test gradient(() -> foo(a=1,b=2.0)) === ()
-
- @test (x->10*(x => 2)[1])'(100) === 10.0
- @test (x->10*(x => 2)[2])'(100) === 0
+ @test (x->10*pairs((a=x, b=2))[1])'(100) === 10.0
+ @test (x->10*pairs((a=x, b=2))[2])'(100) === 0
+ foo(;kw...) = 1
+ @test gradient(() -> foo(a=1,b=2.0)) === ()
+
+ @test (x->10*(x => 2)[1])'(100) === 10.0
+ @test (x->10*(x => 2)[2])'(100) === 0
+
+ @test gradient(x-> (:x => x)[2], 17) == (1,)
+
+ d = Dict(:x=>1.0, :y=>3.0);
+ @test gradient(d -> Dict(:x => d[:x])[:x], d) == (Dict(:x => 1),)
end
# https://github.com/JuliaDiff/ChainRules.jl/issues/257
From b7ee5381822a7de3265223baaf8f688cda1ab2a1 Mon Sep 17 00:00:00 2001
From: Carlo Lucibello
Date: Wed, 8 Sep 2021 07:58:52 +0200
Subject: [PATCH 203/490] fix test
---
src/lib/array.jl | 1 +
test/features.jl | 2 +-
2 files changed, 2 insertions(+), 1 deletion(-)
diff --git a/src/lib/array.jl b/src/lib/array.jl
index 15b994564..c63f0f74a 100644
--- a/src/lib/array.jl
+++ b/src/lib/array.jl
@@ -246,6 +246,7 @@ end
@nograd workers
function _pullback(cx::AContext, ::typeof(collect), g::Base.Generator)
+ @show g.f g.iter
y, b = ∇map(cx, g.f, g.iter)
back(::Nothing) = nothing
function back(ȳ)
diff --git a/test/features.jl b/test/features.jl
index 7f9a1f70c..f3931464d 100644
--- a/test/features.jl
+++ b/test/features.jl
@@ -455,7 +455,7 @@ end
@test gradient(() -> foo(a=1,b=2.0)) === ()
@test (x->10*(x => 2)[1])'(100) === 10.0
- @test (x->10*(x => 2)[2])'(100) === 0
+ @test (x->10*(x => 2)[2])'(100) === nothing
@test gradient(x-> (:x => x)[2], 17) == (1,)
From d9227ba07bdc36d74f804fbdd04aa251755cb0e5 Mon Sep 17 00:00:00 2001
From: Carlo Lucibello
Date: Wed, 8 Sep 2021 10:31:59 +0200
Subject: [PATCH 204/490] cleanup
---
src/lib/array.jl | 1 -
1 file changed, 1 deletion(-)
diff --git a/src/lib/array.jl b/src/lib/array.jl
index c63f0f74a..15b994564 100644
--- a/src/lib/array.jl
+++ b/src/lib/array.jl
@@ -246,7 +246,6 @@ end
@nograd workers
function _pullback(cx::AContext, ::typeof(collect), g::Base.Generator)
- @show g.f g.iter
y, b = ∇map(cx, g.f, g.iter)
back(::Nothing) = nothing
function back(ȳ)
From c658277c8b6208b33b70866fab56a4bc25955d3c Mon Sep 17 00:00:00 2001
From: Lyndon White
Date: Thu, 2 Sep 2021 16:15:26 +0100
Subject: [PATCH 205/490] Support functions that splat namedtuples as keyword
arguments
---
src/lib/base.jl | 9 +++++----
test/features.jl | 9 +++++++++
2 files changed, 14 insertions(+), 4 deletions(-)
diff --git a/src/lib/base.jl b/src/lib/base.jl
index d9e748f9c..8c8799585 100644
--- a/src/lib/base.jl
+++ b/src/lib/base.jl
@@ -118,17 +118,18 @@ end
# named tuple
@adjoint function pairs(t::NamedTuple{N}) where N
-
- pairs_namedtuple_pullback(dx::NamedTuple) = (dx.data,)
+ pairs_namedtuple_pullback(dx::NamedTuple) = (dx.data,)
+
function pairs_namedtuple_pullback(Δ::Dict)
t0 = map(zero, t)
for (idx, v) in Δ
- t0 = NamedTuple{N}(Base.setindex((t0...,), v, idx))
+ ii = idx isa Integer ? idx : findfirst(==(idx), keys(t))
+ t0 = NamedTuple{N}(Base.setindex((t0...,), v, ii))
end
return (t0,)
end
-
+
return pairs(t), pairs_namedtuple_pullback
end
diff --git a/test/features.jl b/test/features.jl
index f3931464d..819ab3fd0 100644
--- a/test/features.jl
+++ b/test/features.jl
@@ -463,6 +463,15 @@ end
@test gradient(d -> Dict(:x => d[:x])[:x], d) == (Dict(:x => 1),)
end
+@testset "kwarg splatting, pass in object" begin
+ g(; kwargs...) = kwargs[:x] * kwargs[:z]
+ h(somedata) = g(; somedata...)
+ @test gradient(h, (; x=3.0, y=4.0, z=2.3)) == ((x = 2.3, y = 0.0, z = 3.0),)
+
+ # Currently broken because we fallback to ADing the `merge(::NamedTuple, itr)` which uses `push!`.
+ @test_broken gradient(h, Dict(:x=>3.0, :y=>4.0, :z=>2.3)) isa Any
+end
+
# https://github.com/JuliaDiff/ChainRules.jl/issues/257
@testset "Keyword Argument Passing" begin
struct Type1{VJP}
From 3172e1cd5a8de885e495008d0e7a73e5831fdcb5 Mon Sep 17 00:00:00 2001
From: Lyndon White
Date: Thu, 2 Sep 2021 16:20:24 +0100
Subject: [PATCH 206/490] bump version
---
Project.toml | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/Project.toml b/Project.toml
index 7e19b5a9a..0cc618af0 100644
--- a/Project.toml
+++ b/Project.toml
@@ -1,6 +1,6 @@
name = "Zygote"
uuid = "e88e6eb3-aa80-5325-afca-941959d7151f"
-version = "0.6.20"
+version = "0.6.21"
[deps]
AbstractFFTs = "621f4979-c628-5d54-868e-fcf4e3e8185c"
From a6615f1a7be3e2cb6e240bac0f04c1c8e89abce4 Mon Sep 17 00:00:00 2001
From: Lyndon White
Date: Thu, 2 Sep 2021 16:44:10 +0100
Subject: [PATCH 207/490] Support passing kwargs as splatted dict (by writing a
adjoint for merge(namedtuple, dict)
---
src/lib/base.jl | 9 +++++++++
test/features.jl | 4 +---
2 files changed, 10 insertions(+), 3 deletions(-)
diff --git a/src/lib/base.jl b/src/lib/base.jl
index 8c8799585..b7c2072a2 100644
--- a/src/lib/base.jl
+++ b/src/lib/base.jl
@@ -133,6 +133,15 @@ end
return pairs(t), pairs_namedtuple_pullback
end
+# For merge between NamedTuple and Dict, we will just convert the Dict to a NamedTuple.
+# and then call `pullback`, which should overall be pretty efficient code generated,
+# and it avoids trying to AD the problematic generic `merge(::NamedTuple, ::iter)` method which uses `push!`.
+if VERSION >= v"1.6"
+ @adjoint merge(nt::NamedTuple, dict::Dict) = pullback(merge, nt, NamedTuple(dict))
+else
+ @adjoint merge(nt::NamedTuple, dict::Dict) = pullback(merge, nt, (;dict...))
+end
+
@adjoint function Base.getfield(p::Pair, i::Int)
function pair_getfield_pullback(Δ)
f, s = i == 1 ? (Δ, nothing) : (nothing, Δ)
diff --git a/test/features.jl b/test/features.jl
index 819ab3fd0..2cf7d1976 100644
--- a/test/features.jl
+++ b/test/features.jl
@@ -467,9 +467,7 @@ end
g(; kwargs...) = kwargs[:x] * kwargs[:z]
h(somedata) = g(; somedata...)
@test gradient(h, (; x=3.0, y=4.0, z=2.3)) == ((x = 2.3, y = 0.0, z = 3.0),)
-
- # Currently broken because we fallback to ADing the `merge(::NamedTuple, itr)` which uses `push!`.
- @test_broken gradient(h, Dict(:x=>3.0, :y=>4.0, :z=>2.3)) isa Any
+ @test gradient(h, Dict(:x=>3.0, :y=>4.0, :z=>2.3)) == ((y = 0.0, z = 3.0, x = 2.3),)
end
# https://github.com/JuliaDiff/ChainRules.jl/issues/257
From 76c27d809a85592526b1fa7dd1ec2384958a1889 Mon Sep 17 00:00:00 2001
From: Michael Abbott <32575566+mcabbott@users.noreply.github.com>
Date: Thu, 9 Sep 2021 14:01:40 -0400
Subject: [PATCH 208/490] use slow broadcasting for 2nd order
---
src/lib/broadcast.jl | 2 +-
test/features.jl | 5 +++++
2 files changed, 6 insertions(+), 1 deletion(-)
diff --git a/src/lib/broadcast.jl b/src/lib/broadcast.jl
index bcd0cbc35..300be6f8b 100644
--- a/src/lib/broadcast.jl
+++ b/src/lib/broadcast.jl
@@ -175,7 +175,7 @@ _dual_safearg(x) = false
# Avoid generic broadcasting in two easy cases:
if T == Bool
return (f.(args...), _ -> nothing)
- elseif T <: Real && isconcretetype(T) && _dual_purefun(F) && all(_dual_safearg, args)
+ elseif T <: Real && isconcretetype(T) && _dual_purefun(F) && all(_dual_safearg, args) && !isderiving()
return broadcast_forward(f, args...)
end
len = inclen(args)
diff --git a/test/features.jl b/test/features.jl
index f3931464d..111be3759 100644
--- a/test/features.jl
+++ b/test/features.jl
@@ -551,4 +551,9 @@ end
@test gradient((x,p) -> sum(x .^ p), [1.0,2.0,4.0], -1)[1] ≈ [-1.0, -0.25, -0.0625]
@test gradient((x,p) -> sum(z -> z^p, x), [1.0,2.0,4.0], -1)[1] ≈ [-1.0, -0.25, -0.0625]
@test gradient((x,p) -> mapreduce(z -> z^p, +, x), [1.0,2.0,4.0], -1)[1] ≈ [-1.0, -0.25, -0.0625]
+
+ # second order
+ @test gradient(x -> sum(gradient(y -> sum(y.^2), x)[1]), [1, 2])[1] ≈ [2, 2]
+ @test gradient(x -> sum(gradient(y -> sum(sin.(y)), x)[1]), [1, 2])[1] ≈ [-0.8414709848078965, -0.9092974268256817]
+ @test gradient(x -> sum(abs, gradient(y -> sum(log.(2 .* exp.(y)) .^ 2), x)[1]), [1, 2])[1] ≈ [2,2]
end
From 52f5fb27b2f42aba29377cfe6a697745785ee95a Mon Sep 17 00:00:00 2001
From: Michael Abbott <32575566+mcabbott@users.noreply.github.com>
Date: Thu, 9 Sep 2021 14:52:08 -0400
Subject: [PATCH 209/490] add short-circuit to rrule_via_ad
---
src/compiler/chainrules.jl | 4 ++++
1 file changed, 4 insertions(+)
diff --git a/src/compiler/chainrules.jl b/src/compiler/chainrules.jl
index aaec3951f..4bf7da28a 100644
--- a/src/compiler/chainrules.jl
+++ b/src/compiler/chainrules.jl
@@ -175,6 +175,10 @@ As per [`chain_rrule`](@ref) but with support for kwargs.
end
function ChainRulesCore.rrule_via_ad(config::ZygoteRuleConfig, f_args...; kwargs...)
+ # first check whether there is an `rrule` which handles this directly
+ direcct = rrule(config, f_args...; kwargs...)
+ direcct === nothing || return direcct
+
# create a closure to work around _pullback not accepting kwargs
# but avoid creating a closure unnecessarily (pullbacks of closures do not infer)
y, pb = if !isempty(kwargs)
From 3b50c5c7dd94516239d6a2f338c628f5798547cc Mon Sep 17 00:00:00 2001
From: Michael Abbott <32575566+mcabbott@users.noreply.github.com>
Date: Thu, 9 Sep 2021 19:30:29 -0400
Subject: [PATCH 210/490] test for shortcut
---
test/chainrules.jl | 12 ++++++++++++
1 file changed, 12 insertions(+)
diff --git a/test/chainrules.jl b/test/chainrules.jl
index 2a76b081e..ec13f6e96 100644
--- a/test/chainrules.jl
+++ b/test/chainrules.jl
@@ -325,6 +325,18 @@ end
test_rrule(ZygoteRuleConfig(), +, rand(3), rand(3); rrule_f=rrule_via_ad)
test_rrule(ZygoteRuleConfig(), *, rand(1, 3), rand(3); rrule_f=rrule_via_ad)
end
+
+ @testset "rules which call rrule_via_ad" begin
+ # since cbrt has a rule, this will test the shortcut:
+ test_rrule(ZygoteRuleConfig(), sum, cbrt, randn(5))
+ test_rrule(ZygoteRuleConfig(), sum, cbrt, randn(5); rrule_f=rrule_via_ad)
+
+ # but x -> cbrt(x) has no rule, so will be done by Zygote
+ test_rrule(ZygoteRuleConfig(), sum, x -> cbrt(x), randn(5))
+ test_rrule(ZygoteRuleConfig(), sum, x -> cbrt(x), randn(5); rrule_f=rrule_via_ad)
+
+ test_rrule(ZygoteRuleConfig(), identity∘sum, x -> cbrt(x), randn(5); rrule_f=rrule_via_ad, check_inferred=false)
+ end
end
@testset "FastMath support" begin
From 1b7dacc368e80d0b54c3d95431444571e526dc18 Mon Sep 17 00:00:00 2001
From: Michael Abbott <32575566+mcabbott@users.noreply.github.com>
Date: Thu, 9 Sep 2021 19:38:33 -0400
Subject: [PATCH 211/490] bump
---
Project.toml | 2 +-
test/chainrules.jl | 2 --
2 files changed, 1 insertion(+), 3 deletions(-)
diff --git a/Project.toml b/Project.toml
index 7e19b5a9a..0cc618af0 100644
--- a/Project.toml
+++ b/Project.toml
@@ -1,6 +1,6 @@
name = "Zygote"
uuid = "e88e6eb3-aa80-5325-afca-941959d7151f"
-version = "0.6.20"
+version = "0.6.21"
[deps]
AbstractFFTs = "621f4979-c628-5d54-868e-fcf4e3e8185c"
diff --git a/test/chainrules.jl b/test/chainrules.jl
index ec13f6e96..b87a9ea3e 100644
--- a/test/chainrules.jl
+++ b/test/chainrules.jl
@@ -334,8 +334,6 @@ end
# but x -> cbrt(x) has no rule, so will be done by Zygote
test_rrule(ZygoteRuleConfig(), sum, x -> cbrt(x), randn(5))
test_rrule(ZygoteRuleConfig(), sum, x -> cbrt(x), randn(5); rrule_f=rrule_via_ad)
-
- test_rrule(ZygoteRuleConfig(), identity∘sum, x -> cbrt(x), randn(5); rrule_f=rrule_via_ad, check_inferred=false)
end
end
From 7f2fed9c58d4650c0295fd9aa0ee92e643a0f0c0 Mon Sep 17 00:00:00 2001
From: Lyndon White
Date: Fri, 10 Sep 2021 10:51:04 +0100
Subject: [PATCH 212/490] bump version
---
Project.toml | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/Project.toml b/Project.toml
index 0cc618af0..476419f64 100644
--- a/Project.toml
+++ b/Project.toml
@@ -1,6 +1,6 @@
name = "Zygote"
uuid = "e88e6eb3-aa80-5325-afca-941959d7151f"
-version = "0.6.21"
+version = "0.6.22"
[deps]
AbstractFFTs = "621f4979-c628-5d54-868e-fcf4e3e8185c"
From 57adb2d2ca919d937dc815e3e660a121feb1c160 Mon Sep 17 00:00:00 2001
From: Michael Abbott <32575566+mcabbott@users.noreply.github.com>
Date: Fri, 10 Sep 2021 11:34:05 -0400
Subject: [PATCH 213/490] unbump version
---
Project.toml | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/Project.toml b/Project.toml
index 476419f64..0cc618af0 100644
--- a/Project.toml
+++ b/Project.toml
@@ -1,6 +1,6 @@
name = "Zygote"
uuid = "e88e6eb3-aa80-5325-afca-941959d7151f"
-version = "0.6.22"
+version = "0.6.21"
[deps]
AbstractFFTs = "621f4979-c628-5d54-868e-fcf4e3e8185c"
From 528e0be677d1feb9ccf6fc4ab298f4d8a106de10 Mon Sep 17 00:00:00 2001
From: Michael Abbott <32575566+mcabbott@users.noreply.github.com>
Date: Tue, 21 Sep 2021 23:04:13 -0400
Subject: [PATCH 214/490] Use `ProjectTo` in broadcasting & `gradient` (#1044)
* use ProjectTo in broadcasting, etc
* separate methods for Params
* move after defn
* better dims handling in unbroadcast
* tidier
* tests
* more wrapping
* fix a test
* handle a few nothings
* fix more, including FFT tests
* tests
* one test
* tests
* tests
* tests
* these are fixed
* add Compat
* tests
* add tests for issues closed
* simplify, some doctests
* fix some tests
* less piracy
* adjoint
* piract
* skip a test
* splat tests
* skip on 1.3
* simplify _project
* a typo
* tweak
* broken GPU test, unrelated
* unexpected pass
* only broken on 1.6
* let nothing through
* rm some broken things
* target 1.3 fix
* comments
* update for ProjectTo(::Any)
* fix a test
* Update test/utils.jl
Co-authored-by: Lyndon White
* Update src/lib/broadcast.jl
* cu tests
* v0.6.22
Co-authored-by: Lyndon White
---
Project.toml | 4 +-
README.md | 2 +-
src/compiler/chainrules.jl | 22 +++++++++
src/compiler/interface.jl | 23 ++++++---
src/lib/array.jl | 2 +-
src/lib/broadcast.jl | 19 ++++----
test/complex.jl | 33 ++++++++++++-
test/cuda.jl | 39 ++++++++++++++-
test/features.jl | 24 +++++++++-
test/forward/forward.jl | 3 +-
test/gradcheck.jl | 97 +++++++++++++++++++++-----------------
test/structures.jl | 1 +
test/utils.jl | 23 +++++----
13 files changed, 214 insertions(+), 78 deletions(-)
diff --git a/Project.toml b/Project.toml
index 0cc618af0..56d086f99 100644
--- a/Project.toml
+++ b/Project.toml
@@ -1,6 +1,6 @@
name = "Zygote"
uuid = "e88e6eb3-aa80-5325-afca-941959d7151f"
-version = "0.6.21"
+version = "0.6.22"
[deps]
AbstractFFTs = "621f4979-c628-5d54-868e-fcf4e3e8185c"
@@ -24,7 +24,7 @@ ZygoteRules = "700de1a5-db45-46bc-99cf-38207098b444"
[compat]
AbstractFFTs = "0.5, 1.0"
ChainRules = "1.5"
-ChainRulesCore = "1.1"
+ChainRulesCore = "1.6"
ChainRulesTestUtils = "1"
DiffRules = "1.0"
FillArrays = "0.8, 0.9, 0.10, 0.11, 0.12"
diff --git a/README.md b/README.md
index 8551bca87..6b2a6517d 100644
--- a/README.md
+++ b/README.md
@@ -18,7 +18,7 @@ julia> using Zygote
julia> f(x) = 5x + 3
julia> f(10), f'(10)
-(53, 5)
+(53, 5.0)
julia> @code_llvm f'(10)
define i64 @"julia_#625_38792"(i64) {
diff --git a/src/compiler/chainrules.jl b/src/compiler/chainrules.jl
index 4bf7da28a..e879af3f8 100644
--- a/src/compiler/chainrules.jl
+++ b/src/compiler/chainrules.jl
@@ -123,11 +123,33 @@ Convert `x` from the format Zygote uses internally to differentials types ChainR
"""
@inline wrap_chainrules_input(x) = x
@inline wrap_chainrules_input(::Nothing) = ChainRules.ZeroTangent()
+@inline wrap_chainrules_input(::AbstractArray{Nothing}) = ChainRules.ZeroTangent()
@inline function wrap_chainrules_input(xs::Union{Tuple, NamedTuple})
xp = map(wrap_chainrules_input, xs)
ChainRules.Tangent{Any, typeof(xp)}(xp)
end
+"""
+ _project(x, dx)
+
+Uses `ChainRulesCore.ProjectTo` to standardise the gradient `dx` for type & shape.
+Also handles some Zygote-specific corrections, such as `x::Array, dx::Tuple`.
+Safe to apply to arbitrary input.
+"""
+@inline function _project(x, dx)
+ wrap_chainrules_output(ProjectTo(x)(wrap_chainrules_input(dx)))
+end
+
+# Restore splatted arrays
+_project(x::AbstractArray, dx::Tuple) = _project(x, reshape(collect(dx), axes(x)))
+
+# Piracy:
+# wrap_chainrules_input doesn't handle array of Union{Int,Nothing}
+(::ChainRulesCore.ProjectTo)(::Nothing) = ChainRulesCore.NoTangent()
+
+# CRC likes Tangent{<:Complex}, but Zygote makes Tangent{Any}
+(project::ProjectTo{<:Complex})(dx::Tangent) = project(Complex(dx.re, dx.im))
+
"""
ZBack{F}(back) <: Function
diff --git a/src/compiler/interface.jl b/src/compiler/interface.jl
index e4db33471..9dc934a49 100644
--- a/src/compiler/interface.jl
+++ b/src/compiler/interface.jl
@@ -68,15 +68,20 @@ julia> gradient([7, 11], 0, 1) do x, y, d
p = size(x, d)
sum(x.^p .+ y)
end
-([14.0, 22.0], 2, nothing)
+([14.0, 22.0], 2.0, nothing)
```
"""
function gradient(f, args...)
y, back = pullback(f, args...)
- return back(sensitivity(y))
+ grad = back(sensitivity(y))
+ isnothing(grad) ? nothing : map(_project, args, grad)
end
-Base.adjoint(f::Function) = x -> gradient(f, x)[1]
+# Base.adjoint(f::Function) = x -> gradient(f, x)[1] # piracy!
+Base.adjoint(f::Function) = x -> begin # still piracy! avoids projection for legacy reasons
+ y, back = pullback(f, x)
+ back(sensitivity(y))[1]
+end
"""
withgradient(f, args...)
@@ -95,7 +100,9 @@ true
"""
function withgradient(f, args...)
y, back = pullback(f, args...)
- (val = y, grad = back(sensitivity(y)))
+ grad = back(sensitivity(y))
+ results = isnothing(grad) ? map(_ -> nothing, args) : map(_project, args, grad)
+ (val=y, grad=results)
end
# Param-style wrappers
@@ -115,9 +122,9 @@ julia> g = gradient(Params([x, y])) do
Grads(...)
julia> g[x]
-2×3 Matrix{Int64}:
- 7 70 700
- 8 80 800
+2×3 Matrix{Float64}:
+ 7.0 70.0 700.0
+ 8.0 80.0 800.0
julia> haskey(g, z) # only x and y are parameters
false
@@ -144,6 +151,8 @@ Params(xs::Tuple) = Params(collect(xs))
@forward Params.order Base.iterate, Base.length, Base.getindex
@forward Params.params Base.in
+Base.map(::typeof(_project), args::Tuple{Params}, grad) = grad # skip _project in gradient(f, ::Params)
+
function Base.union!(ps::Params, itrs...)
foreach(itr -> foreach(x -> push!(ps, x), itr), itrs)
return ps
diff --git a/src/lib/array.jl b/src/lib/array.jl
index 15b994564..9bec64b95 100644
--- a/src/lib/array.jl
+++ b/src/lib/array.jl
@@ -38,7 +38,7 @@ end
dxv = view(dx, inds...)
dxv .= accum.(dxv, _droplike(dy, dxv))
end
- return (dx, map(_->nothing, inds)...)
+ return (_project(x, dx), map(_->nothing, inds)...)
end
"""
diff --git a/src/lib/broadcast.jl b/src/lib/broadcast.jl
index 446e919b1..4e7a3a1cc 100644
--- a/src/lib/broadcast.jl
+++ b/src/lib/broadcast.jl
@@ -45,18 +45,19 @@ function Base.reducedim_init(::typeof(identity), ::typeof(accum), A::AbstractArr
Base.reducedim_initarray(A, region, nothing, Union{Nothing,eltype(A)})
end
-trim(x, Δ) = reshape(Δ, ntuple(i -> size(Δ, i), Val(ndims(x))))
-trim(x::Tuple, Δ) = NTuple{length(x)}(Δ)
-
-unbroadcast(x::AbstractArray, x̄) =
- size(x) == size(x̄) ? x̄ :
- length(x) == length(x̄) ? trim(x, x̄) :
- trim(x, accum_sum(x̄, dims = ntuple(i -> size(x, i) == 1 ? i : ndims(x̄)+1, Val(ndims(x̄)))))
-
+function unbroadcast(x::AbstractArray, x̄)
+ N = ndims(x̄)
+ if length(x) == length(x̄)
+ _project(x, x̄) # ProjectTo handles reshape, offsets, structured matrices, row vectors
+ else
+ dims = ntuple(d -> size(x, d) == 1 ? d : ndims(x̄)+1, ndims(x̄))
+ _project(x, accum_sum(x̄; dims = dims))
+ end
+end
unbroadcast(x::Number, x̄) = accum_sum(x̄)
unbroadcast(x::Tuple{<:Any}, x̄) = (accum_sum(x̄),)
unbroadcast(x::Base.RefValue, x̄) = (x=accum_sum(x̄),)
-unbroadcast(x::Tuple, x̄) = trim(x, length(x) == length(x̄) ? x̄ : accum_sum(x̄; dims=2:ndims(x̄))) # case length(x) > 1
+unbroadcast(x::Tuple, x̄) = NTuple{length(x)}(length(x) == length(x̄) ? x̄ : accum_sum(x̄; dims=2:ndims(x̄))) # case length(x) > 1
unbroadcast(x::AbstractArray, x̄::Nothing) = nothing
diff --git a/test/complex.jl b/test/complex.jl
index 6a0445b85..1abd1303f 100644
--- a/test/complex.jl
+++ b/test/complex.jl
@@ -1,9 +1,13 @@
using Zygote, Test, LinearAlgebra
+@testset "basic" begin
+
@test gradient(x -> real(abs(x)*exp(im*angle(x))), 10+20im)[1] ≈ 1
@test gradient(x -> imag(real(x)+0.3im), 0.3)[1] ≈ 0
-@test gradient(x -> imag(conj(x)+0.3im), 0.3)[1] ≈ -1im
-@test gradient(x -> abs((imag(x)+0.3)), 0.3)[1] == 1im
+@test gradient(x -> imag(conj(x)+0.3im), 0.3 + 0im)[1] ≈ -1im
+@test gradient(x -> imag(conj(x)+0.3im), 0.3)[1] ≈ 0 # projected to zero
+@test gradient(x -> abs((imag(x)+0.3)), 0.3 + 0im)[1] ≈ 1im
+@test gradient(x -> abs((imag(x)+0.3)), 0.3)[1] ≈ 0
@test gradient(a -> real((a*conj(a))), 0.3im)[1] == 0.6im
@test gradient(a -> real((a.*conj(a))), 0.3im)[1] == 0.6im
@@ -21,6 +25,8 @@ using Zygote, Test, LinearAlgebra
@test gradient(x -> imag(sum(exp, x)), [1,2,3])[1] ≈ real(im .* exp.(1:3))
@test gradient(x -> imag(sum(exp, x)), [1+0im,2,3])[1] ≈ im .* exp.(1:3)
+end # @testset
+
fs_C_to_R = (real,
imag,
abs,
@@ -81,3 +87,26 @@ fs_C_to_C_non_holomorphic = (conj,
end
end
end
+
+@testset "issue 342" begin
+ @test Zygote.gradient(x->real(x + 2.0*im), 3.0) == (1.0,)
+ @test Zygote.gradient(x->imag(x + 2.0*im), 3.0) == (0.0,)
+end
+
+@testset "issue 402" begin
+ A = [1,2,3.0]
+ y, B_getindex = Zygote.pullback(x->getindex(x,2,1),Diagonal(A))
+ bA = B_getindex(1)[1]
+ @test bA isa Diagonal
+ @test bA == [0.0 0.0 0.0; 0.0 0.0 0.0; 0.0 0.0 0.0]
+end
+
+@testset "issue #917" begin
+ function fun(v)
+ c = v[1:3] + v[4:6]*im
+ r = v[7:9]
+ sum(r .* abs2.(c)) # This would be calling my actual function depending on r and c
+ end
+ @test Zygote.hessian(fun, collect(1:9)) ≈ [14 0 0 0 0 0 2 0 0; 0 16 0 0 0 0 0 4 0; 0 0 18 0 0 0 0 0 6; 0 0 0 14 0 0 8 0 0; 0 0 0 0 16 0 0 10 0; 0 0 0 0 0 18 0 0 12; 2 0 0 8 0 0 0 0 0; 0 4 0 0 10 0 0 0 0; 0 0 6 0 0 12 0 0 0]
+end
+
diff --git a/test/cuda.jl b/test/cuda.jl
index 3999ace59..5cb1c8cdc 100644
--- a/test/cuda.jl
+++ b/test/cuda.jl
@@ -1,12 +1,20 @@
using CUDA
using Zygote: Grads
+using LinearAlgebra
using Random: randn!
CUDA.allowscalar(false)
# Test GPU movement inside the call to `gradient`
@testset "GPU movement" begin
r = rand(Float32, 3,3)
- @test gradient(x -> sum(cu(x)), r)[1] isa Array{Float32, 2}
+ @test gradient(x -> sum(cu(x)), r)[1] isa Matrix{Float32}
+ @test gradient(x -> sum(x->log(x), cu(x)), r)[1] isa Matrix
+ @test gradient((x,cy) -> sum(cu(x) * cy) + sum(cy'), r, cu(r))[2] isa CUDA.CuArray
+ @test_skip gradient((x,cy) -> sum(cu(x[:,1])' * cy), r, cu(r))[2] isa CUDA.CuArray # generic_matmatmul!
+
+ # Other direction:
+ @test_skip gradient(x -> sum(Array(x)), cu(r))[1] isa CUDA.CuArray
+ @test_skip gradient((x,cy) -> sum(x * Array(cy)) + sum(cy'), r, cu(r))[2] isa CUDA.CuArray
end
@testset "broadcasting" begin
@@ -31,10 +39,19 @@ end
g3 = gradient(x -> sum(x .^ 3) / count(x .> 3), a)[1] # was Can't differentiate gc_preserve_end expression
@test_skip cu(g3) ≈ gradient(x -> sum(x .^ 3) / sum(x .> 3), a_gpu)[1] # was KernelException -- not fixed by PR #1018
@test cu(g3) ≈ gradient(x -> sum(x .^ 3) / count(x .> 3), a_gpu)[1]
+
+ # Projection: eltype preservation:
+ @test gradient(x -> 2.3 * sum(x.^4), a_gpu)[1] isa CuArray{Float32}
+ @test_skip gradient(x -> sum(x .* 5.6), a_gpu)[1] isa CUDA.CuArray{Float32} # dot(x::CuArray{Float64}, y::CuArray{Float32}) fallback
+ # structure restoration:
+ @test gradient(x -> sum(sqrt.(x)), a_gpu')[1] isa Adjoint # previously a matrix
+ @test gradient(x -> sum(exp.(x)), Diagonal(a_gpu))[1] isa Diagonal
+ # non-differentiables
+ @test gradient((x,y) -> sum(x.^2 .+ y'), a_gpu, a_gpu .> 0)[2] === nothing
end
@testset "sum(f, x)" begin
- a = Float32.([-1.5, -9.0, 2.4, -1.3, 0.01])
+ a = Float32[-1.5, -9.0, 2.4, -1.3, 0.01]
a_gpu = a |> cu
f(x) = sum(abs, x)
@@ -42,6 +59,18 @@ end
g_gpu = gradient(f, a_gpu)[1]
@test g_gpu isa CuArray
@test g_gpu |> collect ≈ g
+
+ f2(x) = sum(abs2, x) # sum(abs2, x) has its own rrule
+ g2 = gradient(f2, a)[1]
+ g2_gpu = gradient(f2, a_gpu)[1]
+ @test g2_gpu isa CuArray
+ @test g2_gpu |> collect ≈ g2
+
+ f3(x) = sum(y->y^3, x') # anonymous function
+ g3 = gradient(f3, a')[1]
+ g3_gpu = gradient(f3, a_gpu')[1]
+ @test g3_gpu isa Adjoint{Float32, <:CuArray{Float32, 1}} # preserves structure
+ @test g3_gpu |> collect ≈ g3
end
@testset "jacobian" begin
@@ -103,5 +132,11 @@ end
r = cu(rand(Float32, 3))
grads = (cu(ones(Float32, 3)), 1.f0)
@test gradient((x,y) -> sum(vcat(x,y)), r, 5) == grads
+
+ @test gradient((x,y) -> sum(vcat(x,y)), r, Float64(5))[1] isa CUDA.CuArray{Float32}
+ @test gradient((x,y) -> sum(vcat(x,y)), r, Float64(5))[2] isa Float64 # projection
+
+ @test_skip gradient((x,y) -> sum(vcat(x,y)), 5f0, r)[2] isa CUDA.CuArray{Float32} # wrong order
+ @test_skip gradient((x,y) -> sum(vcat(x,y)), 1f0, r, 2f0, r)[2] isa CUDA.CuArray{Float32}
end
diff --git a/test/features.jl b/test/features.jl
index 8c460dc98..d683d0d94 100644
--- a/test/features.jl
+++ b/test/features.jl
@@ -176,9 +176,9 @@ end
@test gradient(t -> t[1]*t[2], (2, 3)) == ((3, 2),)
-@test gradient(x -> x.re, 2+3im) == ((re = 1, im = nothing),)
+@test gradient(x -> x.re, 2+3im) === (1.0 + 0.0im,)
-@test gradient(x -> x.re*x.im, 2+3im) == ((re = 3, im = 2),)
+@test gradient(x -> x.re*x.im, 2+3im) == (3.0 + 2.0im,)
struct Bar{T}
a::T
@@ -262,6 +262,7 @@ D(f, x) = grad(f, x)[1]
@test D(x -> x*D(y -> x+y, 1), 1) == 1
@test D(x -> x*D(y -> x*y, 1), 4) == 8
+@test sin''(1.0) == -sin(1.0)
@test sin'''(1.0) == -cos(1.0)
f(x) = throw(DimensionMismatch("fubar"))
@@ -499,6 +500,25 @@ end
@test x[1] == x[2]
end
+@testset "splats" begin
+ @test gradient(x -> max(x...), [1,2,3])[1] == [0,0,1]
+ @test gradient(x -> min(x...), (1,2,3))[1] === (1.0, 0.0, 0.0)
+
+ @test gradient(x -> max(x...), [1 2; 3 4])[1] == [0 0; 0 1]
+ @test gradient(x -> max(x...), [1,2,3]')[1] == [0 0 1]
+
+ # https://github.com/FluxML/Zygote.jl/issues/599
+ @test gradient(w -> sum([w...]), [1,1])[1] isa AbstractVector
+
+ # https://github.com/FluxML/Zygote.jl/issues/866
+ f866(x) = reshape(x, fill(2, 2)...)
+ @test gradient(x->sum(f866(x)), rand(4))[1] == [1,1,1,1]
+
+ # https://github.com/FluxML/Zygote.jl/issues/731
+ f731(x) = sum([x' * x, x...])
+ @test_broken gradient(f731, ones(3)) # MethodError: no method matching +(::Tuple{Float64, Float64, Float64}, ::Vector{Float64})
+end
+
@testset "accumulation" begin
# from https://github.com/FluxML/Zygote.jl/issues/905
function net(x1)
diff --git a/test/forward/forward.jl b/test/forward/forward.jl
index 3ae0f6e3a..6aa9173ef 100644
--- a/test/forward/forward.jl
+++ b/test/forward/forward.jl
@@ -36,7 +36,8 @@ end == 1
x
end == 0
-@test D(x -> abs(x+2im), 1) == gradient(x -> abs(x+2im), 1)[1]
+@test D(x -> abs(x+2im), 1) == gradient(x -> abs(x+2im), 1+0im)[1]
+@test real(D(x -> abs(x+2im), 1)) == gradient(x -> abs(x+2im), 1)[1] # ProjectTo means gradient here is real
using LinearAlgebra
diff --git a/test/gradcheck.jl b/test/gradcheck.jl
index eab959ddd..af49b7697 100644
--- a/test/gradcheck.jl
+++ b/test/gradcheck.jl
@@ -177,7 +177,7 @@ end
# Ensure that nothings work with non-numeric types.
_, back = Zygote.pullback(getindex, [randn(2) for _ in 1:3], [1])
- @test back([nothing]) == ([nothing for _ in 1:3], nothing)
+ @test back([nothing]) == (nothing, nothing)
end
@testset "view" begin
@@ -332,10 +332,10 @@ end
@test gradient(x -> sum(log, filter(iseven, x)), 1:10) ==
(map(x -> iseven(x) ? 1/x : 0, 1:10),)
@test gradient(x -> sum(abs2, im .+ filter(iseven, x)), 1:10) ==
- (map(x -> iseven(x) ? 2x+2im : 0, 1:10),)
+ (map(x -> iseven(x) ? 2x : 0, 1:10),)
+ # (map(x -> iseven(x) ? 2x+2im : 0, 1:10),)
end
-
@testset "mean" begin
@test gradtest(mean, rand(2, 3))
@@ -1157,10 +1157,10 @@ end
end
@testset "hvcat" begin
- @test gradient(xs -> hvcat((2,2),xs...)[1,1], [1,2,3,4])[1] == (1,0,0,0)
- @test gradient(xs -> hvcat((2,2),xs...)[2,1], [1,2,3,4])[1] == (0,0,1,0)
- @test gradient(xs -> hvcat((2,2),xs...)[1,2], [1,2,3,4])[1] == (0,1,0,0)
- @test gradient(xs -> hvcat((2,2),xs...)[2,2], [1,2,3,4])[1] == (0,0,0,1)
+ @test gradient(xs -> hvcat((2,2),xs...)[1,1], [1,2,3,4])[1] == [1,0,0,0]
+ @test gradient(xs -> hvcat((2,2),xs...)[2,1], [1,2,3,4])[1] == [0,0,1,0]
+ @test gradient(xs -> hvcat((2,2),xs...)[1,2], [1,2,3,4])[1] == [0,1,0,0]
+ @test gradient(xs -> hvcat((2,2),xs...)[2,2], [1,2,3,4])[1] == [0,0,0,1]
# https://github.com/FluxML/Zygote.jl/issues/513
@test gradient(x -> hvcat((2,2),1,2,3,x)[4], 4.0) == (1.0,)
end
@@ -1375,10 +1375,10 @@ using Zygote: Buffer
@test gs[1] ≈ map(x -> one.(x), p)
@test gs[2] ≈ one.(r)
- p = [rand(3,3), rand(3,3)] # redefine `p` after mutation
- gs = gradient(x -> sum(pop!(x)), p)
- @test length(gs[1]) == 2
- @test gs[1][1] == one.(p[1])
+ # p = [rand(3,3), rand(3,3)] # redefine `p` after mutation
+ # gs = gradient(x -> sum(pop!(x)), p)
+ # @test length(gs[1]) == 2
+ # @test gs[1][1] == one.(p[1])
end
end
@@ -1403,6 +1403,17 @@ end
end
@testset "AbstractFFTs" begin
+
+ # Many of these tests check a complex gradient to a function with real input. This is now
+ # clamped to real by ProjectTo, but to run the old tests, use here the old gradient function:
+ function oldgradient(f, args...)
+ y, back = Zygote.pullback(f, args...)
+ back(Zygote.sensitivity(y))
+ end
+ # Eventually these rules and tests will be moved to ChainRules.jl, at which point the tests
+ # can be updated to use real / complex consistently.
+ # https://github.com/JuliaMath/AbstractFFTs.jl/pull/58
+
findicateMat(i,j,n1,n2) = [(k==i) && (l==j) ? 1.0 : 0.0 for k=1:n1,
l=1:n2]
mirrorIndex(i,N) = i - 2*max(0,i - (N>>1+1))
@@ -1415,11 +1426,11 @@ end
indicateMat = [(k==i) && (l==j) ? 1.0 : 0.0 for k=1:size(X, 1),
l=1:size(X,2)]
# gradient of ifft(fft) must be (approximately) 1 (for various cases)
- @test gradient((X)->real.(ifft(fft(X))[i, j]), X)[1] ≈ indicateMat
+ @test oldgradient((X)->real.(ifft(fft(X))[i, j]), X)[1] ≈ indicateMat
# same for the inverse
- @test gradient((X̂)->real.(fft(ifft(X̂))[i, j]), X̂)[1] ≈ indicateMat
+ @test oldgradient((X̂)->real.(fft(ifft(X̂))[i, j]), X̂)[1] ≈ indicateMat
# same for rfft(irfft)
- @test gradient((X)->real.(irfft(rfft(X), size(X,1)))[i, j], X)[1] ≈ real.(indicateMat)
+ @test oldgradient((X)->real.(irfft(rfft(X), size(X,1)))[i, j], X)[1] ≈ real.(indicateMat)
# rfft isn't actually surjective, so rffft(irfft) can't really be tested this way.
# the gradients are actually just evaluating the inverse transform on the
@@ -1438,22 +1449,22 @@ end
((K)->(irfft(K,sizeX[1])), 1/N * rfft(indicateMat),
zeros(size(X̂r)), plan_rfft(X), i, X̂r)]
for (trans, solRe, solIm, P, mI, evalX) in listOfSols
- @test gradient((X)->real.(trans(X))[mI, j], evalX)[1] ≈
+ @test oldgradient((X)->real.(trans(X))[mI, j], evalX)[1] ≈
solRe
- @test gradient((X)->imag.(trans(X))[mI, j], evalX)[1] ≈
+ @test oldgradient((X)->imag.(trans(X))[mI, j], evalX)[1] ≈
solIm
if typeof(P) <:AbstractFFTs.Plan && maximum(trans .== [fft,rfft])
- @test gradient((X)->real.(P * X)[mI, j], evalX)[1] ≈
+ @test oldgradient((X)->real.(P * X)[mI, j], evalX)[1] ≈
solRe
- @test gradient((X)->imag.(P * X)[mI, j], evalX)[1] ≈
+ @test oldgradient((X)->imag.(P * X)[mI, j], evalX)[1] ≈
solIm
elseif typeof(P) <: AbstractFFTs.Plan
- @test gradient((X)->real.(P \ X)[mI, j], evalX)[1] ≈
+ @test oldgradient((X)->real.(P \ X)[mI, j], evalX)[1] ≈
solRe
# for whatever reason the rfft_plan doesn't handle this case well,
# even though irfft does
if eltype(evalX) <: Real
- @test gradient((X)->imag.(P \ X)[mI, j], evalX)[1] ≈
+ @test oldgradient((X)->imag.(P \ X)[mI, j], evalX)[1] ≈
solIm
end
end
@@ -1464,47 +1475,47 @@ end
x = [-0.353213 -0.789656 -0.270151; -0.95719 -1.27933 0.223982]
# check ffts for individual dimensions
for trans in (fft, ifft, bfft)
- @test gradient((x)->sum(abs.(trans(x))), x)[1] ≈
- gradient( (x) -> sum(abs.(trans(trans(x,1),2))), x)[1]
+ @test oldgradient((x)->sum(abs.(trans(x))), x)[1] ≈
+ oldgradient( (x) -> sum(abs.(trans(trans(x,1),2))), x)[1]
# switch sum abs order
- @test gradient((x)->abs(sum((trans(x)))),x)[1] ≈
- gradient( (x) -> abs(sum(trans(trans(x,1),2))), x)[1]
+ @test oldgradient((x)->abs(sum((trans(x)))),x)[1] ≈
+ oldgradient( (x) -> abs(sum(trans(trans(x,1),2))), x)[1]
# dims parameter for the function
- @test gradient((x, dims)->sum(abs.(trans(x,dims))), x, (1,2))[1] ≈
- gradient( (x) -> sum(abs.(trans(x))), x)[1]
+ @test oldgradient((x, dims)->sum(abs.(trans(x,dims))), x, (1,2))[1] ≈
+ oldgradient( (x) -> sum(abs.(trans(x))), x)[1]
# (1,2) should be the same as no index
- @test gradient( (x) -> sum(abs.(trans(x,(1,2)))), x)[1] ≈
- gradient( (x) -> sum(abs.(trans(trans(x,1),2))), x)[1]
+ @test oldgradient( (x) -> sum(abs.(trans(x,(1,2)))), x)[1] ≈
+ oldgradient( (x) -> sum(abs.(trans(trans(x,1),2))), x)[1]
@test gradcheck(x->sum(abs.(trans(x))), x)
@test gradcheck(x->sum(abs.(trans(x, 2))), x)
end
- @test gradient((x)->sum(abs.(rfft(x))), x)[1] ≈
- gradient( (x) -> sum(abs.(fft(rfft(x,1),2))), x)[1]
- @test gradient((x, dims)->sum(abs.(rfft(x,dims))), x, (1,2))[1] ≈
- gradient( (x) -> sum(abs.(rfft(x))), x)[1]
+ @test oldgradient((x)->sum(abs.(rfft(x))), x)[1] ≈
+ oldgradient( (x) -> sum(abs.(fft(rfft(x,1),2))), x)[1]
+ @test oldgradient((x, dims)->sum(abs.(rfft(x,dims))), x, (1,2))[1] ≈
+ oldgradient( (x) -> sum(abs.(rfft(x))), x)[1]
# Test type stability of fft
x = randn(Float64,16)
P = plan_fft(x)
- @test typeof(gradient(x->sum(abs2,ifft(fft(x))),x)[1]) == Array{Complex{Float64},1}
- @test typeof(gradient(x->sum(abs2,P\(P*x)),x)[1]) == Array{Complex{Float64},1}
- @test typeof(gradient(x->sum(abs2,irfft(rfft(x),16)),x)[1]) == Array{Float64,1}
+ @test typeof(oldgradient(x->sum(abs2,ifft(fft(x))),x)[1]) == Array{Complex{Float64},1}
+ @test typeof(oldgradient(x->sum(abs2,P\(P*x)),x)[1]) == Array{Complex{Float64},1}
+ @test typeof(oldgradient(x->sum(abs2,irfft(rfft(x),16)),x)[1]) == Array{Float64,1}
x = randn(Float64,16,16)
- @test typeof(gradient(x->sum(abs2,ifft(fft(x,1),1)),x)[1]) == Array{Complex{Float64},2}
- @test typeof(gradient(x->sum(abs2,irfft(rfft(x,1),16,1)),x)[1]) == Array{Float64,2}
+ @test typeof(oldgradient(x->sum(abs2,ifft(fft(x,1),1)),x)[1]) == Array{Complex{Float64},2}
+ @test typeof(oldgradient(x->sum(abs2,irfft(rfft(x,1),16,1)),x)[1]) == Array{Float64,2}
x = randn(Float32,16)
P = plan_fft(x)
- @test typeof(gradient(x->sum(abs2,ifft(fft(x))),x)[1]) == Array{Complex{Float32},1}
- @test typeof(gradient(x->sum(abs2,P\(P*x)),x)[1]) == Array{Complex{Float32},1}
- @test typeof(gradient(x->sum(abs2,irfft(rfft(x),16)),x)[1]) == Array{Float32,1}
+ @test typeof(oldgradient(x->sum(abs2,ifft(fft(x))),x)[1]) == Array{Complex{Float32},1}
+ @test typeof(oldgradient(x->sum(abs2,P\(P*x)),x)[1]) == Array{Complex{Float32},1}
+ @test typeof(oldgradient(x->sum(abs2,irfft(rfft(x),16)),x)[1]) == Array{Float32,1}
x = randn(Float32,16,16)
- @test typeof(gradient(x->sum(abs2,ifft(fft(x,1),1)),x)[1]) == Array{Complex{Float32},2}
- @test typeof(gradient(x->sum(abs2,irfft(rfft(x,1),16,1)),x)[1]) == Array{Float32,2}
+ @test typeof(oldgradient(x->sum(abs2,ifft(fft(x,1),1)),x)[1]) == Array{Complex{Float32},2}
+ @test typeof(oldgradient(x->sum(abs2,irfft(rfft(x,1),16,1)),x)[1]) == Array{Float32,2}
end
@testset "FillArrays" begin
@@ -1668,7 +1679,7 @@ end
# check that type is not unnecessarily promoted
# https://github.com/FluxML/Zygote.jl/issues/663
@test gradient(norm, randn(Float32, 2, 2)) isa Tuple{Matrix{Float32}}
- @test gradient(norm, randn(Float32, 2, 2), 3) isa Tuple{Matrix{Float32},Float32}
+ @test gradient(norm, randn(Float32, 2, 2), 3) isa Tuple{Matrix{Float32},Float64}
@test gradient(norm, randn(Float32, 2, 2), 3f0) isa Tuple{Matrix{Float32},Float32}
@test gradient(norm, randn(ComplexF32, 2, 2), 3.5f0) isa Tuple{Matrix{ComplexF32},Float32}
diff --git a/test/structures.jl b/test/structures.jl
index 37c0e246a..5a951a621 100644
--- a/test/structures.jl
+++ b/test/structures.jl
@@ -53,6 +53,7 @@ struct A594 x::Float64 end
Y = randn(2,2)
∇ = gradient(g,X,Y)
@test ∇[1] == [(x = 2.0,); (x = 2.0,)]
+ @test vec(∇[1]) == [(x = 2.0,); (x = 2.0,)]
@test ∇[2] == [1 1; 1 1]
end
diff --git a/test/utils.jl b/test/utils.jl
index 70a8ebd63..b6d6ed018 100644
--- a/test/utils.jl
+++ b/test/utils.jl
@@ -19,16 +19,22 @@ using Zygote: hessian_dual, hessian_reverse
@test_throws Exception hess(identity, randn(2))
end
-@testset "diagonal hessian" begin
+VERSION > v"1.6-" && @testset "diagonal hessian" begin
@test diaghessian(x -> x[1]*x[2]^2, [1, pi]) == ([0, 2],)
- xs, y = randn(2,3), rand()
- f34(xs, y) = xs[1] * (sum(xs .^ (1:3)') + y^4) # non-diagonal Hessian, two arguments
-
- dx, dy = diaghessian(f34, xs, y)
- @test size(dx) == size(xs)
- @test vec(dx) ≈ diag(hessian(x -> f34(x,y), xs))
- @test dy ≈ hessian(y -> f34(xs,y), y)
+ if VERSION > v"1.6-"
+ # Gradient of ^ may contain log(complex(...)), which interacts badly with Dual below Julia 1.6:
+ # julia> log(ForwardDiff.Dual(1,0) + 0im) # ERROR: StackOverflowError:
+ # https://github.com/JuliaDiff/ChainRules.jl/issues/525
+ # Fixed in 1.6 by: https://github.com/JuliaLang/julia/pull/36030
+ xs, y = randn(2,3), rand()
+ f34(xs, y) = xs[1] * (sum(xs .^ (1:3)') + y^4) # non-diagonal Hessian, two arguments
+
+ dx, dy = diaghessian(f34, xs, y)
+ @test size(dx) == size(xs)
+ @test vec(dx) ≈ diag(hessian(x -> f34(x,y), xs))
+ @test dy ≈ hessian(y -> f34(xs,y), y)
+ end
zs = randn(7,13) # test chunk mode
@test length(zs) > ForwardDiff.DEFAULT_CHUNK_THRESHOLD
@@ -67,6 +73,7 @@ end
j5 = jacobian((x,y) -> hcat(x[1], y), fill(pi), exp(1)) # zero-array
@test j5[1] isa Matrix
@test vec(j5[1]) == [1, 0]
+ @test j5[2] == [0, 1]
@test_throws ArgumentError jacobian(identity, [1,2,3+im])
@test_throws ArgumentError jacobian(sum, [1,2,3+im]) # scalar, complex
From cd177371506877ba093277adcac6af2bc86e065a Mon Sep 17 00:00:00 2001
From: willtebbutt
Date: Fri, 24 Sep 2021 11:05:32 +0100
Subject: [PATCH 215/490] Update README
The current README is a bit misleading in terms of performance, because Zygote really doesn't have good performance for control flow. I'm open to other suggestions for re-wording, but it seems reasonable that we temper what is currently there a bit.
---
README.md | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/README.md b/README.md
index 6b2a6517d..866c6a617 100644
--- a/README.md
+++ b/README.md
@@ -29,7 +29,7 @@ top:
"Source-to-source" means that Zygote hooks into Julia's compiler, and generates the backwards pass for you – as if you had written it by hand.
-Without compromising on performance, Zygote supports the full flexibility and dynamism of the Julia language, including control flow, recursion, closures, structs, dictionaries, and more.
+Zygote supports the full flexibility and dynamism of the Julia language, including control flow, recursion, closures, structs, dictionaries, and more.
```julia
julia> fs = Dict("sin" => sin, "cos" => cos, "tan" => tan);
From 1f61e2c6ffde3cfc1f519fda9bcaeb9175b1031d Mon Sep 17 00:00:00 2001
From: WT
Date: Fri, 24 Sep 2021 22:07:39 +0100
Subject: [PATCH 216/490] Add ProjectTo method
---
src/compiler/chainrules.jl | 3 +++
test/chainrules.jl | 6 ++++++
2 files changed, 9 insertions(+)
diff --git a/src/compiler/chainrules.jl b/src/compiler/chainrules.jl
index e879af3f8..2bc122a13 100644
--- a/src/compiler/chainrules.jl
+++ b/src/compiler/chainrules.jl
@@ -150,6 +150,9 @@ _project(x::AbstractArray, dx::Tuple) = _project(x, reshape(collect(dx), axes(x)
# CRC likes Tangent{<:Complex}, but Zygote makes Tangent{Any}
(project::ProjectTo{<:Complex})(dx::Tangent) = project(Complex(dx.re, dx.im))
+# CRC likes Tangent{AbstractArray}, but Zygote makes Tangent{Any}
+(project::ProjectTo{AbstractArray})(dx::Tangent) = dx
+
"""
ZBack{F}(back) <: Function
diff --git a/test/chainrules.jl b/test/chainrules.jl
index b87a9ea3e..eaf05180a 100644
--- a/test/chainrules.jl
+++ b/test/chainrules.jl
@@ -335,6 +335,12 @@ end
test_rrule(ZygoteRuleConfig(), sum, x -> cbrt(x), randn(5))
test_rrule(ZygoteRuleConfig(), sum, x -> cbrt(x), randn(5); rrule_f=rrule_via_ad)
end
+
+ @testset "ProjectTo{AbstractArray}(::Tangent{Any})" begin
+ X = UpperHessenberg(randn(5, 5))
+ dX = Tangent{Any}(element=randn(5, 5))
+ @test ProjectTo(X)(dX) === dX
+ end
end
@testset "FastMath support" begin
From 54b0d90ec461d0bfe20ab3bd0d4a3f4fcaa9fcde Mon Sep 17 00:00:00 2001
From: WT
Date: Fri, 24 Sep 2021 22:07:54 +0100
Subject: [PATCH 217/490] Bump patch version
---
Project.toml | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/Project.toml b/Project.toml
index 56d086f99..758f2090b 100644
--- a/Project.toml
+++ b/Project.toml
@@ -1,6 +1,6 @@
name = "Zygote"
uuid = "e88e6eb3-aa80-5325-afca-941959d7151f"
-version = "0.6.22"
+version = "0.6.23"
[deps]
AbstractFFTs = "621f4979-c628-5d54-868e-fcf4e3e8185c"
From b6bde08771e678fe78f9301b7539cc837f7ccb81 Mon Sep 17 00:00:00 2001
From: WT
Date: Fri, 24 Sep 2021 22:09:12 +0100
Subject: [PATCH 218/490] Note down issue in comment
---
test/chainrules.jl | 1 +
1 file changed, 1 insertion(+)
diff --git a/test/chainrules.jl b/test/chainrules.jl
index eaf05180a..80da51743 100644
--- a/test/chainrules.jl
+++ b/test/chainrules.jl
@@ -336,6 +336,7 @@ end
test_rrule(ZygoteRuleConfig(), sum, x -> cbrt(x), randn(5); rrule_f=rrule_via_ad)
end
+ # See https://github.com/FluxML/Zygote.jl/issues/1078
@testset "ProjectTo{AbstractArray}(::Tangent{Any})" begin
X = UpperHessenberg(randn(5, 5))
dX = Tangent{Any}(element=randn(5, 5))
From 79454f33779b2263ef8c9ba46831d38fe26bfb34 Mon Sep 17 00:00:00 2001
From: willtebbutt
Date: Fri, 24 Sep 2021 22:24:52 +0100
Subject: [PATCH 219/490] Update src/compiler/chainrules.jl
Co-authored-by: Lyndon White
---
src/compiler/chainrules.jl | 3 +++
1 file changed, 3 insertions(+)
diff --git a/src/compiler/chainrules.jl b/src/compiler/chainrules.jl
index 2bc122a13..8c3f1a84d 100644
--- a/src/compiler/chainrules.jl
+++ b/src/compiler/chainrules.jl
@@ -151,6 +151,9 @@ _project(x::AbstractArray, dx::Tuple) = _project(x, reshape(collect(dx), axes(x)
(project::ProjectTo{<:Complex})(dx::Tangent) = project(Complex(dx.re, dx.im))
# CRC likes Tangent{AbstractArray}, but Zygote makes Tangent{Any}
+# in particular this would hit https://github.com/JuliaDiff/ChainRulesCore.jl/blob/2ec2549b73b22bc08f554dae864fb650cfb9c3d7/src/projection.jl#L139
+# if we were not losing track of the Primal in the Tangent
+# This type piracy is just giving up that safety check.
(project::ProjectTo{AbstractArray})(dx::Tangent) = dx
"""
From 993eb16606893b0eece567fa48e3ed7eac250c72 Mon Sep 17 00:00:00 2001
From: Gabriel Birnbaum
Date: Mon, 27 Sep 2021 08:24:06 +0200
Subject: [PATCH 220/490] added a dispatch for tuples to prevent LoadError
---
src/lib/base.jl | 1 +
1 file changed, 1 insertion(+)
diff --git a/src/lib/base.jl b/src/lib/base.jl
index b7c2072a2..370924656 100644
--- a/src/lib/base.jl
+++ b/src/lib/base.jl
@@ -120,6 +120,7 @@ end
@adjoint function pairs(t::NamedTuple{N}) where N
pairs_namedtuple_pullback(dx::NamedTuple) = (dx.data,)
+ pairs_namedtuple_pullback(dx::Tuple) = (dx.data,)
function pairs_namedtuple_pullback(Δ::Dict)
t0 = map(zero, t)
From 14cedb5baf4588210e03326788e5cd5d9184b5ab Mon Sep 17 00:00:00 2001
From: Gabriel Birnbaum
Date: Tue, 28 Sep 2021 07:53:05 +0200
Subject: [PATCH 221/490] fixed dispatch
---
src/lib/base.jl | 3 ++-
1 file changed, 2 insertions(+), 1 deletion(-)
diff --git a/src/lib/base.jl b/src/lib/base.jl
index 370924656..6f590118c 100644
--- a/src/lib/base.jl
+++ b/src/lib/base.jl
@@ -120,7 +120,8 @@ end
@adjoint function pairs(t::NamedTuple{N}) where N
pairs_namedtuple_pullback(dx::NamedTuple) = (dx.data,)
- pairs_namedtuple_pullback(dx::Tuple) = (dx.data,)
+
+ pairs_namedtuple_pullback(dx::Tuple) = isempty(dx) ? (dx,) : (dx[1],)
function pairs_namedtuple_pullback(Δ::Dict)
t0 = map(zero, t)
From 6f41395df6a9c19a9b980e782b9dfff4f5bb372f Mon Sep 17 00:00:00 2001
From: Gabriel Birnbaum
Date: Tue, 28 Sep 2021 08:30:27 +0200
Subject: [PATCH 222/490] handle all tuples, not just empty ones
---
src/lib/base.jl | 10 ++++++++--
1 file changed, 8 insertions(+), 2 deletions(-)
diff --git a/src/lib/base.jl b/src/lib/base.jl
index 6f590118c..c1a7cd38f 100644
--- a/src/lib/base.jl
+++ b/src/lib/base.jl
@@ -120,8 +120,14 @@ end
@adjoint function pairs(t::NamedTuple{N}) where N
pairs_namedtuple_pullback(dx::NamedTuple) = (dx.data,)
-
- pairs_namedtuple_pullback(dx::Tuple) = isempty(dx) ? (dx,) : (dx[1],)
+
+ function pairs_namedtuple_pullback(dx::Tuple)
+ t0 = map(zero, t)
+ for (i, v) in enumerate(dx)
+ t0 = NamedTuple{N}(Base.setindex((t0...,), v, i))
+ end
+ return (t0,)
+ end
function pairs_namedtuple_pullback(Δ::Dict)
t0 = map(zero, t)
From 8acd8de258c12271df82630cf3975599a87bc5ff Mon Sep 17 00:00:00 2001
From: Michael Abbott <32575566+mcabbott@users.noreply.github.com>
Date: Tue, 28 Sep 2021 12:56:08 -0400
Subject: [PATCH 223/490] make OneElement constructor safer
---
src/lib/array.jl | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/src/lib/array.jl b/src/lib/array.jl
index 9bec64b95..4f04809b7 100644
--- a/src/lib/array.jl
+++ b/src/lib/array.jl
@@ -50,7 +50,7 @@ struct OneElement{T,N,I,A} <: AbstractArray{T,N}
val::T
ind::I
axes::A
- OneElement(val::T, ind::I, axes::A) where {T<:Number, I<:NTuple{N,Int}, A} where {N} = new{T,N,I,A}(val, ind, axes)
+ OneElement(val::T, ind::I, axes::A) where {T<:Number, I<:NTuple{N,Int}, A<:NTuple{N,AbstractUnitRange}} where {N} = new{T,N,I,A}(val, ind, axes)
end
Base.size(A::OneElement) = map(length, A.axes)
Base.axes(A::OneElement) = A.axes
From 4e439224a49461ecb43161e72c8bc1421077d809 Mon Sep 17 00:00:00 2001
From: Michael Abbott <32575566+mcabbott@users.noreply.github.com>
Date: Tue, 28 Sep 2021 19:57:50 -0400
Subject: [PATCH 224/490] Remove incorrect `push!` and `pop!` gradients (#1025)
* fix push + pop gradient for vector of arrays, add real tests
* tweak
* allow only trivial gradients in push!(::Params) etc.
* generalise, and fail
* fix
* rm gradients which don't work
* rm unused methods from push(IdSet) gradient
* restrict push error to arrays, rm adjoint for params
* Update test/features.jl
Co-authored-by: Brian Chen
Co-authored-by: Brian Chen
---
src/compiler/interface.jl | 14 --------------
src/lib/array.jl | 23 ++---------------------
test/gradcheck.jl | 16 ----------------
3 files changed, 2 insertions(+), 51 deletions(-)
diff --git a/src/compiler/interface.jl b/src/compiler/interface.jl
index 9dc934a49..e210e65b6 100644
--- a/src/compiler/interface.jl
+++ b/src/compiler/interface.jl
@@ -183,20 +183,6 @@ function Base.push!(ps::Params, x)
return ps
end
-@adjoint! function Base.push!(xs::IdSet, x...)
- l = length(x)
- push!(xs, x...), Δ -> begin
- (Δ, ntuple(_ -> nothing, l)...)
- end
-end
-
-@adjoint! function Base.push!(xs::Params, x::AbstractArray{T}...) where T
- sz_x = size.(x)
- push!(xs, x...), Δ -> begin
- (Δ, map(x -> Ones{T}(x...), sz_x)...)
- end
-end
-
Base.push!(ps::Params, x...) = (foreach(x -> push!(ps, x), x); ps)
function Base.delete!(ps::Params, x)
diff --git a/src/lib/array.jl b/src/lib/array.jl
index 9bec64b95..035a1b239 100644
--- a/src/lib/array.jl
+++ b/src/lib/array.jl
@@ -74,27 +74,8 @@ _droplike(dy::Union{LinearAlgebra.Adjoint, LinearAlgebra.Transpose}, dxv::Abstra
_ -> error("Mutating arrays is not supported -- called copyto!(::$(typeof(xs)), _...)")
for f in [push!, pop!, pushfirst!, popfirst!]
- @eval @adjoint! $f(xs, x...) = $f(xs, x...),
- _ -> error("Mutating arrays is not supported -- called $($f)(::$(typeof(xs)), _...)")
-end
-
-# This is kind of bad, but at least we don't materialize the whole
-# array. Prefer to use `Buffer`
-# function _pullback(cx::Context, ::typeof(push!), xs::AbstractVector{<:AbstractArray}, x::AbstractArray{T}...) where T
-@adjoint! function push!(xs::AbstractVector{<:AbstractArray}, x::AbstractArray{T}...) where T
- sz_xs = size.(xs)
- sz_x = size.(x)
- push!(xs, x...), Δ -> begin
- (Δ, map(x -> Ones{T}(x...), sz_x)...)
- end
-end
-
-@adjoint! function pop!(xs::AbstractVector{<:AbstractArray{T}}) where T
- sz_xs = size.(xs)
- op = pop!(xs)
- op, Δ -> begin
- ([Ones{T}(sz...) for sz in sz_xs], )
- end
+ @eval @adjoint! $f(x::AbstractVector, ys...) = $f(x, ys...),
+ _ -> error("Mutating arrays is not supported -- called $($f)(::$(typeof(x)), _...)")
end
# General
diff --git a/test/gradcheck.jl b/test/gradcheck.jl
index af49b7697..87fe5f46f 100644
--- a/test/gradcheck.jl
+++ b/test/gradcheck.jl
@@ -1365,22 +1365,6 @@ using Zygote: Buffer
prod(copy(b))
end == (3,)
- @testset "Limited Mutation" begin
- p = [rand(3,3), rand(3,3)]
- r = rand(5,5)
-
- # TODO: ngradient cannot handle Vector{Array}
- gs = gradient((p,x) -> sum(sum.(push!(p,x))), p, r)
- @test length(p[end]) == length(gs[1][end])
- @test gs[1] ≈ map(x -> one.(x), p)
- @test gs[2] ≈ one.(r)
-
- # p = [rand(3,3), rand(3,3)] # redefine `p` after mutation
- # gs = gradient(x -> sum(pop!(x)), p)
- # @test length(gs[1]) == 2
- # @test gs[1][1] == one.(p[1])
- end
-
end
@testset "FillArrays" begin
From 355296e84f8b21c0f22204b0804af51653016137 Mon Sep 17 00:00:00 2001
From: Carlo Lucibello
Date: Wed, 29 Sep 2021 19:32:38 +0200
Subject: [PATCH 225/490] Update Project.toml
---
Project.toml | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/Project.toml b/Project.toml
index 758f2090b..aa48d5b68 100644
--- a/Project.toml
+++ b/Project.toml
@@ -1,6 +1,6 @@
name = "Zygote"
uuid = "e88e6eb3-aa80-5325-afca-941959d7151f"
-version = "0.6.23"
+version = "0.6.24"
[deps]
AbstractFFTs = "621f4979-c628-5d54-868e-fcf4e3e8185c"
From 818e3e43a37130c91f36679bd785e841bdf08da5 Mon Sep 17 00:00:00 2001
From: Gabriel Birnbaum
Date: Thu, 30 Sep 2021 10:28:10 +0200
Subject: [PATCH 226/490] clean up tuple dispatch
---
src/lib/base.jl | 5 +----
1 file changed, 1 insertion(+), 4 deletions(-)
diff --git a/src/lib/base.jl b/src/lib/base.jl
index c1a7cd38f..55d660279 100644
--- a/src/lib/base.jl
+++ b/src/lib/base.jl
@@ -122,10 +122,7 @@ end
pairs_namedtuple_pullback(dx::NamedTuple) = (dx.data,)
function pairs_namedtuple_pullback(dx::Tuple)
- t0 = map(zero, t)
- for (i, v) in enumerate(dx)
- t0 = NamedTuple{N}(Base.setindex((t0...,), v, i))
- end
+ t0 = isempty(dx) ? () : NamedTuple{N}(values(dx))
return (t0,)
end
From 7794f810f03cb5db639d3f1ba5cf9eb15238c011 Mon Sep 17 00:00:00 2001
From: WT
Date: Fri, 1 Oct 2021 18:00:58 +0100
Subject: [PATCH 227/490] map type stability
---
src/lib/array.jl | 36 ++++++++++++++++--------------------
test/gradcheck.jl | 22 ++++++++++++++++++++++
2 files changed, 38 insertions(+), 20 deletions(-)
diff --git a/src/lib/array.jl b/src/lib/array.jl
index 035a1b239..e4c567cd9 100644
--- a/src/lib/array.jl
+++ b/src/lib/array.jl
@@ -180,27 +180,23 @@ _tryreverse(m, x) = x
_tryreverse(m::typeof(map), x::Union{AbstractVector, Tuple}) = reverse(x)
for (mapfunc,∇mapfunc) in [(:map,:∇map),(:pmap,:∇pmap)]
- @eval function $∇mapfunc(cx, f::F, args...) where {F}
+ @eval function $∇mapfunc(cx, f::F, args::Vararg{Any, N}) where {F, N}
ys_and_backs = $mapfunc((args...) -> _pullback(cx, f, args...), args...)
- if isempty(ys_and_backs)
- ys_and_backs, _ -> nothing
- else
- ys = map(first, ys_and_backs)
- ys, function (Δ)
- isnothing(Δ) && return nothing
- if Base.issingletontype(F) && length(args) == 1
- Δarg = $mapfunc(((_,pb), δ) -> last(pb(δ)), ys_and_backs, Δ) # No unzip needed
- (nothing, Δarg)
- elseif Base.issingletontype(F) # Ensures `f` is pure: nothing captured & no state
- Δargs = unzip($mapfunc(((_,pb), δ) -> Base.tail(pb(δ)), ys_and_backs, Δ))
- (nothing, Δargs...)
- else
- # Apply pullbacks in reverse order. Needed for correctness if `f` is stateful.
- Δf_and_args_zipped = $mapfunc(((_,pb), δ) -> pb(δ), _tryreverse($mapfunc, ys_and_backs, Δ)...)
- Δf_and_args = unzip(_tryreverse($mapfunc, Δf_and_args_zipped))
- Δf = reduce(accum, Δf_and_args[1])
- (Δf, Δf_and_args[2:end]...)
- end
+ ys = map(first, ys_and_backs)
+ ys, function (Δ)
+ isnothing(Δ) && return nothing
+ if Base.issingletontype(F) && length(args) == 1
+ Δarg = $mapfunc(((_,pb), δ) -> last(pb(δ)), ys_and_backs, Δ) # No unzip needed
+ (nothing, Δarg)
+ elseif Base.issingletontype(F) # Ensures `f` is pure: nothing captured & no state
+ Δargs = _unzip($mapfunc(((_,pb), δ) -> Base.tail(pb(δ)), ys_and_backs, Δ), Val(N))
+ (nothing, Δargs...)
+ else
+ # Apply pullbacks in reverse order. Needed for correctness if `f` is stateful.
+ Δf_and_args_zipped = $mapfunc(((_,pb), δ) -> pb(δ), _tryreverse($mapfunc, ys_and_backs, Δ)...)
+ Δf_and_args = _unzip(_tryreverse($mapfunc, Δf_and_args_zipped), Val(N + 1))
+ Δf = reduce(accum, Δf_and_args[1]; init=nothing)
+ (Δf, Δf_and_args[2:end]...)
end
end
end
diff --git a/test/gradcheck.jl b/test/gradcheck.jl
index 87fe5f46f..287a83093 100644
--- a/test/gradcheck.jl
+++ b/test/gradcheck.jl
@@ -288,6 +288,28 @@ for mapfunc in [map,pmap]
Δy = randn(3)
@test first(pb((Δy..., ))) ≈ first(pb(Δy))
end
+
+ @testset "empty tuples" begin
+ out, pb = Zygote.pullback(map, -, ())
+ @test pb(out) === (nothing, ())
+
+ out, pb = Zygote.pullback(map, +, (), ())
+ @test pb(()) === (nothing, (), ())
+
+ function build_foo(z)
+ foo(x) = x * z
+ return foo
+ end
+ out, pb = Zygote.pullback(map, build_foo(5.0), ())
+ @test pb(()) === (nothing, ())
+ end
+end
+
+# Check that map infers correctly. pmap still doesn't infer.
+@testset "map inference" begin
+ @inferred Zygote._pullback(Zygote.Context(), map, sin, Float64[])
+ out, pb = Zygote._pullback(Zygote.Context(), map, sin, Float64[])
+ @inferred pb(Float64[])
end
@testset "Alternative Pmap Dispatch" begin
From 635682def7a850054a0eedde9e1837ee840613bb Mon Sep 17 00:00:00 2001
From: WT
Date: Fri, 1 Oct 2021 18:02:24 +0100
Subject: [PATCH 228/490] Bump patch
---
Project.toml | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/Project.toml b/Project.toml
index aa48d5b68..f6373cf19 100644
--- a/Project.toml
+++ b/Project.toml
@@ -1,6 +1,6 @@
name = "Zygote"
uuid = "e88e6eb3-aa80-5325-afca-941959d7151f"
-version = "0.6.24"
+version = "0.6.25"
[deps]
AbstractFFTs = "621f4979-c628-5d54-868e-fcf4e3e8185c"
From 7f274ac174cf4ad4b8a3c47e3d3b2f66fcad274e Mon Sep 17 00:00:00 2001
From: WT
Date: Sat, 2 Oct 2021 13:31:04 +0100
Subject: [PATCH 229/490] Extra tests and bug fix
---
src/lib/array.jl | 10 ++++++++--
test/gradcheck.jl | 27 ++++++++++++++++++++++++---
2 files changed, 32 insertions(+), 5 deletions(-)
diff --git a/src/lib/array.jl b/src/lib/array.jl
index e4c567cd9..0838a4628 100644
--- a/src/lib/array.jl
+++ b/src/lib/array.jl
@@ -179,6 +179,12 @@ end
_tryreverse(m, x) = x
_tryreverse(m::typeof(map), x::Union{AbstractVector, Tuple}) = reverse(x)
+# Sometimes a pullback doesn't return a full vector of nothings, but rather returns only a
+# single nothing to say "all arguments have zero cotangent". This function is needed to
+# account for that inside the pullback for map.
+last_or_nothing(::Nothing) = nothing
+last_or_nothing(x) = last(x)
+
for (mapfunc,∇mapfunc) in [(:map,:∇map),(:pmap,:∇pmap)]
@eval function $∇mapfunc(cx, f::F, args::Vararg{Any, N}) where {F, N}
ys_and_backs = $mapfunc((args...) -> _pullback(cx, f, args...), args...)
@@ -186,10 +192,10 @@ for (mapfunc,∇mapfunc) in [(:map,:∇map),(:pmap,:∇pmap)]
ys, function (Δ)
isnothing(Δ) && return nothing
if Base.issingletontype(F) && length(args) == 1
- Δarg = $mapfunc(((_,pb), δ) -> last(pb(δ)), ys_and_backs, Δ) # No unzip needed
+ Δarg = $mapfunc(((_,pb), δ) -> last_or_nothing(pb(δ)), ys_and_backs, Δ) # No unzip needed
(nothing, Δarg)
elseif Base.issingletontype(F) # Ensures `f` is pure: nothing captured & no state
- Δargs = _unzip($mapfunc(((_,pb), δ) -> Base.tail(pb(δ)), ys_and_backs, Δ), Val(N))
+ Δargs = _unzip($mapfunc(((_,pb), δ) -> tailmemaybe(pb(δ)), ys_and_backs, Δ), Val(N))
(nothing, Δargs...)
else
# Apply pullbacks in reverse order. Needed for correctness if `f` is stateful.
diff --git a/test/gradcheck.jl b/test/gradcheck.jl
index 287a83093..9ff9a641a 100644
--- a/test/gradcheck.jl
+++ b/test/gradcheck.jl
@@ -303,13 +303,34 @@ for mapfunc in [map,pmap]
out, pb = Zygote.pullback(map, build_foo(5.0), ())
@test pb(()) === (nothing, ())
end
+
+ @testset "Vector{Nothing} cotangent" begin
+ out, pb = Zygote.pullback(map, -, randn(5))
+ Δ = Vector{Nothing}(nothing, 5)
+ @test pb(Δ)[2] isa Vector{Nothing}
+
+ out, pb = Zygote.pullback(map, +, randn(5), randn(5))
+ @test pb(Δ)[2] isa Vector{Nothing}
+ @test pb(Δ)[3] isa Vector{Nothing}
+ end
end
# Check that map infers correctly. pmap still doesn't infer.
@testset "map inference" begin
- @inferred Zygote._pullback(Zygote.Context(), map, sin, Float64[])
- out, pb = Zygote._pullback(Zygote.Context(), map, sin, Float64[])
- @inferred pb(Float64[])
+ @testset "$name" for (name, f, ȳ, xs) in [
+ ("unary empty vector", sin, Float64[], (Float64[], )),
+ ("unary vector", sin, randn(3), (randn(3), )),
+ ("unary empty tuple", sin, (), ((), )),
+ ("unary tuple", sin, (randn(), randn()), ((randn(), randn()), )),
+ ("binary empty vector", +, Float64[], (Float64[], Float64[])),
+ ("binary vector", +, randn(2), (randn(2), randn(2))),
+ ("binary empty tuple", +, (), ((), ())),
+ ("binary tuple", +, (randn(), randn()), ((randn(), randn()), (randn(), randn()))),
+ ]
+ @inferred Zygote._pullback(Zygote.Context(), map, f, xs...)
+ y, pb = Zygote._pullback(Zygote.Context(), map, f, xs...)
+ @inferred pb(ȳ)
+ end
end
@testset "Alternative Pmap Dispatch" begin
From 576af80e018a7e9bc778f31d471253a7e88b8ffb Mon Sep 17 00:00:00 2001
From: WT
Date: Sat, 2 Oct 2021 13:35:00 +0100
Subject: [PATCH 230/490] Additional Vector{Nothing} cotangent test
---
test/gradcheck.jl | 13 ++++++++++++-
1 file changed, 12 insertions(+), 1 deletion(-)
diff --git a/test/gradcheck.jl b/test/gradcheck.jl
index 9ff9a641a..e34369f7c 100644
--- a/test/gradcheck.jl
+++ b/test/gradcheck.jl
@@ -305,13 +305,24 @@ for mapfunc in [map,pmap]
end
@testset "Vector{Nothing} cotangent" begin
- out, pb = Zygote.pullback(map, -, randn(5))
Δ = Vector{Nothing}(nothing, 5)
+
+ # Unary stateless
+ out, pb = Zygote.pullback(map, -, randn(5))
@test pb(Δ)[2] isa Vector{Nothing}
+ # Binary stateless
out, pb = Zygote.pullback(map, +, randn(5), randn(5))
@test pb(Δ)[2] isa Vector{Nothing}
@test pb(Δ)[3] isa Vector{Nothing}
+
+ # Stateful
+ function build_foo(z)
+ foo(x) = x * z
+ return foo
+ end
+ out, pb = Zygote.pullback(map, build_foo(5.0), randn(5))
+ @test pb(Δ)[2] isa Vector{Nothing}
end
end
From 36572ae946acff72c12cc4bdac0d7150911e61f5 Mon Sep 17 00:00:00 2001
From: WT
Date: Sat, 2 Oct 2021 13:36:05 +0100
Subject: [PATCH 231/490] Fix typo
---
src/lib/array.jl | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/src/lib/array.jl b/src/lib/array.jl
index 0838a4628..4b899d571 100644
--- a/src/lib/array.jl
+++ b/src/lib/array.jl
@@ -179,7 +179,7 @@ end
_tryreverse(m, x) = x
_tryreverse(m::typeof(map), x::Union{AbstractVector, Tuple}) = reverse(x)
-# Sometimes a pullback doesn't return a full vector of nothings, but rather returns only a
+# Sometimes a pullback doesn't return a Tuple, but rather returns only a
# single nothing to say "all arguments have zero cotangent". This function is needed to
# account for that inside the pullback for map.
last_or_nothing(::Nothing) = nothing
From 4f7d5d1aacc7e64b35e5039a78411641daa6a875 Mon Sep 17 00:00:00 2001
From: Michael Abbott <32575566+mcabbott@users.noreply.github.com>
Date: Wed, 29 Sep 2021 10:07:43 -0400
Subject: [PATCH 232/490] fix 1086
---
src/lib/broadcast.jl | 4 +++-
test/features.jl | 5 +++++
2 files changed, 8 insertions(+), 1 deletion(-)
diff --git a/src/lib/broadcast.jl b/src/lib/broadcast.jl
index 4e7a3a1cc..3affebd92 100644
--- a/src/lib/broadcast.jl
+++ b/src/lib/broadcast.jl
@@ -72,7 +72,9 @@ unbroadcast(x::AbstractArray, x̄::Nothing) = nothing
broadcast(+, xs...), ȳ -> (nothing, map(x -> unbroadcast(x, ȳ), xs)...)
@adjoint broadcasted(::typeof(-), x::Numeric, y::Numeric) = x .- y,
- Δ -> (nothing, unbroadcast(x, Δ), -unbroadcast(y, Δ))
+ Δ -> (nothing, unbroadcast(x, Δ), _minus(unbroadcast(y, Δ)))
+_minus(Δ) = -Δ
+_minus(::Nothing) = nothing
@adjoint broadcasted(::typeof(*), x::Numeric, y::Numeric) = x.*y,
Δ -> (nothing, unbroadcast(x, Δ .* conj.(y)), unbroadcast(y, Δ .* conj.(x)))
diff --git a/test/features.jl b/test/features.jl
index d683d0d94..3115a455c 100644
--- a/test/features.jl
+++ b/test/features.jl
@@ -570,6 +570,11 @@ end
@test gradient(x -> sum(_f.(x)), [1,2,3]) == ([0.5, 0.5, 0.5],)
@test gradient(x -> sum(map(_f, x)), [1,2,3]) == ([0.5, 0.5, 0.5],)
+ # with Bool
+ @test gradient(x -> sum(1 .- (x .> 0)), randn(5)) == (nothing,)
+ @test gradient(x -> sum((y->1-y).(x .> 0)), randn(5)) == (nothing,)
+ @test gradient(x -> sum(x .- (x .> 0)), randn(5)) == ([1,1,1,1,1],)
+
@test gradient(x -> sum(x ./ [1,2,4]), [1,2,pi]) == ([1.0, 0.5, 0.25],)
@test gradient(x -> sum(map(/, x, [1,2,4])), [1,2,pi]) == ([1.0, 0.5, 0.25],)
From 0cb031a3181fd9f8b63abb9c0b2629a1cb576a0d Mon Sep 17 00:00:00 2001
From: Michael Abbott <32575566+mcabbott@users.noreply.github.com>
Date: Sat, 2 Oct 2021 12:34:52 -0400
Subject: [PATCH 233/490] 0.6.26
---
Project.toml | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/Project.toml b/Project.toml
index f6373cf19..15ecc3e73 100644
--- a/Project.toml
+++ b/Project.toml
@@ -1,6 +1,6 @@
name = "Zygote"
uuid = "e88e6eb3-aa80-5325-afca-941959d7151f"
-version = "0.6.25"
+version = "0.6.26"
[deps]
AbstractFFTs = "621f4979-c628-5d54-868e-fcf4e3e8185c"
From eb347fe62a8743e30c1de7a318812996cc8ea8dc Mon Sep 17 00:00:00 2001
From: Gabriel Birnbaum
Date: Tue, 5 Oct 2021 09:13:12 +0200
Subject: [PATCH 234/490] minimize changes
Co-authored-by: Carlo Lucibello
---
src/lib/base.jl | 6 +-----
1 file changed, 1 insertion(+), 5 deletions(-)
diff --git a/src/lib/base.jl b/src/lib/base.jl
index 55d660279..472ee85a5 100644
--- a/src/lib/base.jl
+++ b/src/lib/base.jl
@@ -121,11 +121,7 @@ end
pairs_namedtuple_pullback(dx::NamedTuple) = (dx.data,)
- function pairs_namedtuple_pullback(dx::Tuple)
- t0 = isempty(dx) ? () : NamedTuple{N}(values(dx))
- return (t0,)
- end
-
+pairs_namedtuple_pullback(dx::Tuple{}) = (NamedTuple(),)
function pairs_namedtuple_pullback(Δ::Dict)
t0 = map(zero, t)
for (idx, v) in Δ
From 63a9a543ffc0e8e4198803af28464aa2a4587e0f Mon Sep 17 00:00:00 2001
From: Gabriel Birnbaum
Date: Tue, 5 Oct 2021 09:14:44 +0200
Subject: [PATCH 235/490] fix indentation
---
src/lib/base.jl | 3 ++-
1 file changed, 2 insertions(+), 1 deletion(-)
diff --git a/src/lib/base.jl b/src/lib/base.jl
index 472ee85a5..ac7df59a2 100644
--- a/src/lib/base.jl
+++ b/src/lib/base.jl
@@ -121,7 +121,8 @@ end
pairs_namedtuple_pullback(dx::NamedTuple) = (dx.data,)
-pairs_namedtuple_pullback(dx::Tuple{}) = (NamedTuple(),)
+ pairs_namedtuple_pullback(dx::Tuple{}) = (NamedTuple(),)
+
function pairs_namedtuple_pullback(Δ::Dict)
t0 = map(zero, t)
for (idx, v) in Δ
From a3f8dc4986005f532ad72a35afad74b684bb6289 Mon Sep 17 00:00:00 2001
From: Simeon David Schaub
Date: Tue, 5 Oct 2021 11:18:22 -0400
Subject: [PATCH 236/490] WIP: improve inference for getproperty
This has regressed quite a bit due to #848. With this PR, we should be able to get back the same performance as before, assuming there is no custom implementation or pullback for `getproperty`. Still need to add tests.
---
src/Zygote.jl | 1 +
src/lib/literal_getproperty.jl | 82 ++++++++++++++++++++++++++++++++++
test/compiler.jl | 47 +++++++++++++++++++
3 files changed, 130 insertions(+)
create mode 100644 src/lib/literal_getproperty.jl
diff --git a/src/Zygote.jl b/src/Zygote.jl
index ae023213c..85b71359f 100644
--- a/src/Zygote.jl
+++ b/src/Zygote.jl
@@ -33,6 +33,7 @@ include("compiler/show.jl")
include("lib/grad.jl")
include("lib/lib.jl")
+include("lib/literal_getproperty.jl")
include("lib/number.jl")
include("lib/base.jl")
include("lib/array.jl")
diff --git a/src/lib/literal_getproperty.jl b/src/lib/literal_getproperty.jl
new file mode 100644
index 000000000..1959e9462
--- /dev/null
+++ b/src/lib/literal_getproperty.jl
@@ -0,0 +1,82 @@
+# Mostly copied over from Cassette in `src/overdub.jl`
+# Return `Reflection` for signature `sigtypes` and `world`, if possible. Otherwise, return `nothing`.
+function reflect(@nospecialize(sigtypes::Tuple), world::UInt = typemax(UInt))
+ if length(sigtypes) > 2 && sigtypes[1] === typeof(invoke)
+ @assert sigtypes[3] <: Type{<:Tuple}
+ sigtypes = (sigtypes[2], sigtypes[3].parameters[1].parameters...)
+ end
+ # This works around a subtyping bug. Basically, callers can deconstruct upstream
+ # `UnionAll` types in such a way that results in a type with free type variables, in
+ # which case subtyping can just break.
+ #
+ # God help you if you try to use a type parameter here (e.g. `::Type{S} where S<:Tuple`)
+ # instead of this nutty workaround, because the compiler can just rewrite `S` into
+ # whatever it thinks is "type equal" to the actual provided value. In other words, if
+ # `S` is defined as e.g. `f(::Type{S}) where S`, and you call `f(T)`, you should NOT
+ # assume that `S === T`. If you did, SHAME ON YOU. It doesn't matter that such an
+ # assumption holds true for essentially all other kinds of values. I haven't counted in
+ # a while, but I'm pretty sure I have ~40+ hellish years of Julia experience, and this
+ # still catches me every time. Who even uses this crazy language?
+ S = Tuple{map(s -> Core.Compiler.has_free_typevars(s) ? typeof(s.parameters[1]) : s, sigtypes)...}
+ (S.parameters[1]::DataType).name.module === Core.Compiler && return nothing
+ _methods = Base._methods_by_ftype(S, -1, world)
+ method_index = 0
+ for i in 1:length(_methods)
+ if _methods[i][1] === S
+ method_index = i
+ break
+ end
+ end
+ method_index === 0 && return nothing
+ type_signature, raw_static_params, method = _methods[method_index]
+ method_instance = Core.Compiler.specialize_method(method, type_signature, raw_static_params, false)
+ method_signature = method.sig
+ static_params = Any[raw_static_params...]
+ return method_instance, method_signature, static_params
+end
+
+
+# ugly hack to make differentiating `getproperty` infer a lot better
+@generated function _pullback(cx::AContext, ::typeof(literal_getproperty), x, ::Val{f}) where f
+ sig(x) = Tuple{x, typeof(f)}
+ rrule_sig(x) = Tuple{typeof(getproperty), x, typeof(f)}
+ pb_sig(x) = Tuple{cx, typeof(getproperty), x, typeof(f)}
+
+ # either `getproperty` has a custom implementation or `_pullback(cx, getproperty, x, f)`
+ # / `rrule(getproperty, x, f) is overloaded directly
+ is_getfield_fallback = which(getproperty, sig(x)) == which(getproperty, sig(Any)) &&
+ which(_pullback, pb_sig(x)) == which(_pullback, pb_sig(Any)) &&
+ which(rrule, rrule_sig(x)) == which(rrule, rrule_sig(Any))
+
+ #ccall(:jl_safe_printf, Cvoid, (Cstring,), "$is_getfield_fallback: $x\n")
+
+ if is_getfield_fallback
+ # just copy pullback of `literal_getfield`
+ mi, _sig, sparams = reflect((typeof(_pullback), cx, typeof(literal_getfield), x, Val{f}))
+ ci = copy(Core.Compiler.retrieve_code_info(mi))
+
+ # we need to change the second arg to `_pullback` from `literal_getproperty` to
+ # `literal_getfield`
+ Meta.partially_inline!(
+ ci.code, Any[_pullback, Core.SlotNumber(2), literal_getfield],
+ _sig, sparams, 0, 0, :propagate,
+ )
+ ci.inlineable = true
+
+ # backedge for `_pullback`, see https://docs.julialang.org/en/v1/devdocs/ast/#MethodInstance
+ # this will cause a backedge to this particular MethodInstance to be attached to
+ # `_pullback(cx, getproperty, x, f)`
+ mi_pb_getproperty, _, _ = reflect((typeof(_pullback), pb_sig(x).parameters...))
+ mi_getproperty, _, _ = reflect((typeof(getproperty), sig(x).parameters...))
+ mi_rrule, _, _ = reflect((typeof(rrule), rrule_sig(x).parameters...))
+ ci.edges = Core.MethodInstance[mi, mi_pb_getproperty, mi_getproperty, mi_rrule]
+
+ return ci
+ else
+ # nothing to optimize here, need to recurse into `getproperty`
+ return quote
+ Base.@_inline_meta
+ _pullback(cx, getproperty, x, $(QuoteNode(f)))
+ end
+ end
+end
diff --git a/test/compiler.jl b/test/compiler.jl
index af8e6ccb7..71e49ded4 100644
--- a/test/compiler.jl
+++ b/test/compiler.jl
@@ -144,5 +144,52 @@ end
@test Zygote.gradient(sumall, ms) == ((a = 2, b = 2),)
end
+using ChainRulesCore
+
+function _Gaussian(suffix::Symbol)
+ name = gensym(Symbol(:Gaussian_, suffix))
+ return @eval begin
+ struct $name{Tm, TP}
+ m::Tm
+ P::TP
+ end
+ $name
+ end
+end
+
+@testset "inference for `getproperty`" begin
+ Gaussian = _Gaussian(:getproperty)
+ g = Gaussian(randn(3), randn(3, 3))
+ y, back = @inferred pullback(x -> x.m, g)
+ @test y == getfield(g, :m)
+ @test Base.return_types(back, Tuple{Vector{Float64}}) == Any[Union{Tuple{Nothing}, typeof(((m = [1.0, 0.0, 0.0], P = nothing),))}]
+ @test back([1., 0, 0]) == ((m = [1.0, 0.0, 0.0], P = nothing),)
+
+ Base.getproperty(g::Gaussian, s::Symbol) = 2getfield(g, s)
+ y, back = pullback(x -> x.m, g)
+ @test y == 2getfield(g, :m)
+ @test back([1., 0, 0]) == ((m = [2.0, 0.0, 0.0], P = nothing),)
+
+
+ Gaussian = _Gaussian(:pullback)
+ g = Gaussian(randn(3), randn(3, 3))
+ y, back = @inferred pullback(x -> x.m, g)
+
+ Zygote._pullback(::typeof(getproperty), g::Gaussian, s::Symbol) = 3getfield(g, s), Δ -> (nothing, (; ((:m, :P) .=> nothing)..., s => 3Δ), nothing)
+ y, back = pullback(x -> x.m, g)
+ @test_broken y == 3getfield(g, :m)
+ @test_broken back([1., 0, 0]) == ((m = [3.0, 0.0, 0.0], P = nothing),)
+
+
+ Gaussian = _Gaussian(:rrule)
+ g = Gaussian(randn(3), randn(3, 3))
+ y, back = @inferred pullback(x -> x.m, g)
+
+ ChainRulesCore.rrule(::typeof(getproperty), g::Gaussian, s::Symbol) = 4getfield(g, s), Δ -> (NoTangent(), Tangent{typeof(g)}(; s => 4Δ), NoTangent())
+ y, back = pullback(x -> x.m, g)
+ @test y == 4getfield(g, :m)
+ @test back([1., 0, 0]) == ((m = [4.0, 0.0, 0.0], P = nothing),)
+end
+
# issue 897
@test gradient(x -> sum(norm, collect(eachcol(x))), ones(3, 400))[1] ≈ fill(0.5773502691896258, 3, 400)
From 70ab7c1e7346192265d6fbcc92c4a31cf2678e92 Mon Sep 17 00:00:00 2001
From: Carlo Lucibello
Date: Thu, 7 Oct 2021 07:37:02 +0200
Subject: [PATCH 237/490] Update Project.toml
---
Project.toml | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/Project.toml b/Project.toml
index 15ecc3e73..d8cf865a2 100644
--- a/Project.toml
+++ b/Project.toml
@@ -1,6 +1,6 @@
name = "Zygote"
uuid = "e88e6eb3-aa80-5325-afca-941959d7151f"
-version = "0.6.26"
+version = "0.6.27"
[deps]
AbstractFFTs = "621f4979-c628-5d54-868e-fcf4e3e8185c"
From e245dee827f802d54b10a0c00c69083b6038de47 Mon Sep 17 00:00:00 2001
From: Carlo Lucibello
Date: Thu, 7 Oct 2021 13:46:16 +0200
Subject: [PATCH 238/490] Update Project.toml
---
Project.toml | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/Project.toml b/Project.toml
index d8cf865a2..15ecc3e73 100644
--- a/Project.toml
+++ b/Project.toml
@@ -1,6 +1,6 @@
name = "Zygote"
uuid = "e88e6eb3-aa80-5325-afca-941959d7151f"
-version = "0.6.27"
+version = "0.6.26"
[deps]
AbstractFFTs = "621f4979-c628-5d54-868e-fcf4e3e8185c"
From bdbb36979ba458f04706c07891e3047afd22c167 Mon Sep 17 00:00:00 2001
From: Simeon David Schaub
Date: Tue, 12 Oct 2021 12:20:22 -0400
Subject: [PATCH 239/490] address review comments
---
test/compiler.jl | 14 +++++++++++---
1 file changed, 11 insertions(+), 3 deletions(-)
diff --git a/test/compiler.jl b/test/compiler.jl
index 71e49ded4..eec71e53d 100644
--- a/test/compiler.jl
+++ b/test/compiler.jl
@@ -162,6 +162,7 @@ end
g = Gaussian(randn(3), randn(3, 3))
y, back = @inferred pullback(x -> x.m, g)
@test y == getfield(g, :m)
+ # This type instability is due to the handling of non-bitstypes in `accum_param`
@test Base.return_types(back, Tuple{Vector{Float64}}) == Any[Union{Tuple{Nothing}, typeof(((m = [1.0, 0.0, 0.0], P = nothing),))}]
@test back([1., 0, 0]) == ((m = [1.0, 0.0, 0.0], P = nothing),)
@@ -175,10 +176,10 @@ end
g = Gaussian(randn(3), randn(3, 3))
y, back = @inferred pullback(x -> x.m, g)
- Zygote._pullback(::typeof(getproperty), g::Gaussian, s::Symbol) = 3getfield(g, s), Δ -> (nothing, (; ((:m, :P) .=> nothing)..., s => 3Δ), nothing)
+ Zygote._pullback(::Zygote.AContext, ::typeof(getproperty), g::Gaussian, s::Symbol) = 3getfield(g, s), Δ -> (nothing, (; ((:m, :P) .=> nothing)..., s => 3Δ), nothing)
y, back = pullback(x -> x.m, g)
- @test_broken y == 3getfield(g, :m)
- @test_broken back([1., 0, 0]) == ((m = [3.0, 0.0, 0.0], P = nothing),)
+ @test y == 3getfield(g, :m)
+ @test back([1., 0, 0]) == ((m = [3.0, 0.0, 0.0], P = nothing),)
Gaussian = _Gaussian(:rrule)
@@ -189,6 +190,13 @@ end
y, back = pullback(x -> x.m, g)
@test y == 4getfield(g, :m)
@test back([1., 0, 0]) == ((m = [4.0, 0.0, 0.0], P = nothing),)
+
+
+ Gaussian = _Gaussian(:bitstype)
+ g = Gaussian(randn(), randn())
+ y, back = @inferred pullback(x -> x.m, g)
+ @test y == getfield(g, :m)
+ @test @inferred(back(1.0)) == ((m = 1.0, P = nothing),)
end
# issue 897
From 1d189fb72cb0341b6045057368eed8d3583c4c8a Mon Sep 17 00:00:00 2001
From: Simeon David Schaub
Date: Tue, 12 Oct 2021 12:27:31 -0400
Subject: [PATCH 240/490] bump patch version
---
Project.toml | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/Project.toml b/Project.toml
index 15ecc3e73..d8cf865a2 100644
--- a/Project.toml
+++ b/Project.toml
@@ -1,6 +1,6 @@
name = "Zygote"
uuid = "e88e6eb3-aa80-5325-afca-941959d7151f"
-version = "0.6.26"
+version = "0.6.27"
[deps]
AbstractFFTs = "621f4979-c628-5d54-868e-fcf4e3e8185c"
From 2b35bc0d3fff63779f56db3d26daaab8e197fdc2 Mon Sep 17 00:00:00 2001
From: Michael Abbott <32575566+mcabbott@users.noreply.github.com>
Date: Fri, 15 Oct 2021 09:21:56 -0400
Subject: [PATCH 241/490] wrap_chainrules_input for mutable struct
---
Project.toml | 4 ++--
src/compiler/chainrules.jl | 2 ++
test/features.jl | 14 ++++++++++++++
3 files changed, 18 insertions(+), 2 deletions(-)
diff --git a/Project.toml b/Project.toml
index d8cf865a2..7d4a197ea 100644
--- a/Project.toml
+++ b/Project.toml
@@ -1,6 +1,6 @@
name = "Zygote"
uuid = "e88e6eb3-aa80-5325-afca-941959d7151f"
-version = "0.6.27"
+version = "0.6.28"
[deps]
AbstractFFTs = "621f4979-c628-5d54-868e-fcf4e3e8185c"
@@ -24,7 +24,7 @@ ZygoteRules = "700de1a5-db45-46bc-99cf-38207098b444"
[compat]
AbstractFFTs = "0.5, 1.0"
ChainRules = "1.5"
-ChainRulesCore = "1.6"
+ChainRulesCore = "1.9"
ChainRulesTestUtils = "1"
DiffRules = "1.0"
FillArrays = "0.8, 0.9, 0.10, 0.11, 0.12"
diff --git a/src/compiler/chainrules.jl b/src/compiler/chainrules.jl
index 8c3f1a84d..34b6e637b 100644
--- a/src/compiler/chainrules.jl
+++ b/src/compiler/chainrules.jl
@@ -128,6 +128,8 @@ Convert `x` from the format Zygote uses internally to differentials types ChainR
xp = map(wrap_chainrules_input, xs)
ChainRules.Tangent{Any, typeof(xp)}(xp)
end
+# For mutable types, including x=Ref(1), Zygote makes Ref{Any}(::NamedTuple)
+@inline wrap_chainrules_input(x::Ref) = wrap_chainrules_input(x[])
"""
_project(x, dx)
diff --git a/test/features.jl b/test/features.jl
index 3115a455c..ab68b4bd3 100644
--- a/test/features.jl
+++ b/test/features.jl
@@ -443,6 +443,20 @@ let
@test back(1.) == ((1.0,),)
end
+@testset "mutable struct, including Ref" begin
+ # Zygote's representation is Base.RefValue{Any}((value = 7.0,)), but the
+ # map to ChainRules types and back normalises to (value = 7.0,) same as struct:
+ @test gradient(x -> x.value^2 + x.value, MyMutable(3)) === ((value = 7.0,),)
+
+ # Same for Ref. This doesn't seem to affect `pow_mut` test in this file.
+ @test gradient(x -> x.x^2 + x.x, Ref(3)) === ((x = 7.0,),)
+ @test gradient(x -> real(x.x^2 + im * x.x), Ref(4)) === ((x = 8.0,),)
+
+ # Broadcasting over Ref is handled specially. Tested elsehwere too.
+ @test gradient(x -> sum(sum, x .* [1,2,3]), Ref([4,5])) == ((x = [6.0, 6.0],),)
+ @test gradient(x -> sum(sum, Ref(x) .* [1,2,3]), [4,5]) == ([6.0, 6.0],)
+end
+
function type_test()
Complex{<:Real}
end
From 5ae5b4f2933e87923a567f13e1c298e26b954716 Mon Sep 17 00:00:00 2001
From: Michael Abbott <32575566+mcabbott@users.noreply.github.com>
Date: Sat, 16 Oct 2021 18:43:46 -0400
Subject: [PATCH 242/490] wrap_chainrules_input for arrays of Ref (#1103)
* wrap_chainrules_input for arrays of Ref
* z2d too, for rrule_via_ad
* test from https://github.com/JuliaDiff/ChainRulesCore.jl/issues/440
* add test from https://github.com/JuliaDiff/ChainRules.jl/issues/537
* more tests related to CRC types
* union nothing, fix one case
* comments
---
Project.toml | 2 +-
src/compiler/chainrules.jl | 15 ++++++--
src/lib/broadcast.jl | 1 +
test/features.jl | 7 ++++
test/gradcheck.jl | 74 ++++++++++++++++++++++++++++++++++++++
5 files changed, 96 insertions(+), 3 deletions(-)
diff --git a/Project.toml b/Project.toml
index 7d4a197ea..04196b602 100644
--- a/Project.toml
+++ b/Project.toml
@@ -1,6 +1,6 @@
name = "Zygote"
uuid = "e88e6eb3-aa80-5325-afca-941959d7151f"
-version = "0.6.28"
+version = "0.6.29"
[deps]
AbstractFFTs = "621f4979-c628-5d54-868e-fcf4e3e8185c"
diff --git a/src/compiler/chainrules.jl b/src/compiler/chainrules.jl
index 34b6e637b..9bfde430a 100644
--- a/src/compiler/chainrules.jl
+++ b/src/compiler/chainrules.jl
@@ -115,6 +115,8 @@ for T_outer in (:Tuple, :NamedTuple)
ChainRulesCore.backing(xp) # this is accessing ChainRulesCore internals, but it is prob safe enough, and it is fastest
end
end
+# Could `reinterpret` instead of broadcasting here -- TODO
+@inline wrap_chainrules_output(xs::AbstractArray{<:ChainRules.Tangent}) = wrap_chainrules_output.(xs)
"""
wrap_chainrules_input(x)
@@ -130,6 +132,11 @@ Convert `x` from the format Zygote uses internally to differentials types ChainR
end
# For mutable types, including x=Ref(1), Zygote makes Ref{Any}(::NamedTuple)
@inline wrap_chainrules_input(x::Ref) = wrap_chainrules_input(x[])
+# Could `reinterpret` instead of broadcasting here -- TODO
+@inline wrap_chainrules_input(xs::AbstractArray{<:Ref}) = wrap_chainrules_input.(xs)
+@inline wrap_chainrules_input(xs::AbstractArray{<:Union{Nothing, <:Ref}}) = wrap_chainrules_input.(xs) # no test invented for this
+@inline wrap_chainrules_input(xs::AbstractArray{<:NamedTuple}) = wrap_chainrules_input.(xs)
+@inline wrap_chainrules_input(xs::AbstractArray{<:Union{Nothing, <:NamedTuple}}) = wrap_chainrules_input.(xs)
"""
_project(x, dx)
@@ -139,6 +146,8 @@ Also handles some Zygote-specific corrections, such as `x::Array, dx::Tuple`.
Safe to apply to arbitrary input.
"""
@inline function _project(x, dx)
+ # Note that this use of `wrap_chainrules_input` has the primal `x`, so could
+ # avoid making `Tangent{Any}`, perhaps via `zygote2differential` -- TODO.
wrap_chainrules_output(ProjectTo(x)(wrap_chainrules_input(dx)))
end
@@ -224,9 +233,9 @@ function ChainRulesCore.rrule_via_ad(config::ZygoteRuleConfig, f_args...; kwargs
end
"""
- zygote2differential(x)
+ zygote2differential(dx, primal)
-Convert input `x` from the Zygote format to the ChainRules differential types.
+Convert input `dx` from the Zygote format to the ChainRules differential types.
"""
zygote2differential(x, primal) = z2d(x, primal)
zygote2differential(::Nothing, ::Any) = NoTangent()
@@ -235,6 +244,7 @@ zygote2differential(t::Tuple, primal) = (@warn "primal should be a tuple, not $p
z2d(x, ::Any) = x
z2d(::Nothing, ::Any) = NoTangent()
z2d(a::AbstractArray{<:Number}, primal::AbstractArray{T}) where T = a
+# Could probably `reinterpret` instead of broadcasting here -- TODO
z2d(a::AbstractArray, primal::AbstractArray{T}) where T = z2d.(a, primal)
# Note: this should never be hit if we are converting things right, but it seems to be
# happening in the wild for sufficiently weird functions/types.
@@ -254,3 +264,4 @@ function z2d(t::NamedTuple, primal)
tp::NamedTuple = map(z2d, complete_t, primals)
return canonicalize(Tangent{primal_type, typeof(tp)}(tp))
end
+z2d(dx::Ref, primal) = z2d(dx[], primal) # mutable structs
diff --git a/src/lib/broadcast.jl b/src/lib/broadcast.jl
index 3affebd92..8833436a0 100644
--- a/src/lib/broadcast.jl
+++ b/src/lib/broadcast.jl
@@ -58,6 +58,7 @@ unbroadcast(x::Number, x̄) = accum_sum(x̄)
unbroadcast(x::Tuple{<:Any}, x̄) = (accum_sum(x̄),)
unbroadcast(x::Base.RefValue, x̄) = (x=accum_sum(x̄),)
unbroadcast(x::Tuple, x̄) = NTuple{length(x)}(length(x) == length(x̄) ? x̄ : accum_sum(x̄; dims=2:ndims(x̄))) # case length(x) > 1
+unbroadcast(x::Tuple, x̄::Nothing) = nothing
unbroadcast(x::AbstractArray, x̄::Nothing) = nothing
diff --git a/test/features.jl b/test/features.jl
index ab68b4bd3..795879fed 100644
--- a/test/features.jl
+++ b/test/features.jl
@@ -452,6 +452,13 @@ end
@test gradient(x -> x.x^2 + x.x, Ref(3)) === ((x = 7.0,),)
@test gradient(x -> real(x.x^2 + im * x.x), Ref(4)) === ((x = 8.0,),)
+ # Array of mutables:
+ @test gradient(x -> sum(getindex.(x).^2), Ref.(1:3))[1] == [(;x=2i) for i in 1:3]
+ @test gradient(x -> sum(abs2∘getindex, x), Ref.(1:3))[1] == [(;x=2i) for i in 1:3]
+
+ @test gradient(x -> (getindex.(x).^2)[1], Ref.(1:3))[1][1] == (x=2.0,) # rest are (x = 0.0,), but nothing would be OK too
+ @test gradient(x -> (prod.(getindex.(x)))[1], Ref.(eachcol([1 2; 3 4])))[1][1] == (x = [3.0, 1.0],)
+
# Broadcasting over Ref is handled specially. Tested elsehwere too.
@test gradient(x -> sum(sum, x .* [1,2,3]), Ref([4,5])) == ((x = [6.0, 6.0],),)
@test gradient(x -> sum(sum, Ref(x) .* [1,2,3]), [4,5]) == ([6.0, 6.0],)
diff --git a/test/gradcheck.jl b/test/gradcheck.jl
index e34369f7c..ac7893cfa 100644
--- a/test/gradcheck.jl
+++ b/test/gradcheck.jl
@@ -245,6 +245,11 @@ end
@test gradtest(x->fill(first(x), N), randn(rng, 1))
@test gradtest(x->fill(first(x), N, M), randn(rng, 1))
@test gradtest(x->fill(first(x), N, M, P), randn(rng, 1))
+
+ # fill(struct, ...) handled by ChainRules after
+ # https://github.com/FluxML/Zygote.jl/pull/1051
+ @test gradient(x -> fill(x, 3)[1][1], (1,2)) === ((1.0, nothing),)
+ @test gradient(x -> fill(x, 3)[1].a, (a=1, b=2)) == ((a=1.0, b=nothing),) # 1 not 1.0
end
@testset "circshift" begin
@@ -344,6 +349,20 @@ end
end
end
+@testset "map and tuples" begin
+ # arrays of tuples, ChainRules's Tangent should not escape
+ @test gradient(x -> sum(map(first, x)), [(1,2), (3,4)]) == ([(1.0, nothing), (1.0, nothing)],)
+ @test gradient(x -> sum(first, x), [(1,2), (3,4)]) == ([(1.0, nothing), (1.0, nothing)],)
+
+ @test gradient(x -> map(+, x, (1,2,3))[1], (4,5,6)) == ((1.0, nothing, nothing),)
+ @test gradient(x -> map(+, x, [1,2,3])[1], (4,5,6)) == ((1.0, 0.0, 0.0),)
+ @test_broken gradient(x -> map(+, x, (1,2,3))[1], [4,5,6]) == ([1,0,0],) # Gradient [1.0, 0.0, 0.0] should be a tuple, since v0.6.0 at least
+
+ # mismatched lengths, should zip
+ @test_broken gradient(x -> map(+, x, [1,2,3,99])[1], (4,5,6)) == ((1.0, 0.0, 0.0),) # BoundsError: attempt to access 3-element Vector{Float64} at index [4]
+ @test_broken gradient(x -> map(+, x, [1,2,3])[1], (4,5,6,99)) == ((1.0, 0.0, 0.0, nothing),) # DimensionMismatch("variable with size(x) == (4,) cannot have a gradient with size(dx) == (3,)
+end
+
@testset "Alternative Pmap Dispatch" begin
cache_and_map(f,xs...) = pmap(f, CachingPool(workers()), xs...; batch_size = 1)
@test gradtest(xs -> sum(cache_and_map(x -> x^2, xs)), rand(2,3))
@@ -1783,3 +1802,58 @@ end
# https://github.com/FluxML/Zygote.jl/issues/996
a = rand(3)
@test Zygote.gradient(x->sum(x .+ rand.()), a) == (ones(3),)
+
+@testset "CRC issue 440" begin
+ # https://github.com/JuliaDiff/ChainRulesCore.jl/issues/440
+ f(x,y) = sum(sum, [[x[i],y[i]] for i=1:length(x)])
+ g(x,y) = sum(sum, [(x[i],y[i]) for i=1:length(x)])
+ @test gradient(f, rand(3), rand(3)) == ([1.0, 1.0, 1.0], [1.0, 1.0, 1.0])
+ @test gradient(g, rand(3), rand(3)) == ([1.0, 1.0, 1.0], [1.0, 1.0, 1.0])
+end
+
+@testset "CR issue 537" begin
+ # https://github.com/JuliaDiff/ChainRules.jl/issues/537
+ struct BV{F,T}
+ A::F
+ α::T
+ end
+ function Base.:*(c, km::BV)
+ new_A = c*km.A
+ other_params = getfield.([km], propertynames(km))[2:end]
+ BV(new_A, other_params...)
+ end
+ function (bv::BV)(V_app, ox::Bool; kT::Real = 0.026)
+ local exp_arg
+ if ox
+ exp_arg = (bv.α .* V_app) ./ kT
+ else
+ exp_arg = -((1 .- bv.α) .* V_app) ./ kT
+ end
+ bv.A .* exp.(exp_arg)
+ end
+ Zygote.@adjoint function BV{T,S}(A, α) where {T,S}
+ BV(A, α), Δ -> begin
+ (Δ.A, Δ.α)
+ end
+ end
+ bv = BV(1.0, 0.1)
+ I_vals, V = rand(81), rand(81)
+
+ g2 = gradient(V, bv) do V, bv
+ res = fill(bv, length(V))
+ r1 = map((m,v) -> m(v, true), res, V)
+ r2 = map((m,v) -> m(v, false), res, V)
+ sum(r1 .- r2)
+ end
+ @test size(g2[1]) == size(V)
+ @test g2[2] isa NamedTuple
+ @test g2[2].A isa Number
+
+ g1 = gradient(bv, V) do bv, V
+ res = map(x -> x * bv, V)
+ sum(x -> x.A, res)
+ end
+ @test g1[1] isa NamedTuple
+ @test g1[1].A isa Number
+ @test size(g1[2]) == size(V)
+end
From 4edde590303d95605d410066d46ba05e1d3a0843 Mon Sep 17 00:00:00 2001
From: Carlo Lucibello
Date: Sun, 17 Oct 2021 06:48:10 +0200
Subject: [PATCH 243/490] remove FastAI and GeometricFlux
---
.github/workflows/Downstream.yml | 2 --
1 file changed, 2 deletions(-)
diff --git a/.github/workflows/Downstream.yml b/.github/workflows/Downstream.yml
index 6565e82d5..c40bc6617 100644
--- a/.github/workflows/Downstream.yml
+++ b/.github/workflows/Downstream.yml
@@ -20,8 +20,6 @@ jobs:
package:
- {user: FluxML, repo: Flux.jl, group: All}
- {user: FluxML, repo: NNlib.jl, group: All}
- - {user: FluxML, repo: FastAI.jl, group: All}
- - {user: FluxML, repo: GeometricFlux.jl, group: All}
- {user: SciML, repo: DiffEqFlux.jl, group: Layers}
- {user: SciML, repo: NeuralPDE.jl, group: NNPDE}
steps:
From 218eefb340c3dc47404b3cb27ecee5303dc650bb Mon Sep 17 00:00:00 2001
From: lassepe
Date: Sun, 24 Oct 2021 20:30:25 +0200
Subject: [PATCH 244/490] First stab at copy and copy for Grads
---
src/compiler/interface.jl | 17 +++++++++++++++++
src/tools/buffer.jl | 6 ++++++
src/tools/idset.jl | 4 ++++
test/interface.jl | 20 ++++++++++++++++++++
4 files changed, 47 insertions(+)
diff --git a/src/compiler/interface.jl b/src/compiler/interface.jl
index e210e65b6..9965c3469 100644
--- a/src/compiler/interface.jl
+++ b/src/compiler/interface.jl
@@ -237,6 +237,12 @@ function copy!(x::AbstractVector, ps::Params)
ps
end
+function copy!(ps_dst::Params, ps_src::Params)
+ copy!(ps_dst.order, ps_src.order)
+ copy!(ps_dst.params, ps_src.params)
+ ps_dst
+end
+
"""
Grads(...)
@@ -299,6 +305,17 @@ function copy!(x::AbstractVector, gs::Grads)
x
end
+function copy!(gs_dst::Grads, gs_src::Grads)
+ copy!(gs_dst.grads, gs_src.grads)
+ copy!(gs_dst.params, gs_src.params)
+ gs_dst
+end
+
+function Base.copy(gs::Grads)
+ gs_new = Grads(IdDict(), Params())
+ copy!(gs_new, gs)
+end
+
broadcasted(f, gs::Grads, gss::ADictOrGrads...) = map(f, gs, gss...)
broadcasted(f, a::Numeric, gs::Grads) = map(x -> f(a, x), gs)
diff --git a/src/tools/buffer.jl b/src/tools/buffer.jl
index 9409a74bc..5d6fdd9bb 100644
--- a/src/tools/buffer.jl
+++ b/src/tools/buffer.jl
@@ -39,6 +39,12 @@ mutable struct Buffer{T,A<:AbstractArray{T}}
freeze::Bool
end
+function Base.copy!(b_dst::Buffer, b_src::Buffer)
+ b_dst.data = b_src.data
+ b_dst.freeze = b_src.freeze
+ b_dst
+end
+
Buffer(xs::AbstractArray, args...) =
Buffer(similar(xs, args...), false)
diff --git a/src/tools/idset.jl b/src/tools/idset.jl
index a0aa93df0..d8072e18b 100644
--- a/src/tools/idset.jl
+++ b/src/tools/idset.jl
@@ -15,6 +15,10 @@ Base.in(x, s::IdSet) = haskey(s.dict, x)
Base.eltype(::IdSet{T}) where T = T
Base.collect(s::IdSet) = Base.collect(keys(s.dict))
Base.similar(s::IdSet, T::Type) = IdSet{T}()
+function Base.empty!(s::IdSet)
+ empty!(s.dict)
+ s
+end
@forward IdSet.dict Base.length
diff --git a/test/interface.jl b/test/interface.jl
index 0bee98321..5b126a765 100644
--- a/test/interface.jl
+++ b/test/interface.jl
@@ -32,6 +32,10 @@ using Zygote: Grads
x = [0, 0, 0]
copy!(x, ps)
@test x == [1, 2, 3]
+
+ ps_src = Params([[1, 2], [3]])
+ ps_dst = Params([4][5])
+ ps_dst = ps_src
end
@testset "broadcast" begin
@@ -132,6 +136,22 @@ end
@test_throws ArgumentError gs1 .+ gs4
end
+ @testset "copy" begin
+ w, b = rand(2), rand(2)
+ x1, x2 = rand(2), rand(2)
+
+ gs1 = gradient(() -> sum(w .* x1), Params([w]))
+ gs2 = gradient(() -> sum(w .* x2), Params([w]))
+
+ gs_new = copy(gs1)
+ copy!(gs2, gs1)
+
+ # TODO: these tests are currently broken because `Base.iseqeual` is not doing useful things
+ # for `Grads` right now.
+ @test_broken gs1 == gs_new
+ @test_broken gs2 == gs1
+ end
+
@testset "map and broadcast" begin
w = rand(2)
x1 = rand(2)
From 55b4381df6ae9488f1c6c7e04b8d10b9a6df727f Mon Sep 17 00:00:00 2001
From: lassepe
Date: Mon, 25 Oct 2021 10:15:21 +0200
Subject: [PATCH 245/490] Implement copy in terms of merge!
---
src/compiler/interface.jl | 16 +++++-----------
src/tools/buffer.jl | 6 ------
test/interface.jl | 7 +++----
3 files changed, 8 insertions(+), 21 deletions(-)
diff --git a/src/compiler/interface.jl b/src/compiler/interface.jl
index 9965c3469..89285539f 100644
--- a/src/compiler/interface.jl
+++ b/src/compiler/interface.jl
@@ -237,12 +237,6 @@ function copy!(x::AbstractVector, ps::Params)
ps
end
-function copy!(ps_dst::Params, ps_src::Params)
- copy!(ps_dst.order, ps_src.order)
- copy!(ps_dst.params, ps_src.params)
- ps_dst
-end
-
"""
Grads(...)
@@ -305,15 +299,15 @@ function copy!(x::AbstractVector, gs::Grads)
x
end
-function copy!(gs_dst::Grads, gs_src::Grads)
- copy!(gs_dst.grads, gs_src.grads)
- copy!(gs_dst.params, gs_src.params)
+function Base.merge!(gs_dst::Grads, gs_src::Grads)
+ union!(gs_dst.params, gs_src.params)
+ map!(copy, gs_dst, gs_src)
gs_dst
end
function Base.copy(gs::Grads)
- gs_new = Grads(IdDict(), Params())
- copy!(gs_new, gs)
+ gs_new = Grads(IdDict(), gs.params)
+ merge!(gs_new, gs)
end
broadcasted(f, gs::Grads, gss::ADictOrGrads...) = map(f, gs, gss...)
diff --git a/src/tools/buffer.jl b/src/tools/buffer.jl
index 5d6fdd9bb..9409a74bc 100644
--- a/src/tools/buffer.jl
+++ b/src/tools/buffer.jl
@@ -39,12 +39,6 @@ mutable struct Buffer{T,A<:AbstractArray{T}}
freeze::Bool
end
-function Base.copy!(b_dst::Buffer, b_src::Buffer)
- b_dst.data = b_src.data
- b_dst.freeze = b_src.freeze
- b_dst
-end
-
Buffer(xs::AbstractArray, args...) =
Buffer(similar(xs, args...), false)
diff --git a/test/interface.jl b/test/interface.jl
index 5b126a765..0175151a4 100644
--- a/test/interface.jl
+++ b/test/interface.jl
@@ -146,10 +146,9 @@ end
gs_new = copy(gs1)
copy!(gs2, gs1)
- # TODO: these tests are currently broken because `Base.iseqeual` is not doing useful things
- # for `Grads` right now.
- @test_broken gs1 == gs_new
- @test_broken gs2 == gs1
+ #TODO: a bit of a hacky workaround here, would be nice if we could compare gradients directly
+ @test collect(gs1) == collect(gs_new)
+ @test collect(gs2) == collect(gs1)
end
@testset "map and broadcast" begin
From bbc6b362d12d30b9d544d4651f530d96947f0415 Mon Sep 17 00:00:00 2001
From: Lasse Peters
Date: Mon, 25 Oct 2021 10:34:06 +0200
Subject: [PATCH 246/490] Update idset.jl
Get rid of Base.empt! For `IdSet`.
---
src/tools/idset.jl | 4 ----
1 file changed, 4 deletions(-)
diff --git a/src/tools/idset.jl b/src/tools/idset.jl
index d8072e18b..a0aa93df0 100644
--- a/src/tools/idset.jl
+++ b/src/tools/idset.jl
@@ -15,10 +15,6 @@ Base.in(x, s::IdSet) = haskey(s.dict, x)
Base.eltype(::IdSet{T}) where T = T
Base.collect(s::IdSet) = Base.collect(keys(s.dict))
Base.similar(s::IdSet, T::Type) = IdSet{T}()
-function Base.empty!(s::IdSet)
- empty!(s.dict)
- s
-end
@forward IdSet.dict Base.length
From d51de8ac102fe41811fcd557c6ccac7b7d34d2d5 Mon Sep 17 00:00:00 2001
From: Lasse Peters
Date: Mon, 25 Oct 2021 10:41:49 +0200
Subject: [PATCH 247/490] Remove redundant test
---
test/interface.jl | 4 ----
1 file changed, 4 deletions(-)
diff --git a/test/interface.jl b/test/interface.jl
index 0175151a4..c419dbfcd 100644
--- a/test/interface.jl
+++ b/test/interface.jl
@@ -32,10 +32,6 @@ using Zygote: Grads
x = [0, 0, 0]
copy!(x, ps)
@test x == [1, 2, 3]
-
- ps_src = Params([[1, 2], [3]])
- ps_dst = Params([4][5])
- ps_dst = ps_src
end
@testset "broadcast" begin
From 912e60eeb8790c111dfbde216a7c14becc5cb7ab Mon Sep 17 00:00:00 2001
From: lassepe
Date: Mon, 25 Oct 2021 11:06:54 +0200
Subject: [PATCH 248/490] Fix tests
---
test/interface.jl | 26 +++++++++++++++++---------
1 file changed, 17 insertions(+), 9 deletions(-)
diff --git a/test/interface.jl b/test/interface.jl
index c419dbfcd..3651b6adb 100644
--- a/test/interface.jl
+++ b/test/interface.jl
@@ -133,18 +133,26 @@ end
end
@testset "copy" begin
- w, b = rand(2), rand(2)
- x1, x2 = rand(2), rand(2)
+ w, b = rand(2), rand(2)
+ x1, x2 = rand(2), rand(2)
+
+ _, back = pullback(() -> sum(w .* x1), Params([w]))
- gs1 = gradient(() -> sum(w .* x1), Params([w]))
- gs2 = gradient(() -> sum(w .* x2), Params([w]))
+ g1 = back(1)
+ g1_w = g1[w]
+ g2 = back(nothing)
+ @test isnothing(g1[w])
+ @test isnothing(g2[w])
- gs_new = copy(gs1)
- copy!(gs2, gs1)
+ g3 = back(1) |> copy
+ g4 = back(nothing)
+ @test !isnothing(g3[w])
+ @test g3[w] == g1_w
+ @test isnothing(g4[w])
- #TODO: a bit of a hacky workaround here, would be nice if we could compare gradients directly
- @test collect(gs1) == collect(gs_new)
- @test collect(gs2) == collect(gs1)
+ #TODO: a bit of a hacky workaround here, would be nice if we could compare gradients directly
+ g3_copy = copy(g3)
+ @test collect(g3_copy) == collect(g3)
end
@testset "map and broadcast" begin
From 73d1d742f17af18f0c6b6329d87803e3e9e8e95f Mon Sep 17 00:00:00 2001
From: lassepe
Date: Mon, 25 Oct 2021 12:50:38 +0200
Subject: [PATCH 249/490] Add tests for merge!(::Grads, ::Grads)
---
src/compiler/interface.jl | 2 +-
test/interface.jl | 31 ++++++++++++++++++++-----------
2 files changed, 21 insertions(+), 12 deletions(-)
diff --git a/src/compiler/interface.jl b/src/compiler/interface.jl
index 89285539f..e186a134c 100644
--- a/src/compiler/interface.jl
+++ b/src/compiler/interface.jl
@@ -301,7 +301,7 @@ end
function Base.merge!(gs_dst::Grads, gs_src::Grads)
union!(gs_dst.params, gs_src.params)
- map!(copy, gs_dst, gs_src)
+ merge!(gs_dst.grads, gs_src.grads)
gs_dst
end
diff --git a/test/interface.jl b/test/interface.jl
index 3651b6adb..6029d51ee 100644
--- a/test/interface.jl
+++ b/test/interface.jl
@@ -150,7 +150,6 @@ end
@test g3[w] == g1_w
@test isnothing(g4[w])
- #TODO: a bit of a hacky workaround here, would be nice if we could compare gradients directly
g3_copy = copy(g3)
@test collect(g3_copy) == collect(g3)
end
@@ -173,16 +172,26 @@ end
end
@testset "dictionary interface" begin
- w, b, x = rand(2), rand(2), rand(2)
- ps = Params([w, b])
- gs = gradient(() -> sum(tanh.(w .* x .+ b)), ps)
-
- @test issetequal(keys(gs), ps)
- @test length(values(gs)) == 2
- @test length(pairs(gs)) == 2
- k, v = first(pairs(gs))
- @test k === first(ps)
- @test v === gs[first(ps)]
+ w1, b1, x1 = rand(2), rand(2), rand(2)
+ ps1 = Params([w1, b1])
+ gs1 = gradient(() -> sum(tanh.(w1 .* x1 .+ b1)), ps1)
+
+ @test issetequal(keys(gs1), ps1)
+ @test length(values(gs1)) == 2
+ @test length(pairs(gs1)) == 2
+ k, v = first(pairs(gs1))
+ @test k === first(ps1)
+ @test v === gs1[first(ps1)]
+
+ w2, b2, x2 = rand(2), rand(2), rand(2)
+ ps2 = Params([w2, b2])
+ gs2 = gradient(() -> sum(tanh.(w2 .* x2 .+ b2)), ps2)
+
+ keys1 = keys(gs1) |> collect |> copy
+ values1 = values(gs1) |> collect |> copy
+ gs_merged = merge!(gs1, gs2)
+ @test collect(keys(gs_merged)) == union(keys1, keys(gs2))
+ @test collect(values(gs_merged)) == union(values1, values(gs2))
end
@testset "iteration" begin
From f9b9fbb6f5c3110a1d469552ad9a2443f9be5a75 Mon Sep 17 00:00:00 2001
From: lassepe
Date: Mon, 25 Oct 2021 13:04:59 +0200
Subject: [PATCH 250/490] merge! with multiple other Grads objects
---
src/compiler/interface.jl | 10 ++++++----
test/interface.jl | 40 ++++++++++++++++++++++++++++-----------
2 files changed, 35 insertions(+), 15 deletions(-)
diff --git a/src/compiler/interface.jl b/src/compiler/interface.jl
index e186a134c..dc8cb8d18 100644
--- a/src/compiler/interface.jl
+++ b/src/compiler/interface.jl
@@ -299,15 +299,17 @@ function copy!(x::AbstractVector, gs::Grads)
x
end
-function Base.merge!(gs_dst::Grads, gs_src::Grads)
+function Base.merge!(gs_dst::Grads, gs_srcs::Grads...)
+ for gs_src in gs_srcs
union!(gs_dst.params, gs_src.params)
merge!(gs_dst.grads, gs_src.grads)
- gs_dst
+ end
+ gs_dst
end
function Base.copy(gs::Grads)
- gs_new = Grads(IdDict(), gs.params)
- merge!(gs_new, gs)
+ gs_new = Grads(IdDict(), gs.params)
+ merge!(gs_new, gs)
end
broadcasted(f, gs::Grads, gss::ADictOrGrads...) = map(f, gs, gss...)
diff --git a/test/interface.jl b/test/interface.jl
index 6029d51ee..ad1c7f46c 100644
--- a/test/interface.jl
+++ b/test/interface.jl
@@ -172,26 +172,44 @@ end
end
@testset "dictionary interface" begin
+ w, b, x = rand(2), rand(2), rand(2)
+ ps = Params([w, b])
+ gs = gradient(() -> sum(tanh.(w .* x .+ b)), ps)
+
+ @test issetequal(keys(gs), ps)
+ @test length(values(gs)) == 2
+ @test length(pairs(gs)) == 2
+ k, v = first(pairs(gs))
+ @test k === first(ps)
+ @test v === gs[first(ps)]
+ end
+
+ @testset "merge" begin
w1, b1, x1 = rand(2), rand(2), rand(2)
ps1 = Params([w1, b1])
gs1 = gradient(() -> sum(tanh.(w1 .* x1 .+ b1)), ps1)
- @test issetequal(keys(gs1), ps1)
- @test length(values(gs1)) == 2
- @test length(pairs(gs1)) == 2
- k, v = first(pairs(gs1))
- @test k === first(ps1)
- @test v === gs1[first(ps1)]
-
w2, b2, x2 = rand(2), rand(2), rand(2)
ps2 = Params([w2, b2])
gs2 = gradient(() -> sum(tanh.(w2 .* x2 .+ b2)), ps2)
- keys1 = keys(gs1) |> collect |> copy
- values1 = values(gs1) |> collect |> copy
+ w3, b3, x3 = rand(2), rand(2), rand(2)
+ ps3 = Params([w3, b3])
+ gs3 = gradient(() -> sum(tanh.(w3 .* x3 .+ b3)), ps3)
+
+ # merging with a single other Grads object
+ keys1 = keys(gs1)
+ values1 = values(gs1)
gs_merged = merge!(gs1, gs2)
- @test collect(keys(gs_merged)) == union(keys1, keys(gs2))
- @test collect(values(gs_merged)) == union(values1, values(gs2))
+ @test issetequal(keys(gs_merged), union(keys1, keys(gs2)))
+ @test issetequal(values(gs_merged), union(values1, values(gs2)))
+ @test length(pairs(gs_merged)) == 4
+
+ # merging with multiple other Grads objects
+ gs_merged = merge!(gs1, gs2, gs3)
+ @test issetequal(keys(gs_merged), union(keys1, keys(gs2), keys(gs3)))
+ @test issetequal(values(gs_merged), union(values1, values(gs2), values(gs3)))
+ @test length(pairs(gs_merged)) == 6
end
@testset "iteration" begin
From 60f53e709d8b5bc052a20fb4fcf0228004aa4723 Mon Sep 17 00:00:00 2001
From: Michael Abbott <32575566+mcabbott@users.noreply.github.com>
Date: Tue, 26 Oct 2021 04:26:41 -0400
Subject: [PATCH 251/490] Iterators. Product, Filter, enumerate, zip (including
inside map) (#785)
* enumerate, Filter, Product
* zip
* more zip
* tweaking filter locally
* allow for nothing in Iterators.Product
* allow nothing in enumerate, and tidy up
* two more cases
* fix map gradient to allow for early ending & mixed shapes
* more cases for enumerate, zip
* fixes for map
* share code map + zip
* try something re map
* overall testset
* one more restore
* add some info messages
* silence some warnings
* three now pass
* explain what the weird printout is for
* early stopping was different before 1.5
* comments
* project, too
Co-authored-by: Michael Abbott
---
src/lib/array.jl | 64 ++++++++++++++++++++++++++++++---
test/features.jl | 83 ++++++++++++++++++++++++++++++++++++++++++
test/gradcheck.jl | 91 ++++++++++++++++++++++++++++++++++++++++++-----
test/runtests.jl | 7 ++++
test/tools.jl | 9 +++--
5 files changed, 236 insertions(+), 18 deletions(-)
diff --git a/src/lib/array.jl b/src/lib/array.jl
index a44edf7f6..7734ad5ca 100644
--- a/src/lib/array.jl
+++ b/src/lib/array.jl
@@ -179,6 +179,13 @@ end
_tryreverse(m, x) = x
_tryreverse(m::typeof(map), x::Union{AbstractVector, Tuple}) = reverse(x)
+# With mismatched lengths, map stops early. With mismatched shapes, it makes a vector.
+# So we keep axes(x) to restore gradient dx to its full length & correct shape.
+_tryaxes(x) = axes(x)
+_tryaxes(x::Tuple) = Val(length(x))
+_restore(dx, ax::Tuple) = axes(dx) == ax ? dx : reshape(vcat(dx, falses(prod(length, ax) - length(dx))), ax)
+_restore(dx, ::Val{N}) where {N} = length(dx) < N ? ntuple(i -> get(dx,i,nothing), N) : NTuple{N}(dx)
+
# Sometimes a pullback doesn't return a Tuple, but rather returns only a
# single nothing to say "all arguments have zero cotangent". This function is needed to
# account for that inside the pullback for map.
@@ -189,22 +196,27 @@ for (mapfunc,∇mapfunc) in [(:map,:∇map),(:pmap,:∇pmap)]
@eval function $∇mapfunc(cx, f::F, args::Vararg{Any, N}) where {F, N}
ys_and_backs = $mapfunc((args...) -> _pullback(cx, f, args...), args...)
ys = map(first, ys_and_backs)
- ys, function (Δ)
- isnothing(Δ) && return nothing
+ arg_ax = map(_tryaxes, args)
+ function map_back(Δ)
if Base.issingletontype(F) && length(args) == 1
Δarg = $mapfunc(((_,pb), δ) -> last_or_nothing(pb(δ)), ys_and_backs, Δ) # No unzip needed
(nothing, Δarg)
- elseif Base.issingletontype(F) # Ensures `f` is pure: nothing captured & no state
- Δargs = _unzip($mapfunc(((_,pb), δ) -> tailmemaybe(pb(δ)), ys_and_backs, Δ), Val(N))
+ elseif Base.issingletontype(F)
+ # Ensures `f` is pure: nothing captured & no state.
+ unzipped = _unzip($mapfunc(((_,pb), δ) -> tailmemaybe(pb(δ)), ys_and_backs, Δ), Val(N))
+ Δargs = map(_restore, unzipped, arg_ax)
(nothing, Δargs...)
else
# Apply pullbacks in reverse order. Needed for correctness if `f` is stateful.
Δf_and_args_zipped = $mapfunc(((_,pb), δ) -> pb(δ), _tryreverse($mapfunc, ys_and_backs, Δ)...)
Δf_and_args = _unzip(_tryreverse($mapfunc, Δf_and_args_zipped), Val(N + 1))
Δf = reduce(accum, Δf_and_args[1]; init=nothing)
- (Δf, Δf_and_args[2:end]...)
+ Δargs = map(_restore, Δf_and_args[2:end], arg_ax)
+ (Δf, Δargs...)
end
end
+ map_back(::Nothing) = nothing
+ return ys, map_back
end
@eval @adjoint function $mapfunc(f, args::Union{AbstractArray,Tuple}...)
@@ -254,6 +266,48 @@ end
end
end
+# Iterators
+
+@adjoint function enumerate(xs)
+ back(::AbstractArray{Nothing}) = nothing
+ back(dy::NamedTuple{(:itr,)}) = tuple(dy.itr)
+ back(diys) = (map(last, diys),)
+ enumerate(xs), back
+end
+
+@adjoint Iterators.Filter(f, x) = pullback(filter, f, collect(x))
+
+_ndims(::Base.HasShape{d}) where {d} = d
+_ndims(x) = Base.IteratorSize(x) isa Base.HasShape ? _ndims(Base.IteratorSize(x)) : 1
+
+@adjoint function Iterators.product(xs...)
+ back(::AbstractArray{Nothing}) = nothing
+ back(dy::NamedTuple{(:iterators,)}) = dy.iterators
+ function back(dy::AbstractArray)
+ d = 1
+ ntuple(length(xs)) do n
+ first(dy)[n] === nothing && return nothing
+ nd = _ndims(xs[n])
+ dims = ntuple(i -> i tuple
+ Iterators.Zip(xs), back
+end
+
# Reductions
@adjoint function sum(xs::AbstractArray; dims = :)
if dims === (:)
diff --git a/test/features.jl b/test/features.jl
index 795879fed..545db0279 100644
--- a/test/features.jl
+++ b/test/features.jl
@@ -492,6 +492,89 @@ end
@test gradient(h, Dict(:x=>3.0, :y=>4.0, :z=>2.3)) == ((y = 0.0, z = 3.0, x = 2.3),)
end
+@testset "Iterators" begin
+ # enumerate
+ @test gradient(1:5) do xs
+ sum([x^i for (i,x) in enumerate(xs)])
+ end == ([1, 4, 27, 256, 3125],)
+
+ @test gradient([1,10,100]) do xs
+ sum([xs[i]^i for (i,x) in enumerate(xs)])
+ end == ([1, 2 * 10^1, 3 * 100^2],)
+
+ @test gradient([1,10,100]) do xs
+ sum((xs[i]^i for (i,x) in enumerate(xs))) # same without collect
+ end == ([1, 2 * 10^1, 3 * 100^2],)
+
+ # zip
+ if VERSION >= v"1.5"
+ # On Julia 1.4 and earlier, [x/y for (x,y) in zip(10:14, 1:10)] is a DimensionMismatch,
+ # while on 1.5 - 1.7 it stops early.
+
+ @test gradient(10:14, 1:10) do xs, ys
+ sum([x/y for (x,y) in zip(xs, ys)])
+ end[2] ≈ vcat(.-(10:14) ./ (1:5).^2, zeros(5))
+
+ @test_broken gradient(10:14, 1:10) do xs, ys
+ sum(x/y for (x,y) in zip(xs, ys)) # same without collect
+ # Here @adjoint function Iterators.Zip(xs) gets dy = (is = (nothing, nothing),)
+ end[2] ≈ vcat(.-(10:14) ./ (1:5).^2, zeros(5))
+ end
+
+ bk_z = pullback((xs,ys) -> sum([abs2(x*y) for (x,y) in zip(xs,ys)]), [1,2], [3im,4im])[2]
+ @test bk_z(1.0)[1] isa AbstractVector{<:Real} # projection
+
+ # Iterators.Filter
+ @test gradient(2:9) do xs
+ sum([x^2 for x in xs if iseven(x)])
+ end == ([4, 0, 8, 0, 12, 0, 16, 0],)
+
+ @test gradient(2:9) do xs
+ sum(x^2 for x in xs if iseven(x)) # same without collect
+ end == ([4, 0, 8, 0, 12, 0, 16, 0],)
+
+ # Iterators.Product
+ @test gradient(1:10, 3:7) do xs, ys
+ sum([x^2+y for x in xs, y in ys])
+ end == (10:10:100, fill(10, 5))
+
+ @test_broken gradient(1:10, 3:7) do xs, ys
+ sum(x^2+y for x in xs, y in ys) # same without collect
+ # Here @adjoint function Iterators.product(xs...) gets dy = (iterators = (nothing, nothing),)
+ end == (10:10:100, fill(10, 5))
+
+ # Repeat that test without sum(iterator) -- also receives dy = (iterators = (nothing, nothing),)
+ function prod_acc(xs, ys)
+ out = 0
+ # for (x,y) in Iterators.product(xs, ys)
+ # out += x^2+y
+ for xy in Iterators.product(xs, ys)
+ out += xy[1]^2 + xy[2]
+ end
+ out
+ end
+ @test prod_acc(1:10, 3:7) == sum(x^2+y for x in 1:10, y in 3:7)
+ gradient(prod_acc, 1:10, 3:7) == (nothing, nothing) # sadly
+ @test_broken gradient(prod_acc, 1:10, 3:7) == (10:10:100, fill(10, 5))
+
+ @test gradient(rand(2,3)) do A
+ sum([A[i,j] for i in 1:1, j in 1:2])
+ end == ([1 1 0; 0 0 0],)
+
+ @test gradient(ones(3,5), 1:7) do xs, ys
+ sum([x+y for x in xs, y in ys])
+ end == (fill(7, 3,5), fill(15, 7))
+
+ bk_p = pullback((xs,ys) -> sum([x/y for x in xs, y in ys]), Diagonal([3,4,5]), [6,7]')[2]
+ @test bk_p(1.0)[1] isa Diagonal # projection
+ @test bk_p(1.0)[2] isa Adjoint
+
+ # Iterators.Product with enumerate
+ @test gradient([2 3; 4 5]) do xs
+ sum([x^i+y for (i,x) in enumerate(xs), y in xs])
+ end == ([8 112; 36 2004],)
+end
+
# https://github.com/JuliaDiff/ChainRules.jl/issues/257
@testset "Keyword Argument Passing" begin
struct Type1{VJP}
diff --git a/test/gradcheck.jl b/test/gradcheck.jl
index ac7893cfa..66e558869 100644
--- a/test/gradcheck.jl
+++ b/test/gradcheck.jl
@@ -44,26 +44,27 @@ end
Random.seed!(0)
-@testset "println, show, print" begin
+@testset "println, show, string, etc" begin
function foo(x)
Base.show(x)
Base.print(x)
+ Base.print(stdout, x)
Base.println(x)
+ Base.println(stdout, x)
Core.show(x)
Core.print(x)
Core.println(x)
return x
end
+ println("The following printout is from testing that `print` doesn't upset gradients:")
@test gradtest(foo, [5.0])
-end
-@testset "string, repr" begin
- function foo(x)
+ function bar(x)
string(x)
repr(x)
return x
end
- @test gradtest(foo, [5.0])
+ @test gradtest(bar, [5.0])
end
@@ -356,11 +357,11 @@ end
@test gradient(x -> map(+, x, (1,2,3))[1], (4,5,6)) == ((1.0, nothing, nothing),)
@test gradient(x -> map(+, x, [1,2,3])[1], (4,5,6)) == ((1.0, 0.0, 0.0),)
- @test_broken gradient(x -> map(+, x, (1,2,3))[1], [4,5,6]) == ([1,0,0],) # Gradient [1.0, 0.0, 0.0] should be a tuple, since v0.6.0 at least
+ @test gradient(x -> map(+, x, (1,2,3))[1], [4,5,6]) == ([1,0,0],)
# mismatched lengths, should zip
- @test_broken gradient(x -> map(+, x, [1,2,3,99])[1], (4,5,6)) == ((1.0, 0.0, 0.0),) # BoundsError: attempt to access 3-element Vector{Float64} at index [4]
- @test_broken gradient(x -> map(+, x, [1,2,3])[1], (4,5,6,99)) == ((1.0, 0.0, 0.0, nothing),) # DimensionMismatch("variable with size(x) == (4,) cannot have a gradient with size(dx) == (3,)
+ @test gradient(x -> map(+, x, [1,2,3,99])[1], (4,5,6)) == ((1.0, 0.0, 0.0),)
+ @test gradient(x -> map(+, x, [1,2,3])[1], (4,5,6,99)) == ((1.0, 0.0, 0.0, nothing),)
end
@testset "Alternative Pmap Dispatch" begin
@@ -383,6 +384,23 @@ end
@test gradient(x -> sum(map(f, x)), 1:10) == (10:-1:1,)
end
+@testset "vararg map" begin
+ # early stop
+ if VERSION >= v"1.5"
+ # In Julia 1.4 and earlier, map(*,rand(5),[1,2,3]) is a DimensionMismatch
+ @test gradient(x -> sum(map(*,x,[1,2,3])), rand(5)) == ([1,2,3,0,0],)
+ end
+ @test gradient(x -> sum(map(*,x,(1,2,3))), rand(5)) == ([1,2,3,0,0],)
+ @test gradient(x -> sum(map(*,x,[1,2,3])), Tuple(rand(5))) == ((1.0, 2.0, 3.0, nothing, nothing),)
+
+ # mixed shapes
+ @test gradient((x,y) -> sum(map(*,x,y)), [1,2,3,4], [1 2; 3 4]) == ([1,3,2,4], [1 3; 2 4])
+ @test gradient((x,y) -> sum(map(*,x,y)), [1,2,3], [1 2; 3 4]) == ([1,3,2], [1 3; 2 0])
+ @test gradient((x,y) -> sum(map(*,x,y)), (1,2,3), [1 2; 3 4]) == ((1,3,2), [1 3; 2 0])
+ @test gradient((x,y) -> sum(map(*,x,y)), [1,2,3,4,5], [1 2; 3 4]) == ([1,3,2,4,0], [1 3; 2 4])
+ @test gradient((x,y) -> sum(map(*,x,y)), (1,2,3,4,5), [1 2; 3 4]) == ((1,3,2,4,nothing), [1 3; 2 4])
+end
+
@testset "sort" begin
@test gradtest(sort, 5)
correct = [
@@ -1748,6 +1766,63 @@ end
gradient(x->norm(x*[1im 1]), 1.23)
end
+@testset "zip & Iterators.product" begin
+ # roughly from https://github.com/FluxML/Zygote.jl/issues/221
+ d = rand(7)
+ @test gradient(rand(11)) do s
+ tot = 0
+ for (a, b) in zip(s, d)
+ tot += 13a + 17b
+ end
+ tot
+ end == ([13, 13, 13, 13, 13, 13, 13, 0, 0, 0, 0],)
+
+ @test gradient([1,2,3,4], [1 2; 3 4]) do x, y # mismatched shapes
+ tot = 0
+ for (a,b) in zip(x,y)
+ tot += a * b
+ end
+ tot
+ end == ([1, 3, 2, 4], [1 3; 2 4]) # Δy is a matrix
+
+ @test gradient((1,2,3), [1 2; 3 4]) do x, y # ... and lengths, and a tuple
+ tot = 0
+ for (a,b) in zip(x,y)
+ tot += a * b
+ end
+ tot
+ end == ((1, 3, 2), [1 3; 2 0]) # map stops early, Δy reshaped to a matrix
+
+ # similar for enumertate -- tests NamedTuple adjoint
+ @test gradient([2,3,4]) do x
+ tot = 0
+ for (i, x) in enumerate(x)
+ tot += x^i
+ end
+ tot
+ end == ([1, 6, 3 * 4^2],)
+
+ # and for Iterators.product
+ @test gradient([3,4,5], [6,7,8]) do x, y
+ tot = 0
+ for (a,b) in Iterators.product(x, y)
+ tot += a^2 + 10b
+ end
+ tot
+ end == ([18, 24, 30], [30, 30, 30])
+
+ @test gradient([3,4], [1,2,3]) do x, y
+ tot = 0
+ for ab in Iterators.product(x, y)
+ tot += *(ab...)
+ end
+ tot
+ end == ([6,6], [7,7,7])
+
+ # from https://github.com/FluxML/Zygote.jl/pull/785#issuecomment-740562889
+ @test gradient(A -> sum([A[i,j] for i in 1:3, j in 1:3]), ones(3,3)) == (ones(3,3),)
+end
+
# https://github.com/FluxML/Zygote.jl/issues/804
@testset "Unused comprehension" begin
# Comprehension is used.
diff --git a/test/runtests.jl b/test/runtests.jl
index 022727fbe..d1b34da77 100644
--- a/test/runtests.jl
+++ b/test/runtests.jl
@@ -3,10 +3,13 @@ using Zygote: gradient, ZygoteRuleConfig
using CUDA
using CUDA: has_cuda
+@testset "all" begin # Overall testset ensures it keeps running after failure
+
if has_cuda()
@testset "CUDA tests" begin
include("cuda.jl")
end
+ @info "CUDA tests have run"
else
@warn "CUDA not found - Skipping CUDA Tests"
end
@@ -31,6 +34,7 @@ end
@testset "Features" begin
include("features.jl")
+ @info "features.jl done"
end
@testset "Forward" begin
@@ -43,6 +47,7 @@ end
@testset "ChainRules" begin
include("chainrules.jl")
+ @info "chainrules.jl done"
end
@testset "Gradients" begin
@@ -56,3 +61,5 @@ end
@testset "Compiler" begin
include("compiler.jl")
end
+
+end # @testset "all"
diff --git a/test/tools.jl b/test/tools.jl
index 717612284..77b268646 100644
--- a/test/tools.jl
+++ b/test/tools.jl
@@ -48,17 +48,16 @@ end
end
function Tester(p)
- @show Zygote.isderiving(p)
+ # @show Zygote.isderiving(p)
cpu_offload = Zygote.isderiving(p) ? 0.0 : 0.2
Tester(cpu_offload)
end
- function f(p)
+ function f56(p)
sum(Tester(p).cpu_offload .* p)
end
- p = [1.0]
- gs = gradient(f, p)
- @test gs[1] == [0.]
+ gs56 = gradient(f56, [1.0])
+ @test gs56[1] == [0.]
end
From f2bb45d232e4eb1d5ecbcd0b119d91041dd3e5ad Mon Sep 17 00:00:00 2001
From: ST John
Date: Thu, 4 Nov 2021 16:15:53 +0200
Subject: [PATCH 252/490] remove `@adjoint function cholesky`
---
src/lib/array.jl | 29 -----------------------------
1 file changed, 29 deletions(-)
diff --git a/src/lib/array.jl b/src/lib/array.jl
index 15b994564..b35883e59 100644
--- a/src/lib/array.jl
+++ b/src/lib/array.jl
@@ -540,35 +540,6 @@ end
@adjoint Matrix(A::LinearAlgebra.HermOrSym{T,S}) where {T,S} = Matrix(A),
Δ -> (convert(S, Δ),)
-@adjoint function cholesky(Σ::Real)
- C = cholesky(Σ)
- return C, Δ::NamedTuple->(Δ.factors[1, 1] / (2 * C.U[1, 1]),)
-end
-
-@adjoint function cholesky(Σ::Diagonal; check = true)
- C = cholesky(Σ, check = check)
- return C, Δ::NamedTuple -> begin
- issuccess(C) || throw(PosDefException(C.info))
- return Diagonal(diag(Δ.factors) .* inv.(2 .* C.factors.diag)), nothing
- end
-end
-
-# Implementation due to Seeger, Matthias, et al. "Auto-differentiating linear algebra."
-@adjoint function cholesky(Σ::Union{StridedMatrix, Symmetric{<:Real, <:StridedMatrix}}; check = true)
- C = cholesky(Σ, check = check)
- return C, function(Δ::NamedTuple)
- issuccess(C) || throw(PosDefException(C.info))
- U, Ū = C.U, Δ.factors
- Σ̄ = similar(U.data)
- Σ̄ = mul!(Σ̄, Ū, U')
- Σ̄ = copytri!(Σ̄, 'U')
- Σ̄ = ldiv!(U, Σ̄)
- Σ̄ = BLAS.trsm!('R', 'U', 'T', 'N', one(eltype(Σ)), U.data, Σ̄)
- Σ̄[diagind(Σ̄)] ./= 2
- return (UpperTriangular(Σ̄),)
- end
-end
-
@adjoint function lyap(A::AbstractMatrix, C::AbstractMatrix)
X = lyap(A, C)
return X, function (X̄)
From 4ed3a86db708a27bfe0afd5aeaa6408dd8d43a3e Mon Sep 17 00:00:00 2001
From: Michael Abbott <32575566+mcabbott@users.noreply.github.com>
Date: Sun, 7 Nov 2021 16:15:36 -0500
Subject: [PATCH 253/490] Insert `_project` into `getproperty`'s gradient, and
then improve `z2d` etc. to restore stability (#1104)
* insert _project into getproperty
* use zygote2differential in _project
* improve type-stability of zygote2differential
* fix 1 test and break 2
* 2 not broken in fact
* skip inference test
* skip more inference tests
* improve inference for 1.6
* skip a test on 1.6
* skip 2
* handle nothings
* re-enable some inference tests on 1.6
* arrays of abstract tangents, and NamedTuple tests
* reverse dispatch for wrap_chainrules_input
* fix a typo
* fix more notation
* restore a test
* add DynamicPPL.jl
* fix a test
* try removing piracy
* restore some piracy, tidy
* reinterpret
* reinterpret
* collapse nothings
* DistributionsAD too
* collapse zeros in z2d
* comments
* indents
* change one comment
---
.github/workflows/Downstream.yml | 2 +
Project.toml | 2 +-
src/compiler/chainrules.jl | 154 +++++++++++++++++++++++--------
src/lib/lib.jl | 3 +-
test/chainrules.jl | 24 +++++
test/compiler.jl | 4 +-
test/features.jl | 37 +++++++-
test/gradcheck.jl | 12 +++
test/runtests.jl | 88 +++++++++---------
9 files changed, 239 insertions(+), 87 deletions(-)
diff --git a/.github/workflows/Downstream.yml b/.github/workflows/Downstream.yml
index c40bc6617..308754b0a 100644
--- a/.github/workflows/Downstream.yml
+++ b/.github/workflows/Downstream.yml
@@ -20,6 +20,8 @@ jobs:
package:
- {user: FluxML, repo: Flux.jl, group: All}
- {user: FluxML, repo: NNlib.jl, group: All}
+ - {user: TuringLang, repo: DynamicPPL.jl, group: All}
+ - {user: TuringLang, repo: DistributionsAD.jl, group: Zygote}
- {user: SciML, repo: DiffEqFlux.jl, group: Layers}
- {user: SciML, repo: NeuralPDE.jl, group: NNPDE}
steps:
diff --git a/Project.toml b/Project.toml
index 04196b602..a8ea16a25 100644
--- a/Project.toml
+++ b/Project.toml
@@ -1,6 +1,6 @@
name = "Zygote"
uuid = "e88e6eb3-aa80-5325-afca-941959d7151f"
-version = "0.6.29"
+version = "0.6.30"
[deps]
AbstractFFTs = "621f4979-c628-5d54-868e-fcf4e3e8185c"
diff --git a/src/compiler/chainrules.jl b/src/compiler/chainrules.jl
index 9bfde430a..b3157f289 100644
--- a/src/compiler/chainrules.jl
+++ b/src/compiler/chainrules.jl
@@ -115,28 +115,61 @@ for T_outer in (:Tuple, :NamedTuple)
ChainRulesCore.backing(xp) # this is accessing ChainRulesCore internals, but it is prob safe enough, and it is fastest
end
end
-# Could `reinterpret` instead of broadcasting here -- TODO
-@inline wrap_chainrules_output(xs::AbstractArray{<:ChainRules.Tangent}) = wrap_chainrules_output.(xs)
+wrap_chainrules_output(dxs::AbstractArray{<:Number}) = dxs
+wrap_chainrules_output(dxs::AbstractArray{<:AbstractArray{<:Number}}) = dxs
+wrap_chainrules_output(dxs::AbstractArray) = map(wrap_chainrules_output, dxs)
+#=
+# As an optimisation, we can convert by `reinterpret` for bitstypes, e.g. arrays of tuples of numbers
+@inline function wrap_chainrules_output(dxs::AbstractArray{<:ChainRules.Tangent{<:Any, B}}) where {B}
+ if isbitstype(B)
+ # B is the backing type. It still contains NoTangent etc, which need converting to Nothing
+ reinterpret(wrap_chainrules_output(B), dxs)
+ else
+ map(wrap_chainrules_output, dxs)
+ end
+end
+wrap_chainrules_output(::Type{<:AbstractZero}) = Nothing
+wrap_chainrules_output(::Type{NamedTuple{L,T}}) where {L,T} = NamedTuple{L,wrap_chainrules_output(T)}
+@generated function wrap_chainrules_output(::Type{T}) where T<:Tuple
+ inner = map(wrap_chainrules_output, T.parameters)
+ :(Tuple{$(inner...)})
+end
+=#
"""
- wrap_chainrules_input(x)
+ wrap_chainrules_input(dx)
-Convert `x` from the format Zygote uses internally to differentials types ChainRules uses.
+Convert `dx` from the format Zygote uses internally to differentials types ChainRules uses.
"""
-@inline wrap_chainrules_input(x) = x
+@inline wrap_chainrules_input(dx) = dx
@inline wrap_chainrules_input(::Nothing) = ChainRules.ZeroTangent()
+@inline wrap_chainrules_input(::Tuple{Vararg{Nothing}}) = ChainRules.ZeroTangent()
@inline wrap_chainrules_input(::AbstractArray{Nothing}) = ChainRules.ZeroTangent()
-@inline function wrap_chainrules_input(xs::Union{Tuple, NamedTuple})
- xp = map(wrap_chainrules_input, xs)
- ChainRules.Tangent{Any, typeof(xp)}(xp)
+@inline function wrap_chainrules_input(dxs::Union{Tuple, NamedTuple})
+ xp = map(wrap_chainrules_input, dxs)
+ # This produces Tangent{Any} since it does not get to see the primal, `x`.
+ ChainRulesCore.Tangent{Any, typeof(xp)}(xp)
end
# For mutable types, including x=Ref(1), Zygote makes Ref{Any}(::NamedTuple)
-@inline wrap_chainrules_input(x::Ref) = wrap_chainrules_input(x[])
-# Could `reinterpret` instead of broadcasting here -- TODO
-@inline wrap_chainrules_input(xs::AbstractArray{<:Ref}) = wrap_chainrules_input.(xs)
-@inline wrap_chainrules_input(xs::AbstractArray{<:Union{Nothing, <:Ref}}) = wrap_chainrules_input.(xs) # no test invented for this
-@inline wrap_chainrules_input(xs::AbstractArray{<:NamedTuple}) = wrap_chainrules_input.(xs)
-@inline wrap_chainrules_input(xs::AbstractArray{<:Union{Nothing, <:NamedTuple}}) = wrap_chainrules_input.(xs)
+@inline wrap_chainrules_input(dx::Ref) = wrap_chainrules_input(dx[])
+# For arrays, whitelist the safe ones, but always look inside Any[]:
+@inline wrap_chainrules_input(dxs::AbstractArray{<:Number}) = dxs
+@inline wrap_chainrules_input(dxs::AbstractArray{<:AbstractArray{<:Number}}) = dxs
+@inline wrap_chainrules_input(dxs::AbstractArray) = map(wrap_chainrules_input, dxs)
+
+#=
+# Could `reinterpret` instead here? See issue 1112.
+# One easy case, might be this:
+@inline wrap_chainrules_input(xs::Base.ReinterpretArray{<:NamedTuple, <:Tangent}) = parent(xs)
+
+# This is for `z2d` reinterpret below:
+wrap_chainrules_input(::Type{Nothing}) = NoTangent
+wrap_chainrules_input(::Type{NamedTuple{L,T}}) where {L,T} = NamedTuple{L,wrap_chainrules_input(T)}
+@generated function wrap_chainrules_input(::Type{T}) where T<:Tuple
+ inner = map(wrap_chainrules_input, T.parameters)
+ :(Tuple{$(inner...)})
+end
+=#
"""
_project(x, dx)
@@ -146,21 +179,13 @@ Also handles some Zygote-specific corrections, such as `x::Array, dx::Tuple`.
Safe to apply to arbitrary input.
"""
@inline function _project(x, dx)
- # Note that this use of `wrap_chainrules_input` has the primal `x`, so could
- # avoid making `Tangent{Any}`, perhaps via `zygote2differential` -- TODO.
- wrap_chainrules_output(ProjectTo(x)(wrap_chainrules_input(dx)))
+ wrap_chainrules_output(ProjectTo(x)(zygote2differential(dx, x)))
end
# Restore splatted arrays
_project(x::AbstractArray, dx::Tuple) = _project(x, reshape(collect(dx), axes(x)))
# Piracy:
-# wrap_chainrules_input doesn't handle array of Union{Int,Nothing}
-(::ChainRulesCore.ProjectTo)(::Nothing) = ChainRulesCore.NoTangent()
-
-# CRC likes Tangent{<:Complex}, but Zygote makes Tangent{Any}
-(project::ProjectTo{<:Complex})(dx::Tangent) = project(Complex(dx.re, dx.im))
-
# CRC likes Tangent{AbstractArray}, but Zygote makes Tangent{Any}
# in particular this would hit https://github.com/JuliaDiff/ChainRulesCore.jl/blob/2ec2549b73b22bc08f554dae864fb650cfb9c3d7/src/projection.jl#L139
# if we were not losing track of the Primal in the Tangent
@@ -236,32 +261,85 @@ end
zygote2differential(dx, primal)
Convert input `dx` from the Zygote format to the ChainRules differential types.
+This is similar to `wrap_chainrules_input(dx)`, but because it gets `primal::T`,
+it can turn `NamedTuple`s into `Tangent{T}(...)` not `Tangent{Any}(...)`.
"""
zygote2differential(x, primal) = z2d(x, primal)
zygote2differential(::Nothing, ::Any) = NoTangent()
zygote2differential(t::Tuple, primal::Tuple) = map(z2d, t, primal)
zygote2differential(t::Tuple, primal) = (@warn "primal should be a tuple, not $primal"; return t)
-z2d(x, ::Any) = x
+
z2d(::Nothing, ::Any) = NoTangent()
-z2d(a::AbstractArray{<:Number}, primal::AbstractArray{T}) where T = a
-# Could probably `reinterpret` instead of broadcasting here -- TODO
-z2d(a::AbstractArray, primal::AbstractArray{T}) where T = z2d.(a, primal)
+z2d(::Tuple{Vararg{Nothing}}, ::Tuple) = NoTangent() # collapse all-zero case
+z2d(dx, ::Any) = dx
+z2d(dx::AbstractArray{<:Number}, primal::AbstractArray) = dx
+z2d(dx::AbstractArray{<:AbstractArray{<:Number}}, primal::AbstractArray) = dx
+z2d(dx::AbstractArray, primal::AbstractArray) = map(z2d, dx, primal)
+#=
+# As an optimisation, we can convert by `reinterpret` for bitstypes, e.g. arrays of tuples of numbers
+function z2d(dx::AbstractArray{S}, primal::AbstractArray{P}) where {S,P}
+ if isbitstype(S)
+ T = wrap_chainrules_input(S)
+ reinterpret(Tangent{P,T}, dx)
+ else
+ map(z2d, dx, primal)
+ end
+end
+=#
+
# Note: this should never be hit if we are converting things right, but it seems to be
# happening in the wild for sufficiently weird functions/types.
# This fixes most (all?) cases, but it would be good to find what we miss.
z2d(x::Union{AbstractZero, Tangent}, ::Any) = return x
-function z2d(t::Tuple, primal::Tuple)
- tp::Tuple = map(z2d, t, primal)
- primal_type = typeof(primal)
- return canonicalize(Tangent{primal_type, typeof(tp)}(tp))
+
+function z2d(delta::Tuple, primal::Tuple)
+ backing = map(z2d, delta, primal)
+ if backing isa Tuple{Vararg{AbstractZero}}
+ return NoTangent() # collapse all-zero case
+ else
+ return canonicalize(Tangent{typeof(primal), typeof(backing)}(backing))
+ end
end
-function z2d(t::NamedTuple, primal)
- primal_type = typeof(primal)
- fnames = fieldnames(primal_type)
- complete_t = NamedTuple{fnames}(fn in keys(t) ? t[fn] : nothing for fn in fnames)
- primals = NamedTuple{fnames}(getfield(primal, fn) for fn in fnames)
- tp::NamedTuple = map(z2d, complete_t, primals)
- return canonicalize(Tangent{primal_type, typeof(tp)}(tp))
+# Dict handling in Zygote is a mess... should this become a `Tangent{Dict,Dict}` ?
+# Right now it uses a NamedTuple but not for fields of the AbstractDict struct
+z2d(dx::NamedTuple, primal::AbstractDict) = dx
+
+function z2d(delta::NamedTuple, primal::T) where T # arbitrart struct
+ fnames = fieldnames(T)
+ deltas = map(n -> get(delta, n, nothing), fnames)
+ primals = map(n -> getfield(primal, n), fnames)
+ inner = map(z2d, deltas, primals) # recurse into fields
+ if inner isa Tuple{Vararg{AbstractZero}}
+ return NoTangent() # collapse all-zero case
+ else
+ backing = NamedTuple{fnames}(inner)
+ return canonicalize(Tangent{T, typeof(backing)}(backing))
+ end
end
+
+# Dict case matches signature for ambiguity reasons:
+z2d(dx::NamedTuple{L,S}, primal::AbstractDict) where {L,S<:Tuple{Vararg{Union{Number,Nothing}}}} = dx
+# On Julia <= 1.6, this fixes easy cases which do not require recursion into fields, e.g.
+# @inferred Zygote.z2d((re=1, im=nothing), 3.0+im)
+@generated function z2d(delta::NamedTuple{L,S}, primal::T) where {L,S<:Tuple{Vararg{Union{Number,Nothing}}}, T}
+ fnames = fieldnames(T)
+ deltas = map(fnames) do n
+ i = findfirst(isequal(n), L)
+ if i == nothing || S.parameters[i] == Nothing
+ :(NoTangent())
+ else
+ :(delta.$n)
+ end
+ end
+ if all(d -> d == :(NoTangent()), deltas)
+ return :(NoTangent()) # collapse all-zero case
+ else
+ return quote
+ backing = NamedTuple{$fnames}(($(deltas...),))
+ Tangent{$T, typeof(backing)}(backing)
+ end
+ end
+end
+
z2d(dx::Ref, primal) = z2d(dx[], primal) # mutable structs
diff --git a/src/lib/lib.jl b/src/lib/lib.jl
index 96422d78c..f154ecd2a 100644
--- a/src/lib/lib.jl
+++ b/src/lib/lib.jl
@@ -227,7 +227,8 @@ end
function back(Δ)
accum_param(__context__, val, Δ) === nothing && return
if isimmutable(x)
- ((; nt_nothing(x)..., pair(Val(f), Δ, x)...), nothing)
+ dx = (; nt_nothing(x)..., pair(Val(f), Δ, x)...)
+ (_project(x, dx), nothing)
else
dx = grad_mut(__context__, x)
dx[] = (; dx[]..., pair(Val(f), accum(getfield(dx[], f), Δ))...)
diff --git a/test/chainrules.jl b/test/chainrules.jl
index 80da51743..e34be0fa1 100644
--- a/test/chainrules.jl
+++ b/test/chainrules.jl
@@ -349,3 +349,27 @@ end
@fastmath x^2.0
end == (4.0,)
end
+
+@testset "zygote2differential inference" begin
+ @test @inferred(Zygote.z2d(1.0, 2.0)) isa Real
+ @test @inferred(Zygote.z2d([1,2,3], [4,5,6])) isa Vector
+ @test @inferred(Zygote.z2d((1, 2.0, 3+4im), (5, 6.0, 7+8im))) isa Tangent{<:Tuple}
+
+ # Below Julia 1.7, these need a @generated version to be inferred:
+ @test @inferred(Zygote.z2d((re=1,), 3.0+im)) isa Tangent{ComplexF64}
+ @test @inferred(Zygote.z2d((re=1, im=nothing), 3.0+im)) isa Tangent{ComplexF64}
+
+ # collapse nothings
+ @test @inferred(Zygote.z2d((nothing,), (1,))) === NoTangent()
+ @test @inferred(Zygote.z2d((nothing, nothing), (1,2))) === NoTangent()
+
+ # To test the generic case, we need a struct within a struct.
+ nested = Tangent{Base.RefValue{ComplexF64}}(; x=Tangent{ComplexF64}(; re=1, im=NoTangent()),)
+ if VERSION > v"1.7-"
+ @test @inferred(Zygote.z2d((; x=(; re=1)), Ref(3.0+im))) == nested
+ @test @inferred(Zygote.z2d((; x=(; re=nothing)), Ref(3.0+im))) === NoTangent()
+ else
+ @test Zygote.z2d((; x=(; re=1)), Ref(3.0+im)) == nested
+ @test Zygote.z2d((; x=(; re=nothing)), Ref(3.0+im)) === NoTangent()
+ end
+end
diff --git a/test/compiler.jl b/test/compiler.jl
index eec71e53d..bc37d271e 100644
--- a/test/compiler.jl
+++ b/test/compiler.jl
@@ -163,7 +163,9 @@ end
y, back = @inferred pullback(x -> x.m, g)
@test y == getfield(g, :m)
# This type instability is due to the handling of non-bitstypes in `accum_param`
- @test Base.return_types(back, Tuple{Vector{Float64}}) == Any[Union{Tuple{Nothing}, typeof(((m = [1.0, 0.0, 0.0], P = nothing),))}]
+ if VERSION > v"1.7-"
+ @test Base.return_types(back, Tuple{Vector{Float64}}) == Any[Union{Tuple{Nothing}, typeof(((m = [1.0, 0.0, 0.0], P = nothing),))}]
+ end
@test back([1., 0, 0]) == ((m = [1.0, 0.0, 0.0], P = nothing),)
Base.getproperty(g::Gaussian, s::Symbol) = 2getfield(g, s)
diff --git a/test/features.jl b/test/features.jl
index 545db0279..839e98cc4 100644
--- a/test/features.jl
+++ b/test/features.jl
@@ -176,9 +176,13 @@ end
@test gradient(t -> t[1]*t[2], (2, 3)) == ((3, 2),)
-@test gradient(x -> x.re, 2+3im) === (1.0 + 0.0im,)
+@test gradient(x -> x.re, 2+3im) === (1.0 + 0.0im,) # one NamedTuple
+@test gradient(x -> x.re*x.im, 2+3im) == (3.0 + 2.0im,) # two, different fields
+@test gradient(x -> x.re*x.im + x.re, 2+3im) == (4.0 + 2.0im,) # three, with accumulation
-@test gradient(x -> x.re*x.im, 2+3im) == (3.0 + 2.0im,)
+@test gradient(x -> abs2(x * x.re), 4+5im) == (456.0 + 160.0im,) # gradient participates
+@test gradient(x -> abs2(x * real(x)), 4+5im) == (456.0 + 160.0im,) # function not getproperty
+@test gradient(x -> abs2(x * getfield(x, :re)), 4+5im) == (456.0 + 160.0im,)
struct Bar{T}
a::T
@@ -418,6 +422,11 @@ end
@test gradient((x,y,z) -> sum((x,y,z)[1:2]), 7, 8.8, 9.9) == (1.0, 1.0, nothing)
@test gradient((x,y,z) -> sum((x,y,z)[[1,2,1]]), 1,2,3) == (2, 1, nothing)
+
+ @test gradient(xs -> sum(x -> x[2], xs), [(1,2,3), (4,5,6)]) == ([(nothing, 1.0, nothing), (nothing, 1.0, nothing)],)
+ @test gradient(xs -> sum(x -> prod(x[2:3]), xs), [(1,2,3), (4,5,6)]) == ([(nothing, 3.0, 2.0), (nothing, 6.0, 5.0)],)
+ @test gradient(xs -> sum(first, xs), fill((4,3),2)) == ([(1.0, nothing), (1.0, nothing)],)
+ @test gradient(xs -> sum(x -> abs2(x[1]), xs), fill((4,3),2)) == ([(8.0, nothing), (8.0, nothing)],)
end
@testset "@timed" begin
@@ -452,6 +461,13 @@ end
@test gradient(x -> x.x^2 + x.x, Ref(3)) === ((x = 7.0,),)
@test gradient(x -> real(x.x^2 + im * x.x), Ref(4)) === ((x = 8.0,),)
+ # Field access of contents:
+ @test gradient(x -> abs2(x.x) + 7 * x.x.re, Ref(1+im)) == ((x = 9.0 + 2.0im,),)
+ @test_broken gradient(x -> abs2(x[1].x) + 7 * x[1].x.re, [Ref(1+im)]) == ([(x = 9.0 + 2.0im,)],)
+ @test_broken gradient(x -> abs2(x[1].x) + 7 * real(x[1].x), [Ref(1+im)]) == ([(x = 9.0 + 2.0im,)],) # worked on 0.6.0, 0.6.20
+
+ @test_broken gradient(x -> abs2(x[].x) + 7 * real(x[].x), Ref(Ref(1+im))) == ((x = 9.0 + 2.0im,),) # gives nothing, same in 0.6.0
+
# Array of mutables:
@test gradient(x -> sum(getindex.(x).^2), Ref.(1:3))[1] == [(;x=2i) for i in 1:3]
@test gradient(x -> sum(abs2∘getindex, x), Ref.(1:3))[1] == [(;x=2i) for i in 1:3]
@@ -464,6 +480,17 @@ end
@test gradient(x -> sum(sum, Ref(x) .* [1,2,3]), [4,5]) == ([6.0, 6.0],)
end
+@testset "NamedTuples" begin
+ @test gradient(x -> x.a, (a=1, b=2)) == ((a = 1, b = nothing),)
+ @test gradient(x -> x[1].a, [(a=1, b=2)]) == ([(a = 1, b = nothing)],)
+ @test gradient(x -> x[1].a, [(a=1, b=2), (a=3, b=4)]) == ([(a = 1, b = nothing), nothing],)
+
+ # Mix with Ref
+ @test gradient(x -> x[].a, Ref((a=1, b=2))) == ((x = (a = 1, b = nothing),),)
+ @test gradient(x -> x[1][].a, [Ref((a=1, b=2)), Ref((a=3, b=4))]) == ([(x = (a = 1, b = nothing),), nothing],)
+ @test gradient(x -> x[1].a, [(a=1, b=2), "three"]) == ([(a = 1, b = nothing), nothing],)
+end
+
function type_test()
Complex{<:Real}
end
@@ -692,4 +719,10 @@ end
@test gradient(x -> sum(gradient(y -> sum(y.^2), x)[1]), [1, 2])[1] ≈ [2, 2]
@test gradient(x -> sum(gradient(y -> sum(sin.(y)), x)[1]), [1, 2])[1] ≈ [-0.8414709848078965, -0.9092974268256817]
@test gradient(x -> sum(abs, gradient(y -> sum(log.(2 .* exp.(y)) .^ 2), x)[1]), [1, 2])[1] ≈ [2,2]
+
+ # getproperty, Tangents, etc
+ @test gradient(xs -> sum((x->x.im^2).(xs)), [1+2im,3])[1] == [4im, 0]
+ @test gradient(xs -> sum((x->x.im^2), xs), [1+2im,3])[1] == [4im, 0]
+ @test gradient(xs -> sum(map(x->x.im^2, xs)), [1+2im,3])[1] == [4im, 0]
+ @test gradient(xs -> mapreduce(x->x.im^2, +, xs), [1+2im,3])[1] == [4im, 0]
end
diff --git a/test/gradcheck.jl b/test/gradcheck.jl
index 66e558869..ef958da48 100644
--- a/test/gradcheck.jl
+++ b/test/gradcheck.jl
@@ -1878,6 +1878,18 @@ end
a = rand(3)
@test Zygote.gradient(x->sum(x .+ rand.()), a) == (ones(3),)
+@testset "Zygote 660" begin
+ # https://github.com/FluxML/Zygote.jl/pull/660
+ function example(x,N)
+ ax = axes(x)
+ extraAxe = ax[2+N:end]
+ filledLoc = fill(1, N)
+ return x[:, filledLoc..., extraAxe...]
+ end
+ y, back = pullback(example, randn(5,3,4,3), 2)
+ @test back(zero(y).=1) isa Tuple{Array{Float64,4}, Nothing}
+end
+
@testset "CRC issue 440" begin
# https://github.com/JuliaDiff/ChainRulesCore.jl/issues/440
f(x,y) = sum(sum, [[x[i],y[i]] for i=1:length(x)])
diff --git a/test/runtests.jl b/test/runtests.jl
index d1b34da77..17ebb3997 100644
--- a/test/runtests.jl
+++ b/test/runtests.jl
@@ -5,61 +5,61 @@ using CUDA: has_cuda
@testset "all" begin # Overall testset ensures it keeps running after failure
-if has_cuda()
- @testset "CUDA tests" begin
- include("cuda.jl")
+ if has_cuda()
+ @testset "CUDA tests" begin
+ include("cuda.jl")
+ end
+ @info "CUDA tests have run"
+ else
+ @warn "CUDA not found - Skipping CUDA Tests"
end
- @info "CUDA tests have run"
-else
- @warn "CUDA not found - Skipping CUDA Tests"
-end
-@testset "Interface" begin
- include("interface.jl")
-end
+ @testset "Interface" begin
+ include("interface.jl")
+ end
-@testset "Tools" begin
- include("tools.jl")
-end
+ @testset "Tools" begin
+ include("tools.jl")
+ end
-@testset "Utils" begin
- include("utils.jl")
-end
+ @testset "Utils" begin
+ include("utils.jl")
+ end
-@testset "lib" begin
- include("lib/number.jl")
- include("lib/lib.jl")
- include("lib/array.jl")
-end
+ @testset "lib" begin
+ include("lib/number.jl")
+ include("lib/lib.jl")
+ include("lib/array.jl")
+ end
-@testset "Features" begin
- include("features.jl")
- @info "features.jl done"
-end
+ @testset "Features" begin
+ include("features.jl")
+ @info "features.jl done"
+ end
-@testset "Forward" begin
- include("forward/forward.jl")
-end
+ @testset "Forward" begin
+ include("forward/forward.jl")
+ end
-@testset "Data Structures" begin
- include("structures.jl")
-end
+ @testset "Data Structures" begin
+ include("structures.jl")
+ end
-@testset "ChainRules" begin
- include("chainrules.jl")
- @info "chainrules.jl done"
-end
+ @testset "ChainRules" begin
+ include("chainrules.jl")
+ @info "chainrules.jl done"
+ end
-@testset "Gradients" begin
- include("gradcheck.jl")
-end
+ @testset "Gradients" begin
+ include("gradcheck.jl")
+ end
-@testset "Complex" begin
- include("complex.jl")
-end
+ @testset "Complex" begin
+ include("complex.jl")
+ end
-@testset "Compiler" begin
- include("compiler.jl")
-end
+ @testset "Compiler" begin
+ include("compiler.jl")
+ end
end # @testset "all"
From 8b9916a844958cf9cb2d808689759a03fecb8a3c Mon Sep 17 00:00:00 2001
From: "github-actions[bot]"
<41898282+github-actions[bot]@users.noreply.github.com>
Date: Tue, 23 Nov 2021 00:06:50 +0000
Subject: [PATCH 254/490] CompatHelper: bump compat for "SpecialFunctions" to
"2"
---
Project.toml | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/Project.toml b/Project.toml
index a8ea16a25..d7c0809fe 100644
--- a/Project.toml
+++ b/Project.toml
@@ -33,7 +33,7 @@ IRTools = "0.4"
MacroTools = "0.5"
NaNMath = "0.3"
Requires = "1.1"
-SpecialFunctions = "1.6"
+SpecialFunctions = "1.6, 2"
StatsFuns = "0.9.8"
ZygoteRules = "0.2.1"
julia = "1.3"
From 9e6f18262c5fa95a30f2c0120af95742221d4cd9 Mon Sep 17 00:00:00 2001
From: Avik Pal
Date: Tue, 23 Nov 2021 16:11:20 -0500
Subject: [PATCH 255/490] Fix buffer
---
src/lib/buffer.jl | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/src/lib/buffer.jl b/src/lib/buffer.jl
index 4d332c16d..f686340e6 100644
--- a/src/lib/buffer.jl
+++ b/src/lib/buffer.jl
@@ -46,7 +46,7 @@ _pullback(cx::AContext, ::typeof(Broadcast.materialize!), b::Buffer, x::Abstract
@adjoint function copy(b::Buffer)
copy(b), function (b̄)
- grad_mut(__context__, b)[:] = b̄
+ grad_mut(__context__, b)[:] .= b̄
return
end
end
From bc6fd2f20bc7c91c1d1258b582bb32a232ff73f6 Mon Sep 17 00:00:00 2001
From: Michael Abbott <32575566+mcabbott@users.noreply.github.com>
Date: Wed, 24 Nov 2021 20:19:13 -0500
Subject: [PATCH 256/490] Fix `specialize_method` for 1.8 (#1124)
* change Core.Compiler.specialize_method for 1.8
* bump version
---
Project.toml | 2 +-
src/lib/literal_getproperty.jl | 6 +++++-
2 files changed, 6 insertions(+), 2 deletions(-)
diff --git a/Project.toml b/Project.toml
index a8ea16a25..717c707ab 100644
--- a/Project.toml
+++ b/Project.toml
@@ -1,6 +1,6 @@
name = "Zygote"
uuid = "e88e6eb3-aa80-5325-afca-941959d7151f"
-version = "0.6.30"
+version = "0.6.31"
[deps]
AbstractFFTs = "621f4979-c628-5d54-868e-fcf4e3e8185c"
diff --git a/src/lib/literal_getproperty.jl b/src/lib/literal_getproperty.jl
index 1959e9462..c13f7a89b 100644
--- a/src/lib/literal_getproperty.jl
+++ b/src/lib/literal_getproperty.jl
@@ -29,7 +29,11 @@ function reflect(@nospecialize(sigtypes::Tuple), world::UInt = typemax(UInt))
end
method_index === 0 && return nothing
type_signature, raw_static_params, method = _methods[method_index]
- method_instance = Core.Compiler.specialize_method(method, type_signature, raw_static_params, false)
+ if VERSION < v"1.8-"
+ method_instance = Core.Compiler.specialize_method(method, type_signature, raw_static_params, false)
+ else
+ method_instance = Core.Compiler.specialize_method(method, type_signature, raw_static_params; preexisting=false)
+ end
method_signature = method.sig
static_params = Any[raw_static_params...]
return method_instance, method_signature, static_params
From 0d80a08633bdcc610406d514d47f8c23792bf085 Mon Sep 17 00:00:00 2001
From: Michael Abbott <32575566+mcabbott@users.noreply.github.com>
Date: Thu, 25 Nov 2021 13:43:46 -0500
Subject: [PATCH 257/490] Make tests pass on 1.8 (#1125)
* fix chainrules tests on 1.8
* bump
* rm comments
---
Project.toml | 4 ++--
test/chainrules.jl | 22 +++++++++++++++-------
2 files changed, 17 insertions(+), 9 deletions(-)
diff --git a/Project.toml b/Project.toml
index 717c707ab..a04795b53 100644
--- a/Project.toml
+++ b/Project.toml
@@ -1,6 +1,6 @@
name = "Zygote"
uuid = "e88e6eb3-aa80-5325-afca-941959d7151f"
-version = "0.6.31"
+version = "0.6.32"
[deps]
AbstractFFTs = "621f4979-c628-5d54-868e-fcf4e3e8185c"
@@ -29,7 +29,7 @@ ChainRulesTestUtils = "1"
DiffRules = "1.0"
FillArrays = "0.8, 0.9, 0.10, 0.11, 0.12"
ForwardDiff = "0.10"
-IRTools = "0.4"
+IRTools = "0.4.4"
MacroTools = "0.5"
NaNMath = "0.3"
Requires = "1.1"
diff --git a/test/chainrules.jl b/test/chainrules.jl
index e34be0fa1..bc32c879d 100644
--- a/test/chainrules.jl
+++ b/test/chainrules.jl
@@ -1,4 +1,6 @@
using ChainRulesCore, ChainRulesTestUtils, Zygote
+using Zygote: ZygoteRuleConfig
+
@testset "ChainRules integration" begin
@testset "ChainRules basics" begin
cr_inner_demo_rrule_hitcount = Ref(0)
@@ -265,13 +267,12 @@ end
@testset "ChainRulesCore.rrule_via_ad" begin
@testset "basic" begin
- # broken because Zygoye compresses `(NoTangent(), NoTangent())` into just NoTangent()
- # which ChainRulesTestUtils does not think is valid:
- @test_broken(rrule_via_ad(ZygoteRuleConfig(), round, 2.2) isa Tuple{NoTangent,NoTangent})
- # uncomment below when/if above is fixed
- # test_rrule(ZygoteRuleConfig(), round, 2.2; rrule_f=rrule_via_ad)
+ # Not marked as tests since perhaps ZeroTangent would be better.
+ rrule_via_ad(ZygoteRuleConfig(), round, 2.2)[2](1) == (NoTangent(), 0.0)
+ # But test_rrule is happy:
+ test_rrule(ZygoteRuleConfig(), round, 2.2; rrule_f=rrule_via_ad)
- test_rrule(ZygoteRuleConfig(), vcat, rand(3), rand(4); rrule_f=rrule_via_ad, check_inferred=false)
+ test_rrule(ZygoteRuleConfig(), vcat, rand(3), rand(4); rrule_f=rrule_via_ad)
test_rrule(ZygoteRuleConfig(), getindex, rand(5), 3; rrule_f=rrule_via_ad)
end
@@ -313,10 +314,13 @@ end
test_rrule(
ZygoteRuleConfig(), my_namedtuple, 1., (2.0, 2.4), 3.; rrule_f=rrule_via_ad
)
- test_rrule(ZygoteRuleConfig(), sum, (1.0, 2.0, 3.0); rrule_f=rrule_via_ad)
+ test_rrule(
+ ZygoteRuleConfig(), sum, (1.0, 2.0, 3.0); rrule_f=rrule_via_ad, check_inferred=false
+ )
test_rrule(
ZygoteRuleConfig(), sum, (a=1.0, b=2.0); rrule_f=rrule_via_ad, check_inferred=false
)
+ # There is at present no rrule for sum(::Tuple), so those are testing zygote directly.
end
@testset "arrays" begin
@@ -348,6 +352,10 @@ end
@test gradient(2.0) do x
@fastmath x^2.0
end == (4.0,)
+
+ @test gradient(2) do x
+ @fastmath log(x)
+ end == (1/2,)
end
@testset "zygote2differential inference" begin
From 86d1ba6cf9312d60e6321fc835610600eaf264c4 Mon Sep 17 00:00:00 2001
From: Brian Chen
Date: Sun, 5 Dec 2021 16:14:12 -0800
Subject: [PATCH 258/490] Fix incorrect `@forward`ing of `Base.in` on `Params`
---
src/compiler/interface.jl | 3 ++-
1 file changed, 2 insertions(+), 1 deletion(-)
diff --git a/src/compiler/interface.jl b/src/compiler/interface.jl
index dc8cb8d18..c52842945 100644
--- a/src/compiler/interface.jl
+++ b/src/compiler/interface.jl
@@ -149,7 +149,8 @@ Params(ps::Params) = ps
Params(xs::Tuple) = Params(collect(xs))
@forward Params.order Base.iterate, Base.length, Base.getindex
-@forward Params.params Base.in
+
+Base.in(ps::Params, x) = x in ps.params
Base.map(::typeof(_project), args::Tuple{Params}, grad) = grad # skip _project in gradient(f, ::Params)
From a152aaaa632568cba4aaa661f824c4c639ce3321 Mon Sep 17 00:00:00 2001
From: Brian Chen
Date: Sun, 5 Dec 2021 22:38:03 -0800
Subject: [PATCH 259/490] Add test for `in(x, ::Params)`
---
test/interface.jl | 8 ++++++++
1 file changed, 8 insertions(+)
diff --git a/test/interface.jl b/test/interface.jl
index ad1c7f46c..23afdfb1b 100644
--- a/test/interface.jl
+++ b/test/interface.jl
@@ -1,6 +1,14 @@
using Zygote: Grads
@testset "Params" begin
+ @testset "in" begin
+ w = rand(2,3)
+ b = rand(2)
+ ps = Params([w])
+ @test w ∈ ps
+ @test b ∉ ps
+ end
+
@testset "delete!" begin
w = rand(2,3)
b = rand(2)
From abad4e18efe6ba393a26c235d96a7db3b98f1feb Mon Sep 17 00:00:00 2001
From: pakk-minidose <56652555+pakk-minidose@users.noreply.github.com>
Date: Mon, 13 Dec 2021 16:24:06 +0100
Subject: [PATCH 260/490] Fix Zygote.jl#1135
---
src/compiler/interface.jl | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/src/compiler/interface.jl b/src/compiler/interface.jl
index c52842945..88bdf96bd 100644
--- a/src/compiler/interface.jl
+++ b/src/compiler/interface.jl
@@ -235,7 +235,7 @@ function copy!(x::AbstractVector, ps::Params)
x[i+1:i+length(p)] .= vec(p)
i += length(p)
end
- ps
+ x
end
"""
From 60bb9703aaa97a6629bc95858d9e8b4bf4cb4285 Mon Sep 17 00:00:00 2001
From: pakk-minidose <56652555+pakk-minidose@users.noreply.github.com>
Date: Mon, 13 Dec 2021 16:26:32 +0100
Subject: [PATCH 261/490] Changed copy! returned value to first argument
---
src/compiler/interface.jl | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/src/compiler/interface.jl b/src/compiler/interface.jl
index 88bdf96bd..18c0bd8eb 100644
--- a/src/compiler/interface.jl
+++ b/src/compiler/interface.jl
@@ -288,7 +288,7 @@ function copy!(gs::Grads, x::AbstractVector)
gs[p] .= reshape(x[i+1:i+length(p)], size(p))
i += length(p)
end
- x
+ gs
end
function copy!(x::AbstractVector, gs::Grads)
From ff7e3248de1db1281b62b01b3c09d162c2732c22 Mon Sep 17 00:00:00 2001
From: Avik Pal
Date: Mon, 13 Dec 2021 11:45:14 -0500
Subject: [PATCH 262/490] Add test for Buffer when it stores arrays
---
src/lib/buffer.jl | 11 ++++++++++-
test/gradcheck.jl | 17 +++++++++++++++++
2 files changed, 27 insertions(+), 1 deletion(-)
diff --git a/src/lib/buffer.jl b/src/lib/buffer.jl
index f686340e6..49b3ab7f0 100644
--- a/src/lib/buffer.jl
+++ b/src/lib/buffer.jl
@@ -45,8 +45,17 @@ _pullback(cx::AContext, ::typeof(Broadcast.materialize!), b::Buffer, x::Abstract
_pullback(cx, copyto!, b, x)
@adjoint function copy(b::Buffer)
- copy(b), function (b̄)
+ res = copy(b)
+
+ function copy_sensitivity(b̄)
+ grad_mut(__context__, b)[:] .= vec(b̄)
+ return
+ end
+
+ function copy_sensitivity(b̄::Tuple)
grad_mut(__context__, b)[:] .= b̄
return
end
+
+ return res, copy_sensitivity
end
diff --git a/test/gradcheck.jl b/test/gradcheck.jl
index ef958da48..2d058c037 100644
--- a/test/gradcheck.jl
+++ b/test/gradcheck.jl
@@ -1456,6 +1456,23 @@ using Zygote: Buffer
prod(copy(b))
end == (3,)
+ # Buffer storing arrays test
+ W1 = ones(3, 3)
+ W2 = ones(3, 3)
+ x = ones(3, 1)
+
+ function buffer_arrays(W1, W2, x)
+ b = Buffer([])
+ push!(b, W1 * x)
+ push!(b, W2 * x)
+ return sum(vcat(copy(b)...))
+ end
+
+ ∇W1, ∇W2, ∇x = gradient((W1, W2, x) -> buffer_arrays(W1, W2, x), W1, W2, x)
+
+ @test ∇W1 == [1.0 1.0 1.0; 1.0 1.0 1.0; 1.0 1.0 1.0]
+ @test ∇W2 == [1.0 1.0 1.0; 1.0 1.0 1.0; 1.0 1.0 1.0]
+ @test ∇x == [6.0; 6.0; 6.0;;]
end
@testset "FillArrays" begin
From c494ea2a60e7918c8a3224f0fa4f7cac42540115 Mon Sep 17 00:00:00 2001
From: Avik Pal
Date: Mon, 13 Dec 2021 12:16:51 -0500
Subject: [PATCH 263/490] Dims :(
---
test/gradcheck.jl | 6 +++---
1 file changed, 3 insertions(+), 3 deletions(-)
diff --git a/test/gradcheck.jl b/test/gradcheck.jl
index 2d058c037..90f0a4b4a 100644
--- a/test/gradcheck.jl
+++ b/test/gradcheck.jl
@@ -1470,9 +1470,9 @@ using Zygote: Buffer
∇W1, ∇W2, ∇x = gradient((W1, W2, x) -> buffer_arrays(W1, W2, x), W1, W2, x)
- @test ∇W1 == [1.0 1.0 1.0; 1.0 1.0 1.0; 1.0 1.0 1.0]
- @test ∇W2 == [1.0 1.0 1.0; 1.0 1.0 1.0; 1.0 1.0 1.0]
- @test ∇x == [6.0; 6.0; 6.0;;]
+ @test ∇W1 == W1
+ @test ∇W2 == W2
+ @test ∇x == 6 .* x
end
@testset "FillArrays" begin
From a9126567fb347e37380c46b5c61c6230200cada9 Mon Sep 17 00:00:00 2001
From: Avik Pal
Date: Tue, 14 Dec 2021 11:14:59 -0500
Subject: [PATCH 264/490] Might be vector
---
src/lib/buffer.jl | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/src/lib/buffer.jl b/src/lib/buffer.jl
index 49b3ab7f0..8bd374b41 100644
--- a/src/lib/buffer.jl
+++ b/src/lib/buffer.jl
@@ -52,7 +52,7 @@ _pullback(cx::AContext, ::typeof(Broadcast.materialize!), b::Buffer, x::Abstract
return
end
- function copy_sensitivity(b̄::Tuple)
+ function copy_sensitivity(b̄::Union{Tuple,Vector{T}}) where {T<:AbstractArray}
grad_mut(__context__, b)[:] .= b̄
return
end
From 19ba3952523e0e15b91ba359ca7b12851b34fd8b Mon Sep 17 00:00:00 2001
From: Avik Pal
Date: Thu, 16 Dec 2021 13:52:03 -0500
Subject: [PATCH 265/490] Update src/lib/buffer.jl
Co-authored-by: Dhairya Gandhi
---
src/lib/buffer.jl | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/src/lib/buffer.jl b/src/lib/buffer.jl
index 8bd374b41..e62a70041 100644
--- a/src/lib/buffer.jl
+++ b/src/lib/buffer.jl
@@ -52,7 +52,7 @@ _pullback(cx::AContext, ::typeof(Broadcast.materialize!), b::Buffer, x::Abstract
return
end
- function copy_sensitivity(b̄::Union{Tuple,Vector{T}}) where {T<:AbstractArray}
+ function copy_sensitivity(b̄::Union{Tuple,AbstractVector{T}}) where {T<:AbstractArray}
grad_mut(__context__, b)[:] .= b̄
return
end
From 8c88c8d571ac7f1073d1de867c62cc55a6fae618 Mon Sep 17 00:00:00 2001
From: Dhairya Gandhi
Date: Fri, 17 Dec 2021 04:09:18 +0530
Subject: [PATCH 266/490] Tag for SpecialFunctions@2
---
Project.toml | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/Project.toml b/Project.toml
index e4bd32f73..98585b9fb 100644
--- a/Project.toml
+++ b/Project.toml
@@ -1,6 +1,6 @@
name = "Zygote"
uuid = "e88e6eb3-aa80-5325-afca-941959d7151f"
-version = "0.6.32"
+version = "0.6.33"
[deps]
AbstractFFTs = "621f4979-c628-5d54-868e-fcf4e3e8185c"
From e7c240d0ef7baa665edb692e81cbaac54df4ed51 Mon Sep 17 00:00:00 2001
From: Brian Chen
Date: Sat, 25 Dec 2021 12:35:49 -0800
Subject: [PATCH 267/490] Update Buildkite config for 1.6 LTS and 1.7
Changes borrowed from Flux's setup.
---
.buildkite/pipeline.yml | 8 ++++----
1 file changed, 4 insertions(+), 4 deletions(-)
diff --git a/.buildkite/pipeline.yml b/.buildkite/pipeline.yml
index c96109781..6d3a048a2 100644
--- a/.buildkite/pipeline.yml
+++ b/.buildkite/pipeline.yml
@@ -1,18 +1,18 @@
steps:
- - label: "GPU integration - julia 1.5"
+ - label: "GPU integration - julia v1.6"
plugins:
- JuliaCI/julia#v1:
- version: "1.5"
+ version: "1.6"
- JuliaCI/julia-test#v1: ~
agents:
queue: "juliagpu"
cuda: "*"
timeout_in_minutes: 60
- - label: "GPU integration - julia 1.6"
+ - label: "GPU integration - julia v1"
plugins:
- JuliaCI/julia#v1:
- version: '1.6'
+ version: "1"
- JuliaCI/julia-test#v1: ~
agents:
queue: "juliagpu"
From 3a63df8edb3b613107761ff829ca61ed393ce2dd Mon Sep 17 00:00:00 2001
From: Joe Greener
Date: Sat, 8 Jan 2022 19:53:23 +0000
Subject: [PATCH 268/490] Downstream test for Molly.jl (#1145)
---
.github/workflows/Downstream.yml | 1 +
1 file changed, 1 insertion(+)
diff --git a/.github/workflows/Downstream.yml b/.github/workflows/Downstream.yml
index 308754b0a..09f4f2f5d 100644
--- a/.github/workflows/Downstream.yml
+++ b/.github/workflows/Downstream.yml
@@ -24,6 +24,7 @@ jobs:
- {user: TuringLang, repo: DistributionsAD.jl, group: Zygote}
- {user: SciML, repo: DiffEqFlux.jl, group: Layers}
- {user: SciML, repo: NeuralPDE.jl, group: NNPDE}
+ - {user: JuliaMolSim, repo: Molly.jl, group: Zygote}
steps:
- uses: actions/checkout@v2
- uses: julia-actions/setup-julia@v1
From 95d61fc317fcbac8438971794c433e609922282d Mon Sep 17 00:00:00 2001
From: Michael Abbott <32575566+mcabbott@users.noreply.github.com>
Date: Fri, 21 Jan 2022 14:43:02 -0500
Subject: [PATCH 269/490] v0.6.34
---
Project.toml | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/Project.toml b/Project.toml
index 98585b9fb..4b88d68de 100644
--- a/Project.toml
+++ b/Project.toml
@@ -1,6 +1,6 @@
name = "Zygote"
uuid = "e88e6eb3-aa80-5325-afca-941959d7151f"
-version = "0.6.33"
+version = "0.6.34"
[deps]
AbstractFFTs = "621f4979-c628-5d54-868e-fcf4e3e8185c"
From 23175c62b4112b6069610439db45cc4aeb471f2c Mon Sep 17 00:00:00 2001
From: Brian Chen
Date: Fri, 21 Jan 2022 20:15:51 -0800
Subject: [PATCH 270/490] Add codecov badge
---
README.md | 1 +
1 file changed, 1 insertion(+)
diff --git a/README.md b/README.md
index 866c6a617..35e90fc53 100644
--- a/README.md
+++ b/README.md
@@ -4,6 +4,7 @@
[![CI Testing](https://github.com/FluxML/Zygote.jl/workflows/CI/badge.svg)](https://github.com/FluxML/Zygote.jl/actions)
+[![Coverage](https://codecov.io/gh/FluxML/Zygote.jl/branch/master/graph/badge.svg)](https://codecov.io/gh/FluxML/Zygote.jl)
[![Dev Docs](https://img.shields.io/badge/docs-dev-blue.svg)](https://fluxml.ai/Zygote.jl/dev)
`] add Zygote`
From 5c0ecf41e008ad57aa573144a3fd0d2cb7f691f4 Mon Sep 17 00:00:00 2001
From: Brian Chen
Date: Fri, 21 Jan 2022 21:15:03 -0800
Subject: [PATCH 271/490] Update codecov action and only run on stable linux
---
.github/workflows/ci.yml | 6 +++---
1 file changed, 3 insertions(+), 3 deletions(-)
diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml
index eb544c2ec..bab7876a5 100644
--- a/.github/workflows/ci.yml
+++ b/.github/workflows/ci.yml
@@ -57,9 +57,9 @@ jobs:
JULIA_PKG_SERVER: ""
#continue-on-error: ${{ matrix.version == 'nightly' }} # comment out to report nightly failures
- uses: julia-actions/julia-processcoverage@v1
- #continue-on-error: ${{ matrix.version == 'nightly' }} # comment out to report nightly failures
- - uses: codecov/codecov-action@v1
- #continue-on-error: ${{ matrix.version == 'nightly' }} # comment out to report nightly failures
+ if: matrix.version == '1' && matrix.os == 'ubuntu-latest'
+ - uses: codecov/codecov-action@v2
+ if: matrix.version == '1' && matrix.os == 'ubuntu-latest'
with:
file: lcov.info
docs:
From 8bdfc180ea8da332f2894d5fb7db14715a633ca3 Mon Sep 17 00:00:00 2001
From: "github-actions[bot]"
<41898282+github-actions[bot]@users.noreply.github.com>
Date: Thu, 17 Feb 2022 00:07:59 +0000
Subject: [PATCH 272/490] CompatHelper: bump compat for "NaNMath" to "1"
---
Project.toml | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/Project.toml b/Project.toml
index 4b88d68de..caeba7782 100644
--- a/Project.toml
+++ b/Project.toml
@@ -31,7 +31,7 @@ FillArrays = "0.8, 0.9, 0.10, 0.11, 0.12"
ForwardDiff = "0.10"
IRTools = "0.4.4"
MacroTools = "0.5"
-NaNMath = "0.3"
+NaNMath = "0.3, 1"
Requires = "1.1"
SpecialFunctions = "1.6, 2"
StatsFuns = "0.9.8"
From 6b5729936c006828400df522a17ec8451dba3ab8 Mon Sep 17 00:00:00 2001
From: Samuel Buercklin
Date: Thu, 17 Feb 2022 22:51:23 -0500
Subject: [PATCH 273/490] `ntuple` for `_restore` regardless of length (#1163)
* removed ternary in _restore
* added test for _restore with ntuple fix
* added gradient test, remove lib/array test
---
src/lib/array.jl | 2 +-
test/gradcheck.jl | 16 ++++++++++++++++
2 files changed, 17 insertions(+), 1 deletion(-)
diff --git a/src/lib/array.jl b/src/lib/array.jl
index 7734ad5ca..35c678310 100644
--- a/src/lib/array.jl
+++ b/src/lib/array.jl
@@ -184,7 +184,7 @@ _tryreverse(m::typeof(map), x::Union{AbstractVector, Tuple}) = reverse(x)
_tryaxes(x) = axes(x)
_tryaxes(x::Tuple) = Val(length(x))
_restore(dx, ax::Tuple) = axes(dx) == ax ? dx : reshape(vcat(dx, falses(prod(length, ax) - length(dx))), ax)
-_restore(dx, ::Val{N}) where {N} = length(dx) < N ? ntuple(i -> get(dx,i,nothing), N) : NTuple{N}(dx)
+_restore(dx, ::Val{N}) where {N} = ntuple(i -> get(dx,i,nothing), N)
# Sometimes a pullback doesn't return a Tuple, but rather returns only a
# single nothing to say "all arguments have zero cotangent". This function is needed to
diff --git a/test/gradcheck.jl b/test/gradcheck.jl
index 90f0a4b4a..67e51ec19 100644
--- a/test/gradcheck.jl
+++ b/test/gradcheck.jl
@@ -1961,3 +1961,19 @@ end
@test g1[1].A isa Number
@test size(g1[2]) == size(V)
end
+
+@testset "Zygote #1162" begin
+ function zygote1162(as, bs)
+ results = [f1162(a, b) for (a, b) in zip(as, bs)]
+ return results[2][1] + results[2][2]
+ end
+ function f1162(a, b)
+ return [a^2, b^2]
+ end
+
+ as = (1.0, 2.0, 3.0)
+ bs = (4.0, 5.0, 6.0)
+
+ g = Zygote.gradient(zygote1162, as, bs)
+ @test g == ((nothing, 2*as[2], nothing), (nothing, 2*bs[2], nothing))
+end
From acedd2883a927712d120a60482ef5255d2d0de7d Mon Sep 17 00:00:00 2001
From: "github-actions[bot]"
<41898282+github-actions[bot]@users.noreply.github.com>
Date: Thu, 17 Feb 2022 22:55:10 -0500
Subject: [PATCH 274/490] CompatHelper: bump compat for "FillArrays" to "0.13"
(#1165)
Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
---
Project.toml | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/Project.toml b/Project.toml
index 4b88d68de..f2e13f384 100644
--- a/Project.toml
+++ b/Project.toml
@@ -27,7 +27,7 @@ ChainRules = "1.5"
ChainRulesCore = "1.9"
ChainRulesTestUtils = "1"
DiffRules = "1.0"
-FillArrays = "0.8, 0.9, 0.10, 0.11, 0.12"
+FillArrays = "0.8, 0.9, 0.10, 0.11, 0.12, 0.13"
ForwardDiff = "0.10"
IRTools = "0.4.4"
MacroTools = "0.5"
From 36698f74048e79c7a2c542e2cd90a802ea5e1bf4 Mon Sep 17 00:00:00 2001
From: Miha Zgubic
Date: Fri, 18 Feb 2022 11:19:56 +0000
Subject: [PATCH 275/490] update README to reference ChainRulesCore to define
custom gradients
---
README.md | 13 +++++++++----
1 file changed, 9 insertions(+), 4 deletions(-)
diff --git a/README.md b/README.md
index 35e90fc53..c945f1ee2 100644
--- a/README.md
+++ b/README.md
@@ -30,7 +30,8 @@ top:
"Source-to-source" means that Zygote hooks into Julia's compiler, and generates the backwards pass for you – as if you had written it by hand.
-Zygote supports the full flexibility and dynamism of the Julia language, including control flow, recursion, closures, structs, dictionaries, and more.
+Zygote supports the flexibility and dynamism of the Julia language, including control flow, recursion, closures, structs, dictionaries, and more.
+Mutation and exception handling are currently not supported.
```julia
julia> fs = Dict("sin" => sin, "cos" => cos, "tan" => tan);
@@ -40,14 +41,18 @@ sin
0.5403023058681398
```
-Defining custom gradients is a cinch, and errors have good stacktraces.
+Zygote benefits from using the [ChainRules.jl](https://github.com/JuliaDiff/ChainRules.jl) ruleset.
+Custom gradients can be defined by extending the [ChainRulesCore.jl](https://github.com/JuliaDiff/ChainRulesCore.jl)'s `rrule`:
```julia
-julia> using Zygote: @adjoint
+julia> using ChainRulesCore
julia> add(a, b) = a + b
-julia> @adjoint add(a, b) = add(a, b), Δ -> (Δ, Δ)
+julia> function ChainRulesCore.rrule(::typeof(add), a, b)
+ add_pb(dy) = (NoTangent(), dy, dy)
+ return add(a, b), add_pb
+ end
```
To support large machine learning models with many parameters, Zygote can differentiate implicitly-used parameters, as opposed to just function arguments.
From 87872e708573e60d10815dd1b2fb8473f41f9efd Mon Sep 17 00:00:00 2001
From: Michael Abbott <32575566+mcabbott@users.noreply.github.com>
Date: Sun, 20 Feb 2022 22:16:05 -0500
Subject: [PATCH 276/490] Broadcast rule for types (#1171)
* broadcast rule for type
* test on sparse arrays
---
Project.toml | 1 +
src/lib/broadcast.jl | 3 +++
test/gradcheck.jl | 6 +++++-
3 files changed, 9 insertions(+), 1 deletion(-)
diff --git a/Project.toml b/Project.toml
index e15f61f9f..a54f14977 100644
--- a/Project.toml
+++ b/Project.toml
@@ -17,6 +17,7 @@ MacroTools = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09"
NaNMath = "77ba4419-2d1f-58cd-9bb1-8ffee604a2e3"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Requires = "ae029012-a4dd-5104-9daa-d747884805df"
+SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
ZygoteRules = "700de1a5-db45-46bc-99cf-38207098b444"
diff --git a/src/lib/broadcast.jl b/src/lib/broadcast.jl
index 8833436a0..78816e4f7 100644
--- a/src/lib/broadcast.jl
+++ b/src/lib/broadcast.jl
@@ -144,6 +144,9 @@ end
end
end
+@adjoint broadcasted(::Type{T}, x::Numeric) where T =
+ T.(x), ȳ -> (nothing, _project(x, ȳ),)
+
# General Fallback
# ================
diff --git a/test/gradcheck.jl b/test/gradcheck.jl
index 67e51ec19..7c45d26d2 100644
--- a/test/gradcheck.jl
+++ b/test/gradcheck.jl
@@ -1,4 +1,4 @@
-using Zygote, Test, Random, LinearAlgebra, Statistics, FillArrays,
+using Zygote, Test, Random, LinearAlgebra, Statistics, SparseArrays, FillArrays,
AbstractFFTs, FFTW, Distances
using Zygote: gradient
using Base.Broadcast: broadcast_shape
@@ -1406,6 +1406,10 @@ end
@test all(gradient((x,y) -> sum(x .* y), 5, [1,2]) .≈ (3, [5, 5]))
@test all(gradient((x,y) -> sum(x .* y), [1,2], [3 4 5]) .≈ ([12, 12], [3 3 3]))
@test all(gradient((x,y) -> sum(x ./ y), [1,2], 5) .≈ ([0.2, 0.2], -0.12))
+
+ sm = sprand(5, 5, 0.5)
+ @test gradient(x -> sum(abs2, Float32.(x)), sm)[1] ≈ gradient(x -> sum(abs2, x), Matrix{Float32}(sm))[1]
+ @test gradient(x -> real(sum(ComplexF32.(x) .+ 1 .+ im)), sm)[1] isa SparseMatrixCSC{Float64}
end
using Zygote: Buffer
From e56375e08ac6191f32713f4d63c087655a2b4fd3 Mon Sep 17 00:00:00 2001
From: Michael Abbott <32575566+mcabbott@users.noreply.github.com>
Date: Tue, 22 Feb 2022 00:45:27 -0500
Subject: [PATCH 277/490] v0.6.35
---
Project.toml | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/Project.toml b/Project.toml
index a54f14977..ce5c48d1c 100644
--- a/Project.toml
+++ b/Project.toml
@@ -1,6 +1,6 @@
name = "Zygote"
uuid = "e88e6eb3-aa80-5325-afca-941959d7151f"
-version = "0.6.34"
+version = "0.6.35"
[deps]
AbstractFFTs = "621f4979-c628-5d54-868e-fcf4e3e8185c"
From 8b9d67dbc29ef438c65e9ae92ce5c68c2188cd6b Mon Sep 17 00:00:00 2001
From: Lorenzo Van Munoz <66997677+lxvm@users.noreply.github.com>
Date: Tue, 22 Feb 2022 16:39:02 -0800
Subject: [PATCH 278/490] Fix adjoint Iterators.product behavior with nothing
(#1170)
* Fix adjoint Iterators.product behavior with nothing
* Apply suggestions from code review
Co-authored-by: Michael Abbott <32575566+mcabbott@users.noreply.github.com>
* add Iterators.product adjoint tests
* Update test/lib/array.jl
Co-authored-by: Michael Abbott <32575566+mcabbott@users.noreply.github.com>
Co-authored-by: Michael Abbott <32575566+mcabbott@users.noreply.github.com>
---
src/lib/array.jl | 2 +-
test/lib/array.jl | 14 +++++++++++++-
2 files changed, 14 insertions(+), 2 deletions(-)
diff --git a/src/lib/array.jl b/src/lib/array.jl
index 35c678310..6f5386c32 100644
--- a/src/lib/array.jl
+++ b/src/lib/array.jl
@@ -286,10 +286,10 @@ _ndims(x) = Base.IteratorSize(x) isa Base.HasShape ? _ndims(Base.IteratorSize(x)
function back(dy::AbstractArray)
d = 1
ntuple(length(xs)) do n
- first(dy)[n] === nothing && return nothing
nd = _ndims(xs[n])
dims = ntuple(i -> isum(sin, Diagonal(x)), ones(2); rrule_f=rrule_via_ad, check_inferred=false)
test_rrule(ZygoteRuleConfig(), x->sum(sin, Diagonal(x)), rand(3); rrule_f=rrule_via_ad, check_inferred=false)
+
+@testset "adjoints of Iterators.product" begin
+ y, back = _pullback(Iterators.product, 1:5, 1:3, 1:2)
+ @test back(collect(y)) == (nothing, [6.0, 12.0, 18.0, 24.0, 30.0], [10.0, 20.0, 30.0], [15.0, 30.0])
+ @test back([(nothing, j, k) for i in 1:5, j in 1:3, k in 1:2]) == (nothing, nothing, [10.0, 20.0, 30.0], [15.0, 30.0])
+ @test back([(i, nothing, k) for i in 1:5, j in 1:3, k in 1:2]) == (nothing, [6.0, 12.0, 18.0, 24.0, 30.0], nothing, [15.0, 30.0])
+ @test back([(i, j, nothing) for i in 1:5, j in 1:3, k in 1:2]) == (nothing, [6.0, 12.0, 18.0, 24.0, 30.0], [10.0, 20.0, 30.0], nothing)
+
+ # This was wrong before https://github.com/FluxML/Zygote.jl/pull/1170
+ @test gradient(x -> sum([y[2] * y[3] for y in Iterators.product(x, x, x, x)]), [1,2,3,4])[1] ≈ [320, 320, 320, 320]
+ @test gradient(x -> sum(y[2] * y[3] for y in Iterators.product(x, x, x, x)), [1,2,3,4])[1] ≈ [320, 320, 320, 320]
+end
From 2a2095cccc31af5dfe47981033b7c99be0aafb01 Mon Sep 17 00:00:00 2001
From: James Atkins
Date: Thu, 10 Mar 2022 04:46:40 +0000
Subject: [PATCH 279/490] Fix error in example (#1176)
---
docs/src/adjoints.md | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/docs/src/adjoints.md b/docs/src/adjoints.md
index 1d6ddc527..34aa5d18f 100644
--- a/docs/src/adjoints.md
+++ b/docs/src/adjoints.md
@@ -68,7 +68,7 @@ julia> mygradient(sin, 0.5)
The rest of this section contains more technical detail. It can be skipped if you only need an intuition for pullbacks; you generally won't need to worry about it as a user.
-If ``x`` and ``y`` are vectors, ``\frac{\partial y}{\partial x}`` becomes a Jacobian. Importantly, because we are implementing reverse mode we actually left-multiply the Jacobian, i.e. `v'J`, rather than the more usual `J*v`. Transposing `v` to a row vector and back `(v'J)'` is equivalent to `J'v` so our gradient rules actually implement the *adjoint* of the Jacobian. This is relevant even for scalar code: the adjoint for `y = sin(x)` is `x̄ = sin(x)'*ȳ`; the conjugation is usually moot but gives the correct behaviour for complex code. "Pullbacks" are therefore sometimes called "vector-Jacobian products" (VJPs), and we refer to the reverse mode rules themselves as "adjoints".
+If ``x`` and ``y`` are vectors, ``\frac{\partial y}{\partial x}`` becomes a Jacobian. Importantly, because we are implementing reverse mode we actually left-multiply the Jacobian, i.e. `v'J`, rather than the more usual `J*v`. Transposing `v` to a row vector and back `(v'J)'` is equivalent to `J'v` so our gradient rules actually implement the *adjoint* of the Jacobian. This is relevant even for scalar code: the adjoint for `y = sin(x)` is `x̄ = cos(x)'*ȳ`; the conjugation is usually moot but gives the correct behaviour for complex code. "Pullbacks" are therefore sometimes called "vector-Jacobian products" (VJPs), and we refer to the reverse mode rules themselves as "adjoints".
Zygote has many adjoints for non-mathematical operations such as for indexing and data structures. Though these can still be seen as linear functions of vectors, it's not particularly enlightening to implement them with an actual matrix multiply. In these cases it's easiest to think of the adjoint as a kind of inverse. For example, the gradient of a function that takes a tuple to a struct (e.g. `y = Complex(a, b)`) will generally take a struct to a tuple (`(ȳ.re, ȳ.im)`). The gradient of a `getindex` `y = x[i...]` is a `setindex!` `x̄[i...] = ȳ`, etc.
From 843a52d6a069fdaadc6559d0db084a73eb6058f7 Mon Sep 17 00:00:00 2001
From: Michael Abbott <32575566+mcabbott@users.noreply.github.com>
Date: Thu, 10 Mar 2022 00:41:02 -0500
Subject: [PATCH 280/490] Restrict type broadcast rule to numbers (#1179)
* restrict type broadcast to number
* add a test
---
Project.toml | 2 +-
src/lib/broadcast.jl | 2 +-
test/gradcheck.jl | 8 ++++++++
3 files changed, 10 insertions(+), 2 deletions(-)
diff --git a/Project.toml b/Project.toml
index ce5c48d1c..0a1fa489e 100644
--- a/Project.toml
+++ b/Project.toml
@@ -1,6 +1,6 @@
name = "Zygote"
uuid = "e88e6eb3-aa80-5325-afca-941959d7151f"
-version = "0.6.35"
+version = "0.6.36"
[deps]
AbstractFFTs = "621f4979-c628-5d54-868e-fcf4e3e8185c"
diff --git a/src/lib/broadcast.jl b/src/lib/broadcast.jl
index 78816e4f7..f2b0e0709 100644
--- a/src/lib/broadcast.jl
+++ b/src/lib/broadcast.jl
@@ -144,7 +144,7 @@ end
end
end
-@adjoint broadcasted(::Type{T}, x::Numeric) where T =
+@adjoint broadcasted(::Type{T}, x::Numeric) where {T<:Number} =
T.(x), ȳ -> (nothing, _project(x, ȳ),)
# General Fallback
diff --git a/test/gradcheck.jl b/test/gradcheck.jl
index 7c45d26d2..b3b1e2969 100644
--- a/test/gradcheck.jl
+++ b/test/gradcheck.jl
@@ -1407,9 +1407,17 @@ end
@test all(gradient((x,y) -> sum(x .* y), [1,2], [3 4 5]) .≈ ([12, 12], [3 3 3]))
@test all(gradient((x,y) -> sum(x ./ y), [1,2], 5) .≈ ([0.2, 0.2], -0.12))
+ # https://github.com/FluxML/Zygote.jl/pull/1171
sm = sprand(5, 5, 0.5)
@test gradient(x -> sum(abs2, Float32.(x)), sm)[1] ≈ gradient(x -> sum(abs2, x), Matrix{Float32}(sm))[1]
@test gradient(x -> real(sum(ComplexF32.(x) .+ 1 .+ im)), sm)[1] isa SparseMatrixCSC{Float64}
+
+ # https://github.com/FluxML/Zygote.jl/issues/1178
+ function f1179(x)
+ fs = Ref.(x)
+ getindex.(fs)
+ end
+ @test gradient(sum∘f1179, ones(2)) == ([2.0, 2.0],)
end
using Zygote: Buffer
From 403c3ae0bf66c513b26d5ea0707eda37dd6125a0 Mon Sep 17 00:00:00 2001
From: Brian Chen
Date: Sat, 12 Mar 2022 14:04:37 -0800
Subject: [PATCH 281/490] Use Base.IdSet
GitHub tells me this has been available since at least 0.7.
---
src/Zygote.jl | 1 -
src/compiler/interface.jl | 4 ++--
src/profiler/Profile.jl | 4 ++--
src/tools/idset.jl | 21 ---------------------
4 files changed, 4 insertions(+), 26 deletions(-)
delete mode 100644 src/tools/idset.jl
diff --git a/src/Zygote.jl b/src/Zygote.jl
index 85b71359f..d537efbb7 100644
--- a/src/Zygote.jl
+++ b/src/Zygote.jl
@@ -18,7 +18,6 @@ export rrule_via_ad
const Numeric{T<:Number} = Union{T, AbstractArray{<:T}}
-include("tools/idset.jl")
include("tools/buffer.jl")
include("tools/builtins.jl")
diff --git a/src/compiler/interface.jl b/src/compiler/interface.jl
index 18c0bd8eb..38bc328e0 100644
--- a/src/compiler/interface.jl
+++ b/src/compiler/interface.jl
@@ -1,7 +1,7 @@
using InteractiveUtils
using InteractiveUtils: typesof
using Core: Typeof
-import Base: copy!
+import Base: copy!, IdSet
import Base.Broadcast: broadcasted, materialize!
mutable struct Context <: AContext
@@ -144,7 +144,7 @@ struct Params
end
Params() = Params(Buffer([], false), IdSet())
-Params(xs) = Params(Buffer(xs, false), IdSet(xs))
+Params(xs) = Params(Buffer(xs, false), IdSet{Any}(xs))
Params(ps::Params) = ps
Params(xs::Tuple) = Params(collect(xs))
diff --git a/src/profiler/Profile.jl b/src/profiler/Profile.jl
index a351a7df1..d2a16aeff 100644
--- a/src/profiler/Profile.jl
+++ b/src/profiler/Profile.jl
@@ -1,7 +1,7 @@
module Profile
using Requires
-using ..Zygote: Pullback, IdSet, meta, stacklines
+using ..Zygote: Pullback, meta, stacklines
function loc(f)
# TODO perhaps find most general method
@@ -36,7 +36,7 @@ function mem(x, seen)
sum(x -> mem(x, seen), fields(x))
end
-mem(x) = mem(x, IdSet())
+mem(x) = mem(x, Base.IdSet())
struct Node
func::Symbol
diff --git a/src/tools/idset.jl b/src/tools/idset.jl
deleted file mode 100644
index a0aa93df0..000000000
--- a/src/tools/idset.jl
+++ /dev/null
@@ -1,21 +0,0 @@
-struct IdSet{T} <: AbstractSet{T}
- dict::IdDict{T,Nothing}
- IdSet{T}() where T = new(IdDict{T,Nothing}())
-end
-
-IdSet(xs) = IdSet{eltype(xs)}(xs)
-
-IdSet() = IdSet{Any}()
-
-IdSet{T}(xs) where T = isempty(xs) ? IdSet{T}() : push!(IdSet{T}(), xs...)
-
-Base.push!(s::IdSet{T}, x::T) where T = (s.dict[x] = nothing; s)
-Base.delete!(s::IdSet{T}, x::T) where T = (delete!(s.dict, x); s)
-Base.in(x, s::IdSet) = haskey(s.dict, x)
-Base.eltype(::IdSet{T}) where T = T
-Base.collect(s::IdSet) = Base.collect(keys(s.dict))
-Base.similar(s::IdSet, T::Type) = IdSet{T}()
-
-@forward IdSet.dict Base.length
-
-Base.iterate(s::IdSet, st...) = iterate(keys(s.dict), st...)
From c45fa66eb66eb2dabb0f3fddc123efbfbe29807b Mon Sep 17 00:00:00 2001
From: Sebastian Ament
Date: Mon, 14 Mar 2022 12:39:51 +0100
Subject: [PATCH 282/490] fixing type ambiguity of unbroadcast
---
Project.toml | 2 +-
src/lib/broadcast.jl | 12 +++++++-----
2 files changed, 8 insertions(+), 6 deletions(-)
diff --git a/Project.toml b/Project.toml
index 0a1fa489e..19ee30869 100644
--- a/Project.toml
+++ b/Project.toml
@@ -1,6 +1,6 @@
name = "Zygote"
uuid = "e88e6eb3-aa80-5325-afca-941959d7151f"
-version = "0.6.36"
+version = "0.6.37"
[deps]
AbstractFFTs = "621f4979-c628-5d54-868e-fcf4e3e8185c"
diff --git a/src/lib/broadcast.jl b/src/lib/broadcast.jl
index f2b0e0709..6dbfdb829 100644
--- a/src/lib/broadcast.jl
+++ b/src/lib/broadcast.jl
@@ -59,6 +59,8 @@ unbroadcast(x::Tuple{<:Any}, x̄) = (accum_sum(x̄),)
unbroadcast(x::Base.RefValue, x̄) = (x=accum_sum(x̄),)
unbroadcast(x::Tuple, x̄) = NTuple{length(x)}(length(x) == length(x̄) ? x̄ : accum_sum(x̄; dims=2:ndims(x̄))) # case length(x) > 1
unbroadcast(x::Tuple, x̄::Nothing) = nothing
+# fixing issue #1184, not duplicate method, since the above allows for an empty tuple
+unbroadcast(x::Tuple{<:Any}, x̄::Nothing) = nothing
unbroadcast(x::AbstractArray, x̄::Nothing) = nothing
@@ -81,7 +83,7 @@ _minus(::Nothing) = nothing
Δ -> (nothing, unbroadcast(x, Δ .* conj.(y)), unbroadcast(y, Δ .* conj.(x)))
@adjoint broadcasted(::typeof(*), x::Number, y::AbstractArray{<:Number}) =
_pullback(*, x, y) # this uses dot(y,Δ) instead of sum(Δ .* conj.(y))
-@adjoint broadcasted(::typeof(*), x::AbstractArray{<:Number}, y::Number) =
+@adjoint broadcasted(::typeof(*), x::AbstractArray{<:Number}, y::Number) =
_pullback(*, x, y)
@adjoint function broadcasted(::typeof(/), x::Numeric, y::Numeric)
@@ -181,7 +183,7 @@ _dual_safearg(x) = false
T = Broadcast.combine_eltypes(f, args)
# Avoid generic broadcasting in two easy cases:
if T == Bool
- return (f.(args...), _ -> nothing)
+ return (f.(args...), _ -> nothing)
elseif T <: Real && isconcretetype(T) && _dual_purefun(F) && all(_dual_safearg, args) && !isderiving()
return broadcast_forward(f, args...)
end
@@ -260,7 +262,7 @@ end
@eval @adjoint broadcasted(::CuArrayStyle, f, args...) =
broadcast_forward(CUDA.cufunc(f), args...)
- else # CUDA >= 3.0 -- don't need cufunc(f).
+ else # CUDA >= 3.0 -- don't need cufunc(f).
# Ordinary broadcasting calls broadcast_forward anyway when certain its' safe,
# so perhaps this can be deleted? Possible edge case here:
# https://github.com/FluxML/Zygote.jl/pull/1018#issuecomment-873629415
@@ -277,14 +279,14 @@ end
placeholder = similar(xs)
sum(xs, dims = dims), Δ -> (placeholder .= Δ,)
end
-
+
# Make sure sum(f, ::CuArray) uses broadcase through forward-mode defined above
# Not the ChainRules.rrule which will use the Zygote.Context and thus not be GPU compatible
@adjoint function sum(f, xs::CUDA.AbstractGPUArray; kws...)
@assert !haskey(kws, :init) # TODO add init support (julia 1.6)
return pullback(__context__, (f, xs) -> sum(f.(xs); kws...), f, xs)
end
-
+
@adjoint function Base.convert(::Type{T}, xs::Array) where {T<:CUDA.AbstractGPUArray}
Base.convert(T, xs), Δ -> (nothing, Base.convert(Array, Δ),)
end
From e47114be5bf02d2bceb8351a2178f0bab978ef70 Mon Sep 17 00:00:00 2001
From: Sebastian Ament
Date: Wed, 16 Mar 2022 10:40:56 +0100
Subject: [PATCH 283/490] added type ambiguity test
---
test/gradcheck.jl | 35 ++++++++++++++++++++++-------------
1 file changed, 22 insertions(+), 13 deletions(-)
diff --git a/test/gradcheck.jl b/test/gradcheck.jl
index b3b1e2969..ac0dd28bf 100644
--- a/test/gradcheck.jl
+++ b/test/gradcheck.jl
@@ -360,7 +360,7 @@ end
@test gradient(x -> map(+, x, (1,2,3))[1], [4,5,6]) == ([1,0,0],)
# mismatched lengths, should zip
- @test gradient(x -> map(+, x, [1,2,3,99])[1], (4,5,6)) == ((1.0, 0.0, 0.0),)
+ @test gradient(x -> map(+, x, [1,2,3,99])[1], (4,5,6)) == ((1.0, 0.0, 0.0),)
@test gradient(x -> map(+, x, [1,2,3])[1], (4,5,6,99)) == ((1.0, 0.0, 0.0, nothing),)
end
@@ -1386,7 +1386,7 @@ end
end
@testset "broadcast" begin
- # Before https://github.com/FluxML/Zygote.jl/pull/1001 this gave [1 1 1; 1 0 1; 1 1 -1]
+ # Before https://github.com/FluxML/Zygote.jl/pull/1001 this gave [1 1 1; 1 0 1; 1 1 -1]
@test gradient(x -> sum(sin.(x)), Diagonal([0,pi/2,pi]))[1] ≈ [1 0 0; 0 0 0; 0 0 -1]
a = rand(3)
@@ -1487,17 +1487,6 @@ using Zygote: Buffer
@test ∇x == 6 .* x
end
-@testset "FillArrays" begin
- @test gradcheck(x->sum(Fill(x[], (2, 2))), [0.1])
- @test first(Zygote.gradient(sz->sum(Ones(sz)), 6)) === nothing
- @test first(Zygote.gradient(sz->sum(Zeros(sz)), 6)) === nothing
- @test gradcheck(x->Fill(x[], 5).value, [0.1])
- @test gradcheck(x->FillArrays.getindex_value(Fill(x[], 5)), [0.1])
-
- @test first(Zygote.pullback(Ones{Float32}, 10)) isa Ones{Float32}
- @test first(Zygote.pullback(Zeros{Float32}, 10)) isa Zeros{Float32}
-end
-
@testset "AbstractArray Addition / Subtraction / Negation" begin
rng, M, N, P = MersenneTwister(123567), 3, 7, 11
A, B = randn(rng, M, N, P), randn(rng, M, N, P)
@@ -1623,6 +1612,16 @@ end
end
@testset "FillArrays" begin
+
+ @test gradcheck(x->sum(Fill(x[], (2, 2))), [0.1])
+ @test first(Zygote.gradient(sz->sum(Ones(sz)), 6)) === nothing
+ @test first(Zygote.gradient(sz->sum(Zeros(sz)), 6)) === nothing
+ @test gradcheck(x->Fill(x[], 5).value, [0.1])
+ @test gradcheck(x->FillArrays.getindex_value(Fill(x[], 5)), [0.1])
+
+ @test first(Zygote.pullback(Ones{Float32}, 10)) isa Ones{Float32}
+ @test first(Zygote.pullback(Zeros{Float32}, 10)) isa Zeros{Float32}
+
rng, M, N = MersenneTwister(123456), 7, 11
x, y = randn(rng), randn(rng)
@test Zygote.gradient(x->sum(Fill(x, N)), x)[1] == N
@@ -1989,3 +1988,13 @@ end
g = Zygote.gradient(zygote1162, as, bs)
@test g == ((nothing, 2*as[2], nothing), (nothing, 2*bs[2], nothing))
end
+
+@testset "Zygote #1184" begin
+ n, d = 3, 2
+ x = [randn(d) for _ in 1:n]
+
+ f = sin
+ g(x) = sum.((f,), x)
+ h(x) = sum(abs2, g(x))
+ @test gradient(h, x)[1] isa typeof(x)
+end
From 3928ab9e6dd6dae3bf6df7882ec444aba1102628 Mon Sep 17 00:00:00 2001
From: Aman Sharma <76823502+arcAman07@users.noreply.github.com>
Date: Sun, 3 Apr 2022 15:38:46 +0530
Subject: [PATCH 284/490] Ton of doctests added (#1194)
* Ton of doctests added to index.md
* Ton of doctests added to index.md
* Ton of doctests added to index.md
* Ton of doctests added to index.md
* Ton of doctests added to index.md
* Ton of doctests added to index.md
* outdated example fixed
* outdated example fixed
* outdated example fixed
* outdated example fixed
* doctests added to adjoints.md
* doctests added to adjoints.md
* doctests added to adjoints.md
* doctests added to adjoints.md
* outdated example updated
* More doctests added
* More doctests added
* doctest added and checked properly
---
docs/make.jl | 1 +
docs/src/adjoints.md | 28 +++++++++++++++-------------
docs/src/complex.md | 12 +++++++-----
docs/src/index.md | 20 ++++++++++----------
4 files changed, 33 insertions(+), 28 deletions(-)
diff --git a/docs/make.jl b/docs/make.jl
index 9e5dbb3f6..9d2f549c9 100644
--- a/docs/make.jl
+++ b/docs/make.jl
@@ -8,6 +8,7 @@ using Documenter, Zygote
makedocs(
sitename="Zygote",
+ doctest = true,
pages = [
"Home" => "index.md",
"Custom Adjoints" => "adjoints.md",
diff --git a/docs/src/adjoints.md b/docs/src/adjoints.md
index 34aa5d18f..2ee094e78 100644
--- a/docs/src/adjoints.md
+++ b/docs/src/adjoints.md
@@ -18,7 +18,9 @@ The `@adjoint` macro is an important part of Zygote's interface; customising you
`gradient` is really just syntactic sugar around the more fundamental function `pullback`.
-```julia
+```jldoctest adjoints
+julia> using Zygote
+
julia> y, back = Zygote.pullback(sin, 0.5);
julia> y
@@ -55,7 +57,7 @@ julia> cos(0.5)
More generally
-```julia
+```jldoctest adjoints
julia> function mygradient(f, x...)
_, back = Zygote.pullback(f, x...)
back(1)
@@ -76,15 +78,15 @@ Zygote has many adjoints for non-mathematical operations such as for indexing an
We can extend Zygote to a new function with the `@adjoint` function.
-```julia
-julia> mul(a, b) = a*b
+```jldoctest adjoints
+julia> mul(a, b) = a*b;
julia> using Zygote: @adjoint
julia> @adjoint mul(a, b) = mul(a, b), c̄ -> (c̄*b, c̄*a)
julia> gradient(mul, 2, 3)
-(3, 2)
+(3.0, 2.0)
```
It might look strange that we write `mul(a, b)` twice here. In this case we want to call the normal `mul` function for the pullback pass, but you may also want to modify the pullback pass (for example, to capture intermediate results in the pullback).
@@ -152,7 +154,7 @@ We usually use custom adjoints to add gradients that Zygote can't derive itself
### Gradient Hooks
-```julia
+```jldoctest adjoints
julia> hook(f, x) = x
hook (generic function with 1 method)
@@ -161,17 +163,17 @@ julia> @adjoint hook(f, x) = x, x̄ -> (nothing, f(x̄))
`hook` doesn't seem that interesting, as it doesn't do anything. But the fun part is in the adjoint; it's allowing us to apply a function `f` to the gradient of `x`.
-```julia
+```jldoctest adjoints
julia> gradient((a, b) -> hook(-, a)*b, 2, 3)
-(-3, 2)
+(-3.0, 2.0)
```
We could use this for debugging or modifying gradients (e.g. gradient clipping).
-```julia
+```jldoctest adjoints
julia> gradient((a, b) -> hook(ā -> @show(ā), a)*b, 2, 3)
-ā = 3
-(3, 2)
+ā = 3.0
+(3.0, 2.0)
```
Zygote provides both `hook` and `@showgrad` so you don't have to write these yourself.
@@ -180,7 +182,7 @@ Zygote provides both `hook` and `@showgrad` so you don't have to write these you
A more advanced example is checkpointing, in which we save memory by re-computing the pullback pass of a function during the backwards pass. To wit:
-```julia
+```jldoctest adjoints
julia> checkpoint(f, x) = f(x)
checkpoint (generic function with 1 method)
@@ -192,7 +194,7 @@ julia> gradient(x -> checkpoint(sin, x), 1)
If a function has side effects we'll see that the pullback pass happens twice, as expected.
-```julia
+```jldoctest adjoints
julia> foo(x) = (println(x); sin(x))
foo (generic function with 1 method)
diff --git a/docs/src/complex.md b/docs/src/complex.md
index 4013bb112..2c82bf8b6 100644
--- a/docs/src/complex.md
+++ b/docs/src/complex.md
@@ -4,30 +4,32 @@ Complex numbers add some difficulty to the idea of a "gradient". To talk about `
If `f` returns a real number, things are fairly straightforward. For ``c = x + yi`` and ``z = f(c)``, we can define the adjoint ``\bar c = \frac{\partial z}{\partial x} + \frac{\partial z}{\partial y}i = \bar x + \bar y i`` (note that ``\bar c`` means gradient, and ``c'`` means conjugate). It's exactly as if the complex number were just a pair of reals `(re, im)`. This works out of the box.
-```julia
+```jldoctest complex
+julia> using Zygote
+
julia> gradient(c -> abs2(c), 1+2im)
-(2 + 4im,)
+(2.0 + 4.0im,)
```
However, while this is a very pragmatic definition that works great for gradient descent, it's not quite aligned with the mathematical notion of the derivative: i.e. ``f(c + \epsilon) \approx f(c) + \bar c \epsilon``. In general, such a ``\bar c`` is not possible for complex numbers except when `f` is *holomorphic* (or *analytic*). Roughly speaking this means that the function is defined over `c` as if it were a normal real number, without exploiting its complex structure – it can't use `real`, `imag`, `conj`, or anything that depends on these like `abs2` (`abs2(x) = x*x'`). (This constraint also means there's no overlap with the Real case above; holomorphic functions always return complex numbers for complex input.) But most "normal" numerical functions – `exp`, `log`, anything that can be represented by a Taylor series – are fine.
Fortunately it's also possible to get these derivatives; they are the conjugate of the gradients for the real part.
-```julia
+```jldoctest complex
julia> gradient(x -> real(log(x)), 1+2im)[1] |> conj
0.2 - 0.4im
```
We can check that this function is holomorphic – and thus that the gradient we got out is sensible – by checking the Cauchy-Riemann equations. In other words this should give the same answer:
-```julia
+```jldoctest complex
julia> -im*gradient(x -> imag(log(x)), 1+2im)[1] |> conj
0.2 - 0.4im
```
Notice that this fails in a non-holomorphic case, `f(x) = log(x')`:
-```julia
+```jldoctest complex
julia> gradient(x -> real(log(x')), 1+2im)[1] |> conj
0.2 - 0.4im
diff --git a/docs/src/index.md b/docs/src/index.md
index 1eec9768f..3476d5e7d 100644
--- a/docs/src/index.md
+++ b/docs/src/index.md
@@ -18,18 +18,18 @@ Zygote is easy to understand since, at its core, it has a one-function API (`pul
`gradient` calculates derivatives. For example, the derivative of ``3x^2 + 2x + 1`` is ``6x + 2``, so when `x = 5`, `dx = 32`.
-```julia
+```jldoctest index
julia> using Zygote
julia> gradient(x -> 3x^2 + 2x + 1, 5)
-(32,)
+(32.0,)
```
`gradient` returns a tuple, with a gradient for each argument to the function.
-```julia
+```jldoctest index
julia> gradient((a, b) -> a*b, 2, 3)
-(3, 2)
+(3.0, 2.0)
```
This will work equally well if the arguments are arrays, structs, or any other Julia type, but the function should return a scalar (like a loss or objective ``l``, if you're doing optimisation / ML).
@@ -48,7 +48,7 @@ julia> gradient(x -> 3x^2 + 2x + 1, 1//4)
Control flow is fully supported, including recursion.
-```julia
+```jldoctest index
julia> function pow(x, n)
r = 1
for i = 1:n
@@ -59,26 +59,26 @@ julia> function pow(x, n)
pow (generic function with 1 method)
julia> gradient(x -> pow(x, 3), 5)
-(75,)
+(75.0,)
julia> pow2(x, n) = n <= 0 ? 1 : x*pow2(x, n-1)
pow2 (generic function with 1 method)
julia> gradient(x -> pow2(x, 3), 5)
-(75,)
+(75.0,)
```
Data structures are also supported, including mutable ones like dictionaries. Arrays are currently immutable, though [this may change](https://github.com/FluxML/Zygote.jl/pull/75) in future.
-```julia
+```jldoctest index
julia> d = Dict()
-Dict{Any,Any} with 0 entries
+Dict{Any, Any}()
julia> gradient(5) do x
d[:x] = x
d[:x] * d[:x]
end
-(10,)
+(10.0,)
julia> d[:x]
5
From cb5f279faf5947cf54cfe2404870df2c61114718 Mon Sep 17 00:00:00 2001
From: Aman
Date: Wed, 6 Apr 2022 22:23:47 +0530
Subject: [PATCH 285/490] Fixing spelling error in the docs
---
docs/src/adjoints.md | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/docs/src/adjoints.md b/docs/src/adjoints.md
index 2ee094e78..b38c07a14 100644
--- a/docs/src/adjoints.md
+++ b/docs/src/adjoints.md
@@ -9,7 +9,7 @@
This page exists to descibe how Zygote works, and how adjoints can be directly defined for Zygote.
Defining adjoints this way does not make them accessible to other AD systems, but does let you do things that directly depend on how Zygote works.
- It allows for specific definitions of adjoints that are only defined for Zgyote (which might work differently to more generic definitions defined for all AD).
+ It allows for specific definitions of adjoints that are only defined for Zygote (which might work differently to more generic definitions defined for all AD).
The `@adjoint` macro is an important part of Zygote's interface; customising your backwards pass is not only possible but widely used and encouraged. While there are specific utilities available for common things like gradient clipping, understanding adjoints will give you the most flexibility. We first give a bit more background on what these pullback things are.
From c560ed199a8b72793285a4f076868288a0a3bea9 Mon Sep 17 00:00:00 2001
From: Christian Rorvik
Date: Wed, 13 Apr 2022 17:16:51 +0200
Subject: [PATCH 286/490] Fix type stability of Params.order
---
src/compiler/interface.jl | 4 ++--
1 file changed, 2 insertions(+), 2 deletions(-)
diff --git a/src/compiler/interface.jl b/src/compiler/interface.jl
index 38bc328e0..0319d5268 100644
--- a/src/compiler/interface.jl
+++ b/src/compiler/interface.jl
@@ -138,8 +138,8 @@ gradient
Container for implicit parameters, used when differentiating
a zero-argument funtion `() -> loss(A, B)` with respect to `A, B`.
"""
-struct Params
- order::Buffer # {Any, Vector{Any}}
+struct Params{B <: Buffer}
+ order::B
params::IdSet{Any} # TODO store ids only
end
From b15eff13557de05279f3fd1f82734b218647ec88 Mon Sep 17 00:00:00 2001
From: David Widmann
Date: Wed, 13 Apr 2022 23:38:07 +0200
Subject: [PATCH 287/490] Handle `ChainRulesCore.NotImplemented`
---
Project.toml | 2 +-
src/compiler/chainrules.jl | 1 +
test/chainrules.jl | 12 ++++++++++++
3 files changed, 14 insertions(+), 1 deletion(-)
diff --git a/Project.toml b/Project.toml
index 19ee30869..9d531789d 100644
--- a/Project.toml
+++ b/Project.toml
@@ -1,6 +1,6 @@
name = "Zygote"
uuid = "e88e6eb3-aa80-5325-afca-941959d7151f"
-version = "0.6.37"
+version = "0.6.38"
[deps]
AbstractFFTs = "621f4979-c628-5d54-868e-fcf4e3e8185c"
diff --git a/src/compiler/chainrules.jl b/src/compiler/chainrules.jl
index b3157f289..99d8f4652 100644
--- a/src/compiler/chainrules.jl
+++ b/src/compiler/chainrules.jl
@@ -106,6 +106,7 @@ Convert `x` from the differentials types ChainRules uses to the format Zygote us
# Zygote convention: even if many AbstractZero partials (i.e. multi-input function), make just 1 nothing.
@inline wrap_chainrules_output(x::Tuple{Vararg{ChainRules.AbstractZero}}) = nothing
@inline wrap_chainrules_output(x::ChainRules.AbstractZero) = nothing
+@inline wrap_chainrules_output(x::ChainRulesCore.NotImplemented) = nothing
for T_outer in (:Tuple, :NamedTuple)
# we create separate methods rather than using a `Union` + an `if` so that we avoid a
# branch that changes output type, because nested AD on that kinda thing makes Zygote less
diff --git a/test/chainrules.jl b/test/chainrules.jl
index bc32c879d..94ab9584a 100644
--- a/test/chainrules.jl
+++ b/test/chainrules.jl
@@ -263,6 +263,18 @@ using Zygote: ZygoteRuleConfig
@test (1.0,) == Zygote.gradient(oout_id_outer, π)
@test oout_id_rrule_hitcount[] == 0
end
+
+ # issue #1204
+ @testset "NotImplemented" begin
+ f_notimplemented(x) = x
+ @scalar_rule f_notimplemented(x) @not_implemented("not implemented :(")
+ @test Zygote.gradient(f_notimplemented, 0.1) === (nothing,)
+ @test Zygote.gradient(x -> f_notimplemented(x[1]), 0.1) === (nothing,)
+ if isdefined(Base, :only)
+ @test Zygote.gradient(x -> f_notimplemented(only(x)), (0.1,)) === (nothing,)
+ @test Zygote.gradient(x -> f_notimplemented(only(x)), [0.1]) === (nothing,)
+ end
+ end
end
@testset "ChainRulesCore.rrule_via_ad" begin
From 6299e5e7cac4a6e182fe01b2663552240a7a46d1 Mon Sep 17 00:00:00 2001
From: Brian Chen
Date: Sun, 17 Apr 2022 08:31:57 -0700
Subject: [PATCH 288/490] Actually correct `Base.in(x, ps::Params)`
Despite triumphant claims of victory, the [original PR](https://github.com/FluxML/Zygote.jl/pull/1130) still got the order of arguments wrong and essentially manually `@macroexpand`ed part of the `@forward`. This PR properly fixes that snafu.
---
src/compiler/interface.jl | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/src/compiler/interface.jl b/src/compiler/interface.jl
index 0319d5268..d5428e97e 100644
--- a/src/compiler/interface.jl
+++ b/src/compiler/interface.jl
@@ -150,7 +150,7 @@ Params(xs::Tuple) = Params(collect(xs))
@forward Params.order Base.iterate, Base.length, Base.getindex
-Base.in(ps::Params, x) = x in ps.params
+Base.in(x, ps::Params) = x in ps.params
Base.map(::typeof(_project), args::Tuple{Params}, grad) = grad # skip _project in gradient(f, ::Params)
From 6980a17a2689318e0a874b10dcc302466880c1cd Mon Sep 17 00:00:00 2001
From: Brian Chen
Date: Tue, 26 Apr 2022 06:58:34 -0700
Subject: [PATCH 289/490] Remove `cat` adjoint in favour of ChainRules
---
src/lib/array.jl | 13 -------------
1 file changed, 13 deletions(-)
diff --git a/src/lib/array.jl b/src/lib/array.jl
index 6f5386c32..f492af9e6 100644
--- a/src/lib/array.jl
+++ b/src/lib/array.jl
@@ -116,19 +116,6 @@ pull_block_horz(sz, Δ, A::AbstractMatrix) = Δ[:, sz-size(A, 2)+1:sz]
end
@adjoint hcat(xs::Number...) = hcat(xs...), Δ -> (Δ...,)
-@adjoint function cat(Xs...; dims)
- cat(Xs...; dims = dims), Δ -> begin
- start = ntuple(_ -> 0, ndims(Δ))
- catdims = Base.dims2cat(dims)
- dXs = map(Xs) do x
- move = ntuple(d -> (d<=length(catdims) && catdims[d]) ? size(x,d) : 0, ndims(Δ))
- x_in_Δ = ntuple(d -> (d<=length(catdims) && catdims[d]) ? (start[d]+1:start[d]+move[d]) : Colon(), ndims(Δ))
- start = start .+ move
- dx = Δ[x_in_Δ...]
- end
- end
-end
-
@adjoint function repeat(xs; inner=ntuple(_->1, ndims(xs)), outer=ntuple(_->1, ndims(xs)))
repeat(xs, inner = inner, outer = outer), function (Δ)
Δ′ = zero(xs)
From 375c7dbcc9ee0fa2cbb8524c5ff7be9d70cc270d Mon Sep 17 00:00:00 2001
From: Brian Chen
Date: Tue, 26 Apr 2022 12:04:13 -0400
Subject: [PATCH 290/490] Bump version
---
Project.toml | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/Project.toml b/Project.toml
index 9d531789d..343069107 100644
--- a/Project.toml
+++ b/Project.toml
@@ -1,6 +1,6 @@
name = "Zygote"
uuid = "e88e6eb3-aa80-5325-afca-941959d7151f"
-version = "0.6.38"
+version = "0.6.39"
[deps]
AbstractFFTs = "621f4979-c628-5d54-868e-fcf4e3e8185c"
From 74f6b8f0751c8b9242fb27822d4a8d712436d023 Mon Sep 17 00:00:00 2001
From: lassepe
Date: Tue, 26 Apr 2022 20:27:41 +0200
Subject: [PATCH 291/490] Make chunk threshold configurable
---
src/lib/forward.jl | 22 ++++++++++------------
1 file changed, 10 insertions(+), 12 deletions(-)
diff --git a/src/lib/forward.jl b/src/lib/forward.jl
index 13957895c..8b96b417a 100644
--- a/src/lib/forward.jl
+++ b/src/lib/forward.jl
@@ -39,12 +39,9 @@ function forward_jacobian(f, x, ::Val{N}) where N
return y, J
end
-function forward_jacobian(f, x)
- if length(x) < ForwardDiff.DEFAULT_CHUNK_THRESHOLD
- forward_jacobian(f, x, Val(length(x)))
- else
- forward_jacobian(f, x, Val(ForwardDiff.DEFAULT_CHUNK_THRESHOLD))
- end
+function forward_jacobian(f, x; chunk_threshold = ForwardDiff.DEFAULT_CHUNK_THRESHOLD)
+ chunk_size = min(length(x), chunk_threshold)
+ forward_jacobian(f, x, Val(chunk_size))
end
vec_scalar(x) = vec(x)
@@ -82,10 +79,11 @@ function forward_diag(f, x::AbstractArray)
end
"""
- forwarddiff(f, x) -> f(x)
+ forwarddiff(f, x; chunk_threshold = ForwardDiff.DEFAULT_CHUNK_THRESHOLD) -> f(x)
Runs `f(x)` as usual, but instructs Zygote to differentiate `f` using forward
-mode, rather than the usual reverse mode.
+mode, rather than the usual reverse mode. The `chunk_threshold` argument controls
+the maximum chunk size (c.f. ForwardDiff documentation).
Forward mode takes time linear in `length(x)` but only has constant memory
overhead, and is very efficient for scalars, so in some cases this can be a
@@ -130,11 +128,11 @@ gradient(2, 3) do a, b
end
```
"""
-forwarddiff(f, x) = f(x)
+forwarddiff(f, x; chunk_threshold = ForwardDiff.DEFAULT_CHUNK_THRESHOLD) = f(x)
-@adjoint function forwarddiff(f, x)
- y, J = forward_jacobian(f, x)
- return y, ȳ -> (nothing, reshape_scalar(x, J*vec_scalar(ȳ)))
+@adjoint function forwarddiff(f, x; chunk_threshold = ForwardDiff.DEFAULT_CHUNK_THRESHOLD)
+ y, J = forward_jacobian(f, x; chunk_threshold)
+ return y, ȳ -> (nothing, reshape_scalar(x, J * vec_scalar(ȳ)))
end
# Use this to allow second derivatives -- this is forward-over-forward,
From c3b17f736674abbe24c494fa79712716f0a7a23f Mon Sep 17 00:00:00 2001
From: lassepe
Date: Tue, 26 Apr 2022 20:39:16 +0200
Subject: [PATCH 292/490] Add tests for forward mode chunking
---
test/features.jl | 10 ++++++++++
1 file changed, 10 insertions(+)
diff --git a/test/features.jl b/test/features.jl
index 839e98cc4..729352f33 100644
--- a/test/features.jl
+++ b/test/features.jl
@@ -375,6 +375,16 @@ end == (1,)
forwarddiff(x -> x^2, x)
end == (10,)
+@testset "Gradient chunking" begin
+ for chunk_threshold in 1:10:100
+ x = [1:100;]
+ @test gradient(x) do x
+ Zygote.forwarddiff(x -> x' * x, x; chunk_threshold)
+ end == (2 * x,)
+ end
+end
+
+
@test gradient(1) do x
if true
elseif true
From 0b0ea3d5c2ba62f4fd7b694060d687722df48373 Mon Sep 17 00:00:00 2001
From: lassepe
Date: Tue, 26 Apr 2022 20:56:43 +0200
Subject: [PATCH 293/490] Fix julia 1.3 threshold
---
src/lib/forward.jl | 2 +-
test/features.jl | 2 +-
2 files changed, 2 insertions(+), 2 deletions(-)
diff --git a/src/lib/forward.jl b/src/lib/forward.jl
index 8b96b417a..9cdbf1adb 100644
--- a/src/lib/forward.jl
+++ b/src/lib/forward.jl
@@ -131,7 +131,7 @@ end
forwarddiff(f, x; chunk_threshold = ForwardDiff.DEFAULT_CHUNK_THRESHOLD) = f(x)
@adjoint function forwarddiff(f, x; chunk_threshold = ForwardDiff.DEFAULT_CHUNK_THRESHOLD)
- y, J = forward_jacobian(f, x; chunk_threshold)
+ y, J = forward_jacobian(f, x; chunk_threshold = chunk_threshold)
return y, ȳ -> (nothing, reshape_scalar(x, J * vec_scalar(ȳ)))
end
diff --git a/test/features.jl b/test/features.jl
index 729352f33..cdfe7329e 100644
--- a/test/features.jl
+++ b/test/features.jl
@@ -379,7 +379,7 @@ end == (10,)
for chunk_threshold in 1:10:100
x = [1:100;]
@test gradient(x) do x
- Zygote.forwarddiff(x -> x' * x, x; chunk_threshold)
+ Zygote.forwarddiff(x -> x' * x, x; chunk_threshold = chunk_threshold)
end == (2 * x,)
end
end
From 5771195764a1e37ca8521fadc7d2f2ea55e4e830 Mon Sep 17 00:00:00 2001
From: lassepe
Date: Wed, 27 Apr 2022 13:55:03 +0200
Subject: [PATCH 294/490] Formatting
---
src/lib/forward.jl | 6 +++---
1 file changed, 3 insertions(+), 3 deletions(-)
diff --git a/src/lib/forward.jl b/src/lib/forward.jl
index 9cdbf1adb..6d433b405 100644
--- a/src/lib/forward.jl
+++ b/src/lib/forward.jl
@@ -131,11 +131,11 @@ end
forwarddiff(f, x; chunk_threshold = ForwardDiff.DEFAULT_CHUNK_THRESHOLD) = f(x)
@adjoint function forwarddiff(f, x; chunk_threshold = ForwardDiff.DEFAULT_CHUNK_THRESHOLD)
- y, J = forward_jacobian(f, x; chunk_threshold = chunk_threshold)
- return y, ȳ -> (nothing, reshape_scalar(x, J * vec_scalar(ȳ)))
+ y, J = forward_jacobian(f, x; chunk_threshold = chunk_threshold)
+ return y, ȳ -> (nothing, reshape_scalar(x, J * vec_scalar(ȳ)))
end
-# Use this to allow second derivatives -- this is forward-over-forward,
+# Use this to allow second derivatives -- this is forward-over-forward,
# see https://github.com/FluxML/Zygote.jl/issues/769 for a forward-over-reverse proposal
@adjoint ForwardDiff.gradient(f, x) = pullback(forwarddiff, x -> ForwardDiff.gradient(f, x), x)
@adjoint ForwardDiff.jacobian(f, x) = pullback(forwarddiff, x -> ForwardDiff.jacobian(f, x), x)
From c44bac5c5cabc2bc73fffaa637c260d41043f3cc Mon Sep 17 00:00:00 2001
From: Soeren Schoenbrod
Date: Thu, 28 Apr 2022 08:44:13 +0200
Subject: [PATCH 295/490] Correct spelling mistakte
---
docs/src/adjoints.md | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/docs/src/adjoints.md b/docs/src/adjoints.md
index b38c07a14..45f0662b3 100644
--- a/docs/src/adjoints.md
+++ b/docs/src/adjoints.md
@@ -2,7 +2,7 @@
!!! note "Prefer to use ChainRulesCore to define custom adjoints"
Zygote supports the use of [ChainRulesCore](http://www.juliadiff.org/ChainRulesCore.jl/stable/) to define custom sensitivities.
- It is prefered to define the custom sensitivities using `ChainRulesCore.rrule` as they will work for many AD systems, not just Zygote.
+ It is preferred to define the custom sensitivities using `ChainRulesCore.rrule` as they will work for many AD systems, not just Zygote.
These sensitivities can be added in your own package, or for Base/StdLib functions they can be added to [ChainRules.jl](https://github.com/JuliaDiff/ChainRules.jl/).
To define custom sensitivities using ChainRulesCore, define a `ChainRulesCore.rrule(f, args...; kwargs...)`. Head to [ChainRules project's documentation](https://www.juliadiff.org/ChainRulesCore.jl/stable/) for more information.
**If you are defining your custom adjoints using ChainRulesCore then you do not need to read this page**, and can consider it as documenting a legacy feature.
From 2dc86f554235c05a470b4776bad2807b85c33df5 Mon Sep 17 00:00:00 2001
From: Brian Chen
Date: Sat, 7 May 2022 11:08:28 -0700
Subject: [PATCH 296/490] Use `setglobal!` on nightly
This should address a CI failure.
---
src/lib/lib.jl | 8 ++++++--
1 file changed, 6 insertions(+), 2 deletions(-)
diff --git a/src/lib/lib.jl b/src/lib/lib.jl
index f154ecd2a..f11a74214 100644
--- a/src/lib/lib.jl
+++ b/src/lib/lib.jl
@@ -81,8 +81,12 @@ unwrap(ref, x) = x
end
function global_set(ref, val)
- ccall(:jl_set_global, Cvoid, (Any, Any, Any),
- ref.mod, ref.name, val)
+ @static if VERSION < v"1.9.0-DEV.265"
+ ccall(:jl_set_global, Cvoid, (Any, Any, Any),
+ ref.mod, ref.name, val)
+ else
+ setglobal!(ref.mod, ref.name, val)
+ end
end
@adjoint! function global_set(ref, x)
From 89a1caab7f56849f9f9e1f43d59b352b5b3966fe Mon Sep 17 00:00:00 2001
From: David Widmann
Date: Wed, 23 Feb 2022 17:08:36 +0100
Subject: [PATCH 297/490] Fix deprecations in DiffRules 1.4
---
Project.toml | 6 ++++--
src/forward/number.jl | 28 +++++++++++++++-------------
2 files changed, 19 insertions(+), 15 deletions(-)
diff --git a/Project.toml b/Project.toml
index 343069107..97af0610c 100644
--- a/Project.toml
+++ b/Project.toml
@@ -1,6 +1,6 @@
name = "Zygote"
uuid = "e88e6eb3-aa80-5325-afca-941959d7151f"
-version = "0.6.39"
+version = "0.6.40"
[deps]
AbstractFFTs = "621f4979-c628-5d54-868e-fcf4e3e8185c"
@@ -13,6 +13,7 @@ ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
IRTools = "7869d1d1-7146-5819-86e3-90919afe41df"
InteractiveUtils = "b77e0a4c-d291-57a0-90e8-8db25a27a240"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
+LogExpFunctions = "2ab3a3ac-af41-5b50-aa03-7779005ae688"
MacroTools = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09"
NaNMath = "77ba4419-2d1f-58cd-9bb1-8ffee604a2e3"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
@@ -27,10 +28,11 @@ AbstractFFTs = "0.5, 1.0"
ChainRules = "1.5"
ChainRulesCore = "1.9"
ChainRulesTestUtils = "1"
-DiffRules = "1.0"
+DiffRules = "1.4"
FillArrays = "0.8, 0.9, 0.10, 0.11, 0.12, 0.13"
ForwardDiff = "0.10"
IRTools = "0.4.4"
+LogExpFunctions = "0.3"
MacroTools = "0.5"
NaNMath = "0.3, 1"
Requires = "1.1"
diff --git a/src/forward/number.jl b/src/forward/number.jl
index db88af656..322f76f88 100644
--- a/src/forward/number.jl
+++ b/src/forward/number.jl
@@ -1,21 +1,23 @@
-using DiffRules, SpecialFunctions, NaNMath
+using DiffRules, SpecialFunctions, NaNMath, LogExpFunctions
using Base.FastMath: fast_op, make_fastmath
# TODO use CSE here
-for (M, f, arity) in DiffRules.diffrules()
- arity == 1 || continue
- dx = DiffRules.diffrule(M, f, :x)
- @eval begin
- @tangent $M.$f(x::Number) = $M.$f(x), ẋ -> ẋ * $dx
+for (M, f, arity) in DiffRules.diffrules(; filter_modules=nothing)
+ if !(isdefined(@__MODULE__, M) && isdefined(getfield(@__MODULE__, M), f))
+ @warn "$M.$f is not available and hence rule for it can not be defined"
+ continue # Skip rules for methods not defined in the current scope
end
-end
-
-for (M, f, arity) in DiffRules.diffrules()
- arity == 2 || continue
- da, db = DiffRules.diffrule(M, f, :a, :b)
- @eval begin
- @tangent $M.$f(a::Number, b::Number) = $M.$f(a, b), (ȧ, ḃ) -> ȧ*$da + ḃ*$db
+ if arity == 1
+ dx = DiffRules.diffrule(M, f, :x)
+ @eval begin
+ @tangent $M.$f(x::Number) = $M.$f(x), ẋ -> ẋ * $dx
+ end
+ elseif arity == 2
+ da, db = DiffRules.diffrule(M, f, :a, :b)
+ @eval begin
+ @tangent $M.$f(a::Number, b::Number) = $M.$f(a, b), (ȧ, ḃ) -> ȧ*$da + ḃ*$db
+ end
end
end
From e4a9e7ceaf87dac0e271c1c7b1b3b6a7d3f1a5b2 Mon Sep 17 00:00:00 2001
From: Mason Protter
Date: Mon, 9 May 2022 19:57:07 -0600
Subject: [PATCH 298/490] Remove unnecessary generated functions (#1220)
* Remove unnecessary generated functions
* fix typo
---
src/tools/builtins.jl | 17 +++++++----------
1 file changed, 7 insertions(+), 10 deletions(-)
diff --git a/src/tools/builtins.jl b/src/tools/builtins.jl
index 6d0daf57c..07c3a3217 100644
--- a/src/tools/builtins.jl
+++ b/src/tools/builtins.jl
@@ -1,17 +1,14 @@
-@generated function __new__(T, args...)
- quote
- Base.@_inline_meta
- $(Expr(:new, :T, [:(args[$i]) for i = 1:length(args)]...))
- end
+macro __new__(T, args...)
+ esc(Expr(:new, T, args...))
end
-@generated function __splatnew__(T, args)
- quote
- Base.@_inline_meta
- $(Expr(:splatnew, :T, :args))
- end
+macro __splatnew__(T, args)
+ esc(Expr(:splatnew, T, args))
end
+@inline __new__(T, args...) = @__splatnew__(T, args)
+@inline __splatnew__(T, args) = @__splatnew__(T, args)
+
literal_getindex(x, ::Val{i}) where i = getindex(x, i)
literal_indexed_iterate(x, ::Val{i}) where i = Base.indexed_iterate(x, i)
literal_indexed_iterate(x, ::Val{i}, state) where i = Base.indexed_iterate(x, i, state)
From 885a904ed958c74cdaa2af7a971a6b5a2da908a7 Mon Sep 17 00:00:00 2001
From: Michael Abbott <32575566+mcabbott@users.noreply.github.com>
Date: Mon, 9 May 2022 21:59:18 -0400
Subject: [PATCH 299/490] v0.6.40
---
Project.toml | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/Project.toml b/Project.toml
index 343069107..666b2fd7c 100644
--- a/Project.toml
+++ b/Project.toml
@@ -1,6 +1,6 @@
name = "Zygote"
uuid = "e88e6eb3-aa80-5325-afca-941959d7151f"
-version = "0.6.39"
+version = "0.6.40"
[deps]
AbstractFFTs = "621f4979-c628-5d54-868e-fcf4e3e8185c"
From b6651a317448e90ba5cd61b03a0ae19176cb8a56 Mon Sep 17 00:00:00 2001
From: David Widmann
Date: Tue, 10 May 2022 16:57:52 +0200
Subject: [PATCH 300/490] Make LogExpFunctions a proper dependency and clean
adjoints
---
Project.toml | 9 ++++-----
src/Zygote.jl | 2 +-
src/lib/logexpfunctions.jl | 25 ++-----------------------
3 files changed, 7 insertions(+), 29 deletions(-)
diff --git a/Project.toml b/Project.toml
index 666b2fd7c..3f1837806 100644
--- a/Project.toml
+++ b/Project.toml
@@ -1,6 +1,6 @@
name = "Zygote"
uuid = "e88e6eb3-aa80-5325-afca-941959d7151f"
-version = "0.6.40"
+version = "0.6.41"
[deps]
AbstractFFTs = "621f4979-c628-5d54-868e-fcf4e3e8185c"
@@ -13,6 +13,7 @@ ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
IRTools = "7869d1d1-7146-5819-86e3-90919afe41df"
InteractiveUtils = "b77e0a4c-d291-57a0-90e8-8db25a27a240"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
+LogExpFunctions = "2ab3a3ac-af41-5b50-aa03-7779005ae688"
MacroTools = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09"
NaNMath = "77ba4419-2d1f-58cd-9bb1-8ffee604a2e3"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
@@ -31,11 +32,11 @@ DiffRules = "1.0"
FillArrays = "0.8, 0.9, 0.10, 0.11, 0.12, 0.13"
ForwardDiff = "0.10"
IRTools = "0.4.4"
+LogExpFunctions = "0.3.1"
MacroTools = "0.5"
NaNMath = "0.3, 1"
Requires = "1.1"
SpecialFunctions = "1.6, 2"
-StatsFuns = "0.9.8"
ZygoteRules = "0.2.1"
julia = "1.3"
@@ -45,9 +46,7 @@ ChainRulesTestUtils = "cdddcdb0-9152-4a09-a978-84456f9df70a"
Distances = "b4f34e82-e78d-54a5-968a-f98e89d6e8f7"
FFTW = "7a1cc6ca-52ef-59f5-83cd-3a7055c09341"
FiniteDifferences = "26cc04aa-876d-5657-8c51-4c34ba976000"
-LogExpFunctions = "2ab3a3ac-af41-5b50-aa03-7779005ae688"
-StatsFuns = "4c63d2b9-4356-54db-8cca-17b64c39e42c"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
[targets]
-test = ["ChainRulesTestUtils", "CUDA", "Distances", "FFTW", "FiniteDifferences", "LogExpFunctions", "Test"]
+test = ["ChainRulesTestUtils", "CUDA", "Distances", "FFTW", "FiniteDifferences", "Test"]
diff --git a/src/Zygote.jl b/src/Zygote.jl
index d537efbb7..a42dd38c1 100644
--- a/src/Zygote.jl
+++ b/src/Zygote.jl
@@ -41,8 +41,8 @@ include("lib/broadcast.jl")
include("lib/forward.jl")
include("lib/utils.jl")
include("lib/range.jl")
+include("lib/logexpfunctions.jl")
@init @require Distances="b4f34e82-e78d-54a5-968a-f98e89d6e8f7" include("lib/distances.jl")
-@init @require LogExpFunctions="2ab3a3ac-af41-5b50-aa03-7779005ae688" include("lib/logexpfunctions.jl")
# we need to define this late, so that the genfuncs see lib.jl
# Move using statements out of this file to help with sysimage building
diff --git a/src/lib/logexpfunctions.jl b/src/lib/logexpfunctions.jl
index 1e5e4c0b6..4bdda0d6c 100644
--- a/src/lib/logexpfunctions.jl
+++ b/src/lib/logexpfunctions.jl
@@ -1,5 +1,4 @@
-using .LogExpFunctions: xlogx, xlogy, logistic, logit, log1psq, log1pexp,
- logsumexp, logaddexp, logsubexp
+using LogExpFunctions: xlogx, xlogy, logistic, log1pexp, logsumexp, logaddexp, logsubexp
using Base.Broadcast: broadcasted
@adjoint function xlogx(x::Real)
@@ -20,30 +19,10 @@ function ∇xlogx(x::Numeric)
return result, dx
end
-@adjoint function logistic(x::Real)
- y = logistic(x)
- return y, Δ->(Δ * y * (1 - y),)
-end
-
-@adjoint logit(x::Real) = logit(x), Δ->(Δ / (x * (1 - x)),)
-
-@adjoint log1psq(x::Real) = log1psq(x), Δ->(Δ * 2x / (1 + abs2(x)),)
-
-@adjoint function log1pexp(x::Real)
- dx = ∂log1pexp(x)
- return log1pexp(x), δ -> (δ * dx,)
-end
@adjoint function broadcasted(::typeof(log1pexp), x::Numeric)
- dx = ∂log1pexp.(x)
+ dx = logistic.(x)
return log1pexp.(x), δ -> (nothing, unbroadcast(x, δ .* dx))
end
-∂log1pexp(x::Real) = x < 18.0 ? logistic(x) : x < 33.3 ? one(x) - exp(-x) : oftype(exp(x), 1)
-∂log1pexp(x::Float32) = x < 9f0 ? logistic(x) : x < 16f0 ? one(x) - exp(-x) : oftype(exp(x), 1)
-
-@adjoint function logsumexp(X::AbstractArray{<:Real}; dims=:)
- lse = logsumexp(X; dims=dims)
- return lse, Δ -> (Δ .* exp.(X .- lse),)
-end
@adjoint function xlogy(x::Real, y::Real)
result, dx, dy = ∇xlogy(x, y)
From 5c028dd8388f7194a1dc0d0238c5dfc90a3afb41 Mon Sep 17 00:00:00 2001
From: Michael Abbott <32575566+mcabbott@users.noreply.github.com>
Date: Tue, 10 May 2022 21:05:32 -0400
Subject: [PATCH 301/490] add warnings to forwarddiff
---
src/lib/forward.jl | 28 ++++++++++++++++++++++++----
test/utils.jl | 9 +++++++++
2 files changed, 33 insertions(+), 4 deletions(-)
diff --git a/src/lib/forward.jl b/src/lib/forward.jl
index 6d433b405..d6d846d52 100644
--- a/src/lib/forward.jl
+++ b/src/lib/forward.jl
@@ -137,8 +137,28 @@ end
# Use this to allow second derivatives -- this is forward-over-forward,
# see https://github.com/FluxML/Zygote.jl/issues/769 for a forward-over-reverse proposal
-@adjoint ForwardDiff.gradient(f, x) = pullback(forwarddiff, x -> ForwardDiff.gradient(f, x), x)
-@adjoint ForwardDiff.jacobian(f, x) = pullback(forwarddiff, x -> ForwardDiff.jacobian(f, x), x)
+@adjoint function ForwardDiff.gradient(f, x)
+ F = typeof(f)
+ Base.issingletontype(F) || @warn """`ForwardDiff.gradient(f, x)` within Zygote cannot track gradients with respect to `f`
+ typeof(f) = $F is not a singleton type""" # maxlog=1 _id=hash(F)
+ pullback(forwarddiff, x -> ForwardDiff.gradient(f, x), x)
+end
+
+@adjoint function ForwardDiff.jacobian(f::F, x) where F
+ Base.issingletontype(F) || @warn """`ForwardDiff.jacobian(f, x)` within Zygote cannot track gradients with respect to `f`
+ typeof(f) = $F is not a singleton type""" # maxlog=1 _id=hash(F)
+ pullback(forwarddiff, x -> ForwardDiff.jacobian(f, x), x)
+end
+
+@adjoint function ForwardDiff.derivative(f::F, x) where F
+ Base.issingletontype(F) || @warn """`ForwardDiff.derivative(f, x)` within Zygote cannot track gradients with respect to `f`
+ typeof(f) = $F is not a singleton type""" maxlog=1 _id=hash(F)
+ pullback(forwarddiff, x -> ForwardDiff.derivative(f, x), x)
+end
+
+@adjoint function ForwardDiff.hessian(f::F, x) where F
+ Base.issingletontype(F) || @warn """`ForwardDiff.hessian(f, x)` within Zygote cannot track gradients with respect to `f`
+ typeof(f) = $F is not a singleton type""" maxlog=1 _id=hash(F)
+ pullback(forwarddiff, x -> ForwardDiff.hessian(f, x), x)
+end
-@adjoint ForwardDiff.derivative(f, x) = pullback(forwarddiff, x -> ForwardDiff.derivative(f, x), x)
-@adjoint ForwardDiff.hessian(f, x) = pullback(forwarddiff, x -> ForwardDiff.hessian(f, x), x)
diff --git a/test/utils.jl b/test/utils.jl
index b6d6ed018..40b2e85b7 100644
--- a/test/utils.jl
+++ b/test/utils.jl
@@ -133,4 +133,13 @@ using ForwardDiff
g3(x) = sum(abs2,ForwardDiff.jacobian(f,x))
out,back = Zygote.pullback(g3,[2.0,3.2])
@test back(1.0)[1] == ForwardDiff.gradient(g3,[2.0,3.2])
+
+ # From https://github.com/FluxML/Zygote.jl/issues/1218
+ f1218(x::AbstractVector,y::AbstractVector) = sum(x)*sum(y)
+ gradf1218(x,y) = ForwardDiff.gradient(x->f1218(x,y), x)[1]
+ x = [0.1]
+ y = rand(5)
+ @test ForwardDiff.gradient(y->gradf1218(x,y), y) == ones(5)
+ # this returns (nothing,) -- now prints a warning
+ @test_broken Zygote.gradient(y->gradf1218(x,y), y) == ones(5)
end
From fc56b9df8d47a8dda89bffe4422855b95dae143a Mon Sep 17 00:00:00 2001
From: Michael Abbott <32575566+mcabbott@users.noreply.github.com>
Date: Tue, 10 May 2022 23:00:18 -0400
Subject: [PATCH 302/490] change wording, display once
---
src/lib/forward.jl | 20 ++++++++++++--------
1 file changed, 12 insertions(+), 8 deletions(-)
diff --git a/src/lib/forward.jl b/src/lib/forward.jl
index d6d846d52..3ee3f7d1c 100644
--- a/src/lib/forward.jl
+++ b/src/lib/forward.jl
@@ -139,26 +139,30 @@ end
# see https://github.com/FluxML/Zygote.jl/issues/769 for a forward-over-reverse proposal
@adjoint function ForwardDiff.gradient(f, x)
F = typeof(f)
- Base.issingletontype(F) || @warn """`ForwardDiff.gradient(f, x)` within Zygote cannot track gradients with respect to `f`
- typeof(f) = $F is not a singleton type""" # maxlog=1 _id=hash(F)
+ Base.issingletontype(F) || @warn """`ForwardDiff.gradient(f, x)` within Zygote cannot track gradients with respect to `f`,
+ and `f` appears to be a closure, or a struct with fields (according to `issingletontype(typeof(f))`).
+ typeof(f) = $F""" maxlog=1 _id=hash(F)
pullback(forwarddiff, x -> ForwardDiff.gradient(f, x), x)
end
@adjoint function ForwardDiff.jacobian(f::F, x) where F
- Base.issingletontype(F) || @warn """`ForwardDiff.jacobian(f, x)` within Zygote cannot track gradients with respect to `f`
- typeof(f) = $F is not a singleton type""" # maxlog=1 _id=hash(F)
+ Base.issingletontype(F) || @warn """`ForwardDiff.jacobian(f, x)` within Zygote cannot track gradients with respect to `f`,
+ and `f` appears to be a closure, or a struct with fields (according to `issingletontype(typeof(f))`).
+ typeof(f) = $F""" maxlog=1 _id=hash(F)
pullback(forwarddiff, x -> ForwardDiff.jacobian(f, x), x)
end
@adjoint function ForwardDiff.derivative(f::F, x) where F
- Base.issingletontype(F) || @warn """`ForwardDiff.derivative(f, x)` within Zygote cannot track gradients with respect to `f`
- typeof(f) = $F is not a singleton type""" maxlog=1 _id=hash(F)
+ Base.issingletontype(F) || @warn """`ForwardDiff.derivative(f, x)` within Zygote cannot track gradients with respect to `f`,
+ and `f` appears to be a closure, or a struct with fields (according to `issingletontype(typeof(f))`).
+ typeof(f) = $F""" maxlog=1 _id=hash(F)
pullback(forwarddiff, x -> ForwardDiff.derivative(f, x), x)
end
@adjoint function ForwardDiff.hessian(f::F, x) where F
- Base.issingletontype(F) || @warn """`ForwardDiff.hessian(f, x)` within Zygote cannot track gradients with respect to `f`
- typeof(f) = $F is not a singleton type""" maxlog=1 _id=hash(F)
+ Base.issingletontype(F) || @warn """`ForwardDiff.hessian(f, x)` within Zygote cannot track gradients with respect to `f`,
+ and `f` appears to be a closure, or a struct with fields (according to `issingletontype(typeof(f))`).
+ typeof(f) = $F""" maxlog=1 _id=hash(F)
pullback(forwarddiff, x -> ForwardDiff.hessian(f, x), x)
end
From 0f3d8430b181745c3698657780327f2313f926c1 Mon Sep 17 00:00:00 2001
From: DomCRose
Date: Fri, 20 May 2022 19:17:07 +0100
Subject: [PATCH 303/490] Fix non-holomorphic tests
---
test/complex.jl | 14 ++++++++++----
1 file changed, 10 insertions(+), 4 deletions(-)
diff --git a/test/complex.jl b/test/complex.jl
index 1abd1303f..54c49f299 100644
--- a/test/complex.jl
+++ b/test/complex.jl
@@ -76,14 +76,20 @@ fs_C_to_C_non_holomorphic = (conj,
z->im*abs2(z),
z->z'z,
z->conj(z)*z^2,
+ z->imag(z)^2+real(sin(z))^3*1im,
)
@testset "C->C non-holomorphic" begin
- for f in (fs_C_to_C_holomorphic...,fs_C_to_C_holomorphic...)
+ for f in (fs_C_to_C_holomorphic...,fs_C_to_C_non_holomorphic...)
for z in (1.0+2.0im, -2.0+pi*im)
- grad_zygote = gradient(real∘f, z)[1]
+ grad_zygote_r = gradient(real∘f, z)[1]
+ grad_zygote_i = gradient(imag∘f, z)[1]
ε = 1e-8
- grad_fd = real(f(z+ε)-f(z))/ε + im*real(f(z+ε*im)-f(z))/ε
- @test abs(grad_zygote - grad_fd) < sqrt(ε)
+ grad_fd_r = real(f(z+ε)-f(z))/ε + im*real(f(z+ε*im)-f(z))/ε
+ grad_fd_i = imag(f(z+ε)-f(z))/ε + im*imag(f(z+ε*im)-f(z))/ε
+ # Check derivative of both real and imaginary parts of f as these may differ
+ # for non-holomorphic functions
+ @test abs(grad_zygote_r - grad_fd_r) < sqrt(ε)
+ @test abs(grad_zygote_i - grad_fd_i) < sqrt(ε)
end
end
end
From 22b6963ecd6373eee8fd176e4c12c5c40f3008af Mon Sep 17 00:00:00 2001
From: DomCRose
Date: Sun, 22 May 2022 20:32:22 +0100
Subject: [PATCH 304/490] Seperate (non-)holomorphic tests
---
test/complex.jl | 15 ++++++++++-----
1 file changed, 10 insertions(+), 5 deletions(-)
diff --git a/test/complex.jl b/test/complex.jl
index 54c49f299..d73cac65d 100644
--- a/test/complex.jl
+++ b/test/complex.jl
@@ -58,12 +58,17 @@ fs_C_to_C_holomorphic = (cos,
@testset "C->C holomorphic" begin
for f in fs_C_to_C_holomorphic
for z in (1.0+2.0im, -2.0+pi*im)
- grad_zygote = gradient(real∘f, z)[1]
+ grad_zygote_r = gradient(real∘f, z)[1]
+ grad_zygote_i = gradient(imag∘f, z)[1]
ε = 1e-8
grad_fd_r = (f(z+ε)-f(z))/ε
- grad_fd_i = (f(z+ε*im)-f(z))/(ε*im)
- @assert abs(grad_fd_r - grad_fd_i) < sqrt(ε) # check the function is indeed holomorphic
- @test abs(grad_zygote - conj(grad_fd_r)) < sqrt(ε)
+ grad_fd_i = (f(z + ε * im) - f(z)) / (ε * im)
+ # check the function is indeed holomorphic
+ @assert abs(grad_fd_r - grad_fd_i) < sqrt(ε)
+ # check Zygote derivatives agree with holomorphic definition
+ @test abs(grad_zygote_r + im*grad_zygote_i) < sqrt(ε)
+ # check derivative agrees with finite differences
+ @test abs(grad_zygote_r - conj(grad_fd_r)) < sqrt(ε)
end
end
end
@@ -79,7 +84,7 @@ fs_C_to_C_non_holomorphic = (conj,
z->imag(z)^2+real(sin(z))^3*1im,
)
@testset "C->C non-holomorphic" begin
- for f in (fs_C_to_C_holomorphic...,fs_C_to_C_non_holomorphic...)
+ for f in fs_C_to_C_non_holomorphic
for z in (1.0+2.0im, -2.0+pi*im)
grad_zygote_r = gradient(real∘f, z)[1]
grad_zygote_i = gradient(imag∘f, z)[1]
From 9f417db5bc214dd4a60ff0f68bf266bde51cd804 Mon Sep 17 00:00:00 2001
From: DomCRose
Date: Tue, 24 May 2022 11:12:32 +0100
Subject: [PATCH 305/490] Edit holomorphic gradient test to use approx
---
test/complex.jl | 8 ++++----
1 file changed, 4 insertions(+), 4 deletions(-)
diff --git a/test/complex.jl b/test/complex.jl
index d73cac65d..efb1e06dd 100644
--- a/test/complex.jl
+++ b/test/complex.jl
@@ -63,11 +63,11 @@ fs_C_to_C_holomorphic = (cos,
ε = 1e-8
grad_fd_r = (f(z+ε)-f(z))/ε
grad_fd_i = (f(z + ε * im) - f(z)) / (ε * im)
- # check the function is indeed holomorphic
+ # Check the function is indeed holomorphic
@assert abs(grad_fd_r - grad_fd_i) < sqrt(ε)
- # check Zygote derivatives agree with holomorphic definition
- @test abs(grad_zygote_r + im*grad_zygote_i) < sqrt(ε)
- # check derivative agrees with finite differences
+ # Check Zygote derivatives agree with holomorphic definition
+ @test grad_zygote_r ≈ -im*grad_zygote_i
+ # Check derivative agrees with finite differences
@test abs(grad_zygote_r - conj(grad_fd_r)) < sqrt(ε)
end
end
From b9caf3f5b1dccd7448e59e0d59d89374e039fa0c Mon Sep 17 00:00:00 2001
From: Brian Chen
Date: Sat, 4 Jun 2022 17:16:39 -0700
Subject: [PATCH 306/490] Allow accumulating distinct Dicts
---
src/lib/base.jl | 4 ++--
test/features.jl | 15 ++++++++++++++-
2 files changed, 16 insertions(+), 3 deletions(-)
diff --git a/src/lib/base.jl b/src/lib/base.jl
index ac7df59a2..b476eb175 100644
--- a/src/lib/base.jl
+++ b/src/lib/base.jl
@@ -24,8 +24,8 @@ grad_mut(d::IdDict) = IdDict()
# TODO perhaps look up mutable gradients in `pullback`
function accum(a::AbstractDict, b::AbstractDict)
- @assert a === b
- return a
+ a === b && return a # Mutating case
+ return merge(a, b)
end
@adjoint function getindex(d::AbstractDict, k)
diff --git a/test/features.jl b/test/features.jl
index cdfe7329e..838645401 100644
--- a/test/features.jl
+++ b/test/features.jl
@@ -1,5 +1,6 @@
-using Zygote, Test
+using Zygote, Test, LinearAlgebra
using Zygote: Params, gradient, forwarddiff
+using FillArrays: Fill
@testset "gradient checkpointing" begin
@@ -676,6 +677,18 @@ end
loss(x) = sum(abs2, net(x))
@test gradient(loss, ones(10,10))[1] == fill(131072, 10, 10)
@test 150_000_000 > @allocated gradient(loss, ones(1000,1000))
+
+ # https://github.com/FluxML/Zygote.jl/issues/1233
+ function defensiveupdate(d, a)
+ nd = deepcopy(d)
+ nd[1] = d[1] * a
+ return nd
+ end
+ d = Dict(i => ones(1) for i in 1:2)
+ @test gradient(d) do d
+ nd = defensiveupdate(d, 5)
+ return sum(nd[1]) + sum(nd[2])
+ end[1] == Dict(1 => Fill(5, 1), 2 => Fill(1, 1))
end
@testset "tricky broadcasting" begin
From 0fa305d21c5d6495ded84f9c2a9e614038788599 Mon Sep 17 00:00:00 2001
From: Sam Anklesaria
Date: Fri, 10 Jun 2022 16:08:51 -0500
Subject: [PATCH 307/490] Fix #1241
---
src/lib/base.jl | 5 ++++-
test/lib/base.jl | 13 +++++++++++++
test/runtests.jl | 1 +
3 files changed, 18 insertions(+), 1 deletion(-)
create mode 100644 test/lib/base.jl
diff --git a/src/lib/base.jl b/src/lib/base.jl
index ac7df59a2..fa71d8906 100644
--- a/src/lib/base.jl
+++ b/src/lib/base.jl
@@ -29,11 +29,14 @@ function accum(a::AbstractDict, b::AbstractDict)
end
@adjoint function getindex(d::AbstractDict, k)
- d[k], function (Δ)
+ val = d[k]
+ function dict_getindex_pullback(Δ)
+ accum_param(__context__, val, Δ) === nothing && return
grad = grad_mut(__context__, d)
grad[k] = accum(get(grad, k, nothing), Δ)
return (grad, nothing)
end
+ val, dict_getindex_pullback
end
@adjoint! function setindex!(d::AbstractDict, v, k)
diff --git a/test/lib/base.jl b/test/lib/base.jl
new file mode 100644
index 000000000..99ff446ce
--- /dev/null
+++ b/test/lib/base.jl
@@ -0,0 +1,13 @@
+@testset "base.jl" begin
+ @testset "dict_param" begin
+ d = Dict{String, Vector{Float64}}("key"=>ones(4))
+ fn() = d["key"][2]
+ result1 = gradient(fn, Params([d["key"]]))[d["key"]]
+
+ x = d["key"]
+ fn2() = x[2]
+ result2 = gradient(fn2, Params([x]))[x]
+
+ @test result1 == result2
+ end
+end
diff --git a/test/runtests.jl b/test/runtests.jl
index 17ebb3997..fe5590efd 100644
--- a/test/runtests.jl
+++ b/test/runtests.jl
@@ -29,6 +29,7 @@ using CUDA: has_cuda
@testset "lib" begin
include("lib/number.jl")
include("lib/lib.jl")
+ include("lib/base.jl")
include("lib/array.jl")
end
From 99d89b09b5190b02aebf0026c3462c06bfd78a83 Mon Sep 17 00:00:00 2001
From: Brian Chen
Date: Sat, 11 Jun 2022 09:17:05 -0700
Subject: [PATCH 308/490] tweak test name
---
test/lib/base.jl | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/test/lib/base.jl b/test/lib/base.jl
index 99ff446ce..5186483da 100644
--- a/test/lib/base.jl
+++ b/test/lib/base.jl
@@ -1,5 +1,5 @@
@testset "base.jl" begin
- @testset "dict_param" begin
+ @testset "Dict getindex with implicit params" begin
d = Dict{String, Vector{Float64}}("key"=>ones(4))
fn() = d["key"][2]
result1 = gradient(fn, Params([d["key"]]))[d["key"]]
From e82f24ff2aa6449844ade41c94f241b3f6d39bf5 Mon Sep 17 00:00:00 2001
From: Miha Zgubic
Date: Fri, 17 Jun 2022 17:38:47 +0200
Subject: [PATCH 309/490] deprecate dropgrad
---
docs/src/utils.md | 1 -
src/Zygote.jl | 1 +
src/deprecated.jl | 15 +++++++++++++++
src/lib/utils.jl | 13 -------------
4 files changed, 16 insertions(+), 14 deletions(-)
create mode 100644 src/deprecated.jl
diff --git a/docs/src/utils.md b/docs/src/utils.md
index b7e779185..d04f83140 100644
--- a/docs/src/utils.md
+++ b/docs/src/utils.md
@@ -18,7 +18,6 @@ Zygote.withgradient
Zygote.withjacobian
Zygote.@showgrad
Zygote.hook
-Zygote.dropgrad
Zygote.Buffer
Zygote.forwarddiff
Zygote.ignore
diff --git a/src/Zygote.jl b/src/Zygote.jl
index a42dd38c1..b1ca50aa9 100644
--- a/src/Zygote.jl
+++ b/src/Zygote.jl
@@ -18,6 +18,7 @@ export rrule_via_ad
const Numeric{T<:Number} = Union{T, AbstractArray{<:T}}
+include("deprecated.jl")
include("tools/buffer.jl")
include("tools/builtins.jl")
diff --git a/src/deprecated.jl b/src/deprecated.jl
new file mode 100644
index 000000000..1df630017
--- /dev/null
+++ b/src/deprecated.jl
@@ -0,0 +1,15 @@
+"""
+ dropgrad(x) -> x
+
+Drop the gradient of `x`.
+
+ julia> gradient(2, 3) do a, b
+ dropgrad(a)*b
+ end
+ (nothing, 2)
+"""
+function dropgrad end
+
+@adjoint dropgrad(x) = dropgrad(x), _ -> nothing
+
+Base.@deprecate dropgrad(x) ChainRulesCore.ignore_derivatives(x)
diff --git a/src/lib/utils.jl b/src/lib/utils.jl
index 86e6fff8c..e5d8baeee 100644
--- a/src/lib/utils.jl
+++ b/src/lib/utils.jl
@@ -1,16 +1,3 @@
-"""
- dropgrad(x) -> x
-
-Drop the gradient of `x`.
-
- julia> gradient(2, 3) do a, b
- dropgrad(a)*b
- end
- (nothing, 2)
-"""
-dropgrad(x) = x
-@adjoint dropgrad(x) = dropgrad(x), _ -> nothing
-
"""
ignore() do
...
From 99149d4c345a6ea4ddf5a994e5c897381031bcc0 Mon Sep 17 00:00:00 2001
From: Miha Zgubic
Date: Fri, 17 Jun 2022 17:48:58 +0200
Subject: [PATCH 310/490] deprecate ignore
---
src/deprecated.jl | 36 ++++++++++++++++++++++++++++++++++++
src/lib/utils.jl | 32 --------------------------------
2 files changed, 36 insertions(+), 32 deletions(-)
diff --git a/src/deprecated.jl b/src/deprecated.jl
index 1df630017..9bc808511 100644
--- a/src/deprecated.jl
+++ b/src/deprecated.jl
@@ -13,3 +13,39 @@ function dropgrad end
@adjoint dropgrad(x) = dropgrad(x), _ -> nothing
Base.@deprecate dropgrad(x) ChainRulesCore.ignore_derivatives(x)
+
+
+"""
+ ignore() do
+ ...
+ end
+
+Tell Zygote to ignore a block of code. Everything inside the `do` block will run
+on the forward pass as normal, but Zygote won't try to differentiate it at all.
+This can be useful for e.g. code that does logging of the forward pass.
+
+Obviously, you run the risk of incorrect gradients if you use this incorrectly.
+"""
+function ignore end
+
+@adjoint ignore(f) = ignore(f), _ -> nothing
+
+Base.@deprecate ignore(f) ChainRulesCore.ignore_derivatives(f)
+
+"""
+ @ignore (...)
+
+Tell Zygote to ignore an expression. Equivalent to `ignore() do (...) end`.
+Example:
+
+```julia-repl
+julia> f(x) = (y = Zygote.@ignore x; x * y);
+julia> f'(1)
+1
+```
+"""
+macro ignore(ex)
+ return :(Zygote.ignore() do
+ $(esc(ex))
+ end)
+end
diff --git a/src/lib/utils.jl b/src/lib/utils.jl
index e5d8baeee..72c60a961 100644
--- a/src/lib/utils.jl
+++ b/src/lib/utils.jl
@@ -1,35 +1,3 @@
-"""
- ignore() do
- ...
- end
-
-Tell Zygote to ignore a block of code. Everything inside the `do` block will run
-on the forward pass as normal, but Zygote won't try to differentiate it at all.
-This can be useful for e.g. code that does logging of the forward pass.
-
-Obviously, you run the risk of incorrect gradients if you use this incorrectly.
-"""
-ignore(f) = f()
-@adjoint ignore(f) = ignore(f), _ -> nothing
-
-"""
- @ignore (...)
-
-Tell Zygote to ignore an expression. Equivalent to `ignore() do (...) end`.
-Example:
-
-```julia-repl
-julia> f(x) = (y = Zygote.@ignore x; x * y);
-julia> f'(1)
-1
-```
-"""
-macro ignore(ex)
- return :(Zygote.ignore() do
- $(esc(ex))
- end)
-end
-
"""
hook(x̄ -> ..., x) -> x
From e647fc30b301904ed52ba329624a23176757d59d Mon Sep 17 00:00:00 2001
From: Miha Zgubic
Date: Fri, 17 Jun 2022 17:55:08 +0200
Subject: [PATCH 311/490] move tests for deprecated functionality into
deprecated.jl
---
test/deprecated.jl | 10 ++++++++++
test/gradcheck.jl | 8 --------
test/runtests.jl | 4 ++++
3 files changed, 14 insertions(+), 8 deletions(-)
create mode 100644 test/deprecated.jl
diff --git a/test/deprecated.jl b/test/deprecated.jl
new file mode 100644
index 000000000..ffc4994c7
--- /dev/null
+++ b/test/deprecated.jl
@@ -0,0 +1,10 @@
+@test_deprecated dropgrad(1)
+@test_deprecated ignore(1)
+@test_deprecated Zygote.@ignore x=1
+
+@test gradient(x -> Zygote.ignore(() -> x*x), 1) == (nothing,)
+@test gradient(x -> Zygote.@ignore(x*x), 1) == (nothing,)
+@test gradient(1) do x
+ y = Zygote.@ignore x
+ x * y
+end == (1,)
diff --git a/test/gradcheck.jl b/test/gradcheck.jl
index ac0dd28bf..268c1734e 100644
--- a/test/gradcheck.jl
+++ b/test/gradcheck.jl
@@ -1681,14 +1681,6 @@ end
@test gradient(x -> findfirst(ismissing, x), [1, missing]) == (nothing,)
@test gradient(x -> findlast(ismissing, x), [1, missing]) == (nothing,)
@test gradient(x -> findall(ismissing, x)[1], [1, missing]) == (nothing,)
-
-
- @test gradient(x -> Zygote.ignore(() -> x*x), 1) == (nothing,)
- @test gradient(x -> Zygote.@ignore(x*x), 1) == (nothing,)
- @test gradient(1) do x
- y = Zygote.@ignore x
- x * y
- end == (1,)
end
@testset "fastmath" begin
diff --git a/test/runtests.jl b/test/runtests.jl
index fe5590efd..565ad182f 100644
--- a/test/runtests.jl
+++ b/test/runtests.jl
@@ -14,6 +14,10 @@ using CUDA: has_cuda
@warn "CUDA not found - Skipping CUDA Tests"
end
+ @testset "deprecated.jl" begin
+ include("deprecated.jl")
+ end
+
@testset "Interface" begin
include("interface.jl")
end
From 38bf316766ed552ecd850fde1ae9e19b295d1db9 Mon Sep 17 00:00:00 2001
From: Miha Zgubic
Date: Fri, 17 Jun 2022 18:08:14 +0200
Subject: [PATCH 312/490] remove ignore from docs
---
docs/src/utils.md | 1 -
1 file changed, 1 deletion(-)
diff --git a/docs/src/utils.md b/docs/src/utils.md
index d04f83140..ce9c3e778 100644
--- a/docs/src/utils.md
+++ b/docs/src/utils.md
@@ -20,7 +20,6 @@ Zygote.@showgrad
Zygote.hook
Zygote.Buffer
Zygote.forwarddiff
-Zygote.ignore
Zygote.checkpointed
```
From 131c5c82a9c653a836f1545cbac9c687ab7507f8 Mon Sep 17 00:00:00 2001
From: ST John
Date: Fri, 17 Jun 2022 19:28:36 +0300
Subject: [PATCH 313/490] increase ChainRules lower bound to 1.35.3
---
Project.toml | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/Project.toml b/Project.toml
index ae5c974d9..8de31f1a4 100644
--- a/Project.toml
+++ b/Project.toml
@@ -25,7 +25,7 @@ ZygoteRules = "700de1a5-db45-46bc-99cf-38207098b444"
[compat]
AbstractFFTs = "0.5, 1.0"
-ChainRules = "1.5"
+ChainRules = "1.35.3"
ChainRulesCore = "1.9"
ChainRulesTestUtils = "1"
DiffRules = "1.4"
From 644a5dd874f769bca7d6e4aa314e08f36bf4bd27 Mon Sep 17 00:00:00 2001
From: ST John
Date: Sat, 18 Jun 2022 10:47:23 +0300
Subject: [PATCH 314/490] bump julia compat to 1.6
---
Project.toml | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/Project.toml b/Project.toml
index 8de31f1a4..a35854934 100644
--- a/Project.toml
+++ b/Project.toml
@@ -38,7 +38,7 @@ NaNMath = "0.3, 1"
Requires = "1.1"
SpecialFunctions = "1.6, 2"
ZygoteRules = "0.2.1"
-julia = "1.3"
+julia = "1.6"
[extras]
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
From 206a6000397cc47ebf3c2b5c3c79c201e6eca310 Mon Sep 17 00:00:00 2001
From: ST John
Date: Sat, 18 Jun 2022 10:58:30 +0300
Subject: [PATCH 315/490] remove failing test
---
test/gradcheck.jl | 1 -
1 file changed, 1 deletion(-)
diff --git a/test/gradcheck.jl b/test/gradcheck.jl
index ac0dd28bf..0b024a51f 100644
--- a/test/gradcheck.jl
+++ b/test/gradcheck.jl
@@ -654,7 +654,6 @@ end
g(X) = cholesky(X * X' + I)
@test Zygote.pullback(g, X)[2]((factors=LowerTriangular(X),))[1] ≈
Zygote.pullback(g, X)[2]((factors=Matrix(LowerTriangular(X)),))[1]
- @test_throws PosDefException Zygote.pullback(X -> cholesky(X, check = false), X)[2]((factors=X,))
# https://github.com/FluxML/Zygote.jl/issues/932
@test gradcheck(rand(5, 5), rand(5)) do A, x
From 984a25c01adca0e9b197c235a986dcb9ed5c1396 Mon Sep 17 00:00:00 2001
From: ST John
Date: Sat, 18 Jun 2022 11:08:55 +0300
Subject: [PATCH 316/490] add Hermitian cholesky test
---
test/gradcheck.jl | 12 ++++++++++++
1 file changed, 12 insertions(+)
diff --git a/test/gradcheck.jl b/test/gradcheck.jl
index 0b024a51f..0ce7f6ae4 100644
--- a/test/gradcheck.jl
+++ b/test/gradcheck.jl
@@ -819,6 +819,18 @@ end
@test back′(C̄)[1] isa Diagonal
@test diag(back′(C̄)[1]) ≈ diag(back(C̄)[1])
end
+ @testset "cholesky - Hermitian" begin
+ rng, N = MersenneTwister(123456), 3
+ A = randn(rng, N, N) + im * randn(rng, N, N)
+ H = Hermitian(A * A' + I)
+ Hmat = Matrix(H)
+ y, back = Zygote.pullback(cholesky, Hmat)
+ y′, back′ = Zygote.pullback(cholesky, H)
+ C̄ = (factors=randn(rng, N, N),)
+ @test back′(C̄)[1] isa Hermitian
+ @test gradtest(B->cholesky(Hermitian(B)).U, A * A' + I)
+ @test gradtest(B->logdet(cholesky(Hermitian(B))), A * A' + I)
+ end
end
@testset "lyap" begin
From d13be2e84358e449e45ea7e3b86e8db793fd70ea Mon Sep 17 00:00:00 2001
From: st--
Date: Sat, 18 Jun 2022 21:48:44 +0300
Subject: [PATCH 317/490] Update test/gradcheck.jl
Co-authored-by: David Widmann
---
test/gradcheck.jl | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/test/gradcheck.jl b/test/gradcheck.jl
index 0ce7f6ae4..46e40b87e 100644
--- a/test/gradcheck.jl
+++ b/test/gradcheck.jl
@@ -821,7 +821,7 @@ end
end
@testset "cholesky - Hermitian" begin
rng, N = MersenneTwister(123456), 3
- A = randn(rng, N, N) + im * randn(rng, N, N)
+ A = randn(rng, Complex{Float64}, N, N)
H = Hermitian(A * A' + I)
Hmat = Matrix(H)
y, back = Zygote.pullback(cholesky, Hmat)
From c8df3f07d326437d35a31f0da60190388e9dbc14 Mon Sep 17 00:00:00 2001
From: ST John
Date: Sat, 18 Jun 2022 21:49:26 +0300
Subject: [PATCH 318/490] bump julia minimum version in github action ci.yml
---
.github/workflows/ci.yml | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml
index bab7876a5..887c985c8 100644
--- a/.github/workflows/ci.yml
+++ b/.github/workflows/ci.yml
@@ -17,7 +17,7 @@ jobs:
fail-fast: false
matrix:
version:
- - '1.3' # Replace this with the minimum Julia version that your package supports.
+ - '1.6' # Replace this with the minimum Julia version that your package supports.
- '1' # automatically expands to the latest stable 1.x release of Julia
- 'nightly'
os:
From 7ce5705b4cd481d46df5416536dc172375000cbe Mon Sep 17 00:00:00 2001
From: Kyle Daruwalla
Date: Sun, 19 Jun 2022 22:35:38 +0530
Subject: [PATCH 319/490] Make array mutation error nicer
---
src/lib/array.jl | 22 +++++++++++++++++++---
1 file changed, 19 insertions(+), 3 deletions(-)
diff --git a/src/lib/array.jl b/src/lib/array.jl
index f492af9e6..93f6ba12b 100644
--- a/src/lib/array.jl
+++ b/src/lib/array.jl
@@ -67,15 +67,31 @@ _droplike(dy::Union{LinearAlgebra.Adjoint, LinearAlgebra.Transpose}, dxv::Abstra
@adjoint getindex(::Type{T}, xs...) where {T} = T[xs...], dy -> (nothing, dy...)
+_throw_mutation_error(details) = error("""
+Mutating arrays is not supported -- $details
+This error occurs when you ask Zygote to differentiate operations that change
+the elements of arrays in place. Some common examples:
+- setting values (x .= ...)
+- appending values (push!(x, v))
+- popping values (pop!(x))
+- calling mutating functions (mul!(C, A, B))
+NOTE: non-mutating functions may use mutation under the hood
+ for performance or code-reuse.
+Possible fixes:
+- avoid mutating operations (preferred)
+- hide the mutation from Zygote by wrapping the mutating call in a custom rrule
+ (https://juliadiff.org/ChainRulesCore.jl/stable/rule_author/example.html)
+""")
+
@adjoint! setindex!(xs::AbstractArray, x...) = setindex!(xs, x...),
- _ -> error("Mutating arrays is not supported -- called setindex!(::$(typeof(xs)), _...)")
+ _ -> _throw_mutation_error("called setindex!(::$(typeof(xs)), _...)")
@adjoint! copyto!(xs, args...) = copyto!(xs, args...),
- _ -> error("Mutating arrays is not supported -- called copyto!(::$(typeof(xs)), _...)")
+ _ -> _throw_mutation_error("called copyto!(::$(typeof(xs)), _...)")
for f in [push!, pop!, pushfirst!, popfirst!]
@eval @adjoint! $f(x::AbstractVector, ys...) = $f(x, ys...),
- _ -> error("Mutating arrays is not supported -- called $($f)(::$(typeof(x)), _...)")
+ _ -> _throw_mutation_error("called $($f)(::$(typeof(x)), _...)")
end
# General
From fc945ba4322b0d9b7aff12b9382f00c35897ad24 Mon Sep 17 00:00:00 2001
From: Kyle Daruwalla
Date: Sun, 19 Jun 2022 22:52:07 +0530
Subject: [PATCH 320/490] Minor improvements to mutation error
---
src/lib/array.jl | 12 +++++++-----
1 file changed, 7 insertions(+), 5 deletions(-)
diff --git a/src/lib/array.jl b/src/lib/array.jl
index 93f6ba12b..d1d542b8e 100644
--- a/src/lib/array.jl
+++ b/src/lib/array.jl
@@ -67,8 +67,8 @@ _droplike(dy::Union{LinearAlgebra.Adjoint, LinearAlgebra.Transpose}, dxv::Abstra
@adjoint getindex(::Type{T}, xs...) where {T} = T[xs...], dy -> (nothing, dy...)
-_throw_mutation_error(details) = error("""
-Mutating arrays is not supported -- $details
+_throw_mutation_error(f, args...) = error("""
+Mutating arrays is not supported -- called $f($(join(map(typeof, args), ", ")), ...)
This error occurs when you ask Zygote to differentiate operations that change
the elements of arrays in place. Some common examples:
- setting values (x .= ...)
@@ -81,17 +81,19 @@ Possible fixes:
- avoid mutating operations (preferred)
- hide the mutation from Zygote by wrapping the mutating call in a custom rrule
(https://juliadiff.org/ChainRulesCore.jl/stable/rule_author/example.html)
+- if the mutation is coming from within a package (i.e. not user code),
+ then open an issue on Zygote.jl (https://github.com/FluxML/Zygote.jl/issues)
""")
@adjoint! setindex!(xs::AbstractArray, x...) = setindex!(xs, x...),
- _ -> _throw_mutation_error("called setindex!(::$(typeof(xs)), _...)")
+ _ -> _throw_mutation_error(setindex!, xs)
@adjoint! copyto!(xs, args...) = copyto!(xs, args...),
- _ -> _throw_mutation_error("called copyto!(::$(typeof(xs)), _...)")
+ _ -> _throw_mutation_error(copyto!, xs)
for f in [push!, pop!, pushfirst!, popfirst!]
@eval @adjoint! $f(x::AbstractVector, ys...) = $f(x, ys...),
- _ -> _throw_mutation_error("called $($f)(::$(typeof(x)), _...)")
+ _ -> _throw_mutation_error($f, x)
end
# General
From 815e8dd76f3618640ad9305f54ce85fc3a634d8f Mon Sep 17 00:00:00 2001
From: ST John
Date: Mon, 20 Jun 2022 13:01:04 +0300
Subject: [PATCH 321/490] fix test
---
test/gradcheck.jl | 20 +++++++++++++++++---
1 file changed, 17 insertions(+), 3 deletions(-)
diff --git a/test/gradcheck.jl b/test/gradcheck.jl
index 46e40b87e..182f2b666 100644
--- a/test/gradcheck.jl
+++ b/test/gradcheck.jl
@@ -819,7 +819,7 @@ end
@test back′(C̄)[1] isa Diagonal
@test diag(back′(C̄)[1]) ≈ diag(back(C̄)[1])
end
- @testset "cholesky - Hermitian" begin
+ @testset "cholesky - Hermitian{Complex}" begin
rng, N = MersenneTwister(123456), 3
A = randn(rng, Complex{Float64}, N, N)
H = Hermitian(A * A' + I)
@@ -827,9 +827,23 @@ end
y, back = Zygote.pullback(cholesky, Hmat)
y′, back′ = Zygote.pullback(cholesky, H)
C̄ = (factors=randn(rng, N, N),)
+ @test only(back′(C̄)) isa Hermitian
+ # gradtest does not support complex gradients, even though the pullback exists
+ d = only(back(C̄))
+ d′ = only(back′(C̄))
+ @test (d + d')/2 ≈ d′
+ end
+ @testset "cholesky - Hermitian{Real}" begin
+ rng, N = MersenneTwister(123456), 3
+ A = randn(rng, N, N)
+ H = Hermitian(A * A' + I)
+ Hmat = Matrix(H)
+ y, back = Zygote.pullback(cholesky, Hmat)
+ y′, back′ = Zygote.pullback(cholesky, H)
+ C̄ = (factors=randn(rng, N, N),)
@test back′(C̄)[1] isa Hermitian
- @test gradtest(B->cholesky(Hermitian(B)).U, A * A' + I)
- @test gradtest(B->logdet(cholesky(Hermitian(B))), A * A' + I)
+ @test gradtest(B->cholesky(Hermitian(B)).U, Hmat)
+ @test gradtest(B->logdet(cholesky(Hermitian(B))), Hmat)
end
end
From 3239330ccf42add76a817d57c46cefb19a8ec0f1 Mon Sep 17 00:00:00 2001
From: David Widmann
Date: Mon, 20 Jun 2022 12:10:02 +0200
Subject: [PATCH 322/490] Use CR adjoint for `logdet(::Cholesky)` (#1226)
* Use CR adjoint for `logdet(::Cholesky)`
* Update Project.toml
* Update ci.yml
* Update Project.toml
---
.github/workflows/ci.yml | 2 +-
Project.toml | 4 ++--
src/lib/array.jl | 6 ------
3 files changed, 3 insertions(+), 9 deletions(-)
diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml
index bab7876a5..887c985c8 100644
--- a/.github/workflows/ci.yml
+++ b/.github/workflows/ci.yml
@@ -17,7 +17,7 @@ jobs:
fail-fast: false
matrix:
version:
- - '1.3' # Replace this with the minimum Julia version that your package supports.
+ - '1.6' # Replace this with the minimum Julia version that your package supports.
- '1' # automatically expands to the latest stable 1.x release of Julia
- 'nightly'
os:
diff --git a/Project.toml b/Project.toml
index ae5c974d9..bdb6f327b 100644
--- a/Project.toml
+++ b/Project.toml
@@ -25,7 +25,7 @@ ZygoteRules = "700de1a5-db45-46bc-99cf-38207098b444"
[compat]
AbstractFFTs = "0.5, 1.0"
-ChainRules = "1.5"
+ChainRules = "1.33"
ChainRulesCore = "1.9"
ChainRulesTestUtils = "1"
DiffRules = "1.4"
@@ -38,7 +38,7 @@ NaNMath = "0.3, 1"
Requires = "1.1"
SpecialFunctions = "1.6, 2"
ZygoteRules = "0.2.1"
-julia = "1.3"
+julia = "1.6"
[extras]
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
diff --git a/src/lib/array.jl b/src/lib/array.jl
index f492af9e6..548159766 100644
--- a/src/lib/array.jl
+++ b/src/lib/array.jl
@@ -741,12 +741,6 @@ end
end
end
-@adjoint function logdet(C::Cholesky)
- return logdet(C), function(Δ)
- return ((uplo=nothing, info=nothing, factors=Diagonal(2 .* Δ ./ diag(C.factors))),)
- end
-end
-
@adjoint function Matrix(S::UniformScaling, i::Integer, j::Integer)
return Matrix(S, i, j), Δ -> ((λ=tr(Δ),), nothing, nothing)
end
From 45389145459d7f4ca892c735abf86b14ff4dc6cd Mon Sep 17 00:00:00 2001
From: Miha Zgubic
Date: Mon, 20 Jun 2022 17:09:21 +0200
Subject: [PATCH 323/490] add a note to docs
---
docs/src/utils.md | 4 ++++
1 file changed, 4 insertions(+)
diff --git a/docs/src/utils.md b/docs/src/utils.md
index ce9c3e778..25b5954e4 100644
--- a/docs/src/utils.md
+++ b/docs/src/utils.md
@@ -13,6 +13,10 @@ Zygote also provides a set of helpful utilities. These are all "user-level" tool
in other words you could have written them easily yourself, but they live in
Zygote for convenience.
+See `ChainRules.ignore_derivatives` if you want to exclude some of your code from the
+gradient calculation. This replaces previous Zygote-specific `ignore` and `dropgrad`
+functionality.
+
```@docs
Zygote.withgradient
Zygote.withjacobian
From 4bb6b4dd4a4b6eb0e40126587a7a170216c97448 Mon Sep 17 00:00:00 2001
From: Miha Zgubic
Date: Thu, 23 Jun 2022 12:15:42 +0200
Subject: [PATCH 324/490] deprecate Zygote.@nograd
---
Project.toml | 2 +-
src/Zygote.jl | 2 +-
src/deprecated.jl | 16 ++++++++++++++++
src/forward/lib.jl | 2 +-
src/lib/array.jl | 7 -------
src/lib/base.jl | 4 ----
src/lib/buffer.jl | 2 +-
src/lib/grad.jl | 12 ------------
src/lib/lib.jl | 2 --
src/lib/number.jl | 3 ---
test/lib/number.jl | 6 +++---
11 files changed, 23 insertions(+), 35 deletions(-)
diff --git a/Project.toml b/Project.toml
index a35854934..14ed91b5c 100644
--- a/Project.toml
+++ b/Project.toml
@@ -25,7 +25,7 @@ ZygoteRules = "700de1a5-db45-46bc-99cf-38207098b444"
[compat]
AbstractFFTs = "0.5, 1.0"
-ChainRules = "1.35.3"
+ChainRules = "1.37"
ChainRulesCore = "1.9"
ChainRulesTestUtils = "1"
DiffRules = "1.4"
diff --git a/src/Zygote.jl b/src/Zygote.jl
index b1ca50aa9..8a51b14fd 100644
--- a/src/Zygote.jl
+++ b/src/Zygote.jl
@@ -58,7 +58,7 @@ include("profiler/Profile.jl")
end
@init @require Colors="5ae59095-9a9b-59fe-a467-6f913c188581" begin
- @nograd Colors.ColorTypes._parameter_upper_bound
+ @non_differentiable Colors.ColorTypes._parameter_upper_bound(::Any...)
end
using InteractiveUtils
diff --git a/src/deprecated.jl b/src/deprecated.jl
index 9bc808511..6fe88b5b1 100644
--- a/src/deprecated.jl
+++ b/src/deprecated.jl
@@ -49,3 +49,19 @@ macro ignore(ex)
$(esc(ex))
end)
end
+
+using MacroTools: @q
+
+macro nograd(ex)
+ Base.depwarn(
+ "`Zygote.@nograd myfunc` is deprecated, use `ChainRulesCore.@non_differentiable myfunc(::Any...)` instead.",
+ :nograd
+ )
+ isexpr(ex, :tuple) || (ex = Expr(:tuple, ex))
+ blk = @q begin end
+ for f in ex.args
+ back = MacroTools.@q _ -> ($__source__; nothing)
+ push!(blk.args, :(@inline Zygote._pullback(::Context, ::Core.Typeof($(esc(f))), args...) = $(esc(f))(args...), $back))
+ end
+ return blk
+end
diff --git a/src/forward/lib.jl b/src/forward/lib.jl
index b297dab41..a5518fd5d 100644
--- a/src/forward/lib.jl
+++ b/src/forward/lib.jl
@@ -9,7 +9,7 @@ end
# TODO figure out why this made a test fail
zerolike(x::Union{Module,Type}) = nothing
-# TODO: `@nograd` and `@linear`
+# TODO: `@non_differentiable` and `@linear`
@tangent zerolike(x) = zerolike(x), _ -> zerolike(x)
@tangent one(x::Number) = one(x), _ -> zero(x)
diff --git a/src/lib/array.jl b/src/lib/array.jl
index bbe13669d..70790f757 100644
--- a/src/lib/array.jl
+++ b/src/lib/array.jl
@@ -6,8 +6,6 @@ using Distributed: pmap, AbstractWorkerPool
@adjoint Array(xs::AbstractArray) = Array(xs), ȳ -> (ȳ,)
@adjoint Array(xs::Array) = Array(xs), ȳ -> (ȳ,)
-@nograd ones, zeros, Base.OneTo, Colon(), one, zero, sizehint!, count
-
@adjoint copy(x::AbstractArray) = copy(x), ȳ -> (ȳ,)
@adjoint collect(x::Tuple) = collect(x), dy -> (Tuple(dy),)
@@ -222,11 +220,6 @@ end
end
end
-for t in subtypes(AbstractWorkerPool)
- @nograd t
-end
-@nograd workers
-
function _pullback(cx::AContext, ::typeof(collect), g::Base.Generator)
y, b = ∇map(cx, g.f, g.iter)
back(::Nothing) = nothing
diff --git a/src/lib/base.jl b/src/lib/base.jl
index fa71d8906..79dfb77b6 100644
--- a/src/lib/base.jl
+++ b/src/lib/base.jl
@@ -49,8 +49,6 @@ end
# Channels
-@nograd Channel
-
grad_mut(ch::Channel) = Channel(ch.sz_max)
@adjoint! function put!(ch::Channel, x)
@@ -157,8 +155,6 @@ end
@adjoint Base.nameof(x::UnionAll) = nameof(x), _ -> (nothing,)
-@nograd typeintersect
-
# Base.Fix1 and Base.Fix2: https://github.com/FluxML/Zygote.jl/issues/957
@adjoint function (g::Base.Fix1)(y)
f = g.f
diff --git a/src/lib/buffer.jl b/src/lib/buffer.jl
index e62a70041..b3aef17f0 100644
--- a/src/lib/buffer.jl
+++ b/src/lib/buffer.jl
@@ -1,7 +1,7 @@
grad_mut(b::Buffer) = fill!(similar(b.data, Any), nothing)
grad_mut(b::Buffer{T}) where T<:Number = fill!(similar(b.data, float(T)), 0)
-@nograd Buffer
+@non_differentiable Buffer(::Any...)
@adjoint function getindex(b::Buffer, i...)
b[i...], function (Δ)
diff --git a/src/lib/grad.jl b/src/lib/grad.jl
index a522d685a..38347b312 100644
--- a/src/lib/grad.jl
+++ b/src/lib/grad.jl
@@ -1,15 +1,3 @@
-using MacroTools: @q
-
-macro nograd(ex)
- isexpr(ex, :tuple) || (ex = Expr(:tuple, ex))
- blk = @q begin end
- for f in ex.args
- back = MacroTools.@q _ -> ($__source__; nothing)
- push!(blk.args, :(@inline Zygote._pullback(::Context, ::Core.Typeof($(esc(f))), args...) = $(esc(f))(args...), $back))
- end
- return blk
-end
-
macro which(ex)
@capture(ex, f_(args__)) || error("Zygote.@which f(args...)")
:(InteractiveUtils.@which adjoint(Context(), $(esc(f)), $(esc.(args)...)))
diff --git a/src/lib/lib.jl b/src/lib/lib.jl
index f11a74214..22bda1e19 100644
--- a/src/lib/lib.jl
+++ b/src/lib/lib.jl
@@ -38,8 +38,6 @@ function accum(x::RefValue, y::RefValue)
end
# Core functions
-@nograd eps, Base.eval, Core.TypeVar, Core.UnionAll, Symbol
-
@adjoint deepcopy(x) = deepcopy(x), ȳ -> (ȳ,)
@adjoint (::Type{V})(x...) where V<:Val = V(x...), _ -> nothing
diff --git a/src/lib/number.jl b/src/lib/number.jl
index 4097863c8..296852dbc 100644
--- a/src/lib/number.jl
+++ b/src/lib/number.jl
@@ -1,6 +1,3 @@
-
-@nograd floor, ceil, trunc, round, div
-
@adjoint Base.literal_pow(::typeof(^), x::Number, ::Val{p}) where {p} =
Base.literal_pow(^,x,Val(p)),
Δ -> (nothing, Δ * conj(p * Base.literal_pow(^,x,Val(p-1))), nothing)
diff --git a/test/lib/number.jl b/test/lib/number.jl
index d69241655..ae7b1cb75 100644
--- a/test/lib/number.jl
+++ b/test/lib/number.jl
@@ -1,7 +1,7 @@
@testset "nograds" begin
- @test gradient(floor, 1) === nothing
- @test gradient(ceil, 1) === nothing
- @test gradient(round, 1) === nothing
+ @test gradient(floor, 1) === (0.0,)
+ @test gradient(ceil, 1) === (0.0,)
+ @test gradient(round, 1) === (0.0,)
@test gradient(hash, 1) === nothing
@test gradient(div, 1, 2) === nothing
end #testset
From e9b119743819fe9bc4c244d0b22f27c6a62a15e1 Mon Sep 17 00:00:00 2001
From: Kyle Daruwalla
Date: Fri, 24 Jun 2022 12:03:38 +0200
Subject: [PATCH 325/490] Reduce length for error message and add detailed docs
---
docs/make.jl | 1 +
docs/src/limitations.md | 148 ++++++++++++++++++++++++++++++++++++++++
src/compiler/reverse.jl | 11 ++-
src/lib/array.jl | 15 ++--
4 files changed, 162 insertions(+), 13 deletions(-)
create mode 100644 docs/src/limitations.md
diff --git a/docs/make.jl b/docs/make.jl
index 9d2f549c9..ff8cb28d1 100644
--- a/docs/make.jl
+++ b/docs/make.jl
@@ -11,6 +11,7 @@ makedocs(
doctest = true,
pages = [
"Home" => "index.md",
+ "Limitations" => "limitations.md",
"Custom Adjoints" => "adjoints.md",
"Utilities" => "utils.md",
"Complex Differentiation" => "complex.md",
diff --git a/docs/src/limitations.md b/docs/src/limitations.md
new file mode 100644
index 000000000..0908b0882
--- /dev/null
+++ b/docs/src/limitations.md
@@ -0,0 +1,148 @@
+# Limitations
+
+Zygote aims to support differentiating any code you might write in Julia, but it still has a few limitations. Notably, you might encounter errors when trying to differentiate:
+- array mutation
+- `try`/`catch` statements
+- "foreign call" expressions
+
+In this section, we will introduce examples where each of these errors occurs as well as possible work-arounds.
+
+## Array mutation
+
+Array mutation is by far the most commonly encountered Zygote limitation. Unfortunately, supporting it natively in Zygote is tricky, though it may happen eventually. For now, let's focus on what counts as mutation, and how to fix it.
+
+Here we define a simple mutating function, `f!`, which modifies the elements of its input argument, `x`, in place.
+```julia
+function f!(x)
+ x .= 2 .* x
+
+ return x
+end
+```
+Let's see what happens when we differentiate `f!`
+```julia
+julia> gradient(rand(3)) do x
+ sum(f!(x))
+ end
+ERROR: Mutating arrays is not supported -- called copyto!(Vector{Float64}, ...)
+This error occurs when you ask Zygote to differentiate operations that change
+the elements of arrays in-place (e.g. setting values with x .= ...)
+
+Possible fixes:
+- avoid mutating operations (preferred)
+- or read the documentation and solutions for this error
+ https://fluxml.ai/Zygote.jl/dev/limitations.html#Array-mutation
+
+Stacktrace:
+ ...
+```
+We got an error message and a long stacktrace. The error informs us that our code performs array mutation by calling `copyto!` (we might not have directly called this function, but it is being invoked somewhere in the call stack). We see that our code includes `x .= ...` which is given as an example of array mutation. Other examples of mutating operations include:
+- setting values (`x .= ...`)
+- appending/popping values (`push!(x, v)` / `pop!(x)`)
+- calling mutating functions (`mul!(C, A, B)`)
+
+!!! warning
+
+ Non-mutating functions may also use mutation under the hood. This can be done for performance reasons or code re-use.
+
+```julia
+function g!(x, y)
+ x .= 2 .* y
+
+ return x
+end
+g(y) = g!(similar(y), y)
+```
+Here `g` is a "non-mutating function," and it indeed does not mutate `y`, its only argument. But it still allocates a new array and calls `g!` on this array which will result in a mutating operation. You may encounter such functions when working with another package.
+
+Specifically for array mutation, we can use [`Zygote.Buffer`](@ref) to re-write our function. For example, let's fix the function `g!` above.
+```julia
+function g!(x, y)
+ x .= 2 .* y
+
+ return x
+end
+
+function g(y)
+ x = Zygote.Buffer(y) # Buffer supports syntax like similar
+ g!(x, y)
+ return copy(x) # this step makes the Buffer immutable (w/o actually copying)
+end
+
+julia> gradient(rand(3)) do y
+ sum(g(y))
+ end
+([2.0, 2.0, 2.0],)
+```
+
+## Try-catch statements
+
+Any expressions involving `try`/`catch` statements is not supported.
+```julia
+function tryme(x)
+ try
+ 2 * x
+ catch e
+ throw(e)
+ end
+end
+
+julia> gradient(rand(3)) do x
+ sum(tryme(x))
+ end
+ERROR: Compiling Tuple{typeof(tryme), Vector{Float64}}: try/catch is not supported.
+Refer to the Zygote documentation for fixes.
+https://fluxml.ai/Zygote.jl/dev/limitations.html#try-catch-statements-1
+
+Stacktrace:
+ ...
+```
+Here `tryme` uses a `try`/`catch` statement, and Zygote throws an error when trying to differentiate it as expected. `try`/`catch` expressions are used for error handling, but they are less common in Julia compared to some other languages.
+
+## Foreign call expressions
+
+Foreign call expressions refer to expressions that call external libraries such as code written in C or Fortran. You may want to read more about these calls in the [Julia documentation](https://docs.julialang.org/en/v1/manual/calling-c-and-fortran-code/). Scientific computing libraries in Julia may call established C or Fortran libraries under the hood. Since the underlying code for a foreign call expression is not in Julia, it is not possible for Zygote to differentiate this expression.
+
+Below, we define a function that calls a standard C function, `clock`. This function returns the Unix clock as an `Int32`.
+```julia
+julia> jclock(x) = ccall(:clock, Int32, ()) * 2
+jclock (generic function with 1 method)
+
+julia> jclock(2)
+30921278
+
+julia> gradient(jclock, rand())
+ERROR: Can't differentiate foreigncall expression
+You might want to check the Zygote limitations documentation.
+https://fluxml.ai/Zygote.jl/dev/limitations.html
+
+Stacktrace:
+ ...
+```
+`jclock` will multiply the result of our C function by an argument. When we try to differentiate with respect to this argument, we get an `foreigncall` error.
+
+## Solutions
+
+For all of the errors above, the suggested solutions are similar. You have the following possible work arounds available (in order of preference):
+1. avoid the error-inducing operation (e.g. do not use mutating functions)
+2. define a [custom `ChainRulesCore.rrule`](https://juliadiff.org/ChainRulesCore.jl/stable/rule_author/example.html)
+3. open an [issue on Zygote](https://github.com/FluxML/Zygote.jl/issues)
+
+Avoiding the operation is simple, just don't do it! If you are using a mutating function, try to use a non-mutating variant. If you are using `try`/`catch` statements, try to use more graceful error handling such as returning `nothing` or another sentinel value. Recall that array mutation can also be avoided by using [`Zygote.Buffer`](@ref) as discussed above.
+
+Sometimes, we cannot avoid expressions that Zygote cannot differentiate, but we may be able to manually derive a gradient. In these cases, you can write [a custom `rrule`](https://juliadiff.org/ChainRulesCore.jl/stable/rule_author/example.html) using ChainRules.jl. Please refer to the linked ChainRules documentation for how to do this. _This solution is the only solution available for foreign call expressions._ Below, we provide a custom `rrule` for `jclock`.
+```julia
+jclock(x) = ccall(:clock, Int32, ()) * x
+
+function ChainRulesCore.rrule(::typeof(jclock), x)
+ y = jclock(x)
+ pb(ȳ) = (ChainRulesCore.NoTangent(), ȳ * y)
+
+ return y, pb
+end
+
+julia> gradient(jclock, rand())
+(674298.4243400148,)
+```
+
+Lastly, if the code causing problems can be fixed, but it is package code instead of your code, then you should open an issue. For functions built into Julia or its standard libraries, you can open an issue with Zygote.jl or ChainRules.jl. For functions in other packages, you can open an issue with the corresponding package issue tracker.
diff --git a/src/compiler/reverse.jl b/src/compiler/reverse.jl
index e746684f7..b6dec7caa 100644
--- a/src/compiler/reverse.jl
+++ b/src/compiler/reverse.jl
@@ -118,7 +118,10 @@ function instrument(ir::IR)
if isexpr(ex, :foreigncall, :isdefined)
continue
elseif isexpr(ex, :enter, :leave)
- error("try/catch is not supported.")
+ error("""try/catch is not supported.
+ Refer to the Zygote documentation for fixes.
+ https://fluxml.ai/Zygote.jl/dev/limitations.html#Try-catch-statements-1
+ """)
elseif isexpr(ex, :(=))
@assert ex.args[1] isa GlobalRef
pr[v] = xcall(Zygote, :global_set, QuoteNode(ex.args[1]), ex.args[2])
@@ -277,7 +280,11 @@ function adjoint(pr::Primal)
grads[ex.val] = grads[v]
elseif isexpr(ex, GlobalRef, :call, :isdefined, :inbounds, :meta, :loopinfo)
elseif isexpr(ex)
- push!(rb, stmt(xcall(Base, :error, "Can't differentiate $(ex.head) expression"),
+ push!(rb, stmt(xcall(Base, :error, """
+ Can't differentiate $(ex.head) expression.
+ You might want to check the Zygote limitations documentation.
+ https://fluxml.ai/Zygote.jl/dev/limitations.html
+ """),
line = b[v].line))
else # A literal value
continue
diff --git a/src/lib/array.jl b/src/lib/array.jl
index d1d542b8e..1f1e12d94 100644
--- a/src/lib/array.jl
+++ b/src/lib/array.jl
@@ -70,19 +70,12 @@ _droplike(dy::Union{LinearAlgebra.Adjoint, LinearAlgebra.Transpose}, dxv::Abstra
_throw_mutation_error(f, args...) = error("""
Mutating arrays is not supported -- called $f($(join(map(typeof, args), ", ")), ...)
This error occurs when you ask Zygote to differentiate operations that change
-the elements of arrays in place. Some common examples:
-- setting values (x .= ...)
-- appending values (push!(x, v))
-- popping values (pop!(x))
-- calling mutating functions (mul!(C, A, B))
-NOTE: non-mutating functions may use mutation under the hood
- for performance or code-reuse.
+the elements of arrays in place (e.g. setting values with x .= ...)
+
Possible fixes:
- avoid mutating operations (preferred)
-- hide the mutation from Zygote by wrapping the mutating call in a custom rrule
- (https://juliadiff.org/ChainRulesCore.jl/stable/rule_author/example.html)
-- if the mutation is coming from within a package (i.e. not user code),
- then open an issue on Zygote.jl (https://github.com/FluxML/Zygote.jl/issues)
+- or read the documentation and solutions for this error
+ https://fluxml.ai/Zygote.jl/dev/limitations.html#Array-mutation-1
""")
@adjoint! setindex!(xs::AbstractArray, x...) = setindex!(xs, x...),
From a4eac890f41e7072ef5b571da101b8f498e49b2c Mon Sep 17 00:00:00 2001
From: Michael Abbott <32575566+mcabbott@users.noreply.github.com>
Date: Sat, 25 Jun 2022 18:40:31 -0600
Subject: [PATCH 326/490] rm rules for `maximum`, `minimum`, `dropdims` (#1250)
* rm rules for maximum, minimum, dropdims
* add test
* typo
---
src/lib/array.jl | 24 ------------------------
test/gradcheck.jl | 6 ++++++
2 files changed, 6 insertions(+), 24 deletions(-)
diff --git a/src/lib/array.jl b/src/lib/array.jl
index bbe13669d..a37aa7787 100644
--- a/src/lib/array.jl
+++ b/src/lib/array.jl
@@ -313,35 +313,11 @@ end
sum(xs, dims = dims), Δ -> (nothing,)
end
-
function _pullback(cx::AContext, ::typeof(prod), f, xs::AbstractArray)
y, back = pullback(cx, ((f, xs) -> prod(f.(xs))), f, xs)
y, ȳ -> (nothing, back(ȳ)...)
end
-@adjoint function maximum(xs::AbstractArray; dims = :)
- max, i = findmax(xs, dims = dims)
- max, function (Δ)
- Δ isa Real && abs(Δ) <= sqrt(eps(float(Δ))) && return nothing
- Δ′ = zero(xs)
- Δ′[i] = Δ
- return (Δ′,)
- end
-end
-
-@adjoint function minimum(xs::AbstractArray; dims = :)
- min, i = findmin(xs, dims = dims)
- min, function (Δ)
- Δ′ = zero(xs)
- Δ′[i] = Δ
- return (Δ′,)
- end
-end
-
-@adjoint function dropdims(xs::AbstractArray; dims)
- dropdims(xs, dims = dims), Δ -> (reshape(Δ, size(xs)...),)
-end
-
@adjoint real(x::AbstractArray) = real(x), r̄ -> (real(r̄),)
@adjoint conj(x::AbstractArray) = conj(x), r̄ -> (conj(r̄),)
@adjoint imag(x::AbstractArray) = imag(x), ī -> (complex.(0, real.(ī)),)
diff --git a/test/gradcheck.jl b/test/gradcheck.jl
index 30c62eb3e..e37e0ea15 100644
--- a/test/gradcheck.jl
+++ b/test/gradcheck.jl
@@ -501,6 +501,12 @@ end
@test gradtest(x -> maximum(x, dims=[1, 2]), rand(2, 3, 4))
@test gradient(x -> 1 / maximum(x), [1., 2, 3])[1] == [0, 0, -1/9]
+
+ # issue 1224, second order
+ f1244(w, x) = sum(maximum((w * x).^2, dims=1))
+ g1244(w, x) = sum(gradient(f1244, w, x)[2].^2)
+ h1244(w, x) = gradient(g1244, w, x)[2]
+ @test h1244([1 2 3; 4 5 6.0], [7,8,9.0]) ≈ [300608, 375760, 450912]
end
@testset "minimum" begin
From b00ff49abf9ace8af90ec4eb8e6b3a169e194586 Mon Sep 17 00:00:00 2001
From: Saransh
Date: Mon, 27 Jun 2022 14:33:15 +0530
Subject: [PATCH 327/490] Run doctests only once
---
docs/make.jl | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/docs/make.jl b/docs/make.jl
index 9d2f549c9..98f9333f2 100644
--- a/docs/make.jl
+++ b/docs/make.jl
@@ -8,7 +8,7 @@ using Documenter, Zygote
makedocs(
sitename="Zygote",
- doctest = true,
+ doctest = false,
pages = [
"Home" => "index.md",
"Custom Adjoints" => "adjoints.md",
From fe6ff51c43f7bdf92790d056b2f6ed2c717c99a3 Mon Sep 17 00:00:00 2001
From: Kyle Daruwalla
Date: Tue, 28 Jun 2022 18:02:20 +0100
Subject: [PATCH 328/490] More intro on mutation
---
docs/src/limitations.md | 6 ++++--
1 file changed, 4 insertions(+), 2 deletions(-)
diff --git a/docs/src/limitations.md b/docs/src/limitations.md
index 0908b0882..f74304e97 100644
--- a/docs/src/limitations.md
+++ b/docs/src/limitations.md
@@ -9,9 +9,11 @@ In this section, we will introduce examples where each of these errors occurs as
## Array mutation
-Array mutation is by far the most commonly encountered Zygote limitation. Unfortunately, supporting it natively in Zygote is tricky, though it may happen eventually. For now, let's focus on what counts as mutation, and how to fix it.
+Array mutation is by far the most commonly encountered Zygote limitation.
-Here we define a simple mutating function, `f!`, which modifies the elements of its input argument, `x`, in place.
+Automatic differentiation (AD) systems like Zygote are built on basic principles of calculus where we encounter _pure_ functions. This means that the function, ``y = f(x)``, does not modify ``x`` and only produces the output ``y`` based on ``x``. If we have a chain of functions, such as ``y = h(g(f(x)))``, we can apply the chain rule to differentiate it. AD systems are built to programmatically apply the chain rule to a series of function calls. Unfortunately, typical programs do not behave this way. We might allocate some memory, `x`, then call a function `y = f!(x)` that modifies `x` to produce the output `y`. This mutating behavior is a _side-effect_ of `f!`. Side-effects are difficult for AD systems to handle, because the must track changes to mutated variables and store older versions of the variable. For these reasons, Zygote does not handle array mutation for now.
+
+Let's explore this with a more concrete example. Here we define a simple mutating function, `f!`, which modifies the elements of its input argument, `x`, in place.
```julia
function f!(x)
x .= 2 .* x
From 7604288b9898d31a32eefcc7a23ac04da820e94e Mon Sep 17 00:00:00 2001
From: Michael Abbott <32575566+mcabbott@users.noreply.github.com>
Date: Wed, 29 Jun 2022 21:33:09 -0600
Subject: [PATCH 329/490] rm rules for `eachslice`, `cumsum` (#1253)
* rm rules for eachslice, cumsum
* bump
* bound chainrules
* bump
---
Project.toml | 4 ++--
src/lib/array.jl | 31 -------------------------------
2 files changed, 2 insertions(+), 33 deletions(-)
diff --git a/Project.toml b/Project.toml
index a35854934..b0d881a56 100644
--- a/Project.toml
+++ b/Project.toml
@@ -1,6 +1,6 @@
name = "Zygote"
uuid = "e88e6eb3-aa80-5325-afca-941959d7151f"
-version = "0.6.41"
+version = "0.6.42"
[deps]
AbstractFFTs = "621f4979-c628-5d54-868e-fcf4e3e8185c"
@@ -25,7 +25,7 @@ ZygoteRules = "700de1a5-db45-46bc-99cf-38207098b444"
[compat]
AbstractFFTs = "0.5, 1.0"
-ChainRules = "1.35.3"
+ChainRules = "1.36.2"
ChainRulesCore = "1.9"
ChainRulesTestUtils = "1"
DiffRules = "1.4"
diff --git a/src/lib/array.jl b/src/lib/array.jl
index 92658e1e0..d1057b560 100644
--- a/src/lib/array.jl
+++ b/src/lib/array.jl
@@ -351,37 +351,6 @@ _backvar(xs, Δ, N::Int, mean) = (convert(eltype(xs), 2/N) .* Δ .* (xs .- mean)
return s, Δ -> _backvar(xs, Δ ./ (2 .* s), corrected, mean, dims)
end
-@adjoint function cumsum(xs::AbstractVector; dims::Integer = 1)
- dims == 1 || return copy(xs), Δ -> (Δ,)
- cumsum(xs), Δ -> (reverse(cumsum(reverse(Δ))),)
-end
-@adjoint function cumsum(xs::AbstractArray; dims::Integer)
- dims <= ndims(xs) || return copy(xs), Δ -> (Δ,)
- cumsum(xs; dims=dims), Δ -> begin
- (reverse(cumsum(reverse(Δ, dims=dims), dims=dims), dims=dims),)
- end
-end
-
-@adjoint eachrow(x::AbstractVecOrMat) = collect(eachrow(x)), dys -> ∇eachslice(dys, x, 1)
-@adjoint eachcol(x::AbstractVecOrMat) = collect(eachcol(x)), dys -> ∇eachslice(dys, x, 2)
-@adjoint eachslice(x::AbstractArray; dims::Integer) =
- collect(eachslice(x; dims=dims)), dys -> ∇eachslice(dys, x, dims)
-
-function ∇eachslice(dys, x::AbstractArray, dim::Integer) where {TX}
- i1 = findfirst(dy -> dy isa AbstractArray, dys)
- i1 === nothing && return (zero(x),) # all slices get nothing
- T = promote_type(eltype(dys[i1]), eltype(x))
- dx = similar(x, T)
- for i in axes(x, dim)
- if dys[i] isa AbstractArray
- copyto!(selectdim(dx,dim,i), dys[i])
- else
- selectdim(dx,dim,i) .= 0
- end
- end
- (dx,)
-end
-
# LinearAlgebra
# =============
From 4777767737b4c95d2cea842933c5b2edae2771b2 Mon Sep 17 00:00:00 2001
From: Michael Abbott <32575566+mcabbott@users.noreply.github.com>
Date: Wed, 29 Jun 2022 21:34:19 -0600
Subject: [PATCH 330/490] un-bump version
---
Project.toml | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/Project.toml b/Project.toml
index b0d881a56..9920da241 100644
--- a/Project.toml
+++ b/Project.toml
@@ -1,6 +1,6 @@
name = "Zygote"
uuid = "e88e6eb3-aa80-5325-afca-941959d7151f"
-version = "0.6.42"
+version = "0.6.41"
[deps]
AbstractFFTs = "621f4979-c628-5d54-868e-fcf4e3e8185c"
From 6336b60ab392ea4ade0f914db7f37453a97f1a42 Mon Sep 17 00:00:00 2001
From: Michael Abbott <32575566+mcabbott@users.noreply.github.com>
Date: Fri, 1 Jul 2022 10:26:15 -0600
Subject: [PATCH 331/490] rm rules for Statistics (#1252)
---
src/lib/array.jl | 18 ------------------
1 file changed, 18 deletions(-)
diff --git a/src/lib/array.jl b/src/lib/array.jl
index d1057b560..0eef0c64a 100644
--- a/src/lib/array.jl
+++ b/src/lib/array.jl
@@ -333,24 +333,6 @@ end
@adjoint conj(x::AbstractArray) = conj(x), r̄ -> (conj(r̄),)
@adjoint imag(x::AbstractArray) = imag(x), ī -> (complex.(0, real.(ī)),)
-@adjoint function mean(xs::AbstractArray; dims = :)
- return mean(xs, dims=dims), Δ -> (_backmean(xs,Δ,dims),)
-end
-_backmean(xs, Δ, ::Colon) = zero(xs) .+ Δ ./ length(xs)
-_backmean(xs, Δ, dims) = zero(xs) .+ Δ ./ mapreduce(i -> size(xs,i),*,dims)
-
-@adjoint function Statistics.var(xs::AbstractArray; corrected::Bool=true, dims=:, mean=mean(xs, dims=dims))
- return Statistics.var(xs; corrected=corrected, mean=mean, dims=dims), Δ -> _backvar(xs, Δ, corrected, mean, dims)
-end
-_backvar(xs, Δ, corrected::Bool, mean, dims) = _backvar(xs, Δ, mapreduce(i -> size(xs,i),*,dims) - corrected, mean)
-_backvar(xs, Δ, corrected::Bool, mean, ::Colon) = _backvar(xs, Δ, length(xs) - corrected, mean)
-_backvar(xs, Δ, N::Int, mean) = (convert(eltype(xs), 2/N) .* Δ .* (xs .- mean),)
-
-@adjoint function Statistics.std(xs::AbstractArray; corrected::Bool=true, dims=:, mean=mean(xs, dims=dims))
- s = Statistics.std(xs; corrected=corrected, mean=mean, dims=dims)
- return s, Δ -> _backvar(xs, Δ ./ (2 .* s), corrected, mean, dims)
-end
-
# LinearAlgebra
# =============
From ed84d53a97df7991c8688f797c0989e62101fdee Mon Sep 17 00:00:00 2001
From: Michael Abbott <32575566+mcabbott@users.noreply.github.com>
Date: Tue, 12 Jul 2022 23:23:48 -0400
Subject: [PATCH 332/490] rm `adjoint` & `transpose` adjoints (#1259)
* rm adjoint + transpose adjoint
* restore parent adjoints
---
src/lib/array.jl | 27 +--------------------------
1 file changed, 1 insertion(+), 26 deletions(-)
diff --git a/src/lib/array.jl b/src/lib/array.jl
index 0eef0c64a..e4079eb89 100644
--- a/src/lib/array.jl
+++ b/src/lib/array.jl
@@ -337,33 +337,8 @@ end
# LinearAlgebra
# =============
-@adjoint function transpose(x)
- back(Δ) = (transpose(Δ),)
- back(Δ::NamedTuple{(:parent,)}) = (Δ.parent,)
- return transpose(x), back
-end
-
-@adjoint function LinearAlgebra.Transpose(x)
- back(Δ) = (LinearAlgebra.Transpose(Δ),)
- back(Δ::NamedTuple{(:parent,)}) = (Δ.parent,)
- return LinearAlgebra.Transpose(x), back
-end
-
-
-@adjoint function Base.adjoint(x)
- back(Δ) = (Δ',)
- back(Δ::NamedTuple{(:parent,)}) = (Δ.parent,)
- return x', back
-end
-
-@adjoint function LinearAlgebra.Adjoint(x)
- back(Δ) = (LinearAlgebra.Adjoint(Δ),)
- back(Δ::NamedTuple{(:parent,)}) = (Δ.parent,)
- return LinearAlgebra.Adjoint(x), back
-end
-
@adjoint parent(x::LinearAlgebra.Adjoint) = parent(x), ȳ -> (LinearAlgebra.Adjoint(ȳ),)
-@adjoint parent(x::LinearAlgebra.Transpose) = parent(x), ȳ -> (LinearAlgebra.Transpose(ȳ),)
+@adjoint parent(x::LinearAlgebra.Transpose) = parent(x), ȳ -> (LinearAlgebra.Transpose(ȳ),)
function _kron(mat1::AbstractMatrix,mat2::AbstractMatrix)
m1, n1 = size(mat1)
From 995778d0520113b0f22bf9230c6a2ee5ba8ef459 Mon Sep 17 00:00:00 2001
From: "Ziyi (Francis) Yin" <54320031+ziyiyin97@users.noreply.github.com>
Date: Tue, 19 Jul 2022 14:47:53 -0400
Subject: [PATCH 333/490] fix the link (#1265)
---
docs/src/limitations.md | 6 +++---
src/compiler/reverse.jl | 4 ++--
src/lib/array.jl | 2 +-
3 files changed, 6 insertions(+), 6 deletions(-)
diff --git a/docs/src/limitations.md b/docs/src/limitations.md
index f74304e97..5b15afac3 100644
--- a/docs/src/limitations.md
+++ b/docs/src/limitations.md
@@ -33,7 +33,7 @@ the elements of arrays in-place (e.g. setting values with x .= ...)
Possible fixes:
- avoid mutating operations (preferred)
- or read the documentation and solutions for this error
- https://fluxml.ai/Zygote.jl/dev/limitations.html#Array-mutation
+ https://fluxml.ai/Zygote.jl/latest/limitations
Stacktrace:
...
@@ -94,7 +94,7 @@ julia> gradient(rand(3)) do x
end
ERROR: Compiling Tuple{typeof(tryme), Vector{Float64}}: try/catch is not supported.
Refer to the Zygote documentation for fixes.
-https://fluxml.ai/Zygote.jl/dev/limitations.html#try-catch-statements-1
+https://fluxml.ai/Zygote.jl/latest/limitations
Stacktrace:
...
@@ -116,7 +116,7 @@ julia> jclock(2)
julia> gradient(jclock, rand())
ERROR: Can't differentiate foreigncall expression
You might want to check the Zygote limitations documentation.
-https://fluxml.ai/Zygote.jl/dev/limitations.html
+https://fluxml.ai/Zygote.jl/latest/limitations
Stacktrace:
...
diff --git a/src/compiler/reverse.jl b/src/compiler/reverse.jl
index b6dec7caa..ba10ea5b1 100644
--- a/src/compiler/reverse.jl
+++ b/src/compiler/reverse.jl
@@ -120,7 +120,7 @@ function instrument(ir::IR)
elseif isexpr(ex, :enter, :leave)
error("""try/catch is not supported.
Refer to the Zygote documentation for fixes.
- https://fluxml.ai/Zygote.jl/dev/limitations.html#Try-catch-statements-1
+ https://fluxml.ai/Zygote.jl/latest/limitations
""")
elseif isexpr(ex, :(=))
@assert ex.args[1] isa GlobalRef
@@ -283,7 +283,7 @@ function adjoint(pr::Primal)
push!(rb, stmt(xcall(Base, :error, """
Can't differentiate $(ex.head) expression.
You might want to check the Zygote limitations documentation.
- https://fluxml.ai/Zygote.jl/dev/limitations.html
+ https://fluxml.ai/Zygote.jl/latest/limitations
"""),
line = b[v].line))
else # A literal value
diff --git a/src/lib/array.jl b/src/lib/array.jl
index e4079eb89..9496fdf32 100644
--- a/src/lib/array.jl
+++ b/src/lib/array.jl
@@ -75,7 +75,7 @@ the elements of arrays in place (e.g. setting values with x .= ...)
Possible fixes:
- avoid mutating operations (preferred)
- or read the documentation and solutions for this error
- https://fluxml.ai/Zygote.jl/dev/limitations.html#Array-mutation-1
+ https://fluxml.ai/Zygote.jl/latest/limitations
""")
@adjoint! setindex!(xs::AbstractArray, x...) = setindex!(xs, x...),
From 5ffbd43f70d85ed53ab5ca2cb4f281158414706f Mon Sep 17 00:00:00 2001
From: Michael Abbott <32575566+mcabbott@users.noreply.github.com>
Date: Tue, 26 Jul 2022 15:38:50 -0400
Subject: [PATCH 334/490] Replace `@require CUDA` with `using GPUArraysCore`
(#1272)
* require GPUArrays instead of CUDA
* more
* change to unconditionally load GPUArraysCore
* add GPUArrays dep
* trivial trigger commit
---
Project.toml | 4 ++++
src/lib/broadcast.jl | 29 +++++++++--------------------
2 files changed, 13 insertions(+), 20 deletions(-)
diff --git a/Project.toml b/Project.toml
index 9920da241..15f08ad2b 100644
--- a/Project.toml
+++ b/Project.toml
@@ -10,6 +10,8 @@ DiffRules = "b552c78f-8df3-52c6-915a-8e097449b14b"
Distributed = "8ba89e20-285c-5b6f-9357-94700520ee1b"
FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b"
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
+GPUArrays = "0c68f7d7-f131-5f86-a1c3-88cf8149b2d7" # not loaded, just a version bound
+GPUArraysCore = "46192b85-c4d5-4398-a991-12ede77f4527"
IRTools = "7869d1d1-7146-5819-86e3-90919afe41df"
InteractiveUtils = "b77e0a4c-d291-57a0-90e8-8db25a27a240"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
@@ -31,6 +33,8 @@ ChainRulesTestUtils = "1"
DiffRules = "1.4"
FillArrays = "0.8, 0.9, 0.10, 0.11, 0.12, 0.13"
ForwardDiff = "0.10"
+GPUArrays = "8.4.2" # not loaded, just a version bound
+GPUArraysCore = "0.1.1"
IRTools = "0.4.4"
LogExpFunctions = "0.3.1"
MacroTools = "0.5"
diff --git a/src/lib/broadcast.jl b/src/lib/broadcast.jl
index 6dbfdb829..b3c16e823 100644
--- a/src/lib/broadcast.jl
+++ b/src/lib/broadcast.jl
@@ -253,43 +253,32 @@ end
return y, bc_fwd_back
end
-@init @require CUDA="052768ef-5323-5732-b1bb-66c8b64840ba" begin
+using GPUArraysCore # replaces @require CUDA block, weird indenting to preserve git blame
- const CuArrayStyle = CUDA.AbstractGPUArrayStyle
-
- if isdefined(CUDA, :cufunc) # CUDA < 3.0
-
- @eval @adjoint broadcasted(::CuArrayStyle, f, args...) =
- broadcast_forward(CUDA.cufunc(f), args...)
-
- else # CUDA >= 3.0 -- don't need cufunc(f).
# Ordinary broadcasting calls broadcast_forward anyway when certain its' safe,
# so perhaps this can be deleted? Possible edge case here:
# https://github.com/FluxML/Zygote.jl/pull/1018#issuecomment-873629415
+ @adjoint broadcasted(::AbstractGPUArrayStyle, f, args...) =
+ broadcast_forward(f, args...)
- @eval @adjoint broadcasted(::CuArrayStyle, f, args...) =
- broadcast_forward(f, args...)
-
- end
-
- @adjoint (::Type{T})(xs::Array) where {T <: CUDA.CuArray} =
+ @adjoint (::Type{T})(xs::Array) where {T <: AbstractGPUArray} =
T(xs), Δ -> (convert(Array, Δ), )
- @adjoint function sum(xs::CUDA.AbstractGPUArray; dims = :)
+ @adjoint function sum(xs::AbstractGPUArray; dims = :)
placeholder = similar(xs)
sum(xs, dims = dims), Δ -> (placeholder .= Δ,)
end
# Make sure sum(f, ::CuArray) uses broadcase through forward-mode defined above
# Not the ChainRules.rrule which will use the Zygote.Context and thus not be GPU compatible
- @adjoint function sum(f, xs::CUDA.AbstractGPUArray; kws...)
+ @adjoint function sum(f, xs::AbstractGPUArray; kws...)
@assert !haskey(kws, :init) # TODO add init support (julia 1.6)
return pullback(__context__, (f, xs) -> sum(f.(xs); kws...), f, xs)
end
- @adjoint function Base.convert(::Type{T}, xs::Array) where {T<:CUDA.AbstractGPUArray}
+ @adjoint function Base.convert(::Type{T}, xs::Array) where {T<:AbstractGPUArray}
Base.convert(T, xs), Δ -> (nothing, Base.convert(Array, Δ),)
end
- @eval pull_block_vert(sz, Δ::CUDA.CuArray, A::Number) = CUDA.@allowscalar Δ[sz]
-end
+ pull_block_vert(sz, Δ::AbstractGPUArray, A::Number) = @allowscalar Δ[sz]
+
From c822e9e77fa76647ba2a39896cbcdee604e9aa9f Mon Sep 17 00:00:00 2001
From: Brian Chen
Date: Fri, 29 Jul 2022 19:59:10 -0700
Subject: [PATCH 335/490] Bump version
---
Project.toml | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/Project.toml b/Project.toml
index 55581f932..1ac51616e 100644
--- a/Project.toml
+++ b/Project.toml
@@ -1,6 +1,6 @@
name = "Zygote"
uuid = "e88e6eb3-aa80-5325-afca-941959d7151f"
-version = "0.6.41"
+version = "0.6.42"
[deps]
AbstractFFTs = "621f4979-c628-5d54-868e-fcf4e3e8185c"
From be5b47fad5fc9c0a3e22f239ec8517df60ffdf4c Mon Sep 17 00:00:00 2001
From: Brian Chen
Date: Wed, 22 Jun 2022 19:35:04 -0700
Subject: [PATCH 336/490] Improved type stability with explicit params
We can disable accumulating (implicit) parameters to the gradient cache
in explicit mode. This can dramatically improve type stability because
`accum_param` will return a `Union{Nothing, [grad type]}` otherwise.
---
src/compiler/interface.jl | 34 ++++++++++++++++++++++++++--------
src/lib/array.jl | 4 ++--
src/lib/broadcast.jl | 12 ++++++++----
src/lib/lib.jl | 3 ++-
test/compiler.jl | 17 +++++++++++------
5 files changed, 49 insertions(+), 21 deletions(-)
diff --git a/src/compiler/interface.jl b/src/compiler/interface.jl
index d5428e97e..f429102f6 100644
--- a/src/compiler/interface.jl
+++ b/src/compiler/interface.jl
@@ -4,11 +4,11 @@ using Core: Typeof
import Base: copy!, IdSet
import Base.Broadcast: broadcasted, materialize!
-mutable struct Context <: AContext
+mutable struct Context{I} <: AContext
cache::Union{IdDict{Any,Any},Nothing}
end
-Context() = Context(nothing)
+Context() = Context{false}(nothing)
cache(cx::Context) = cx.cache === nothing ? (cx.cache = IdDict()) : cx.cache
@@ -36,10 +36,28 @@ _pullback(f, args...) = _pullback(Context(), f, args...)
tailmemaybe(::Nothing) = nothing
tailmemaybe(x::Tuple) = Base.tail(x)
-function pullback(f, args...)
- y, back = _pullback(f, args...)
+@inline pullback(f, args...) = pullback(f, Context(), args...)
+function pullback(f, cx::AContext, args...)
+ y, back = _pullback(cx, f, args...)
y, Δ -> tailmemaybe(back(Δ))
end
+function pullback(cx::Context, f, args...)
+ ChainRulesCore.ignore_derivatives() do
+ @warn """
+ Incorrect argument order for pullback, please use:
+
+ pullback(f, __context__::Context, args)
+
+ instead of:
+
+ pullback(__context__::Context, f, args)
+
+ This is usually caused by a call to pullback in a higher-order @adjoint.
+ The above warning will become an error in Zygote 0.7.
+ """
+ end
+ return pullback(f, cx, args...)
+end
sensitivity(y::Number) = one(y)
sensitivity(y::Complex) = error("Output is complex, so the gradient is not defined.")
@@ -334,21 +352,21 @@ function Base.map(f, gs1::Grads, gss::ADictOrGrads...)
end
function Base.map!(f, gsout::Grads, gss::ADictOrGrads...)
- all(issetequal(gsout.params, keys(gs)) for gs in gss) ||
+ all(issetequal(gsout.params, keys(gs)) for gs in gss) ||
throw(ArgumentError("map! expects Grads objects with the same Params."))
for p in gsout.params
- gsout[p] = f((_getformap(gs, p) for gs in gss)...)
+ gsout[p] = f((_getformap(gs, p) for gs in gss)...)
end
return gsout
end
function _getformap(gs, p)
g = gs[p]
- isnothing(g) ? fill!(similar(p), 0) : g
+ isnothing(g) ? fill!(similar(p), 0) : g
end
function pullback(f, ps::Params)
- cx = Context()
+ cx = Context{true}(nothing)
y, back = _pullback(cx, f)
y, function (Δ)
for p in ps
diff --git a/src/lib/array.jl b/src/lib/array.jl
index 4e72713d9..293801b21 100644
--- a/src/lib/array.jl
+++ b/src/lib/array.jl
@@ -310,7 +310,7 @@ end
@adjoint function sum(f, xs::AbstractArray{<:AbstractArray}; kws...)
@assert !haskey(kws, :init) # TODO add init support (julia 1.6)
- return pullback(__context__, (f, xs) -> sum(f.(xs); kws...), f, xs)
+ return pullback((f, xs) -> sum(f.(xs); kws...), __context__, f, xs)
end
@adjoint function sum(xs::AbstractArray{Bool}; dims = :)
@@ -318,7 +318,7 @@ end
end
function _pullback(cx::AContext, ::typeof(prod), f, xs::AbstractArray)
- y, back = pullback(cx, ((f, xs) -> prod(f.(xs))), f, xs)
+ y, back = pullback((f, xs) -> prod(f.(xs)), cx, f, xs)
y, ȳ -> (nothing, back(ȳ)...)
end
diff --git a/src/lib/broadcast.jl b/src/lib/broadcast.jl
index b3c16e823..8c0d3c54c 100644
--- a/src/lib/broadcast.jl
+++ b/src/lib/broadcast.jl
@@ -30,6 +30,10 @@ using Base.Broadcast: Broadcasted, AbstractArrayStyle, broadcasted, materialize
# Utilities
# =========
+# ChainRules already marks this non-differentiable,
+# But inference can still give up because of the Zygote -> CR wrapper layer
+@nograd Broadcast.combine_styles
+
accum_sum(xs; dims = :) = reduce(accum, xs, dims = dims)
# Work around reducedim_init issue
@@ -82,16 +86,16 @@ _minus(::Nothing) = nothing
@adjoint broadcasted(::typeof(*), x::Numeric, y::Numeric) = x.*y,
Δ -> (nothing, unbroadcast(x, Δ .* conj.(y)), unbroadcast(y, Δ .* conj.(x)))
@adjoint broadcasted(::typeof(*), x::Number, y::AbstractArray{<:Number}) =
- _pullback(*, x, y) # this uses dot(y,Δ) instead of sum(Δ .* conj.(y))
+ _pullback(__context__, *, x, y) # this uses dot(y,Δ) instead of sum(Δ .* conj.(y))
@adjoint broadcasted(::typeof(*), x::AbstractArray{<:Number}, y::Number) =
- _pullback(*, x, y)
+ _pullback(__context__, *, x, y)
@adjoint function broadcasted(::typeof(/), x::Numeric, y::Numeric)
res = x ./ y
res, Δ -> (nothing, unbroadcast(x, Δ ./ conj.(y)), unbroadcast(y, .-Δ .* conj.(res ./ y)))
end
@adjoint broadcasted(::typeof(/), x::AbstractArray{<:Number}, y::Number) =
- _pullback(/, x, y)
+ _pullback(__context__, /, x, y)
@adjoint function broadcasted(::typeof(Base.literal_pow), ::typeof(^), x::Numeric, exp::Val{p}) where p
y = Base.literal_pow.(^, x, exp)
@@ -273,7 +277,7 @@ using GPUArraysCore # replaces @require CUDA block, weird indenting to preserve
# Not the ChainRules.rrule which will use the Zygote.Context and thus not be GPU compatible
@adjoint function sum(f, xs::AbstractGPUArray; kws...)
@assert !haskey(kws, :init) # TODO add init support (julia 1.6)
- return pullback(__context__, (f, xs) -> sum(f.(xs); kws...), f, xs)
+ return pullback((f, xs) -> sum(f.(xs); kws...), __context__, f, xs)
end
@adjoint function Base.convert(::Type{T}, xs::Array) where {T<:AbstractGPUArray}
diff --git a/src/lib/lib.jl b/src/lib/lib.jl
index 22bda1e19..52a734809 100644
--- a/src/lib/lib.jl
+++ b/src/lib/lib.jl
@@ -21,7 +21,7 @@ accum(x, y) =
accum(x, y, zs...) = accum(accum(x, y), zs...)
-accum(x::Tuple, ys::Tuple...) = accum.(x, ys...)
+accum(x::Tuple, ys::Tuple...) = map(accum, x, ys...)
accum(x::AbstractArray, ys::AbstractArray...) = accum.(x, ys...)
@generated function accum(x::NamedTuple, y::NamedTuple)
@@ -48,6 +48,7 @@ end
@adjoint Base.typeassert(x, T) = Base.typeassert(x, T), Δ -> (Δ, nothing)
+accum_param(::Context{false}, _, Δ) = Δ
@generated function accum_param(cx::Context, x, Δ)
isbitstype(x) && return :(Δ)
quote
diff --git a/test/compiler.jl b/test/compiler.jl
index bc37d271e..c5ddf1f38 100644
--- a/test/compiler.jl
+++ b/test/compiler.jl
@@ -1,5 +1,5 @@
using Zygote, Test
-using Zygote: pullback, @adjoint
+using Zygote: pullback, @adjoint, Context
macro test_inferred(ex)
:(let res = nothing
@@ -160,13 +160,18 @@ end
@testset "inference for `getproperty`" begin
Gaussian = _Gaussian(:getproperty)
g = Gaussian(randn(3), randn(3, 3))
- y, back = @inferred pullback(x -> x.m, g)
- @test y == getfield(g, :m)
- # This type instability is due to the handling of non-bitstypes in `accum_param`
+ y_explicit, back_explicit = @inferred pullback(x -> x.m, g)
+ y_implicit, back_implicit = @inferred pullback(x -> x.m, Context{true}(nothing), g)
+ @test y_explicit == y_implicit == getfield(g, :m)
+
+ ∇args = ((m = [1.0, 0.0, 0.0], P = nothing),)
if VERSION > v"1.7-"
- @test Base.return_types(back, Tuple{Vector{Float64}}) == Any[Union{Tuple{Nothing}, typeof(((m = [1.0, 0.0, 0.0], P = nothing),))}]
+ # This type instability is due to the handling of non-bitstypes in `accum_param`
+ @test Base.return_types(back_implicit, Tuple{Vector{Float64}}) == Any[Union{Tuple{Nothing}, typeof(∇args)}]
+ # But the same should infer if implicit parameters are disabled
+ @test Base.return_types(back_explicit, Tuple{Vector{Float64}}) == Any[typeof(∇args)]
end
- @test back([1., 0, 0]) == ((m = [1.0, 0.0, 0.0], P = nothing),)
+ @test back_explicit([1., 0, 0]) == back_implicit([1., 0, 0]) == ∇args
Base.getproperty(g::Gaussian, s::Symbol) = 2getfield(g, s)
y, back = pullback(x -> x.m, g)
From e9a60757196368f8999e2413f2658a3d9cfffab4 Mon Sep 17 00:00:00 2001
From: Brian Chen
Date: Fri, 24 Jun 2022 12:43:33 -0700
Subject: [PATCH 337/490] basic comment for Context{I}
---
src/compiler/interface.jl | 7 +++++--
1 file changed, 5 insertions(+), 2 deletions(-)
diff --git a/src/compiler/interface.jl b/src/compiler/interface.jl
index f429102f6..ee2a69528 100644
--- a/src/compiler/interface.jl
+++ b/src/compiler/interface.jl
@@ -4,6 +4,9 @@ using Core: Typeof
import Base: copy!, IdSet
import Base.Broadcast: broadcasted, materialize!
+# Internal container used to track accumulated gradients of mutable types (including params).
+# Type param I ∈ (true, false) indicates whether implicit params are in use.
+# By default, this should be false unless pullback(f, ::Params) is called.
mutable struct Context{I} <: AContext
cache::Union{IdDict{Any,Any},Nothing}
end
@@ -47,11 +50,11 @@ function pullback(cx::Context, f, args...)
Incorrect argument order for pullback, please use:
pullback(f, __context__::Context, args)
-
+
instead of:
pullback(__context__::Context, f, args)
-
+
This is usually caused by a call to pullback in a higher-order @adjoint.
The above warning will become an error in Zygote 0.7.
"""
From 3433cdd310cf3f262ea72370caba13028d8d4e0e Mon Sep 17 00:00:00 2001
From: Brian Chen
Date: Sun, 31 Jul 2022 08:36:57 -0700
Subject: [PATCH 338/490] Add accum_param tests
Co-authored-by: Michael Abbott <32575566+mcabbott@users.noreply.github.com>
---
test/features.jl | 61 ++++++++++++++++++++++++++++++++++++++++++++----
1 file changed, 57 insertions(+), 4 deletions(-)
diff --git a/test/features.jl b/test/features.jl
index cdfe7329e..d4f68d36b 100644
--- a/test/features.jl
+++ b/test/features.jl
@@ -476,7 +476,7 @@ end
@test_broken gradient(x -> abs2(x[1].x) + 7 * x[1].x.re, [Ref(1+im)]) == ([(x = 9.0 + 2.0im,)],)
@test_broken gradient(x -> abs2(x[1].x) + 7 * real(x[1].x), [Ref(1+im)]) == ([(x = 9.0 + 2.0im,)],) # worked on 0.6.0, 0.6.20
- @test_broken gradient(x -> abs2(x[].x) + 7 * real(x[].x), Ref(Ref(1+im))) == ((x = 9.0 + 2.0im,),) # gives nothing, same in 0.6.0
+ @test gradient(x -> abs2(x[].x) + 7 * real(x[].x), Ref(Ref(1+im))) == ((x = (x = 9.0 + 2.0im,),),) # gave `nothing` from 0.6.0 to 0.6.41
# Array of mutables:
@test gradient(x -> sum(getindex.(x).^2), Ref.(1:3))[1] == [(;x=2i) for i in 1:3]
@@ -490,6 +490,59 @@ end
@test gradient(x -> sum(sum, Ref(x) .* [1,2,3]), [4,5]) == ([6.0, 6.0],)
end
+@testset "mutable accum_param bugs" begin
+ mutable struct Mut{T}; x::T; end
+ struct Imm{T}; x::T; end
+
+ # Indexing a tuple containing a mutable struct gave `nothing`
+ x1 = (Mut(3.0),)
+ x2 = (Imm(3.0),)
+ x3 = (Ref(3.0),)
+ @test gradient(x -> x[1].x^2, x1)[1] == ((x = 6.0,),) # fails on v0.6.0 v0.6.41
+ @test gradient(x -> x[1].x^2, x2)[1] == ((x = 6.0,),)
+ @test gradient(x -> x[1].x^2, x3)[1] == ((x = 6.0,),) # fails on v0.6.0 v0.6.41
+ i1 = 1
+ @test gradient(x -> x[i1].x^2, x1)[1] == ((x = 6.0,),) # fails on v0.6.0 v0.6.41
+ @test gradient(x -> x[i1].x^2, x2)[1] == ((x = 6.0,),)
+ @test gradient(x -> x[i1].x^2, x3)[1] == ((x = 6.0,),) # fails on v0.6.0 v0.6.41
+
+ @test gradient(x -> x[1][1].x^2, [x1])[1] == [((x = 6.0,),)] # fails on v0.6.0 v0.6.41
+ @test gradient(x -> x[1][1].x^2, [x2])[1] == [((x = 6.0,),)]
+ @test gradient(x -> x[1][1].x^2, [x3])[1] == [((x = 6.0,),)] # fails on v0.6.0 v0.6.41
+
+ # When `getfield` returns a mutable struct, it gave `nothing`:
+ x4 = Imm(Mut(4.0))
+ x5 = Mut(Mut(4.0))
+ x6 = Imm(Imm(4.0))
+ @test gradient(x -> x.x.x^3, x4)[1] == (x = (x = 48.0,),) # fails on v0.6.0 v0.6.41
+ @test gradient(x -> x.x.x^3, x5)[1] == (x = (x = 48.0,),) # fails on v0.6.0
+ @test gradient(x -> x.x.x^3, x6)[1] == (x = (x = 48.0,),) # fails on v0.6.41
+
+ @test gradient(x -> x[2].x.x^3, [x4, x4])[1] == [nothing, (x = (x = 48.0,),)] # fails on v0.6.0 v0.6.41
+ @test gradient(x -> x[2].x.x^3, [x4, x5])[1] == [nothing, (x = (x = 48.0,),)] # fails on v0.6.0
+ @test gradient(x -> x[2].x.x^3, [x4, x6])[1] == [nothing, (x = (x = 48.0,),)] # fails on v0.6.41
+
+ # Check when using implicit parameters, Params cases used to pass:
+ y1 = [3.0]
+ y2 = (Mut(y1),)
+ y3 = (Imm(y1),)
+ @test gradient(x -> sum(x[1].x)^2, y2)[1] == ((x = [6.0],),) # fails on v0.6.0 v0.6.41
+ @test gradient(() -> sum(y2[1].x)^2, Params([y1]))[y1] == [6.0]
+ @test gradient(x -> sum(x[1].x)^2, y3)[1] == ((x = [6.0],),)
+ @test gradient(() -> sum(y3[1].x)^2, Params([y1]))[y1] == [6.0]
+
+ @test gradient(x -> sum(x[1].x .+ x[1].x)^3, y2)[1] == ((x = [216.0],),) # fails on v0.6.0 v0.6.41
+ @test gradient(() -> sum(y2[1].x .+ y2[1].x)^3, Params([y1]))[y1] == [216.0]
+ @test gradient(x -> sum(x[1].x .+ x[1].x)^3, y3)[1] == ((x = [216.0],),)
+ @test gradient(() -> sum(y3[1].x .+ y3[1].x)^3, Params([y1]))[y1] == [216.0]
+
+ i1 = 1
+ @test gradient(x -> sum(x[i1].x .+ x[1].x)^3, y2)[1] == ((x = [216.0],),) # fails on v0.6.0 v0.6.41
+ @test gradient(() -> sum(y2[i1].x .+ y2[1].x)^3, Params([y1]))[y1] == [216.0]
+ @test gradient(x -> sum(x[i1].x .+ x[1].x)^3, y3)[1] == ((x = [216.0],),)
+ @test gradient(() -> sum(y3[i1].x .+ y3[1].x)^3, Params([y1]))[y1] == [216.0]
+end
+
@testset "NamedTuples" begin
@test gradient(x -> x.a, (a=1, b=2)) == ((a = 1, b = nothing),)
@test gradient(x -> x[1].a, [(a=1, b=2)]) == ([(a = 1, b = nothing)],)
@@ -517,7 +570,7 @@ end
@test (x->10*(x => 2)[2])'(100) === nothing
@test gradient(x-> (:x => x)[2], 17) == (1,)
-
+
d = Dict(:x=>1.0, :y=>3.0);
@test gradient(d -> Dict(:x => d[:x])[:x], d) == (Dict(:x => 1),)
end
@@ -546,7 +599,7 @@ end
# zip
if VERSION >= v"1.5"
# On Julia 1.4 and earlier, [x/y for (x,y) in zip(10:14, 1:10)] is a DimensionMismatch,
- # while on 1.5 - 1.7 it stops early.
+ # while on 1.5 - 1.7 it stops early.
@test gradient(10:14, 1:10) do xs, ys
sum([x/y for (x,y) in zip(xs, ys)])
@@ -608,7 +661,7 @@ end
# Iterators.Product with enumerate
@test gradient([2 3; 4 5]) do xs
- sum([x^i+y for (i,x) in enumerate(xs), y in xs])
+ sum([x^i+y for (i,x) in enumerate(xs), y in xs])
end == ([8 112; 36 2004],)
end
From c098f37a643223e7e4f397abf9c3ea6c8c542325 Mon Sep 17 00:00:00 2001
From: Miha Zgubic
Date: Thu, 28 Jul 2022 13:20:39 +0200
Subject: [PATCH 339/490] number adjoints to rrules
---
src/lib/number.jl | 63 ++++++++++++++++++++++++++++++++++------------
test/lib/number.jl | 51 +++++++++++++++++++++++++++++++------
2 files changed, 91 insertions(+), 23 deletions(-)
diff --git a/src/lib/number.jl b/src/lib/number.jl
index 296852dbc..0e629518d 100644
--- a/src/lib/number.jl
+++ b/src/lib/number.jl
@@ -1,29 +1,60 @@
-@adjoint Base.literal_pow(::typeof(^), x::Number, ::Val{p}) where {p} =
- Base.literal_pow(^,x,Val(p)),
- Δ -> (nothing, Δ * conj(p * Base.literal_pow(^,x,Val(p-1))), nothing)
+function ChainRulesCore.rrule(
+ ::ZygoteRuleConfig, ::typeof(Base.literal_pow), ::typeof(^), x::Number, ::Val{p}
+) where {p}
+ function literal_pow_pullback(Δ)
+ dx = Δ * conj(p * Base.literal_pow(^,x,Val(p-1)))
+ return (NoTangent(), NoTangent(), dx, NoTangent())
+ end
+ return Base.literal_pow(^,x,Val(p)), literal_pow_pullback
+end
-@adjoint Base.convert(T::Type{<:Real}, x::Real) = convert(T, x), ȳ -> (nothing, ȳ)
-@adjoint (T::Type{<:Real})(x::Real) = T(x), ȳ -> (nothing, ȳ)
+function ChainRulesCore.rrule(::ZygoteRuleConfig, T::Type{<:Real}, x::Real)
+ Real_pullback(Δ) = (NoTangent(), Δ)
+ return T(x), Real_pullback
+end
for T in Base.uniontypes(Core.BuiltinInts)
- @adjoint (::Type{T})(x::Core.BuiltinInts) = T(x), Δ -> (Δ,)
+ @eval function ChainRulesCore.rrule(::ZygoteRuleConfig, ::Type{$T}, x::Core.BuiltinInts)
+ IntX_pullback(Δ) = (NoTangent(), Δ)
+ return $T(x), IntX_pullback
+ end
end
-@adjoint Base.:+(xs::Number...) = +(xs...), Δ -> map(_ -> Δ, xs)
+function ChainRulesCore.rrule(::ZygoteRuleConfig, ::typeof(+), xs::Number...)
+ plus_pullback(Δ) = (NoTangent(), map(_ -> Δ, xs)...)
+ return +(xs...), plus_pullback
+end
-@adjoint a // b = (a // b, c̄ -> (c̄ * 1//b, - c̄ * a // b // b))
+function ChainRulesCore.rrule(::ZygoteRuleConfig, ::typeof(//), a, b)
+ divide_pullback(r̄) = (NoTangent(), r̄ * 1//b, - r̄ * a // b // b)
+ return a // b, divide_pullback
+end
# Complex Numbers
-@adjoint (T::Type{<:Complex})(re, im) = T(re, im), c̄ -> (nothing, real(c̄), imag(c̄))
+function ChainRulesCore.rrule(::ZygoteRuleConfig, T::Type{<:Complex}, r, i)
+ Complex_pullback(c̄) = (NoTangent(), real(c̄), imag(c̄))
+ return T(r, i), Complex_pullback
+end
# we define these here because ChainRules.jl only defines them for x::Union{Real,Complex}
-@adjoint abs2(x::Number) = abs2(x), Δ -> (real(Δ)*(x + x),)
-@adjoint real(x::Number) = real(x), r̄ -> (real(r̄),)
-@adjoint conj(x::Number) = conj(x), r̄ -> (conj(r̄),)
-@adjoint imag(x::Number) = imag(x), ī -> (real(ī)*im,)
+function ChainRulesCore.rrule(::ZygoteRuleConfig, ::typeof(abs2), x::Number)
+ abs2_pullback(Δ) = (NoTangent(), real(Δ)*(x + x))
+ return abs2(x), abs2_pullback
+end
-# for real x, ChainRules pulls back a zero real adjoint, whereas we treat x
-# as embedded in the complex numbers and pull back a pure imaginary adjoint
-@adjoint imag(x::Real) = zero(x), ī -> (real(ī)*im,)
+function ChainRulesCore.rrule(::ZygoteRuleConfig, ::typeof(real), x::Number)
+ real_pullback(r̄) = (NoTangent(), real(r̄))
+ return real(x), real_pullback
+end
+
+function ChainRulesCore.rrule(::ZygoteRuleConfig, ::typeof(conj), x::Number)
+ conj_pullback(c̄) = (NoTangent(), conj(c̄))
+ return conj(x), conj_pullback
+end
+
+function ChainRulesCore.rrule(::ZygoteRuleConfig, ::typeof(imag), x::Number)
+ imag_pullback(ī) = (NoTangent(), real(ī)*im)
+ return imag(x), imag_pullback
+end
diff --git a/test/lib/number.jl b/test/lib/number.jl
index ae7b1cb75..ce0a64bef 100644
--- a/test/lib/number.jl
+++ b/test/lib/number.jl
@@ -1,7 +1,44 @@
-@testset "nograds" begin
- @test gradient(floor, 1) === (0.0,)
- @test gradient(ceil, 1) === (0.0,)
- @test gradient(round, 1) === (0.0,)
- @test gradient(hash, 1) === nothing
- @test gradient(div, 1, 2) === nothing
-end #testset
+@testset "number.jl" begin
+ @testset "nograds" begin
+ @test gradient(floor, 1) === (0.0,)
+ @test gradient(ceil, 1) === (0.0,)
+ @test gradient(round, 1) === (0.0,)
+ @test gradient(hash, 1) === nothing
+ @test gradient(div, 1, 2) === nothing
+ end
+
+ @testset "basics" begin
+ @test gradient(Base.literal_pow, ^, 3//2, Val(-5))[2] isa Rational
+
+ @test gradient(convert, Rational, 3.14) == (nothing, 1.0)
+ @test gradient(convert, Rational, 2.3) == (nothing, 1.0)
+ @test gradient(convert, UInt64, 2) == (nothing, 1.0)
+ @test gradient(convert, BigFloat, π) == (nothing, 1.0)
+
+ @test gradient(Rational, 2) == (1//1,)
+
+ @test gradient(Bool, 1) == (1.0,)
+ @test gradient(Int32, 2) == (1.0,)
+ @test gradient(UInt16, 2) == (1.0,)
+
+ @test gradient(+, 2.0, 3, 4.0, 5.0) == (1.0, 1.0, 1.0, 1.0)
+
+ @test gradient(//, 3, 2) == (1//2, -3//4)
+ end
+
+ @testset "Complex numbers" begin
+ @test gradient(imag, 3.0) == (0.0,)
+ @test gradient(imag, 3.0 + 3.0im) == (0.0 + 1.0im,)
+
+ @test gradient(conj, 3.0) == (1.0,)
+ @test gradient(real ∘ conj, 3.0 + 1im) == (1.0 + 0im,)
+
+ @test gradient(real, 3.0) == (1.0,)
+ @test gradient(real, 3.0 + 1im) == (1.0 + 0im,)
+
+ @test gradient(abs2, 3.0) == (2*3.0,)
+ @test gradient(abs2, 3.0+2im) == (2*3.0 + 2*2.0im,)
+
+ @test gradient(real ∘ Complex, 3.0, 2.0) == (1.0, 0.0)
+ end
+end
From cbe800d129564a925958ce6f8de5ca10ff017490 Mon Sep 17 00:00:00 2001
From: Miha Zgubic
Date: Thu, 28 Jul 2022 22:54:42 +0200
Subject: [PATCH 340/490] replace convert as well
---
src/lib/number.jl | 7 +++++++
1 file changed, 7 insertions(+)
diff --git a/src/lib/number.jl b/src/lib/number.jl
index 0e629518d..30b702254 100644
--- a/src/lib/number.jl
+++ b/src/lib/number.jl
@@ -1,3 +1,10 @@
+function ChainRulesCore.rrule(
+ ::ZygoteRuleConfig, ::typeof(convert), T::Type{<:Real}, x::Real
+)
+ convert_pullback(Δ) = (NoTangent(), NoTangent(), Δ)
+ return convert(T, x), convert_pullback
+end
+
function ChainRulesCore.rrule(
::ZygoteRuleConfig, ::typeof(Base.literal_pow), ::typeof(^), x::Number, ::Val{p}
) where {p}
From cdadaffe55b69a1202014ab3b49522dc209c3425 Mon Sep 17 00:00:00 2001
From: Miha Zgubic
Date: Mon, 1 Aug 2022 10:02:55 +0200
Subject: [PATCH 341/490] comment and version number
---
Project.toml | 2 +-
src/lib/number.jl | 2 ++
2 files changed, 3 insertions(+), 1 deletion(-)
diff --git a/Project.toml b/Project.toml
index 1ac51616e..8328eb69a 100644
--- a/Project.toml
+++ b/Project.toml
@@ -1,6 +1,6 @@
name = "Zygote"
uuid = "e88e6eb3-aa80-5325-afca-941959d7151f"
-version = "0.6.42"
+version = "0.6.43"
[deps]
AbstractFFTs = "621f4979-c628-5d54-868e-fcf4e3e8185c"
diff --git a/src/lib/number.jl b/src/lib/number.jl
index 30b702254..aa50c54dc 100644
--- a/src/lib/number.jl
+++ b/src/lib/number.jl
@@ -61,6 +61,8 @@ function ChainRulesCore.rrule(::ZygoteRuleConfig, ::typeof(conj), x::Number)
return conj(x), conj_pullback
end
+# for real x, ChainRules pulls back a zero real adjoint, whereas we treat x
+# as embedded in the complex numbers and pull back a pure imaginary adjoint
function ChainRulesCore.rrule(::ZygoteRuleConfig, ::typeof(imag), x::Number)
imag_pullback(ī) = (NoTangent(), real(ī)*im)
return imag(x), imag_pullback
From 1207b4d0fb0c1d3b2cb64e3cb38bfaf60c03f5fc Mon Sep 17 00:00:00 2001
From: Miha Zgubic
Date: Mon, 1 Aug 2022 23:27:45 +0200
Subject: [PATCH 342/490] remove adjoint for hv/h/v/cat
---
src/lib/array.jl | 21 ---------------------
1 file changed, 21 deletions(-)
diff --git a/src/lib/array.jl b/src/lib/array.jl
index 293801b21..420e3716f 100644
--- a/src/lib/array.jl
+++ b/src/lib/array.jl
@@ -104,27 +104,6 @@ end
@adjoint reshape(xs, dims...) = reshape(xs, dims...),
Δ -> (reshape(Δ, size(xs)),map(_->nothing,dims)...)
-@adjoint function hvcat(rows::Tuple{Vararg{Int}}, xs::Number...)
- hvcat(rows, xs...), ȳ -> (nothing, permutedims(ȳ)...)
-end
-
-pull_block_vert(sz, Δ, A::Number) = Δ[sz]
-pull_block_vert(sz, Δ, A::AbstractVector) = Δ[sz-length(A)+1:sz]
-pull_block_vert(sz, Δ, A::AbstractMatrix) = Δ[sz-size(A, 1)+1:sz, :]
-@adjoint function vcat(A::Union{AbstractVector, AbstractMatrix, Number}...)
- sz = cumsum([size.(A, 1)...])
- return vcat(A...), Δ->(map(n->pull_block_vert(sz[n], Δ, A[n]), eachindex(A))...,)
-end
-@adjoint vcat(xs::Number...) = vcat(xs...), Δ -> (Δ...,)
-
-pull_block_horz(sz, Δ, A::AbstractVector) = Δ[:, sz]
-pull_block_horz(sz, Δ, A::AbstractMatrix) = Δ[:, sz-size(A, 2)+1:sz]
-@adjoint function hcat(A::Union{AbstractVector, AbstractMatrix}...)
- sz = cumsum([size.(A, 2)...])
- return hcat(A...), Δ->(map(n->pull_block_horz(sz[n], Δ, A[n]), eachindex(A))...,)
-end
-@adjoint hcat(xs::Number...) = hcat(xs...), Δ -> (Δ...,)
-
@adjoint function repeat(xs; inner=ntuple(_->1, ndims(xs)), outer=ntuple(_->1, ndims(xs)))
repeat(xs, inner = inner, outer = outer), function (Δ)
Δ′ = zero(xs)
From 1d63da63d47ad66372c8092e6cb89932603627e8 Mon Sep 17 00:00:00 2001
From: Miha Zgubic
Date: Mon, 1 Aug 2022 23:29:10 +0200
Subject: [PATCH 343/490] v0.6.44
---
Project.toml | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/Project.toml b/Project.toml
index 8328eb69a..c61ff68c2 100644
--- a/Project.toml
+++ b/Project.toml
@@ -1,6 +1,6 @@
name = "Zygote"
uuid = "e88e6eb3-aa80-5325-afca-941959d7151f"
-version = "0.6.43"
+version = "0.6.44"
[deps]
AbstractFFTs = "621f4979-c628-5d54-868e-fcf4e3e8185c"
From 328eb4d122a12b0a6c4947e17278081e877169cf Mon Sep 17 00:00:00 2001
From: Brian Chen
Date: Mon, 1 Aug 2022 18:52:06 -0700
Subject: [PATCH 344/490] propagate ambiguities from rrule lookup instead of
failing inexplicably
---
src/compiler/chainrules.jl | 3 ++-
test/chainrules.jl | 9 +++++++++
2 files changed, 11 insertions(+), 1 deletion(-)
diff --git a/src/compiler/chainrules.jl b/src/compiler/chainrules.jl
index 99d8f4652..7c7de8655 100644
--- a/src/compiler/chainrules.jl
+++ b/src/compiler/chainrules.jl
@@ -73,7 +73,8 @@ matching_cr_sig(t, s) = matching_cr_sig(t.method.sig, s.method.sig)
matching_cr_sig(::DataType, ::UnionAll) = false
matching_cr_sig(::UnionAll, ::DataType) = false
matching_cr_sig(t::Type, s::Type) = type_tuple_tail(t) == type_tuple_tail(s)
-
+matching_cr_sig(::Any, ::Nothing) = false # https://github.com/FluxML/Zygote.jl/issues/1234
+
type_tuple_tail(d::DataType) = Tuple{d.parameters[2:end]...}
function type_tuple_tail(d::UnionAll)
body = Base.unwrap_unionall(d)
diff --git a/test/chainrules.jl b/test/chainrules.jl
index 94ab9584a..e9cb4afbc 100644
--- a/test/chainrules.jl
+++ b/test/chainrules.jl
@@ -275,6 +275,15 @@ using Zygote: ZygoteRuleConfig
@test Zygote.gradient(x -> f_notimplemented(only(x)), [0.1]) === (nothing,)
end
end
+
+ # https://github.com/FluxML/Zygote.jl/issues/1234
+ @testset "rrule lookup ambiguities" begin
+ f_ambig(x, y) = x + y
+ ChainRulesCore.rrule(::typeof(f_ambig), x::Int, y) = x + y, _ -> (0, 0)
+ ChainRulesCore.rrule(::typeof(f_ambig), x, y::Int) = x + y, _ -> (0, 0)
+
+ @test_throws MethodError pullback(f_ambig, 1, 2)
+ end
end
@testset "ChainRulesCore.rrule_via_ad" begin
From 4da04412628ad2803038dd38240cb482d6f22a5a Mon Sep 17 00:00:00 2001
From: Brian Chen
Date: Tue, 2 Aug 2022 18:55:30 -0700
Subject: [PATCH 345/490] passthrough safe ccalls in threading code
---
src/compiler/reverse.jl | 14 +++++++++++---
1 file changed, 11 insertions(+), 3 deletions(-)
diff --git a/src/compiler/reverse.jl b/src/compiler/reverse.jl
index ba10ea5b1..6e88e7273 100644
--- a/src/compiler/reverse.jl
+++ b/src/compiler/reverse.jl
@@ -254,6 +254,15 @@ xaccum(ir) = nothing
xaccum(ir, x) = x
xaccum(ir, xs...) = push!(ir, xcall(Zygote, :accum, xs...))
+function passthrough_expr(ex::Expr)
+ # Metadata we want to preserve
+ isexpr(ex, GlobalRef, :call, :isdefined, :inbounds, :meta, :loopinfo) && return true
+ # ccalls and more that are safe to preserve/required for proper operation:
+ # - jl_set_task_threadpoolid: added in 1.9 for @spawn
+ isexpr(ex, :foreigncall) && ex.args[1] in (:jl_set_task_threadpoolid,) && return true
+ return false
+end
+
function adjoint(pr::Primal)
ir, sigs = adjointcfg(pr)
for b in reverse(blocks(pr.ir))
@@ -278,10 +287,9 @@ function adjoint(pr::Primal)
end
elseif ex isa Core.PiNode
grads[ex.val] = grads[v]
- elseif isexpr(ex, GlobalRef, :call, :isdefined, :inbounds, :meta, :loopinfo)
- elseif isexpr(ex)
+ elseif isexpr(ex) && !passthrough_expr(ex)
push!(rb, stmt(xcall(Base, :error, """
- Can't differentiate $(ex.head) expression.
+ Can't differentiate $(ex.head) expression $ex.
You might want to check the Zygote limitations documentation.
https://fluxml.ai/Zygote.jl/latest/limitations
"""),
From 7d0376a1a0cea719b943292594b4753fe6b0f3e0 Mon Sep 17 00:00:00 2001
From: Brian Chen
Date: Tue, 2 Aug 2022 20:41:37 -0700
Subject: [PATCH 346/490] function name is actually a QuoteNode
---
src/compiler/reverse.jl | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/src/compiler/reverse.jl b/src/compiler/reverse.jl
index 6e88e7273..532644914 100644
--- a/src/compiler/reverse.jl
+++ b/src/compiler/reverse.jl
@@ -259,7 +259,7 @@ function passthrough_expr(ex::Expr)
isexpr(ex, GlobalRef, :call, :isdefined, :inbounds, :meta, :loopinfo) && return true
# ccalls and more that are safe to preserve/required for proper operation:
# - jl_set_task_threadpoolid: added in 1.9 for @spawn
- isexpr(ex, :foreigncall) && ex.args[1] in (:jl_set_task_threadpoolid,) && return true
+ isexpr(ex, :foreigncall) && unwrapquote(ex.args[1]) in (:jl_set_task_threadpoolid,) && return true
return false
end
From 83cdacca7d150a13d4f9ff302cd452b3a3961fa4 Mon Sep 17 00:00:00 2001
From: Brian Chen
Date: Tue, 9 Aug 2022 22:05:43 -0700
Subject: [PATCH 347/490] Add rule for Dict iteration
---
src/lib/base.jl | 39 +++++++++++++++++++++++++++++++++++++++
test/lib/base.jl | 32 ++++++++++++++++++++++++++++++++
2 files changed, 71 insertions(+)
diff --git a/src/lib/base.jl b/src/lib/base.jl
index 79dfb77b6..161dd6e60 100644
--- a/src/lib/base.jl
+++ b/src/lib/base.jl
@@ -47,6 +47,45 @@ end
end
end
+# This rule behaves much like the getindex adjoint,
+# just with an (internal) ordinal index instead of a key.
+function _pullback(cx::AContext, ::typeof(iterate), d::Dict, i)
+ iter = iterate(d, i)
+ function dict_iterate_pullback(Δ)
+ (iter === nothing || Δ === nothing) && return
+ k, v = iter[1]
+ _, dv = Δ[1]
+ accum_param(cx, v, dv) === nothing && return
+ grad = grad_mut(cx, d)
+ grad[k] = accum(get(grad, k, nothing), dv)
+ return (nothing, grad, nothing)
+ end
+ return iter, dict_iterate_pullback
+end
+
+# ...while this one is to avoid duplicating code or differentiating skip_deleted.
+# The alternative would be to write a rule for the private _iterate(::Dict, i).
+function _pullback(cx::AContext, ::typeof(iterate), d::Dict)
+ # Calculation of i is the same used in iterate(::Dict)
+ return _pullback(cx, iterate, d, Base.skip_deleted(d, d.idxfloor))
+end
+
+function _pullback(cx::AContext, ::typeof(iterate), vi::Base.ValueIterator{<:Dict}, i::Int)
+ iter = iterate(vi, i)
+ function values_iterate_pullback(Δ)
+ (iter === nothing || Δ === nothing) && return
+ v, dv = iter[1], Δ[1]
+ accum_param(cx, v, dv) === nothing && return
+ # Same as vi.dict.keys[i], but without reaching into Dict internals.
+ # Iterating the dict instead of keys() is to hit the rules above in nested AD.
+ k = iterate(vi.dict, i)[1][1]
+ grad = grad_mut(cx, vi.dict)
+ grad[k] = accum(get(grad, k, nothing), dv)
+ return (nothing, (; dict = grad), nothing)
+ end
+ return iter, values_iterate_pullback
+end
+
# Channels
grad_mut(ch::Channel) = Channel(ch.sz_max)
diff --git a/test/lib/base.jl b/test/lib/base.jl
index 5186483da..74f129f6d 100644
--- a/test/lib/base.jl
+++ b/test/lib/base.jl
@@ -10,4 +10,36 @@
@test result1 == result2
end
+
+ @testset "Dict iteration" begin
+ # https://github.com/FluxML/Zygote.jl/issues/1065
+ function sumkv(d)
+ s = zero(d["c"])
+ for (k, v) in d
+ s += v
+ k == :b && (s += v)
+ end
+ return sum(s)
+ end
+
+ function sumvals(d)
+ s = zero(d["c"])
+ for v in values(d)
+ s += v
+ end
+ return sum(s)
+ end
+
+ d_num = Dict(:a => 3, :b => 4, "c" => 5)
+ d_arr = Dict(:a => [3], :b => [4], "c" => [5])
+ ps = d_arr |> values |> collect |> Params
+
+ @test gradient(sumkv, d_num)[1] == Dict(:a => 1, :b => 2, "c" => 1)
+ grads = gradient(() -> sumkv(d_arr), ps)
+ @test (grads[d_arr[:a]], grads[d_arr[:b]], grads[d_arr["c"]]) == ([1], [2], [1])
+
+ @test gradient(sumvals, d_num)[1] == Dict(:a => 1, :b => 1, "c" => 1)
+ grads = gradient(() -> sumvals(d_arr), ps)
+ @test (grads[d_arr[:a]], grads[d_arr[:b]], grads[d_arr["c"]]) == ([1], [1], [1])
+ end
end
From 24a6111c851c8dfab87628c5227b11a2dbf89648 Mon Sep 17 00:00:00 2001
From: Brian Chen
Date: Mon, 8 Aug 2022 16:28:29 -0700
Subject: [PATCH 348/490] Treat Pairs(NamedTuple) as NamedTuple for indexing
This prevents issues with double-counting when using kwargs.
---
src/lib/base.jl | 28 ++++++++++++++++++++++++++--
test/features.jl | 17 ++++++++++++++---
2 files changed, 40 insertions(+), 5 deletions(-)
diff --git a/src/lib/base.jl b/src/lib/base.jl
index 79dfb77b6..21ca62b1c 100644
--- a/src/lib/base.jl
+++ b/src/lib/base.jl
@@ -119,11 +119,11 @@ end
# named tuple
@adjoint function pairs(t::NamedTuple{N}) where N
-
+
pairs_namedtuple_pullback(dx::NamedTuple) = (dx.data,)
pairs_namedtuple_pullback(dx::Tuple{}) = (NamedTuple(),)
-
+
function pairs_namedtuple_pullback(Δ::Dict)
t0 = map(zero, t)
for (idx, v) in Δ
@@ -145,6 +145,30 @@ else
@adjoint merge(nt::NamedTuple, dict::Dict) = pullback(merge, nt, (;dict...))
end
+# Keyword arguments pretend to be a Dict, but are secretly wrapping a NamedTuple.
+# We can treat them much the same, just with some plumbing to handle the extra `itr` field.
+function _pullback(::AContext, ::typeof(getindex),
+ ps::Iterators.Pairs{<:Any,<:Any,<:Any,<:NamedTuple}, k)
+ # So we don't close over kwarg values in the pullback
+ data = map(_ -> nothing, NamedTuple(ps))
+ function kwargs_getindex_pullback(Δ)
+ dps = (data = Base.setindex(data, Δ, k), itr = nothing)
+ return (nothing, dps, nothing)
+ end
+ return ps[k], kwargs_getindex_pullback
+end
+
+function _pullback(cx::AContext, ::typeof(literal_getindex),
+ ps::Iterators.Pairs{<:Any,<:Any,<:Any,<:NamedTuple}, ::Val{K}) where K
+ val, gf_back = _pullback(cx, literal_getfield, NamedTuple(ps), Val(K))
+ function kwargs_literal_getindex_pullback(Δ)
+ dps = (data = gf_back(Δ)[2], itr = nothing)
+ return (nothing, dps, nothing)
+ end
+ return val, kwargs_literal_getindex_pullback
+end
+
+# Misc.
@adjoint function Base.getfield(p::Pair, i::Int)
function pair_getfield_pullback(Δ)
f, s = i == 1 ? (Δ, nothing) : (nothing, Δ)
diff --git a/test/features.jl b/test/features.jl
index d4f68d36b..4c16267f2 100644
--- a/test/features.jl
+++ b/test/features.jl
@@ -552,6 +552,17 @@ end
@test gradient(x -> x[].a, Ref((a=1, b=2))) == ((x = (a = 1, b = nothing),),)
@test gradient(x -> x[1][].a, [Ref((a=1, b=2)), Ref((a=3, b=4))]) == ([(x = (a = 1, b = nothing),), nothing],)
@test gradient(x -> x[1].a, [(a=1, b=2), "three"]) == ([(a = 1, b = nothing), nothing],)
+
+ @testset "indexing kwargs" begin
+ inner_lit_index(; kwargs...) = kwargs[:x]
+ outer_lit_index(; kwargs...) = inner_lit_index(; x=kwargs[:x])
+
+ inner_dyn_index(k; kwargs...) = kwargs[k]
+ outer_dyn_index(k; kwargs...) = inner_dyn_index(k; x=kwargs[k])
+
+ @test gradient(x -> outer_lit_index(; x), 0.0) == (1.0,)
+ @test gradient((x, k) -> outer_dyn_index(k; x), 0.0, :x) == (1.0, nothing)
+ end
end
function type_test()
@@ -562,7 +573,7 @@ end
@testset "Pairs" begin
@test (x->10*pairs((a=x, b=2))[1])'(100) === 10.0
- @test (x->10*pairs((a=x, b=2))[2])'(100) === 0
+ @test (x->10*pairs((a=x, b=2))[2])'(100) === nothing
foo(;kw...) = 1
@test gradient(() -> foo(a=1,b=2.0)) === ()
@@ -578,8 +589,8 @@ end
@testset "kwarg splatting, pass in object" begin
g(; kwargs...) = kwargs[:x] * kwargs[:z]
h(somedata) = g(; somedata...)
- @test gradient(h, (; x=3.0, y=4.0, z=2.3)) == ((x = 2.3, y = 0.0, z = 3.0),)
- @test gradient(h, Dict(:x=>3.0, :y=>4.0, :z=>2.3)) == ((y = 0.0, z = 3.0, x = 2.3),)
+ @test gradient(h, (; x=3.0, y=4.0, z=2.3)) == ((x = 2.3, y = nothing, z = 3.0),)
+ @test gradient(h, Dict(:x=>3.0, :y=>4.0, :z=>2.3)) == ((y = nothing, z = 3.0, x = 2.3),)
end
@testset "Iterators" begin
From 17a5673bdfd7d8fcdf27454b75b15efdb6477c9e Mon Sep 17 00:00:00 2001
From: Michael Abbott <32575566+mcabbott@users.noreply.github.com>
Date: Sun, 14 Aug 2022 20:33:17 -0700
Subject: [PATCH 349/490] broadcast adjoint for unary minus (#1287)
---
src/lib/broadcast.jl | 2 ++
1 file changed, 2 insertions(+)
diff --git a/src/lib/broadcast.jl b/src/lib/broadcast.jl
index 8c0d3c54c..98124bd03 100644
--- a/src/lib/broadcast.jl
+++ b/src/lib/broadcast.jl
@@ -80,6 +80,8 @@ unbroadcast(x::AbstractArray, x̄::Nothing) = nothing
@adjoint broadcasted(::typeof(-), x::Numeric, y::Numeric) = x .- y,
Δ -> (nothing, unbroadcast(x, Δ), _minus(unbroadcast(y, Δ)))
+@adjoint broadcasted(::typeof(-), x::Numeric) = .-x,
+ Δ -> (nothing, _minus(Δ))
_minus(Δ) = -Δ
_minus(::Nothing) = nothing
From bd5ce6e6e394081b6e7b28d669e8b8e7b3e05176 Mon Sep 17 00:00:00 2001
From: Michael Abbott <32575566+mcabbott@users.noreply.github.com>
Date: Mon, 15 Aug 2022 15:58:02 -0700
Subject: [PATCH 350/490] v0.6.44
---
Project.toml | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/Project.toml b/Project.toml
index 8328eb69a..c61ff68c2 100644
--- a/Project.toml
+++ b/Project.toml
@@ -1,6 +1,6 @@
name = "Zygote"
uuid = "e88e6eb3-aa80-5325-afca-941959d7151f"
-version = "0.6.43"
+version = "0.6.44"
[deps]
AbstractFFTs = "621f4979-c628-5d54-868e-fcf4e3e8185c"
From 7855c46db1d1e4dc133f76b0cdf61e199001fadd Mon Sep 17 00:00:00 2001
From: Miha Zgubic
Date: Tue, 23 Aug 2022 15:38:51 +0100
Subject: [PATCH 351/490] compat ChainRules for @allowscalar
---
Project.toml | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/Project.toml b/Project.toml
index c61ff68c2..37d40b4f5 100644
--- a/Project.toml
+++ b/Project.toml
@@ -27,7 +27,7 @@ ZygoteRules = "700de1a5-db45-46bc-99cf-38207098b444"
[compat]
AbstractFFTs = "0.5, 1.0"
-ChainRules = "1.37"
+ChainRules = "1.44.1"
ChainRulesCore = "1.9"
ChainRulesTestUtils = "1"
DiffRules = "1.4"
From 77dea335d1d51b8a3a4ba1893bd7cf1160862628 Mon Sep 17 00:00:00 2001
From: Brian Chen
Date: Tue, 23 Aug 2022 19:07:15 -0700
Subject: [PATCH 352/490] Bump version to v0.6.45
---
Project.toml | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/Project.toml b/Project.toml
index 37d40b4f5..58e84dbf2 100644
--- a/Project.toml
+++ b/Project.toml
@@ -1,6 +1,6 @@
name = "Zygote"
uuid = "e88e6eb3-aa80-5325-afca-941959d7151f"
-version = "0.6.44"
+version = "0.6.45"
[deps]
AbstractFFTs = "621f4979-c628-5d54-868e-fcf4e3e8185c"
From fb34703a2c57cec8cbdd23be816250dc9de17e91 Mon Sep 17 00:00:00 2001
From: Brian Chen
Date: Thu, 25 Aug 2022 21:28:19 -0700
Subject: [PATCH 353/490] Handle nothing grads for Pairs.data
---
src/lib/base.jl | 2 +-
test/features.jl | 4 ++++
2 files changed, 5 insertions(+), 1 deletion(-)
diff --git a/src/lib/base.jl b/src/lib/base.jl
index 21ca62b1c..1a85cc56c 100644
--- a/src/lib/base.jl
+++ b/src/lib/base.jl
@@ -162,7 +162,7 @@ function _pullback(cx::AContext, ::typeof(literal_getindex),
ps::Iterators.Pairs{<:Any,<:Any,<:Any,<:NamedTuple}, ::Val{K}) where K
val, gf_back = _pullback(cx, literal_getfield, NamedTuple(ps), Val(K))
function kwargs_literal_getindex_pullback(Δ)
- dps = (data = gf_back(Δ)[2], itr = nothing)
+ dps = (data = gradindex(gf_back(Δ), 2), itr = nothing)
return (nothing, dps, nothing)
end
return val, kwargs_literal_getindex_pullback
diff --git a/test/features.jl b/test/features.jl
index 4c16267f2..e3e0e55bd 100644
--- a/test/features.jl
+++ b/test/features.jl
@@ -591,6 +591,10 @@ end
h(somedata) = g(; somedata...)
@test gradient(h, (; x=3.0, y=4.0, z=2.3)) == ((x = 2.3, y = nothing, z = 3.0),)
@test gradient(h, Dict(:x=>3.0, :y=>4.0, :z=>2.3)) == ((y = nothing, z = 3.0, x = 2.3),)
+
+ # for when no kwargs have grads backpropogated
+ no_kwarg_grad(x; kwargs...) = x[kwargs[:i]]
+ @test gradient(x -> no_kwarg_grad(x; i=1), [1]) == (1,)
end
@testset "Iterators" begin
From 3de236ec143b73cddd1e081f30730ad4c5952b77 Mon Sep 17 00:00:00 2001
From: Christopher Rackauckas
Date: Fri, 26 Aug 2022 00:36:23 -0400
Subject: [PATCH 354/490] Add DiffEqFlux BasicNeuralDE Test
---
.github/workflows/Downstream.yml | 1 +
1 file changed, 1 insertion(+)
diff --git a/.github/workflows/Downstream.yml b/.github/workflows/Downstream.yml
index 09f4f2f5d..47f8032a5 100644
--- a/.github/workflows/Downstream.yml
+++ b/.github/workflows/Downstream.yml
@@ -23,6 +23,7 @@ jobs:
- {user: TuringLang, repo: DynamicPPL.jl, group: All}
- {user: TuringLang, repo: DistributionsAD.jl, group: Zygote}
- {user: SciML, repo: DiffEqFlux.jl, group: Layers}
+ - {user: SciML, repo: DiffEqFlux.jl, group: BasicNeuralDE}
- {user: SciML, repo: NeuralPDE.jl, group: NNPDE}
- {user: JuliaMolSim, repo: Molly.jl, group: Zygote}
steps:
From 4183226eff3ed45ecee36701eb6a569ad08fd3cb Mon Sep 17 00:00:00 2001
From: Brian Chen