From ea4e09da89d0bea422a0de7cff4d3f4e60f0429a Mon Sep 17 00:00:00 2001 From: Markus Hennerbichler Date: Tue, 18 Apr 2023 12:28:35 +0100 Subject: [PATCH] Add optional target_prefix input --- src/ctranslate2.cc | 313 +++++++++++++++++++++++++++++++-------------- 1 file changed, 214 insertions(+), 99 deletions(-) diff --git a/src/ctranslate2.cc b/src/ctranslate2.cc index cce46ec..fc2207e 100644 --- a/src/ctranslate2.cc +++ b/src/ctranslate2.cc @@ -27,6 +27,10 @@ #include #include #include +#include +#include +#include +#include #include "triton/backend/backend_common.h" #include "triton/backend/backend_input_collector.h" @@ -66,6 +70,15 @@ ReadParameter(const triton::common::TritonJson::Value ¶ms, return nullptr; // success } +TRITONSERVER_Error * +ReadParameter(const triton::common::TritonJson::Value ¶ms, + const std::string &key, size_t *param) { + std::string tmp; + RETURN_IF_ERROR(ReadParameter(params, key, &tmp)); + *param = static_cast(std::stoi(tmp)); + return nullptr; // success +} + TRITONSERVER_Error * ReadParameter(const triton::common::TritonJson::Value ¶ms, const std::string &key, float *param) { @@ -106,9 +119,10 @@ class ModelState : public BackendModel { RETURN_IF_ERROR(ModelConfig().MemberAsArray("output", &outputs)); // The model must have exactly 1 input and 1 output. - RETURN_ERROR_IF_FALSE(inputs.ArraySize() == 1, - TRITONSERVER_ERROR_INVALID_ARG, - std::string("model configuration must have 1 input")); + RETURN_ERROR_IF_FALSE( + inputs.ArraySize() == 1 || inputs.ArraySize() == 2, + TRITONSERVER_ERROR_INVALID_ARG, + std::string("model configuration must have 1 or 2 inputs")); RETURN_ERROR_IF_FALSE( outputs.ArraySize() == 1, TRITONSERVER_ERROR_INVALID_ARG, std::string("model configuration must have 1 output")); @@ -123,6 +137,15 @@ class ModelState : public BackendModel { RETURN_IF_ERROR(input.MemberAsString("name", &input_name, &input_name_len)); input_name_ = std::string(input_name); + if (inputs.ArraySize() == 2) { + RETURN_IF_ERROR(inputs.IndexAsObject(1, &input)); + const char *target_prefix_input_name; + size_t target_prefix_input_name_len; + RETURN_IF_ERROR(input.MemberAsString("name", &target_prefix_input_name, + &target_prefix_input_name_len)); + target_prefix_input_name_ = std::string(target_prefix_input_name); + } + const char *output_name; size_t output_name_len; RETURN_IF_ERROR( @@ -148,13 +171,32 @@ class ModelState : public BackendModel { compute_type_str) .c_str()); } + if (params.Find("max_decoding_length_multiple")) { + size_t max_decode_length_multiple; + RETURN_IF_ERROR(ReadParameter(params, "max_decoding_length_multiple", + &max_decode_length_multiple)); + max_decode_length_multiple_ = max_decode_length_multiple; + } + if (params.Find("beam_size")) { + RETURN_IF_ERROR(ReadParameter( + params, "beam_size", &(default_translation_options_.beam_size))); + } } return nullptr; } const std::string &InputTensorName() const { return input_name_; } + const std::optional &TargetPrefixInputName() const { + return target_prefix_input_name_; + } const std::string &OutputTensorName() const { return output_name_; } TRITONSERVER_DataType OutputDataType() const { return output_type_; } + ::ctranslate2::TranslationOptions DefaultTranslationOptions() const { + return default_translation_options_; + } + const std::optional &MaxDecodeLengthMultiple() const { + return max_decode_length_multiple_; + } TRITONSERVER_Error * LoadModel(const ::ctranslate2::Device device, std::int32_t device_index, @@ -183,10 +225,13 @@ class ModelState : public BackendModel { // TRITONBACKEND_Model *triton_model_; triton::common::TritonJson::Value model_config_; std::string input_name_; + std::optional target_prefix_input_name_; std::string output_name_; TRITONSERVER_DataType output_type_; ::ctranslate2::ComputeType compute_type_ = ::ctranslate2::ComputeType::DEFAULT; + ::ctranslate2::TranslationOptions default_translation_options_; + std::optional max_decode_length_multiple_; std::string model_path_; std::shared_ptr<::ctranslate2::models::ModelReader> model_reader_; std::map, @@ -413,6 +458,133 @@ TRITONSERVER_Error *ToOutBuffer(const std::vector &out_tokens, return nullptr; } +std::string +TranslationOptionsToString(const ::ctranslate2::TranslationOptions &options) { + std::stringstream ss; + ss << "TranslationOptions(" + << "beam_size=" << options.beam_size << ", " + << "patience=" << options.patience << ", " + << "length_penalty=" << options.length_penalty << ", " + << "coverage_penalty=" << options.coverage_penalty << ", " + << "repetition_penalty=" << options.repetition_penalty << ", " + << "no_repeat_ngram_size=" << options.no_repeat_ngram_size << ", " + << "disable_unk=" << options.disable_unk << ", " + << "size(suppress_sequences)=" << options.suppress_sequences.size() << ", " + << "prefix_bias_beta=" << options.prefix_bias_beta << ", " + << "end_token=\"" << options.end_token << "\", " + << "max_input_length=" << options.max_input_length << ", " + << "max_decoding_length=" << options.max_decoding_length << ", " + << "min_decoding_length=" << options.min_decoding_length << ", " + << "sampling_topk=" << options.sampling_topk << ", " + << "sampling_temperature=" << options.sampling_temperature << ", " + << "use_vmap=" << options.use_vmap << ", " + << "num_hypotheses=" << options.num_hypotheses << ", " + << "return_scores=" << options.return_scores << ", " + << "return_attention=" << options.return_attention << ", " + << "return_alternatives=" << options.return_alternatives << ", " + << "min_alternative_expansion_prob=" + << options.min_alternative_expansion_prob << ", " + << "replace_unknowns=" << options.replace_unknowns << ")"; + return ss.str(); +} + +TRITONSERVER_Error *InputBufferToRaggedTokens( + size_t total_batch_size, TRITONBACKEND_Request **requests, + const uint32_t request_count, + std::vector *responses, + BackendInputCollector *collector, + std::vector> *ragged_tokens, + size_t *max_sequence_length, const std::string &input_name, + bool is_ragged_input = true, bool supports_batching = true) { + std::vector> tokens; + tokens.reserve(request_count); + + const char *input_buffer; + size_t batchn_byte_size; + TRITONSERVER_MemoryType memory_type; + int64_t memory_type_id; + + // TODO support data straight from GPU + std::vector> alloc_preference = { + {TRITONSERVER_MEMORY_CPU_PINNED, 0}, {TRITONSERVER_MEMORY_CPU, 0}}; + + RETURN_IF_ERROR(collector->ProcessTensor( + input_name.c_str(), nullptr, 0, alloc_preference, &input_buffer, + &batchn_byte_size, &memory_type, &memory_type_id)); + + // bool is_ragged = + // + size_t max_seq_length = 0; + if (is_ragged_input) { + int64_t total_elements = 0; + for (size_t request_idx = 0; request_idx < request_count; request_idx++) { + TRITONBACKEND_Input *input; + RESPOND_AND_SET_NULL_IF_ERROR( + &((*responses)[request_idx]), + TRITONBACKEND_RequestInput(requests[request_idx], input_name.c_str(), + &input)); + + TRITONSERVER_DataType input_dt; + const int64_t *input_shape; + uint32_t input_dims_count; + RETURN_IF_ERROR( + TRITONBACKEND_InputProperties(input, nullptr, &input_dt, &input_shape, + &input_dims_count, nullptr, nullptr)); + + auto element_count = GetElementCount(input_shape, input_dims_count); + LOG_MESSAGE(TRITONSERVER_LOG_VERBOSE, + (std::string("Element count for request ") + + std::to_string(request_idx) + std::string(": ") + + std::to_string(element_count)) + .c_str()); + max_seq_length = + std::max(max_seq_length, static_cast(element_count)); + + std::vector ids; + ToIdVector(input_buffer, input_dt, &ids, total_elements, element_count); + total_elements += element_count; + tokens.emplace_back(ids); + } + } else { + // input type is the same for all + TRITONBACKEND_Input *input; + RETURN_IF_ERROR( + TRITONBACKEND_RequestInput(requests[0], input_name.c_str(), &input)); + + TRITONSERVER_DataType input_dt; + const int64_t *input_shape; + uint32_t input_dims_count; + RETURN_IF_ERROR( + TRITONBACKEND_InputProperties(input, nullptr, &input_dt, &input_shape, + &input_dims_count, nullptr, nullptr)); + + if (input_dims_count > 2) { + return TRITONSERVER_ErrorNew( + TRITONSERVER_ERROR_INVALID_ARG, + std::string("Inputs with more than two dimensions unsupported") + .c_str()); + } + + std::vector batchn_shape = + std::vector(input_shape, input_shape + input_dims_count); + if (supports_batching) { + batchn_shape[0] = total_batch_size; + } + + for (size_t vector_idx = 0; vector_idx < total_batch_size; vector_idx++) { + std::vector ids; + ToIdVector(input_buffer, input_dt, &ids, vector_idx * batchn_shape[1], + (vector_idx + 1) * batchn_shape[1]); + tokens.emplace_back(ids); + } + max_seq_length = static_cast(batchn_shape[1]); + } + + *ragged_tokens = tokens; + *max_sequence_length = max_seq_length; + + return nullptr; +} ///////////// // @@ -460,101 +632,31 @@ class ModelInstanceState : public BackendModelInstance { std::vector *responses, BackendInputCollector *collector, const ::ctranslate2::Vocabulary &source_vocab, - std::vector> *input_tokens) { - - const char *input_buffer; - size_t batchn_byte_size; - TRITONSERVER_MemoryType memory_type; - int64_t memory_type_id; - - // TODO support data straight from GPU - std::vector> alloc_preference = - {{TRITONSERVER_MEMORY_CPU_PINNED, 0}, {TRITONSERVER_MEMORY_CPU, 0}}; - - RETURN_IF_ERROR(collector->ProcessTensor( - StateForModel()->InputTensorName().c_str(), nullptr, 0, - alloc_preference, &input_buffer, &batchn_byte_size, &memory_type, - &memory_type_id)); - - std::vector> token_ids; - token_ids.reserve(request_count); - - bool is_ragged = - StateForModel()->IsInputRagged(StateForModel()->InputTensorName()); - - std::string tstr; - IGNORE_ERROR(BufferAsTypedString(tstr, input_buffer, batchn_byte_size, - TRITONSERVER_TYPE_INT32)); - LOG_MESSAGE( - TRITONSERVER_LOG_VERBOSE, - (std::string("Input ragged: ") + std::to_string(is_ragged)).c_str()); - - if (is_ragged) { - int64_t total_elements = 0; - for (size_t request_idx = 0; request_idx < request_count; request_idx++) { - TRITONBACKEND_Input *input; - RESPOND_AND_SET_NULL_IF_ERROR( - &((*responses)[request_idx]), - TRITONBACKEND_RequestInput( - requests[request_idx], - StateForModel()->InputTensorName().c_str(), &input)); - - TRITONSERVER_DataType input_dt; - const int64_t *input_shape; - uint32_t input_dims_count; - RETURN_IF_ERROR(TRITONBACKEND_InputProperties( - input, nullptr, &input_dt, &input_shape, &input_dims_count, nullptr, - nullptr)); - - auto element_count = GetElementCount(input_shape, input_dims_count); - LOG_MESSAGE(TRITONSERVER_LOG_VERBOSE, - (std::string("Element count for request ") + - std::to_string(request_idx) + std::string(": ") + - std::to_string(element_count)) - .c_str()); - - std::vector ids; - - ToIdVector(input_buffer, input_dt, &ids, total_elements, element_count); - total_elements += element_count; - token_ids.emplace_back(ids); - } - } else { - // input type is the same for all - TRITONBACKEND_Input *input; - RETURN_IF_ERROR(TRITONBACKEND_RequestInput( - requests[0], StateForModel()->InputTensorName().c_str(), &input)); - - TRITONSERVER_DataType input_dt; - const int64_t *input_shape; - uint32_t input_dims_count; - RETURN_IF_ERROR( - TRITONBACKEND_InputProperties(input, nullptr, &input_dt, &input_shape, - &input_dims_count, nullptr, nullptr)); - - if (input_dims_count > 2) { - return TRITONSERVER_ErrorNew( - TRITONSERVER_ERROR_INVALID_ARG, - std::string("Inputs with more than two dimensions unsupported") - .c_str()); - } - - std::vector batchn_shape = - std::vector(input_shape, input_shape + input_dims_count); - if (supports_batching_) { - batchn_shape[0] = total_batch_size; - } - - for (size_t vector_idx = 0; vector_idx < total_batch_size; vector_idx++) { - std::vector ids; - ToIdVector(input_buffer, input_dt, &ids, vector_idx * batchn_shape[1], - (vector_idx + 1) * batchn_shape[1]); - token_ids.emplace_back(ids); - } + const ::ctranslate2::Vocabulary &target_vocab, + std::vector> *input_tokens, + std::vector> *input_target_prefix, + size_t *max_sequence_length) { + + std::vector> input_token_ids; + RETURN_IF_ERROR(InputBufferToRaggedTokens( + total_batch_size, requests, request_count, responses, collector, + &input_token_ids, max_sequence_length, + StateForModel()->InputTensorName(), + StateForModel()->IsInputRagged(StateForModel()->InputTensorName()), + supports_batching_)); + *input_tokens = source_vocab.to_tokens(input_token_ids); + if (StateForModel()->TargetPrefixInputName()) { + std::vector> target_prefix_token_ids; + size_t discard_seq_length; + RETURN_IF_ERROR(InputBufferToRaggedTokens( + total_batch_size, requests, request_count, responses, collector, + &target_prefix_token_ids, &discard_seq_length, + *(StateForModel()->TargetPrefixInputName()), + StateForModel()->IsInputRagged( + *(StateForModel()->TargetPrefixInputName())), + supports_batching_)); + *input_target_prefix = target_vocab.to_tokens(target_prefix_token_ids); } - - *input_tokens = source_vocab.to_tokens(token_ids); - return nullptr; } @@ -642,10 +744,13 @@ class ModelInstanceState : public BackendModelInstance { nullptr /* stream*/); std::vector> inputs; + std::vector> target_prefix; + size_t max_input_seq_length; RESPOND_ALL_AND_SET_NULL_IF_ERROR( responses, request_count, CreateInput(total_batch_size, requests, request_count, &responses, - collector.get(), source_vocab, &inputs)); + collector.get(), source_vocab, target_vocab, &inputs, + &target_prefix, &max_input_seq_length)); std::unique_ptr<::ctranslate2::models::SequenceToSequenceReplica> seq2seq_replica = model_->as_sequence_to_sequence(); @@ -663,8 +768,18 @@ class ModelInstanceState : public BackendModelInstance { uint64_t compute_start_ns = 0; SET_TIMESTAMP(compute_start_ns); + ::ctranslate2::TranslationOptions options = + StateForModel()->DefaultTranslationOptions(); + auto max_decode_length_multiple = + StateForModel()->MaxDecodeLengthMultiple(); + if (max_decode_length_multiple) { + options.max_decoding_length = + *max_decode_length_multiple * max_input_seq_length; + } + LOG_MESSAGE(TRITONSERVER_LOG_VERBOSE, + TranslationOptionsToString(options).c_str()); std::vector<::ctranslate2::TranslationResult> translation_results = - seq2seq_replica->translate(inputs); + seq2seq_replica->translate(inputs, target_prefix, options); uint64_t compute_end_ns = 0; SET_TIMESTAMP(compute_end_ns);