Skip to content

Commit

Permalink
Hard-code the Workflow output dtypes in Triton (#1468)
Browse files Browse the repository at this point in the history
Since HugeCTR always expects the same three fields, we don't have to consult the `Workflow`'s output schema to determine the dtypes. We can just hard-code them.

Partially addresses NVIDIA-Merlin/Merlin#125
  • Loading branch information
karlhigley authored Mar 22, 2022
1 parent 8b3efa0 commit 9d87ffa
Showing 1 changed file with 10 additions and 6 deletions.
16 changes: 10 additions & 6 deletions nvtabular/inference/triton/workflow_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
import os
from typing import List

import numpy as np
from triton_python_backend_utils import (
InferenceRequest,
InferenceResponse,
Expand Down Expand Up @@ -68,12 +69,15 @@ def initialize(self, args):
self.input_dtypes, self.input_multihots = _parse_input_dtypes(input_dtypes)

self.output_dtypes = dict()
for col_name, col_schema in self.workflow.output_schema.column_schemas.items():
if col_schema.is_list and col_schema.is_ragged:
self._set_output_dtype(col_name + "__nnzs")
self._set_output_dtype(col_name + "__values")
else:
self._set_output_dtype(col_name)
if model_framework == "hugectr":
self.output_dtypes = {"DES": np.float32, "CATCOLUMN": np.int64, "ROWINDEX": np.int32}
else:
for col_name, col_schema in self.workflow.output_schema.column_schemas.items():
if col_schema.is_list and col_schema.is_ragged:
self._set_output_dtype(col_name + "__nnzs")
self._set_output_dtype(col_name + "__values")
else:
self._set_output_dtype(col_name)

if model_framework == "hugectr":
runner_class = HugeCTRWorkflowRunner
Expand Down

0 comments on commit 9d87ffa

Please sign in to comment.