Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature/fix gec predictor #13

Merged
merged 18 commits into from
Oct 24, 2022
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Addressed PR comments
  • Loading branch information
Frost45 committed Oct 24, 2022
commit d96fb80a9d70781eece16a5cf9401f4676296e22
83 changes: 56 additions & 27 deletions gector/gec_predictor.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,8 @@ def predict_instance(self, instance: Instance) -> JsonDict:

Parameters
---------
instance: Instance
Frost45 marked this conversation as resolved.
Show resolved Hide resolved
Instance to be predicted

Returns
-------
Expand Down Expand Up @@ -137,6 +139,8 @@ def predict_batch_instance(

Parameters
----------
instances: List[Instance]
Instances to be predicted

Returns
-------
Expand All @@ -157,51 +161,70 @@ def predict_batch_instance(
# Make deep copy of batch
final_batch = instances[:]

# Create list to store final predictions
final_outputs = [None] * len(instances)

prev_preds_dict = {
id: [final_batch[id].fields["tokens"].tokens]
for id in range(len(final_batch))
}

short_ids = [
id
for id in range(len(final_batch))
if len(final_batch[id].fields["tokens"].tokens) < 4
]

for id in short_ids:
final_outputs[id] = {
"logits_labels": None,
"logits_d_tags": None,
"class_probabilities_labels": None,
"class_probabilities_d_tags": None,
"max_error_probability": None,
"words": final_batch[id].fields["tokens"].tokens[1:],
"labels": None,
"d_tags": None,
"corrected_words": final_batch[id].fields["tokens"].tokens[1:],
}

pred_ids = [id for id in range(len(instances)) if id not in short_ids]

for n_iter in range(self._iterations):
# This dictionary keeps track of predictions made in every iteration
prev_preds_dict = {}

# This list contains IDs of sentences to be passed into model
pred_ids = []

# Populating `prev_preds_dict` and `pred_ids`
for id, instance in enumerate(final_batch):
prev_preds_dict[id] = [instance.fields["tokens"].tokens]
# If len(tokens) is less than 4 ($START + 3 tokens)
# we will not correct it.
# It is directly written to output.
if len(instance.fields["tokens"].tokens) < 4:
final_outputs[id] = {
"logits_labels": None,
"logits_d_tags": None,
"class_probabilities_labels": None,
"class_probabilities_d_tags": None,
"max_error_probability": None,
"words": instance.fields["tokens"].tokens[1:],
"labels": None,
"d_tags": None,
"corrected_words": instance.fields["tokens"].tokens[1:],
}
else:
pred_ids.append(id)

# Applying correction model multiple times
for _ in range(self._iterations):

# If no sentences need to be passed into model
if len(pred_ids) == 0:
break

# Create batch of instances to be passed into model
orig_batch = [final_batch[pred_id] for pred_id in pred_ids]

# Pass into model
outputs = self._model.forward_on_instances(orig_batch)

new_pred_ids = []

# Output_ID and Pred_ID in pred_ids
for op_ind, pred_id in enumerate(pred_ids):

# Update final outputs
final_outputs[pred_id] = outputs[op_ind]
orig = final_batch[pred_id]

# Create tokens from corrected words for next iter
tokens = [
Token(word)
for word in ["$START"] + outputs[op_ind]["corrected_words"]
Frost45 marked this conversation as resolved.
Show resolved Hide resolved
]

# Tokens to instance
pred = self._dataset_reader.text_to_instance(tokens)
prev_preds = prev_preds_dict[pred_id]

# If model output is different from previous iter outputs
# Update input batch, append to dict and add to `pred_ids`
if (
orig.fields["tokens"].tokens != pred.fields["tokens"].tokens
and pred.fields["tokens"].tokens not in prev_preds
Expand All @@ -214,6 +237,9 @@ def predict_batch_instance(
prev_preds_dict[pred_id].append(
pred.fields["tokens"].tokens
)
# If model output is same as that in prev iter, update final batch
# but stop passing it into the model for future iters
# This means that no corrections have been made in this iteration
elif (
orig.fields["tokens"].tokens != pred.fields["tokens"].tokens
Frost45 marked this conversation as resolved.
Show resolved Hide resolved
and pred.fields["tokens"].tokens in prev_preds
Expand All @@ -222,7 +248,10 @@ def predict_batch_instance(
final_batch[pred_id] = pred
Frost45 marked this conversation as resolved.
Show resolved Hide resolved
else:
continue

# Update `pred_ids` with new indices to be predicted
pred_ids = new_pred_ids

return sanitize(final_outputs)
Frost45 marked this conversation as resolved.
Show resolved Hide resolved

@overrides
Expand Down