Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Memoization to S3 #28

Merged
merged 9 commits into from
Sep 20, 2023
14 changes: 4 additions & 10 deletions braingeneers/utils/configure.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,7 @@
""" Global package functions and helpers for Braingeneers specific configuration and package management. """
import distutils.util
import functools
import os
from typing import List, Tuple, Union, Iterable, Iterator
import re
import itertools
import importlib
import distutils.util


"""
Expand Down Expand Up @@ -34,11 +30,12 @@
'tenacity',
# 'sortedcontainers',
'boto3',
'joblib>=1.3.0,<2',
'smart_open @ git+https://github.com/davidparks21/smart_open.git@develop', # 'smart_open>=5.1.0', the hash version fixes the bytes from-to range header issue.
'awswrangler==3.*',
],
'data': [
'h5py',
'smart_open @ git+https://github.com/davidparks21/smart_open.git@develop', # 'smart_open>=5.1.0', the hash version fixes the bytes from-to range header issue.
'awswrangler==3.*',
'pandas',
'nptyping',
'paho-mqtt',
Expand All @@ -56,9 +53,6 @@
'pandas',
'powerlaw',
'matplotlib',
# Both of these dependencies are required for read_phy_files
'awswrangler==3.*',
'smart_open @ git+https://github.com/davidparks21/smart_open.git@develop', # 'smart_open>=5.1.0', the hash version fixes the bytes from-to range header issue.
],
'ml': [
'torch',
Expand Down
151 changes: 151 additions & 0 deletions braingeneers/utils/memoize_s3.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,151 @@
import glob
import os
from functools import partial

import awswrangler as wr
import boto3
from joblib import Memory, register_store_backend
from joblib._store_backends import StoreBackendBase, StoreBackendMixin
from smart_open.s3 import parse_uri

from .smart_open_braingeneers import open


def s3_isdir(path):
"""
S3 doesn't support directories, so to check whether some path "exists",
instead check whether it is a prefix of at least one object.
"""
try:
next(wr.s3.list_objects(glob.escape(path), chunked=True))
return True
except StopIteration:
return False


class S3StoreBackend(StoreBackendBase, StoreBackendMixin):
_open_item = staticmethod(open)

def _item_exists(self, location):
return wr.s3.does_object_exist(location) or s3_isdir(location)

def _move_item(self, src_uri, dst_uri):
# awswrangler only includes a fancy move/rename method that actually
# makes it pretty hard to just do a simple move.
src, dst = [parse_uri(x) for x in (src_uri, dst_uri)]
self.client.copy_object(
Bucket=dst["bucket_id"],
Key=dst["key_id"],
CopySource=f"{src['bucket_id']}/{src['key_id']}",
)
self.client.delete_object(Bucket=src["bucket_id"], Key=src["key_id"])

def create_location(self, location):
# Actually don't do anything. There are no locations on S3.
pass

def clear_location(self, location):
# This should only ever be used for prefixes contained within a joblib cache
# directory, so make sure that's actually happening before deleting.
if not location.startswith(self.location):
raise ValueError("can only clear locations within the cache directory")
wr.s3.delete_objects(glob.escape(location))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I make a suggestion in a following comment, but I would enforce a cache dir prefix here and error if trying to delete outside of the cache dir. The following suggested prefix is s3://braingeneersdev/cache/ . A user should not be able to accidentally clear a cache if they naively set, for example, s3://braingeneersdev/ephys/ as their cache dir.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's a good point, but it's also a lot less dangerous than it looks because of the way joblib uses the cache dir you give it. A call like this:

@memoize("s3://braingeneers/")
def bar(baz): ...

in the module foo.py has its actual cache files stored under s3://braingeneers/joblib/foo/bar/, so only things under that whole prefix get deleted when the cache is cleared.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Perfect. I would still check for the s3://braingeneers/joblib prefix then, in that case. It's good to be a little paranoid about recursive deletion, I think.

Copy link
Member Author

@atspaeth atspaeth Sep 19, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, fair. I'll check that the requested location actually starts with self.location + '/joblib/' and then nothing weird should happen

(edit: self.location already includes the /joblib/ part, so the commit implementing this is right and what I originally said in this comment is wrong)


def get_items(self):
return []

def configure(self, location, verbose, backend_options={}):
# We don't do any logging yet, but `configure()` must accept this
# argument, so store it for forwards compatibility.
self.verbose = verbose

# We have to save this on the backend because joblib queries it, but
# default to True instead of joblib's usual False because S3 is not
# local disk and compression can really help.
self.compress = backend_options.get("compress", True)

# This option is available by default but we can't accept it because
# there's no reasonable way to make joblib use NumpyS3Memmap.
self.mmap_mode = backend_options.get("mmap_mode")
if self.mmap_mode is not None:
raise ValueError("impossible to mmap on S3.")

# Don't attempt to handle local files, just use the default backend
# for that!
if not location.startswith("s3://"):
raise ValueError("location must be an s3:// URI")

# We don't have to check that the bucket exists because joblib
# performs a `list_objects()` in it, but note that this doesn't
# actually check whether we can write to it!
self.location = location

# We need a boto3 client, so create it using the endpoint which was
# configured in awswrangler by importing smart_open_braingeneers.
self.client = boto3.Session().client(
"s3", endpoint_url=wr.config.s3_endpoint_url
)


def memoize(
location=None, backend="s3", ignore=None, cache_validation_callback=None, **kwargs
):
"""
Memoize a function to S3 using joblib.Memory. By default, saves to
`s3://braingeneersdev/$S3_USER/cache`, where $S3_USER defaults to "common" if unset.
Alternately, the cache directory can be provided explicitly.

Accepts all the same keyword arguments as `joblib.Memory`, including `backend`,
which can be set to "local" to recover default behavior. Also accepts the
keyword arguments of `joblib.Memory.cache()` and passes them on. Usage:

```
from braingeneers.utils.memoize_s3 import memoize

# Cache to the default location on NRP S3.
@memoize
def foo(x):
return x

# Cache to a different NRP S3 location.
@memoize("s3://braingeneers/someplace/else/idk")
def bar(x):
return x

# Ignore some parameters when deciding which cache entry to check.
@memoize(ignore=["verbose"])
def plover(x, verbose):
if verbose: ...
return x
```

If the bucket doesn't exist, an error will be raised, but if the only
problem is permissions, silent failure to cache may be all that occurs,
depending on the verbosity setting.

Another known issue is that size-based cache eviction is NOT supported,
and will also silently fail. This is because there is no easy way to get
access times out of S3, so we can't find LRU entries.
"""
if callable(location):
# This case probably means the @memoize decorator was used without
# arguments, but pass the kwargs on anyway just in case.
return memoize(
backend=backend,
ignore=ignore,
cache_validation_callback=cache_validation_callback,
**kwargs,
)(location)

if location is None and backend == "s3":
user = os.environ.get("S3_USER", "common")
location = f"s3://braingeneersdev/{user}/cache"

return partial(
Memory(location, backend=backend, **kwargs).cache,
ignore=ignore,
cache_validation_callback=cache_validation_callback,
)


register_store_backend("s3", S3StoreBackend)
91 changes: 91 additions & 0 deletions braingeneers/utils/memoize_s3_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
import unittest
from unittest import mock

from botocore.exceptions import ClientError

from .configure import skip_unittest_if_offline
from .memoize_s3 import memoize


class TestMemoizeS3(unittest.TestCase):
@skip_unittest_if_offline
def test(self):
# Run these checks in a context where S3_USER is set.
with mock.patch.dict("os.environ", {"S3_USER": "unittest"}):
# Memoize a function that counts its calls.
@memoize()(ignore=["y"])
def foo(x, y):
nonlocal side_effect
side_effect += 1
return x

self.assertEqual(
foo.store_backend.location, "s3://braingeneersdev/unittest/cache/joblib"
)

# Call it a few times and make sure it only runs once.
foo.clear()
side_effect = 0
for i in range(3):
self.assertEqual(foo("bar", i), "bar")
self.assertEqual(side_effect, 1)

# Force it to run again and make sure it happens.
foo("baz", 1)
self.assertEqual(side_effect, 2)

# Clean up by reaching into the cache and clearing the directory
# without recreating the cache.
foo.store_backend.clear()

@skip_unittest_if_offline
def test_uri_validation(self):
# Our backend only supports S3 URIs.
with self.assertRaises(ValueError):

@memoize("this has to start with s3://")
def foo(x):
return x

@skip_unittest_if_offline
def test_cant_mmap(self):
# We have to fail if memory mapping is requested because it's
# impossible on S3.
with self.assertRaises(ValueError):

@memoize("s3://this-uri-doesnt-matter/", mmap_mode=True)
def foo(x):
return x

@skip_unittest_if_offline
def test_bucket_existence(self):
# Bucket existence should be checked at creation.
with self.assertRaises(ClientError):

@memoize("s3://i-sure-hope-this-crazy-bucket-doesnt-exist/")
def foo(x):
return x

@skip_unittest_if_offline
def test_default_location(self):
# Make sure a default location is correctly set when S3_USER is not.
with mock.patch.dict("os.environ", {}, clear=True):

@memoize()
def foo(x):
return x

self.assertEqual(
foo.store_backend.location, "s3://braingeneersdev/common/cache/joblib"
)

@skip_unittest_if_offline
def test_custom_location(self):
# Make sure custom locations get set correctly.
@memoize("s3://braingeneersdev/unittest/cache")
def foo(x):
return x

self.assertEqual(
foo.store_backend.location, "s3://braingeneersdev/unittest/cache/joblib"
)
5 changes: 4 additions & 1 deletion braingeneers/utils/numpy_s3_memmap_test.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
import unittest
from braingeneers.utils import NumpyS3Memmap # deprecated form
import numpy as np
from .configure import skip_unittest_if_offline
from .numpy_s3_memmap import NumpyS3Memmap


class TestNumpyS3Memmap(unittest.TestCase):
@skip_unittest_if_offline
def test_numpy32memmap_online(self):
""" Note: this is an online test requiring access to the PRP/S3 braingeneersdev bucket. """
x = NumpyS3Memmap('s3://braingeneersdev/dfparks/test/test.npy')
Expand All @@ -18,6 +20,7 @@ def test_numpy32memmap_online(self):
self.assertTrue(np.all(x[:, 0:2] == e[:, 0:2]))
self.assertTrue(np.all(x[:, [0, 1]] == e[:, [0, 1]]))

@skip_unittest_if_offline
def test_online_in_the_wild_file(self):
"""
This test assumes online access.
Expand Down