Skip to content

Commit

Permalink
make DataLoader warning less noisy. test=develop (PaddlePaddle#33712)
Browse files Browse the repository at this point in the history
  • Loading branch information
heavengate committed Jul 6, 2021
1 parent c9ae136 commit 5085c44
Showing 1 changed file with 16 additions and 8 deletions.
24 changes: 16 additions & 8 deletions python/paddle/fluid/dataloader/fetcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,9 @@

import logging
from ..log_helper import get_logger
from collections.abc import Sequence, Mapping

from collections.abc import Sequence
_WARNING_TO_LOG = True


class _DatasetFetcher(object):
Expand All @@ -24,13 +25,17 @@ def __init__(self, dataset, auto_collate_batch, collate_fn, drop_last):
self.auto_collate_batch = auto_collate_batch
self.collate_fn = collate_fn
self.drop_last = drop_last
self._is_warning_logged = False

def fetch(self, batch_indices):
raise NotImplementedError("'fetch' not implement for class {}".format(
self.__class__.__name__))

def _log_warning(self):
# only log warning on GPU 0 when distributed launch
from ...distributed import get_world_size, get_rank
if get_world_size() >= 2 and get_rank() != 0:
return

warn_str = "Detect dataset only contains single fileds, return format " \
"changed since Paddle 2.1. In Paddle <= 2.0, DataLoader add " \
"a list surround output data(e.g. return [data]), and in " \
Expand Down Expand Up @@ -77,10 +82,12 @@ def fetch(self, batch_indices):
if len(data) == 0 or (self.drop_last and
len(data) < len(batch_indices)):
raise StopIteration
if not isinstance(data[0],
Sequence) and not self._is_warning_logged:

global _WARNING_TO_LOG
if not isinstance(data[0], (Sequence, Mapping)) \
and _WARNING_TO_LOG:
self._log_warning()
self._is_warning_logged = True
_WARNING_TO_LOG = False
else:
data = next(self.dataset_iter)

Expand All @@ -98,10 +105,11 @@ def fetch(self, batch_indices):
if self.auto_collate_batch:
data = [self.dataset[idx] for idx in batch_indices]

if not isinstance(data[0],
Sequence) and not self._is_warning_logged:
global _WARNING_TO_LOG
if not isinstance(data[0], (Sequence, Mapping)) \
and _WARNING_TO_LOG:
self._log_warning()
self._is_warning_logged = True
_WARNING_TO_LOG = False
else:
data = self.dataset[batch_indices]

Expand Down

0 comments on commit 5085c44

Please sign in to comment.