Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

GPT Neo #10848

Merged
merged 46 commits into from
Mar 30, 2021
Merged

GPT Neo #10848

merged 46 commits into from
Mar 30, 2021

Conversation

patil-suraj
Copy link
Contributor

@patil-suraj patil-suraj commented Mar 22, 2021

What does this PR do?

This PR adds the GPT Neo model.
The model architecture is very similar to GPT2 except it local attention in alternate layers

  • LocalAttention module implements the local attention. The implementation is not as clean as it should be and will be cleaned-up in follow-up PR.
  • To enable caching (use_cache) the local attention layer caches the hidden_states instead of past_key_value_states.
    Also right now when use_cache is enabled the current length can-not be greater than 1.
  • The model uses the same tokenizer as GPT2 so does not need a new tokenizer class.

Example: usage

import torch
from transformers import GPTNeoForCausalLM, AutoTokenizer

model = GPTNeoForCausalLM.from_pretrained("EleutherAI/gpt-neo-1.3B")
tokenizer = AutoTokenizer.from_pretrained("EleutherAI/gpt-neo-1.3B")

unicorns = "In a shocking finding, scientists discovered a herd of unicorns living in a remote, " \
           "previously unexplored valley, in the Andes Mountains. Even more surprising to the " \
           "researchers was the fact that the unicorns spoke perfect English."

input_ids = tokenizer(unicorns, return_tensors="pt").input_ids

# add the length of the prompt tokens to match with the mesh-tf generation
max_length = 400 + input_ids.shape[1] 

temperature = .9
do_sample = True

# set seed to reproduce samples
torch.manual_seed(42) 

gen_tokens = model.generate(
  input_ids,
  do_sample=do_sample,
  min_length=max_length,
  max_length=max_length,
  temperature=temperature,
)
gen_text = tokenizer.batch_decode(gen_tokens)[0]

Future TODOs:

  • clean-up the implementation of LocalAttention especially the creation of attention_mask.
  • test fine-tuning.
  • enable current length > 1 when use_cache is enabled.
  • Add more robust and aggressive tests for the LocalAttention module.
  • Add TF model.

@TevenLeScao
Copy link
Contributor

TevenLeScao commented Mar 26, 2021

@sdtblck @leogao2 this is the Neo PR, reviews/comments appreciated !

@StellaAthena
Copy link
Contributor

I tried running this with the 2.7B checkpoint and got

(base) stellabiderman@Stellas-MacBook-Pro research % python transformers/src/transformers/models/gpt_neo/convert_gpt_neo_mesh_tf_to_pytorch.py --tf_checkpoint_path GPT3_2.7B/checkpoint --config_file GPT3_2-7B/config.json --pytorch_dump_path GPT3_2-7B
Building PyTorch model from configuration: GPTNeoConfig {
  "activation_function": "gelu",
  "ada_epsilon1": "1e-30",
  "ada_epsilon2": 0.001,
  "attention_types": [
    [
      [
        "global",
        "local"
      ],
      16
    ]
  ],
  "attn_dropout": 0,
  "attn_layers": [
    "global",
    "local",
    "global",
    "local",
    "global",
    "local",
    "global",
    "local",
    "global",
    "local",
    "global",
    "local",
    "global",
    "local",
    "global",
    "local",
    "global",
    "local",
    "global",
    "local",
    "global",
    "local",
    "global",
    "local"
  ],
  "attn_pdrop": 0.1,
  "beta1": 0.9,
  "beta2": 0.95,
  "bos_token_id": 50256,
  "datasets": [
    [
      "pile",
      null,
      null,
      null
    ]
  ],
  "embd_pdrop": 0.1,
  "embed_dropout": 0,
  "eos_id": 50256,
  "eos_token_id": 50256,
  "epsilon": 1e-08,
  "eval_batch_size": 128,
  "eval_steps": 10,
  "gradient_checkpointing": false,
  "gradient_clipping": 1.0,
  "initializer_range": 0.02,
  "iterations": 500,
  "layer_norm_epsilon": 1e-05,
  "layout": "batch:x,embd:y",
  "lr": 0.00016,
  "lr_decay": "cosine",
  "lr_decay_end": 300000,
  "mesh_shape": "x:64,y:4",
  "model_path": "gs://neo-d/models/GPT3_2-7B",
  "model_type": "gpt_neo",
  "n_ctx": 2048,
  "n_embd": 2560,
  "n_head": 20,
  "n_inner": null,
  "n_layer": 32,
  "n_positions": 2048,
  "n_vocab": 50257,
  "opt_name": "adam",
  "padding_id": 50257,
  "predict_batch_size": 1,
  "predict_steps": 0,
  "recompute_grad": true,
  "res_dropout": 0,
  "resid_pdrop": 0.1,
  "scale_by_depth": true,
  "scale_by_in": false,
  "summary_activation": null,
  "summary_first_dropout": 0.1,
  "summary_proj_to_labels": true,
  "summary_type": "cls_index",
  "summary_use_proj": true,
  "tokens_per_mb_per_replica": 4096,
  "train_batch_size": 512,
  "train_steps": 400000,
  "transformers_version": "4.5.0.dev0",
  "use_cache": false,
  "vocab_size": 50257,
  "warmup_steps": 3000,
  "weight_decay": 0,
  "window_size": 256
}

Traceback (most recent call last):
  File "transformers/src/transformers/models/gpt_neo/convert_gpt_neo_mesh_tf_to_pytorch.py", line 59, in <module>
    convert_tf_checkpoint_to_pytorch(args.tf_checkpoint_path, args.config_file, args.pytorch_dump_path)
  File "transformers/src/transformers/models/gpt_neo/convert_gpt_neo_mesh_tf_to_pytorch.py", line 31, in convert_tf_checkpoint_to_pytorch
    model = GPTNeoForCausalLM(config)
  File "/Users/stellabiderman/Documents/Research/transformers/src/transformers/models/gpt_neo/modeling_gpt_neo.py", line 778, in __init__
    self.transformer = GPTNeoModel(config)
  File "/Users/stellabiderman/Documents/Research/transformers/src/transformers/models/gpt_neo/modeling_gpt_neo.py", line 597, in __init__
    self.h = nn.ModuleList([Block(config, layer_id=i) for i in range(config.n_layer)])
  File "/Users/stellabiderman/Documents/Research/transformers/src/transformers/models/gpt_neo/modeling_gpt_neo.py", line 597, in <listcomp>
    self.h = nn.ModuleList([Block(config, layer_id=i) for i in range(config.n_layer)])
  File "/Users/stellabiderman/Documents/Research/transformers/src/transformers/models/gpt_neo/modeling_gpt_neo.py", line 434, in __init__
    self.attn = GPTNeoAttention(config, layer_id)
  File "/Users/stellabiderman/Documents/Research/transformers/src/transformers/models/gpt_neo/modeling_gpt_neo.py", line 381, in __init__
    self.attention_type = self.attn_layers[layer_id]
IndexError: list index out of range

@patil-suraj
Copy link
Contributor Author

Hi @StellaAthena ,
2.7B models has 32 layers, so attn_layers should be

['global',
 'local',
 'global',
 'local',
 'global',
 'local',
 'global',
 'local',
 'global',
 'local',
 'global',
 'local',
 'global',
 'local',
 'global',
 'local',
 'global',
 'local',
 'global',
 'local',
 'global',
 'local',
 'global',
 'local',
 'global',
 'local',
 'global',
 'local',
 'global',
 'local',
 'global',
 'local']

I've converted these checkpoints and will push them to the hub in a couple of hours. I'll ping you once that's done, so you can directly download them.

@StellaAthena
Copy link
Contributor

StellaAthena commented Mar 26, 2021

I see! Is this a problem with my local config file, or is something up with the code on the repo? I downloaded my file directly from the-eye before running the conversion script, so if the local config file is wrong that’s a bit of a problem for us.

@sdtblck
Copy link

sdtblck commented Mar 26, 2021

Hey @patil-suraj haven't had a chance to look over the whole PR yet, so i'm not sure how you load up the configuration, but I wonder why you even have separate fields for "attention_types" and "attention_layers" since they configure the same thing, and attention layers can be derived from attention types

@patil-suraj
Copy link
Contributor Author

patil-suraj commented Mar 26, 2021

Hi @sdtblck

attention_types is not used by the config, it only uses attention_layers, but yeah attention_layers can be derived from
attention_types.

For an example config file, see https://huggingface.co/valhalla/gpt_neo_xl_test/blob/main/config.json

I've uploaded the 1.3B checkpoint under my namespace temporarily, here's a colab if you wanna give it a try.

@StellaAthena
Copy link
Contributor

StellaAthena commented Mar 26, 2021

Hi @sdtblck

attention_types is not used by the config, it only uses attention_layers, but yeah attention_layers can be derived from
attention_types.

Our config file doesn't define attention _layers. It appears that you hard-coded this specific attention pattern. I agree with @sdtblck that it would make much more sense to derive attention_layers from attention_types. I believe the correct place to do that would be here.

@patil-suraj
Copy link
Contributor Author

Yes, you are right! I hardcoded it since we usually prefer to keep everything explicit but yeah I agree this would be a problem for your side. I will change it so that attention_layers will be derived from attention_types.

Are there any other issues?

@patil-suraj
Copy link
Contributor Author

@StellaAthena @sdtblck

The 2.7B model is up! https://huggingface.co/valhalla/gpt_neo_2.7B/tree/main

@StellaAthena
Copy link
Contributor

StellaAthena commented Mar 26, 2021

I tried out the 2.7B model you posted @patil-suraj but it wouldn't run. I get the error

Some weights of the model checkpoint at valhalla/gpt_neo_2.7B were not used when initializing GPT2LMHeadModel:
...
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
Traceback (most recent call last):
  File "main.py", line 9, in <module>
    from lm_eval import models, tasks, evaluator, base
  File "/home/mchorse/lm-evaluation-harness/lm_eval/models/__init__.py", line 7, in <module>
    "gpt-neo": gpt2.GPT2LM(device="cuda",pretrained="valhalla/gpt_neo_2.7B"),
  File "/home/mchorse/lm-evaluation-harness/lm_eval/models/gpt2.py", line 14, in __init__
    self.gpt2 = transformers.GPT2LMHeadModel.from_pretrained(pretrained).to(self.device)
  File "/home/mchorse/.local/lib/python3.8/site-packages/transformers/modeling_utils.py", line 1181, in from_pretrained
    raise RuntimeError(
RuntimeError: Error(s) in loading state_dict for GPT2LMHeadModel:
...

Looking through the readout, I see

size mismatch for transformer.h.0.mlp.c_fc.weight: copying a param with shape torch.Size([10240, 2560]) from checkpoint, the shape in current model is torch.Size([2560, 10240]).
size mismatch for transformer.h.0.mlp.c_proj.weight: copying a param with shape torch.Size([2560, 10240]) from checkpoint, the shape in current model is torch.Size([10240, 2560]).
size mismatch for transformer.h.1.mlp.c_fc.weight: copying a param with shape torch.Size([10240, 2560]) from checkpoint, the shape in current model is torch.Size([2560, 10240]).
size mismatch for transformer.h.1.mlp.c_proj.weight: copying a param with shape torch.Size([2560, 10240]) from checkpoint, the shape in current model is torch.Size([10240, 2560]).
size mismatch for transformer.h.2.mlp.c_fc.weight: copying a param with shape torch.Size([10240, 2560]) from checkpoint, the shape in current model is torch.Size([2560, 10240]).
size mismatch for transformer.h.2.mlp.c_proj.weight: copying a param with shape torch.Size([2560, 10240]) from checkpoint, the shape in current model is torch.Size([10240, 2560]).

I think that there's an unneeded transpose hanging out in the code.

@patil-suraj
Copy link
Contributor Author

patil-suraj commented Mar 26, 2021

It looks like you are using the GPT2LMHeadModel class. We've added a new class GPTNeoForCasualLM for gpt-neo , which should be used instead of GPT2LMHeadModel.

Could you checkout this PR and try loading it using the GPTNeoForCasualLM class ?

And yes, GPT2 uses this Conv1D layer which has transposed weights, hence the error.

Copy link
Member

@LysandreJik LysandreJik left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is fantastic, good job at going so fast @patil-suraj! Could you put here the tasks left to be done, even if you plan on doing them in a future PR? Thanks!

docs/source/model_doc/gpt_neo.rst Outdated Show resolved Hide resolved
src/transformers/models/gpt_neo/configuration_gpt_neo.py Outdated Show resolved Hide resolved
src/transformers/models/gpt_neo/configuration_gpt_neo.py Outdated Show resolved Hide resolved
src/transformers/models/gpt_neo/configuration_gpt_neo.py Outdated Show resolved Hide resolved
src/transformers/models/gpt_neo/configuration_gpt_neo.py Outdated Show resolved Hide resolved
src/transformers/models/gpt_neo/tokenization_gpt_neo.py Outdated Show resolved Hide resolved
tests/test_modeling_gpt_neo.py Show resolved Hide resolved
tests/test_modeling_gpt_neo.py Show resolved Hide resolved
src/transformers/models/gpt_neo/configuration_gpt_neo.py Outdated Show resolved Hide resolved
tests/test_modeling_gpt_neo.py Outdated Show resolved Hide resolved
Copy link
Member

@LysandreJik LysandreJik left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Was there no way to add some "# Copied from" statements to ensure that the two models do not diverge?

@patil-suraj
Copy link
Contributor Author

patil-suraj commented Mar 26, 2021

Was there no way to add some "# Copied from" statements to ensure that the two models do not diverge?

I have made some changes to the code mostly related to naming and passing config to Block and Attention instead of individual arguments, so can't really use # Copied from

Copy link
Collaborator

@sgugger sgugger left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Very nice work! Left mostly comments on style/docstrings and names, the two main comments I have are:

  • there should be an easier API in the config to get the attn_layers (instead of having to pass a tuple of 24 elements following a given pattern), I made a suggestion.
  • if the tokenizer is a copy of GPT-2, the model should just use the GPT-2 tokenizer.

docs/source/model_doc/gpt_neo.rst Outdated Show resolved Hide resolved
docs/source/model_doc/gpt_neo.rst Outdated Show resolved Hide resolved
docs/source/model_doc/gpt_neo.rst Outdated Show resolved Hide resolved
docs/source/model_doc/gpt_neo.rst Outdated Show resolved Hide resolved
src/transformers/models/auto/configuration_auto.py Outdated Show resolved Hide resolved
src/transformers/models/gpt_neo/tokenization_gpt_neo.py Outdated Show resolved Hide resolved
src/transformers/models/gpt_neo/tokenization_gpt_neo.py Outdated Show resolved Hide resolved
tests/test_modeling_gpt_neo.py Outdated Show resolved Hide resolved
tests/test_modeling_gpt_neo.py Outdated Show resolved Hide resolved
:obj:`inputs_ids` passed when calling :class:`~transformers.GPTNeoModel` or
:class:`~transformers.TFGPTNeoModel`. Vocabulary size of the model. Defines the different tokens that can
be represented by the `inputs_ids` passed to the forward method of :class:`~transformers.GPTNeoModel`.
attn_layers (:obj:`Tuple[str]`, `optional`, defaults to :obj:`("global","local","global","local","global","local","global","local","global","local","global","local","global","local","global","local","global","local","global","local","global","local","global","local")`):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Our training code produces a config file that includes the entry "attention_types" : [[["global", "local"], 16]]. I would highly recommend using attn_layers = [copy.copy(e) for _ in range(args.attention_types[1]) for e in args.attention_types[0]]. This has the added advantage of allowing support for additional layer types that aren't used in these pretrained models but are implemented in the repo, like Mixture of Experts and Linear Attention.

docs/source/model_doc/gpt_neo.rst Outdated Show resolved Hide resolved
docs/source/model_doc/gpt_neo.rst Outdated Show resolved Hide resolved
README.md Outdated Show resolved Hide resolved
docs/source/index.rst Outdated Show resolved Hide resolved
docs/source/pretrained_models.rst Outdated Show resolved Hide resolved
Comment on lines 102 to 127
attn_layers=(
"global",
"local",
"global",
"local",
"global",
"local",
"global",
"local",
"global",
"local",
"global",
"local",
"global",
"local",
"global",
"local",
"global",
"local",
"global",
"local",
"global",
"local",
"global",
"local",
),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Following up on my comment above, I think this should just be removed as an argument and entirely calculated internally.

@StellaAthena
Copy link
Contributor

StellaAthena commented Mar 27, 2021

An update from our end: We got the 2.7B model up and running in our evaluation harness! Unfortunately the run revealed that the harness is bugged...

Running it by hand gives reasonable-looking results, but I don't know how much I should trust myself to judge that.

@leogao2
Copy link
Contributor

leogao2 commented Mar 27, 2021

(to clarify: the bugs in eval harness were introduced by a series of pretty aggressive optimizations i implemented just a few hours earlier today)

@Xirider
Copy link

Xirider commented Mar 28, 2021

I tried finetuning the model with deepspeed and gradient checkpointing, but unlike with GPT2, the loss explodes. I used the default run_clm.py from the examples folder, but added one line to activate gradient checkpointing. Here is then the command i ran:

deepspeed --num_gpus=1 run_clm.py \
--deepspeed ds_config_gptneo.json \
--model_name_or_path valhalla/gpt_neo_2.7B \
--train_file train.csv \
--validation_file validation.csv \
--do_train \
--do_eval \
--fp16 \
--overwrite_cache \
--evaluation_strategy="steps" \
--output_dir finetuned \
--num_train_epochs 2 \
--eval_steps 15 \
--gradient_accumulation_steps 2 \
--per_device_train_batch_size 4 \
--use_fast_tokenizer False \
--learning_rate 1e-05 \
--adam_beta1 0.9 \
--adam_beta2 0.95 \
--weight_decay 0.1 \
--warmup_steps 50

Here is my ds_config_gptneo.json (is almost the default, except for a lower min_loss_scaling, otherwise i got overflows) (optimizer and warmup hps are overwritten by the flags above):


{
    "fp16": {
        "enabled": true,
        "loss_scale": 0,
        "loss_scale_window": 1000,
        "initial_scale_power": -3,
        "hysteresis": 2,
        "min_loss_scale": -1000
    },
    "zero_optimization": {
        "stage": 2,
        "allgather_partitions": true,
        "allgather_bucket_size": 5e7,
        "overlap_comm": true,
        "reduce_scatter": true,
        "reduce_bucket_size": 5e7,
        "contiguous_gradients": true,
        "cpu_offload": true
    },
    "optimizer": {
        "type": "AdamW",
        "params": {
            "lr": 0.00001,
            "betas": [
                0.9,
                0.95
            ],
            "eps": 1e-8,
            "weight_decay": 0.1
        }
    },
    "scheduler": {
        "type": "WarmupLR",
        "params": {
            "warmup_min_lr": 0,
            "warmup_max_lr": 0.00001,
            "warmup_num_steps":  50
        }
    },
    "steps_per_print": 1000,
    "wall_clock_breakdown": false
}

I tried the exact hyperparameters as well that EleutherAi used, with long warmup phases, but it is still the same. If the learning rate is low enough the loss doesn't change and once its big enough, it immediately explodes. I also did an hyperparameter sweep with the same result. Could this be an issue with the model implementation, as finetuning with EleutherAi's implementation in Mesh Tensorflow on Colab seems to work?

Here are the exact steps that i did (on the bottom half part): https://github.com/Xirider/finetune-gpt2xl

@patil-suraj
Copy link
Contributor Author

patil-suraj commented Mar 28, 2021

hi @Xirider let me take a look, but meanwhile could you try without fp16 ?

@Xirider
Copy link

Xirider commented Mar 28, 2021

Hi, yes, i will try it

@Xirider
Copy link

Xirider commented Mar 28, 2021

Hm, setting no fp16 doesn't work with Zero:
AssertionError: DeepSpeedConfig: ZeRO is only supported if fp16 is enabled.
And without deepspeed's zero i don't think i have enough gpu memory.

@sdtblck
Copy link

sdtblck commented Mar 29, 2021

One thing I've caught testing the neo model is that if i try to add a padding token to the tokenizer after loading it from pretrained (i.e to predict batches instead of a single sequence at a time), then i get:

RuntimeError: CUDA error: device-side assert triggered

I guess because the tokenizer vocabulary is different to the way it was initialized. I'm not sure if this is a HF-wide problem (although I don't recall this being a problem with GPT2Tokenizer.from_pretrained('gpt2')) or specific to neo, but here is the code to reproduce the error:

import torch
from transformers import GPTNeoForCausalLM, GPT2Tokenizer
ckpt_2b = "EleutherAI/gpt_neo_2-7B"
tokenizer = GPT2Tokenizer.from_pretrained(ckpt_2b)
tokenizer.add_special_tokens({'pad_token': '<|padding|>'})
ids = tokenizer("hello world", return_tensors="pt").input_ids.to("cuda")

@sdtblck
Copy link

sdtblck commented Mar 29, 2021

maybe I'm just going insane, or doing something stupid, because swapping out ckpt_2b for 'gpt2' is giving the same error. We never had this problem training with gpt-neox. Can anyone reproduce, and if so, should I open up a new issue?

@LysandreJik
Copy link
Member

Hey @sdtblck! I think the issue here is because you're adding a new token to your tokenizer (so you're extending your vocab), but you're not resizing the token embedding matrix.

When you're creating the GPT-2 tokenizer from your checkpoint, you should have a tokenizer size of 50257:

from transformers import GPTNeoForCausalLM, GPT2Tokenizer
ckpt_2b = "EleutherAI/gpt_neo_2-7B"
tokenizer = GPT2Tokenizer.from_pretrained(ckpt_2b)
print(len(tokenizer))
# 50257

That's the same size as the model token embedding matrix:

print(model.get_input_embeddings())
# Embedding(50257, 2560)

When adding a new token, you should also resize the token embedding matrix alongside it. Otherwise you'll get some index out of range issues, as you'll be trying to obtain the 50258th row of a matrix with 50257 rows. Please add the following line to your code, once you have added a token to your tokenizer and instantiated your model:

model.resize_token_embeddings(len(tokenizer))

Everything should be working smoothly now :)

@sdtblck
Copy link

sdtblck commented Mar 30, 2021

Hm, @LysandreJik so doing that does make the error to go away, but sampling with the model when I've added padding tokens seems to cause almost everything in the prediction to become padding. Let me know if i should take this somewhere else btw, don't want to clog up this PR if this issue doesn't relate to it at all.

predict below is pretty much just a wrapper around model.generate()

prompt = "Q: What is the meaning of life? A:"

gen_text = predict(prompt)
print('-'*100)
print(gen_text)

tokenizer.add_special_tokens({'pad_token': '<|padding|>'})
model.resize_token_embeddings(len(tokenizer))
model.half()

gen_text = predict(prompt)
print('-'*100)
print(gen_text)

Outputs:

Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.

----------------------------------------------------------------------------------------------------
Q: What is the meaning of life? A: It is the sum total of the events and happenings which lead to the end of this human life. A person dies because of the event or occurrence which gives birth to his life. In other words, every time a person dies he brings a new life beginning from his own death. In short, if something happens in a human life, it will lead to a life, but if there is no event or occurrence, it will lead to death. Every life matters greatly - everyone has their own life. Life is a measure of happiness, a measure of fulfillment, and a measure of the value and the quality of a person. It is a reflection of everything that has led to a person's development; therefore, Column 1 of the book contains the questions, "What is the meaning of life?" and "What is happiness?" Column 2 contains the answers. The third column contains the answers taken from the column of questions raised by the readers.

Q: What is the meaning of life? A: It is the sum total of the events and happenings which lead to the end of this human life. A person dies because of the event or occurrence which gives birth to his life. In other words, every time a person dies he brings a new life beginning from his own death. In short, if something happens in a human life, it will lead to a life, but if there is no event or occurrence, it will lead to death. Every life matters greatly - everyone has their
----------------------------------------------------------------------------------------------------
Q: What is the meaning of life? A: It<|padding|><|padding|><|padding|><|padding|><|padding|><|padding|><|padding|><|padding|><|padding|><|padding|><|padding|><|padding|><|padding|><|padding|><|padding|><|padding|><|padding|><|padding|><|padding|><|padding|><|padding|><|padding|><|padding|><|padding|><|padding|><|padding|><|padding|><|padding|><|padding|><|padding|><|padding|><|padding|><|padding|><|padding|> ... ```

@patil-suraj
Copy link
Contributor Author

patil-suraj commented Mar 30, 2021

Hi @sdtblck

For batch generation with GPT like models, the text should be padded to the left.

this is how batch generation works

model.config.pad_token_id = tokenizer.pad_token_id
tokenizer.padding_side = "left"

inputs = tokenizer(sentences, return_tensors="pt", padding=True)
outputs = model.generate(
    input_ids=inputs["input_ids"],
    attention_mask=inputs["attention_mask"]
)

@patil-suraj
Copy link
Contributor Author

patil-suraj commented Mar 30, 2021

Also, the actual vocab size of the model is 50257 so token ids range from 0 to 50256. This <|padding|> padding token is not in the embedding matrix, so I doubt if generation will work as expected when using <|padding|> as pad token. Instead, this is what we can do, set the eos_token as pad token and set the padding side to left.

tokenizer.pad_token_id = tokenizer.eos_token
model.config.pad_token_id = tokenizer.eos_token_id

tokenizer.padding_side = "left"

inputs = tokenizer(sentences, return_tensors="pt", padding=True)
gen_tokens = model.generate(
  inputs["input_ids"],
  attention_mask=inputs["attention_mask"]
)

This should work. Or feel free to open an issue if this is not working.

patil-suraj and others added 3 commits March 30, 2021 18:34
Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
Copy link
Member

@LysandreJik LysandreJik left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM! Thanks for the very quick implementation @patil-suraj.

Follow up PRs will cleanup some cosmetics and add more robust tests for the attention mechanism - parity has been achieved with EleutherAI's implementation on inference, and it has been verified that the model can be fine-tuned.

Merging!

@LysandreJik LysandreJik merged commit 8602643 into huggingface:master Mar 30, 2021
@patil-suraj
Copy link
Contributor Author

patil-suraj commented Mar 30, 2021

@StellaAthena

The convert_gpt2_original_tf_checkpoint_to_pytorch.py now works with the GPT-Neo config, it reads the neo config and initializess HF config from that. Should be now easy to convert the mesh-tf models to PT.

@sink-chan
Copy link

@StellaAthena

The convert_gpt2_original_tf_checkpoint_to_pytorch.py now works with the GPT-Neo config, it reads the neo config and initializess HF config from that. Should be now easy to convert the mesh-tf models to PT.

Do you by any chance have an example input/output with the conversion script? I was having trouble getting the new code to work with the default configs in the gpt-neo repo.

@StellaAthena
Copy link
Contributor

There are models listed on the eleutherai HuggingFace account that AFAIK we did not post. Are these the pretrained models @patil-suraj had been hosting?

@sink-chan
Copy link

I was referring to the pre-trained models posted here: https://the-eye.eu/public/AI/gptneo-release/

@LysandreJik
Copy link
Member

Hi @StellaAthena, which models are you talking about? The only two models available are the 1.3B and the 2.7B versions.

@BigSalmon2
Copy link

Hi. I'm getting this issue on colab when trying to import it:

cannot import name 'GPTNeoForCausalLM' from 'transformers' (unknown location)

@LysandreJik
Copy link
Member

LysandreJik commented Mar 30, 2021

Hi @zanderbush, please make sure you:

  • are using the master branch, as it is only available from source as of now
  • have torch installed in your environment, as otherwise the model cannot be imported. Actually that's untrue. You can import it, but can't correctly instantiate it and the error should be more explicit.

@BigSalmon2
Copy link

@LysandreJik Thank you! That worked. I face a new issue, however, as I look to return the most probable next token. This works with the typical GPT-2, but not this for some reason:

import torch
prompt = """In the"""
prompt = prompt.strip()
text = tokenizer.encode(prompt)
myinput, past = torch.tensor([text]), None
logits, past = model(myinput, past_key_values = past)
logits = logits[0,-1]
probabilities = torch.nn.functional.softmax(logits)
best_logits, best_indices = logits.topk(10)
best_words = [tokenizer.decode([idx.item()]) for idx in best_indices]
text.append(best_indices[0].item())
best_probabilities = probabilities[best_indices].tolist()
words = []
for i in range(10):
    m = (best_words[i])
    print(m)

TypeError: string indices must be integers

@leogao2
Copy link
Contributor

leogao2 commented Mar 31, 2021

Why is n_ctx not present on GPTNeoConfig? Afaict, max_position_embeddings is the closest replacement but I just wanted to double check that it's reasonable to use it as a guarantee that the model can handle sequences of that length.

@patil-suraj
Copy link
Contributor Author

@leogao2
yes, max_position_embeddings is the replacement of n_ctx, and the positional embedding are initialized using that value so it does accept the specified sequence length (2048), see
https://github.com/huggingface/transformers/blob/master/src/transformers/models/gpt_neo/modeling_gpt_neo.py#L651

@LysandreJik
Copy link
Member

@zanderbush I believe this is unrelated to GPT Neo and related to your code instead. Please open a new issue with a reproducible code example (tokenizer and model defined). Thank you!

Iwontbecreative pushed a commit to Iwontbecreative/transformers that referenced this pull request Jul 15, 2021
* lets begin

* boom boom

* fix out proj in attn

* fix attention

* fix local attention

* add tokenizer

* fix imports

* autotokenizer

* fix checkpoint name

* cleanup

* more clean-up

* more cleanup

* output attentions

* fix attn mask creation

* fix imports

* config doc

* add tests

* add slow tests

* quality

* add conversion script

* copyright

* typo

* another bites the dust

* fix attention tests

* doc

* add embed init in convert function

* fix copies

* remove tokenizer

* enable caching

* address review comments

* improve config and create attn layer list internally

* more consistent naming

* init hf config from mesh-tf config json file

* remove neo tokenizer from doc

* handle attention_mask in local attn layer

* attn_layers => attention_layers

* add tokenizer_class in config

* fix docstring

* raise if len of attention_layers is not same as num_layers

* remove tokenizer_class from config

* more consistent naming

* fix doc

* fix checkpoint names

* fp16 compat

* Apply suggestions from code review

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>

Co-authored-by: Lysandre Debut <lysandre@huggingface.co>
Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet