Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

GPT Neo #10848

Merged
merged 46 commits into from
Mar 30, 2021
Merged

GPT Neo #10848

Changes from 1 commit
Commits
Show all changes
46 commits
Select commit Hold shift + click to select a range
f015a91
lets begin
patil-suraj Mar 22, 2021
36f9c94
boom boom
patil-suraj Mar 24, 2021
8255753
fix out proj in attn
patil-suraj Mar 24, 2021
eb6c00f
fix attention
patil-suraj Mar 25, 2021
0c135f9
fix local attention
patil-suraj Mar 25, 2021
36c827d
add tokenizer
patil-suraj Mar 25, 2021
bcee6c7
fix imports
patil-suraj Mar 25, 2021
39954ff
autotokenizer
patil-suraj Mar 25, 2021
b302780
fix checkpoint name
patil-suraj Mar 25, 2021
efa6003
cleanup
patil-suraj Mar 25, 2021
d970255
more clean-up
patil-suraj Mar 26, 2021
dac3f89
more cleanup
patil-suraj Mar 26, 2021
30cf9ca
output attentions
patil-suraj Mar 26, 2021
ca4bad5
fix attn mask creation
patil-suraj Mar 26, 2021
5685de9
fix imports
patil-suraj Mar 26, 2021
784d1cd
config doc
patil-suraj Mar 26, 2021
a474df5
add tests
patil-suraj Mar 26, 2021
7c90f3b
add slow tests
patil-suraj Mar 26, 2021
a5d1161
quality
patil-suraj Mar 26, 2021
647aec4
add conversion script
patil-suraj Mar 26, 2021
4fc464a
copyright
patil-suraj Mar 26, 2021
8781740
typo
patil-suraj Mar 26, 2021
eecbeea
another bites the dust
patil-suraj Mar 26, 2021
f5ca1b9
fix attention tests
patil-suraj Mar 26, 2021
22c9441
doc
patil-suraj Mar 26, 2021
2683d8f
add embed init in convert function
patil-suraj Mar 26, 2021
6b9aef4
fix copies
patil-suraj Mar 26, 2021
8be570a
remove tokenizer
patil-suraj Mar 28, 2021
0a44cbb
enable caching
patil-suraj Mar 28, 2021
bae1b69
address review comments
patil-suraj Mar 29, 2021
c859513
improve config and create attn layer list internally
patil-suraj Mar 29, 2021
7336c6f
more consistent naming
patil-suraj Mar 29, 2021
0d8d2bc
init hf config from mesh-tf config json file
patil-suraj Mar 29, 2021
1eb0bfe
remove neo tokenizer from doc
patil-suraj Mar 29, 2021
23849f7
handle attention_mask in local attn layer
patil-suraj Mar 29, 2021
c46278f
attn_layers => attention_layers
patil-suraj Mar 29, 2021
a59f111
add tokenizer_class in config
patil-suraj Mar 29, 2021
cbb81f9
fix docstring
patil-suraj Mar 29, 2021
08988ab
raise if len of attention_layers is not same as num_layers
patil-suraj Mar 29, 2021
6869ee7
remove tokenizer_class from config
patil-suraj Mar 29, 2021
29663ab
more consistent naming
patil-suraj Mar 29, 2021
e80fc91
fix doc
patil-suraj Mar 29, 2021
7bb186b
fix checkpoint names
patil-suraj Mar 29, 2021
22150cc
fp16 compat
patil-suraj Mar 30, 2021
83c07a0
Merge branch 'master' into gpt-neo
patil-suraj Mar 30, 2021
33c9ada
Apply suggestions from code review
LysandreJik Mar 30, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
more consistent naming
  • Loading branch information
patil-suraj committed Mar 29, 2021
commit 29663abb50a7dcdea5cd74a60daa704c6b032019
24 changes: 12 additions & 12 deletions src/transformers/models/gpt_neo/modeling_gpt_neo.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,7 @@ def load_tf_weights_in_gpt_neo(model, config, gpt_neo_checkpoint_path):
return model


class Attention(nn.Module):
class GPTNeoSelfAttention(nn.Module):
def __init__(self, config):
super().__init__()

Expand Down Expand Up @@ -232,7 +232,7 @@ def forward(
return (a, present) + attn_outputs[1:] # a, present, (attentions)


class LocalAttention(nn.Module):
class GPTNeoLocalSelfAttention(nn.Module):
def __init__(self, config):
super().__init__()

Expand Down Expand Up @@ -287,8 +287,8 @@ def merge_heads(self, x):
new_x_shape = x.size()[:-2] + (x.size(-2) * x.size(-1),)
return x.view(*new_x_shape)

def create_buckets(self, tensor, num_buckets, block_length):
return tensor.reshape(tensor.size()[0], num_buckets, block_length, -1)
def _split_seq_length_dim_to(self, tensors, num_blocks, block_length):
return tensors.reshape(tensors.size()[0], num_blocks, block_length, -1)

def create_attention_mask(self, bs, seq_len, windows, block_length, attention_mask):
ticker = torch.arange(seq_len)[None, :]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

very hard to understand this function...a refactor would be useful here with better namings and more comments

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There should be a device=attention_mask.device in torch.arange(...) IMO

Expand Down Expand Up @@ -366,17 +366,17 @@ def forward(
block_length = self.window_size
while full_seq_length % block_length != 0:
block_length -= 1
windows = full_seq_length // block_length
num_blocks = full_seq_length // block_length

# create buckets
if layer_past is not None:
# we just need 1 window with block_length 1 when caching is enabled
query = self.create_buckets(query, 1, 1)
query = self._split_seq_length_dim_to(query, 1, 1)
else:
query = self.create_buckets(query, windows, block_length)
query = self._split_seq_length_dim_to(query, num_blocks, block_length)

key = self.create_buckets(key, windows, block_length)
value = self.create_buckets(value, windows, block_length)
key = self._split_seq_length_dim_to(key, num_blocks, block_length)
value = self._split_seq_length_dim_to(value, num_blocks, block_length)

key = self.look_around(key, block_length, self.window_size)
value = self.look_around(value, block_length, self.window_size)
Expand All @@ -390,7 +390,7 @@ def forward(
key = self.split_heads(key, k=True)
value = self.split_heads(value)

mask = self.create_attention_mask(bs, full_seq_length, windows, block_length, attention_mask)
mask = self.create_attention_mask(bs, full_seq_length, num_blocks, block_length, attention_mask)
if layer_past is not None:
mask = mask[:, -1:, :, -1:, :] # only take the mask for the last window
mask = mask.to(hidden_states.device)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the mask should have the correct device when creating it in the function

Expand All @@ -415,9 +415,9 @@ def __init__(self, config, layer_id=0):
self.attention_type = self.attention_layers[layer_id]

if self.attention_type == "global":
self.attention = Attention(config)
self.attention = GPTNeoSelfAttention(config)
elif self.attention_type == "local":
self.attention = LocalAttention(config)
self.attention = GPTNeoLocalSelfAttention(config)
else:
raise NotImplementedError(
"Only attn layer types 'global' and 'local' exist, but got `config.attention_layers`: {}. Select attn layer types from ['global', 'local'] only.".format(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
"Only attn layer types 'global' and 'local' exist, but got `config.attention_layers`: {}. Select attn layer types from ['global', 'local'] only.".format(
f"Only attn layer types 'global' and 'local' exist, but got `config.attention_layers`: {self.attention_layers}. Select attn layer types from ['global', 'local'] only."

Expand Down