-
Notifications
You must be signed in to change notification settings - Fork 25.9k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
GPT Neo #10848
Merged
Merged
GPT Neo #10848
Changes from 1 commit
Commits
Show all changes
46 commits
Select commit
Hold shift + click to select a range
f015a91
lets begin
patil-suraj 36f9c94
boom boom
patil-suraj 8255753
fix out proj in attn
patil-suraj eb6c00f
fix attention
patil-suraj 0c135f9
fix local attention
patil-suraj 36c827d
add tokenizer
patil-suraj bcee6c7
fix imports
patil-suraj 39954ff
autotokenizer
patil-suraj b302780
fix checkpoint name
patil-suraj efa6003
cleanup
patil-suraj d970255
more clean-up
patil-suraj dac3f89
more cleanup
patil-suraj 30cf9ca
output attentions
patil-suraj ca4bad5
fix attn mask creation
patil-suraj 5685de9
fix imports
patil-suraj 784d1cd
config doc
patil-suraj a474df5
add tests
patil-suraj 7c90f3b
add slow tests
patil-suraj a5d1161
quality
patil-suraj 647aec4
add conversion script
patil-suraj 4fc464a
copyright
patil-suraj 8781740
typo
patil-suraj eecbeea
another bites the dust
patil-suraj f5ca1b9
fix attention tests
patil-suraj 22c9441
doc
patil-suraj 2683d8f
add embed init in convert function
patil-suraj 6b9aef4
fix copies
patil-suraj 8be570a
remove tokenizer
patil-suraj 0a44cbb
enable caching
patil-suraj bae1b69
address review comments
patil-suraj c859513
improve config and create attn layer list internally
patil-suraj 7336c6f
more consistent naming
patil-suraj 0d8d2bc
init hf config from mesh-tf config json file
patil-suraj 1eb0bfe
remove neo tokenizer from doc
patil-suraj 23849f7
handle attention_mask in local attn layer
patil-suraj c46278f
attn_layers => attention_layers
patil-suraj a59f111
add tokenizer_class in config
patil-suraj cbb81f9
fix docstring
patil-suraj 08988ab
raise if len of attention_layers is not same as num_layers
patil-suraj 6869ee7
remove tokenizer_class from config
patil-suraj 29663ab
more consistent naming
patil-suraj e80fc91
fix doc
patil-suraj 7bb186b
fix checkpoint names
patil-suraj 22150cc
fp16 compat
patil-suraj 83c07a0
Merge branch 'master' into gpt-neo
patil-suraj 33c9ada
Apply suggestions from code review
LysandreJik File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
fix out proj in attn
- Loading branch information
commit 8255753cebde7e0955ad9ebfd475fc9cccdedc52
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change | ||||
---|---|---|---|---|---|---|
|
@@ -35,7 +35,8 @@ | |||||
from ...modeling_outputs import ( | ||||||
BaseModelOutputWithPast, | ||||||
BaseModelOutputWithPastAndCrossAttentions, | ||||||
CausalLMOutputWithCrossAttentions, CausalLMOutputWithPast, | ||||||
CausalLMOutputWithCrossAttentions, | ||||||
CausalLMOutputWithPast, | ||||||
MaskedLMOutput, | ||||||
MultipleChoiceModelOutput, | ||||||
QuestionAnsweringModelOutput, | ||||||
|
@@ -154,9 +155,7 @@ def __init__(self, nx, n_ctx, config, scale=False): | |||||
self.k_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=False) | ||||||
self.v_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=False) | ||||||
self.q_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=False) | ||||||
self.out_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=False) | ||||||
|
||||||
self.attn_bias = nn.Parameter(torch.zeros(self.embed_dim)) | ||||||
self.out_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=True) | ||||||
|
||||||
def _attn(self, q, k, v, attention_mask=None, head_mask=None, output_attentions=False): | ||||||
w = torch.matmul(q, k) | ||||||
|
@@ -230,17 +229,14 @@ def forward( | |||||
|
||||||
a = self.merge_heads(a) | ||||||
a = self.out_proj(a) | ||||||
# a = self.resid_dropout(a) | ||||||
|
||||||
a += self.attn_bias | ||||||
a = self.resid_dropout(a) | ||||||
|
||||||
return (a, present) + attn_outputs[1:] # a, present, (attentions) | ||||||
|
||||||
|
||||||
class LocalAttention(nn.Module): | ||||||
def __init__(self, nx, n_ctx, config, scale=False): | ||||||
super().__init__() | ||||||
print("init local") | ||||||
n_state = nx # in Attention: n_state=768 (nx=n_embd) | ||||||
# [switch nx => n_state from Block to Attention to keep identical to TF implem] | ||||||
assert n_state % config.n_head == 0 | ||||||
|
@@ -268,9 +264,7 @@ def __init__(self, nx, n_ctx, config, scale=False): | |||||
self.k_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=False) | ||||||
self.v_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=False) | ||||||
self.q_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=False) | ||||||
self.out_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=False) | ||||||
|
||||||
self.attn_bias = nn.Parameter(torch.zeros(self.embed_dim)) | ||||||
self.out_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=True) | ||||||
|
||||||
self.window_size = config.window_size | ||||||
|
||||||
|
@@ -350,8 +344,7 @@ def forward( | |||||
attn = attn.reshape(-1, seq_len, self.embed_dim) | ||||||
|
||||||
attn = self.out_proj(attn) | ||||||
|
||||||
attn += self.attn_bias | ||||||
attn = self.resid_dropout(attn) | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||
return (attn,) | ||||||
|
||||||
|
||||||
|
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same here regarding the
GPTNeo
prefix