Skip to content

Commit

Permalink
Try cropping vocab size
Browse files Browse the repository at this point in the history
  • Loading branch information
borzunov committed Aug 8, 2023
1 parent 85fcab0 commit a065dce
Showing 1 changed file with 15 additions and 10 deletions.
25 changes: 15 additions & 10 deletions tests/test_full_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,14 +20,19 @@ def tokenizer():
@pytest.mark.forked
@pytest.mark.parametrize("use_peft", (True, False) if ADAPTER_NAME else (False,))
@pytest.mark.parametrize("pass_empty_tensors", (True, False))
def test_full_model_exact_match(tokenizer, use_peft, pass_empty_tensors, atol_forward=1e-3, atol_inference=1e-3):
def test_full_model_exact_match(
tokenizer, use_peft, pass_empty_tensors, crop_vocab_size=32768, atol_forward=1e-3, atol_inference=1e-3
):
model = AutoDistributedModelForCausalLM.from_pretrained(
MODEL_NAME,
initial_peers=INITIAL_PEERS,
torch_dtype=torch.float32,
active_adapter=ADAPTER_NAME if use_peft else None,
)
config = model.config
if model.config.vocab_size > crop_vocab_size:
logger.warning(f"Cropping embeddings to {crop_vocab_size} tokens to save RAM")
model.resize_token_embeddings(crop_vocab_size)

assert len(model.transformer.h) == model.config.num_hidden_layers

test_inputs = tokenizer("A quick brown fox was minding its own buisness", return_tensors="pt")["input_ids"]
Expand All @@ -42,7 +47,7 @@ def test_full_model_exact_match(tokenizer, use_peft, pass_empty_tensors, atol_fo
recurrent_outputs = []
with model.transformer.h.inference_session(max_length=embs.shape[1]) as sess:
if pass_empty_tensors:
recurrent_outputs.append(sess.step(torch.empty(1, 0, config.hidden_size)))
recurrent_outputs.append(sess.step(torch.empty(1, 0, model.config.hidden_size)))

for t in range(embs.shape[1]):
if t == 4:
Expand All @@ -53,8 +58,8 @@ def test_full_model_exact_match(tokenizer, use_peft, pass_empty_tensors, atol_fo
recurrent_outputs.append(sess.step(embs[:, t : t + 1, :]))

if t == 2 and pass_empty_tensors:
recurrent_outputs.append(sess.step(torch.empty(1, 0, config.hidden_size)))
recurrent_outputs.append(sess.step(torch.empty(1, 0, config.hidden_size)))
recurrent_outputs.append(sess.step(torch.empty(1, 0, model.config.hidden_size)))
recurrent_outputs.append(sess.step(torch.empty(1, 0, model.config.hidden_size)))

recurrent_outputs = torch.cat(recurrent_outputs, dim=1)
recurrent_outputs = model.transformer.ln_f(recurrent_outputs)
Expand All @@ -66,14 +71,14 @@ def test_full_model_exact_match(tokenizer, use_peft, pass_empty_tensors, atol_fo

if REF_NAME:
ref_model = transformers.AutoModelForCausalLM.from_pretrained(
REF_NAME, low_cpu_mem_usage=True, offload_state_dict=True, device_map="auto", torch_dtype=torch.float32
) # device_map="auto" may use disk offloading for some weights
REF_NAME, low_cpu_mem_usage=True, offload_state_dict=True, torch_dtype=torch.float32
)
if ref_model.config.vocab_size > crop_vocab_size:
logger.warning(f"Cropping embeddings to {crop_vocab_size} tokens to save RAM")
ref_model.resize_token_embeddings(crop_vocab_size)
if use_peft:
ref_model = peft.PeftModel.from_pretrained(ref_model, ADAPTER_NAME)
ref_model.train(False)
if config.vocab_size < ref_model.config.vocab_size:
ref_model.resize_token_embeddings(config.vocab_size)
logger.warning(f"Resized the reference model embeddings, new total = {ref_model.config.vocab_size}")

dummy_mask = torch.ones_like(test_inputs, dtype=torch.bool)
# note: this creates a dummy mask to make the test compatible with older transformer versions
Expand Down

0 comments on commit a065dce

Please sign in to comment.