diff --git a/python/paddle/hapi/model.py b/python/paddle/hapi/model.py index daab26013350e..b029c4d3a1c3a 100644 --- a/python/paddle/hapi/model.py +++ b/python/paddle/hapi/model.py @@ -892,17 +892,14 @@ def eval_batch(self, inputs, labels=None): if self._nranks > 1: outputs = [_all_gather(o) for o in to_list(outputs)] labels = [_all_gather(l) for l in labels] - metrics = [] - for metric in self.model._metrics: - # cut off padding value. - if ( - self.model._test_dataloader is not None - and self._nranks > 1 - and isinstance(self.model._test_dataloader, DataLoader) + + if self.model._test_dataloader is not None and isinstance( + self.model._test_dataloader, DataLoader ): total_size = len(self.model._test_dataloader.dataset) samples = outputs[0].shape[0] current_count = self._merge_count.get(self.mode + '_total', 0) + if current_count + samples >= total_size: outputs = [ o[: int(total_size - current_count)] for o in outputs @@ -918,6 +915,9 @@ def eval_batch(self, inputs, labels=None): self._merge_count[self.mode + '_total'] += samples self._merge_count[self.mode + '_batch'] = samples + metrics = [] + for metric in self.model._metrics: + # cut off padding value. metric_outs = metric.compute(*(to_list(outputs) + labels)) m = metric.update(*[to_numpy(m) for m in to_list(metric_outs)]) metrics.append(m)