Skip to content

Commit

Permalink
[feat] finish tensor functions
Browse files Browse the repository at this point in the history
  • Loading branch information
liushuwei committed Jun 29, 2023
1 parent 43aa645 commit 64dc28b
Showing 1 changed file with 2 additions and 1 deletion.
3 changes: 2 additions & 1 deletion minitorch/tensor_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,7 +210,8 @@ class Permute(Function):
@staticmethod
def forward(ctx: Context, a: Tensor, order: Tensor) -> Tensor:
ctx.save_for_backward(a.shape, a._tensor.strides)
return a._new(a._tensor.permute(*order))
ord = [int(order._tensor.get(i)) for i in order._tensor.indices()]
return a._new(a._tensor.permute(*ord))

@staticmethod
def backward(ctx: Context, grad_output: Tensor) -> Tuple[Tensor, float]:
Expand Down

0 comments on commit 64dc28b

Please sign in to comment.