Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Miscellaneous fixes to automatic tests #35

Merged
merged 12 commits into from
Jul 22, 2022
Prev Previous commit
Next Next commit
set dtype='auto'
  • Loading branch information
justheuristic committed Jul 22, 2022
commit dfe57ebf9a02713ef1d298f2bef315d0072a04c8
9 changes: 3 additions & 6 deletions tests/test_full_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
@pytest.mark.forked
def test_full_model_exact_match(atol_forward=1e-3, atol_inference=1e-3):
tokenizer = transformers.BloomTokenizerFast.from_pretrained(MODEL_NAME)
model = DistributedBloomForCausalLM.from_pretrained(MODEL_NAME, initial_peers=INITIAL_PEERS)
model = DistributedBloomForCausalLM.from_pretrained(MODEL_NAME, initial_peers=INITIAL_PEERS, torch_dtype="auto")
assert isinstance(model, DistributedBloomForCausalLM)
assert len(model.transformer.h) == model.config.n_layer

Expand All @@ -32,17 +32,14 @@ def test_full_model_exact_match(atol_forward=1e-3, atol_inference=1e-3):
recurrent_outputs.append(sess.step(embs[:, t : t + 1, :]))
recurrent_outputs = torch.cat(recurrent_outputs, dim=1)
recurrent_outputs = model.transformer.ln_f(recurrent_outputs)

dictionary = model.transformer.word_embeddings.weight.t()
recurrent_outputs = recurrent_outputs.to(dictionary.dtype)
recurrent_outputs = (recurrent_outputs @ dictionary).float()
recurrent_outputs = model.lm_head(recurrent_outputs)
assert torch.allclose(recurrent_outputs, parallel_outputs, rtol=0, atol=atol_inference)
logger.info("Inference is consistent with forward")

del model, recurrent_outputs

if REF_NAME:
ref_model = transformers.AutoModelForCausalLM.from_pretrained(REF_NAME)
ref_model = transformers.AutoModelForCausalLM.from_pretrained(REF_NAME, torch_dtype="auto")
dummy_mask = torch.ones_like(test_inputs, dtype=torch.bool)
# note: this creates a dummy mask to make the test compatible with older transformer versions
# prior to https://github.com/huggingface/transformers/pull/17837
Expand Down