Skip to content

Commit

Permalink
Add new method when bins not specified
Browse files Browse the repository at this point in the history
  • Loading branch information
mcrumiller committed Jun 14, 2024
1 parent 1308bf4 commit 449843d
Show file tree
Hide file tree
Showing 3 changed files with 45 additions and 53 deletions.
51 changes: 19 additions & 32 deletions crates/polars-ops/src/chunked_array/hist.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@ use std::fmt::Write;
use num_traits::ToPrimitive;
use polars_core::prelude::*;
use polars_core::with_match_physical_numeric_polars_type;
use polars_utils::float::IsFloat;
use polars_utils::total_ord::ToTotalOrd;

fn compute_hist<T>(
Expand All @@ -17,6 +16,7 @@ where
T: PolarsNumericType,
ChunkedArray<T>: ChunkAgg<T::Native>,
{
let mut lower_bound: f64;
let (breaks, count) = if let Some(bins) = bins {
let mut breaks = Vec::with_capacity(bins.len() + 1);
breaks.extend_from_slice(bins);
Expand All @@ -31,7 +31,7 @@ where

// We start with the lower garbage bin.
// (-inf, B0]
let mut lower_bound = f64::NEG_INFINITY;
lower_bound = f64::NEG_INFINITY;
let mut upper_bound = *breaks_iter.next().unwrap();

for chunk in sorted.downcast_iter() {
Expand Down Expand Up @@ -60,17 +60,17 @@ where
while count.len() < breaks.len() {
count.push(0)
}
// Push lower bound to infinity
lower_bound = f64::NEG_INFINITY;
(breaks, count)
} else if ca.null_count() == ca.len() {
lower_bound = f64::NEG_INFINITY;
let breaks: Vec<f64> = vec![f64::INFINITY];
let count: Vec<IdxSize> = vec![0];
(breaks, count)
} else {
let min = ChunkAgg::min(ca).unwrap().to_f64().unwrap();
let max = ChunkAgg::max(ca).unwrap().to_f64().unwrap();

let start = min.floor() - 1.0;
let end = max.ceil() + 1.0;
let start = ChunkAgg::min(ca).unwrap().to_f64().unwrap();
let end = ChunkAgg::max(ca).unwrap().to_f64().unwrap();

// If bin_count is omitted, default to the difference between start and stop (unit bins)
let bin_count = if let Some(bin_count) = bin_count {
Expand All @@ -79,45 +79,32 @@ where
(end - start).round() as usize
};

// Calculate the breakpoints and make the array
// Calculate the breakpoints and make the array. The breakpoints form the RHS of the bins.
let interval = (end - start) / (bin_count as f64);

let breaks_iter = (0..(bin_count)).map(|b| start + (b as f64) * interval);

let breaks_iter = (1..(bin_count)).map(|b| start + (b as f64) * interval);
let mut breaks = Vec::with_capacity(breaks_iter.size_hint().0 + 1);
breaks.extend(breaks_iter);
breaks.push(f64::INFINITY);

let mut count: Vec<IdxSize> = vec![0; breaks.len()];
let end_idx = count.len() - 1;
// Extend the left-most edge by 0.1% of the total range to include the minimum value.
let margin = (end - start) * 0.001;
lower_bound = start - margin;
breaks.push(end);

// start is the closed rhs of the interval, so we subtract the bucket width
let start_range = start - interval;
let mut count: Vec<IdxSize> = vec![0; bin_count];
let max_bin = breaks.len() - 1;
for chunk in ca.downcast_iter() {
for item in chunk.non_null_values_iter() {
let item = item.to_f64().unwrap() - start_range;

// This is needed for numeric stability.
// Only for integers.
// we can fall directly on a boundary with an integer.
let item = item / interval;
let item = if !T::Native::is_float() && (item.round() - item).abs() < 0.0000001 {
item.round() - 1.0
} else {
item.ceil() - 1.0
};

let idx = item as usize;
let idx = std::cmp::min(idx, end_idx);
count[idx] += 1;
let item = item.to_f64().unwrap();
let bin = ((((item - start) / interval).ceil() - 1.0) as usize).min(max_bin);
count[bin] += 1;
}
}
(breaks, count)
};
let mut fields = Vec::with_capacity(3);
if include_category {
// Use AnyValue for formatting.
let mut lower = AnyValue::Float64(f64::NEG_INFINITY);
let mut lower = AnyValue::Float64(lower_bound);
let mut categories = StringChunkedBuilder::new("category", breaks.len());

let mut buf = String::new();
Expand Down
25 changes: 12 additions & 13 deletions py-polars/polars/series/series.py
Original file line number Diff line number Diff line change
Expand Up @@ -2354,7 +2354,7 @@ def hist(
If None given, we determine the boundaries based on the data.
bin_count
If no bins provided, this will be used to determine
the distance of the bins
the distance of the bins.
include_breakpoint
Include a column that indicates the upper breakpoint.
include_category
Expand All @@ -2368,18 +2368,17 @@ def hist(
--------
>>> a = pl.Series("a", [1, 3, 8, 8, 2, 1, 3])
>>> a.hist(bin_count=4)
shape: (5, 3)
┌────────────┬─────────────┬───────┐
│ breakpoint ┆ category ┆ count │
│ --- ┆ --- ┆ --- │
│ f64 ┆ cat ┆ u32 │
╞════════════╪═════════════╪═══════╡
│ 0.0 ┆ (-inf, 0.0] ┆ 0 │
│ 2.25 ┆ (0.0, 2.25] ┆ 3 │
│ 4.5 ┆ (2.25, 4.5] ┆ 2 │
│ 6.75 ┆ (4.5, 6.75] ┆ 0 │
│ inf ┆ (6.75, inf] ┆ 2 │
└────────────┴─────────────┴───────┘
shape: (4, 3)
┌────────────┬───────────────┬───────┐
│ breakpoint ┆ category ┆ count │
│ --- ┆ --- ┆ --- │
│ f64 ┆ cat ┆ u32 │
╞════════════╪═══════════════╪═══════╡
│ 2.75 ┆ (0.993, 2.75] ┆ 3 │
│ 4.5 ┆ (2.75, 4.5] ┆ 2 │
│ 6.25 ┆ (4.5, 6.25] ┆ 0 │
│ 8.0 ┆ (6.25, 8.0] ┆ 2 │
└────────────┴───────────────┴───────┘
"""
out = (
self.to_frame()
Expand Down
22 changes: 14 additions & 8 deletions py-polars/tests/unit/operations/test_statistics.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import pytest

import polars as pl
from polars import StringCache
from polars.testing import assert_frame_equal


Expand Down Expand Up @@ -37,16 +38,21 @@ def test_corr_nan() -> None:
assert str(df.select(pl.corr("a", "b", ddof=1))[0, 0]) == "nan"


@StringCache()
def test_hist() -> None:
a = pl.Series("a", [1, 3, 8, 8, 2, 1, 3])
assert (
str(a.hist(bin_count=4).to_dict(as_series=False))
== "{'breakpoint': [0.0, 2.25, 4.5, 6.75, inf], 'category': ['(-inf, 0.0]', '(0.0, 2.25]', '(2.25, 4.5]', '(4.5, 6.75]', '(6.75, inf]'], 'count': [0, 3, 2, 0, 2]}"
s = pl.Series("a", [1, 3, 8, 8, 2, 1, 3])
out = s.hist(bin_count=4)
expected = pl.DataFrame(
{
"breakpoint": pl.Series([2.75, 4.5, 6.25, 8.0], dtype=pl.Float64),
"category": pl.Series(
["(0.993, 2.75]", "(2.75, 4.5]", "(4.5, 6.25]", "(6.25, 8.0]"],
dtype=pl.Categorical,
),
"count": pl.Series([3, 2, 0, 2], dtype=pl.get_index_type()),
}
)

assert a.hist(
bins=[0, 2], include_category=False, include_breakpoint=False
).to_series().to_list() == [0, 3, 4]
assert_frame_equal(out, expected, categorical_as_str=True)


@pytest.mark.parametrize("values", [[], [None]])
Expand Down

0 comments on commit 449843d

Please sign in to comment.