Skip to content

Commit

Permalink
simpler vector loading and defaults
Browse files Browse the repository at this point in the history
  • Loading branch information
bmccann committed Jan 25, 2017
1 parent 57b48f3 commit 5f8d3ca
Show file tree
Hide file tree
Showing 3 changed files with 9 additions and 11 deletions.
2 changes: 1 addition & 1 deletion snli/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ def __init__(self, config):
seq_in_size *= 2
lin_config = [seq_in_size]*2
self.out = nn.Sequential(
Linear(lin_config),
Linear(*lin_config),
self.relu,
self.dropout,
Linear(*lin_config),
Expand Down
14 changes: 6 additions & 8 deletions snli/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,16 +20,14 @@

train, dev, test = datasets.SNLI.splits(inputs, answers)

if args.word_vectors and os.path.isfile(args.vector_cache):
inputs.build_vocab(train, dev, test, lower=args.lower)
inputs.vocab.vectors = torch.load(args.vector_cache)
else:
if args.word_vectors:
inputs.build_vocab(train, dev, test, vectors=(args.data_cache, args.word_vectors, args.d_embed), lower=args.lower)
inputs.build_vocab(train, dev, test, lower=args.lower)
if args.word_vectors:
if os.path.isfile(args.vector_cache):
inputs.vocab.vectors = torch.load(args.vector_cache)
else:
inputs.vocab.load_vectors(vectors=(args.data_cache, args.word_vectors, args.d_embed))
os.makedirs(os.path.dirname(args.vector_cache), exist_ok=True)
torch.save(inputs.vocab.vectors, args.vector_cache)
else:
inputs.build_vocab(train, dev, test, lower=args.lower)
answers.build_vocab(train)

train_iter, dev_iter, test_iter = data.BucketIterator.splits(
Expand Down
4 changes: 2 additions & 2 deletions snli/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,12 @@ def get_args():
parser.add_argument('--d_embed', type=int, default=300)
parser.add_argument('--d_proj', type=int, default=300)
parser.add_argument('--d_hidden', type=int, default=300)
parser.add_argument('--n_layers', type=int, default=2)
parser.add_argument('--n_layers', type=int, default=1)
parser.add_argument('--log_every', type=int, default=50)
parser.add_argument('--lr', type=float, default=.001)
parser.add_argument('--dev_every', type=int, default=1000)
parser.add_argument('--save_every', type=int, default=1000)
parser.add_argument('--dp_ratio', type=int, default=0.0)
parser.add_argument('--dp_ratio', type=int, default=0.2)
parser.add_argument('--bidirectional', action='store_true', dest='birnn')
parser.add_argument('--preserve-case', action='store_false', dest='lower')
parser.add_argument('--no-projection', action='store_false', dest='projection')
Expand Down

0 comments on commit 5f8d3ca

Please sign in to comment.