Skip to content

Commit

Permalink
Add support for batch_input
Browse files Browse the repository at this point in the history
  • Loading branch information
HennerM authored and mc-nv committed Jun 7, 2023
1 parent fd29c6e commit bfca5d0
Showing 1 changed file with 9 additions and 20 deletions.
29 changes: 9 additions & 20 deletions src/libtorch.cc
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@
// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

#include <stdint.h>

#include <cstdint>
#include <exception>

Expand Down Expand Up @@ -506,8 +505,9 @@ class ModelInstanceState : public BackendModelInstance {
const std::string& control_kind, bool required, bool* have_control);
TRITONSERVER_Error* ValidateInputs(const size_t expected_input_cnt);
void AddInputToMap(
NamingConvention naming_convention,
const std::vector<std::string> allowed_inputs, const std::string& io_name,
NamingConvention naming_convention,
const std::vector<std::string> allowed_inputs,
const std::string &io_name,
const uint32_t index);
TRITONSERVER_Error* ValidateOutputs();
void Execute(
Expand Down Expand Up @@ -771,12 +771,7 @@ ModelInstanceState::ValidateTypedSequenceControl(
return nullptr; // success
}

void
ModelInstanceState::AddInputToMap(
NamingConvention naming_convention,
const std::vector<std::string> allowed_inputs, const std::string& io_name,
const uint32_t index)
{
void ModelInstanceState::AddInputToMap(NamingConvention naming_convention, const std::vector<std::string> allowed_inputs, const std::string &io_name, const uint32_t index) {
std::string deliminator = "__";

if (is_dict_input_) {
Expand Down Expand Up @@ -929,13 +924,11 @@ ModelInstanceState::ValidateInputs(const size_t expected_input_cnt)
}

triton::common::TritonJson::Value batch_inputs;
RETURN_IF_ERROR(
model_state_->ModelConfig().MemberAsArray("batch_input", &batch_inputs));
RETURN_IF_ERROR(model_state_->ModelConfig().MemberAsArray("batch_input", &batch_inputs));
size_t i = 0;
for (const auto& batch_input : StateForModel()->BatchInputs()) {
for (const auto& input_name : batch_input.TargetNames()) {
AddInputToMap(
naming_convention, allowed_inputs, input_name, i + ios.ArraySize());
AddInputToMap(naming_convention, allowed_inputs, input_name, i + ios.ArraySize());
i++;
}
}
Expand Down Expand Up @@ -1883,16 +1876,12 @@ ModelInstanceState::SetInputTensors(
&dst_buffer, &dst_buffer_byte_size, &dst_memory_type,
&dst_memory_type_id));

const auto torch_dtype =
ConvertDataTypeToTorchType(batch_input.DataType());
const auto torch_dtype = ConvertDataTypeToTorchType(batch_input.DataType());
torch::TensorOptions options{torch_dtype.second};
auto updated_options = (dst_memory_type == TRITONSERVER_MEMORY_GPU)
? options.device(torch::kCUDA, device_.index())
: options.device(torch::kCPU);
auto updated_options = options.device(torch::kCPU);

torch::Tensor input_tensor = torch::from_blob(
const_cast<char*>(dst_buffer), shape,
updated_options.dtype(torch_dtype.second));
const_cast<char*>(dst_buffer), shape, updated_options);
(*input_tensors)[input_index_map_[input_name]] = input_tensor;
}
}
Expand Down

0 comments on commit bfca5d0

Please sign in to comment.