Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: Use appropriate bins in hist when bin_count specified #16942

Merged
merged 1 commit into from
Aug 19, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 19 additions & 32 deletions crates/polars-ops/src/chunked_array/hist.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@ use std::fmt::Write;
use num_traits::ToPrimitive;
use polars_core::prelude::*;
use polars_core::with_match_physical_numeric_polars_type;
use polars_utils::float::IsFloat;
use polars_utils::total_ord::ToTotalOrd;

fn compute_hist<T>(
Expand All @@ -17,6 +16,7 @@ where
T: PolarsNumericType,
ChunkedArray<T>: ChunkAgg<T::Native>,
{
let mut lower_bound: f64;
let (breaks, count) = if let Some(bins) = bins {
let mut breaks = Vec::with_capacity(bins.len() + 1);
breaks.extend_from_slice(bins);
Expand All @@ -31,7 +31,7 @@ where

// We start with the lower garbage bin.
// (-inf, B0]
let mut lower_bound = f64::NEG_INFINITY;
lower_bound = f64::NEG_INFINITY;
let mut upper_bound = *breaks_iter.next().unwrap();

for chunk in sorted.downcast_iter() {
Expand Down Expand Up @@ -60,17 +60,17 @@ where
while count.len() < breaks.len() {
count.push(0)
}
// Push lower bound to infinity
lower_bound = f64::NEG_INFINITY;
(breaks, count)
} else if ca.null_count() == ca.len() {
lower_bound = f64::NEG_INFINITY;
let breaks: Vec<f64> = vec![f64::INFINITY];
let count: Vec<IdxSize> = vec![0];
(breaks, count)
} else {
let min = ChunkAgg::min(ca).unwrap().to_f64().unwrap();
let max = ChunkAgg::max(ca).unwrap().to_f64().unwrap();

let start = min.floor() - 1.0;
let end = max.ceil() + 1.0;
let start = ChunkAgg::min(ca).unwrap().to_f64().unwrap();
let end = ChunkAgg::max(ca).unwrap().to_f64().unwrap();

// If bin_count is omitted, default to the difference between start and stop (unit bins)
let bin_count = if let Some(bin_count) = bin_count {
Expand All @@ -79,45 +79,32 @@ where
(end - start).round() as usize
};

// Calculate the breakpoints and make the array
// Calculate the breakpoints and make the array. The breakpoints form the RHS of the bins.
let interval = (end - start) / (bin_count as f64);

let breaks_iter = (0..(bin_count)).map(|b| start + (b as f64) * interval);

let breaks_iter = (1..(bin_count)).map(|b| start + (b as f64) * interval);
let mut breaks = Vec::with_capacity(breaks_iter.size_hint().0 + 1);
breaks.extend(breaks_iter);
breaks.push(f64::INFINITY);

let mut count: Vec<IdxSize> = vec![0; breaks.len()];
let end_idx = count.len() - 1;
// Extend the left-most edge by 0.1% of the total range to include the minimum value.
let margin = (end - start) * 0.001;
lower_bound = start - margin;
breaks.push(end);

// start is the closed rhs of the interval, so we subtract the bucket width
let start_range = start - interval;
let mut count: Vec<IdxSize> = vec![0; bin_count];
let max_bin = breaks.len() - 1;
for chunk in ca.downcast_iter() {
for item in chunk.non_null_values_iter() {
let item = item.to_f64().unwrap() - start_range;

// This is needed for numeric stability.
// Only for integers.
// we can fall directly on a boundary with an integer.
let item = item / interval;
let item = if !T::Native::is_float() && (item.round() - item).abs() < 0.0000001 {
item.round() - 1.0
} else {
item.ceil() - 1.0
};

let idx = item as usize;
let idx = std::cmp::min(idx, end_idx);
count[idx] += 1;
let item = item.to_f64().unwrap();
let bin = ((((item - start) / interval).ceil() - 1.0) as usize).min(max_bin);
count[bin] += 1;
}
}
(breaks, count)
};
let mut fields = Vec::with_capacity(3);
if include_category {
// Use AnyValue for formatting.
let mut lower = AnyValue::Float64(f64::NEG_INFINITY);
let mut lower = AnyValue::Float64(lower_bound);
let mut categories = StringChunkedBuilder::new("category", breaks.len());

let mut buf = String::new();
Expand Down
25 changes: 12 additions & 13 deletions py-polars/polars/series/series.py
Original file line number Diff line number Diff line change
Expand Up @@ -2404,7 +2404,7 @@ def hist(
If None given, we determine the boundaries based on the data.
bin_count
If no bins provided, this will be used to determine
the distance of the bins
the distance of the bins.
include_breakpoint
Include a column that indicates the upper breakpoint.
include_category
Expand All @@ -2418,18 +2418,17 @@ def hist(
--------
>>> a = pl.Series("a", [1, 3, 8, 8, 2, 1, 3])
>>> a.hist(bin_count=4)
shape: (5, 3)
┌────────────┬─────────────┬───────┐
│ breakpoint ┆ category ┆ count │
│ --- ┆ --- ┆ --- │
│ f64 ┆ cat ┆ u32 │
╞════════════╪═════════════╪═══════╡
│ 0.0 ┆ (-inf, 0.0] ┆ 0 │
│ 2.25 ┆ (0.0, 2.25] ┆ 3 │
│ 4.5 ┆ (2.25, 4.5] ┆ 2 │
│ 6.75 ┆ (4.5, 6.75] ┆ 0 │
│ inf ┆ (6.75, inf] ┆ 2 │
└────────────┴─────────────┴───────┘
shape: (4, 3)
┌────────────┬───────────────┬───────┐
│ breakpoint ┆ category ┆ count │
│ --- ┆ --- ┆ --- │
│ f64 ┆ cat ┆ u32 │
╞════════════╪═══════════════╪═══════╡
│ 2.75 ┆ (0.993, 2.75] ┆ 3 │
│ 4.5 ┆ (2.75, 4.5] ┆ 2 │
│ 6.25 ┆ (4.5, 6.25] ┆ 0 │
│ 8.0 ┆ (6.25, 8.0] ┆ 2 │
└────────────┴───────────────┴───────┘
"""
out = (
self.to_frame()
Expand Down
22 changes: 14 additions & 8 deletions py-polars/tests/unit/operations/test_statistics.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import pytest

import polars as pl
from polars import StringCache
from polars.testing import assert_frame_equal


Expand Down Expand Up @@ -37,16 +38,21 @@ def test_corr_nan() -> None:
assert str(df.select(pl.corr("a", "b", ddof=1))[0, 0]) == "nan"


@StringCache()
def test_hist() -> None:
a = pl.Series("a", [1, 3, 8, 8, 2, 1, 3])
assert (
str(a.hist(bin_count=4).to_dict(as_series=False))
== "{'breakpoint': [0.0, 2.25, 4.5, 6.75, inf], 'category': ['(-inf, 0.0]', '(0.0, 2.25]', '(2.25, 4.5]', '(4.5, 6.75]', '(6.75, inf]'], 'count': [0, 3, 2, 0, 2]}"
s = pl.Series("a", [1, 3, 8, 8, 2, 1, 3])
out = s.hist(bin_count=4)
expected = pl.DataFrame(
{
"breakpoint": pl.Series([2.75, 4.5, 6.25, 8.0], dtype=pl.Float64),
"category": pl.Series(
["(0.993, 2.75]", "(2.75, 4.5]", "(4.5, 6.25]", "(6.25, 8.0]"],
dtype=pl.Categorical,
),
"count": pl.Series([3, 2, 0, 2], dtype=pl.get_index_type()),
}
)

assert a.hist(
bins=[0, 2], include_category=False, include_breakpoint=False
).to_series().to_list() == [0, 3, 4]
assert_frame_equal(out, expected, categorical_as_str=True)


@pytest.mark.parametrize("values", [[], [None]])
Expand Down
Loading