-
Notifications
You must be signed in to change notification settings - Fork 2.2k
SRL elmo 5b #1310
SRL elmo 5b #1310
Changes from 7 commits
90e5590
0b06f92
d6cfdb2
0239d48
bd70fff
6ec4080
75a890e
94beb14
492a2e2
c3fb02b
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,40 +1,62 @@ | ||
import os | ||
import sys | ||
|
||
import torch | ||
|
||
sys.path.insert(0, os.path.dirname(os.path.abspath(os.path.join(__file__, os.pardir)))) | ||
import argparse | ||
from allennlp.common import Params | ||
|
||
from allennlp.common.tqdm import Tqdm | ||
from allennlp.common import Params | ||
from allennlp.models.archival import load_archive | ||
from allennlp.data.iterators import BasicIterator | ||
from allennlp.data import DatasetReader | ||
from allennlp.models import Model | ||
from allennlp.models.semantic_role_labeler import write_to_conll_eval_file | ||
from allennlp.modules.elmo import Elmo | ||
|
||
|
||
def main(serialization_directory, device): | ||
def main(serialization_directory: int, | ||
device: int, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Maybe name this |
||
data: str, | ||
prefix: str, | ||
domain: str = None): | ||
""" | ||
serialization_directory : str, required. | ||
The directory containing the serialized weights. | ||
the directory containing the serialized weights. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Did you mean to do all of these capitalization changes? |
||
device: int, default = -1 | ||
The device to run the evaluation on. | ||
the device to run the evaluation on. | ||
data: str, default = None | ||
The data to evaluate on. By default, we use the validation data from | ||
the original experiment. | ||
prefix: str, default="" | ||
The prefix to prepend to the generated gold and prediction files, to distinguish | ||
different models/data. | ||
domain: str, optional (default = None) | ||
If passed, filters the ontonotes evaluation/test dataset to only contain the | ||
specified domain. | ||
""" | ||
torch.set_grad_enabled(False) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Isn't there some |
||
|
||
config = Params.from_file(os.path.join(serialization_directory, "config.json")) | ||
|
||
if domain is not None: | ||
config["dataset_reader"]["domain_identifier"] = domain | ||
prefix = f"{domain}_{prefix}" | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This isn't noted in the docs. |
||
|
||
else: | ||
config["dataset_reader"].pop("domain_identifier", None) | ||
|
||
dataset_reader = DatasetReader.from_params(config['dataset_reader']) | ||
evaluation_data_path = config['validation_data_path'] | ||
|
||
model = Model.load(config, serialization_dir=serialization_directory, cuda_device=device) | ||
evaluation_data_path = data if data else config['validation_data_path'] | ||
|
||
prediction_file_path = os.path.join(serialization_directory, "predictions.txt") | ||
gold_file_path = os.path.join(serialization_directory, "gold.txt") | ||
archive = load_archive(os.path.join(serialization_directory, "model.tar.gz"), cuda_device=device) | ||
model = archive.model | ||
model.eval() | ||
|
||
prediction_file_path = os.path.join(serialization_directory, prefix + "_predictions.txt") | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Suggestion: if prefix is empty, or already ends with an underscore, this behaves a little oddly. You could make that nicer. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The second argument is a string concatenation, rather than multiple arguments to os.join, so I don't think it does? It would just be "/_predictions.txt" There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes, and There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I was doing the review from a tablet yesterday, so I didn't want to type out the whole thing, but I was thinking something like: if prefix and prefix[-1] != '_':
prefix += '_' Then you don't need to add an underscore at the beginning of the filename. Again, this really isn't a big deal; it's up to you. |
||
gold_file_path = os.path.join(serialization_directory, prefix + "_gold.txt") | ||
prediction_file = open(prediction_file_path, "w+") | ||
gold_file = open(gold_file_path, "w+") | ||
|
||
# Load the evaluation data and index it. | ||
print("Reading evaluation data from {}".format(evaluation_data_path)) | ||
# load the evaluation data and index it. | ||
print("reading evaluation data from {}".format(evaluation_data_path)) | ||
instances = dataset_reader.read(evaluation_data_path) | ||
iterator = BasicIterator(batch_size=32) | ||
iterator.index_with(model.vocab) | ||
|
@@ -49,13 +71,13 @@ def main(serialization_directory, device): | |
for instance, prediction in zip(instances, model_predictions): | ||
fields = instance.fields | ||
try: | ||
# Most sentences have a verbal predicate, but not all. | ||
# most sentences have a verbal predicate, but not all. | ||
verb_index = fields["verb_indicator"].labels.index(1) | ||
except ValueError: | ||
verb_index = None | ||
|
||
gold_tags = fields["tags"].labels | ||
sentence = fields["tokens"].tokens | ||
sentence = [x.text for x in fields["tokens"].tokens] | ||
|
||
write_to_conll_eval_file(prediction_file, gold_file, | ||
verb_index, sentence, prediction, gold_tags) | ||
|
@@ -64,10 +86,12 @@ def main(serialization_directory, device): | |
|
||
if __name__ == "__main__": | ||
|
||
parser = argparse.ArgumentParser(description="Write CONLL format SRL predictions" | ||
parser = argparse.ArgumentParser(description="write conll format srl predictions" | ||
" to file from a pretrained model.") | ||
parser.add_argument('--path', type=str, help='The serialization directory.') | ||
parser.add_argument('--device', type=int, default=-1, help='The device to load the model onto.') | ||
|
||
parser.add_argument('--path', type=str, help='the serialization directory.') | ||
parser.add_argument('--device', type=int, default=-1, help='the device to load the model onto.') | ||
parser.add_argument('--data', type=str, default=None, help='A directory containing a dataset to evaluate on.') | ||
parser.add_argument('--prefix', type=str, default="", help='A prefix to distinguish model outputs.') | ||
parser.add_argument('--domain', type=str, default=None, help='An optional domain to filter by for producing results.') | ||
args = parser.parse_args() | ||
main(args.path, args.device) | ||
main(args.path, args.device, args.data, args.prefix, args.domain) |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,76 @@ | ||
{ | ||
"dataset_reader": { | ||
"type": "srl", | ||
"token_indexers": { | ||
"elmo": { | ||
"type": "elmo_characters" | ||
} | ||
} | ||
}, | ||
"train_data_path": ${SRL_TRAIN_DATA_PATH}, | ||
"validation_data_path": ${SRL_VALIDATION_DATA_PATH}, | ||
"model": { | ||
"type": "srl", | ||
"text_field_embedder": { | ||
"elmo": { | ||
"type": "elmo_token_embedder", | ||
"options_file": "https://s3-us-west-2.amazonaws.com/allennlp/models/elmo/2x4096_512_2048cnn_2xhighway_5.5B/elmo_2x4096_512_2048cnn_2xhighway_5.5B_options.json", | ||
"weight_file": "https://s3-us-west-2.amazonaws.com/allennlp/models/elmo/2x4096_512_2048cnn_2xhighway_5.5B/elmo_2x4096_512_2048cnn_2xhighway_5.5B_weights.hdf5", | ||
"do_layer_norm": false, | ||
"dropout": 0.1 | ||
} | ||
}, | ||
"initializer": [ | ||
[ | ||
"tag_projection_layer.*weight", | ||
{ | ||
"type": "orthogonal" | ||
} | ||
] | ||
], | ||
// NOTE: This configuration is correct, but slow. | ||
// If you are interested in training the SRL model | ||
// from scratch, you should use the 'alternating_lstm_cuda' | ||
// encoder instead. | ||
"encoder": { | ||
"type": "alternating_lstm", | ||
"input_size": 1124, | ||
"hidden_size": 300, | ||
"num_layers": 8, | ||
"recurrent_dropout_probability": 0.1, | ||
"use_input_projection_bias": false | ||
}, | ||
"binary_feature_dim": 100, | ||
"regularizer": [ | ||
[ | ||
".*scalar_parameters.*", | ||
{ | ||
"type": "l2", | ||
"alpha": 0.001 | ||
} | ||
] | ||
] | ||
}, | ||
"iterator": { | ||
"type": "bucket", | ||
"sorting_keys": [ | ||
[ | ||
"tokens", | ||
"num_tokens" | ||
] | ||
], | ||
"batch_size": 80 | ||
}, | ||
"trainer": { | ||
"num_epochs": 500, | ||
"grad_clipping": 1.0, | ||
"patience": 200, | ||
"num_serialized_models_to_keep": 10, | ||
"validation_metric": "+f1-measure-overall", | ||
"cuda_device": 0, | ||
"optimizer": { | ||
"type": "adadelta", | ||
"rho": 0.95 | ||
} | ||
} | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's pretty cool that the new model picks up this case - I hadn't thought that the previous model was incorrect, but in this sentence, "will" is a light verb and doesn't really hold any semantic meaning (I think). Progress 🎉