Skip to content

Commit

Permalink
feat(python): add tooltip by default to charts
Browse files Browse the repository at this point in the history
  • Loading branch information
MarcoGorelli committed Sep 9, 2024
1 parent aa3b2c3 commit ab0ba24
Show file tree
Hide file tree
Showing 2 changed files with 52 additions and 34 deletions.
73 changes: 46 additions & 27 deletions py-polars/polars/dataframe/plotting.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,24 @@

import altair as alt
from altair.typing import (
ChannelColor,
ChannelOrder,
ChannelSize,
ChannelTooltip,
ChannelX,
ChannelY,
ChannelColor as Color,
)
from altair.typing import (
ChannelOrder as Order,
)
from altair.typing import (
ChannelSize as Size,
)
from altair.typing import (
ChannelTooltip as Tooltip,
)
from altair.typing import (
ChannelX as X,
)
from altair.typing import (
ChannelY as Y,
)
from altair.typing import (
EncodeKwds,
)

Expand All @@ -29,12 +41,15 @@

Encodings: TypeAlias = Dict[
str,
Union[
ChannelX, ChannelY, ChannelColor, ChannelOrder, ChannelSize, ChannelTooltip
],
Union[X, Y, Color, Order, Size, Tooltip],
]


def _add_tooltip(chart: alt.Chart) -> alt.Chart:
chart.mark = {"type": chart.mark, "tooltip": True}
return chart


class DataFramePlot:
"""DataFrame.plot namespace."""

Expand All @@ -45,10 +60,10 @@ def __init__(self, df: DataFrame) -> None:

def bar(
self,
x: ChannelX | None = None,
y: ChannelY | None = None,
color: ChannelColor | None = None,
tooltip: ChannelTooltip | None = None,
x: X | None = None,
y: Y | None = None,
color: Color | None = None,
tooltip: Tooltip | None = None,
/,
**kwargs: Unpack[EncodeKwds],
) -> alt.Chart:
Expand Down Expand Up @@ -104,15 +119,17 @@ def bar(
encodings["color"] = color
if tooltip is not None:
encodings["tooltip"] = tooltip
return self._chart.mark_bar().encode(**encodings, **kwargs).interactive()
return _add_tooltip(
self._chart.mark_bar().encode(**encodings, **kwargs).interactive()
)

def line(
self,
x: ChannelX | None = None,
y: ChannelY | None = None,
color: ChannelColor | None = None,
order: ChannelOrder | None = None,
tooltip: ChannelTooltip | None = None,
x: X | None = None,
y: Y | None = None,
color: Color | None = None,
order: Order | None = None,
tooltip: Tooltip | None = None,
/,
**kwargs: Unpack[EncodeKwds],
) -> alt.Chart:
Expand Down Expand Up @@ -170,15 +187,17 @@ def line(
encodings["order"] = order
if tooltip is not None:
encodings["tooltip"] = tooltip
return self._chart.mark_line().encode(**encodings, **kwargs).interactive()
return _add_tooltip(
self._chart.mark_line().encode(**encodings, **kwargs).interactive()
)

def point(
self,
x: ChannelX | None = None,
y: ChannelY | None = None,
color: ChannelColor | None = None,
size: ChannelSize | None = None,
tooltip: ChannelTooltip | None = None,
x: X | None = None,
y: Y | None = None,
color: Color | None = None,
size: Size | None = None,
tooltip: Tooltip | None = None,
/,
**kwargs: Unpack[EncodeKwds],
) -> alt.Chart:
Expand Down Expand Up @@ -236,7 +255,7 @@ def point(
encodings["size"] = size
if tooltip is not None:
encodings["tooltip"] = tooltip
return (
return _add_tooltip(
self._chart.mark_point()
.encode(
**encodings,
Expand All @@ -253,4 +272,4 @@ def __getattr__(self, attr: str) -> Callable[..., alt.Chart]:
if method is None:
msg = "Altair has no method 'mark_{attr}'"
raise AttributeError(msg)
return lambda **kwargs: method().encode(**kwargs).interactive()
return lambda **kwargs: _add_tooltip(method().encode(**kwargs).interactive())
13 changes: 6 additions & 7 deletions py-polars/polars/series/plotting.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

from typing import TYPE_CHECKING, Callable

from polars.dataframe.plotting import add_tooltip
from polars.dependencies import altair as alt

if TYPE_CHECKING:
Expand Down Expand Up @@ -62,7 +63,7 @@ def hist(
if self._series_name == "count()":
msg = "Cannot use `plot.hist` when Series name is `'count()'`"
raise ValueError(msg)
return (
return add_tooltip(
alt.Chart(self._df)
.mark_bar()
.encode(x=alt.X(f"{self._series_name}:Q", bin=True), y="count()", **kwargs) # type: ignore[misc]
Expand Down Expand Up @@ -104,7 +105,7 @@ def kde(
if self._series_name == "density":
msg = "Cannot use `plot.kde` when Series name is `'density'`"
raise ValueError(msg)
return (
return add_tooltip(
alt.Chart(self._df)
.transform_density(self._series_name, as_=[self._series_name, "density"])
.mark_area()
Expand Down Expand Up @@ -147,7 +148,7 @@ def line(
if self._series_name == "index":
msg = "Cannot call `plot.line` when Series name is 'index'"
raise ValueError(msg)
return (
return add_tooltip(
alt.Chart(self._df.with_row_index())
.mark_line()
.encode(x="index", y=self._series_name, **kwargs) # type: ignore[misc]
Expand All @@ -165,8 +166,6 @@ def __getattr__(self, attr: str) -> Callable[..., alt.Chart]:
if method is None:
msg = "Altair has no method 'mark_{attr}'"
raise AttributeError(msg)
return (
lambda **kwargs: method()
.encode(x="index", y=self._series_name, **kwargs)
.interactive()
return lambda **kwargs: add_tooltip(
method().encode(x="index", y=self._series_name, **kwargs).interactive()
)

0 comments on commit ab0ba24

Please sign in to comment.