Skip to content

Commit

Permalink
Add optional target_prefix input
Browse files Browse the repository at this point in the history
  • Loading branch information
HennerM committed Apr 18, 2023
1 parent ae0fa67 commit ea4e09d
Showing 1 changed file with 214 additions and 99 deletions.
313 changes: 214 additions & 99 deletions src/ctranslate2.cc
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,10 @@
#include <cstddef>
#include <cstring>
#include <memory>
#include <optional>
#include <sstream>
#include <string>
#include <utility>

#include "triton/backend/backend_common.h"
#include "triton/backend/backend_input_collector.h"
Expand Down Expand Up @@ -66,6 +70,15 @@ ReadParameter(const triton::common::TritonJson::Value &params,
return nullptr; // success
}

TRITONSERVER_Error *
ReadParameter(const triton::common::TritonJson::Value &params,
const std::string &key, size_t *param) {
std::string tmp;
RETURN_IF_ERROR(ReadParameter(params, key, &tmp));
*param = static_cast<size_t>(std::stoi(tmp));
return nullptr; // success
}

TRITONSERVER_Error *
ReadParameter(const triton::common::TritonJson::Value &params,
const std::string &key, float *param) {
Expand Down Expand Up @@ -106,9 +119,10 @@ class ModelState : public BackendModel {
RETURN_IF_ERROR(ModelConfig().MemberAsArray("output", &outputs));

// The model must have exactly 1 input and 1 output.
RETURN_ERROR_IF_FALSE(inputs.ArraySize() == 1,
TRITONSERVER_ERROR_INVALID_ARG,
std::string("model configuration must have 1 input"));
RETURN_ERROR_IF_FALSE(
inputs.ArraySize() == 1 || inputs.ArraySize() == 2,
TRITONSERVER_ERROR_INVALID_ARG,
std::string("model configuration must have 1 or 2 inputs"));
RETURN_ERROR_IF_FALSE(
outputs.ArraySize() == 1, TRITONSERVER_ERROR_INVALID_ARG,
std::string("model configuration must have 1 output"));
Expand All @@ -123,6 +137,15 @@ class ModelState : public BackendModel {
RETURN_IF_ERROR(input.MemberAsString("name", &input_name, &input_name_len));
input_name_ = std::string(input_name);

if (inputs.ArraySize() == 2) {
RETURN_IF_ERROR(inputs.IndexAsObject(1, &input));
const char *target_prefix_input_name;
size_t target_prefix_input_name_len;
RETURN_IF_ERROR(input.MemberAsString("name", &target_prefix_input_name,
&target_prefix_input_name_len));
target_prefix_input_name_ = std::string(target_prefix_input_name);
}

const char *output_name;
size_t output_name_len;
RETURN_IF_ERROR(
Expand All @@ -148,13 +171,32 @@ class ModelState : public BackendModel {
compute_type_str)
.c_str());
}
if (params.Find("max_decoding_length_multiple")) {
size_t max_decode_length_multiple;
RETURN_IF_ERROR(ReadParameter(params, "max_decoding_length_multiple",
&max_decode_length_multiple));
max_decode_length_multiple_ = max_decode_length_multiple;
}
if (params.Find("beam_size")) {
RETURN_IF_ERROR(ReadParameter(
params, "beam_size", &(default_translation_options_.beam_size)));
}
}
return nullptr;
}

const std::string &InputTensorName() const { return input_name_; }
const std::optional<std::string> &TargetPrefixInputName() const {
return target_prefix_input_name_;
}
const std::string &OutputTensorName() const { return output_name_; }
TRITONSERVER_DataType OutputDataType() const { return output_type_; }
::ctranslate2::TranslationOptions DefaultTranslationOptions() const {
return default_translation_options_;
}
const std::optional<size_t> &MaxDecodeLengthMultiple() const {
return max_decode_length_multiple_;
}

TRITONSERVER_Error *
LoadModel(const ::ctranslate2::Device device, std::int32_t device_index,
Expand Down Expand Up @@ -183,10 +225,13 @@ class ModelState : public BackendModel {
// TRITONBACKEND_Model *triton_model_;
triton::common::TritonJson::Value model_config_;
std::string input_name_;
std::optional<std::string> target_prefix_input_name_;
std::string output_name_;
TRITONSERVER_DataType output_type_;
::ctranslate2::ComputeType compute_type_ =
::ctranslate2::ComputeType::DEFAULT;
::ctranslate2::TranslationOptions default_translation_options_;
std::optional<size_t> max_decode_length_multiple_;
std::string model_path_;
std::shared_ptr<::ctranslate2::models::ModelReader> model_reader_;
std::map<std::pair<::ctranslate2::Device, std::int32_t>,
Expand Down Expand Up @@ -413,6 +458,133 @@ TRITONSERVER_Error *ToOutBuffer(const std::vector<std::size_t> &out_tokens,
return nullptr;
}

std::string
TranslationOptionsToString(const ::ctranslate2::TranslationOptions &options) {
std::stringstream ss;
ss << "TranslationOptions("
<< "beam_size=" << options.beam_size << ", "
<< "patience=" << options.patience << ", "
<< "length_penalty=" << options.length_penalty << ", "
<< "coverage_penalty=" << options.coverage_penalty << ", "
<< "repetition_penalty=" << options.repetition_penalty << ", "
<< "no_repeat_ngram_size=" << options.no_repeat_ngram_size << ", "
<< "disable_unk=" << options.disable_unk << ", "
<< "size(suppress_sequences)=" << options.suppress_sequences.size() << ", "
<< "prefix_bias_beta=" << options.prefix_bias_beta << ", "
<< "end_token=\"" << options.end_token << "\", "
<< "max_input_length=" << options.max_input_length << ", "
<< "max_decoding_length=" << options.max_decoding_length << ", "
<< "min_decoding_length=" << options.min_decoding_length << ", "
<< "sampling_topk=" << options.sampling_topk << ", "
<< "sampling_temperature=" << options.sampling_temperature << ", "
<< "use_vmap=" << options.use_vmap << ", "
<< "num_hypotheses=" << options.num_hypotheses << ", "
<< "return_scores=" << options.return_scores << ", "
<< "return_attention=" << options.return_attention << ", "
<< "return_alternatives=" << options.return_alternatives << ", "
<< "min_alternative_expansion_prob="
<< options.min_alternative_expansion_prob << ", "
<< "replace_unknowns=" << options.replace_unknowns << ")";
return ss.str();
}

TRITONSERVER_Error *InputBufferToRaggedTokens(
size_t total_batch_size, TRITONBACKEND_Request **requests,
const uint32_t request_count,
std::vector<TRITONBACKEND_Response *> *responses,
BackendInputCollector *collector,
std::vector<std::vector<size_t>> *ragged_tokens,
size_t *max_sequence_length, const std::string &input_name,
bool is_ragged_input = true, bool supports_batching = true) {
std::vector<std::vector<size_t>> tokens;
tokens.reserve(request_count);

const char *input_buffer;
size_t batchn_byte_size;
TRITONSERVER_MemoryType memory_type;
int64_t memory_type_id;

// TODO support data straight from GPU
std::vector<std::pair<TRITONSERVER_MemoryType, int64_t>> alloc_preference = {
{TRITONSERVER_MEMORY_CPU_PINNED, 0}, {TRITONSERVER_MEMORY_CPU, 0}};

RETURN_IF_ERROR(collector->ProcessTensor(
input_name.c_str(), nullptr, 0, alloc_preference, &input_buffer,
&batchn_byte_size, &memory_type, &memory_type_id));

// bool is_ragged =
//
size_t max_seq_length = 0;
if (is_ragged_input) {
int64_t total_elements = 0;
for (size_t request_idx = 0; request_idx < request_count; request_idx++) {
TRITONBACKEND_Input *input;
RESPOND_AND_SET_NULL_IF_ERROR(
&((*responses)[request_idx]),
TRITONBACKEND_RequestInput(requests[request_idx], input_name.c_str(),
&input));

TRITONSERVER_DataType input_dt;
const int64_t *input_shape;
uint32_t input_dims_count;
RETURN_IF_ERROR(
TRITONBACKEND_InputProperties(input, nullptr, &input_dt, &input_shape,
&input_dims_count, nullptr, nullptr));

auto element_count = GetElementCount(input_shape, input_dims_count);
LOG_MESSAGE(TRITONSERVER_LOG_VERBOSE,
(std::string("Element count for request ") +
std::to_string(request_idx) + std::string(": ") +
std::to_string(element_count))
.c_str());
max_seq_length =
std::max(max_seq_length, static_cast<size_t>(element_count));

std::vector<size_t> ids;
ToIdVector(input_buffer, input_dt, &ids, total_elements, element_count);
total_elements += element_count;
tokens.emplace_back(ids);
}
} else {
// input type is the same for all
TRITONBACKEND_Input *input;
RETURN_IF_ERROR(
TRITONBACKEND_RequestInput(requests[0], input_name.c_str(), &input));

TRITONSERVER_DataType input_dt;
const int64_t *input_shape;
uint32_t input_dims_count;
RETURN_IF_ERROR(
TRITONBACKEND_InputProperties(input, nullptr, &input_dt, &input_shape,
&input_dims_count, nullptr, nullptr));

if (input_dims_count > 2) {
return TRITONSERVER_ErrorNew(
TRITONSERVER_ERROR_INVALID_ARG,
std::string("Inputs with more than two dimensions unsupported")
.c_str());
}

std::vector<int64_t> batchn_shape =
std::vector<int64_t>(input_shape, input_shape + input_dims_count);
if (supports_batching) {
batchn_shape[0] = total_batch_size;
}

for (size_t vector_idx = 0; vector_idx < total_batch_size; vector_idx++) {
std::vector<size_t> ids;
ToIdVector(input_buffer, input_dt, &ids, vector_idx * batchn_shape[1],
(vector_idx + 1) * batchn_shape[1]);
tokens.emplace_back(ids);
}
max_seq_length = static_cast<size_t>(batchn_shape[1]);
}

*ragged_tokens = tokens;
*max_sequence_length = max_seq_length;

return nullptr;
}
/////////////

//
Expand Down Expand Up @@ -460,101 +632,31 @@ class ModelInstanceState : public BackendModelInstance {
std::vector<TRITONBACKEND_Response *> *responses,
BackendInputCollector *collector,
const ::ctranslate2::Vocabulary &source_vocab,
std::vector<std::vector<std::string>> *input_tokens) {

const char *input_buffer;
size_t batchn_byte_size;
TRITONSERVER_MemoryType memory_type;
int64_t memory_type_id;

// TODO support data straight from GPU
std::vector<std::pair<TRITONSERVER_MemoryType, int64_t>> alloc_preference =
{{TRITONSERVER_MEMORY_CPU_PINNED, 0}, {TRITONSERVER_MEMORY_CPU, 0}};

RETURN_IF_ERROR(collector->ProcessTensor(
StateForModel()->InputTensorName().c_str(), nullptr, 0,
alloc_preference, &input_buffer, &batchn_byte_size, &memory_type,
&memory_type_id));

std::vector<std::vector<size_t>> token_ids;
token_ids.reserve(request_count);

bool is_ragged =
StateForModel()->IsInputRagged(StateForModel()->InputTensorName());

std::string tstr;
IGNORE_ERROR(BufferAsTypedString(tstr, input_buffer, batchn_byte_size,
TRITONSERVER_TYPE_INT32));
LOG_MESSAGE(
TRITONSERVER_LOG_VERBOSE,
(std::string("Input ragged: ") + std::to_string(is_ragged)).c_str());

if (is_ragged) {
int64_t total_elements = 0;
for (size_t request_idx = 0; request_idx < request_count; request_idx++) {
TRITONBACKEND_Input *input;
RESPOND_AND_SET_NULL_IF_ERROR(
&((*responses)[request_idx]),
TRITONBACKEND_RequestInput(
requests[request_idx],
StateForModel()->InputTensorName().c_str(), &input));

TRITONSERVER_DataType input_dt;
const int64_t *input_shape;
uint32_t input_dims_count;
RETURN_IF_ERROR(TRITONBACKEND_InputProperties(
input, nullptr, &input_dt, &input_shape, &input_dims_count, nullptr,
nullptr));

auto element_count = GetElementCount(input_shape, input_dims_count);
LOG_MESSAGE(TRITONSERVER_LOG_VERBOSE,
(std::string("Element count for request ") +
std::to_string(request_idx) + std::string(": ") +
std::to_string(element_count))
.c_str());

std::vector<size_t> ids;

ToIdVector(input_buffer, input_dt, &ids, total_elements, element_count);
total_elements += element_count;
token_ids.emplace_back(ids);
}
} else {
// input type is the same for all
TRITONBACKEND_Input *input;
RETURN_IF_ERROR(TRITONBACKEND_RequestInput(
requests[0], StateForModel()->InputTensorName().c_str(), &input));

TRITONSERVER_DataType input_dt;
const int64_t *input_shape;
uint32_t input_dims_count;
RETURN_IF_ERROR(
TRITONBACKEND_InputProperties(input, nullptr, &input_dt, &input_shape,
&input_dims_count, nullptr, nullptr));

if (input_dims_count > 2) {
return TRITONSERVER_ErrorNew(
TRITONSERVER_ERROR_INVALID_ARG,
std::string("Inputs with more than two dimensions unsupported")
.c_str());
}

std::vector<int64_t> batchn_shape =
std::vector<int64_t>(input_shape, input_shape + input_dims_count);
if (supports_batching_) {
batchn_shape[0] = total_batch_size;
}

for (size_t vector_idx = 0; vector_idx < total_batch_size; vector_idx++) {
std::vector<size_t> ids;
ToIdVector(input_buffer, input_dt, &ids, vector_idx * batchn_shape[1],
(vector_idx + 1) * batchn_shape[1]);
token_ids.emplace_back(ids);
}
const ::ctranslate2::Vocabulary &target_vocab,
std::vector<std::vector<std::string>> *input_tokens,
std::vector<std::vector<std::string>> *input_target_prefix,
size_t *max_sequence_length) {

std::vector<std::vector<size_t>> input_token_ids;
RETURN_IF_ERROR(InputBufferToRaggedTokens(
total_batch_size, requests, request_count, responses, collector,
&input_token_ids, max_sequence_length,
StateForModel()->InputTensorName(),
StateForModel()->IsInputRagged(StateForModel()->InputTensorName()),
supports_batching_));
*input_tokens = source_vocab.to_tokens(input_token_ids);
if (StateForModel()->TargetPrefixInputName()) {
std::vector<std::vector<size_t>> target_prefix_token_ids;
size_t discard_seq_length;
RETURN_IF_ERROR(InputBufferToRaggedTokens(
total_batch_size, requests, request_count, responses, collector,
&target_prefix_token_ids, &discard_seq_length,
*(StateForModel()->TargetPrefixInputName()),
StateForModel()->IsInputRagged(
*(StateForModel()->TargetPrefixInputName())),
supports_batching_));
*input_target_prefix = target_vocab.to_tokens(target_prefix_token_ids);
}

*input_tokens = source_vocab.to_tokens(token_ids);

return nullptr;
}

Expand Down Expand Up @@ -642,10 +744,13 @@ class ModelInstanceState : public BackendModelInstance {
nullptr /* stream*/);

std::vector<std::vector<std::string>> inputs;
std::vector<std::vector<std::string>> target_prefix;
size_t max_input_seq_length;
RESPOND_ALL_AND_SET_NULL_IF_ERROR(
responses, request_count,
CreateInput(total_batch_size, requests, request_count, &responses,
collector.get(), source_vocab, &inputs));
collector.get(), source_vocab, target_vocab, &inputs,
&target_prefix, &max_input_seq_length));

std::unique_ptr<::ctranslate2::models::SequenceToSequenceReplica>
seq2seq_replica = model_->as_sequence_to_sequence();
Expand All @@ -663,8 +768,18 @@ class ModelInstanceState : public BackendModelInstance {

uint64_t compute_start_ns = 0;
SET_TIMESTAMP(compute_start_ns);
::ctranslate2::TranslationOptions options =
StateForModel()->DefaultTranslationOptions();
auto max_decode_length_multiple =
StateForModel()->MaxDecodeLengthMultiple();
if (max_decode_length_multiple) {
options.max_decoding_length =
*max_decode_length_multiple * max_input_seq_length;
}
LOG_MESSAGE(TRITONSERVER_LOG_VERBOSE,
TranslationOptionsToString(options).c_str());
std::vector<::ctranslate2::TranslationResult> translation_results =
seq2seq_replica->translate(inputs);
seq2seq_replica->translate(inputs, target_prefix, options);

uint64_t compute_end_ns = 0;
SET_TIMESTAMP(compute_end_ns);
Expand Down

0 comments on commit ea4e09d

Please sign in to comment.