Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Upstream/2 write a predictor class to mimic the gec model #172

Open
wants to merge 77 commits into
base: master
Choose a base branch
from
Open
Changes from 1 commit
Commits
Show all changes
77 commits
Select commit Hold shift + click to select a range
aa05578
Create Python package & bump to Python 3.8
Sep 14, 2022
03ab20e
Address pr comments
Sep 14, 2022
ad1d4c0
Address pr comments
Sep 14, 2022
47b9779
Merge pull request #2 from EducationalTestingService/create_python_pa…
damien2012eng Sep 14, 2022
4a2ec3d
add unit tests for tokenization file
Sep 7, 2022
87a4997
Address pr comments
Sep 14, 2022
67d2256
add unit tests for token_indexer
Sep 14, 2022
6e57109
Address pr comments
Sep 15, 2022
f311813
Address pr comments
Sep 15, 2022
5c0cf10
Merge pull request #1 from EducationalTestingService/features/add_uni…
damien2012eng Sep 16, 2022
b00e946
Add unit tests for pretrained BERT embedder
Frost45 Sep 9, 2022
05407f3
Add unit tests for pretrained RoBERTa embedder
Frost45 Sep 14, 2022
b649aeb
Add unit tests for seq2labels model
Frost45 Sep 20, 2022
6801c42
Update tokenization tests to use AllenNLP test modules
Frost45 Sep 23, 2022
2964b05
Address PR comments
Frost45 Sep 26, 2022
a3c74c2
Add CI plan
Frost45 Sep 26, 2022
dbc39e2
Addressed PR comments
Frost45 Sep 27, 2022
ca950c3
Unit test for GecModel Prediction.
Sep 27, 2022
e3ea139
Removing duplicate import.
Sep 27, 2022
8291809
Adding readme file so that fixtures dir exists for downloading gector
ksteimel Sep 27, 2022
9ba0aeb
Minor changes to make these tests pass if a cuda device is available.
ksteimel Sep 27, 2022
34386bd
Adding registered names for use by predictor
ksteimel Sep 27, 2022
36ab359
Adding expected test output.
ksteimel Sep 27, 2022
9bd6a7b
Added WIP docstring to GecBERTModel
ksteimel Sep 27, 2022
6f3f7ae
WIP Gec Predictor.
ksteimel Sep 27, 2022
9dd6bf1
words metadata is getting filled if unspecified when text_to_instance…
ksteimel Sep 28, 2022
64cbc8d
Using JustSpacesWordSplitter so that tokenization matches that used b…
ksteimel Sep 28, 2022
10e8069
Decode now adds the corrected sentence to the output dict.
ksteimel Sep 28, 2022
f6a9185
Updating gitignore to prevent adding .th files
ksteimel Sep 28, 2022
657f72b
Adding directory fixture as analogue to model archive
ksteimel Sep 28, 2022
e3cca59
Fixing errors in modeling code now that model.decode adds the origina…
ksteimel Sep 28, 2022
1c0025c
Adding conditional so that no correction is performed in decode if no…
ksteimel Sep 28, 2022
0f204ea
Appending start token when creating instances from json or string.
ksteimel Sep 28, 2022
67ef511
Start token is expected in ouptut.
ksteimel Sep 28, 2022
53a94dd
Drop START_TOKEN from output_dict["words"]. This interferes with the …
ksteimel Sep 28, 2022
86212b7
The outputs now no longer have $START_TOKEN in the corrected sentence…
ksteimel Sep 28, 2022
40619fe
Merge pull request #5 from EducationalTestingService/feature/add_inte…
ksteimel Sep 28, 2022
20a0692
Handling multiple iterations of correction in predictor now.
ksteimel Sep 28, 2022
0b88028
Changed location of weights file so it can be used by gec_predictor a…
ksteimel Sep 28, 2022
0b6b508
setup is now downloading weights file if it does not already exist.
ksteimel Sep 28, 2022
c535c53
Add how to run unit tests on README
Sep 29, 2022
86a7aff
Add regression data for raw and predictions
Sep 29, 2022
469d162
Merge pull request #7 from EducationalTestingService/features/add_reg…
damien2012eng Oct 3, 2022
bf4e209
Add regression test file
Frost45 Oct 6, 2022
2b4513e
Update CI plan to run regression tests
Frost45 Oct 6, 2022
68d900e
Addressed PR comments
Frost45 Oct 10, 2022
1b60b1b
Addressed PR comments
Frost45 Oct 12, 2022
f1db9a9
Addressed PR comments
Frost45 Oct 12, 2022
6f8f23c
Apply suggestions from code review
ksteimel Oct 12, 2022
7427f93
Removing unused imports, adding docstrings.
Oct 12, 2022
73aef1b
Removing unused predictions to labeled_instances method.
Oct 12, 2022
944993c
Updated docstring for decode()
Oct 12, 2022
709ba28
Removed unused imports.
Oct 12, 2022
323d61a
Add environment.yml
damien2012eng Oct 17, 2022
21d496c
versioning starting with 1.0.0
damien2012eng Oct 17, 2022
5da5955
Address PR comments
damien2012eng Oct 17, 2022
b04376c
Adding back import of gec_predictor that shouldn't have been removed
ksteimel Oct 22, 2022
ba588e2
Add how to run unit tests on README
Sep 29, 2022
f9fb4d7
Add regression data for raw and predictions
Sep 29, 2022
fbbcf10
Add regression test file
Frost45 Oct 6, 2022
d523ed9
Update CI plan to run regression tests
Frost45 Oct 6, 2022
e2014bc
Addressed PR comments
Frost45 Oct 10, 2022
9e84cb8
Addressed PR comments
Frost45 Oct 12, 2022
93f2722
Addressed PR comments
Frost45 Oct 12, 2022
bc89564
Add environment.yml
damien2012eng Oct 17, 2022
e96c38d
versioning starting with 1.0.0
damien2012eng Oct 17, 2022
618d686
Address PR comments
damien2012eng Oct 17, 2022
41a1cf0
Modify gec_predictor and seq2labels to work as gec_model does
Frost45 Oct 20, 2022
1bfc956
Add regression test script for predictor
Frost45 Oct 20, 2022
d952b18
Modify gec_predictor and seq2labels to work as gec_model does
Frost45 Oct 24, 2022
5d7ef1d
Update CI plan to run on all PRs and add regression tests
Frost45 Oct 24, 2022
d96fb80
Addressed PR comments
Frost45 Oct 24, 2022
c937981
Merge pull request #6 from EducationalTestingService/feature/predicto…
ksteimel Oct 24, 2022
3dabb0f
Merge branch 'master' into feature/fix-gec-predictor
Frost45 Oct 24, 2022
7f46ba1
Addressed PR comments
Frost45 Oct 24, 2022
7c6882b
Merge branch 'feature/fix-gec-predictor' of github.com:EducationalTes…
Frost45 Oct 24, 2022
7c5610d
Merge pull request #13 from EducationalTestingService/feature/fix-gec…
Frost45 Oct 24, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Removing unused imports, adding docstrings.
  • Loading branch information
ksteimel committed Oct 12, 2022
commit 7427f93fd3af8db5095662e31493f58789301eb6
125 changes: 115 additions & 10 deletions gector/gec_predictor.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,6 @@
from overrides import overrides
from allennlp.common.util import JsonDict
from allennlp.data import DatasetReader, Instance, Token
from allennlp.data.fields import TextField, SequenceLabelField
from allennlp.data.tokenizers.word_splitter import SpacyWordSplitter
from allennlp.data.tokenizers.word_splitter import JustSpacesWordSplitter
from allennlp.models import Model
from utils.helpers import START_TOKEN
Expand All @@ -21,25 +19,99 @@ class GecPredictor(Predictor):

Note that currently, this is unable to handle ensemble predictions.
"""
def __init__(self, model: Model,

def __init__(self,
model: Model,
dataset_reader: DatasetReader,
language: str = 'en_core_web_sm',
iterations: int = 3) -> None:
"""
Parameters
---------
model: Model
An instantiated `Seq2Labels` model for performing grammatical error correction.
dataset_reader: DatasetReader
An instantiated dataset reader, typically `Seq2LabelsDatasetReader`.
iterations: int
This represents the number of times grammatical error correction is applied to the input.
"""
super().__init__(model, dataset_reader)
#self._tokenizer = SpacyWordSplitter(language=language, pos_tags=True)
self._tokenizer = JustSpacesWordSplitter()
self._iterations = iterations

def predict(self, sentence: str) -> JsonDict:
"""
Generate error correction predictions for a single input string.

Parameters
----------
sentence: str
The input text to perform error correction on.

Returns
-------
JsonDict
A dictionary containing the following keys:
- logits_labels
- logits_d_tags
- class_probabilities_labels
- class_probabilities_d_tags
- max_error_probability
- words
- labels
- d_tags
- corrected_words
For an explanation of each of these see `Seq2Labels.decode()`.
"""
return self.predict_json({"sentence": sentence})

def predict_batch(self, sentences: List[str]) -> JsonDict:
def predict_batch(self, sentences: List[str]) -> List[JsonDict]:
"""
Generate predictions for a sequence of input strings.

Parameters
----------
sentences: List[str]
A list of strings to correct.

Returns
-------
List[JsonDict]
A list of dictionaries, each containing the following keys:
- logits_labels
- logits_d_tags
- class_probabilities_labels
- class_probabilities_d_tags
- max_error_probability
- words
- labels
- d_tags
- corrected_words
For an explanation of each of these see `Seq2Labels.decode()`.
"""
return self.predict_batch_json([{"sentence": sentence} for sentence in sentences])

@overrides
def predict_instance(self, instance: Instance) -> JsonDict:
"""
This special predict_instance method allows for applying the correction model multiple times.

Parameters
---------

Returns
-------
JsonDict
A dictionary containing the following keys:
- logits_labels
- logits_d_tags
- class_probabilities_labels
- class_probabilities_d_tags
- max_error_probability
- words
- labels
- d_tags
- corrected_words
For an explanation of each of these see `Seq2Labels.decode()`.
"""
for i in range(self._iterations):
output = self._model.forward_on_instance(instance)
Expand All @@ -52,6 +124,24 @@ def predict_instance(self, instance: Instance) -> JsonDict:
def predict_batch_instance(self, instances: List[Instance]) -> List[JsonDict]:
"""
This special predict_batch_instance method allows for applying the correction model multiple times.

Parameters
----------

Returns
-------
List[JsonDict]
A list of dictionaries, each containing the following keys:
- logits_labels
- logits_d_tags
- class_probabilities_labels
- class_probabilities_d_tags
- max_error_probability
- words
- labels
- d_tags
- corrected_words
For an explanation of each of these see `Seq2Labels.decode()`.
"""
for i in range(self._iterations):
outputs = self._model.forward_on_instances(instances)
Expand All @@ -66,21 +156,36 @@ def predict_batch_instance(self, instances: List[Instance]) -> List[JsonDict]:
@overrides
def _json_to_instance(self, json_dict: JsonDict) -> Instance:
"""
Expects JSON that looks like ``{"sentence": "..."}``.
Runs the underlying model, and adds the ``"words"`` to the output.
Convert a JsonDict into an Instance.

This is used internally by `self.predict_json()`.

Parameters
----------
json_dict: JsonDict
Expects a dict with a single key "sentence" with a value representing the string to correct.
i.e. ``{"sentence": "..."}``.

Returns
------
Instance
An instance with the following fields:
- tokens
- metadata
- labels
- d_tags
"""
sentence = json_dict["sentence"]
tokens = self._tokenizer.split_words(sentence)
# Add start token to front
tokens = [Token(START_TOKEN)] + tokens
return self._dataset_reader.text_to_instance(tokens)


@overrides
def predictions_to_labeled_instances(self,
instance: Instance,
outputs: Dict[str, numpy.ndarray]) -> List[Instance]:
"""
This method creates an instance out of the predictions generated by the model.
"""
NotImplemented
NotImplemented