Skip to content

Commit

Permalink
new paper suggests that feedforwards needs to be present
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Mar 8, 2021
1 parent 785a241 commit 365d5c8
Show file tree
Hide file tree
Showing 4 changed files with 16 additions and 12 deletions.
12 changes: 10 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

This is a Pytorch implementation of Reformer https://openreview.net/pdf?id=rkgNKkHtvB

It includes LSH attention, reversible network, and chunking. It has been validated with an auto-regressive task (enwik8). It also includes additional features to make the entire network pure attention all the way down.
It includes LSH attention, reversible network, and chunking. It has been validated with an auto-regressive task (enwik8).

[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1am1DRl80Kd3o6n_4u3MomPzYS0NfdHAC) 32k tokens

Expand Down Expand Up @@ -45,7 +45,6 @@ model = ReformerLM(
ff_chunks = 200, # number of chunks for feedforward layer, make higher if there are memory issues
attn_chunks = 8, # process lsh attention in chunks, only way for memory to fit when scaling to 16k tokens
num_mem_kv = 128, # persistent learned memory key values, from all-attention paper
twin_attention = False, # both branches of the reversible network will be attention
full_attn_thres = 1024, # use full attention if context length is less than set value
reverse_thres = 1024, # turn off reversibility for 2x speed for sequence lengths shorter or equal to the designated value
use_scale_norm = False, # use scale norm from 'Transformers without tears' paper
Expand Down Expand Up @@ -554,6 +553,15 @@ print(sample.shape) # (1, <=100) token ids
}
```

```bibtex
@misc{dong2021attention,
title = {Attention is Not All You Need: Pure Attention Loses Rank Doubly Exponentially with Depth},
author = {Yihe Dong and Jean-Baptiste Cordonnier and Andreas Loukas},
year = {2021},
eprint = {2103.03404}
}
```

```bibtex
@misc{vaswani2017attention,
title = {Attention Is All You Need},
Expand Down
12 changes: 4 additions & 8 deletions reformer_pytorch/reformer_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -632,15 +632,14 @@ def forward(self, x):
# reformer lm

class Reformer(nn.Module):
def __init__(self, dim, depth, max_seq_len, heads = 8, dim_head = None, bucket_size = 64, n_hashes = 8, ff_chunks = 100, attn_chunks = None, causal = False, weight_tie = False, lsh_dropout = 0., ff_dropout = 0., ff_activation = None, ff_mult = 4, ff_glu = False, post_attn_dropout = 0., layer_dropout = 0., lsh_attend_across_buckets = True, lsh_allow_duplicate_attention = True, random_rotations_per_head = False, twin_attention = False, use_scale_norm = False, use_rezero = False, use_full_attn = False, full_attn_thres = 0, reverse_thres = 0, num_mem_kv = 0, one_value_head = False, n_local_attn_heads = 0, pkm_layers = tuple(), pkm_num_keys = 128):
def __init__(self, dim, depth, max_seq_len, heads = 8, dim_head = None, bucket_size = 64, n_hashes = 8, ff_chunks = 100, attn_chunks = None, causal = False, weight_tie = False, lsh_dropout = 0., ff_dropout = 0., ff_activation = None, ff_mult = 4, ff_glu = False, post_attn_dropout = 0., layer_dropout = 0., lsh_attend_across_buckets = True, lsh_allow_duplicate_attention = True, random_rotations_per_head = False, use_scale_norm = False, use_rezero = False, use_full_attn = False, full_attn_thres = 0, reverse_thres = 0, num_mem_kv = 0, one_value_head = False, n_local_attn_heads = 0, pkm_layers = tuple(), pkm_num_keys = 128):
super().__init__()
self.dim = dim
self.depth = depth

self.bucket_size = bucket_size
self.num_mem_kv = num_mem_kv

self.twin_attention = twin_attention
self.full_attn_thres = full_attn_thres

get_attn = lambda: LSHSelfAttention(dim, heads, bucket_size, n_hashes, causal = causal, dim_head = dim_head, dropout = lsh_dropout, post_attn_dropout = post_attn_dropout, attn_chunks = attn_chunks, allow_duplicate_attention = lsh_allow_duplicate_attention, attend_across_buckets = lsh_attend_across_buckets, random_rotations_per_head = random_rotations_per_head, num_mem_kv = num_mem_kv, use_full_attn = use_full_attn, full_attn_thres = full_attn_thres, one_value_head = one_value_head, n_local_attn_heads = n_local_attn_heads)
Expand All @@ -665,8 +664,6 @@ def __init__(self, dim, depth, max_seq_len, heads = 8, dim_head = None, bucket_s

if use_pkm:
parallel_net = get_pkm()
elif twin_attention:
parallel_net = get_attn()
else:
parallel_net = get_ff()

Expand All @@ -679,12 +676,11 @@ def __init__(self, dim, depth, max_seq_len, heads = 8, dim_head = None, bucket_s

def forward(self, x, **kwargs):
x = torch.cat([x, x], dim = -1)
arg_route = (True, self.twin_attention)
x = self.layers(x, arg_route = arg_route, **kwargs)
x = self.layers(x, **kwargs)
return torch.stack(x.chunk(2, dim=-1)).mean(dim=0)

class ReformerLM(nn.Module):
def __init__(self, num_tokens, dim, depth, max_seq_len, heads = 8, dim_head = None, bucket_size = 64, n_hashes = 4, ff_chunks = 100, attn_chunks = 1, causal = False, weight_tie = False, lsh_dropout = 0., ff_dropout = 0., ff_mult = 4, ff_activation = None, ff_glu = False, post_attn_dropout = 0., layer_dropout = 0., random_rotations_per_head = False, twin_attention = False, use_scale_norm = False, use_rezero = False, use_full_attn = False, full_attn_thres = 0, reverse_thres = 0, num_mem_kv = 0, one_value_head = False, emb_dim = None, return_embeddings = False, weight_tie_embedding = False, fixed_position_emb = False, absolute_position_emb = False, axial_position_shape = None, n_local_attn_heads = 0, pkm_layers = tuple(), pkm_num_keys = 128):
def __init__(self, num_tokens, dim, depth, max_seq_len, heads = 8, dim_head = None, bucket_size = 64, n_hashes = 4, ff_chunks = 100, attn_chunks = 1, causal = False, weight_tie = False, lsh_dropout = 0., ff_dropout = 0., ff_mult = 4, ff_activation = None, ff_glu = False, post_attn_dropout = 0., layer_dropout = 0., random_rotations_per_head = False, use_scale_norm = False, use_rezero = False, use_full_attn = False, full_attn_thres = 0, reverse_thres = 0, num_mem_kv = 0, one_value_head = False, emb_dim = None, return_embeddings = False, weight_tie_embedding = False, fixed_position_emb = False, absolute_position_emb = False, axial_position_shape = None, n_local_attn_heads = 0, pkm_layers = tuple(), pkm_num_keys = 128):
super().__init__()
emb_dim = default(emb_dim, dim)
self.max_seq_len = max_seq_len
Expand All @@ -701,7 +697,7 @@ def __init__(self, num_tokens, dim, depth, max_seq_len, heads = 8, dim_head = No
axial_position_shape = default(axial_position_shape, (math.ceil(max_seq_len / bucket_size), bucket_size))
self.pos_emb = AxialPositionalEmbedding(emb_dim, axial_position_shape)

self.reformer = Reformer(dim, depth, max_seq_len, heads = heads, dim_head = dim_head, bucket_size = bucket_size, n_hashes = n_hashes, ff_chunks = ff_chunks, attn_chunks = attn_chunks, causal = causal, weight_tie = weight_tie, lsh_dropout = lsh_dropout, ff_mult = ff_mult, ff_activation = ff_activation, ff_glu = ff_glu, ff_dropout = ff_dropout, post_attn_dropout = 0., layer_dropout = layer_dropout, random_rotations_per_head = random_rotations_per_head, twin_attention = twin_attention, use_scale_norm = use_scale_norm, use_rezero = use_rezero, use_full_attn = use_full_attn, full_attn_thres = full_attn_thres, reverse_thres = reverse_thres, num_mem_kv = num_mem_kv, one_value_head = one_value_head, n_local_attn_heads = n_local_attn_heads, pkm_layers = pkm_layers, pkm_num_keys = pkm_num_keys)
self.reformer = Reformer(dim, depth, max_seq_len, heads = heads, dim_head = dim_head, bucket_size = bucket_size, n_hashes = n_hashes, ff_chunks = ff_chunks, attn_chunks = attn_chunks, causal = causal, weight_tie = weight_tie, lsh_dropout = lsh_dropout, ff_mult = ff_mult, ff_activation = ff_activation, ff_glu = ff_glu, ff_dropout = ff_dropout, post_attn_dropout = 0., layer_dropout = layer_dropout, random_rotations_per_head = random_rotations_per_head, use_scale_norm = use_scale_norm, use_rezero = use_rezero, use_full_attn = use_full_attn, full_attn_thres = full_attn_thres, reverse_thres = reverse_thres, num_mem_kv = num_mem_kv, one_value_head = one_value_head, n_local_attn_heads = n_local_attn_heads, pkm_layers = pkm_layers, pkm_num_keys = pkm_num_keys)
self.norm = nn.LayerNorm(dim)

if return_embeddings:
Expand Down
2 changes: 1 addition & 1 deletion reformer_pytorch/reversible.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,7 @@ def __init__(self, blocks, layer_dropout = 0., reverse_thres = 0, send_signal =
self.blocks = nn.ModuleList([ReversibleBlock(f, g, depth, send_signal) for depth, (f, g) in enumerate(blocks)])
self.irrev_blocks = nn.ModuleList([IrreversibleBlock(f=f, g=g) for f, g in blocks])

def forward(self, x, arg_route = (True, True), **kwargs):
def forward(self, x, arg_route = (True, False), **kwargs):
reverse = x.shape[1] > self.reverse_thres
blocks = self.blocks if reverse else self.irrev_blocks

Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
setup(
name = 'reformer_pytorch',
packages = find_packages(exclude=['examples', 'pretraining']),
version = '1.2.5',
version = '1.2.6',
license='MIT',
description = 'Reformer, the Efficient Transformer, Pytorch',
author = 'Phil Wang',
Expand Down

0 comments on commit 365d5c8

Please sign in to comment.