diff --git a/snli/model.py b/snli/model.py index 8a17baa9df..4de4f100e2 100644 --- a/snli/model.py +++ b/snli/model.py @@ -50,7 +50,7 @@ def __init__(self, config): seq_in_size *= 2 lin_config = [seq_in_size]*2 self.out = nn.Sequential( - Linear(lin_config), + Linear(*lin_config), self.relu, self.dropout, Linear(*lin_config), diff --git a/snli/train.py b/snli/train.py index eefaa56a3f..06dfe4d109 100644 --- a/snli/train.py +++ b/snli/train.py @@ -20,16 +20,14 @@ train, dev, test = datasets.SNLI.splits(inputs, answers) -if args.word_vectors and os.path.isfile(args.vector_cache): - inputs.build_vocab(train, dev, test, lower=args.lower) - inputs.vocab.vectors = torch.load(args.vector_cache) -else: - if args.word_vectors: - inputs.build_vocab(train, dev, test, vectors=(args.data_cache, args.word_vectors, args.d_embed), lower=args.lower) +inputs.build_vocab(train, dev, test, lower=args.lower) +if args.word_vectors: + if os.path.isfile(args.vector_cache): + inputs.vocab.vectors = torch.load(args.vector_cache) + else: + inputs.vocab.load_vectors(vectors=(args.data_cache, args.word_vectors, args.d_embed)) os.makedirs(os.path.dirname(args.vector_cache), exist_ok=True) torch.save(inputs.vocab.vectors, args.vector_cache) - else: - inputs.build_vocab(train, dev, test, lower=args.lower) answers.build_vocab(train) train_iter, dev_iter, test_iter = data.BucketIterator.splits( diff --git a/snli/util.py b/snli/util.py index 592089f199..ac8fe47ff0 100644 --- a/snli/util.py +++ b/snli/util.py @@ -8,12 +8,12 @@ def get_args(): parser.add_argument('--d_embed', type=int, default=300) parser.add_argument('--d_proj', type=int, default=300) parser.add_argument('--d_hidden', type=int, default=300) - parser.add_argument('--n_layers', type=int, default=2) + parser.add_argument('--n_layers', type=int, default=1) parser.add_argument('--log_every', type=int, default=50) parser.add_argument('--lr', type=float, default=.001) parser.add_argument('--dev_every', type=int, default=1000) parser.add_argument('--save_every', type=int, default=1000) - parser.add_argument('--dp_ratio', type=int, default=0.0) + parser.add_argument('--dp_ratio', type=int, default=0.2) parser.add_argument('--bidirectional', action='store_true', dest='birnn') parser.add_argument('--preserve-case', action='store_false', dest='lower') parser.add_argument('--no-projection', action='store_false', dest='projection')