Skip to content

Commit

Permalink
generated z2d
Browse files Browse the repository at this point in the history
  • Loading branch information
chengchingwen committed Mar 14, 2023
1 parent 756dd37 commit bb560f2
Showing 1 changed file with 27 additions and 8 deletions.
35 changes: 27 additions & 8 deletions src/compiler/chainrules.jl
Original file line number Diff line number Diff line change
Expand Up @@ -313,15 +313,34 @@ end
z2d(dx::NamedTuple, primal::AbstractDict) = dx

function z2d(delta::NamedTuple, primal::T) where T # arbitrart struct
fnames = fieldnames(T)
deltas = map(n -> get(delta, n, nothing), fnames)
primals = map(n -> getfield(primal, n), fnames)
inner = map(z2d, deltas, primals) # recurse into fields
if inner isa Tuple{Vararg{AbstractZero}}
return NoTangent() # collapse all-zero case
if @generated
fnames = fieldnames(T)
N = length(fnames)
deltas = [ :($(Symbol(:delta_, fname)) = get(delta, $(QuoteNode(fname)), nothing)) for fname in fnames ]
primals = [ :($(Symbol(:primal_, fname)) = getfield(primal, $(QuoteNode(fname)))) for fname in fnames ]
inner = Expr(:tuple, [ :(z2d($(Symbol(:delta_, fname)), $(Symbol(:primal_, fname)))) for fname in fnames ]...)
return quote
$(deltas...)
$(primals...)
inner = $inner
if inner isa Tuple{Vararg{AbstractZero}}
return NoTangent() # collapse all-zero case
else
backing = NamedTuple{$fnames}(inner)
return canonicalize(Tangent{T, typeof(backing)}(backing))
end
end
else
backing = NamedTuple{fnames}(inner)
return canonicalize(Tangent{T, typeof(backing)}(backing))
fnames = fieldnames(T)
deltas = map(n -> get(delta, n, nothing), fnames)
primals = map(n -> getfield(primal, n), fnames)
inner = map(z2d, deltas, primals) # recurse into fields
if inner isa Tuple{Vararg{AbstractZero}}
return NoTangent() # collapse all-zero case
else
backing = NamedTuple{fnames}(inner)
return canonicalize(Tangent{T, typeof(backing)}(backing))
end
end
end

Expand Down

0 comments on commit bb560f2

Please sign in to comment.