Skip to content

Commit

Permalink
fix distributed writes (#1793)
Browse files Browse the repository at this point in the history
* distributed tests that write dask arrays

* Change zarr test to synchronous API

* initial go at __setitem__ on array wrappers

* fixes for scipy

* cleanup after merging with upstream/master

* needless duplication of tests to work around pytest bug

* use netcdf_variable instead of get_array()

* use synchronous dask.distributed test harness

* cleanup tests

* per scheduler locks and autoclose behavior for writes

* HDF5_LOCK and CombinedLock

* integration test for distributed locks

* more tests and set isopen to false when pickling

* Fixing style errors.

* ds property on DataStorePickleMixin

* stickler-ci

* compat fixes for other backends

* HDF5_USE_FILE_LOCKING = False in test_distributed

* style fix

* update tests to only expect netcdf4 to work, docstrings, and some cleanup in to_netcdf

* Fixing style errors.

* fix imports
after merge

* fix more import bugs

* update docs

* fix for pynio

* cleanup locks and use pytest monkeypatch for environment variable

* fix failing test using combined lock
  • Loading branch information
Joe Hamman authored Mar 10, 2018
1 parent 8c6a284 commit 2f590f7
Show file tree
Hide file tree
Showing 12 changed files with 331 additions and 64 deletions.
8 changes: 8 additions & 0 deletions doc/dask.rst
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,14 @@ Once you've manipulated a dask array, you can still write a dataset too big to
fit into memory back to disk by using :py:meth:`~xarray.Dataset.to_netcdf` in the
usual way.

.. note::

When using dask's distributed scheduler to write NETCDF4 files,
it may be necessary to set the environment variable `HDF5_USE_FILE_LOCKING=FALSE`
to avoid competing locks within the HDF5 SWMR file locking scheme. Note that
writing netCDF files with dask's distributed scheduler is only supported for
the `netcdf4` backend.

A dataset can also be converted to a dask DataFrame using :py:meth:`~xarray.Dataset.to_dask_dataframe`.

.. ipython:: python
Expand Down
6 changes: 3 additions & 3 deletions doc/io.rst
Original file line number Diff line number Diff line change
Expand Up @@ -672,9 +672,9 @@ files into a single Dataset by making use of :py:func:`~xarray.concat`.

.. note::

Version 0.5 includes support for manipulating datasets that
don't fit into memory with dask_. If you have dask installed, you can open
multiple files simultaneously using :py:func:`~xarray.open_mfdataset`::
Xarray includes support for manipulating datasets that don't fit into memory
with dask_. If you have dask installed, you can open multiple files
simultaneously using :py:func:`~xarray.open_mfdataset`::

xr.open_mfdataset('my/files/*.nc')

Expand Down
7 changes: 7 additions & 0 deletions doc/whats-new.rst
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,13 @@ Documentation
Enhancements
~~~~~~~~~~~~

- Support for writing xarray datasets to netCDF files (netcdf4 backend only)
when using the `dask.distributed <https://distributed.readthedocs.io>`_
scheduler (:issue:`1464`).
By `Joe Hamman <https://github.com/jhamman>`_.


- Fixed to_netcdf when using dask distributed
- Support lazy vectorized-indexing. After this change, flexible indexing such
as orthogonal/vectorized indexing, becomes possible for all the backend
arrays. Also, lazy ``transpose`` is now also supported. (:issue:`1897`)
Expand Down
35 changes: 31 additions & 4 deletions xarray/backends/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,8 @@
from ..core.combine import auto_combine
from ..core.pycompat import basestring, path_type
from ..core.utils import close_on_error, is_remote_uri
from .common import GLOBAL_LOCK, ArrayWriter
from .common import (
HDF5_LOCK, ArrayWriter, CombinedLock, get_scheduler, get_scheduler_lock)

DATAARRAY_NAME = '__xarray_dataarray_name__'
DATAARRAY_VARIABLE = '__xarray_dataarray_variable__'
Expand Down Expand Up @@ -64,9 +65,9 @@ def _default_lock(filename, engine):
else:
# TODO: identify netcdf3 files and don't use the global lock
# for them
lock = GLOBAL_LOCK
lock = HDF5_LOCK
elif engine in {'h5netcdf', 'pynio'}:
lock = GLOBAL_LOCK
lock = HDF5_LOCK
else:
lock = False
return lock
Expand Down Expand Up @@ -129,6 +130,20 @@ def _protect_dataset_variables_inplace(dataset, cache):
variable.data = data


def _get_lock(engine, scheduler, format, path_or_file):
""" Get the lock(s) that apply to a particular scheduler/engine/format"""

locks = []
if format in ['NETCDF4', None] and engine in ['h5netcdf', 'netcdf4']:
locks.append(HDF5_LOCK)
locks.append(get_scheduler_lock(scheduler, path_or_file))

# When we have more than one lock, use the CombinedLock wrapper class
lock = CombinedLock(locks) if len(locks) > 1 else locks[0]

return lock


def open_dataset(filename_or_obj, group=None, decode_cf=True,
mask_and_scale=True, decode_times=True, autoclose=False,
concat_characters=True, decode_coords=True, engine=None,
Expand Down Expand Up @@ -620,8 +635,20 @@ def to_netcdf(dataset, path_or_file=None, mode='w', format=None, group=None,
# if a writer is provided, store asynchronously
sync = writer is None

# handle scheduler specific logic
scheduler = get_scheduler()
if (dataset.chunks and scheduler in ['distributed', 'multiprocessing'] and
engine != 'netcdf4'):
raise NotImplementedError("Writing netCDF files with the %s backend "
"is not currently supported with dask's %s "
"scheduler" % (engine, scheduler))
lock = _get_lock(engine, scheduler, format, path_or_file)
autoclose = (dataset.chunks and
scheduler in ['distributed', 'multiprocessing'])

target = path_or_file if path_or_file is not None else BytesIO()
store = store_open(target, mode, format, group, writer)
store = store_open(target, mode, format, group, writer,
autoclose=autoclose, lock=lock)

if unlimited_dims is None:
unlimited_dims = dataset.encoding.get('unlimited_dims', None)
Expand Down
135 changes: 116 additions & 19 deletions xarray/backends/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@

import contextlib
import logging
import multiprocessing
import threading
import time
import traceback
import warnings
Expand All @@ -14,11 +16,12 @@
from ..core.pycompat import dask_array_type, iteritems
from ..core.utils import FrozenOrderedDict, NdimSizeLenMixin

# Import default lock
try:
from dask.utils import SerializableLock as Lock
from dask.utils import SerializableLock
HDF5_LOCK = SerializableLock()
except ImportError:
from threading import Lock

HDF5_LOCK = threading.Lock()

# Create a logger object, but don't add any handlers. Leave that to user code.
logger = logging.getLogger(__name__)
Expand All @@ -27,8 +30,54 @@
NONE_VAR_NAME = '__values__'


# dask.utils.SerializableLock if available, otherwise just a threading.Lock
GLOBAL_LOCK = Lock()
def get_scheduler(get=None, collection=None):
""" Determine the dask scheduler that is being used.
None is returned if not dask scheduler is active.
See also
--------
dask.utils.effective_get
"""
try:
from dask.utils import effective_get
actual_get = effective_get(get, collection)
try:
from dask.distributed import Client
if isinstance(actual_get.__self__, Client):
return 'distributed'
except (ImportError, AttributeError):
try:
import dask.multiprocessing
if actual_get == dask.multiprocessing.get:
return 'multiprocessing'
else:
return 'threaded'
except ImportError:
return 'threaded'
except ImportError:
return None


def get_scheduler_lock(scheduler, path_or_file=None):
""" Get the appropriate lock for a certain situation based onthe dask
scheduler used.
See Also
--------
dask.utils.get_scheduler_lock
"""

if scheduler == 'distributed':
from dask.distributed import Lock
return Lock(path_or_file)
elif scheduler == 'multiprocessing':
return multiprocessing.Lock()
elif scheduler == 'threaded':
from dask.utils import SerializableLock
return SerializableLock()
else:
return threading.Lock()


def _encode_variable_name(name):
Expand Down Expand Up @@ -77,6 +126,39 @@ def robust_getitem(array, key, catch=Exception, max_retries=6,
time.sleep(1e-3 * next_delay)


class CombinedLock(object):
"""A combination of multiple locks.
Like a locked door, a CombinedLock is locked if any of its constituent
locks are locked.
"""

def __init__(self, locks):
self.locks = tuple(set(locks)) # remove duplicates

def acquire(self, *args):
return all(lock.acquire(*args) for lock in self.locks)

def release(self, *args):
for lock in self.locks:
lock.release(*args)

def __enter__(self):
for lock in self.locks:
lock.__enter__()

def __exit__(self, *args):
for lock in self.locks:
lock.__exit__(*args)

@property
def locked(self):
return any(lock.locked for lock in self.locks)

def __repr__(self):
return "CombinedLock(%r)" % list(self.locks)


class BackendArray(NdimSizeLenMixin, indexing.ExplicitlyIndexed):

def __array__(self, dtype=None):
Expand All @@ -85,7 +167,9 @@ def __array__(self, dtype=None):


class AbstractDataStore(Mapping):
_autoclose = False
_autoclose = None
_ds = None
_isopen = False

def __iter__(self):
return iter(self.variables)
Expand Down Expand Up @@ -168,7 +252,7 @@ def __exit__(self, exception_type, exception_value, traceback):


class ArrayWriter(object):
def __init__(self, lock=GLOBAL_LOCK):
def __init__(self, lock=HDF5_LOCK):
self.sources = []
self.targets = []
self.lock = lock
Expand All @@ -178,11 +262,7 @@ def add(self, source, target):
self.sources.append(source)
self.targets.append(target)
else:
try:
target[...] = source
except TypeError:
# workaround for GH: scipy/scipy#6880
target[:] = source
target[...] = source

def sync(self):
if self.sources:
Expand All @@ -193,9 +273,9 @@ def sync(self):


class AbstractWritableDataStore(AbstractDataStore):
def __init__(self, writer=None):
def __init__(self, writer=None, lock=HDF5_LOCK):
if writer is None:
writer = ArrayWriter()
writer = ArrayWriter(lock=lock)
self.writer = writer

def encode(self, variables, attributes):
Expand Down Expand Up @@ -239,6 +319,9 @@ def set_variable(self, k, v): # pragma: no cover
raise NotImplementedError

def sync(self):
if self._isopen and self._autoclose:
# datastore will be reopened during write
self.close()
self.writer.sync()

def store_dataset(self, dataset):
Expand Down Expand Up @@ -373,27 +456,41 @@ class DataStorePickleMixin(object):

def __getstate__(self):
state = self.__dict__.copy()
del state['ds']
del state['_ds']
del state['_isopen']
if self._mode == 'w':
# file has already been created, don't override when restoring
state['_mode'] = 'a'
return state

def __setstate__(self, state):
self.__dict__.update(state)
self.ds = self._opener(mode=self._mode)
self._ds = None
self._isopen = False

@property
def ds(self):
if self._ds is not None and self._isopen:
return self._ds
ds = self._opener(mode=self._mode)
self._isopen = True
return ds

@contextlib.contextmanager
def ensure_open(self, autoclose):
def ensure_open(self, autoclose=None):
"""
Helper function to make sure datasets are closed and opened
at appropriate times to avoid too many open file errors.
Use requires `autoclose=True` argument to `open_mfdataset`.
"""
if self._autoclose and not self._isopen:

if autoclose is None:
autoclose = self._autoclose

if not self._isopen:
try:
self.ds = self._opener()
self._ds = self._opener()
self._isopen = True
yield
finally:
Expand Down
14 changes: 9 additions & 5 deletions xarray/backends/h5netcdf_.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,8 @@
from ..core import indexing
from ..core.pycompat import OrderedDict, bytes_type, iteritems, unicode_type
from ..core.utils import FrozenOrderedDict, close_on_error
from .common import DataStorePickleMixin, WritableCFDataStore, find_root
from .common import (
HDF5_LOCK, DataStorePickleMixin, WritableCFDataStore, find_root)
from .netCDF4_ import (
BaseNetCDF4Array, _encode_nc4_variable, _extract_nc4_variable_encoding,
_get_datatype, _nc4_group)
Expand Down Expand Up @@ -68,12 +69,12 @@ class H5NetCDFStore(WritableCFDataStore, DataStorePickleMixin):
"""

def __init__(self, filename, mode='r', format=None, group=None,
writer=None, autoclose=False):
writer=None, autoclose=False, lock=HDF5_LOCK):
if format not in [None, 'NETCDF4']:
raise ValueError('invalid format for h5netcdf backend')
opener = functools.partial(_open_h5netcdf_group, filename, mode=mode,
group=group)
self.ds = opener()
self._ds = opener()
if autoclose:
raise NotImplementedError('autoclose=True is not implemented '
'for the h5netcdf backend pending '
Expand All @@ -85,7 +86,7 @@ def __init__(self, filename, mode='r', format=None, group=None,
self._opener = opener
self._filename = filename
self._mode = mode
super(H5NetCDFStore, self).__init__(writer)
super(H5NetCDFStore, self).__init__(writer, lock=lock)

def open_store_variable(self, name, var):
with self.ensure_open(autoclose=False):
Expand Down Expand Up @@ -177,7 +178,10 @@ def prepare_variable(self, name, variable, check_encoding=False,

for k, v in iteritems(attrs):
nc4_var.setncattr(k, v)
return nc4_var, variable.data

target = H5NetCDFArrayWrapper(name, self)

return target, variable.data

def sync(self):
with self.ensure_open(autoclose=True):
Expand Down
Loading

0 comments on commit 2f590f7

Please sign in to comment.