Skip to content

Commit

Permalink
clean up of rt2
Browse files Browse the repository at this point in the history
  • Loading branch information
Kye committed Aug 30, 2023
1 parent 91539a7 commit b797040
Show file tree
Hide file tree
Showing 3 changed files with 37 additions and 4 deletions.
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -8,4 +8,4 @@ transformers
palm-rlhf-pytorch
tokenizers
wandb
classifier-free-guidance-pytorch>
classifier-free-guidance-pytorch
38 changes: 36 additions & 2 deletions rt2/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -518,7 +518,7 @@ def forward(self, x):
# Robotic Transformer

@beartype
class RT2(nn.Module):
class RT1(nn.Module):
def __init__(
self,
*,
Expand Down Expand Up @@ -629,4 +629,38 @@ def forward(
pooled = reduce(attended_tokens, 'b (f n) d -> b f d', 'mean', f = frames)

logits = self.to_logits(pooled)
return logits
return logits


class RT2:
def __init__(self):
self.vit = MaxViT(
num_classes=1000,
dim=96,
dim_conv_stem=64,
dim_head=32,
depth = (2, 2, 5, 2),
window_size=7,
mbconv_expansion_rate=4,
mbconv_shrinkage_rate=0.25,
dropout=0.1
)

self.model = RT1(
vit = self.vit,
num_actions = 11,
depth=6,
heads=8,
dim_head=64,
cond_drop_prob=0.2
)

def train(self, video, instructions):
train_logits = self.model(video, instructions)
return train_logits

def eval(self, video, instructions, cond_scale=1.0):
self.model.eval()
eval_logits = self.model(video, instructions, cond_scale=cond_scale)
return eval_logits

1 change: 0 additions & 1 deletion rt2/palme.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import torch
import torch.nn as nn

from rt2.transformer import (
AutoregressiveWrapper,
Expand Down

0 comments on commit b797040

Please sign in to comment.