Skip to content

Commit

Permalink
move require usage to extensions on 1.9+ (FluxML#1390)
Browse files Browse the repository at this point in the history
* move require usage to extensions on 1.9+

* remove extra loads in tracker extension

* fix an unexprted function
  • Loading branch information
KristofferC committed Mar 14, 2023
1 parent 108e5a1 commit 1aec78f
Show file tree
Hide file tree
Showing 6 changed files with 68 additions and 18 deletions.
15 changes: 15 additions & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -26,12 +26,24 @@ SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
ZygoteRules = "700de1a5-db45-46bc-99cf-38207098b444"

[weakdeps]
Colors = "5ae59095-9a9b-59fe-a467-6f913c188581"
Distances = "b4f34e82-e78d-54a5-968a-f98e89d6e8f7"
Tracker= "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c"

[extensions]
ZygoteColorsExt = "Colors"
ZygoteDistancesExt = "Distances"
ZygoteTrackerExt = "Tracker"

[compat]
AbstractFFTs = "1.3.1"
ChainRules = "1.44.1"
ChainRulesCore = "1.9"
ChainRulesTestUtils = "1"
Colors = "0.12"
DiffRules = "1.4"
Distances = "0.10"
FillArrays = "0.8, 0.9, 0.10, 0.11, 0.12, 0.13"
ForwardDiff = "0.10"
GPUArrays = "8.4.2"
Expand All @@ -43,17 +55,20 @@ NaNMath = "0.3, 1"
Requires = "1.1"
SnoopPrecompile = "1.0.3"
SpecialFunctions = "1.6, 2"
Tracker = "0.2"
ZygoteRules = "0.2.1"
julia = "1.6"

[extras]
Colors = "5ae59095-9a9b-59fe-a467-6f913c188581"
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
ChainRulesTestUtils = "cdddcdb0-9152-4a09-a978-84456f9df70a"
Distances = "b4f34e82-e78d-54a5-968a-f98e89d6e8f7"
FFTW = "7a1cc6ca-52ef-59f5-83cd-3a7055c09341"
FiniteDifferences = "26cc04aa-876d-5657-8c51-4c34ba976000"
PyCall = "438e738f-606a-5dbb-bf0a-cddfbfd45ab0"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c"

[targets]
test = ["ChainRulesTestUtils", "CUDA", "Distances", "FFTW", "FiniteDifferences", "PyCall", "Test"]
13 changes: 13 additions & 0 deletions ext/ZygoteColorsExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
module ZygoteColorsExt

if isdefined(Base, :get_extension)
using Zygote
using Colors
else
using ..Zygote
using ..Colors
end

Zygote.@non_differentiable Colors.ColorTypes._parameter_upper_bound(::Any...)

end
24 changes: 19 additions & 5 deletions src/lib/distances.jl → ext/ZygoteDistancesExt.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,16 @@
using .Distances
module ZygoteDistancesExt

if isdefined(Base, :get_extension)
using Zygote
using Distances
using LinearAlgebra
else
using ..Zygote
using ..Distances
using ..LinearAlgebra
end

using Zygote: @adjoint, @adjoint, AContext, _pullback

@adjoint function (::SqEuclidean)(x::AbstractVector, y::AbstractVector)
δ = x .- y
Expand Down Expand Up @@ -66,7 +78,7 @@ end

_sqrt_if_positive(d, δ) = d > δ ? sqrt(d) : zero(d)

function _pullback(cx::AContext, ::Core.kwftype(typeof(pairwise)),
function Zygote._pullback(cx::AContext, ::Core.kwftype(typeof(pairwise)),
kws::@NamedTuple{dims::Int}, ::typeof(pairwise), dist::Euclidean,
X::AbstractMatrix, Y::AbstractMatrix)
# Modify the forwards-pass slightly to ensure stability on the reverse.
Expand All @@ -77,11 +89,11 @@ function _pullback(cx::AContext, ::Core.kwftype(typeof(pairwise)),
return _sqrt_if_positive.(D2, δ)
end
res, back = _pullback(cx, _pairwise_euclidean, SqEuclidean(dist.thresh), X, Y)
pairwise_Euclidean_pullback(Δ) = (nothing, nothing, back(unthunk_tangent(Δ))...)
pairwise_Euclidean_pullback(Δ) = (nothing, nothing, back(Zygote.unthunk_tangent(Δ))...)
return res, pairwise_Euclidean_pullback
end

function _pullback(cx::AContext, ::Core.kwftype(typeof(pairwise)),
function Zygote._pullback(cx::AContext, ::Core.kwftype(typeof(pairwise)),
kws::@NamedTuple{dims::Int}, ::typeof(pairwise), dist::Euclidean,
X::AbstractMatrix)
# Modify the forwards-pass slightly to ensure stability on the reverse.
Expand All @@ -92,6 +104,8 @@ function _pullback(cx::AContext, ::Core.kwftype(typeof(pairwise)),
return _sqrt_if_positive.(D2, δ)
end
res, back = _pullback(cx, _pairwise_euclidean, SqEuclidean(dist.thresh), X)
pairwise_Euclidean_pullback(Δ) = (nothing, nothing, back(unthunk_tangent(Δ))...)
pairwise_Euclidean_pullback(Δ) = (nothing, nothing, back(Zygote.unthunk_tangent(Δ))...)
return res, pairwise_Euclidean_pullback
end

end
17 changes: 17 additions & 0 deletions ext/ZygoteTrackerExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
module ZygoteTrackerExt

if isdefined(Base, :get_extension)
using Zygote
using Tracker: Tracker, TrackedArray, TrackedReal
else
using ..Zygote
using ..Tracker: Tracker, TrackedArray, TrackedReal
end

Zygote.unwrap(x::Union{TrackedArray,TrackedReal}) = Tracker.data(x)

Zygote.pullback(f, ps::Tracker.Params) = pullback(f, ZygtParams(ps))
Tracker.forward(f, ps::Params) = Tracker.forward(f, Tracker.Params(ps))
Tracker.gradient_(f, ps::Params) = Tracker.gradient_(f, Tracker.Params(ps))

end
10 changes: 4 additions & 6 deletions src/Zygote.jl
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,6 @@ include("lib/forward.jl")
include("lib/utils.jl")
include("lib/range.jl")
include("lib/logexpfunctions.jl")
@init @require Distances="b4f34e82-e78d-54a5-968a-f98e89d6e8f7" include("lib/distances.jl")

# we need to define this late, so that the genfuncs see lib.jl
# Move using statements out of this file to help with sysimage building
Expand All @@ -53,12 +52,11 @@ include("compiler/interface2.jl")

include("profiler/Profile.jl")

@init @require Tracker="9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c" begin
include("flux.jl")
end

@init @require Colors="5ae59095-9a9b-59fe-a467-6f913c188581" begin
@non_differentiable Colors.ColorTypes._parameter_upper_bound(::Any...)
if !isdefined(Base, :get_extension)
@init @require Distances="b4f34e82-e78d-54a5-968a-f98e89d6e8f7" include("../ext/ZygoteDistancesExt.jl")
@init @require Tracker="9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c" include("../ext/ZygoteTrackerExt.jl")
@init @require Colors="5ae59095-9a9b-59fe-a467-6f913c188581" include("../ext/ZygoteColorsExt.jl")
end

using InteractiveUtils
Expand Down
7 changes: 0 additions & 7 deletions src/flux.jl

This file was deleted.

0 comments on commit 1aec78f

Please sign in to comment.