Skip to content

Commit

Permalink
rm rules for eachslice, cumsum (FluxML#1253)
Browse files Browse the repository at this point in the history
* rm rules for eachslice, cumsum

* bump

* bound chainrules

* bump
  • Loading branch information
mcabbott committed Jun 30, 2022
1 parent 1936109 commit 7604288
Show file tree
Hide file tree
Showing 2 changed files with 2 additions and 33 deletions.
4 changes: 2 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name = "Zygote"
uuid = "e88e6eb3-aa80-5325-afca-941959d7151f"
version = "0.6.41"
version = "0.6.42"

[deps]
AbstractFFTs = "621f4979-c628-5d54-868e-fcf4e3e8185c"
Expand All @@ -25,7 +25,7 @@ ZygoteRules = "700de1a5-db45-46bc-99cf-38207098b444"

[compat]
AbstractFFTs = "0.5, 1.0"
ChainRules = "1.35.3"
ChainRules = "1.36.2"
ChainRulesCore = "1.9"
ChainRulesTestUtils = "1"
DiffRules = "1.4"
Expand Down
31 changes: 0 additions & 31 deletions src/lib/array.jl
Original file line number Diff line number Diff line change
Expand Up @@ -351,37 +351,6 @@ _backvar(xs, Δ, N::Int, mean) = (convert(eltype(xs), 2/N) .* Δ .* (xs .- mean)
return s, Δ -> _backvar(xs, Δ ./ (2 .* s), corrected, mean, dims)
end

@adjoint function cumsum(xs::AbstractVector; dims::Integer = 1)
dims == 1 || return copy(xs), Δ -> (Δ,)
cumsum(xs), Δ -> (reverse(cumsum(reverse(Δ))),)
end
@adjoint function cumsum(xs::AbstractArray; dims::Integer)
dims <= ndims(xs) || return copy(xs), Δ -> (Δ,)
cumsum(xs; dims=dims), Δ -> begin
(reverse(cumsum(reverse(Δ, dims=dims), dims=dims), dims=dims),)
end
end

@adjoint eachrow(x::AbstractVecOrMat) = collect(eachrow(x)), dys -> ∇eachslice(dys, x, 1)
@adjoint eachcol(x::AbstractVecOrMat) = collect(eachcol(x)), dys -> ∇eachslice(dys, x, 2)
@adjoint eachslice(x::AbstractArray; dims::Integer) =
collect(eachslice(x; dims=dims)), dys -> ∇eachslice(dys, x, dims)

function ∇eachslice(dys, x::AbstractArray, dim::Integer) where {TX}
i1 = findfirst(dy -> dy isa AbstractArray, dys)
i1 === nothing && return (zero(x),) # all slices get nothing
T = promote_type(eltype(dys[i1]), eltype(x))
dx = similar(x, T)
for i in axes(x, dim)
if dys[i] isa AbstractArray
copyto!(selectdim(dx,dim,i), dys[i])
else
selectdim(dx,dim,i) .= 0
end
end
(dx,)
end


# LinearAlgebra
# =============
Expand Down

0 comments on commit 7604288

Please sign in to comment.