Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Upstream/2 write a predictor class to mimic the gec model #172

Open
wants to merge 77 commits into
base: master
Choose a base branch
from
Open
Changes from 1 commit
Commits
Show all changes
77 commits
Select commit Hold shift + click to select a range
aa05578
Create Python package & bump to Python 3.8
Sep 14, 2022
03ab20e
Address pr comments
Sep 14, 2022
ad1d4c0
Address pr comments
Sep 14, 2022
47b9779
Merge pull request #2 from EducationalTestingService/create_python_pa…
damien2012eng Sep 14, 2022
4a2ec3d
add unit tests for tokenization file
Sep 7, 2022
87a4997
Address pr comments
Sep 14, 2022
67d2256
add unit tests for token_indexer
Sep 14, 2022
6e57109
Address pr comments
Sep 15, 2022
f311813
Address pr comments
Sep 15, 2022
5c0cf10
Merge pull request #1 from EducationalTestingService/features/add_uni…
damien2012eng Sep 16, 2022
b00e946
Add unit tests for pretrained BERT embedder
Frost45 Sep 9, 2022
05407f3
Add unit tests for pretrained RoBERTa embedder
Frost45 Sep 14, 2022
b649aeb
Add unit tests for seq2labels model
Frost45 Sep 20, 2022
6801c42
Update tokenization tests to use AllenNLP test modules
Frost45 Sep 23, 2022
2964b05
Address PR comments
Frost45 Sep 26, 2022
a3c74c2
Add CI plan
Frost45 Sep 26, 2022
dbc39e2
Addressed PR comments
Frost45 Sep 27, 2022
ca950c3
Unit test for GecModel Prediction.
Sep 27, 2022
e3ea139
Removing duplicate import.
Sep 27, 2022
8291809
Adding readme file so that fixtures dir exists for downloading gector
ksteimel Sep 27, 2022
9ba0aeb
Minor changes to make these tests pass if a cuda device is available.
ksteimel Sep 27, 2022
34386bd
Adding registered names for use by predictor
ksteimel Sep 27, 2022
36ab359
Adding expected test output.
ksteimel Sep 27, 2022
9bd6a7b
Added WIP docstring to GecBERTModel
ksteimel Sep 27, 2022
6f3f7ae
WIP Gec Predictor.
ksteimel Sep 27, 2022
9dd6bf1
words metadata is getting filled if unspecified when text_to_instance…
ksteimel Sep 28, 2022
64cbc8d
Using JustSpacesWordSplitter so that tokenization matches that used b…
ksteimel Sep 28, 2022
10e8069
Decode now adds the corrected sentence to the output dict.
ksteimel Sep 28, 2022
f6a9185
Updating gitignore to prevent adding .th files
ksteimel Sep 28, 2022
657f72b
Adding directory fixture as analogue to model archive
ksteimel Sep 28, 2022
e3cca59
Fixing errors in modeling code now that model.decode adds the origina…
ksteimel Sep 28, 2022
1c0025c
Adding conditional so that no correction is performed in decode if no…
ksteimel Sep 28, 2022
0f204ea
Appending start token when creating instances from json or string.
ksteimel Sep 28, 2022
67ef511
Start token is expected in ouptut.
ksteimel Sep 28, 2022
53a94dd
Drop START_TOKEN from output_dict["words"]. This interferes with the …
ksteimel Sep 28, 2022
86212b7
The outputs now no longer have $START_TOKEN in the corrected sentence…
ksteimel Sep 28, 2022
40619fe
Merge pull request #5 from EducationalTestingService/feature/add_inte…
ksteimel Sep 28, 2022
20a0692
Handling multiple iterations of correction in predictor now.
ksteimel Sep 28, 2022
0b88028
Changed location of weights file so it can be used by gec_predictor a…
ksteimel Sep 28, 2022
0b6b508
setup is now downloading weights file if it does not already exist.
ksteimel Sep 28, 2022
c535c53
Add how to run unit tests on README
Sep 29, 2022
86a7aff
Add regression data for raw and predictions
Sep 29, 2022
469d162
Merge pull request #7 from EducationalTestingService/features/add_reg…
damien2012eng Oct 3, 2022
bf4e209
Add regression test file
Frost45 Oct 6, 2022
2b4513e
Update CI plan to run regression tests
Frost45 Oct 6, 2022
68d900e
Addressed PR comments
Frost45 Oct 10, 2022
1b60b1b
Addressed PR comments
Frost45 Oct 12, 2022
f1db9a9
Addressed PR comments
Frost45 Oct 12, 2022
6f8f23c
Apply suggestions from code review
ksteimel Oct 12, 2022
7427f93
Removing unused imports, adding docstrings.
Oct 12, 2022
73aef1b
Removing unused predictions to labeled_instances method.
Oct 12, 2022
944993c
Updated docstring for decode()
Oct 12, 2022
709ba28
Removed unused imports.
Oct 12, 2022
323d61a
Add environment.yml
damien2012eng Oct 17, 2022
21d496c
versioning starting with 1.0.0
damien2012eng Oct 17, 2022
5da5955
Address PR comments
damien2012eng Oct 17, 2022
b04376c
Adding back import of gec_predictor that shouldn't have been removed
ksteimel Oct 22, 2022
ba588e2
Add how to run unit tests on README
Sep 29, 2022
f9fb4d7
Add regression data for raw and predictions
Sep 29, 2022
fbbcf10
Add regression test file
Frost45 Oct 6, 2022
d523ed9
Update CI plan to run regression tests
Frost45 Oct 6, 2022
e2014bc
Addressed PR comments
Frost45 Oct 10, 2022
9e84cb8
Addressed PR comments
Frost45 Oct 12, 2022
93f2722
Addressed PR comments
Frost45 Oct 12, 2022
bc89564
Add environment.yml
damien2012eng Oct 17, 2022
e96c38d
versioning starting with 1.0.0
damien2012eng Oct 17, 2022
618d686
Address PR comments
damien2012eng Oct 17, 2022
41a1cf0
Modify gec_predictor and seq2labels to work as gec_model does
Frost45 Oct 20, 2022
1bfc956
Add regression test script for predictor
Frost45 Oct 20, 2022
d952b18
Modify gec_predictor and seq2labels to work as gec_model does
Frost45 Oct 24, 2022
5d7ef1d
Update CI plan to run on all PRs and add regression tests
Frost45 Oct 24, 2022
d96fb80
Addressed PR comments
Frost45 Oct 24, 2022
c937981
Merge pull request #6 from EducationalTestingService/feature/predicto…
ksteimel Oct 24, 2022
3dabb0f
Merge branch 'master' into feature/fix-gec-predictor
Frost45 Oct 24, 2022
7f46ba1
Addressed PR comments
Frost45 Oct 24, 2022
7c6882b
Merge branch 'feature/fix-gec-predictor' of github.com:EducationalTes…
Frost45 Oct 24, 2022
7c5610d
Merge pull request #13 from EducationalTestingService/feature/fix-gec…
Frost45 Oct 24, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Add regression test script for predictor
  • Loading branch information
Frost45 committed Oct 24, 2022
commit 1bfc956a8071638e2fb30dbaa89a4573eb7a0e45
172 changes: 172 additions & 0 deletions regression_tests/test_regression_data_predictor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,172 @@
import filecmp
from pathlib import Path
import requests
import tempfile
import torch
from tqdm import tqdm

from allennlp.common.testing import ModelTestCase
from allennlp.predictors import Predictor
from allennlp.modules.text_field_embedders import BasicTextFieldEmbedder
from allennlp.data.vocabulary import Vocabulary
from allennlp.data import Token
from allennlp.data.instance import Instance
from allennlp.data.fields import TextField
from allennlp.data.dataset import Batch

from gector.gec_predictor import GecPredictor

# These imports are required so that instantiating the predictor can be done automatically
from gector.datareader import Seq2LabelsDatasetReader
from gector.seq2labels_model import Seq2Labels
from gector.bert_token_embedder import PretrainedBertEmbedder
from gector.tokenizer_indexer import PretrainedBertIndexer
from utils.helpers import read_lines

ORIG_FILE_DIR = Path(__file__).parent / "original"
GOLD_FILE_DIR = Path(__file__).parent / "prediction"
TEST_FIXTURES_DIR_PATH = Path(__file__).parent.parent / "test_fixtures"
# VOCAB_PATH = TEST_FIXTURES_DIR_PATH / "roberta_model" / "vocabulary"


def download_weights():
"""
Downloads model weights from S3 if not already present at path.

Returns
-------
Path
Path to model directory
"""

# Download weights for model archive
weights_url = "https://grammarly-nlp-data-public.s3.amazonaws.com/gector/roberta_1_gectorv2.th"
model_path = TEST_FIXTURES_DIR_PATH / "roberta_model" / "weights.th"
if not model_path.exists():
response = requests.get(weights_url)
with model_path.open("wb") as out_fp:
# Write out data with progress bar
for data in tqdm(response.iter_content()):
out_fp.write(data)
model_path = TEST_FIXTURES_DIR_PATH / "roberta_model"

return model_path


def predict_for_file(input_file, temp_file, model, batch_size=32):
"""
Generates predictions for a single file and store it in a temp file.

Parameters
----------
input_file : str
Path to input file
temp_file : TemporaryFileWrapper
Temp file object
model : GecBERTModel
Initialized model object
batch_size : int, optional
Batch size, by default 32

Returns
-------
int
Total number of corrections made
"""

test_data = read_lines(input_file)
predictions = []
batch = []
for sent in test_data:
batch.append(sent)
if len(batch) == batch_size:
preds = model.predict_batch(batch)
preds_corrected_words = [x["corrected_words"] for x in preds]
# print(preds_corrected_words)
predictions.extend(preds_corrected_words)
batch = []
if batch:
preds = model.predict_batch(batch)
preds_corrected_words = [x["corrected_words"] for x in preds]
# print(preds_corrected_words)
predictions.extend(preds_corrected_words)

# print(predictions)

result_lines = [" ".join(x) for x in predictions]

# with open("timetestpredictor_5iters.txt", "w") as f:
# f.write("\n".join(result_lines) + "\n")

with open(temp_file.name, "w") as f:
f.write("\n".join(result_lines) + "\n")


def compare_files(filename, gold_file, temp_file):
"""
Compares two files and tests that they are equal.

Parameters
----------
filename : str
Name of file being compared
gold_file : str
Path to gold standard file
temp_file : str
Path to file containing generated prediction
"""

assert filecmp.cmp(
gold_file, temp_file, shallow=False
), f"Output of {filename} does not match gold output."
print(filename, "passed.")


def predict_and_compare(model):
"""
Generate predictions for all test files and tests that there are no changes.

Parameters
----------
model : Predictor
Initialized model
"""

for child in ORIG_FILE_DIR.iterdir():
if child.is_file():
input_file = str(ORIG_FILE_DIR.joinpath(child.name))
gold_standard_file = str(GOLD_FILE_DIR.joinpath(child.name))
# Create temp file to store generated output
with tempfile.NamedTemporaryFile() as temp_file:
predict_for_file(input_file, temp_file, model)
compare_files(child.name, gold_standard_file, temp_file.name)


def main():

# Download weights from S3
model_path = download_weights()

# Initialize model
model = Predictor.from_path(model_path, predictor_name="gec-predictor")

# Generate predictions and compare to previous output.
predict_and_compare(model)
# with tempfile.NamedTemporaryFile() as temp_file:
# predict_for_file(
# str(ORIG_FILE_DIR.joinpath("conll14_10.txt")), temp_file, model
# )
# predict_for_file(
# "/Users/skashyap/Work/local/temp/prepositions/ped/preposition_error_detection/test_pipeline/time_test.txt",
# temp_file,
# model,
# )
# # compare_files(
# # "conll14_10.txt",
# # str(GOLD_FILE_DIR.joinpath("conll14_10.txt")),
# # temp_file.name,
# # )


if __name__ == "__main__":
main()