Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Disable automatic cache with dask #1024

Merged
merged 16 commits into from
Nov 14, 2016
Merged
Show file tree
Hide file tree
Changes from 15 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 11 additions & 5 deletions doc/whats-new.rst
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,13 @@ Breaking changes
merges will now succeed in cases that previously raised
``xarray.MergeError``. Set ``compat='broadcast_equals'`` to restore the
previous default.
- Pickling an xarray object based on the dask backend, or reading its
:py:meth:`values` property, won't automatically convert the array from dask
to numpy in the original object anymore.
If a dask object is used as a coord of a :py:class:`~xarray.DataArray` or
:py:class:`~xarray.Dataset`, its values will still be automatically cached,
but only if it's used to index a dim (e.g. it's used for alignment).
By `Guido Imperiale <https://github.com/crusaderky>`_.

Deprecations
~~~~~~~~~~~~
Expand Down Expand Up @@ -52,32 +59,31 @@ Enhancements
- Add checking of ``attr`` names and values when saving to netCDF, raising useful
error messages if they are invalid. (:issue:`911`).
By `Robin Wilson <https://github.com/robintw>`_.

- Added ability to save ``DataArray`` objects directly to netCDF files using
:py:meth:`~xarray.DataArray.to_netcdf`, and to load directly from netCDF files
using :py:func:`~xarray.open_dataarray` (:issue:`915`). These remove the need
to convert a ``DataArray`` to a ``Dataset`` before saving as a netCDF file,
and deals with names to ensure a perfect 'roundtrip' capability.
By `Robin Wilson <https://github.com/robintw>`_.

- Multi-index levels are now accessible as "virtual" coordinate variables,
e.g., ``ds['time']`` can pull out the ``'time'`` level of a multi-index
(see :ref:`coordinates`). ``sel`` also accepts providing multi-index levels
as keyword arguments, e.g., ``ds.sel(time='2000-01')``
(see :ref:`multi-level indexing`).
By `Benoit Bovy <https://github.com/benbovy>`_.

- Added the ``compat`` option ``'no_conflicts'`` to ``merge``, allowing the
combination of xarray objects with disjoint (:issue:`742`) or
overlapping (:issue:`835`) coordinates as long as all present data agrees.
By `Johnnie Gray <https://github.com/jcmgray>`_. See
:ref:`combining.no_conflicts` for more details.

- It is now possible to set ``concat_dim=None`` explicitly in
:py:func:`~xarray.open_mfdataset` to disable inferring a dimension along
which to concatenate.
By `Stephan Hoyer <https://github.com/shoyer>`_.

- Added methods :py:meth:`DataArray.compute`, :py:meth:`Dataset.compute`, and
:py:meth:`Variable.compute` as a non-mutating alternative to
:py:meth:`~DataArray.load`.
By `Guido Imperiale <https://github.com/crusaderky>`_.
- Adds DataArray and Dataset methods :py:meth:`~xarray.DataArray.cumsum` and
:py:meth:`~xarray.DataArray.cumprod`. By `Phillip J. Wolfram
<https://github.com/pwolfram>`_.
Expand Down
13 changes: 13 additions & 0 deletions xarray/core/dataarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -570,6 +570,19 @@ def load(self):
self._coords = new._coords
return self

def compute(self):
"""Manually trigger loading of this array's data from disk or a
remote source into memory and return a new array. The original is
left unaltered.

Normally, it should not be necessary to call this method in user code,
because all xarray functions should either work on deferred data or
load data automatically. However, this method can be necessary when
working with many file objects on disk.
"""
new = self.copy(deep=False)
return new.load()

def copy(self, deep=True):
"""Returns a copy of this array.

Expand Down
44 changes: 30 additions & 14 deletions xarray/core/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -260,8 +260,11 @@ def load_store(cls, store, decoder=None):
return obj

def __getstate__(self):
"""Always load data in-memory before pickling"""
self.load()
"""Load data in-memory before pickling (except for Dask data)"""
for v in self.variables.values():
if not isinstance(v.data, dask_array_type):
v.load()

# self.__dict__ is the default pickle object, we don't need to
# implement our own __setstate__ method to make pickle work
state = self.__dict__.copy()
Expand Down Expand Up @@ -342,6 +345,19 @@ def load(self):

return self

def compute(self):
"""Manually trigger loading of this dataset's data from disk or a
remote source into memory and return a new dataset. The original is
left unaltered.

Normally, it should not be necessary to call this method in user code,
because all xarray functions should either work on deferred data or
load data automatically. However, this method can be necessary when
working with many file objects on disk.
"""
new = self.copy(deep=False)
return new.load()

@classmethod
def _construct_direct(cls, variables, coord_names, dims=None, attrs=None,
file_obj=None):
Expand Down Expand Up @@ -424,14 +440,12 @@ def copy(self, deep=False):
"""Returns a copy of this dataset.

If `deep=True`, a deep copy is made of each of the component variables.
Otherwise, a shallow copy is made, so each variable in the new dataset
is also a variable in the original dataset.
Otherwise, a shallow copy of each of the component variable is made, so
that the underlying memory region of the new dataset is the same as in
the original dataset.
"""
if deep:
variables = OrderedDict((k, v.copy(deep=True))
for k, v in iteritems(self._variables))
else:
variables = self._variables.copy()
variables = OrderedDict((k, v.copy(deep=deep))
for k, v in iteritems(self._variables))
# skip __init__ to avoid costly validation
return self._construct_direct(variables, self._coord_names.copy(),
self._dims.copy(), self._attrs_copy())
Expand Down Expand Up @@ -817,11 +831,10 @@ def chunks(self):
chunks = {}
for v in self.variables.values():
if v.chunks is not None:
new_chunks = list(zip(v.dims, v.chunks))
if any(chunk != chunks[d] for d, chunk in new_chunks
if d in chunks):
raise ValueError('inconsistent chunks')
chunks.update(new_chunks)
for dim, c in zip(v.dims, v.chunks):
if dim in chunks and c != chunks[dim]:
raise ValueError('inconsistent chunks')
chunks[dim] = c
return Frozen(SortedKeysDict(chunks))

def chunk(self, chunks=None, name_prefix='xarray-', token=None,
Expand Down Expand Up @@ -874,6 +887,9 @@ def selkeys(dict_, keys):
return dict((d, dict_[d]) for d in keys if d in dict_)

def maybe_chunk(name, var, chunks):
if name in self.dims:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually, maybe put this logic in IndexVariable instead? We could define a chunk method that looks like:

def chunk(self, ...):
    return self.copy(deep=False)

return var

chunks = selkeys(chunks, var.dims)
if not chunks:
chunks = None
Expand Down
47 changes: 39 additions & 8 deletions xarray/core/variable.py
Original file line number Diff line number Diff line change
Expand Up @@ -277,10 +277,21 @@ def data(self, data):
"replacement data must match the Variable's shape")
self._data = data

def _data_cast(self):
if isinstance(self._data, (np.ndarray, PandasIndexAdapter)):
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should this branch not also apply to dask_array_type?

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In fact, if you manually create a Variable with a dask array you'll get a LazilyIndexedArray at this point. Should this not also be kept unchanged?

return self._data
else:
return np.asarray(self._data)

def _data_cached(self):
if not isinstance(self._data, (np.ndarray, PandasIndexAdapter)):
self._data = np.asarray(self._data)
return self._data
"""Load data into memory and return it.
Do not cache dask arrays automatically; that should
require an explicit load() call.
"""
new_data = self._data_cast()
if not isinstance(self._data, dask_array_type):
self._data = new_data
return new_data

@property
def _indexable_data(self):
Expand All @@ -294,12 +305,26 @@ def load(self):
because all xarray functions should either work on deferred data or
load data automatically.
"""
self._data_cached()
self._data = self._data_cast()
return self

def compute(self):
"""Manually trigger loading of this variable's data from disk or a
remote source into memory and return a new variable. The original is
left unaltered.

Normally, it should not be necessary to call this method in user code,
because all xarray functions should either work on deferred data or
load data automatically.
"""
new = self.copy(deep=False)
return new.load()

def __getstate__(self):
"""Always cache data as an in-memory array before pickling"""
self._data_cached()
"""Always cache data as an in-memory array before pickling
(with the exception of dask backend)"""
if not isinstance(self._data, dask_array_type):
self._data_cached()
# self.__dict__ is the default pickle object, we don't need to
# implement our own __setstate__ method to make pickle work
return self.__dict__
Expand Down Expand Up @@ -1076,10 +1101,16 @@ def __init__(self, dims, data, attrs=None, encoding=None, fastpath=False):
type(self).__name__)

def _data_cached(self):
if not isinstance(self._data, PandasIndexAdapter):
self._data = PandasIndexAdapter(self._data)
# Unlike in Variable._data_cached, always eagerly resolve dask arrays
Copy link
Member

@shoyer shoyer Nov 13, 2016

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I thought we wanted to always eagerly load IndexVariable objects into memory without caching at all?

That would suggest we should put something like self._data = PandasIndexAdapter(self._data) in the constructor, and make _data_cached and _data_cast on the subclass dummy methods.

self._data = self._data_cast()
return self._data

def _data_cast(self):
if isinstance(self._data, PandasIndexAdapter):
return self._data
else:
return PandasIndexAdapter(self._data)

def __getitem__(self, key):
key = self._item_key_to_tuple(key)
values = self._indexable_data[key]
Expand Down
69 changes: 54 additions & 15 deletions xarray/test/test_backends.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,8 +128,10 @@ def assert_loads(vars=None):
if vars is None:
vars = expected
with self.roundtrip(expected) as actual:
for v in actual.variables.values():
self.assertFalse(v._in_memory)
for k, v in actual.variables.items():
# IndexVariables are eagerly cached into memory
if k not in actual.dims:
self.assertFalse(v._in_memory)
yield actual
for k, v in actual.variables.items():
if k in vars:
Expand All @@ -152,6 +154,27 @@ def assert_loads(vars=None):
actual = ds.load()
self.assertDatasetAllClose(expected, actual)

def test_dataset_compute(self):
expected = create_test_data()

with self.roundtrip(expected) as actual:
# Test Dataset.compute()
for k, v in actual.variables.items():
# IndexVariables are eagerly cached
if k not in actual.dims:
self.assertFalse(v._in_memory)

computed = actual.compute()

for k, v in actual.variables.items():
if k not in actual.dims:
self.assertFalse(v._in_memory)
for v in computed.variables.values():
self.assertTrue(v._in_memory)

self.assertDatasetAllClose(expected, actual)
self.assertDatasetAllClose(expected, computed)

def test_roundtrip_None_variable(self):
expected = Dataset({None: (('x', 'y'), [[0, 1], [2, 3]])})
with self.roundtrip(expected) as actual:
Expand Down Expand Up @@ -233,18 +256,6 @@ def test_roundtrip_coordinates(self):
with self.roundtrip(expected) as actual:
self.assertDatasetIdentical(expected, actual)

expected = original.copy()
expected.attrs['coordinates'] = 'something random'
with self.assertRaisesRegexp(ValueError, 'cannot serialize'):
with self.roundtrip(expected):
pass

expected = original.copy(deep=True)
expected['foo'].attrs['coordinates'] = 'something random'
with self.assertRaisesRegexp(ValueError, 'cannot serialize'):
with self.roundtrip(expected):
pass

def test_roundtrip_boolean_dtype(self):
original = create_boolean_data()
self.assertEqual(original['x'].dtype, 'bool')
Expand Down Expand Up @@ -875,7 +886,26 @@ def test_read_byte_attrs_as_unicode(self):
@requires_dask
@requires_scipy
@requires_netCDF4
class DaskTest(TestCase):
class DaskTest(TestCase, DatasetIOTestCases):
@contextlib.contextmanager
def create_store(self):
yield Dataset()

@contextlib.contextmanager
def roundtrip(self, data, save_kwargs={}, open_kwargs={}):
yield data.chunk()

def test_roundtrip_datetime_data(self):
# Override method in DatasetIOTestCases - remove not applicable save_kwds
times = pd.to_datetime(['2000-01-01', '2000-01-02', 'NaT'])
expected = Dataset({'t': ('t', times), 't0': times[0]})
with self.roundtrip(expected) as actual:
self.assertDatasetIdentical(expected, actual)

def test_write_store(self):
# Override method in DatasetIOTestCases - not applicable to dask
pass

def test_open_mfdataset(self):
original = Dataset({'foo': ('x', np.random.randn(10))})
with create_tmp_file() as tmp1:
Expand Down Expand Up @@ -995,6 +1025,15 @@ def test_deterministic_names(self):
self.assertIn(tmp, dask_name)
self.assertEqual(original_names, repeat_names)

def test_dataarray_compute(self):
# Test DataArray.compute() on dask backend.
# The test for Dataset.compute() is already in DatasetIOTestCases;
# however dask is the only tested backend which supports DataArrays
actual = DataArray([1,2]).chunk()
computed = actual.compute()
self.assertFalse(actual._in_memory)
self.assertTrue(computed._in_memory)
self.assertDataArrayAllClose(actual, computed)

@requires_scipy_or_netCDF4
@requires_pydap
Expand Down
Loading