Skip to content

Commit

Permalink
Fix IndexError: tensors used as indices must be long, byte or bool te…
Browse files Browse the repository at this point in the history
…nsors
  • Loading branch information
muhammad_wasim committed Dec 28, 2022
1 parent 88a6f36 commit 3898290
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 7 deletions.
10 changes: 5 additions & 5 deletions decompressor.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,12 +126,12 @@

# batch
batch = samples[torch.randint(0, len(samples), size=[batchsize])]

# print(X[batch.long()])
# (seq. length, batch size, features) <- (batch size, seq. length, features)
Xgnd = X[batch].transpose(0, 1)
Ygnd = Ytxy[batch].transpose(0, 1)
Qgnd = Qtxy[batch].transpose(0, 1)
Qgnd_xfm = Qxfm[batch].transpose(0, 1)
Xgnd = X[batch.long()].transpose(0, 1)
Ygnd = Ytxy[batch.long()].transpose(0, 1)
Qgnd = Qtxy[batch.long()].transpose(0, 1)
Qgnd_xfm = Qxfm[batch.long()].transpose(0, 1)

# Generate latent variables Z
Zgnd = compressor((torch.cat((Ygnd, Qgnd), dim=-1) - compressor_mean) / compressor_std)
Expand Down
4 changes: 2 additions & 2 deletions stepper.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,8 +78,8 @@
# batch
batch = samples[torch.randint(0, len(samples), size=[batchsize])]

Xgnd = X[batch]
Zgnd = Z[batch]
Xgnd = X[batch.long()]
Zgnd = Z[batch.long()]

# Predict delta x and delta z over a window of s frames
Xtil = [Xgnd[:,0]]
Expand Down

0 comments on commit 3898290

Please sign in to comment.