From ef2a10ec622569a3e6abcfc116e11dc7075635b4 Mon Sep 17 00:00:00 2001 From: Marco Gorelli <33491632+MarcoGorelli@users.noreply.github.com> Date: Sun, 30 Jun 2024 11:54:01 +0100 Subject: [PATCH 01/22] feat: use Narwhals to make NumPy and pandas optional, and to support Narwhals-compliant libraries natively without PyArrow conversions --- altair/_magics.py | 4 +- altair/utils/__init__.py | 12 +- altair/utils/_vegafusion_data.py | 4 +- altair/utils/core.py | 175 +++++++++++----------- altair/utils/data.py | 98 +++++++----- altair/utils/schemapi.py | 62 +++++--- altair/vegalite/data.py | 4 +- altair/vegalite/v5/schema/channels.py | 4 +- pyproject.toml | 10 +- tests/utils/test_core.py | 2 +- tests/utils/test_data.py | 4 + tests/utils/test_dataframe_interchange.py | 5 +- tests/utils/test_utils.py | 33 ++-- tests/vegalite/v5/test_api.py | 17 +++ tools/generate_schema_wrapper.py | 5 +- tools/schemapi/schemapi.py | 62 +++++--- 16 files changed, 297 insertions(+), 204 deletions(-) diff --git a/altair/_magics.py b/altair/_magics.py index 28e6d832f..638400c78 100644 --- a/altair/_magics.py +++ b/altair/_magics.py @@ -9,7 +9,7 @@ import IPython from IPython.core import magic_arguments -import pandas as pd +from narwhals.dependencies import is_pandas_dataframe as _is_pandas_dataframe from altair.vegalite import v5 as vegalite_v5 @@ -39,7 +39,7 @@ def _prepare_data(data, data_transformers): """Convert input data to data for use within schema""" if data is None or isinstance(data, dict): return data - elif isinstance(data, pd.DataFrame): + elif _is_pandas_dataframe(data): if func := data_transformers.get(): data = func(data) return data diff --git a/altair/utils/__init__.py b/altair/utils/__init__.py index 64d6f4566..36d35bca4 100644 --- a/altair/utils/__init__.py +++ b/altair/utils/__init__.py @@ -1,8 +1,8 @@ from .core import ( - infer_vegalite_type, + infer_vegalite_type_for_pandas, infer_encoding_types, - sanitize_dataframe, - sanitize_arrow_table, + sanitize_pandas_dataframe, + sanitize_narwhals_dataframe, parse_shorthand, use_signature, update_nested, @@ -23,10 +23,10 @@ "Undefined", "display_traceback", "infer_encoding_types", - "infer_vegalite_type", + "infer_vegalite_type_for_pandas", "parse_shorthand", - "sanitize_arrow_table", - "sanitize_dataframe", + "sanitize_narwhals_dataframe", + "sanitize_pandas_dataframe", "spec_to_html", "update_nested", "use_signature", diff --git a/altair/utils/_vegafusion_data.py b/altair/utils/_vegafusion_data.py index ea1ae6dad..7608d3589 100644 --- a/altair/utils/_vegafusion_data.py +++ b/altair/utils/_vegafusion_data.py @@ -24,7 +24,7 @@ if TYPE_CHECKING: - import pandas as pd + from narwhals.typing import IntoDataFrame from vegafusion.runtime import ChartState # type: ignore # Temporary storage for dataframes that have been extracted @@ -60,7 +60,7 @@ def vegafusion_data_transformer( @overload def vegafusion_data_transformer( - data: dict | pd.DataFrame | SupportsGeoInterface, max_rows: int = ... + data: dict | IntoDataFrame | SupportsGeoInterface, max_rows: int = ... ) -> _VegaFusionReturnType: ... diff --git a/altair/utils/core.py b/altair/utils/core.py index a001f7a12..acdbc9df6 100644 --- a/altair/utils/core.py +++ b/altair/utils/core.py @@ -27,12 +27,10 @@ from operator import itemgetter import jsonschema -import pandas as pd -import numpy as np -from pandas.api.types import infer_dtype +import narwhals.stable.v1 as nw +from narwhals.dependencies import is_pandas_dataframe as _is_pandas_dataframe from altair.utils.schemapi import SchemaBase, Undefined -from altair.utils._dfi_types import Column, DtypeKind, DataFrame as DfiDataFrame if sys.version_info >= (3, 10): from typing import ParamSpec @@ -43,8 +41,11 @@ if TYPE_CHECKING: from types import ModuleType import typing as t - from pandas.core.interchange.dataframe_protocol import Column as PandasColumn - import pyarrow as pa + from altair.vegalite.v5.schema._typing import StandardType_T as InferredVegaLiteType + from altair.utils._dfi_types import DataFrame as DfiDataFrame + from altair.utils.data import DataType + from narwhals.typing import IntoExpr, IntoDataFrameT + import pandas as pd V = TypeVar("V") P = ParamSpec("P") @@ -198,10 +199,7 @@ def __dataframe__( ] -InferredVegaLiteType = Literal["ordinal", "nominal", "quantitative", "temporal"] - - -def infer_vegalite_type( +def infer_vegalite_type_for_pandas( data: object, ) -> InferredVegaLiteType | tuple[InferredVegaLiteType, list[Any]]: """ @@ -212,6 +210,9 @@ def infer_vegalite_type( ---------- data: object """ + # This is safe to import here, as this function is only called on pandas input. + from pandas.api.types import infer_dtype + typ = infer_dtype(data, skipna=False) if typ in { @@ -297,13 +298,16 @@ def sanitize_geo_interface(geo: t.MutableMapping[Any, Any]) -> dict[str, Any]: def numpy_is_subtype(dtype: Any, subtype: Any) -> bool: + # This is only called on `numpy` inputs, so it's safe to import it here. + import numpy as np + try: return np.issubdtype(dtype, subtype) except (NotImplementedError, TypeError): return False -def sanitize_dataframe(df: pd.DataFrame) -> pd.DataFrame: +def sanitize_pandas_dataframe(df: pd.DataFrame) -> pd.DataFrame: """Sanitize a DataFrame to prepare it for serialization. * Make a copy @@ -320,6 +324,11 @@ def sanitize_dataframe(df: pd.DataFrame) -> pd.DataFrame: * convert dedicated string column to objects and replace NaN with None * Raise a ValueError for TimeDelta dtypes """ + # This is safe to import here, as this function is only called on pandas input. + # NumPy is a required dependency of pandas so is also safe to import. + import pandas as pd + import numpy as np + df = df.copy() if isinstance(df.columns, pd.RangeIndex): @@ -429,30 +438,54 @@ def to_list_if_array(val): return df -def sanitize_arrow_table(pa_table: pa.Table) -> pa.Table: - """Sanitize arrow table for JSON serialization""" - import pyarrow as pa - import pyarrow.compute as pc - - arrays = [] - schema = pa_table.schema - for name in schema.names: - array = pa_table[name] - dtype_name = str(schema.field(name).type) - if dtype_name.startswith(("timestamp", "date")): - arrays.append(pc.strftime(array)) - elif dtype_name.startswith("duration"): +def sanitize_narwhals_dataframe( + data: nw.DataFrame[IntoDataFrameT], +) -> nw.DataFrame[IntoDataFrameT]: + """Sanitize narwhals.DataFrame for JSON serialization""" + schema = data.schema + columns: list[IntoExpr] = [] + # See https://github.com/vega/altair/issues/1027 for why this is necessary. + local_iso_fmt_string = "%Y-%m-%dT%H:%M:%S" + for name, dtype in schema.items(): + if dtype == nw.Date: + # Polars doesn't allow formatting `Date` with time directives. + # The date -> datetime cast is extremely fast compared with `to_string` + columns.append( + nw.col(name).cast(nw.Datetime).dt.to_string(local_iso_fmt_string) + ) + elif dtype == nw.Datetime: + columns.append(nw.col(name).dt.to_string(f"{local_iso_fmt_string}%.f")) + elif dtype == nw.Duration: msg = ( - f'Field "{name}" has type "{dtype_name}" which is ' + f'Field "{name}" has type "{dtype}" which is ' "not supported by Altair. Please convert to " "either a timestamp or a numerical value." "" ) raise ValueError(msg) else: - arrays.append(array) + columns.append(name) + return data.select(columns) + - return pa.Table.from_arrays(arrays, names=schema.names) +def narwhalify(data: DataType) -> nw.DataFrame: + """Wrap `data` in `narwhals.DataFrame`. + + If `data` is not supported by Narwhals, but it is convertible + to a PyArrow table, then first convert to a PyArrow Table, + and then wrap in `narwhals.DataFrame`. + """ + # Using `strict=False` will return `data` as-is if the object cannot be converted. + data = nw.from_native(data, eager_only=True, strict=False) + if isinstance(data, nw.DataFrame): + return data + if isinstance(data, DataFrameLike): + from altair.utils.data import arrow_table_from_dfi_dataframe + + pa_table = arrow_table_from_dfi_dataframe(data) + return nw.from_native(pa_table, eager_only=True) + msg = f"Unsupported data type: {type(data)}" + raise TypeError(msg) def parse_shorthand( @@ -498,6 +531,7 @@ def parse_shorthand( Examples -------- + >>> import pandas as pd >>> data = pd.DataFrame({'foo': ['A', 'B', 'A', 'B'], ... 'bar': [1, 2, 3, 4]}) @@ -537,7 +571,7 @@ def parse_shorthand( >>> parse_shorthand('count()', data) == {'aggregate': 'count', 'type': 'quantitative'} True """ - from altair.utils._importers import pyarrow_available + from altair.utils.data import is_data_type if not shorthand: return {} @@ -597,39 +631,18 @@ def parse_shorthand( attrs["type"] = "temporal" # if data is specified and type is not, infer type from data - if "type" not in attrs: - if pyarrow_available() and data is not None and isinstance(data, DataFrameLike): - dfi = data.__dataframe__() - if "field" in attrs: - unescaped_field = attrs["field"].replace("\\", "") - if unescaped_field in dfi.column_names(): - column = dfi.get_column_by_name(unescaped_field) - try: - attrs["type"] = infer_vegalite_type_for_dfi_column(column) - except (NotImplementedError, AttributeError, ValueError): - # Fall back to pandas-based inference. - # Note: The AttributeError catch is a workaround for - # https://github.com/pandas-dev/pandas/issues/55332 - if isinstance(data, pd.DataFrame): - attrs["type"] = infer_vegalite_type(data[unescaped_field]) - else: - raise - - if isinstance(attrs["type"], tuple): - attrs["sort"] = attrs["type"][1] - attrs["type"] = attrs["type"][0] - elif isinstance(data, pd.DataFrame): - # Fallback if pyarrow is not installed or if pandas is older than 1.5 - # - # Remove escape sequences so that types can be inferred for columns with special characters - if "field" in attrs and attrs["field"].replace("\\", "") in data.columns: - attrs["type"] = infer_vegalite_type( - data[attrs["field"].replace("\\", "")] - ) - # ordered categorical dataframe columns return the type and sort order as a tuple - if isinstance(attrs["type"], tuple): - attrs["sort"] = attrs["type"][1] - attrs["type"] = attrs["type"][0] + if "type" not in attrs and is_data_type(data): + data_nw = narwhalify(data) + unescaped_field = attrs["field"].replace("\\", "") + if isinstance(data_nw, nw.DataFrame) and unescaped_field in data_nw.columns: + column = data_nw[unescaped_field] + if column.dtype in {nw.Object, nw.Unknown} and _is_pandas_dataframe(data): + attrs["type"] = infer_vegalite_type_for_pandas(nw.to_native(column)) + else: + attrs["type"] = infer_vegalite_type_for_narwhals(column) + if isinstance(attrs["type"], tuple): + attrs["sort"] = attrs["type"][1] + attrs["type"] = attrs["type"][0] # If an unescaped colon is still present, it's often due to an incorrect data type specification # but could also be due to using a column name with ":" in it. @@ -650,41 +663,23 @@ def parse_shorthand( return attrs -def infer_vegalite_type_for_dfi_column( - column: Column | PandasColumn, -) -> InferredVegaLiteType | tuple[InferredVegaLiteType, list[Any]]: - from pyarrow.interchange.from_dataframe import column_to_array - - try: - kind = column.dtype[0] - except NotImplementedError as e: - # Edge case hack: - # dtype access fails for pandas column with datetime64[ns, UTC] type, - # but all we need to know is that its temporal, so check the - # error message for the presence of datetime64. - # - # See https://github.com/pandas-dev/pandas/issues/54239 - if "datetime64" in e.args[0] or "timestamp" in e.args[0]: - return "temporal" - raise e - +def infer_vegalite_type_for_narwhals( + column: nw.Series, +) -> InferredVegaLiteType | tuple[InferredVegaLiteType, list]: + dtype = column.dtype if ( - kind == DtypeKind.CATEGORICAL - and column.describe_categorical["is_ordered"] - and column.describe_categorical["categories"] is not None + nw.is_ordered_categorical(column) + and not (categories := column.cat.get_categories()).is_empty() ): - # Treat ordered categorical column as Vega-Lite ordinal - categories_column = column.describe_categorical["categories"] - categories_array = column_to_array(categories_column) - return "ordinal", categories_array.to_pylist() - if kind in {DtypeKind.STRING, DtypeKind.CATEGORICAL, DtypeKind.BOOL}: + return "ordinal", categories.to_list() + if dtype in {nw.String, nw.Categorical, nw.Boolean}: return "nominal" - elif kind in {DtypeKind.INT, DtypeKind.UINT, DtypeKind.FLOAT}: + elif dtype.is_numeric(): return "quantitative" - elif kind == DtypeKind.DATETIME: + elif dtype in {nw.Datetime, nw.Date}: return "temporal" else: - msg = f"Unexpected DtypeKind: {kind}" + msg = f"Unexpected DtypeKind: {dtype}" raise ValueError(msg) diff --git a/altair/utils/data.py b/altair/utils/data.py index 61923231c..d09c58922 100644 --- a/altair/utils/data.py +++ b/altair/utils/data.py @@ -22,10 +22,17 @@ from functools import partial import sys -import pandas as pd +import narwhals.stable.v1 as nw +from narwhals.dependencies import is_pandas_dataframe as _is_pandas_dataframe +from narwhals.typing import IntoDataFrame from ._importers import import_pyarrow_interchange -from .core import sanitize_dataframe, sanitize_arrow_table, DataFrameLike +from .core import ( + sanitize_pandas_dataframe, + DataFrameLike, + sanitize_narwhals_dataframe, + narwhalify, +) from .core import sanitize_geo_interface from .plugin_registry import PluginRegistry @@ -36,6 +43,7 @@ if TYPE_CHECKING: import pyarrow as pa + import pandas as pd @runtime_checkable @@ -44,20 +52,23 @@ class SupportsGeoInterface(Protocol): DataType: TypeAlias = Union[ - Dict[Any, Any], pd.DataFrame, SupportsGeoInterface, DataFrameLike + Dict[Any, Any], IntoDataFrame, SupportsGeoInterface, DataFrameLike ] TDataType = TypeVar("TDataType", bound=DataType) +TIntoDataFrame = TypeVar("TIntoDataFrame", bound=IntoDataFrame) VegaLiteDataDict: TypeAlias = Dict[ str, Union[str, Dict[Any, Any], List[Dict[Any, Any]]] ] ToValuesReturnType: TypeAlias = Dict[str, Union[Dict[Any, Any], List[Dict[Any, Any]]]] -SampleReturnType = Union[pd.DataFrame, Dict[str, Sequence], "pa.lib.Table", None] +SampleReturnType = Union[IntoDataFrame, Dict[str, Sequence], None] def is_data_type(obj: Any) -> TypeIs[DataType]: - return isinstance(obj, (dict, pd.DataFrame, DataFrameLike, SupportsGeoInterface)) + return _is_pandas_dataframe(obj) or isinstance( + obj, (dict, DataFrameLike, SupportsGeoInterface, nw.DataFrame) + ) # ============================================================================== @@ -133,20 +144,21 @@ def raise_max_rows_error(): values = data.__geo_interface__["features"] else: values = data.__geo_interface__ - elif isinstance(data, pd.DataFrame): + elif _is_pandas_dataframe(data): values = data elif isinstance(data, dict): if "values" in data: values = data["values"] else: return data - elif isinstance(data, DataFrameLike): - pa_table = arrow_table_from_dfi_dataframe(data) - if max_rows is not None and pa_table.num_rows > max_rows: + else: + data_nw = narwhalify(data) + if max_rows is not None and len(data_nw) > max_rows: raise_max_rows_error() - # Return pyarrow Table instead of input since the - # `arrow_table_from_dfi_dataframe` call above may be expensive - return pa_table + # `narwhalify` may call `arrow_table_from_dfi_dataframe`, + # which can be expensive. Therefore, we return the `narwhals.DataFrame` + # here instead of the original input. + return data_nw if max_rows is not None and len(values) > max_rows: raise_max_rows_error() @@ -159,6 +171,10 @@ def sample( data: None = ..., n: int | None = ..., frac: float | None = ... ) -> partial: ... @overload +def sample( + data: TIntoDataFrame, n: int | None = ..., frac: float | None = ... +) -> TIntoDataFrame: ... +@overload def sample( data: DataType, n: int | None = ..., frac: float | None = ... ) -> SampleReturnType: ... @@ -171,7 +187,7 @@ def sample( if data is None: return partial(sample, n=n, frac=frac) check_data_type(data) - if isinstance(data, pd.DataFrame): + if _is_pandas_dataframe(data): return data.sample(n=n, frac=frac) elif isinstance(data, dict): if "values" in data: @@ -186,19 +202,19 @@ def sample( else: # Maybe this should raise an error or return something useful? return None - elif isinstance(data, DataFrameLike): - pa_table = arrow_table_from_dfi_dataframe(data) - if not n: - if frac is None: - msg = "frac cannot be None if n is None with this data input type" - raise ValueError(msg) - n = int(frac * len(pa_table)) - indices = random.sample(range(len(pa_table)), n) - return pa_table.take(indices) - else: + try: + data = narwhalify(data) + except TypeError: # Maybe this should raise an error or return something useful? Currently, # if data is of type SupportsGeoInterface it lands here return None + if not n: + if frac is None: + msg = "frac cannot be None if n is None with this data input type" + raise ValueError(msg) + n = int(frac * len(data)) + indices = random.sample(range(len(data)), n) + return nw.to_native(data[indices]) _FormatType = Literal["csv", "json"] @@ -310,27 +326,28 @@ def to_values(data: DataType) -> ToValuesReturnType: """Replace a DataFrame by a data model with values.""" check_data_type(data) if isinstance(data, SupportsGeoInterface): - if isinstance(data, pd.DataFrame): - data = sanitize_dataframe(data) + if _is_pandas_dataframe(data): + data = sanitize_pandas_dataframe(data) # Maybe the type could be further clarified here that it is # SupportGeoInterface and then the ignore statement is not needed? data_sanitized = sanitize_geo_interface(data.__geo_interface__) # type: ignore[arg-type] return {"values": data_sanitized} - elif isinstance(data, pd.DataFrame): - data = sanitize_dataframe(data) + elif _is_pandas_dataframe(data): + data = sanitize_pandas_dataframe(data) return {"values": data.to_dict(orient="records")} elif isinstance(data, dict): if "values" not in data: msg = "values expected in data dict, but not present." raise KeyError(msg) return data - elif isinstance(data, DataFrameLike): - pa_table = sanitize_arrow_table(arrow_table_from_dfi_dataframe(data)) - return {"values": pa_table.to_pylist()} - else: + try: + data = narwhalify(data) + except TypeError as exc: # Should never reach this state as tested by check_data_type msg = f"Unrecognized data type: {type(data)}" - raise ValueError(msg) + raise ValueError(msg) from exc + data = sanitize_narwhals_dataframe(data) + return {"values": data.rows(named=True)} def check_data_type(data: DataType) -> None: @@ -350,14 +367,14 @@ def _data_to_json_string(data: DataType) -> str: """Return a JSON string representation of the input data""" check_data_type(data) if isinstance(data, SupportsGeoInterface): - if isinstance(data, pd.DataFrame): - data = sanitize_dataframe(data) + if _is_pandas_dataframe(data): + data = sanitize_pandas_dataframe(data) # Maybe the type could be further clarified here that it is # SupportGeoInterface and then the ignore statement is not needed? data = sanitize_geo_interface(data.__geo_interface__) # type: ignore[arg-type] return json.dumps(data) - elif isinstance(data, pd.DataFrame): - data = sanitize_dataframe(data) + elif _is_pandas_dataframe(data): + data = sanitize_pandas_dataframe(data) return data.to_json(orient="records", double_precision=15) elif isinstance(data, dict): if "values" not in data: @@ -382,13 +399,18 @@ def _data_to_csv_string(data: dict | pd.DataFrame | DataFrameLike) -> str: f"See https://github.com/vega/altair/issues/3441" ) raise NotImplementedError(msg) - elif isinstance(data, pd.DataFrame): - data = sanitize_dataframe(data) + elif _is_pandas_dataframe(data): + data = sanitize_pandas_dataframe(data) return data.to_csv(index=False) elif isinstance(data, dict): if "values" not in data: msg = "values expected in data dict, but not present" raise KeyError(msg) + try: + import pandas as pd + except ImportError as exc: + msg = "pandas is required to convert a dict to a CSV string" + raise ImportError(msg) from exc return pd.DataFrame.from_dict(data["values"]).to_csv(index=False) elif isinstance(data, DataFrameLike): # experimental interchange dataframe support diff --git a/altair/utils/schemapi.py b/altair/utils/schemapi.py index 8f71247da..6aa7c6381 100644 --- a/altair/utils/schemapi.py +++ b/altair/utils/schemapi.py @@ -7,9 +7,11 @@ import inspect import json import textwrap +from math import ceil from collections import defaultdict from importlib.metadata import version as importlib_version from itertools import chain, zip_longest +import sys from typing import ( TYPE_CHECKING, Any, @@ -29,8 +31,6 @@ import jsonschema import jsonschema.exceptions import jsonschema.validators -import numpy as np -import pandas as pd from packaging.version import Version # This leads to circular imports with the vegalite module. Currently, this works @@ -39,8 +39,6 @@ from altair import vegalite if TYPE_CHECKING: - import sys - from referencing import Registry from altair import ChartType @@ -56,7 +54,6 @@ else: from typing_extensions import Self, Never - ValidationErrorList: TypeAlias = List[jsonschema.exceptions.ValidationError] GroupedValidationErrors: TypeAlias = Dict[str, ValidationErrorList] @@ -477,20 +474,34 @@ def _subclasses(cls: type[Any]) -> Iterator[type[Any]]: yield cls -def _todict(obj: Any, context: dict[str, Any] | None) -> Any: +def _todict(obj: Any, context: dict[str, Any] | None, np_opt: Any, pd_opt: Any) -> Any: """Convert an object to a dict representation.""" + if np_opt is not None: + np = np_opt + if isinstance(obj, np.ndarray): + return [_todict(v, context, np_opt, pd_opt) for v in obj] + elif isinstance(obj, np.number): + return float(obj) + elif isinstance(obj, np.datetime64): + result = str(obj) + if "T" not in result: + # See https://github.com/vega/altair/issues/1027 for why this is necessary. + result += "T00:00:00" + return result if isinstance(obj, SchemaBase): return obj.to_dict(validate=False, context=context) - elif isinstance(obj, (list, tuple, np.ndarray)): - return [_todict(v, context) for v in obj] + elif isinstance(obj, (list, tuple)): + return [_todict(v, context, np_opt, pd_opt) for v in obj] elif isinstance(obj, dict): - return {k: _todict(v, context) for k, v in obj.items() if v is not Undefined} + return { + k: _todict(v, context, np_opt, pd_opt) + for k, v in obj.items() + if v is not Undefined + } elif hasattr(obj, "to_dict"): return obj.to_dict() - elif isinstance(obj, np.number): - return float(obj) - elif isinstance(obj, (pd.Timestamp, np.datetime64)): - return pd.Timestamp(obj).isoformat() + elif pd_opt is not None and isinstance(obj, pd_opt.Timestamp): + return pd_opt.Timestamp(obj).isoformat() else: return obj @@ -636,7 +647,7 @@ def _format_params_as_table(param_dict_keys: Iterable[str]) -> str: max_column_width = 80 # Output a square table if not too big (since it is easier to read) num_param_names = len(param_names) - square_columns = int(np.ceil(num_param_names**0.5)) + square_columns = int(ceil(num_param_names**0.5)) columns = min(max_column_width // max_name_length, square_columns) # Compute roughly equal column heights to evenly divide the param names @@ -969,9 +980,17 @@ def to_dict( context = {} if ignore is None: ignore = [] + # The following return the package only if it has already been + # imported - otherwise they return None. This is useful for + # isinstance checks - for example, if pandas has not been imported, + # then an object is definitely not a `pandas.Timestamp`. + pd_opt = sys.modules.get("pandas") + np_opt = sys.modules.get("numpy") if self._args and not self._kwds: - result = _todict(self._args[0], context=context) + result = _todict( + self._args[0], context=context, np_opt=np_opt, pd_opt=pd_opt + ) elif not self._args: kwds = self._kwds.copy() # parsed_shorthand is added by FieldChannelMixin. @@ -999,10 +1018,7 @@ def to_dict( } if "mark" in kwds and isinstance(kwds["mark"], str): kwds["mark"] = {"type": kwds["mark"]} - result = _todict( - kwds, - context=context, - ) + result = _todict(kwds, context=context, np_opt=np_opt, pd_opt=pd_opt) else: msg = ( f"{self.__class__} instance has both a value and properties : " @@ -1173,7 +1189,13 @@ def validate_property( Validate a property against property schema in the context of the rootschema """ - value = _todict(value, context={}) + # The following return the package only if it has already been + # imported - otherwise they return None. This is useful for + # isinstance checks - for example, if pandas has not been imported, + # then an object is definitely not a `pandas.Timestamp`. + pd_opt = sys.modules.get("pandas") + np_opt = sys.modules.get("numpy") + value = _todict(value, context={}, np_opt=np_opt, pd_opt=pd_opt) props = cls.resolve_references(schema or cls._schema).get("properties", {}) return validate_jsonschema( value, props.get(name, {}), rootschema=cls._rootschema or cls._schema diff --git a/altair/vegalite/data.py b/altair/vegalite/data.py index db5b4bcdc..19371fc87 100644 --- a/altair/vegalite/data.py +++ b/altair/vegalite/data.py @@ -1,7 +1,7 @@ from __future__ import annotations from typing import TYPE_CHECKING, overload, Callable -from ..utils.core import sanitize_dataframe +from ..utils.core import sanitize_pandas_dataframe from ..utils.data import ( MaxRowsError, limit_rows, @@ -58,7 +58,7 @@ def disable_max_rows(self) -> PluginEnabler: "default_data_transformer", "limit_rows", "sample", - "sanitize_dataframe", + "sanitize_pandas_dataframe", "to_csv", "to_json", "to_values", diff --git a/altair/vegalite/v5/schema/channels.py b/altair/vegalite/v5/schema/channels.py index da3365613..7e2951b71 100644 --- a/altair/vegalite/v5/schema/channels.py +++ b/altair/vegalite/v5/schema/channels.py @@ -13,7 +13,7 @@ from typing import TYPE_CHECKING, Any, Literal, Sequence, overload -import pandas as pd +from narwhals.dependencies import is_pandas_dataframe as _is_pandas_dataframe from altair.utils import infer_encoding_types as _infer_encoding_types from altair.utils import parse_shorthand @@ -72,7 +72,7 @@ def to_dict( # We still parse it out of the shorthand, but drop it here. parsed.pop("type", None) elif not (type_in_shorthand or type_defined_explicitly): - if isinstance(context.get("data", None), pd.DataFrame): + if _is_pandas_dataframe(context.get("data", None)): msg = ( f'Unable to determine data type for the field "{shorthand}";' " verify that the field name is not misspelled." diff --git a/pyproject.toml b/pyproject.toml index 2b9dc55f0..7c28cbdcb 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -19,10 +19,8 @@ dependencies = [ "jinja2", # If you update the minimum required jsonschema version, also update it in build.yml "jsonschema>=3.0", - "numpy<2.0.0", - # If you update the minimum required pandas version, also update it in build.yml - "pandas>=0.25", - "packaging" + "packaging", + "narwhals>=1.0.0" ] description = "Vega-Altair: A declarative statistical visualization library for Python." readme = "README.md" @@ -59,6 +57,8 @@ Source = "https://github.com/vega/altair" all = [ "vega_datasets>=0.9.0", "vl-convert-python>=1.3.0", + "pandas>=0.25.3", + "numpy<2.0.0", "pyarrow>=11", "vegafusion[embed]>=1.6.6", "anywidget>=0.9.0", @@ -68,6 +68,7 @@ dev = [ "hatch", "ruff>=0.5.1", "ipython", + "pandas>=0.25.3", "pytest", "pytest-cov", "pytest-xdist[psutil]~=3.5", @@ -77,6 +78,7 @@ dev = [ "types-jsonschema", "types-setuptools", "geopandas", + "polars>=0.20.3", ] doc = [ "sphinx", diff --git a/tests/utils/test_core.py b/tests/utils/test_core.py index 8327d3afe..a2344a218 100644 --- a/tests/utils/test_core.py +++ b/tests/utils/test_core.py @@ -8,7 +8,7 @@ import altair as alt from altair.utils.core import parse_shorthand, update_nested, infer_encoding_types -from altair.utils.core import infer_dtype +from pandas.api.types import infer_dtype json_schema_specification = alt.load_schema()["$schema"] json_schema_dict_str = f'{{"$schema": "{json_schema_specification}"}}' diff --git a/tests/utils/test_data.py b/tests/utils/test_data.py index e90474d83..90ca71c67 100644 --- a/tests/utils/test_data.py +++ b/tests/utils/test_data.py @@ -3,6 +3,7 @@ from typing import Any, Callable import pytest import pandas as pd +import polars as pl from altair.utils.data import ( limit_rows, MaxRowsError, @@ -65,6 +66,9 @@ def test_sample(): assert isinstance(result, dict) assert "values" in result assert len(result["values"]) == 10 + result = sample(pl.DataFrame(data), n=10) + assert isinstance(result, pl.DataFrame) + assert len(result) == 10 def test_to_values(): diff --git a/tests/utils/test_dataframe_interchange.py b/tests/utils/test_dataframe_interchange.py index 56e6499ff..5b5cdb877 100644 --- a/tests/utils/test_dataframe_interchange.py +++ b/tests/utils/test_dataframe_interchange.py @@ -59,4 +59,7 @@ def test_duration_raises(): # Check that exception mentions the duration[ns] type, # which is what the pandas timedelta is converted into - assert "duration[ns]" in e.value.args[0] + assert ( + 'Field "timedelta" has type "Duration" which is not supported by Altair' + in e.value.args[0] + ) diff --git a/tests/utils/test_utils.py b/tests/utils/test_utils.py index 875647214..d9569f0b1 100644 --- a/tests/utils/test_utils.py +++ b/tests/utils/test_utils.py @@ -3,11 +3,16 @@ import sys import warnings +import narwhals.stable.v1 as nw import numpy as np import pandas as pd import pytest -from altair.utils import infer_vegalite_type, sanitize_dataframe, sanitize_arrow_table +from altair.utils import ( + infer_vegalite_type_for_pandas, + sanitize_pandas_dataframe, + sanitize_narwhals_dataframe, +) try: import pyarrow as pa @@ -17,7 +22,7 @@ def test_infer_vegalite_type(): def _check(arr, typ): - assert infer_vegalite_type(arr) == typ + assert infer_vegalite_type_for_pandas(arr) == typ _check(np.arange(5, dtype=float), "quantitative") _check(np.arange(5, dtype=int), "quantitative") @@ -64,7 +69,7 @@ def test_sanitize_dataframe(): # JSON serialize. This will fail on non-sanitized dataframes print(df[["s", "c2"]]) - df_clean = sanitize_dataframe(df) + df_clean = sanitize_pandas_dataframe(df) print(df_clean[["s", "c2"]]) print(df_clean[["s", "c2"]].to_dict()) s = json.dumps(df_clean.to_dict(orient="records")) @@ -107,7 +112,7 @@ def test_sanitize_dataframe_arrow_columns(): } ) df_arrow = pa.Table.from_pandas(df).to_pandas(types_mapper=pd.ArrowDtype) - df_clean = sanitize_dataframe(df_arrow) + df_clean = sanitize_pandas_dataframe(df_arrow) records = df_clean.to_dict(orient="records") assert records[0] == { "s": "a", @@ -157,15 +162,15 @@ def test_sanitize_pyarrow_table_columns() -> None: ] ), ) - sanitized = sanitize_arrow_table(pa_table) - values = sanitized.to_pylist() + sanitized = sanitize_narwhals_dataframe(nw.from_native(pa_table, eager_only=True)) + values = sanitized.rows(named=True) assert values[0] == { "s": "a", "f": 0.0, "i": 0, "b": True, - "d": "2012-01-01T00:00:00", + "d": "2012-01-01T00:00:00.000000", "c": "a", "p": "2012-01-01T00:00:00.000000000", } @@ -178,26 +183,26 @@ def test_sanitize_dataframe_colnames(): df = pd.DataFrame(np.arange(12).reshape(4, 3)) # Test that RangeIndex is converted to strings - df = sanitize_dataframe(df) + df = sanitize_pandas_dataframe(df) assert [isinstance(col, str) for col in df.columns] # Test that non-string columns result in an error df.columns = [4, "foo", "bar"] with pytest.raises(ValueError) as err: # noqa: PT011 - sanitize_dataframe(df) + sanitize_pandas_dataframe(df) assert str(err.value).startswith("Dataframe contains invalid column name: 4.") def test_sanitize_dataframe_timedelta(): df = pd.DataFrame({"r": pd.timedelta_range(start="1 day", periods=4)}) with pytest.raises(ValueError) as err: # noqa: PT011 - sanitize_dataframe(df) + sanitize_pandas_dataframe(df) assert str(err.value).startswith('Field "r" has type "timedelta') def test_sanitize_dataframe_infs(): df = pd.DataFrame({"x": [0, 1, 2, np.inf, -np.inf, np.nan]}) - df_clean = sanitize_dataframe(df) + df_clean = sanitize_pandas_dataframe(df) assert list(df_clean.dtypes) == [object] assert list(df_clean["x"]) == [0, 1, 2, None, None, None] @@ -218,7 +223,7 @@ def test_sanitize_nullable_integers(): } ) - df_clean = sanitize_dataframe(df) + df_clean = sanitize_pandas_dataframe(df) assert {col.dtype.name for _, col in df_clean.items()} == {"object"} result_python = {col_name: list(col) for col_name, col in df_clean.items()} @@ -246,7 +251,7 @@ def test_sanitize_string_dtype(): } ) - df_clean = sanitize_dataframe(df) + df_clean = sanitize_pandas_dataframe(df) assert {col.dtype.name for _, col in df_clean.items()} == {"object"} result_python = {col_name: list(col) for col_name, col in df_clean.items()} @@ -271,7 +276,7 @@ def test_sanitize_boolean_dtype(): } ) - df_clean = sanitize_dataframe(df) + df_clean = sanitize_pandas_dataframe(df) assert {col.dtype.name for _, col in df_clean.items()} == {"object"} result_python = {col_name: list(col) for col_name, col in df_clean.items()} diff --git a/tests/vegalite/v5/test_api.py b/tests/vegalite/v5/test_api.py index a7197e0b1..1a65d83c4 100644 --- a/tests/vegalite/v5/test_api.py +++ b/tests/vegalite/v5/test_api.py @@ -1,6 +1,7 @@ """Unit tests for altair API""" import io +import sys import json import operator import os @@ -10,6 +11,7 @@ import jsonschema import pytest import pandas as pd +import polars as pl import altair.vegalite.v5 as alt @@ -1065,3 +1067,18 @@ def test_validate_dataset(): jsn = chart.to_json() assert jsn + + +def test_polars_with_pandas_nor_pyarrow(monkeypatch: pytest.MonkeyPatch): + monkeypatch.delitem(sys.modules, "pandas") + monkeypatch.delitem(sys.modules, "numpy") + monkeypatch.delitem(sys.modules, "pyarrow", raising=False) + + df = pl.DataFrame({"a": [1, 2, 3], "b": [4, 5, 6]}) + _ = alt.Chart(df).mark_line().encode(x="a", y="b").to_json() + # Check pandas and PyArrow weren't imported anywhere along the way, + # confirming that the plot above would work without pandas no PyArrow + # installed. + assert "pandas" not in sys.modules + assert "pyarrow" not in sys.modules + assert "numpy" not in sys.modules diff --git a/tools/generate_schema_wrapper.py b/tools/generate_schema_wrapper.py index f46096960..68f910522 100644 --- a/tools/generate_schema_wrapper.py +++ b/tools/generate_schema_wrapper.py @@ -108,7 +108,7 @@ def to_dict( # We still parse it out of the shorthand, but drop it here. parsed.pop("type", None) elif not (type_in_shorthand or type_defined_explicitly): - if isinstance(context.get("data", None), pd.DataFrame): + if _is_pandas_dataframe(context.get("data", None)): msg = ( f'Unable to determine data type for the field "{shorthand}";' " verify that the field name is not misspelled." @@ -520,6 +520,7 @@ def generate_vegalite_schema_wrapper(schema_file: Path) -> str: "from typing import Any, Literal, Union, Protocol, Sequence, List, Iterator, TYPE_CHECKING", "import pkgutil", "import json\n", + "from narwhals.dependencies import is_pandas_dataframe as _is_pandas_dataframe", "from altair.utils.schemapi import SchemaBase, Undefined, UndefinedType, _subclasses # noqa: F401\n", _type_checking_only_imports( "from altair import Parameter", @@ -570,7 +571,7 @@ def generate_vegalite_channel_wrappers( imports = imports or [ "from __future__ import annotations\n", "from typing import Any, overload, Sequence, List, Literal, Union, TYPE_CHECKING", - "import pandas as pd", + "from narwhals.dependencies import is_pandas_dataframe as _is_pandas_dataframe", "from altair.utils.schemapi import Undefined, with_property_setters", "from altair.utils import infer_encoding_types as _infer_encoding_types", "from altair.utils import parse_shorthand", diff --git a/tools/schemapi/schemapi.py b/tools/schemapi/schemapi.py index b8bbe34c2..b0804978b 100644 --- a/tools/schemapi/schemapi.py +++ b/tools/schemapi/schemapi.py @@ -5,9 +5,11 @@ import inspect import json import textwrap +from math import ceil from collections import defaultdict from importlib.metadata import version as importlib_version from itertools import chain, zip_longest +import sys from typing import ( TYPE_CHECKING, Any, @@ -27,8 +29,6 @@ import jsonschema import jsonschema.exceptions import jsonschema.validators -import numpy as np -import pandas as pd from packaging.version import Version # This leads to circular imports with the vegalite module. Currently, this works @@ -37,8 +37,6 @@ from altair import vegalite if TYPE_CHECKING: - import sys - from referencing import Registry from altair import ChartType @@ -54,7 +52,6 @@ else: from typing_extensions import Self, Never - ValidationErrorList: TypeAlias = List[jsonschema.exceptions.ValidationError] GroupedValidationErrors: TypeAlias = Dict[str, ValidationErrorList] @@ -475,20 +472,34 @@ def _subclasses(cls: type[Any]) -> Iterator[type[Any]]: yield cls -def _todict(obj: Any, context: dict[str, Any] | None) -> Any: +def _todict(obj: Any, context: dict[str, Any] | None, np_opt: Any, pd_opt: Any) -> Any: """Convert an object to a dict representation.""" + if np_opt is not None: + np = np_opt + if isinstance(obj, np.ndarray): + return [_todict(v, context, np_opt, pd_opt) for v in obj] + elif isinstance(obj, np.number): + return float(obj) + elif isinstance(obj, np.datetime64): + result = str(obj) + if "T" not in result: + # See https://github.com/vega/altair/issues/1027 for why this is necessary. + result += "T00:00:00" + return result if isinstance(obj, SchemaBase): return obj.to_dict(validate=False, context=context) - elif isinstance(obj, (list, tuple, np.ndarray)): - return [_todict(v, context) for v in obj] + elif isinstance(obj, (list, tuple)): + return [_todict(v, context, np_opt, pd_opt) for v in obj] elif isinstance(obj, dict): - return {k: _todict(v, context) for k, v in obj.items() if v is not Undefined} + return { + k: _todict(v, context, np_opt, pd_opt) + for k, v in obj.items() + if v is not Undefined + } elif hasattr(obj, "to_dict"): return obj.to_dict() - elif isinstance(obj, np.number): - return float(obj) - elif isinstance(obj, (pd.Timestamp, np.datetime64)): - return pd.Timestamp(obj).isoformat() + elif pd_opt is not None and isinstance(obj, pd_opt.Timestamp): + return pd_opt.Timestamp(obj).isoformat() else: return obj @@ -634,7 +645,7 @@ def _format_params_as_table(param_dict_keys: Iterable[str]) -> str: max_column_width = 80 # Output a square table if not too big (since it is easier to read) num_param_names = len(param_names) - square_columns = int(np.ceil(num_param_names**0.5)) + square_columns = int(ceil(num_param_names**0.5)) columns = min(max_column_width // max_name_length, square_columns) # Compute roughly equal column heights to evenly divide the param names @@ -967,9 +978,17 @@ def to_dict( context = {} if ignore is None: ignore = [] + # The following return the package only if it has already been + # imported - otherwise they return None. This is useful for + # isinstance checks - for example, if pandas has not been imported, + # then an object is definitely not a `pandas.Timestamp`. + pd_opt = sys.modules.get("pandas") + np_opt = sys.modules.get("numpy") if self._args and not self._kwds: - result = _todict(self._args[0], context=context) + result = _todict( + self._args[0], context=context, np_opt=np_opt, pd_opt=pd_opt + ) elif not self._args: kwds = self._kwds.copy() # parsed_shorthand is added by FieldChannelMixin. @@ -997,10 +1016,7 @@ def to_dict( } if "mark" in kwds and isinstance(kwds["mark"], str): kwds["mark"] = {"type": kwds["mark"]} - result = _todict( - kwds, - context=context, - ) + result = _todict(kwds, context=context, np_opt=np_opt, pd_opt=pd_opt) else: msg = ( f"{self.__class__} instance has both a value and properties : " @@ -1171,7 +1187,13 @@ def validate_property( Validate a property against property schema in the context of the rootschema """ - value = _todict(value, context={}) + # The following return the package only if it has already been + # imported - otherwise they return None. This is useful for + # isinstance checks - for example, if pandas has not been imported, + # then an object is definitely not a `pandas.Timestamp`. + pd_opt = sys.modules.get("pandas") + np_opt = sys.modules.get("numpy") + value = _todict(value, context={}, np_opt=np_opt, pd_opt=pd_opt) props = cls.resolve_references(schema or cls._schema).get("properties", {}) return validate_jsonschema( value, props.get(name, {}), rootschema=cls._rootschema or cls._schema From e91ed4d9783fd1910df1132234bc9628e880ce12 Mon Sep 17 00:00:00 2001 From: Marco Gorelli <33491632+MarcoGorelli@users.noreply.github.com> Date: Thu, 11 Jul 2024 16:02:57 +0100 Subject: [PATCH 02/22] fix ibis-date32 bug too, make sure to only convert to arrow once --- altair/utils/_vegafusion_data.py | 10 +++--- altair/utils/core.py | 7 ++-- altair/utils/data.py | 39 +++++++++++------------ altair/vegalite/v5/api.py | 15 ++++++--- pyproject.toml | 1 + tests/utils/test_dataframe_interchange.py | 7 ++-- tests/vegalite/v5/test_api.py | 15 +++++++++ 7 files changed, 61 insertions(+), 33 deletions(-) diff --git a/altair/utils/_vegafusion_data.py b/altair/utils/_vegafusion_data.py index 7608d3589..274a3d32a 100644 --- a/altair/utils/_vegafusion_data.py +++ b/altair/utils/_vegafusion_data.py @@ -12,8 +12,9 @@ Callable, ) +import narwhals as nw + from altair.utils._importers import import_vegafusion -from altair.utils.core import DataFrameLike from altair.utils.data import ( DataType, ToValuesReturnType, @@ -22,9 +23,9 @@ ) from altair.vegalite.data import default_data_transformer - if TYPE_CHECKING: from narwhals.typing import IntoDataFrame + from altair.utils.core import DataFrameLike from vegafusion.runtime import ChartState # type: ignore # Temporary storage for dataframes that have been extracted @@ -70,9 +71,10 @@ def vegafusion_data_transformer( """VegaFusion Data Transformer""" if data is None: return vegafusion_data_transformer - elif isinstance(data, DataFrameLike) and not isinstance(data, SupportsGeoInterface): + elif isinstance(data, nw.DataFrame) and not isinstance(data, SupportsGeoInterface): table_name = f"table_{uuid.uuid4()}".replace("-", "_") - extracted_inline_tables[table_name] = data + # vegafusion doesn't support Narwhals, so we extract the native object. + extracted_inline_tables[table_name] = nw.to_native(data) return {"url": VEGAFUSION_PREFIX + table_name} else: # Use default transformer for geo interface objects diff --git a/altair/utils/core.py b/altair/utils/core.py index acdbc9df6..87cf01bf0 100644 --- a/altair/utils/core.py +++ b/altair/utils/core.py @@ -468,12 +468,13 @@ def sanitize_narwhals_dataframe( return data.select(columns) -def narwhalify(data: DataType) -> nw.DataFrame: - """Wrap `data` in `narwhals.DataFrame`. +def narwhalify(data: DataType) -> nw.DataFrame[Any]: + """Wrap `data` in `narwhals.DataFrame` (if possible). If `data` is not supported by Narwhals, but it is convertible to a PyArrow table, then first convert to a PyArrow Table, and then wrap in `narwhals.DataFrame`. + If it can't even be converted to a PyArrow Table, return as-is. """ # Using `strict=False` will return `data` as-is if the object cannot be converted. data = nw.from_native(data, eager_only=True, strict=False) @@ -634,7 +635,7 @@ def parse_shorthand( if "type" not in attrs and is_data_type(data): data_nw = narwhalify(data) unescaped_field = attrs["field"].replace("\\", "") - if isinstance(data_nw, nw.DataFrame) and unescaped_field in data_nw.columns: + if unescaped_field in data_nw.columns: column = data_nw[unescaped_field] if column.dtype in {nw.Object, nw.Unknown} and _is_pandas_dataframe(data): attrs["type"] = infer_vegalite_type_for_pandas(nw.to_native(column)) diff --git a/altair/utils/data.py b/altair/utils/data.py index d09c58922..2afa09618 100644 --- a/altair/utils/data.py +++ b/altair/utils/data.py @@ -202,9 +202,8 @@ def sample( else: # Maybe this should raise an error or return something useful? return None - try: - data = narwhalify(data) - except TypeError: + data = narwhalify(data) + if not isinstance(data, nw.DataFrame): # Maybe this should raise an error or return something useful? Currently, # if data is of type SupportsGeoInterface it lands here return None @@ -325,29 +324,29 @@ def _to_text_kwds(prefix: str, extension: str, filename: str, urlpath: str, /) - def to_values(data: DataType) -> ToValuesReturnType: """Replace a DataFrame by a data model with values.""" check_data_type(data) - if isinstance(data, SupportsGeoInterface): - if _is_pandas_dataframe(data): - data = sanitize_pandas_dataframe(data) + data_native = nw.to_native(data, strict=False) + if isinstance(data_native, SupportsGeoInterface): + if _is_pandas_dataframe(data_native): + data_native = sanitize_pandas_dataframe(data_native) # Maybe the type could be further clarified here that it is # SupportGeoInterface and then the ignore statement is not needed? - data_sanitized = sanitize_geo_interface(data.__geo_interface__) # type: ignore[arg-type] + data_sanitized = sanitize_geo_interface(data_native.__geo_interface__) return {"values": data_sanitized} - elif _is_pandas_dataframe(data): - data = sanitize_pandas_dataframe(data) - return {"values": data.to_dict(orient="records")} - elif isinstance(data, dict): - if "values" not in data: + elif _is_pandas_dataframe(data_native): + data_native = sanitize_pandas_dataframe(data_native) + return {"values": data_native.to_dict(orient="records")} + elif isinstance(data_native, dict): + if "values" not in data_native: msg = "values expected in data dict, but not present." raise KeyError(msg) - return data - try: - data = narwhalify(data) - except TypeError as exc: + return data_native + elif isinstance(data, nw.DataFrame): + data = sanitize_narwhals_dataframe(data) + return {"values": data.rows(named=True)} + else: # Should never reach this state as tested by check_data_type msg = f"Unrecognized data type: {type(data)}" - raise ValueError(msg) from exc - data = sanitize_narwhals_dataframe(data) - return {"values": data.rows(named=True)} + raise ValueError(msg) def check_data_type(data: DataType) -> None: @@ -435,7 +434,7 @@ def arrow_table_from_dfi_dataframe(dfi_df: DataFrameLike) -> pa.Table: # has more control over the conversion, and may have broader compatibility. # This is the case for Polars, which supports Date32 columns in direct conversion # while pyarrow does not yet support this type in from_dataframe - for convert_method_name in ("arrow", "to_arrow", "to_arrow_table"): + for convert_method_name in ("arrow", "to_arrow", "to_arrow_table", "to_pyarrow"): convert_method = getattr(dfi_df, convert_method_name, None) if callable(convert_method): result = convert_method() diff --git a/altair/vegalite/v5/api.py b/altair/vegalite/v5/api.py index 843315fa7..82fc2059d 100644 --- a/altair/vegalite/v5/api.py +++ b/altair/vegalite/v5/api.py @@ -13,6 +13,7 @@ from .schema import core, channels, mixins, Undefined, SCHEMA_URL from altair.utils import Optional +from altair.utils.data import narwhalify as _narwhalify from .data import data_transformers from ... import utils from ...expr import core as _expr_core @@ -1015,11 +1016,17 @@ def to_dict( # TopLevelMixin instance does not necessarily have copy defined but due to how # Altair is set up this should hold. Too complex to type hint right now copy = self.copy(deep=False) # type: ignore[attr-defined] - original_data = getattr(copy, "data", Undefined) - copy.data = _prepare_data(original_data, context) - if original_data is not Undefined: - context["data"] = original_data + data = getattr(copy, "data", Undefined) + try: + data = _narwhalify(data) # type: ignore[arg-type] + except TypeError: + # Non-narwhalifiable type still supported by Altair, such as dict. + pass + copy.data = _prepare_data(data, context) + + if data is not Undefined: + context["data"] = data # remaining to_dict calls are not at top level context["top_level"] = False diff --git a/pyproject.toml b/pyproject.toml index 7c28cbdcb..acc9a1db9 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -67,6 +67,7 @@ all = [ dev = [ "hatch", "ruff>=0.5.1", + "ibis-framework", "ipython", "pandas>=0.25.3", "pytest", diff --git a/tests/utils/test_dataframe_interchange.py b/tests/utils/test_dataframe_interchange.py index 5b5cdb877..3e7977d18 100644 --- a/tests/utils/test_dataframe_interchange.py +++ b/tests/utils/test_dataframe_interchange.py @@ -3,6 +3,7 @@ import pandas as pd import pytest import sys +import narwhals.stable.v1 as nw try: import pyarrow as pa @@ -36,8 +37,9 @@ def test_arrow_timestamp_conversion(): "value": [102, 129, 139], } pa_table = pa.table(data) + nw_frame = nw.from_native(pa_table) - values = to_values(pa_table) + values = to_values(nw_frame) expected_values = { "values": [ {"date": "2004-08-01T00:00:00.000000", "value": 102}, @@ -54,8 +56,9 @@ def test_duration_raises(): df = pd.DataFrame(td).reset_index() df.columns = ["id", "timedelta"] pa_table = pa.table(df) + nw_frame = nw.from_native(pa_table) with pytest.raises(ValueError) as e: # noqa: PT011 - to_values(pa_table) + to_values(nw_frame) # Check that exception mentions the duration[ns] type, # which is what the pandas timedelta is converted into diff --git a/tests/vegalite/v5/test_api.py b/tests/vegalite/v5/test_api.py index 1a65d83c4..5d7a68818 100644 --- a/tests/vegalite/v5/test_api.py +++ b/tests/vegalite/v5/test_api.py @@ -1,6 +1,8 @@ """Unit tests for altair API""" +from datetime import date import io +import ibis import sys import json import operator @@ -1082,3 +1084,16 @@ def test_polars_with_pandas_nor_pyarrow(monkeypatch: pytest.MonkeyPatch): assert "pandas" not in sys.modules assert "pyarrow" not in sys.modules assert "numpy" not in sys.modules + + +def test_ibis_with_date_32(): + df = pl.DataFrame( + {"a": [1, 2, 3], "b": [date(2020, 1, 1), date(2020, 1, 2), date(2020, 1, 3)]} + ) + tbl = ibis.memtable(df) + result = alt.Chart(tbl).mark_line().encode(x="a", y="b").to_dict() + assert next(iter(result["datasets"].values())) == [ + {"a": 1, "b": "2020-01-01T00:00:00.000000"}, + {"a": 2, "b": "2020-01-02T00:00:00.000000"}, + {"a": 3, "b": "2020-01-03T00:00:00.000000"}, + ] From f6d639ec75de3257c0becd0f6ccf34d975b2ca46 Mon Sep 17 00:00:00 2001 From: Marco Gorelli <33491632+MarcoGorelli@users.noreply.github.com> Date: Thu, 11 Jul 2024 16:05:28 +0100 Subject: [PATCH 03/22] stable api --- altair/utils/_vegafusion_data.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/altair/utils/_vegafusion_data.py b/altair/utils/_vegafusion_data.py index 274a3d32a..57e1099ad 100644 --- a/altair/utils/_vegafusion_data.py +++ b/altair/utils/_vegafusion_data.py @@ -12,7 +12,7 @@ Callable, ) -import narwhals as nw +import narwhals.stable.v1 as nw from altair.utils._importers import import_vegafusion from altair.utils.data import ( From 7c052c0af1bb87b7340e24227d14bb8a3350005f Mon Sep 17 00:00:00 2001 From: Marco Gorelli <33491632+MarcoGorelli@users.noreply.github.com> Date: Thu, 11 Jul 2024 16:07:17 +0100 Subject: [PATCH 04/22] update docstring --- altair/utils/core.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/altair/utils/core.py b/altair/utils/core.py index 87cf01bf0..fda5d3eec 100644 --- a/altair/utils/core.py +++ b/altair/utils/core.py @@ -469,12 +469,11 @@ def sanitize_narwhals_dataframe( def narwhalify(data: DataType) -> nw.DataFrame[Any]: - """Wrap `data` in `narwhals.DataFrame` (if possible). + """Wrap `data` in `narwhals.DataFrame`. If `data` is not supported by Narwhals, but it is convertible to a PyArrow table, then first convert to a PyArrow Table, and then wrap in `narwhals.DataFrame`. - If it can't even be converted to a PyArrow Table, return as-is. """ # Using `strict=False` will return `data` as-is if the object cannot be converted. data = nw.from_native(data, eager_only=True, strict=False) From 81da742b065122aea1e09bacdea590fc4d882128 Mon Sep 17 00:00:00 2001 From: Marco Gorelli <33491632+MarcoGorelli@users.noreply.github.com> Date: Thu, 11 Jul 2024 17:06:50 +0100 Subject: [PATCH 05/22] fixup --- altair/_magics.py | 3 ++- altair/utils/core.py | 9 +++++++-- altair/utils/data.py | 16 ++++------------ pyproject.toml | 1 + tests/utils/test_data.py | 3 ++- tests/vegalite/v5/test_api.py | 4 +++- 6 files changed, 19 insertions(+), 17 deletions(-) diff --git a/altair/_magics.py b/altair/_magics.py index 638400c78..5b879b060 100644 --- a/altair/_magics.py +++ b/altair/_magics.py @@ -10,6 +10,7 @@ import IPython from IPython.core import magic_arguments from narwhals.dependencies import is_pandas_dataframe as _is_pandas_dataframe +import narwhals.stable.v1 as nw from altair.vegalite import v5 as vegalite_v5 @@ -41,7 +42,7 @@ def _prepare_data(data, data_transformers): return data elif _is_pandas_dataframe(data): if func := data_transformers.get(): - data = func(data) + data = func(nw.from_native(data, eager_only=True)) return data elif isinstance(data, str): return {"url": data} diff --git a/altair/utils/core.py b/altair/utils/core.py index fda5d3eec..4d7995193 100644 --- a/altair/utils/core.py +++ b/altair/utils/core.py @@ -475,6 +475,9 @@ def narwhalify(data: DataType) -> nw.DataFrame[Any]: to a PyArrow table, then first convert to a PyArrow Table, and then wrap in `narwhals.DataFrame`. """ + if isinstance(data, nw.DataFrame): + # Early return if already a Narwhals DataFrame + return data # Using `strict=False` will return `data` as-is if the object cannot be converted. data = nw.from_native(data, eager_only=True, strict=False) if isinstance(data, nw.DataFrame): @@ -632,11 +635,13 @@ def parse_shorthand( # if data is specified and type is not, infer type from data if "type" not in attrs and is_data_type(data): - data_nw = narwhalify(data) unescaped_field = attrs["field"].replace("\\", "") + data_nw = narwhalify(data) if unescaped_field in data_nw.columns: column = data_nw[unescaped_field] - if column.dtype in {nw.Object, nw.Unknown} and _is_pandas_dataframe(data): + if column.dtype in {nw.Object, nw.Unknown} and _is_pandas_dataframe( + nw.to_native(data_nw) + ): attrs["type"] = infer_vegalite_type_for_pandas(nw.to_native(column)) else: attrs["type"] = infer_vegalite_type_for_narwhals(column) diff --git a/altair/utils/data.py b/altair/utils/data.py index 2afa09618..c643274a6 100644 --- a/altair/utils/data.py +++ b/altair/utils/data.py @@ -144,21 +144,14 @@ def raise_max_rows_error(): values = data.__geo_interface__["features"] else: values = data.__geo_interface__ - elif _is_pandas_dataframe(data): - values = data elif isinstance(data, dict): if "values" in data: values = data["values"] else: return data else: - data_nw = narwhalify(data) - if max_rows is not None and len(data_nw) > max_rows: - raise_max_rows_error() - # `narwhalify` may call `arrow_table_from_dfi_dataframe`, - # which can be expensive. Therefore, we return the `narwhals.DataFrame` - # here instead of the original input. - return data_nw + data = narwhalify(data) + values = data if max_rows is not None and len(values) > max_rows: raise_max_rows_error() @@ -380,9 +373,8 @@ def _data_to_json_string(data: DataType) -> str: msg = "values expected in data dict, but not present." raise KeyError(msg) return json.dumps(data["values"], sort_keys=True) - elif isinstance(data, DataFrameLike): - pa_table = arrow_table_from_dfi_dataframe(data) - return json.dumps(pa_table.to_pylist()) + elif isinstance(data, nw.DataFrame): + return json.dumps(data.rows(named=True)) else: msg = "to_json only works with data expressed as " "a DataFrame or as a dict" raise NotImplementedError(msg) diff --git a/pyproject.toml b/pyproject.toml index acc9a1db9..ed632bfa6 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -353,6 +353,7 @@ module = [ "geopandas.*", "nbformat.*", "ipykernel.*", + "ibis.*", "m2r.*", # This refers to schemapi in the tools folder which is imported # by the tools scripts such as generate_schema_wrapper.py diff --git a/tests/utils/test_data.py b/tests/utils/test_data.py index 90ca71c67..f3a91790d 100644 --- a/tests/utils/test_data.py +++ b/tests/utils/test_data.py @@ -4,6 +4,7 @@ import pytest import pandas as pd import polars as pl +import narwhals.stable.v1 as nw from altair.utils.data import ( limit_rows, MaxRowsError, @@ -34,7 +35,7 @@ def _create_data_with_values(N): def test_limit_rows(): """Test the limit_rows data transformer.""" - data = _create_dataframe(10) + data = nw.from_native(_create_dataframe(10)) result = limit_rows(data, max_rows=20) assert data is result with pytest.raises(MaxRowsError): diff --git a/tests/vegalite/v5/test_api.py b/tests/vegalite/v5/test_api.py index 5d7a68818..db87f0ab8 100644 --- a/tests/vegalite/v5/test_api.py +++ b/tests/vegalite/v5/test_api.py @@ -9,6 +9,7 @@ import os import pathlib import tempfile +import narwhals.stable.v1 as nw import jsonschema import pytest @@ -741,7 +742,7 @@ def test_selection_property(): def test_LookupData(): - df = pd.DataFrame({"x": [1, 2, 3], "y": [4, 5, 6]}) + df = nw.from_native(pd.DataFrame({"x": [1, 2, 3], "y": [4, 5, 6]})) lookup = alt.LookupData(data=df, key="x") dct = lookup.to_dict() @@ -1090,6 +1091,7 @@ def test_ibis_with_date_32(): df = pl.DataFrame( {"a": [1, 2, 3], "b": [date(2020, 1, 1), date(2020, 1, 2), date(2020, 1, 3)]} ) + ibis.set_backend("polars") tbl = ibis.memtable(df) result = alt.Chart(tbl).mark_line().encode(x="a", y="b").to_dict() assert next(iter(result["datasets"].values())) == [ From c72fb9b24cb35bb40e337ce7c5767491d49fd3e0 Mon Sep 17 00:00:00 2001 From: Marco Gorelli <33491632+MarcoGorelli@users.noreply.github.com> Date: Thu, 11 Jul 2024 17:10:13 +0100 Subject: [PATCH 06/22] silence warning for old pandas --- tests/vegalite/v5/test_api.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/vegalite/v5/test_api.py b/tests/vegalite/v5/test_api.py index db87f0ab8..ac00f211d 100644 --- a/tests/vegalite/v5/test_api.py +++ b/tests/vegalite/v5/test_api.py @@ -1087,11 +1087,11 @@ def test_polars_with_pandas_nor_pyarrow(monkeypatch: pytest.MonkeyPatch): assert "numpy" not in sys.modules +@pytest.mark.skipif(not pd.__version__.startswith('2'), reason="A warning is thrown on old pandas versions") def test_ibis_with_date_32(): df = pl.DataFrame( {"a": [1, 2, 3], "b": [date(2020, 1, 1), date(2020, 1, 2), date(2020, 1, 3)]} ) - ibis.set_backend("polars") tbl = ibis.memtable(df) result = alt.Chart(tbl).mark_line().encode(x="a", y="b").to_dict() assert next(iter(result["datasets"].values())) == [ From 1078d38e137a6953c72711d46699b89031f6482b Mon Sep 17 00:00:00 2001 From: Marco Gorelli <33491632+MarcoGorelli@users.noreply.github.com> Date: Thu, 11 Jul 2024 17:14:11 +0100 Subject: [PATCH 07/22] improve skipif condition, simplify --- altair/_magics.py | 2 +- tests/vegalite/v5/test_api.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/altair/_magics.py b/altair/_magics.py index 5b879b060..962ff04ce 100644 --- a/altair/_magics.py +++ b/altair/_magics.py @@ -42,7 +42,7 @@ def _prepare_data(data, data_transformers): return data elif _is_pandas_dataframe(data): if func := data_transformers.get(): - data = func(nw.from_native(data, eager_only=True)) + data = func(data) return data elif isinstance(data, str): return {"url": data} diff --git a/tests/vegalite/v5/test_api.py b/tests/vegalite/v5/test_api.py index ac00f211d..72f8026d3 100644 --- a/tests/vegalite/v5/test_api.py +++ b/tests/vegalite/v5/test_api.py @@ -1087,7 +1087,7 @@ def test_polars_with_pandas_nor_pyarrow(monkeypatch: pytest.MonkeyPatch): assert "numpy" not in sys.modules -@pytest.mark.skipif(not pd.__version__.startswith('2'), reason="A warning is thrown on old pandas versions") +@pytest.mark.skipif(int(pd.__version__.split('.')[0]) < 2, reason="A warning is thrown on old pandas versions") def test_ibis_with_date_32(): df = pl.DataFrame( {"a": [1, 2, 3], "b": [date(2020, 1, 1), date(2020, 1, 2), date(2020, 1, 3)]} From bb84f2235c06fad9494dd20a21caf5fb36e312eb Mon Sep 17 00:00:00 2001 From: Marco Gorelli <33491632+MarcoGorelli@users.noreply.github.com> Date: Thu, 11 Jul 2024 17:16:58 +0100 Subject: [PATCH 08/22] lint --- altair/_magics.py | 1 - tests/vegalite/v5/test_api.py | 5 ++++- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/altair/_magics.py b/altair/_magics.py index 962ff04ce..638400c78 100644 --- a/altair/_magics.py +++ b/altair/_magics.py @@ -10,7 +10,6 @@ import IPython from IPython.core import magic_arguments from narwhals.dependencies import is_pandas_dataframe as _is_pandas_dataframe -import narwhals.stable.v1 as nw from altair.vegalite import v5 as vegalite_v5 diff --git a/tests/vegalite/v5/test_api.py b/tests/vegalite/v5/test_api.py index 72f8026d3..ec8990ff4 100644 --- a/tests/vegalite/v5/test_api.py +++ b/tests/vegalite/v5/test_api.py @@ -1087,7 +1087,10 @@ def test_polars_with_pandas_nor_pyarrow(monkeypatch: pytest.MonkeyPatch): assert "numpy" not in sys.modules -@pytest.mark.skipif(int(pd.__version__.split('.')[0]) < 2, reason="A warning is thrown on old pandas versions") +@pytest.mark.skipif( + int(pd.__version__.split(".")[0]) < 2, + reason="A warning is thrown on old pandas versions", +) def test_ibis_with_date_32(): df = pl.DataFrame( {"a": [1, 2, 3], "b": [date(2020, 1, 1), date(2020, 1, 2), date(2020, 1, 3)]} From b111fe9e54825d5f2f2b58453ad3a5ec34e4dc51 Mon Sep 17 00:00:00 2001 From: Marco Gorelli <33491632+MarcoGorelli@users.noreply.github.com> Date: Thu, 11 Jul 2024 17:20:19 +0100 Subject: [PATCH 09/22] explicitly set ibis backend --- tests/vegalite/v5/test_api.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/tests/vegalite/v5/test_api.py b/tests/vegalite/v5/test_api.py index ec8990ff4..e887f6573 100644 --- a/tests/vegalite/v5/test_api.py +++ b/tests/vegalite/v5/test_api.py @@ -23,6 +23,8 @@ except ImportError: vlc = None +ibis.set_backend("polars") + def getargs(*args, **kwargs): return args, kwargs From b2118f96dde5fb0a2c0ab09a8a6ead519ae757dc Mon Sep 17 00:00:00 2001 From: Marco Gorelli <33491632+MarcoGorelli@users.noreply.github.com> Date: Thu, 11 Jul 2024 18:48:52 +0100 Subject: [PATCH 10/22] use importlib / packaging --- tests/vegalite/v5/test_api.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/tests/vegalite/v5/test_api.py b/tests/vegalite/v5/test_api.py index e887f6573..615706ff7 100644 --- a/tests/vegalite/v5/test_api.py +++ b/tests/vegalite/v5/test_api.py @@ -9,9 +9,11 @@ import os import pathlib import tempfile -import narwhals.stable.v1 as nw +from importlib.metadata import version as importlib_version +from packaging.version import Version import jsonschema +import narwhals.stable.v1 as nw import pytest import pandas as pd import polars as pl @@ -25,6 +27,8 @@ ibis.set_backend("polars") +PANDAS_VERSION = Version(importlib_version("pandas")) + def getargs(*args, **kwargs): return args, kwargs @@ -1090,7 +1094,7 @@ def test_polars_with_pandas_nor_pyarrow(monkeypatch: pytest.MonkeyPatch): @pytest.mark.skipif( - int(pd.__version__.split(".")[0]) < 2, + Version("1.5") > PANDAS_VERSION, reason="A warning is thrown on old pandas versions", ) def test_ibis_with_date_32(): From 84bda857f9c966b65fb1c7d9e703960cf597c7fd Mon Sep 17 00:00:00 2001 From: Marco Gorelli <33491632+MarcoGorelli@users.noreply.github.com> Date: Thu, 11 Jul 2024 18:50:10 +0100 Subject: [PATCH 11/22] xfail ibis date32 test on windows --- tests/vegalite/v5/test_api.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/tests/vegalite/v5/test_api.py b/tests/vegalite/v5/test_api.py index 615706ff7..0c9a720c6 100644 --- a/tests/vegalite/v5/test_api.py +++ b/tests/vegalite/v5/test_api.py @@ -1097,6 +1097,9 @@ def test_polars_with_pandas_nor_pyarrow(monkeypatch: pytest.MonkeyPatch): Version("1.5") > PANDAS_VERSION, reason="A warning is thrown on old pandas versions", ) +@pytest.mark.xfail( + sys.platform == "win32", reason="Timezone database is not installed on Windows" +) def test_ibis_with_date_32(): df = pl.DataFrame( {"a": [1, 2, 3], "b": [date(2020, 1, 1), date(2020, 1, 2), date(2020, 1, 3)]} From 110f848abcd5bee351d0da9b18275ff6d98cfb7e Mon Sep 17 00:00:00 2001 From: Marco Gorelli <33491632+MarcoGorelli@users.noreply.github.com> Date: Sun, 14 Jul 2024 12:01:53 +0100 Subject: [PATCH 12/22] wip --- altair/utils/core.py | 33 ++++++++++++++++++--------------- altair/utils/data.py | 20 ++++++-------------- altair/vegalite/v5/api.py | 15 ++++----------- 3 files changed, 28 insertions(+), 40 deletions(-) diff --git a/altair/utils/core.py b/altair/utils/core.py index 4d7995193..6392512d9 100644 --- a/altair/utils/core.py +++ b/altair/utils/core.py @@ -468,27 +468,29 @@ def sanitize_narwhals_dataframe( return data.select(columns) -def narwhalify(data: DataType) -> nw.DataFrame[Any]: +def to_eager_narwhals_dataframe(data: DataType) -> nw.DataFrame[Any]: """Wrap `data` in `narwhals.DataFrame`. If `data` is not supported by Narwhals, but it is convertible to a PyArrow table, then first convert to a PyArrow Table, and then wrap in `narwhals.DataFrame`. """ - if isinstance(data, nw.DataFrame): - # Early return if already a Narwhals DataFrame - return data - # Using `strict=False` will return `data` as-is if the object cannot be converted. - data = nw.from_native(data, eager_only=True, strict=False) - if isinstance(data, nw.DataFrame): - return data - if isinstance(data, DataFrameLike): + data_nw = nw.from_native(data) + if nw.get_level(data_nw) == 'metadata': + # If Narwhals' support for `data`'s class is only metadata-level, then we + # use the interchange protocol to convert to a PyArrow Table. from altair.utils.data import arrow_table_from_dfi_dataframe - pa_table = arrow_table_from_dfi_dataframe(data) - return nw.from_native(pa_table, eager_only=True) - msg = f"Unsupported data type: {type(data)}" - raise TypeError(msg) + data_nw = nw.from_native(pa_table, eager_only=True) + elif isinstance(data_nw, nw.LazyFrame): + msg = ( + "Lazy objects which do not implement the dataframe interchange protocol " + "are not supported. Please collect your lazy object into an eager one " + "first." + ) + raise NotImplementedError(msg) + + return data_nw def parse_shorthand( @@ -636,8 +638,9 @@ def parse_shorthand( # if data is specified and type is not, infer type from data if "type" not in attrs and is_data_type(data): unescaped_field = attrs["field"].replace("\\", "") - data_nw = narwhalify(data) - if unescaped_field in data_nw.columns: + data_nw = nw.from_native(data) + schema = data_nw.schema + if unescaped_field in schema: column = data_nw[unescaped_field] if column.dtype in {nw.Object, nw.Unknown} and _is_pandas_dataframe( nw.to_native(data_nw) diff --git a/altair/utils/data.py b/altair/utils/data.py index c643274a6..affd8c3c0 100644 --- a/altair/utils/data.py +++ b/altair/utils/data.py @@ -31,7 +31,6 @@ sanitize_pandas_dataframe, DataFrameLike, sanitize_narwhals_dataframe, - narwhalify, ) from .core import sanitize_geo_interface from .plugin_registry import PluginRegistry @@ -150,8 +149,8 @@ def raise_max_rows_error(): else: return data else: - data = narwhalify(data) - values = data + from altair.utils.core import to_eager_narwhals_dataframe + values = to_eager_narwhals_dataframe(data) if max_rows is not None and len(values) > max_rows: raise_max_rows_error() @@ -196,10 +195,6 @@ def sample( # Maybe this should raise an error or return something useful? return None data = narwhalify(data) - if not isinstance(data, nw.DataFrame): - # Maybe this should raise an error or return something useful? Currently, - # if data is of type SupportsGeoInterface it lands here - return None if not n: if frac is None: msg = "frac cannot be None if n is None with this data input type" @@ -333,13 +328,10 @@ def to_values(data: DataType) -> ToValuesReturnType: msg = "values expected in data dict, but not present." raise KeyError(msg) return data_native - elif isinstance(data, nw.DataFrame): - data = sanitize_narwhals_dataframe(data) - return {"values": data.rows(named=True)} - else: - # Should never reach this state as tested by check_data_type - msg = f"Unrecognized data type: {type(data)}" - raise ValueError(msg) + from altair.utils.core import to_eager_narwhals_dataframe + data = to_eager_narwhals_dataframe(data) + data = sanitize_narwhals_dataframe(data) + return {"values": data.rows(named=True)} def check_data_type(data: DataType) -> None: diff --git a/altair/vegalite/v5/api.py b/altair/vegalite/v5/api.py index 82fc2059d..843315fa7 100644 --- a/altair/vegalite/v5/api.py +++ b/altair/vegalite/v5/api.py @@ -13,7 +13,6 @@ from .schema import core, channels, mixins, Undefined, SCHEMA_URL from altair.utils import Optional -from altair.utils.data import narwhalify as _narwhalify from .data import data_transformers from ... import utils from ...expr import core as _expr_core @@ -1016,17 +1015,11 @@ def to_dict( # TopLevelMixin instance does not necessarily have copy defined but due to how # Altair is set up this should hold. Too complex to type hint right now copy = self.copy(deep=False) # type: ignore[attr-defined] + original_data = getattr(copy, "data", Undefined) + copy.data = _prepare_data(original_data, context) - data = getattr(copy, "data", Undefined) - try: - data = _narwhalify(data) # type: ignore[arg-type] - except TypeError: - # Non-narwhalifiable type still supported by Altair, such as dict. - pass - copy.data = _prepare_data(data, context) - - if data is not Undefined: - context["data"] = data + if original_data is not Undefined: + context["data"] = original_data # remaining to_dict calls are not at top level context["top_level"] = False From 6be087aa353406a78b0e2c1b75e846bb28aa2ec7 Mon Sep 17 00:00:00 2001 From: Marco Gorelli <33491632+MarcoGorelli@users.noreply.github.com> Date: Sun, 14 Jul 2024 12:12:51 +0100 Subject: [PATCH 13/22] wip --- altair/utils/data.py | 12 +++++++----- tests/utils/test_data.py | 2 +- 2 files changed, 8 insertions(+), 6 deletions(-) diff --git a/altair/utils/data.py b/altair/utils/data.py index affd8c3c0..9ad730e8b 100644 --- a/altair/utils/data.py +++ b/altair/utils/data.py @@ -194,7 +194,7 @@ def sample( else: # Maybe this should raise an error or return something useful? return None - data = narwhalify(data) + data = nw.from_native(data, eager_only=True) if not n: if frac is None: msg = "frac cannot be None if n is None with this data input type" @@ -312,6 +312,7 @@ def _to_text_kwds(prefix: str, extension: str, filename: str, urlpath: str, /) - def to_values(data: DataType) -> ToValuesReturnType: """Replace a DataFrame by a data model with values.""" check_data_type(data) + breakpoint() data_native = nw.to_native(data, strict=False) if isinstance(data_native, SupportsGeoInterface): if _is_pandas_dataframe(data_native): @@ -328,10 +329,11 @@ def to_values(data: DataType) -> ToValuesReturnType: msg = "values expected in data dict, but not present." raise KeyError(msg) return data_native - from altair.utils.core import to_eager_narwhals_dataframe - data = to_eager_narwhals_dataframe(data) - data = sanitize_narwhals_dataframe(data) - return {"values": data.rows(named=True)} + elif isinstance(data_native, nw.DataFrame): + data = sanitize_narwhals_dataframe(data) + return {"values": data.rows(named=True)} + else: + raise TypeError("Unreachable") def check_data_type(data: DataType) -> None: diff --git a/tests/utils/test_data.py b/tests/utils/test_data.py index f3a91790d..3be90f3bf 100644 --- a/tests/utils/test_data.py +++ b/tests/utils/test_data.py @@ -35,7 +35,7 @@ def _create_data_with_values(N): def test_limit_rows(): """Test the limit_rows data transformer.""" - data = nw.from_native(_create_dataframe(10)) + data = _create_dataframe(10) result = limit_rows(data, max_rows=20) assert data is result with pytest.raises(MaxRowsError): From 3063fdf616312c051e9c60cef1613d9cab630000 Mon Sep 17 00:00:00 2001 From: Marco Gorelli <33491632+MarcoGorelli@users.noreply.github.com> Date: Sun, 14 Jul 2024 12:13:18 +0100 Subject: [PATCH 14/22] wip --- altair/utils/data.py | 1 - 1 file changed, 1 deletion(-) diff --git a/altair/utils/data.py b/altair/utils/data.py index 9ad730e8b..f9d2f512f 100644 --- a/altair/utils/data.py +++ b/altair/utils/data.py @@ -312,7 +312,6 @@ def _to_text_kwds(prefix: str, extension: str, filename: str, urlpath: str, /) - def to_values(data: DataType) -> ToValuesReturnType: """Replace a DataFrame by a data model with values.""" check_data_type(data) - breakpoint() data_native = nw.to_native(data, strict=False) if isinstance(data_native, SupportsGeoInterface): if _is_pandas_dataframe(data_native): From 1c8c5a3c2a3e34da15700b97fe2f9b4b0a921633 Mon Sep 17 00:00:00 2001 From: Marco Gorelli <33491632+MarcoGorelli@users.noreply.github.com> Date: Sun, 14 Jul 2024 12:23:20 +0100 Subject: [PATCH 15/22] wip date32 --- altair/utils/_vegafusion_data.py | 6 +++--- altair/utils/data.py | 16 ++++++++++------ 2 files changed, 13 insertions(+), 9 deletions(-) diff --git a/altair/utils/_vegafusion_data.py b/altair/utils/_vegafusion_data.py index 57e1099ad..13480034a 100644 --- a/altair/utils/_vegafusion_data.py +++ b/altair/utils/_vegafusion_data.py @@ -21,6 +21,7 @@ MaxRowsError, SupportsGeoInterface, ) +from altair.utils.core import DataFrameLike from altair.vegalite.data import default_data_transformer if TYPE_CHECKING: @@ -71,10 +72,9 @@ def vegafusion_data_transformer( """VegaFusion Data Transformer""" if data is None: return vegafusion_data_transformer - elif isinstance(data, nw.DataFrame) and not isinstance(data, SupportsGeoInterface): + elif isinstance(data, DataFrameLike) and not isinstance(data, SupportsGeoInterface): table_name = f"table_{uuid.uuid4()}".replace("-", "_") - # vegafusion doesn't support Narwhals, so we extract the native object. - extracted_inline_tables[table_name] = nw.to_native(data) + extracted_inline_tables[table_name] = data return {"url": VEGAFUSION_PREFIX + table_name} else: # Use default transformer for geo interface objects diff --git a/altair/utils/data.py b/altair/utils/data.py index f9d2f512f..2d740e93d 100644 --- a/altair/utils/data.py +++ b/altair/utils/data.py @@ -150,7 +150,8 @@ def raise_max_rows_error(): return data else: from altair.utils.core import to_eager_narwhals_dataframe - values = to_eager_narwhals_dataframe(data) + data = to_eager_narwhals_dataframe(data) + values = data if max_rows is not None and len(values) > max_rows: raise_max_rows_error() @@ -328,11 +329,13 @@ def to_values(data: DataType) -> ToValuesReturnType: msg = "values expected in data dict, but not present." raise KeyError(msg) return data_native - elif isinstance(data_native, nw.DataFrame): + elif isinstance(data, nw.DataFrame): data = sanitize_narwhals_dataframe(data) return {"values": data.rows(named=True)} else: - raise TypeError("Unreachable") + # Should never reach this state as tested by check_data_type + msg = f"Unrecognized data type: {type(data)}" + raise ValueError(msg) def check_data_type(data: DataType) -> None: @@ -351,6 +354,7 @@ def _compute_data_hash(data_str: str) -> str: def _data_to_json_string(data: DataType) -> str: """Return a JSON string representation of the input data""" check_data_type(data) + data_native = nw.to_native(data, strict=False) if isinstance(data, SupportsGeoInterface): if _is_pandas_dataframe(data): data = sanitize_pandas_dataframe(data) @@ -358,9 +362,9 @@ def _data_to_json_string(data: DataType) -> str: # SupportGeoInterface and then the ignore statement is not needed? data = sanitize_geo_interface(data.__geo_interface__) # type: ignore[arg-type] return json.dumps(data) - elif _is_pandas_dataframe(data): - data = sanitize_pandas_dataframe(data) - return data.to_json(orient="records", double_precision=15) + elif _is_pandas_dataframe(data_native): + data = sanitize_pandas_dataframe(data_native) + return data_native.to_json(orient="records", double_precision=15) elif isinstance(data, dict): if "values" not in data: msg = "values expected in data dict, but not present." From b0ca54de1c29363fefc295730ac231c1c1339357 Mon Sep 17 00:00:00 2001 From: Marco Gorelli <33491632+MarcoGorelli@users.noreply.github.com> Date: Sun, 14 Jul 2024 12:43:59 +0100 Subject: [PATCH 16/22] reorder imports --- altair/utils/_vegafusion_data.py | 3 +++ altair/utils/data.py | 16 ++++++++-------- altair/vegalite/v5/api.py | 10 ++++++++-- tests/utils/test_data.py | 2 +- 4 files changed, 20 insertions(+), 11 deletions(-) diff --git a/altair/utils/_vegafusion_data.py b/altair/utils/_vegafusion_data.py index 13480034a..bfd81486c 100644 --- a/altair/utils/_vegafusion_data.py +++ b/altair/utils/_vegafusion_data.py @@ -70,6 +70,9 @@ def vegafusion_data_transformer( data: DataType | None = None, max_rows: int = 100000 ) -> Callable[..., Any] | _VegaFusionReturnType: """VegaFusion Data Transformer""" + # Vegafusion does not support Narwhals, so if `data` is a Narwhals + # object, we make sure to extract the native object and let Vegafusion handle it. + data = nw.to_native(data, strict=False) if data is None: return vegafusion_data_transformer elif isinstance(data, DataFrameLike) and not isinstance(data, SupportsGeoInterface): diff --git a/altair/utils/data.py b/altair/utils/data.py index 2d740e93d..506a3217a 100644 --- a/altair/utils/data.py +++ b/altair/utils/data.py @@ -31,6 +31,7 @@ sanitize_pandas_dataframe, DataFrameLike, sanitize_narwhals_dataframe, + to_eager_narwhals_dataframe, ) from .core import sanitize_geo_interface from .plugin_registry import PluginRegistry @@ -149,7 +150,6 @@ def raise_max_rows_error(): else: return data else: - from altair.utils.core import to_eager_narwhals_dataframe data = to_eager_narwhals_dataframe(data) values = data @@ -355,18 +355,18 @@ def _data_to_json_string(data: DataType) -> str: """Return a JSON string representation of the input data""" check_data_type(data) data_native = nw.to_native(data, strict=False) - if isinstance(data, SupportsGeoInterface): - if _is_pandas_dataframe(data): - data = sanitize_pandas_dataframe(data) + if isinstance(data_native, SupportsGeoInterface): + if _is_pandas_dataframe(data_native): + data_native = sanitize_pandas_dataframe(data_native) # Maybe the type could be further clarified here that it is # SupportGeoInterface and then the ignore statement is not needed? - data = sanitize_geo_interface(data.__geo_interface__) # type: ignore[arg-type] - return json.dumps(data) + data_native = sanitize_geo_interface(data_native.__geo_interface__) # type: ignore[arg-type] + return json.dumps(data_native) elif _is_pandas_dataframe(data_native): data = sanitize_pandas_dataframe(data_native) return data_native.to_json(orient="records", double_precision=15) - elif isinstance(data, dict): - if "values" not in data: + elif isinstance(data_native, dict): + if "values" not in data_native: msg = "values expected in data dict, but not present." raise KeyError(msg) return json.dumps(data["values"], sort_keys=True) diff --git a/altair/vegalite/v5/api.py b/altair/vegalite/v5/api.py index 23138613f..92aee216c 100644 --- a/altair/vegalite/v5/api.py +++ b/altair/vegalite/v5/api.py @@ -25,6 +25,7 @@ ) from ...utils.data import DataType, is_data_type as _is_data_type from ...utils.deprecation import AltairDeprecationWarning +from ...utils.core import to_eager_narwhals_dataframe if TYPE_CHECKING: from ...utils.core import DataFrameLike @@ -1007,10 +1008,15 @@ def to_dict( # Altair is set up this should hold. Too complex to type hint right now copy = self.copy(deep=False) # type: ignore[attr-defined] original_data = getattr(copy, "data", Undefined) - copy.data = _prepare_data(original_data, context) + try: + data = to_eager_narwhals_dataframe(original_data) + except TypeError: + # Non-narwhalifiable type support by Altair, such as dict + data = original_data + copy.data = _prepare_data(data, context) if original_data is not Undefined: - context["data"] = original_data + context["data"] = data # remaining to_dict calls are not at top level context["top_level"] = False diff --git a/tests/utils/test_data.py b/tests/utils/test_data.py index 3be90f3bf..f58fc9f10 100644 --- a/tests/utils/test_data.py +++ b/tests/utils/test_data.py @@ -35,7 +35,7 @@ def _create_data_with_values(N): def test_limit_rows(): """Test the limit_rows data transformer.""" - data = _create_dataframe(10) + data = nw.from_native(_create_dataframe(10), eager_only=True) result = limit_rows(data, max_rows=20) assert data is result with pytest.raises(MaxRowsError): From d0417df7b49b703e7d72fb93c638173ac56b4950 Mon Sep 17 00:00:00 2001 From: Marco Gorelli <33491632+MarcoGorelli@users.noreply.github.com> Date: Sun, 14 Jul 2024 13:02:23 +0100 Subject: [PATCH 17/22] typing --- altair/utils/core.py | 12 ++++++------ altair/utils/data.py | 8 +++----- altair/vegalite/v5/api.py | 2 +- tests/utils/test_data.py | 2 +- tests/utils/test_utils.py | 2 +- 5 files changed, 12 insertions(+), 14 deletions(-) diff --git a/altair/utils/core.py b/altair/utils/core.py index 6392512d9..81212e792 100644 --- a/altair/utils/core.py +++ b/altair/utils/core.py @@ -44,7 +44,7 @@ from altair.vegalite.v5.schema._typing import StandardType_T as InferredVegaLiteType from altair.utils._dfi_types import DataFrame as DfiDataFrame from altair.utils.data import DataType - from narwhals.typing import IntoExpr, IntoDataFrameT + from narwhals.typing import IntoExpr, IntoDataFrameT, IntoDataFrame import pandas as pd V = TypeVar("V") @@ -468,7 +468,7 @@ def sanitize_narwhals_dataframe( return data.select(columns) -def to_eager_narwhals_dataframe(data: DataType) -> nw.DataFrame[Any]: +def to_eager_narwhals_dataframe(data: IntoDataFrame) -> nw.DataFrame[Any]: """Wrap `data` in `narwhals.DataFrame`. If `data` is not supported by Narwhals, but it is convertible @@ -480,8 +480,8 @@ def to_eager_narwhals_dataframe(data: DataType) -> nw.DataFrame[Any]: # If Narwhals' support for `data`'s class is only metadata-level, then we # use the interchange protocol to convert to a PyArrow Table. from altair.utils.data import arrow_table_from_dfi_dataframe - pa_table = arrow_table_from_dfi_dataframe(data) - data_nw = nw.from_native(pa_table, eager_only=True) + pa_table = arrow_table_from_dfi_dataframe(data) # type: ignore[arg-type] + data_nw = nw.from_native(pa_table, eager_or_interchange_only=True) elif isinstance(data_nw, nw.LazyFrame): msg = ( "Lazy objects which do not implement the dataframe interchange protocol " @@ -638,11 +638,11 @@ def parse_shorthand( # if data is specified and type is not, infer type from data if "type" not in attrs and is_data_type(data): unescaped_field = attrs["field"].replace("\\", "") - data_nw = nw.from_native(data) + data_nw = nw.from_native(data, eager_or_interchange_only=True) schema = data_nw.schema if unescaped_field in schema: column = data_nw[unescaped_field] - if column.dtype in {nw.Object, nw.Unknown} and _is_pandas_dataframe( + if schema[unescaped_field] in {nw.Object, nw.Unknown} and _is_pandas_dataframe( nw.to_native(data_nw) ): attrs["type"] = infer_vegalite_type_for_pandas(nw.to_native(column)) diff --git a/altair/utils/data.py b/altair/utils/data.py index 506a3217a..725fbddad 100644 --- a/altair/utils/data.py +++ b/altair/utils/data.py @@ -195,7 +195,7 @@ def sample( else: # Maybe this should raise an error or return something useful? return None - data = nw.from_native(data, eager_only=True) + data = nw.from_native(data, eager_or_interchange_only=True) if not n: if frac is None: msg = "frac cannot be None if n is None with this data input type" @@ -358,9 +358,7 @@ def _data_to_json_string(data: DataType) -> str: if isinstance(data_native, SupportsGeoInterface): if _is_pandas_dataframe(data_native): data_native = sanitize_pandas_dataframe(data_native) - # Maybe the type could be further clarified here that it is - # SupportGeoInterface and then the ignore statement is not needed? - data_native = sanitize_geo_interface(data_native.__geo_interface__) # type: ignore[arg-type] + data_native = sanitize_geo_interface(data_native.__geo_interface__) return json.dumps(data_native) elif _is_pandas_dataframe(data_native): data = sanitize_pandas_dataframe(data_native) @@ -369,7 +367,7 @@ def _data_to_json_string(data: DataType) -> str: if "values" not in data_native: msg = "values expected in data dict, but not present." raise KeyError(msg) - return json.dumps(data["values"], sort_keys=True) + return json.dumps(data_native["values"], sort_keys=True) elif isinstance(data, nw.DataFrame): return json.dumps(data.rows(named=True)) else: diff --git a/altair/vegalite/v5/api.py b/altair/vegalite/v5/api.py index 92aee216c..2a75d2357 100644 --- a/altair/vegalite/v5/api.py +++ b/altair/vegalite/v5/api.py @@ -1009,7 +1009,7 @@ def to_dict( copy = self.copy(deep=False) # type: ignore[attr-defined] original_data = getattr(copy, "data", Undefined) try: - data = to_eager_narwhals_dataframe(original_data) + data: Any = to_eager_narwhals_dataframe(original_data) # type: ignore[arg-type] except TypeError: # Non-narwhalifiable type support by Altair, such as dict data = original_data diff --git a/tests/utils/test_data.py b/tests/utils/test_data.py index f58fc9f10..c0359ba31 100644 --- a/tests/utils/test_data.py +++ b/tests/utils/test_data.py @@ -35,7 +35,7 @@ def _create_data_with_values(N): def test_limit_rows(): """Test the limit_rows data transformer.""" - data = nw.from_native(_create_dataframe(10), eager_only=True) + data = nw.from_native(_create_dataframe(10), eager_or_interchange_only=True) result = limit_rows(data, max_rows=20) assert data is result with pytest.raises(MaxRowsError): diff --git a/tests/utils/test_utils.py b/tests/utils/test_utils.py index d9569f0b1..fc2daa67b 100644 --- a/tests/utils/test_utils.py +++ b/tests/utils/test_utils.py @@ -162,7 +162,7 @@ def test_sanitize_pyarrow_table_columns() -> None: ] ), ) - sanitized = sanitize_narwhals_dataframe(nw.from_native(pa_table, eager_only=True)) + sanitized = sanitize_narwhals_dataframe(nw.from_native(pa_table, eager_or_interchange_only=True)) values = sanitized.rows(named=True) assert values[0] == { From 5d54bc4263dd5bdba600819b316dda8c592bdd8f Mon Sep 17 00:00:00 2001 From: Marco Gorelli <33491632+MarcoGorelli@users.noreply.github.com> Date: Sun, 14 Jul 2024 19:52:20 +0100 Subject: [PATCH 18/22] bump version --- altair/utils/_vegafusion_data.py | 1 - altair/utils/core.py | 23 ++++++++--------------- altair/utils/data.py | 2 +- altair/vegalite/v5/api.py | 4 ++-- pyproject.toml | 2 +- tests/utils/test_data.py | 2 +- tests/utils/test_utils.py | 2 +- 7 files changed, 14 insertions(+), 22 deletions(-) diff --git a/altair/utils/_vegafusion_data.py b/altair/utils/_vegafusion_data.py index bfd81486c..84257328b 100644 --- a/altair/utils/_vegafusion_data.py +++ b/altair/utils/_vegafusion_data.py @@ -26,7 +26,6 @@ if TYPE_CHECKING: from narwhals.typing import IntoDataFrame - from altair.utils.core import DataFrameLike from vegafusion.runtime import ChartState # type: ignore # Temporary storage for dataframes that have been extracted diff --git a/altair/utils/core.py b/altair/utils/core.py index 81212e792..c4d6e71a2 100644 --- a/altair/utils/core.py +++ b/altair/utils/core.py @@ -43,7 +43,6 @@ import typing as t from altair.vegalite.v5.schema._typing import StandardType_T as InferredVegaLiteType from altair.utils._dfi_types import DataFrame as DfiDataFrame - from altair.utils.data import DataType from narwhals.typing import IntoExpr, IntoDataFrameT, IntoDataFrame import pandas as pd @@ -475,21 +474,14 @@ def to_eager_narwhals_dataframe(data: IntoDataFrame) -> nw.DataFrame[Any]: to a PyArrow table, then first convert to a PyArrow Table, and then wrap in `narwhals.DataFrame`. """ - data_nw = nw.from_native(data) - if nw.get_level(data_nw) == 'metadata': + data_nw = nw.from_native(data, eager_or_interchange_only=True) + if nw.get_level(data_nw) == "interchange": # If Narwhals' support for `data`'s class is only metadata-level, then we # use the interchange protocol to convert to a PyArrow Table. from altair.utils.data import arrow_table_from_dfi_dataframe - pa_table = arrow_table_from_dfi_dataframe(data) # type: ignore[arg-type] - data_nw = nw.from_native(pa_table, eager_or_interchange_only=True) - elif isinstance(data_nw, nw.LazyFrame): - msg = ( - "Lazy objects which do not implement the dataframe interchange protocol " - "are not supported. Please collect your lazy object into an eager one " - "first." - ) - raise NotImplementedError(msg) + pa_table = arrow_table_from_dfi_dataframe(data) # type: ignore[arg-type] + data_nw = nw.from_native(pa_table, eager_only=True) return data_nw @@ -642,9 +634,10 @@ def parse_shorthand( schema = data_nw.schema if unescaped_field in schema: column = data_nw[unescaped_field] - if schema[unescaped_field] in {nw.Object, nw.Unknown} and _is_pandas_dataframe( - nw.to_native(data_nw) - ): + if schema[unescaped_field] in { + nw.Object, + nw.Unknown, + } and _is_pandas_dataframe(nw.to_native(data_nw)): attrs["type"] = infer_vegalite_type_for_pandas(nw.to_native(column)) else: attrs["type"] = infer_vegalite_type_for_narwhals(column) diff --git a/altair/utils/data.py b/altair/utils/data.py index 725fbddad..9f690fc9f 100644 --- a/altair/utils/data.py +++ b/altair/utils/data.py @@ -195,7 +195,7 @@ def sample( else: # Maybe this should raise an error or return something useful? return None - data = nw.from_native(data, eager_or_interchange_only=True) + data = nw.from_native(data, eager_only=True) if not n: if frac is None: msg = "frac cannot be None if n is None with this data input type" diff --git a/altair/vegalite/v5/api.py b/altair/vegalite/v5/api.py index 2a75d2357..ef08b96c7 100644 --- a/altair/vegalite/v5/api.py +++ b/altair/vegalite/v5/api.py @@ -25,7 +25,7 @@ ) from ...utils.data import DataType, is_data_type as _is_data_type from ...utils.deprecation import AltairDeprecationWarning -from ...utils.core import to_eager_narwhals_dataframe +from ...utils.core import to_eager_narwhals_dataframe as _to_eager_narwhals_dataframe if TYPE_CHECKING: from ...utils.core import DataFrameLike @@ -1009,7 +1009,7 @@ def to_dict( copy = self.copy(deep=False) # type: ignore[attr-defined] original_data = getattr(copy, "data", Undefined) try: - data: Any = to_eager_narwhals_dataframe(original_data) # type: ignore[arg-type] + data: Any = _to_eager_narwhals_dataframe(original_data) # type: ignore[arg-type] except TypeError: # Non-narwhalifiable type support by Altair, such as dict data = original_data diff --git a/pyproject.toml b/pyproject.toml index 6f6d3b4a9..04430f426 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -20,7 +20,7 @@ dependencies = [ # If you update the minimum required jsonschema version, also update it in build.yml "jsonschema>=3.0", "packaging", - "narwhals>=1.0.0" + "narwhals>=1.1.0" ] description = "Vega-Altair: A declarative statistical visualization library for Python." readme = "README.md" diff --git a/tests/utils/test_data.py b/tests/utils/test_data.py index c0359ba31..f58fc9f10 100644 --- a/tests/utils/test_data.py +++ b/tests/utils/test_data.py @@ -35,7 +35,7 @@ def _create_data_with_values(N): def test_limit_rows(): """Test the limit_rows data transformer.""" - data = nw.from_native(_create_dataframe(10), eager_or_interchange_only=True) + data = nw.from_native(_create_dataframe(10), eager_only=True) result = limit_rows(data, max_rows=20) assert data is result with pytest.raises(MaxRowsError): diff --git a/tests/utils/test_utils.py b/tests/utils/test_utils.py index fc2daa67b..d9569f0b1 100644 --- a/tests/utils/test_utils.py +++ b/tests/utils/test_utils.py @@ -162,7 +162,7 @@ def test_sanitize_pyarrow_table_columns() -> None: ] ), ) - sanitized = sanitize_narwhals_dataframe(nw.from_native(pa_table, eager_or_interchange_only=True)) + sanitized = sanitize_narwhals_dataframe(nw.from_native(pa_table, eager_only=True)) values = sanitized.rows(named=True) assert values[0] == { From b52eaca2eb327ebe365d61eb12946dfd69e041f0 Mon Sep 17 00:00:00 2001 From: Marco Gorelli <33491632+MarcoGorelli@users.noreply.github.com> Date: Sun, 14 Jul 2024 23:03:36 +0100 Subject: [PATCH 19/22] typo, `strict=False` comments, use consistent typevar name --- altair/utils/_vegafusion_data.py | 1 + altair/utils/core.py | 7 ++++--- altair/utils/data.py | 2 ++ altair/vegalite/v5/api.py | 2 +- 4 files changed, 8 insertions(+), 4 deletions(-) diff --git a/altair/utils/_vegafusion_data.py b/altair/utils/_vegafusion_data.py index 84257328b..c9f127378 100644 --- a/altair/utils/_vegafusion_data.py +++ b/altair/utils/_vegafusion_data.py @@ -71,6 +71,7 @@ def vegafusion_data_transformer( """VegaFusion Data Transformer""" # Vegafusion does not support Narwhals, so if `data` is a Narwhals # object, we make sure to extract the native object and let Vegafusion handle it. + # `strict=False` passes `data` through as-is if it is not a Narwhals object. data = nw.to_native(data, strict=False) if data is None: return vegafusion_data_transformer diff --git a/altair/utils/core.py b/altair/utils/core.py index c4d6e71a2..d689863da 100644 --- a/altair/utils/core.py +++ b/altair/utils/core.py @@ -43,11 +43,12 @@ import typing as t from altair.vegalite.v5.schema._typing import StandardType_T as InferredVegaLiteType from altair.utils._dfi_types import DataFrame as DfiDataFrame - from narwhals.typing import IntoExpr, IntoDataFrameT, IntoDataFrame + from narwhals.typing import IntoExpr, IntoDataFrame import pandas as pd V = TypeVar("V") P = ParamSpec("P") +TIntoDataFrame = TypeVar("TIntoDataFrame", bound=IntoDataFrame) @runtime_checkable @@ -438,8 +439,8 @@ def to_list_if_array(val): def sanitize_narwhals_dataframe( - data: nw.DataFrame[IntoDataFrameT], -) -> nw.DataFrame[IntoDataFrameT]: + data: nw.DataFrame[TIntoDataFrame], +) -> nw.DataFrame[TIntoDataFrame]: """Sanitize narwhals.DataFrame for JSON serialization""" schema = data.schema columns: list[IntoExpr] = [] diff --git a/altair/utils/data.py b/altair/utils/data.py index 9f690fc9f..daf37393d 100644 --- a/altair/utils/data.py +++ b/altair/utils/data.py @@ -313,6 +313,7 @@ def _to_text_kwds(prefix: str, extension: str, filename: str, urlpath: str, /) - def to_values(data: DataType) -> ToValuesReturnType: """Replace a DataFrame by a data model with values.""" check_data_type(data) + # `strict=False` passes `data` through as-is if it is not a Narwhals object. data_native = nw.to_native(data, strict=False) if isinstance(data_native, SupportsGeoInterface): if _is_pandas_dataframe(data_native): @@ -354,6 +355,7 @@ def _compute_data_hash(data_str: str) -> str: def _data_to_json_string(data: DataType) -> str: """Return a JSON string representation of the input data""" check_data_type(data) + # `strict=False` passes `data` through as-is if it is not a Narwhals object. data_native = nw.to_native(data, strict=False) if isinstance(data_native, SupportsGeoInterface): if _is_pandas_dataframe(data_native): diff --git a/altair/vegalite/v5/api.py b/altair/vegalite/v5/api.py index ef08b96c7..635577db8 100644 --- a/altair/vegalite/v5/api.py +++ b/altair/vegalite/v5/api.py @@ -1011,7 +1011,7 @@ def to_dict( try: data: Any = _to_eager_narwhals_dataframe(original_data) # type: ignore[arg-type] except TypeError: - # Non-narwhalifiable type support by Altair, such as dict + # Non-narwhalifiable type supported by Altair, such as dict data = original_data copy.data = _prepare_data(data, context) From 8b4b3dbeadbe4459eecbe0ad4459e6fb701dcdfe Mon Sep 17 00:00:00 2001 From: Marco Gorelli <33491632+MarcoGorelli@users.noreply.github.com> Date: Sun, 14 Jul 2024 23:21:48 +0100 Subject: [PATCH 20/22] lint --- altair/utils/core.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/altair/utils/core.py b/altair/utils/core.py index d689863da..d10c8d093 100644 --- a/altair/utils/core.py +++ b/altair/utils/core.py @@ -28,7 +28,8 @@ import jsonschema import narwhals.stable.v1 as nw -from narwhals.dependencies import is_pandas_dataframe as _is_pandas_dataframe +from narwhals.dependencies import is_pandas_dataframe +from narwhals.typing import IntoDataFrame from altair.utils.schemapi import SchemaBase, Undefined @@ -43,7 +44,7 @@ import typing as t from altair.vegalite.v5.schema._typing import StandardType_T as InferredVegaLiteType from altair.utils._dfi_types import DataFrame as DfiDataFrame - from narwhals.typing import IntoExpr, IntoDataFrame + from narwhals.typing import IntoExpr import pandas as pd V = TypeVar("V") @@ -638,7 +639,7 @@ def parse_shorthand( if schema[unescaped_field] in { nw.Object, nw.Unknown, - } and _is_pandas_dataframe(nw.to_native(data_nw)): + } and is_pandas_dataframe(nw.to_native(data_nw)): attrs["type"] = infer_vegalite_type_for_pandas(nw.to_native(column)) else: attrs["type"] = infer_vegalite_type_for_narwhals(column) From cd813854d1edf0bc63e85babc6c51a6701a62059 Mon Sep 17 00:00:00 2001 From: Marco Gorelli <33491632+MarcoGorelli@users.noreply.github.com> Date: Mon, 15 Jul 2024 12:50:46 +0100 Subject: [PATCH 21/22] special-case Polars in `.dt.to_string` for Date --- altair/utils/core.py | 6 ++++-- tests/utils/test_utils.py | 2 +- tests/vegalite/v5/test_api.py | 6 +++--- 3 files changed, 8 insertions(+), 6 deletions(-) diff --git a/altair/utils/core.py b/altair/utils/core.py index d10c8d093..9046a75d5 100644 --- a/altair/utils/core.py +++ b/altair/utils/core.py @@ -28,7 +28,7 @@ import jsonschema import narwhals.stable.v1 as nw -from narwhals.dependencies import is_pandas_dataframe +from narwhals.dependencies import is_pandas_dataframe, get_polars from narwhals.typing import IntoDataFrame from altair.utils.schemapi import SchemaBase, Undefined @@ -448,12 +448,14 @@ def sanitize_narwhals_dataframe( # See https://github.com/vega/altair/issues/1027 for why this is necessary. local_iso_fmt_string = "%Y-%m-%dT%H:%M:%S" for name, dtype in schema.items(): - if dtype == nw.Date: + if dtype == nw.Date and nw.get_native_namespace(data) is get_polars(): # Polars doesn't allow formatting `Date` with time directives. # The date -> datetime cast is extremely fast compared with `to_string` columns.append( nw.col(name).cast(nw.Datetime).dt.to_string(local_iso_fmt_string) ) + elif dtype == nw.Date: + columns.append(nw.col(name).dt.to_string(local_iso_fmt_string)) elif dtype == nw.Datetime: columns.append(nw.col(name).dt.to_string(f"{local_iso_fmt_string}%.f")) elif dtype == nw.Duration: diff --git a/tests/utils/test_utils.py b/tests/utils/test_utils.py index d9569f0b1..c3a73acb7 100644 --- a/tests/utils/test_utils.py +++ b/tests/utils/test_utils.py @@ -170,7 +170,7 @@ def test_sanitize_pyarrow_table_columns() -> None: "f": 0.0, "i": 0, "b": True, - "d": "2012-01-01T00:00:00.000000", + "d": "2012-01-01T00:00:00", "c": "a", "p": "2012-01-01T00:00:00.000000000", } diff --git a/tests/vegalite/v5/test_api.py b/tests/vegalite/v5/test_api.py index 0c9a720c6..d00cc9849 100644 --- a/tests/vegalite/v5/test_api.py +++ b/tests/vegalite/v5/test_api.py @@ -1107,7 +1107,7 @@ def test_ibis_with_date_32(): tbl = ibis.memtable(df) result = alt.Chart(tbl).mark_line().encode(x="a", y="b").to_dict() assert next(iter(result["datasets"].values())) == [ - {"a": 1, "b": "2020-01-01T00:00:00.000000"}, - {"a": 2, "b": "2020-01-02T00:00:00.000000"}, - {"a": 3, "b": "2020-01-03T00:00:00.000000"}, + {"a": 1, "b": "2020-01-01T00:00:00"}, + {"a": 2, "b": "2020-01-02T00:00:00"}, + {"a": 3, "b": "2020-01-03T00:00:00"}, ] From bec9bc240210eefcc2ec3d8acb6932bec06f3141 Mon Sep 17 00:00:00 2001 From: Marco Gorelli <33491632+MarcoGorelli@users.noreply.github.com> Date: Mon, 15 Jul 2024 16:44:08 +0100 Subject: [PATCH 22/22] rename test file --- .../{test_dataframe_interchange.py => test_to_values_narwhals.py} | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename tests/utils/{test_dataframe_interchange.py => test_to_values_narwhals.py} (100%) diff --git a/tests/utils/test_dataframe_interchange.py b/tests/utils/test_to_values_narwhals.py similarity index 100% rename from tests/utils/test_dataframe_interchange.py rename to tests/utils/test_to_values_narwhals.py