Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(python): Add tooltip by default to charts #18625

Merged
merged 2 commits into from
Sep 11, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
73 changes: 34 additions & 39 deletions py-polars/polars/dataframe/plotting.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,15 +6,13 @@
import sys

import altair as alt
from altair.typing import (
ChannelColor,
ChannelOrder,
ChannelSize,
ChannelTooltip,
ChannelX,
ChannelY,
EncodeKwds,
)
from altair.typing import ChannelColor as Color
from altair.typing import ChannelOrder as Order
from altair.typing import ChannelSize as Size
from altair.typing import ChannelTooltip as Tooltip
from altair.typing import ChannelX as X
from altair.typing import ChannelY as Y
from altair.typing import EncodeKwds

from polars import DataFrame

Expand All @@ -29,12 +27,15 @@

Encodings: TypeAlias = Dict[
str,
Union[
ChannelX, ChannelY, ChannelColor, ChannelOrder, ChannelSize, ChannelTooltip
],
Union[X, Y, Color, Order, Size, Tooltip],
]


def _add_tooltip(encodings: Encodings, /, **kwargs: Unpack[EncodeKwds]) -> None:
if "tooltip" not in kwargs:
encodings["tooltip"] = [*encodings.values(), *kwargs.values()] # type: ignore[assignment]


class DataFramePlot:
"""DataFrame.plot namespace."""

Expand All @@ -45,10 +46,9 @@ def __init__(self, df: DataFrame) -> None:

def bar(
self,
x: ChannelX | None = None,
y: ChannelY | None = None,
color: ChannelColor | None = None,
tooltip: ChannelTooltip | None = None,
x: X | None = None,
y: Y | None = None,
color: Color | None = None,
/,
**kwargs: Unpack[EncodeKwds],
) -> alt.Chart:
Expand Down Expand Up @@ -77,8 +77,6 @@ def bar(
Column with y-coordinates of bars.
color
Column to color bars by.
tooltip
Columns to show values of when hovering over bars with pointer.
**kwargs
Additional keyword arguments passed to Altair.
Comment on lines -80 to 81
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

note that this is a backwards-compatible change, as anyone passing tooltip will still have their code work due to **kwargs


Expand All @@ -102,17 +100,15 @@ def bar(
encodings["y"] = y
if color is not None:
encodings["color"] = color
if tooltip is not None:
encodings["tooltip"] = tooltip
_add_tooltip(encodings, **kwargs)
return self._chart.mark_bar().encode(**encodings, **kwargs).interactive()

def line(
self,
x: ChannelX | None = None,
y: ChannelY | None = None,
color: ChannelColor | None = None,
order: ChannelOrder | None = None,
tooltip: ChannelTooltip | None = None,
x: X | None = None,
y: Y | None = None,
color: Color | None = None,
order: Order | None = None,
/,
**kwargs: Unpack[EncodeKwds],
) -> alt.Chart:
Expand Down Expand Up @@ -142,8 +138,6 @@ def line(
Column to color lines by.
order
Column to use for order of data points in lines.
tooltip
Columns to show values of when hovering over lines with pointer.
**kwargs
Additional keyword arguments passed to Altair.

Expand All @@ -168,17 +162,15 @@ def line(
encodings["color"] = color
if order is not None:
encodings["order"] = order
if tooltip is not None:
encodings["tooltip"] = tooltip
_add_tooltip(encodings, **kwargs)
return self._chart.mark_line().encode(**encodings, **kwargs).interactive()

def point(
self,
x: ChannelX | None = None,
y: ChannelY | None = None,
color: ChannelColor | None = None,
size: ChannelSize | None = None,
tooltip: ChannelTooltip | None = None,
x: X | None = None,
y: Y | None = None,
color: Color | None = None,
size: Size | None = None,
/,
**kwargs: Unpack[EncodeKwds],
) -> alt.Chart:
Expand Down Expand Up @@ -209,8 +201,6 @@ def point(
Column to color points by.
size
Column which determines points' sizes.
tooltip
Columns to show values of when hovering over points with pointer.
**kwargs
Additional keyword arguments passed to Altair.

Expand All @@ -234,8 +224,7 @@ def point(
encodings["color"] = color
if size is not None:
encodings["size"] = size
if tooltip is not None:
encodings["tooltip"] = tooltip
_add_tooltip(encodings, **kwargs)
return (
self._chart.mark_point()
.encode(
Expand All @@ -253,4 +242,10 @@ def __getattr__(self, attr: str) -> Callable[..., alt.Chart]:
if method is None:
msg = "Altair has no method 'mark_{attr}'"
raise AttributeError(msg)
return lambda **kwargs: method().encode(**kwargs).interactive()
encodings: Encodings = {}

def func(**kwargs: EncodeKwds) -> alt.Chart:
_add_tooltip(encodings, **kwargs)
return method().encode(**encodings, **kwargs).interactive()

return func
33 changes: 22 additions & 11 deletions py-polars/polars/series/plotting.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,16 @@

from typing import TYPE_CHECKING, Callable

from polars.dataframe.plotting import _add_tooltip
from polars.dependencies import altair as alt

if TYPE_CHECKING:
import sys

from altair.typing import EncodeKwds

from polars.dataframe.plotting import Encodings

if sys.version_info >= (3, 11):
from typing import Unpack
else:
Expand Down Expand Up @@ -62,11 +65,13 @@ def hist(
if self._series_name == "count()":
msg = "Cannot use `plot.hist` when Series name is `'count()'`"
raise ValueError(msg)
encodings: Encodings = {
"x": alt.X(f"{self._series_name}:Q", bin=True),
"y": "count()",
}
_add_tooltip(encodings, **kwargs)
return (
alt.Chart(self._df)
.mark_bar()
.encode(x=alt.X(f"{self._series_name}:Q", bin=True), y="count()", **kwargs) # type: ignore[misc]
.interactive()
alt.Chart(self._df).mark_bar().encode(**encodings, **kwargs).interactive()
)

def kde(
Expand Down Expand Up @@ -104,11 +109,13 @@ def kde(
if self._series_name == "density":
msg = "Cannot use `plot.kde` when Series name is `'density'`"
raise ValueError(msg)
encodings: Encodings = {"x": self._series_name, "y": "density:Q"}
_add_tooltip(encodings, **kwargs)
return (
alt.Chart(self._df)
.transform_density(self._series_name, as_=[self._series_name, "density"])
.mark_area()
.encode(x=self._series_name, y="density:Q", **kwargs) # type: ignore[misc]
.encode(**encodings, **kwargs)
.interactive()
)

Expand Down Expand Up @@ -147,10 +154,12 @@ def line(
if self._series_name == "index":
msg = "Cannot call `plot.line` when Series name is 'index'"
raise ValueError(msg)
encodings: Encodings = {"x": "index", "y": self._series_name}
_add_tooltip(encodings, **kwargs)
return (
alt.Chart(self._df.with_row_index())
.mark_line()
.encode(x="index", y=self._series_name, **kwargs) # type: ignore[misc]
.encode(**encodings, **kwargs)
.interactive()
)

Expand All @@ -165,8 +174,10 @@ def __getattr__(self, attr: str) -> Callable[..., alt.Chart]:
if method is None:
msg = "Altair has no method 'mark_{attr}'"
raise AttributeError(msg)
return (
lambda **kwargs: method()
.encode(x="index", y=self._series_name, **kwargs)
.interactive()
)
encodings: Encodings = {"x": "index", "y": self._series_name}

def func(**kwargs: EncodeKwds) -> alt.Chart:
_add_tooltip(encodings, **kwargs)
return method().encode(**encodings, **kwargs).interactive()

return func
34 changes: 34 additions & 0 deletions py-polars/tests/unit/operations/namespaces/test_plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,29 @@ def test_dataframe_plot() -> None:
df.plot.area(x="length", y="width", color="species").to_json()


def test_dataframe_plot_tooltip() -> None:
df = pl.DataFrame(
{
"length": [1, 4, 6],
"width": [4, 5, 6],
"species": ["setosa", "setosa", "versicolor"],
}
)
result = df.plot.line(x="length", y="width", color="species").to_dict()
assert result["encoding"]["tooltip"] == [
{"field": "length", "type": "quantitative"},
{"field": "width", "type": "quantitative"},
{"field": "species", "type": "nominal"},
]
result = df.plot.line(
x="length", y="width", color="species", tooltip=["length", "width"]
).to_dict()
assert result["encoding"]["tooltip"] == [
{"field": "length", "type": "quantitative"},
{"field": "width", "type": "quantitative"},
]


def test_series_plot() -> None:
# dry-run, check nothing errors
s = pl.Series("a", [1, 4, 4, 4, 7, 2, 5, 3, 6])
Expand All @@ -26,6 +49,17 @@ def test_series_plot() -> None:
s.plot.point().to_json()


def test_series_plot_tooltip() -> None:
s = pl.Series("a", [1, 4, 4, 4, 7, 2, 5, 3, 6])
result = s.plot.line().to_dict()
assert result["encoding"]["tooltip"] == [
{"field": "index", "type": "quantitative"},
{"field": "a", "type": "quantitative"},
]
result = s.plot.line(tooltip=["a"]).to_dict()
assert result["encoding"]["tooltip"] == [{"field": "a", "type": "quantitative"}]


def test_empty_dataframe() -> None:
pl.DataFrame({"a": [], "b": []}).plot.point(x="a", y="b")

Expand Down