-
Notifications
You must be signed in to change notification settings - Fork 25.9k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
GPT Neo #10848
GPT Neo #10848
Changes from 1 commit
f015a91
36f9c94
8255753
eb6c00f
0c135f9
36c827d
bcee6c7
39954ff
b302780
efa6003
d970255
dac3f89
30cf9ca
ca4bad5
5685de9
784d1cd
a474df5
7c90f3b
a5d1161
647aec4
4fc464a
8781740
eecbeea
f5ca1b9
22c9441
2683d8f
6b9aef4
8be570a
0a44cbb
bae1b69
c859513
7336c6f
0d8d2bc
1eb0bfe
23849f7
c46278f
a59f111
cbb81f9
08988ab
6869ee7
29663ab
e80fc91
7bb186b
22150cc
83c07a0
33c9ada
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
- Loading branch information
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||
---|---|---|---|---|---|---|
|
@@ -130,7 +130,7 @@ def load_tf_weights_in_gpt_neo(model, config, gpt_neo_checkpoint_path): | |||||
return model | ||||||
|
||||||
|
||||||
class Attention(nn.Module): | ||||||
class GPTNeoSelfAttention(nn.Module): | ||||||
def __init__(self, config): | ||||||
super().__init__() | ||||||
|
||||||
|
@@ -232,7 +232,7 @@ def forward( | |||||
return (a, present) + attn_outputs[1:] # a, present, (attentions) | ||||||
|
||||||
|
||||||
class LocalAttention(nn.Module): | ||||||
class GPTNeoLocalSelfAttention(nn.Module): | ||||||
def __init__(self, config): | ||||||
super().__init__() | ||||||
|
||||||
|
@@ -287,8 +287,8 @@ def merge_heads(self, x): | |||||
new_x_shape = x.size()[:-2] + (x.size(-2) * x.size(-1),) | ||||||
return x.view(*new_x_shape) | ||||||
|
||||||
def create_buckets(self, tensor, num_buckets, block_length): | ||||||
return tensor.reshape(tensor.size()[0], num_buckets, block_length, -1) | ||||||
def _split_seq_length_dim_to(self, tensors, num_blocks, block_length): | ||||||
return tensors.reshape(tensors.size()[0], num_blocks, block_length, -1) | ||||||
|
||||||
def create_attention_mask(self, bs, seq_len, windows, block_length, attention_mask): | ||||||
ticker = torch.arange(seq_len)[None, :] | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. There should be a |
||||||
|
@@ -366,17 +366,17 @@ def forward( | |||||
block_length = self.window_size | ||||||
while full_seq_length % block_length != 0: | ||||||
block_length -= 1 | ||||||
windows = full_seq_length // block_length | ||||||
num_blocks = full_seq_length // block_length | ||||||
|
||||||
# create buckets | ||||||
if layer_past is not None: | ||||||
# we just need 1 window with block_length 1 when caching is enabled | ||||||
query = self.create_buckets(query, 1, 1) | ||||||
query = self._split_seq_length_dim_to(query, 1, 1) | ||||||
else: | ||||||
query = self.create_buckets(query, windows, block_length) | ||||||
query = self._split_seq_length_dim_to(query, num_blocks, block_length) | ||||||
|
||||||
key = self.create_buckets(key, windows, block_length) | ||||||
value = self.create_buckets(value, windows, block_length) | ||||||
key = self._split_seq_length_dim_to(key, num_blocks, block_length) | ||||||
value = self._split_seq_length_dim_to(value, num_blocks, block_length) | ||||||
|
||||||
key = self.look_around(key, block_length, self.window_size) | ||||||
value = self.look_around(value, block_length, self.window_size) | ||||||
|
@@ -390,7 +390,7 @@ def forward( | |||||
key = self.split_heads(key, k=True) | ||||||
value = self.split_heads(value) | ||||||
|
||||||
mask = self.create_attention_mask(bs, full_seq_length, windows, block_length, attention_mask) | ||||||
mask = self.create_attention_mask(bs, full_seq_length, num_blocks, block_length, attention_mask) | ||||||
if layer_past is not None: | ||||||
mask = mask[:, -1:, :, -1:, :] # only take the mask for the last window | ||||||
mask = mask.to(hidden_states.device) | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think the mask should have the correct device when creating it in the function |
||||||
|
@@ -415,9 +415,9 @@ def __init__(self, config, layer_id=0): | |||||
self.attention_type = self.attention_layers[layer_id] | ||||||
|
||||||
if self.attention_type == "global": | ||||||
self.attention = Attention(config) | ||||||
self.attention = GPTNeoSelfAttention(config) | ||||||
elif self.attention_type == "local": | ||||||
self.attention = LocalAttention(config) | ||||||
self.attention = GPTNeoLocalSelfAttention(config) | ||||||
else: | ||||||
raise NotImplementedError( | ||||||
"Only attn layer types 'global' and 'local' exist, but got `config.attention_layers`: {}. Select attn layer types from ['global', 'local'] only.".format( | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
very hard to understand this function...a refactor would be useful here with better namings and more comments