Skip to content

Commit

Permalink
palme
Browse files Browse the repository at this point in the history
  • Loading branch information
Kye committed Aug 28, 2023
1 parent 8c52f20 commit 91539a7
Show file tree
Hide file tree
Showing 3 changed files with 2,076 additions and 0 deletions.
380 changes: 380 additions & 0 deletions rt2/attend.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,380 @@
from functools import partial
from typing import Optional

import torch
from torch import nn, einsum, Tensor
import torch.nn.functional as F

from collections import namedtuple
from functools import wraps
from packaging import version
from dataclasses import dataclass

from einops import rearrange, repeat

# constants

EfficientAttentionConfig = namedtuple('EfficientAttentionConfig', ['enable_flash', 'enable_math', 'enable_mem_efficient'])

@dataclass
class Intermediates:
qk_similarities: Optional[Tensor] = None
pre_softmax_attn: Optional[Tensor] = None
post_softmax_attn: Optional[Tensor] = None

def to_tuple(self):
return (self.qk_similarities, self.pre_softmax_attn, self.post_softmax_attn)

# helpers

def exists(val):
return val is not None

def default(val, d):
return val if exists(val) else d

def compact(arr):
return [*filter(exists, arr)]

def once(fn):
called = False
@wraps(fn)
def inner(x):
nonlocal called
if called:
return
called = True
return fn(x)
return inner

print_once = once(print)

# functions for creating causal mask
# need a special one for onnx cpu (no support for .triu)

def create_causal_mask(i, j, device):
return torch.ones((i, j), device = device, dtype = torch.bool).triu(j - i + 1)

def onnx_create_causal_mask(i, j, device):
r = torch.arange(i, device = device)
causal_mask = rearrange(r, 'i -> i 1') < rearrange(r, 'j -> 1 j')
causal_mask = F.pad(causal_mask, (j - i, 0), value = False)
return causal_mask

# main class

class Attend(nn.Module):
def __init__(
self,
*,
dropout = 0.,
causal = False,
heads = None,
talking_heads = False,
sparse_topk = None,
scale = None,
qk_norm = False,
flash = False,
add_zero_kv = False,
onnxable = False
):
super().__init__()
self.scale = scale
self.qk_norm = qk_norm

self.causal = causal
self.create_causal_mask = onnx_create_causal_mask if onnxable else create_causal_mask

self.attn_fn = partial(F.softmax, dtype = torch.float32) if not qk_norm else F.softmax

self.dropout = dropout
self.attn_dropout = nn.Dropout(dropout)

# talking heads

assert not (flash and talking_heads), 'talking heads not compatible with flash attention'

self.talking_heads = talking_heads
if talking_heads:
self.pre_softmax_talking_heads = nn.Conv2d(heads, heads, 1, bias = False)
self.post_softmax_talking_heads = nn.Conv2d(heads, heads, 1, bias = False)

# sparse topk

assert not (flash and sparse_topk), 'sparse topk not compatible with flash attention'
self.sparse_topk = sparse_topk

# add a key / value token composed of zeros
# in case this helps controlling outliers, proposed by https://www.evanmiller.org/attention-is-off-by-one.html

self.add_zero_kv = add_zero_kv

# flash attention

self.flash = flash
assert not (flash and version.parse(torch.__version__) < version.parse('2.0.0')), 'in order to use flash attention, you must be using pytorch 2.0 or above'

# determine efficient attention configs for cuda and cpu

self.cpu_config = EfficientAttentionConfig(True, True, True)
self.cuda_config = None

if not torch.cuda.is_available() or not flash:
return

device_properties = torch.cuda.get_device_properties(torch.device('cuda'))

if device_properties.major == 8 and device_properties.minor == 0:
print_once('A100 GPU detected, using flash attention if input tensor is on cuda')
self.cuda_config = EfficientAttentionConfig(True, False, False)
else:
print_once('Non-A100 GPU detected, using math or mem efficient attention if input tensor is on cuda')
self.cuda_config = EfficientAttentionConfig(False, True, True)

def flash_attn(
self,
q, k, v,
mask = None,
attn_bias = None
):
batch, heads, q_len, _, k_len, is_cuda, device = *q.shape, k.shape[-2], q.is_cuda, q.device

# Recommended for multi-query single-key-value attention by Tri Dao
# kv shape torch.Size([1, 512, 64]) -> torch.Size([1, 8, 512, 64])

if k.ndim == 3:
k = rearrange(k, 'b ... -> b 1 ...').expand_as(q)

if v.ndim == 3:
v = rearrange(v, 'b ... -> b 1 ...').expand_as(q)

# handle scale - by default they scale by dim_head ** -0.5, but need to take care if using cosine sim attention

if self.qk_norm:
default_scale = q.shape[-1] ** -0.5
q = q * (default_scale / self.scale)

# Check if mask exists and expand to compatible shape
# The mask is B L, so it would have to be expanded to B H N L

causal = self.causal

if exists(mask):
assert mask.ndim == 4
mask = mask.expand(batch, heads, q_len, k_len)

# manually handle causal mask, if another mask was given

if causal:
causal_mask = self.create_causal_mask(q_len, k_len, device = device)
mask = mask & ~causal_mask
causal = False

# handle alibi positional bias
# convert from bool to float

if exists(attn_bias):
attn_bias = rearrange(attn_bias, 'h i j -> 1 h i j').expand(batch, heads, -1, -1)

# if mask given, the mask would already contain the causal mask from above logic
# otherwise, if no mask given but still causal, mask out alibi positional bias to a large negative number

mask_value = -torch.finfo(q.dtype).max

if exists(mask):
attn_bias = attn_bias.masked_fill(~mask, mask_value // 2)
elif causal:
causal_mask = self.create_causal_mask(q_len, k_len, device = device)
attn_bias = attn_bias.masked_fill(causal_mask, mask_value // 2)
causal = False

# scaled_dot_product_attention handles attn_mask either as bool or additive bias
# make it an additive bias here

mask = attn_bias

# Check if there is a compatible device for flash attention

config = self.cuda_config if is_cuda else self.cpu_config

# pytorch 2.0 flash attn: q, k, v, mask, dropout, causal, softmax_scale

with torch.backends.cuda.sdp_kernel(**config._asdict()):
out = F.scaled_dot_product_attention(
q, k, v,
attn_mask = mask,
dropout_p = self.dropout if self.training else 0.,
is_causal = causal
)

return out, Intermediates()

def forward(
self,
q, k, v,
mask = None,
attn_bias = None,
prev_attn = None
):
"""
einstein notation
b - batch
h - heads
n, i, j - sequence length (base sequence length, source, target)
d - feature dimension
"""

n, heads, kv_heads, device = q.shape[-2], q.shape[1], k.shape[1], q.device

scale = default(self.scale, q.shape[-1] ** -0.5)

# handle grouped multi-query attention

if kv_heads == 1:
k, v = map(lambda t: rearrange(t, 'b 1 n d -> b n d'), (k, v))
elif kv_heads < heads:
k, v = map(lambda t: repeat(t, 'b kvh n d -> b (r kvh) n d', r = heads // kv_heads), (k, v))

# handle zero kv, as means for allowing network to attend to nothing

if self.add_zero_kv:
k, v = map(lambda t: F.pad(t, (0, 0, 1, 0), value = 0.), (k, v))

if exists(mask):
mask = F.pad(mask, (1, 0), value = True)

if exists(attn_bias):
attn_bias = F.pad(attn_bias, (1, 0), value = 0.)

if self.flash:
assert not exists(prev_attn), 'residual attention not compatible with flash attention'
return self.flash_attn(q, k, v, mask = mask, attn_bias = attn_bias)

kv_einsum_eq = 'b j d' if k.ndim == 3 else 'b h j d'

dots = einsum(f'b h i d, {kv_einsum_eq} -> b h i j', q, k) * scale

if exists(prev_attn):
dots = dots + prev_attn

qk_similarities = dots.clone()

if self.talking_heads:
dots = self.pre_softmax_talking_heads(dots)

if exists(attn_bias):
dots = dots + attn_bias

i, j, dtype = *dots.shape[-2:], dots.dtype

mask_value = -torch.finfo(dots.dtype).max

if exists(self.sparse_topk) and self.sparse_topk < j:
top_values, _ = dots.topk(self.sparse_topk, dim = -1)
sparse_topk_mask = dots < top_values[..., -1:]
mask = (mask & sparse_topk_mask) if exists(mask) else sparse_topk_mask

if exists(mask):
dots = dots.masked_fill(~mask, mask_value)

if self.causal:
causal_mask = self.create_causal_mask(i, j, device = device)
dots = dots.masked_fill(causal_mask, mask_value)

pre_softmax_attn = dots.clone()

attn = self.attn_fn(dots, dim = -1)
attn = attn.type(dtype)

post_softmax_attn = attn.clone()

attn = self.attn_dropout(attn)

if self.talking_heads:
attn = self.post_softmax_talking_heads(attn)

out = einsum(f'b h i j, {kv_einsum_eq} -> b h i d', attn, v)

intermediates = Intermediates(
qk_similarities = qk_similarities,
pre_softmax_attn = pre_softmax_attn,
post_softmax_attn = post_softmax_attn
)

return out, intermediates

# cascading heads logic

def to_single_heads(t, dim = 1):
heads = t.unbind(dim = dim)
return tuple(head.unsqueeze(dim) for head in heads)

class CascadingHeads(nn.Module):
def __init__(self, attend: Attend):
super().__init__()
self.attend = attend

def forward(
self,
q, k, v,
mask = None,
attn_bias = None,
prev_attn = None
):
assert q.shape[-1] == v.shape[-1], 'cascading heads can only be done if query / key and value head dimensions are the same'

# split inputs into per-head inputs

heads = q.shape[1]

queries = to_single_heads(q)
keys = to_single_heads(k) if k.ndim == 4 else ((k,) * heads)
values = to_single_heads(v) if v.ndim == 4 else ((v,) * heads)

mask = (mask,) * heads

attn_bias = to_single_heads(attn_bias, dim = 0) if exists(attn_bias) else ((None,) * heads)
prev_attn = to_single_heads(prev_attn) if exists(prev_attn) else ((None,) * heads)

# now loop through each head, without output of previous head summed with the next head
# thus cascading

all_outs = []
all_intermediates = []

prev_head_out = None

for h_q, h_k, h_v, h_mask, h_attn_bias, h_prev_attn in zip(queries, keys, values, mask, attn_bias, prev_attn):

if exists(prev_head_out):
h_q = h_q + prev_head_out

out, intermediates = self.attend(
h_q, h_k, h_v,
mask = h_mask,
attn_bias = h_attn_bias,
prev_attn = h_prev_attn
)

prev_head_out = out

all_outs.append(out)
all_intermediates.append(intermediates)

# cat all output heads

all_outs = torch.cat(all_outs, dim = 1)

# cat all intermediates, if they exist

qk_similarities, pre_softmax_attn, post_softmax_attn = zip(*map(lambda i: i.to_tuple(), all_intermediates))

qk_similarities, pre_softmax_attn, post_softmax_attn = map(compact, (qk_similarities, pre_softmax_attn, post_softmax_attn))

aggregated_intermediates = Intermediates(
qk_similarities = torch.cat(qk_similarities, dim = 1) if len(qk_similarities) > 0 else None,
pre_softmax_attn = torch.cat(pre_softmax_attn, dim = 1) if len(pre_softmax_attn) > 0 else None,
post_softmax_attn = torch.cat(post_softmax_attn, dim = 1) if len(post_softmax_attn) > 0 else None
)

return all_outs, aggregated_intermediates
Loading

0 comments on commit 91539a7

Please sign in to comment.