Skip to content

Commit

Permalink
refactor predict_proba_many and tensor_conversion.py
Browse files Browse the repository at this point in the history
  • Loading branch information
Cedric Kulbach committed Dec 23, 2022
1 parent e872e11 commit ad452fa
Show file tree
Hide file tree
Showing 3 changed files with 21 additions and 15 deletions.
2 changes: 1 addition & 1 deletion deep_river/classification/rolling_classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -333,7 +333,7 @@ def predict_proba_many(self, X: pd.DataFrame) -> pd.DataFrame:
probas = [default_proba] * len(X)
return pd.DataFrame(probas)

def _get_default_proba(self)-> List[Dict[ClfTarget, float]]:
def _get_default_proba(self) -> List[Dict[ClfTarget, float]]:
if len(self.observed_classes) > 0:
mean_proba = (
1 / len(self.observed_classes)
Expand Down
6 changes: 3 additions & 3 deletions deep_river/utils/tensor_conversion.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Collection, Deque, Dict, Optional, Union
from typing import Deque, Dict, List, Optional, Union

import numpy as np
import pandas as pd
Expand Down Expand Up @@ -146,7 +146,7 @@ def labels2onehot(

def output2proba(
preds: torch.Tensor, classes: OrderedSet, with_logits=False
) -> Collection[Dict[ClfTarget, float]]:
) -> List[Dict[ClfTarget, float]]:
if with_logits:
if preds.shape[-1] >= 1:
preds = torch.softmax(preds, dim=-1)
Expand All @@ -168,4 +168,4 @@ def output2proba(
if preds_np.shape[0] == 1
else [dict(zip(classes, pred)) for pred in preds_np]
)
return [probas] if isinstance(probas, dict) else probas
return [probas] if isinstance(probas, dict) else list(probas)
28 changes: 17 additions & 11 deletions deep_river/utils/test_tensor_conversion.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,9 @@ def test_output2proba():
def assert_dicts_almost_equal(d1, d2):
for i in range(len(d1)):
for k in d1[i]:
assert np.isclose(d1[i][k], d2[i][k]), f"{d1[i][k]} != {d2[i][k]}"
assert np.isclose(
d1[i][k], d2[i][k]
), f"{d1[i][k]} != {d2[i][k]}"

y = torch.tensor([[0.1, 0.2, 0.7]])
classes = ["first class", "second class", "third class"]
Expand All @@ -92,20 +94,24 @@ def assert_dicts_almost_equal(d1, d2):
classes = ["first class"]
assert_dicts_almost_equal(
output2proba(y, classes),
[dict(
zip(
["first class", "unobserved 0"],
np.array([0.6, 0.4], dtype=np.float32),
[
dict(
zip(
["first class", "unobserved 0"],
np.array([0.6, 0.4], dtype=np.float32),
)
)
)],
],
)
y = torch.tensor([[0.6, 0.4, 0.0]])
assert_dicts_almost_equal(
output2proba(y, classes),
[dict(
zip(
["first class", "unobserved 0", "unobserved 1"],
np.array([0.6, 0.4, 0.0], dtype=np.float32),
[
dict(
zip(
["first class", "unobserved 0", "unobserved 1"],
np.array([0.6, 0.4, 0.0], dtype=np.float32),
)
)
)]
],
)

0 comments on commit ad452fa

Please sign in to comment.