Skip to content

Commit

Permalink
clean up
Browse files Browse the repository at this point in the history
  • Loading branch information
Kye committed Aug 28, 2023
1 parent 7512ee2 commit 5b126d3
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 70 deletions.
Binary file added .DS_Store
Binary file not shown.
88 changes: 18 additions & 70 deletions rt2/model.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,16 @@
from typing import Callable, List, Optional, Tuple

import torch
import torch.nn.functional as F
from torch import nn, einsum

from typing import List, Optional, Callable, Tuple
from beartype import beartype
from classifier_free_guidance_pytorch import (
AttentionTextConditioner,
TextConditioner,
classifier_free_guidance,
)
from einops import pack, rearrange, reduce, repeat, unpack

from einops import pack, unpack, repeat, reduce, rearrange
from einops.layers.torch import Rearrange, Reduce
from torch import einsum, nn

from functools import partial

from classifier_free_guidance_pytorch import TextConditioner, AttentionTextConditioner, classifier_free_guidance

# helpers

Expand Down Expand Up @@ -398,7 +398,7 @@ def forward(
attn_mask = None,
cond_fn: Optional[Callable] = None
):
x.shape[0]
b = x.shape[0]

if exists(context):
context = self.context_norm(context)
Expand Down Expand Up @@ -515,8 +515,10 @@ def forward(self, x):
x = unpack_one(x, ps, '* c n')
return x

# Robotic Transformer

@beartype
class RT1(nn.Module):
class RT2(nn.Module):
def __init__(
self,
*,
Expand Down Expand Up @@ -588,8 +590,7 @@ def forward(
cond_fns = self.conditioner(
texts,
cond_drop_prob = cond_drop_prob,
repeat_batch = (*((frames,) * self.num_vit_stages), *((1,) *
self.transformer_depth * 2))
repeat_batch = (*((frames,) * self.num_vit_stages), *((1,) * self.transformer_depth * 2))
)

vit_cond_fns, transformer_cond_fns = cond_fns[:-(depth * 2)], cond_fns[-(depth * 2):]
Expand All @@ -613,72 +614,19 @@ def forward(
# causal attention mask

attn_mask = torch.ones((frames, frames), dtype = torch.bool, device = device).triu(1)
attn_mask = repeat(attn_mask, 'i j -> (i r1) (j r2)', r1 = self.num_learned_tokens,
r2 = self.num_learned_tokens)
attn_mask = repeat(attn_mask, 'i j -> (i r1) (j r2)', r1 = self.num_learned_tokens, r2 = self.num_learned_tokens)

# sinusoidal positional embedding

pos_emb = posemb_sincos_1d(frames, learned_tokens.shape[-1],
dtype = learned_tokens.dtype, device = learned_tokens.device)
pos_emb = posemb_sincos_1d(frames, learned_tokens.shape[-1], dtype = learned_tokens.dtype, device = learned_tokens.device)

learned_tokens = learned_tokens + repeat(pos_emb, 'n d -> (n r) d', r = self.num_learned_tokens)

# attention

attended_tokens = self.transformer(learned_tokens,
cond_fns = transformer_cond_fns,
attn_mask = ~attn_mask)
attended_tokens = self.transformer(learned_tokens, cond_fns = transformer_cond_fns, attn_mask = ~attn_mask)

pooled = reduce(attended_tokens, 'b (f n) d -> b f d', 'mean', f = frames)

logits = self.to_logits(pooled)
return logits


class RT2:
def __init__(
self,
num_classes: int = 1000,
dim_conv_stem: int = 64,
dim: int = 96,
dim_head: int = 32,
vit_depth: tuple = (2, 2, 5, 2),
window_size: int = 7,
mbconv_expansion_rate: int = 4,
mbconv_shrinkage_rate: float = 0.25,
dropout: float = 0.1,
num_actions: int = 11,
rt1_depth: int = 6,
heads: int = 8,
cond_drop_prob: float=0.2
):
self.vit = MaxViT(
num_classes=num_classes,
dim_conv_stem=dim_conv_stem,
dim=dim,
dim_head=32,
depth=vit_depth,
window_size=window_size,
mbconv_expansion_rate=mbconv_expansion_rate,
mbconv_shrinkage_rate=mbconv_shrinkage_rate,
dropout=dropout
)

self.model = RT1(
vit=self.vit,
num_actions=num_actions,
depth=rt1_depth,
heads=heads,
dim_head=dim_head,
cond_drop_prob=cond_drop_prob
)

def eval(self):
self.model.eval()

def __call__(self, video, instructions, cond_scale=None):
if cond_scale:
return self.model(video, instructions, cond_scale=cond_scale)
else:
return self.model(video, instructions)

return logits

0 comments on commit 5b126d3

Please sign in to comment.