Skip to content
This repository has been archived by the owner on Dec 16, 2022. It is now read-only.

fix online metric calculation for EVALB #956

Merged
merged 1 commit into from
Mar 6, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
fix online metric calculation for EVALB
  • Loading branch information
Mark Neumann committed Mar 6, 2018
commit 1d79812c8ce642408df0e088c751c0d9cd166ebb
65 changes: 19 additions & 46 deletions allennlp/training/metrics/evalb_bracketing_scorer.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,6 @@
from typing import List
import os
import tempfile
import re
import math
import subprocess

from overrides import overrides
Expand Down Expand Up @@ -49,14 +47,12 @@ def __init__(self, evalb_directory_path: str, evalb_param_filename: str = "COLLI
raise ConfigurationError("You must compile the EVALB scorer before using it."
" Run 'make' in the 'scripts/EVALB' directory.")

self._recall_regex = re.compile(r"Bracketing Recall\s+=\s+(\d+\.\d+)")
self._precision_regex = re.compile(r"Bracketing Precision\s+=\s+(\d+\.\d+)")
self._f1_measure_regex = re.compile(r"Bracketing FMeasure\s+=\s+(\d+\.\d+)")
self._header_line = ['ID', 'Len.', 'Stat.', 'Recal', 'Prec.', 'Bracket',
'gold', 'test', 'Bracket', 'Words', 'Tags', 'Accracy']

self._precision = 0.0
self._recall = 0.0
self._f1_measure = 0.0
self._count = 0.0
self._correct_predicted_brackets = 0.0
self._gold_brackets = 0.0
self._predicted_brackets = 0.0

@overrides
def __call__(self, predicted_trees: List[Tree], gold_trees: List[Tree]) -> None: # type: ignore
Expand Down Expand Up @@ -84,36 +80,15 @@ def __call__(self, predicted_trees: List[Tree], gold_trees: List[Tree]) -> None:
f"{gold_path} {predicted_path} > {output_path}"
subprocess.run(command, shell=True, check=True)

recall = math.nan
precision = math.nan
fmeasure = math.nan
with open(output_path) as infile:
for line in infile:
recall_match = self._recall_regex.match(line)
if recall_match:
recall = float(recall_match.group(1))
precision_match = self._precision_regex.match(line)
if precision_match:
precision = float(precision_match.group(1))
f1_measure_match = self._f1_measure_regex.match(line)
if f1_measure_match:
fmeasure = float(f1_measure_match.group(1))
break
if any([math.isnan(recall), math.isnan(precision)]):
raise RuntimeError(f"Call to EVALB produced invalid metrics: recall: "
f"{recall}, precision: {precision}, fmeasure: {fmeasure}")

if math.isnan(fmeasure) and recall == 0.0 and precision == 0.0:
fmeasure = 0.0
elif math.isnan(fmeasure):
raise RuntimeError(f"Call to EVALB produced an invalid f1 measure, "
f"which was not due to zero division: recall: "
f"{recall}, precision: {precision}, fmeasure: {fmeasure}")

self._precision += precision / 100.0
self._recall += recall / 100.0
self._f1_measure += fmeasure / 100.0
self._count += 1
stripped = line.strip().split()
if len(stripped) == 12 and stripped != self._header_line:
# This line contains results for a single tree.
numeric_line = [float(x) for x in stripped]
self._correct_predicted_brackets += numeric_line[5]
self._gold_brackets += numeric_line[6]
self._predicted_brackets += numeric_line[7]

@overrides
def get_metric(self, reset: bool = False):
Expand All @@ -122,18 +97,16 @@ def get_metric(self, reset: bool = False):
-------
The average precision, recall and f1.
"""
metrics = {}
metrics["evalb_precision"] = self._precision / self._count if self._count > 0 else 0.0
metrics["evalb_recall"] = self._recall / self._count if self._count > 0 else 0.0
metrics["evalb_f1_measure"] = self._f1_measure / self._count if self._count > 0 else 0.0
recall = self._correct_predicted_brackets / self._gold_brackets if self._gold_brackets > 0 else 0.0
precision = self._correct_predicted_brackets / self._predicted_brackets if self._gold_brackets > 0 else 0.0
f1_measure = 2 * (precision * recall) / (precision + recall) if precision + recall > 0 else 0

if reset:
self.reset()
return metrics
return {"evalb_recall": recall, "evalb_precision": precision, "evalb_f1_measure": f1_measure}

@overrides
def reset(self):
self._recall = 0.0
self._precision = 0.0
self._f1_measure = 0.0
self._count = 0
self._correct_predicted_brackets = 0.0
self._gold_brackets = 0.0
self._predicted_brackets = 0.0
11 changes: 11 additions & 0 deletions tests/training/metrics/evalb_bracketing_scorer_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,17 @@ def test_evalb_correctly_scores_imperfect_trees(self):
assert metrics["evalb_precision"] == 0.75
assert metrics["evalb_f1_measure"] == 0.75

def test_evalb_correctly_calculates_bracketing_metrics_over_multiple_trees(self):
tree1 = Tree.fromstring("(S (VP (D the) (NP dog)) (VP (V chased) (NP (D the) (N cat))))")
tree2 = Tree.fromstring("(S (NP (D the) (N dog)) (VP (V chased) (NP (D the) (N cat))))")
evalb_scorer = EvalbBracketingScorer("scripts/EVALB/")
evalb_scorer([tree1, tree2], [tree2, tree2])
metrics = evalb_scorer.get_metric()
assert metrics["evalb_recall"] == 0.875
assert metrics["evalb_precision"] == 0.875
assert metrics["evalb_f1_measure"] == 0.875


def test_evalb_with_terrible_trees_handles_nan_f1(self):
# If precision and recall are zero, evalb returns nan f1.
# This checks that we handle the zero division.
Expand Down