Skip to content

Commit

Permalink
LIN-712 use cloudpickle (LineaLabs#857)
Browse files Browse the repository at this point in the history
* Move pickling and unpickling base functions to api_utils

* Remove dependency on pandas's to_pickle and read_pickle

* make cloudpickle the default but allow for fallback to default pickle (we might never need this though - if so we can also simplify this)

* update requirements to include cloudpickle as min requirement
  • Loading branch information
lionsardesai authored Dec 14, 2022
1 parent 370a73d commit a853e5d
Show file tree
Hide file tree
Showing 6 changed files with 106 additions and 42 deletions.
96 changes: 96 additions & 0 deletions lineapy/api/api_utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,18 @@
import logging
import pickle
import re
import warnings
from pathlib import Path

import cloudpickle
from pandas.io.common import get_handle

from lineapy.db.db import RelationalLineaDB
from lineapy.utils.analytics.event_schemas import ErrorType, ExceptionEvent
from lineapy.utils.analytics.usage_tracking import track
from lineapy.utils.config import options

logger = logging.getLogger(__name__)


def de_lineate_code(code: str, db: RelationalLineaDB) -> str:
Expand Down Expand Up @@ -50,3 +62,87 @@ def replace_fun(match):
# logger.debug("replaces made: %s", replaces)

return swapped


def to_pickle(
value,
filepath_or_buffer,
storage_options=None,
):
with get_handle(
filepath_or_buffer,
"wb",
compression="infer",
is_text=False,
storage_options=storage_options,
) as handles:
# Simplifying use cases handled in pandas for readability.
# letting pickle write directly to the buffer is more memory-efficient
try:
cloudpickle.dump(
value,
handles.handle, # type: ignore[arg-type]
protocol=pickle.HIGHEST_PROTOCOL,
)
except Exception:
pickle.dump(
value,
handles.handle, # type: ignore[arg-type]
protocol=pickle.HIGHEST_PROTOCOL,
)


def read_pickle(pickle_filename):
"""
Read pickle file from artifact storage dir
"""
# TODO - set unicode etc here
artifact_storage_dir = options.safe_get("artifact_storage_dir")
filepath = (
artifact_storage_dir.joinpath(pickle_filename)
if isinstance(artifact_storage_dir, Path)
else f'{artifact_storage_dir.rstrip("/")}/{pickle_filename}'
)
try:
logger.debug(
f"Retriving pickle file from {filepath} ",
)
return _try_pickle_read(
filepath, storage_options=options.get("storage_options")
)
except Exception as e:
logger.error(e)
track(
ExceptionEvent(
ErrorType.RETRIEVE, "Error in retriving pickle file"
)
)
raise e


def _try_pickle_read(filepath_or_buffer, storage_options=None):

with get_handle(
filepath_or_buffer,
"rb",
compression="infer",
is_text=False,
storage_options=storage_options,
) as handles:

# Original pandas comment:
# 1) try standard library Pickle
# 2) try pickle_compat (older pandas version) to handle subclass changes
# 3) try pickle_compat with latin-1 encoding upon a UnicodeDecodeError

# We will not be attempting to support pt 2 and pt 3.
# Simplifying the cases from original

with warnings.catch_warnings(record=True):
# We want to silence any warnings about, e.g. moved modules.
warnings.simplefilter("ignore", Warning)

try:
return cloudpickle.load(handles.handle) # type: ignore[arg-type]
except Exception:
return pickle.load(handles.handle) # type: ignore[arg-type]
3 changes: 1 addition & 2 deletions lineapy/api/artifact_serializer.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,7 @@
from pathlib import Path
from typing import Any, Dict, Optional

from pandas.io.pickle import to_pickle

from lineapy.api.api_utils import to_pickle
from lineapy.data.types import ARTIFACT_STORAGE_BACKEND, LineaID
from lineapy.exceptions.db_exceptions import ArtifactSaveException
from lineapy.plugins.serializers.mlflow_io import try_write_to_mlflow
Expand Down
44 changes: 5 additions & 39 deletions lineapy/api/models/linea_artifact.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,11 @@
from __future__ import annotations

import logging
from dataclasses import dataclass, field
from datetime import datetime
from pathlib import Path
from typing import Optional, Set, Tuple, Union

from IPython.display import display
from pandas.io.pickle import read_pickle

from lineapy.api.api_utils import de_lineate_code
from lineapy.api.api_utils import de_lineate_code, read_pickle
from lineapy.data.graph import Graph
from lineapy.data.types import (
ARTIFACT_STORAGE_BACKEND,
Expand All @@ -28,14 +24,11 @@
)
from lineapy.plugins.serializers.mlflow_io import read_mlflow
from lineapy.utils.analytics.event_schemas import (
ErrorType,
ExceptionEvent,
GetCodeEvent,
GetValueEvent,
GetVersionEvent,
)
from lineapy.utils.analytics.usage_tracking import track
from lineapy.utils.config import options
from lineapy.utils.deprecation_utils import lru_cache
from lineapy.utils.utils import prettify

Expand Down Expand Up @@ -126,7 +119,7 @@ def get_value(self) -> object:
return read_mlflow(metadata["mlflow"])

# read from lineapy
return self._read_pickle(saved_filepath)
return read_pickle(saved_filepath)

@lru_cache(maxsize=None)
def _get_storage_path(self) -> Optional[str]:
Expand Down Expand Up @@ -199,33 +192,6 @@ def get_metadata(self, lineapy_only: bool = False) -> ArtifactInfo:

return metadata

def _read_pickle(self, pickle_filename):
"""
Read pickle file from artifact storage dir
"""
# TODO - set unicode etc here
artifact_storage_dir = options.safe_get("artifact_storage_dir")
filepath = (
artifact_storage_dir.joinpath(pickle_filename)
if isinstance(artifact_storage_dir, Path)
else f'{artifact_storage_dir.rstrip("/")}/{pickle_filename}'
)
try:
logger.debug(
f"Retriving pickle file from {filepath} ",
)
return read_pickle(
filepath, storage_options=options.get("storage_options")
)
except Exception as e:
logger.error(e)
track(
ExceptionEvent(
ErrorType.RETRIEVE, "Error in retriving pickle file"
)
)
raise e

# Note that I removed the @properties because they were not working
# well with the lru_cache
@lru_cache(maxsize=None)
Expand Down Expand Up @@ -347,7 +313,7 @@ def execute(self) -> object:
@staticmethod
def get_artifact_from_orm(
db: RelationalLineaDB, artifactorm: ArtifactORM
) -> LineaArtifact:
) -> "LineaArtifact":
"""
Return LineaArtifact from artifactorm
"""
Expand All @@ -368,7 +334,7 @@ def get_artifact_from_name_and_version(
db: RelationalLineaDB,
artifact_name: str,
version: Optional[int] = None,
) -> LineaArtifact:
) -> "LineaArtifact":
"""
Return LineaArtifact from artifact name and version
"""
Expand All @@ -379,7 +345,7 @@ def get_artifact_from_name_and_version(
@staticmethod
def get_artifact_from_def(
db: RelationalLineaDB, artifactdef: LineaArtifactDef
) -> LineaArtifact:
) -> "LineaArtifact":
"""
Return LineaArtifact from LineaArtifactDef
"""
Expand Down
3 changes: 2 additions & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ astpretty==2.1.0
asttokens==2.0.5
black==22.3.0
click==8.1.2
cloudpickle==2.2.0
coveralls==3.3.1
fastparquet==0.8.0
flake8==4.0.1
Expand All @@ -24,7 +25,7 @@ pdbpp==0.10.3
pg==0.1
Pillow==9.1.1
pre-commit==2.18.1
psycopg2-binary==2.9.3
psycopg2-binary==2.9.5
pydantic==1.9.0
pytest==6.2.5
pytest-alembic==0.8.2
Expand Down
1 change: 1 addition & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ def version(path):
"fsspec",
"pandas",
"alembic==1.8.0",
"cloudpickle",
]

graph_libs = [
Expand Down
1 change: 1 addition & 0 deletions tests/tools/requirements_txt_gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
"nbconvert",
"requests",
"alembic==1.8.0",
"cloudpickle",
]


Expand Down

0 comments on commit a853e5d

Please sign in to comment.