From 9d87ffad466b786d2f0ace7be47af181daddfe0b Mon Sep 17 00:00:00 2001 From: Karl Higley Date: Tue, 22 Mar 2022 19:09:10 -0400 Subject: [PATCH] Hard-code the `Workflow` output dtypes in Triton (#1468) Since HugeCTR always expects the same three fields, we don't have to consult the `Workflow`'s output schema to determine the dtypes. We can just hard-code them. Partially addresses NVIDIA-Merlin/Merlin#125 --- nvtabular/inference/triton/workflow_model.py | 16 ++++++++++------ 1 file changed, 10 insertions(+), 6 deletions(-) diff --git a/nvtabular/inference/triton/workflow_model.py b/nvtabular/inference/triton/workflow_model.py index 457f56d47b3..a880121ceca 100644 --- a/nvtabular/inference/triton/workflow_model.py +++ b/nvtabular/inference/triton/workflow_model.py @@ -29,6 +29,7 @@ import os from typing import List +import numpy as np from triton_python_backend_utils import ( InferenceRequest, InferenceResponse, @@ -68,12 +69,15 @@ def initialize(self, args): self.input_dtypes, self.input_multihots = _parse_input_dtypes(input_dtypes) self.output_dtypes = dict() - for col_name, col_schema in self.workflow.output_schema.column_schemas.items(): - if col_schema.is_list and col_schema.is_ragged: - self._set_output_dtype(col_name + "__nnzs") - self._set_output_dtype(col_name + "__values") - else: - self._set_output_dtype(col_name) + if model_framework == "hugectr": + self.output_dtypes = {"DES": np.float32, "CATCOLUMN": np.int64, "ROWINDEX": np.int32} + else: + for col_name, col_schema in self.workflow.output_schema.column_schemas.items(): + if col_schema.is_list and col_schema.is_ragged: + self._set_output_dtype(col_name + "__nnzs") + self._set_output_dtype(col_name + "__values") + else: + self._set_output_dtype(col_name) if model_framework == "hugectr": runner_class = HugeCTRWorkflowRunner