Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

GPT Neo #10848

Merged
merged 46 commits into from
Mar 30, 2021
Merged

GPT Neo #10848

Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
46 commits
Select commit Hold shift + click to select a range
f015a91
lets begin
patil-suraj Mar 22, 2021
36f9c94
boom boom
patil-suraj Mar 24, 2021
8255753
fix out proj in attn
patil-suraj Mar 24, 2021
eb6c00f
fix attention
patil-suraj Mar 25, 2021
0c135f9
fix local attention
patil-suraj Mar 25, 2021
36c827d
add tokenizer
patil-suraj Mar 25, 2021
bcee6c7
fix imports
patil-suraj Mar 25, 2021
39954ff
autotokenizer
patil-suraj Mar 25, 2021
b302780
fix checkpoint name
patil-suraj Mar 25, 2021
efa6003
cleanup
patil-suraj Mar 25, 2021
d970255
more clean-up
patil-suraj Mar 26, 2021
dac3f89
more cleanup
patil-suraj Mar 26, 2021
30cf9ca
output attentions
patil-suraj Mar 26, 2021
ca4bad5
fix attn mask creation
patil-suraj Mar 26, 2021
5685de9
fix imports
patil-suraj Mar 26, 2021
784d1cd
config doc
patil-suraj Mar 26, 2021
a474df5
add tests
patil-suraj Mar 26, 2021
7c90f3b
add slow tests
patil-suraj Mar 26, 2021
a5d1161
quality
patil-suraj Mar 26, 2021
647aec4
add conversion script
patil-suraj Mar 26, 2021
4fc464a
copyright
patil-suraj Mar 26, 2021
8781740
typo
patil-suraj Mar 26, 2021
eecbeea
another bites the dust
patil-suraj Mar 26, 2021
f5ca1b9
fix attention tests
patil-suraj Mar 26, 2021
22c9441
doc
patil-suraj Mar 26, 2021
2683d8f
add embed init in convert function
patil-suraj Mar 26, 2021
6b9aef4
fix copies
patil-suraj Mar 26, 2021
8be570a
remove tokenizer
patil-suraj Mar 28, 2021
0a44cbb
enable caching
patil-suraj Mar 28, 2021
bae1b69
address review comments
patil-suraj Mar 29, 2021
c859513
improve config and create attn layer list internally
patil-suraj Mar 29, 2021
7336c6f
more consistent naming
patil-suraj Mar 29, 2021
0d8d2bc
init hf config from mesh-tf config json file
patil-suraj Mar 29, 2021
1eb0bfe
remove neo tokenizer from doc
patil-suraj Mar 29, 2021
23849f7
handle attention_mask in local attn layer
patil-suraj Mar 29, 2021
c46278f
attn_layers => attention_layers
patil-suraj Mar 29, 2021
a59f111
add tokenizer_class in config
patil-suraj Mar 29, 2021
cbb81f9
fix docstring
patil-suraj Mar 29, 2021
08988ab
raise if len of attention_layers is not same as num_layers
patil-suraj Mar 29, 2021
6869ee7
remove tokenizer_class from config
patil-suraj Mar 29, 2021
29663ab
more consistent naming
patil-suraj Mar 29, 2021
e80fc91
fix doc
patil-suraj Mar 29, 2021
7bb186b
fix checkpoint names
patil-suraj Mar 29, 2021
22150cc
fp16 compat
patil-suraj Mar 30, 2021
83c07a0
Merge branch 'master' into gpt-neo
patil-suraj Mar 30, 2021
33c9ada
Apply suggestions from code review
LysandreJik Mar 30, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
fix out proj in attn
  • Loading branch information
patil-suraj committed Mar 29, 2021
commit 8255753cebde7e0955ad9ebfd475fc9cccdedc52
15 changes: 14 additions & 1 deletion src/transformers/models/gpt_neo/configuration_gpt_neo.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,20 @@ def __init__(
n_ctx=1024,
n_embd=768,
n_layer=12,
attn_layers=["global", "local", "global", "local", "global", "local", "global", "local", "global", "local", "global", "local"],
attn_layers=[
"global",
"local",
"global",
"local",
"global",
"local",
"global",
"local",
"global",
"local",
"global",
"local",
],
n_head=12,
n_inner=None,
activation_function="gelu_new",
Expand Down
19 changes: 6 additions & 13 deletions src/transformers/models/gpt_neo/modeling_gpt_neo.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,8 @@
from ...modeling_outputs import (
BaseModelOutputWithPast,
BaseModelOutputWithPastAndCrossAttentions,
CausalLMOutputWithCrossAttentions, CausalLMOutputWithPast,
CausalLMOutputWithCrossAttentions,
CausalLMOutputWithPast,
MaskedLMOutput,
MultipleChoiceModelOutput,
QuestionAnsweringModelOutput,
Expand Down Expand Up @@ -154,9 +155,7 @@ def __init__(self, nx, n_ctx, config, scale=False):
self.k_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=False)
self.v_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=False)
self.q_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=False)
self.out_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=False)

self.attn_bias = nn.Parameter(torch.zeros(self.embed_dim))
self.out_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=True)

def _attn(self, q, k, v, attention_mask=None, head_mask=None, output_attentions=False):
w = torch.matmul(q, k)
Expand Down Expand Up @@ -230,17 +229,14 @@ def forward(

a = self.merge_heads(a)
a = self.out_proj(a)
# a = self.resid_dropout(a)

a += self.attn_bias
a = self.resid_dropout(a)

return (a, present) + attn_outputs[1:] # a, present, (attentions)


class LocalAttention(nn.Module):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same here regarding the GPTNeo prefix

def __init__(self, nx, n_ctx, config, scale=False):
super().__init__()
print("init local")
n_state = nx # in Attention: n_state=768 (nx=n_embd)
# [switch nx => n_state from Block to Attention to keep identical to TF implem]
assert n_state % config.n_head == 0
Expand Down Expand Up @@ -268,9 +264,7 @@ def __init__(self, nx, n_ctx, config, scale=False):
self.k_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=False)
self.v_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=False)
self.q_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=False)
self.out_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=False)

self.attn_bias = nn.Parameter(torch.zeros(self.embed_dim))
self.out_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=True)

self.window_size = config.window_size

Expand Down Expand Up @@ -350,8 +344,7 @@ def forward(
attn = attn.reshape(-1, seq_len, self.embed_dim)

attn = self.out_proj(attn)

attn += self.attn_bias
attn = self.resid_dropout(attn)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
attn = self.resid_dropout(attn)
attn = self.residual_dropout(attn)

return (attn,)


Expand Down