Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

evaluation can receive metric values from anywhere #877

Merged
merged 1 commit into from
Jan 11, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 23 additions & 15 deletions avalanche/training/plugins/evaluation.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from collections import defaultdict
from typing import Union, Sequence, TYPE_CHECKING

from avalanche.evaluation.metric_results import MetricValue
from avalanche.evaluation.metrics import accuracy_metrics, loss_metrics
from avalanche.training.plugins.strategy_plugin import StrategyPlugin
from avalanche.logging import StrategyLogger, InteractiveLogger
Expand Down Expand Up @@ -110,6 +111,9 @@ def __init__(self,
self._active = True
"""If False, no metrics will be collected."""

self._metric_values = []
"""List of metrics that have yet to be processed by loggers."""

@property
def active(self):
return self._active
Expand All @@ -120,31 +124,35 @@ def active(self, value):
"Active must be set as either True or False"
self._active = value

def publish_metric_value(self, mval: MetricValue):
"""Publish a MetricValue to be processed by the loggers."""
self._metric_values.append(mval)

name = mval.name
x = mval.x_plot
val = mval.value
if self.collect_all:
self.all_metric_results[name][0].append(x)
self.all_metric_results[name][1].append(val)
self.last_metric_results[name] = val

def _update_metrics(self, strategy: 'BaseStrategy', callback: str):
"""Call the metric plugins with the correct callback `callback` and
update the loggers with the new metric values."""
if not self._active:
return []

metric_values = []
for metric in self.metrics:
metric_result = getattr(metric, callback)(strategy)
if isinstance(metric_result, Sequence):
metric_values += list(metric_result)
for mval in metric_result:
self.publish_metric_value(mval)
elif metric_result is not None:
metric_values.append(metric_result)

for metric_value in metric_values:
name = metric_value.name
x = metric_value.x_plot
val = metric_value.value
if self.collect_all:
self.all_metric_results[name][0].append(x)
self.all_metric_results[name][1].append(val)

self.last_metric_results[name] = val
self.publish_metric_value(metric_result)

for logger in self.loggers:
getattr(logger, callback)(strategy, metric_values)
return metric_values
getattr(logger, callback)(strategy, self._metric_values)
self._metric_values = []

def get_last_metrics(self):
"""
Expand Down
13 changes: 12 additions & 1 deletion tests/training/test_plugins.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,11 @@
from avalanche.benchmarks import nc_benchmark, GenericCLScenario, \
benchmark_with_validation_stream
from avalanche.benchmarks.utils.data_loader import TaskBalancedDataLoader
from avalanche.evaluation.metric_results import MetricValue
from avalanche.evaluation.metrics import Mean
from avalanche.logging import TextLogger
from avalanche.models import BaseModel
from avalanche.training.plugins import StrategyPlugin
from avalanche.training.plugins import StrategyPlugin, EvaluationPlugin
from avalanche.training.plugins.lr_scheduling import LRSchedulerPlugin
from avalanche.training.strategies import Naive

Expand Down Expand Up @@ -542,5 +543,15 @@ def get_features(self, x):
return x


class EvaluationPluginTest(unittest.TestCase):
def test_publish_metric(self):
ep = EvaluationPlugin()
mval = MetricValue(self, 'metric', 1.0, 0)
ep.publish_metric_value(mval)

# check key exists
assert len(ep.get_all_metrics()['metric'][1]) == 1


if __name__ == '__main__':
unittest.main()