This repository has been archived by the owner on Dec 16, 2022. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 2.2k
SRL elmo 5b #1310
Merged
Merged
SRL elmo 5b #1310
Changes from all commits
Commits
Show all changes
10 commits
Select commit
Hold shift + click to select a range
90e5590
update script to write predictions
0b06f92
add training config for new model
d6cfdb2
add new SRL model
0239d48
tweak sniff test
bd70fff
Merge branch 'master' into srl-elmo-5b
DeNeutoy 6ec4080
more tweaks
75a890e
Merge branch 'master' into srl-elmo-5b
DeNeutoy 94beb14
PR comments
492a2e2
Merge branch 'srl-elmo-5b' of https://github.com/DeNeutoy/allennlp in…
c3fb02b
Merge branch 'master' into srl-elmo-5b
DeNeutoy File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,73 +1,102 @@ | ||
import os | ||
import sys | ||
|
||
import torch | ||
|
||
sys.path.insert(0, os.path.dirname(os.path.abspath(os.path.join(__file__, os.pardir)))) | ||
import argparse | ||
from allennlp.common import Params | ||
|
||
import torch | ||
|
||
from allennlp.common.tqdm import Tqdm | ||
from allennlp.common import Params | ||
from allennlp.models.archival import load_archive | ||
from allennlp.data.iterators import BasicIterator | ||
from allennlp.data import DatasetReader | ||
from allennlp.models import Model | ||
from allennlp.models.semantic_role_labeler import write_to_conll_eval_file | ||
from allennlp.modules.elmo import Elmo | ||
|
||
|
||
def main(serialization_directory, device): | ||
def main(serialization_directory: int, | ||
device: int, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Maybe name this |
||
data: str, | ||
prefix: str, | ||
domain: str = None): | ||
""" | ||
serialization_directory : str, required. | ||
The directory containing the serialized weights. | ||
device: int, default = -1 | ||
The device to run the evaluation on. | ||
data: str, default = None | ||
The data to evaluate on. By default, we use the validation data from | ||
the original experiment. | ||
prefix: str, default="" | ||
The prefix to prepend to the generated gold and prediction files, to distinguish | ||
different models/data. | ||
domain: str, optional (default = None) | ||
If passed, filters the ontonotes evaluation/test dataset to only contain the | ||
specified domain. This overwrites the domain in the config file from the model, | ||
to allow evaluation on domains other than the one the model was trained on. | ||
""" | ||
torch.set_grad_enabled(False) | ||
|
||
config = Params.from_file(os.path.join(serialization_directory, "config.json")) | ||
|
||
if domain is not None: | ||
# Hack to allow evaluation on different domains than the | ||
# model was trained on. | ||
config["dataset_reader"]["domain_identifier"] = domain | ||
prefix = f"{domain}_{prefix}" | ||
else: | ||
config["dataset_reader"].pop("domain_identifier", None) | ||
|
||
dataset_reader = DatasetReader.from_params(config['dataset_reader']) | ||
evaluation_data_path = config['validation_data_path'] | ||
evaluation_data_path = data if data else config['validation_data_path'] | ||
|
||
model = Model.load(config, serialization_dir=serialization_directory, cuda_device=device) | ||
archive = load_archive(os.path.join(serialization_directory, "model.tar.gz"), cuda_device=device) | ||
model = archive.model | ||
model.eval() | ||
|
||
prediction_file_path = os.path.join(serialization_directory, "predictions.txt") | ||
gold_file_path = os.path.join(serialization_directory, "gold.txt") | ||
prediction_file_path = os.path.join(serialization_directory, prefix + "_predictions.txt") | ||
gold_file_path = os.path.join(serialization_directory, prefix + "_gold.txt") | ||
prediction_file = open(prediction_file_path, "w+") | ||
gold_file = open(gold_file_path, "w+") | ||
|
||
# Load the evaluation data and index it. | ||
print("Reading evaluation data from {}".format(evaluation_data_path)) | ||
print("reading evaluation data from {}".format(evaluation_data_path)) | ||
instances = dataset_reader.read(evaluation_data_path) | ||
iterator = BasicIterator(batch_size=32) | ||
iterator.index_with(model.vocab) | ||
|
||
model_predictions = [] | ||
batches = iterator(instances, num_epochs=1, shuffle=False, cuda_device=device) | ||
for batch in Tqdm.tqdm(batches): | ||
result = model(**batch) | ||
predictions = model.decode(result) | ||
model_predictions.extend(predictions["tags"]) | ||
|
||
for instance, prediction in zip(instances, model_predictions): | ||
fields = instance.fields | ||
try: | ||
# Most sentences have a verbal predicate, but not all. | ||
verb_index = fields["verb_indicator"].labels.index(1) | ||
except ValueError: | ||
verb_index = None | ||
|
||
gold_tags = fields["tags"].labels | ||
sentence = fields["tokens"].tokens | ||
|
||
write_to_conll_eval_file(prediction_file, gold_file, | ||
verb_index, sentence, prediction, gold_tags) | ||
prediction_file.close() | ||
gold_file.close() | ||
|
||
with torch.autograd.no_grad(): | ||
iterator = BasicIterator(batch_size=32) | ||
iterator.index_with(model.vocab) | ||
|
||
model_predictions = [] | ||
batches = iterator(instances, num_epochs=1, shuffle=False, cuda_device=device) | ||
for batch in Tqdm.tqdm(batches): | ||
result = model(**batch) | ||
predictions = model.decode(result) | ||
model_predictions.extend(predictions["tags"]) | ||
|
||
for instance, prediction in zip(instances, model_predictions): | ||
fields = instance.fields | ||
try: | ||
# Most sentences have a verbal predicate, but not all. | ||
verb_index = fields["verb_indicator"].labels.index(1) | ||
except ValueError: | ||
verb_index = None | ||
|
||
gold_tags = fields["tags"].labels | ||
sentence = [x.text for x in fields["tokens"].tokens] | ||
|
||
write_to_conll_eval_file(prediction_file, gold_file, | ||
verb_index, sentence, prediction, gold_tags) | ||
prediction_file.close() | ||
gold_file.close() | ||
|
||
if __name__ == "__main__": | ||
|
||
parser = argparse.ArgumentParser(description="Write CONLL format SRL predictions" | ||
parser = argparse.ArgumentParser(description="write conll format srl predictions" | ||
" to file from a pretrained model.") | ||
parser.add_argument('--path', type=str, help='The serialization directory.') | ||
parser.add_argument('--device', type=int, default=-1, help='The device to load the model onto.') | ||
|
||
parser.add_argument('--path', type=str, help='the serialization directory.') | ||
parser.add_argument('--device', type=int, default=-1, help='the device to load the model onto.') | ||
parser.add_argument('--data', type=str, default=None, help='A directory containing a dataset to evaluate on.') | ||
parser.add_argument('--prefix', type=str, default="", help='A prefix to distinguish model outputs.') | ||
parser.add_argument('--domain', type=str, default=None, help='An optional domain to filter by for producing results.') | ||
args = parser.parse_args() | ||
main(args.path, args.device) | ||
main(args.path, args.device, args.data, args.prefix, args.domain) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,76 @@ | ||
{ | ||
"dataset_reader": { | ||
"type": "srl", | ||
"token_indexers": { | ||
"elmo": { | ||
"type": "elmo_characters" | ||
} | ||
} | ||
}, | ||
"train_data_path": ${SRL_TRAIN_DATA_PATH}, | ||
"validation_data_path": ${SRL_VALIDATION_DATA_PATH}, | ||
"model": { | ||
"type": "srl", | ||
"text_field_embedder": { | ||
"elmo": { | ||
"type": "elmo_token_embedder", | ||
"options_file": "https://s3-us-west-2.amazonaws.com/allennlp/models/elmo/2x4096_512_2048cnn_2xhighway_5.5B/elmo_2x4096_512_2048cnn_2xhighway_5.5B_options.json", | ||
"weight_file": "https://s3-us-west-2.amazonaws.com/allennlp/models/elmo/2x4096_512_2048cnn_2xhighway_5.5B/elmo_2x4096_512_2048cnn_2xhighway_5.5B_weights.hdf5", | ||
"do_layer_norm": false, | ||
"dropout": 0.1 | ||
} | ||
}, | ||
"initializer": [ | ||
[ | ||
"tag_projection_layer.*weight", | ||
{ | ||
"type": "orthogonal" | ||
} | ||
] | ||
], | ||
// NOTE: This configuration is correct, but slow. | ||
// If you are interested in training the SRL model | ||
// from scratch, you should use the 'alternating_lstm_cuda' | ||
// encoder instead. | ||
"encoder": { | ||
"type": "alternating_lstm", | ||
"input_size": 1124, | ||
"hidden_size": 300, | ||
"num_layers": 8, | ||
"recurrent_dropout_probability": 0.1, | ||
"use_input_projection_bias": false | ||
}, | ||
"binary_feature_dim": 100, | ||
"regularizer": [ | ||
[ | ||
".*scalar_parameters.*", | ||
{ | ||
"type": "l2", | ||
"alpha": 0.001 | ||
} | ||
] | ||
] | ||
}, | ||
"iterator": { | ||
"type": "bucket", | ||
"sorting_keys": [ | ||
[ | ||
"tokens", | ||
"num_tokens" | ||
] | ||
], | ||
"batch_size": 80 | ||
}, | ||
"trainer": { | ||
"num_epochs": 500, | ||
"grad_clipping": 1.0, | ||
"patience": 200, | ||
"num_serialized_models_to_keep": 10, | ||
"validation_metric": "+f1-measure-overall", | ||
"cuda_device": 0, | ||
"optimizer": { | ||
"type": "adadelta", | ||
"rho": 0.95 | ||
} | ||
} | ||
} |
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's pretty cool that the new model picks up this case - I hadn't thought that the previous model was incorrect, but in this sentence, "will" is a light verb and doesn't really hold any semantic meaning (I think). Progress 🎉