From 5085c44b91c42c7e65773472f010358e7d83b08f Mon Sep 17 00:00:00 2001 From: Kaipeng Deng Date: Tue, 6 Jul 2021 16:29:46 +0800 Subject: [PATCH] make DataLoader warning less noisy. test=develop (#33712) --- python/paddle/fluid/dataloader/fetcher.py | 24 +++++++++++++++-------- 1 file changed, 16 insertions(+), 8 deletions(-) diff --git a/python/paddle/fluid/dataloader/fetcher.py b/python/paddle/fluid/dataloader/fetcher.py index 05382b04dc457..8ccec81810a0a 100644 --- a/python/paddle/fluid/dataloader/fetcher.py +++ b/python/paddle/fluid/dataloader/fetcher.py @@ -14,8 +14,9 @@ import logging from ..log_helper import get_logger +from collections.abc import Sequence, Mapping -from collections.abc import Sequence +_WARNING_TO_LOG = True class _DatasetFetcher(object): @@ -24,13 +25,17 @@ def __init__(self, dataset, auto_collate_batch, collate_fn, drop_last): self.auto_collate_batch = auto_collate_batch self.collate_fn = collate_fn self.drop_last = drop_last - self._is_warning_logged = False def fetch(self, batch_indices): raise NotImplementedError("'fetch' not implement for class {}".format( self.__class__.__name__)) def _log_warning(self): + # only log warning on GPU 0 when distributed launch + from ...distributed import get_world_size, get_rank + if get_world_size() >= 2 and get_rank() != 0: + return + warn_str = "Detect dataset only contains single fileds, return format " \ "changed since Paddle 2.1. In Paddle <= 2.0, DataLoader add " \ "a list surround output data(e.g. return [data]), and in " \ @@ -77,10 +82,12 @@ def fetch(self, batch_indices): if len(data) == 0 or (self.drop_last and len(data) < len(batch_indices)): raise StopIteration - if not isinstance(data[0], - Sequence) and not self._is_warning_logged: + + global _WARNING_TO_LOG + if not isinstance(data[0], (Sequence, Mapping)) \ + and _WARNING_TO_LOG: self._log_warning() - self._is_warning_logged = True + _WARNING_TO_LOG = False else: data = next(self.dataset_iter) @@ -98,10 +105,11 @@ def fetch(self, batch_indices): if self.auto_collate_batch: data = [self.dataset[idx] for idx in batch_indices] - if not isinstance(data[0], - Sequence) and not self._is_warning_logged: + global _WARNING_TO_LOG + if not isinstance(data[0], (Sequence, Mapping)) \ + and _WARNING_TO_LOG: self._log_warning() - self._is_warning_logged = True + _WARNING_TO_LOG = False else: data = self.dataset[batch_indices]