Skip to content

Commit

Permalink
Updated clippy settings and bart download script
Browse files Browse the repository at this point in the history
  • Loading branch information
guillaume-be committed Sep 14, 2020
1 parent 44ea6dd commit 882ec17
Show file tree
Hide file tree
Showing 3 changed files with 14 additions and 14 deletions.
2 changes: 1 addition & 1 deletion .travis.yml
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ jobs:
- before_script:
- rustup component add clippy
script:
- cargo clippy --all-targets --all-features -- -D warnings
- cargo clippy --all-targets --all-features -- -D warnings -A clippy::assign_op_pattern
- script:
- cargo build --verbose
- os:
Expand Down
4 changes: 2 additions & 2 deletions src/gpt2/attention.rs
Original file line number Diff line number Diff line change
Expand Up @@ -141,15 +141,15 @@ impl Attention {
) -> (Tensor, Option<Tensor>) {
let mut w = query.matmul(&key);
if self.scale {
w /= (*value.size().last().unwrap() as f64).sqrt();
w = w / (*value.size().last().unwrap() as f64).sqrt();
}

let (nd, ns) = (w.size()[2], w.size()[3]);
let b = self.bias.narrow(2, ns - nd, nd).narrow(3, 0, ns);

let mut w: Tensor = w * &b + 1e4 * (&b - 1);
if let Some(mask) = attention_mask {
w += mask;
w = w + mask;
}
w = w.softmax(-1, Float).apply_t(&self.attn_dropout, train);
let output = w.matmul(&value);
Expand Down
22 changes: 11 additions & 11 deletions utils/download-dependencies_bart_large_mnli.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,14 +39,14 @@
k = k.replace("gamma", "weight").replace("beta", "bias")
nps[k] = np.ascontiguousarray(v.cpu().numpy())

# np.savez(target_path / 'model.npz', **nps)
#
# source = str(target_path / 'model.npz')
# target = str(target_path / 'model.ot')
#
# toml_location = (Path(__file__).resolve() / '..' / '..' / 'Cargo.toml').resolve()
#
# subprocess.call(['cargo', 'run', '--bin=convert-tensor', '--manifest-path=%s' % toml_location, '--', source, target])
#
# os.remove(str(target_path / 'model.bin'))
# os.remove(str(target_path / 'model.npz'))
np.savez(target_path / 'model.npz', **nps)

source = str(target_path / 'model.npz')
target = str(target_path / 'model.ot')

toml_location = (Path(__file__).resolve() / '..' / '..' / 'Cargo.toml').resolve()

subprocess.call(['cargo', 'run', '--bin=convert-tensor', '--manifest-path=%s' % toml_location, '--', source, target])

os.remove(str(target_path / 'model.bin'))
os.remove(str(target_path / 'model.npz'))

0 comments on commit 882ec17

Please sign in to comment.