-
Notifications
You must be signed in to change notification settings - Fork 215
/
nllb.rs
51 lines (44 loc) · 2.12 KB
/
nllb.rs
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
use rust_bert::nllb::{
NLLBConfigResources, NLLBLanguages, NLLBMergeResources, NLLBResources, NLLBVocabResources,
};
use rust_bert::pipelines::common::{ModelResource, ModelType};
use rust_bert::pipelines::translation::{Language, TranslationConfig, TranslationModel};
use rust_bert::resources::RemoteResource;
use tch::Device;
#[test]
// #[cfg_attr(not(feature = "all-tests"), ignore)]
fn nllb_translation() -> anyhow::Result<()> {
let model_resource = ModelResource::Torch(Box::new(RemoteResource::from_pretrained(
NLLBResources::NLLB_600M_DISTILLED,
)));
let config_resource = RemoteResource::from_pretrained(NLLBConfigResources::NLLB_600M_DISTILLED);
let vocab_resource = RemoteResource::from_pretrained(NLLBVocabResources::NLLB_600M_DISTILLED);
let merges_resource = RemoteResource::from_pretrained(NLLBMergeResources::NLLB_600M_DISTILLED);
// let special_map = RemoteResource::from_pretrained(NLLBSpecialMap::NLLB_600M_DISTILLED);
let source_languages = NLLBLanguages::NLLB;
let target_languages = NLLBLanguages::NLLB;
let translation_config = TranslationConfig::new(
ModelType::NLLB,
model_resource,
config_resource,
vocab_resource,
Some(merges_resource),
source_languages,
target_languages,
Device::Cpu,
);
let model = TranslationModel::new(translation_config)?;
let source_sentence = "This sentence will be translated in multiple languages.";
let mut outputs = Vec::new();
outputs.extend(model.translate(&[source_sentence], Language::English, Language::French)?);
outputs.extend(model.translate(&[source_sentence], Language::English, Language::Spanish)?);
outputs.extend(model.translate(&[source_sentence], Language::English, Language::Hindi)?);
assert_eq!(outputs.len(), 3);
assert_eq!(
outputs[0],
" Cette phrase sera traduite en plusieurs langues."
);
assert_eq!(outputs[1], " Esta frase será traducida a varios idiomas.");
assert_eq!(outputs[2], " यह वाक्य कई भाषाओं में अनुवादित किया जाएगा।");
Ok(())
}