Skip to content

Commit

Permalink
fix: perform predictions in inference mode
Browse files Browse the repository at this point in the history
  • Loading branch information
lucasczz committed Dec 20, 2022
1 parent 6d57ee0 commit 8049ecb
Show file tree
Hide file tree
Showing 3 changed files with 12 additions and 5 deletions.
6 changes: 4 additions & 2 deletions river_torch/classification/classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -254,7 +254,8 @@ def predict_proba_one(self, x: dict) -> Dict[ClfTarget, float]:
self.initialize_module(**self.kwargs)
x_t = dict2tensor(x, device=self.device)
self.module.eval()
y_pred = self.module(x_t)
with torch.inference_mode():
y_pred = self.module(x_t)
return output2proba(
y_pred, self.observed_classes, self.output_is_logit
)
Expand Down Expand Up @@ -306,7 +307,8 @@ def predict_proba_many(self, X: pd.DataFrame) -> pd.DataFrame:
self.initialize_module(**self.kwargs)
X_t = df2tensor(X, device=self.device)
self.module.eval()
y_preds = self.module(X_t)
with torch.inference_mode():
y_preds = self.module(X_t)
return pd.Dataframe(output2proba(y_preds, self.observed_classes))

def _adapt_output_dim(self):
Expand Down
8 changes: 6 additions & 2 deletions river_torch/regression/regressor.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,7 +177,9 @@ def predict_one(self, x: dict) -> RegTarget:
self.initialize_module(**self.kwargs)
x_t = dict2tensor(x, self.device)
self.module.eval()
return self.module(x_t).item()
with torch.inference_mode():
y_pred = self.module(x_t).item()
return y_pred

def learn_many(self, X: pd.DataFrame, y: List) -> "Regressor":
"""
Expand Down Expand Up @@ -223,4 +225,6 @@ def predict_many(self, X: pd.DataFrame) -> List:

X = df2tensor(X, device=self.device)
self.module.eval()
return self.module(X).detach().tolist()
with torch.inference_mode():
y_preds = self.module(X).detach().tolist()
return y_preds
3 changes: 2 additions & 1 deletion river_torch/utils/hooks.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,9 @@ def __call__(self, module, input, output):
self.ordered_modules.append(module)


def apply_hooks(module, hook, handles):
def apply_hooks(module, hook, handles=[]):
for child in module.children():
apply_hooks(child, hook, handles)
handle = module.register_forward_hook(hook)
handles.append(handle)
return handles

0 comments on commit 8049ecb

Please sign in to comment.