Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove caching logic from xarray.Variable #1128

Merged
merged 14 commits into from
Nov 30, 2016
1 change: 1 addition & 0 deletions ci/requirements-py27-cdat+pynio.yml
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ dependencies:
- python=2.7
- cdat-lite
- dask
- distributed
- pytest
- numpy
- pandas>=0.15.0
Expand Down
1 change: 1 addition & 0 deletions ci/requirements-py27-netcdf4-dev.yml
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ dependencies:
- python=2.7
- cython
- dask
- distributed
- h5py
- pytest
- numpy
Expand Down
1 change: 1 addition & 0 deletions ci/requirements-py27-pydap.yml
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ name: test_env
dependencies:
- python=2.7
- dask
- distributed
- h5py
- netcdf4
- pytest
Expand Down
1 change: 1 addition & 0 deletions ci/requirements-py35.yml
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ dependencies:
- python=3.5
- cython
- dask
- distributed
- h5py
- matplotlib
- netcdf4
Expand Down
20 changes: 14 additions & 6 deletions doc/whats-new.rst
Original file line number Diff line number Diff line change
Expand Up @@ -25,12 +25,20 @@ Breaking changes
merges will now succeed in cases that previously raised
``xarray.MergeError``. Set ``compat='broadcast_equals'`` to restore the
previous default.
- Pickling an xarray object based on the dask backend, or reading its
:py:meth:`values` property, won't automatically convert the array from dask
to numpy in the original object anymore.
If a dask object is used as a coord of a :py:class:`~xarray.DataArray` or
:py:class:`~xarray.Dataset`, its values are eagerly computed and cached,
but only if it's used to index a dim (e.g. it's used for alignment).
- Reading :py:attr:`~DataArray.values` no longer always caches values in a NumPy
array :issue:`1128`. Caching of ``.values`` on variables read from netCDF
files on disk is still the default when :py:func:`open_dataset` is called with
``cache=True``.
By `Guido Imperiale <https://github.com/crusaderky>`_ and
`Stephan Hoyer <https://github.com/shoyer>`_.
- Pickling a ``Dataset`` or ``DataArray`` linked to a file on disk no longer
caches its values into memory before pickling :issue:`1128`. Instead, pickle
stores file paths and restores objects by reopening file references. This
enables preliminary, experimental use of xarray for opening files with
`dask.distributed <https://distributed.readthedocs.io>`_.
By `Stephan Hoyer <https://github.com/shoyer>`_.
- Coordinates used to index a dimension are now loaded eagerly into
:py:class:`pandas.Index` objects, instead of loading the values lazily.
By `Guido Imperiale <https://github.com/crusaderky>`_.

Deprecations
Expand Down
66 changes: 59 additions & 7 deletions xarray/backends/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,15 @@

from .. import backends, conventions
from .common import ArrayWriter
from ..core import indexing
from ..core.combine import auto_combine
from ..core.utils import close_on_error, is_remote_uri
from ..core.pycompat import basestring

DATAARRAY_NAME = '__xarray_dataarray_name__'
DATAARRAY_VARIABLE = '__xarray_dataarray_variable__'


def _get_default_engine(path, allow_remote=False):
if allow_remote and is_remote_uri(path): # pragma: no cover
try:
Expand All @@ -46,6 +48,13 @@ def _get_default_engine(path, allow_remote=False):
return engine


def _normalize_path(path):
if is_remote_uri(path):
return path
else:
return os.path.abspath(os.path.expanduser(path))


_global_lock = threading.Lock()


Expand Down Expand Up @@ -117,10 +126,20 @@ def check_attr(name, value):
check_attr(k, v)


def _protect_dataset_variables_inplace(dataset, cache):
for name, variable in dataset.variables.items():
if name not in variable.dims:
# no need to protect IndexVariable objects
data = indexing.CopyOnWriteArray(variable._data)
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I ran a test where I create a DataSet from a custom data store which initializes Variables using dask arrays for data. In this case the dask arrays is still converted to an ndarray when accessing the Variable's data property, since it checks is for a dask array type, however here the array is wrapped into a CopyOnWriteArray, which means Variable.values is called, which loads eagerly.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Indeed, you would need to use cache=False in such a case.

Xarray's decoding logic in conventions.py uses it's own array objects instead of dask arrays (for reasons I could get into), which unfortunately makes using dask.array objects to produce variables in a custom data store non-ideal. The problem is that the graphs from such dask arrays don't get linked up into xarray, which means that even if you rechunk the arrays in the xarray Dataset, they still get executed separately by dask. Duck typing for dask objects would probably help here (dask/dask#1068) .

if cache:
data = indexing.MemoryCachedArray(data)
variable.data = data


def open_dataset(filename_or_obj, group=None, decode_cf=True,
mask_and_scale=True, decode_times=True,
concat_characters=True, decode_coords=True, engine=None,
chunks=None, lock=None, drop_variables=None):
chunks=None, lock=None, cache=None, drop_variables=None):
"""Load and decode a dataset from a file or file-like object.

Parameters
Expand Down Expand Up @@ -162,14 +181,22 @@ def open_dataset(filename_or_obj, group=None, decode_cf=True,
'netcdf4'.
chunks : int or dict, optional
If chunks is provided, it used to load the new dataset into dask
arrays. This is an experimental feature; see the documentation for more
details.
arrays. ``chunks={}`` loads the dataset with dask using a single
chunk for all arrays. This is an experimental feature; see the
documentation for more details.
lock : False, True or threading.Lock, optional
If chunks is provided, this argument is passed on to
:py:func:`dask.array.from_array`. By default, a per-variable lock is
used when reading data from netCDF files with the netcdf4 and h5netcdf
engines to avoid issues with concurrent access when using dask's
multithreaded backend.
cache : bool, optional
If True, cache data loaded from the underlying datastore in memory as
NumPy arrays when accessed to avoid reading from the underlying data-
store multiple times. Defaults to True unless you specify the `chunks`
argument to use dask, in which case it defaults to False. Does not
change the behavior of coordinates corresponding to dimensions, which
always load their data from disk into a ``pandas.Index``.
drop_variables: string or iterable, optional
A variable or list of variables to exclude from being parsed from the
dataset. This may be useful to drop variables with problems or
Expand All @@ -190,12 +217,17 @@ def open_dataset(filename_or_obj, group=None, decode_cf=True,
concat_characters = False
decode_coords = False

if cache is None:
cache = chunks is None

def maybe_decode_store(store, lock=False):
ds = conventions.decode_cf(
store, mask_and_scale=mask_and_scale, decode_times=decode_times,
concat_characters=concat_characters, decode_coords=decode_coords,
drop_variables=drop_variables)

_protect_dataset_variables_inplace(ds, cache)

if chunks is not None:
try:
from dask.base import tokenize
Expand Down Expand Up @@ -226,6 +258,17 @@ def maybe_decode_store(store, lock=False):
if isinstance(filename_or_obj, backends.AbstractDataStore):
store = filename_or_obj
elif isinstance(filename_or_obj, basestring):

if (isinstance(filename_or_obj, bytes) and
filename_or_obj.startswith(b'\x89HDF')):
raise ValueError('cannot read netCDF4/HDF5 file images')
elif (isinstance(filename_or_obj, bytes) and
filename_or_obj.startswith(b'CDF')):
# netCDF3 file images are handled by scipy
pass
elif isinstance(filename_or_obj, basestring):
filename_or_obj = _normalize_path(filename_or_obj)

if filename_or_obj.endswith('.gz'):
if engine is not None and engine != 'scipy':
raise ValueError('can only read gzipped netCDF files with '
Expand Down Expand Up @@ -274,7 +317,7 @@ def maybe_decode_store(store, lock=False):
def open_dataarray(filename_or_obj, group=None, decode_cf=True,
mask_and_scale=True, decode_times=True,
concat_characters=True, decode_coords=True, engine=None,
chunks=None, lock=None, drop_variables=None):
chunks=None, lock=None, cache=None, drop_variables=None):
"""
Opens an DataArray from a netCDF file containing a single data variable.

Expand Down Expand Up @@ -328,6 +371,13 @@ def open_dataarray(filename_or_obj, group=None, decode_cf=True,
used when reading data from netCDF files with the netcdf4 and h5netcdf
engines to avoid issues with concurrent access when using dask's
multithreaded backend.
cache : bool, optional
If True, cache data loaded from the underlying datastore in memory as
NumPy arrays when accessed to avoid reading from the underlying data-
store multiple times. Defaults to True unless you specify the `chunks`
argument to use dask, in which case it defaults to False. Does not
change the behavior of coordinates corresponding to dimensions, which
always load their data from disk into a ``pandas.Index``.
drop_variables: string or iterable, optional
A variable or list of variables to exclude from being parsed from the
dataset. This may be useful to drop variables with problems or
Expand All @@ -349,7 +399,7 @@ def open_dataarray(filename_or_obj, group=None, decode_cf=True,
dataset = open_dataset(filename_or_obj, group, decode_cf,
mask_and_scale, decode_times,
concat_characters, decode_coords, engine,
chunks, lock, drop_variables)
chunks, lock, cache, drop_variables)

if len(dataset.data_vars) != 1:
raise ValueError('Given file dataset contains more than one data '
Expand Down Expand Up @@ -494,8 +544,10 @@ def to_netcdf(dataset, path=None, mode='w', format=None, group=None,
raise ValueError('invalid engine for creating bytes with '
'to_netcdf: %r. Only the default engine '
"or engine='scipy' is supported" % engine)
elif engine is None:
engine = _get_default_engine(path)
else:
if engine is None:
engine = _get_default_engine(path)
path = _normalize_path(path)

# validate Dataset keys, DataArray names, and attr keys/values
_validate_dataset_names(dataset)
Expand Down
19 changes: 19 additions & 0 deletions xarray/backends/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,3 +235,22 @@ def store(self, variables, attributes, check_encoding_set=frozenset()):
cf_variables, cf_attrs = cf_encoder(variables, attributes)
AbstractWritableDataStore.store(self, cf_variables, cf_attrs,
check_encoding_set)


class DataStorePickleMixin(object):
"""Subclasses must define `ds`, `_opener` and `_mode` attributes.

Do not subclass this class: it is not part of xarray's external API.
"""

def __getstate__(self):
state = self.__dict__.copy()
del state['ds']
if self._mode == 'w':
# file has already been created, don't override when restoring
state['_mode'] = 'a'
return state

def __setstate__(self, state):
self.__dict__.update(state)
self.ds = self._opener(mode=self._mode)
26 changes: 17 additions & 9 deletions xarray/backends/h5netcdf_.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from ..core.utils import FrozenOrderedDict, close_on_error, Frozen
from ..core.pycompat import iteritems, bytes_type, unicode_type, OrderedDict

from .common import WritableCFDataStore
from .common import WritableCFDataStore, DataStorePickleMixin
from .netCDF4_ import (_nc4_group, _nc4_values_and_dtype, _extract_nc4_encoding,
BaseNetCDF4Array)

Expand Down Expand Up @@ -37,24 +37,32 @@ def _read_attributes(h5netcdf_var):
lsd_okay=False, backend='h5netcdf')


class H5NetCDFStore(WritableCFDataStore):
def _open_h5netcdf_group(filename, mode, group):
import h5netcdf.legacyapi
ds = h5netcdf.legacyapi.Dataset(filename, mode=mode)
with close_on_error(ds):
return _nc4_group(ds, group, mode)


class H5NetCDFStore(WritableCFDataStore, DataStorePickleMixin):
"""Store for reading and writing data via h5netcdf
"""
def __init__(self, filename, mode='r', format=None, group=None,
writer=None):
import h5netcdf.legacyapi
if format not in [None, 'NETCDF4']:
raise ValueError('invalid format for h5netcdf backend')
ds = h5netcdf.legacyapi.Dataset(filename, mode=mode)
with close_on_error(ds):
self.ds = _nc4_group(ds, group, mode)
opener = functools.partial(_open_h5netcdf_group, filename, mode=mode,
group=group)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Might be worth noting that this is only cloud-picklable, not stdlib pickleable.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you clarify why this won't work with stdlib pickle? Is the issue doing the h5netcdf import inside the function definition? My understand was that functools.partial is pickleable.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, you're right. I'm surprised:

In [1]: from operator import add

In [2]: from functools import partial

In [3]: from pickle import dumps, loads

In [4]: loads(dumps(partial(add, 1)))
Out[4]: functools.partial(<built-in function add>, 1)

In [5]: loads(dumps(partial(add, 1)))(2)
Out[5]: 3

self.ds = opener()
self.format = format
self._opener = opener
self._filename = filename
self._mode = mode
super(H5NetCDFStore, self).__init__(writer)

def open_store_variable(self, var):
def open_store_variable(self, name, var):
dimensions = var.dimensions
data = indexing.LazilyIndexedArray(BaseNetCDF4Array(var))
data = indexing.LazilyIndexedArray(BaseNetCDF4Array(name, self))
attrs = _read_attributes(var)

# netCDF4 specific encoding
Expand All @@ -69,7 +77,7 @@ def open_store_variable(self, var):
return Variable(dimensions, data, attrs, encoding)

def get_variables(self):
return FrozenOrderedDict((k, self.open_store_variable(v))
return FrozenOrderedDict((k, self.open_store_variable(k, v))
for k, v in iteritems(self.ds.variables))

def get_attrs(self):
Expand Down
Loading