Skip to content

Commit

Permalink
Support list/dict when wrapping function in wrap_ad
Browse files Browse the repository at this point in the history
  • Loading branch information
thomasw21 authored and Speierers committed Oct 5, 2022
1 parent e73234e commit 9f711c5
Showing 1 changed file with 10 additions and 1 deletion.
11 changes: 10 additions & 1 deletion drjit/router.py
Original file line number Diff line number Diff line change
Expand Up @@ -5869,10 +5869,19 @@ def backward(self):
grad_out_torch = drjit_to_torch(self.grad_out())
grad_out_torch = torch_ensure_shape(grad_out_torch, self.res_torch)
_torch.autograd.backward(self.res_torch, grad_out_torch)
args_grad_torch = [getattr(a, 'grad', None) for a in self.args_torch]
args_grad_torch = self.get_grads(self.args_torch)
args_grad = torch_to_drjit(args_grad_torch)
self.set_grad_in('args', args_grad)

@classmethod
def get_grads(cls, args):
if isinstance(args, _Sequence) and not isinstance(args, str):
return tuple(cls.get_grads(b) for b in args)
elif isinstance(args, _Mapping):
return {k: cls.get_grads(v) for k, v in args.items()}
else:
return getattr(args, 'grad', None)

return _dr.custom(ToTorch, args)

return f
Expand Down

0 comments on commit 9f711c5

Please sign in to comment.