From 7015a7780e8b268a0809e121e9fe205dbe857bbe Mon Sep 17 00:00:00 2001 From: Damien Daspit Date: Thu, 16 May 2024 17:31:24 -0500 Subject: [PATCH] Do not train/pretranslate all if textIds is an empty array - only train/pretranslate all if both textIds and scriptureRange are not specified - fixes #383 --- .../Services/EngineService.cs | 7 +- .../Services/EngineServiceTests.cs | 381 +++++++++++++++++- 2 files changed, 363 insertions(+), 25 deletions(-) diff --git a/src/Serval.Translation/Services/EngineService.cs b/src/Serval.Translation/Services/EngineService.cs index decd7dea..46a4f575 100644 --- a/src/Serval.Translation/Services/EngineService.cs +++ b/src/Serval.Translation/Services/EngineService.cs @@ -223,7 +223,6 @@ Dictionary> GetChapters(V1.Corpus corpus, string scriptureRang EngineType = engine.Type, EngineId = engine.Id, BuildId = build.Id, - Options = JsonSerializer.Serialize(build.Options), Corpora = { engine.Corpora.Select(c => @@ -232,7 +231,7 @@ Dictionary> GetChapters(V1.Corpus corpus, string scriptureRang if (pretranslate?.TryGetValue(c.Id, out PretranslateCorpus? pretranslateCorpus) ?? false) { corpus.PretranslateAll = - pretranslateCorpus.TextIds is null || pretranslateCorpus.TextIds.Count == 0; + pretranslateCorpus.TextIds is null && pretranslateCorpus.ScriptureRange is null; if (pretranslateCorpus.TextIds is not null) corpus.PretranslateTextIds.Add(pretranslateCorpus.TextIds); if (pretranslateCorpus.ScriptureRange is not null) @@ -262,7 +261,7 @@ Dictionary> GetChapters(V1.Corpus corpus, string scriptureRang } if (trainOn?.TryGetValue(c.Id, out TrainingCorpus? trainingCorpus) ?? false) { - corpus.TrainOnAll = trainingCorpus.TextIds is null || trainingCorpus.TextIds.Count == 0; + corpus.TrainOnAll = trainingCorpus.TextIds is null && trainingCorpus.ScriptureRange is null; if (trainingCorpus.TextIds is not null) corpus.TrainOnTextIds.Add(trainingCorpus.TextIds); if (trainingCorpus.ScriptureRange is not null) @@ -298,6 +297,8 @@ Dictionary> GetChapters(V1.Corpus corpus, string scriptureRang }) } }; + if (build.Options is not null) + request.Options = JsonSerializer.Serialize(build.Options); // Log the build request summary try diff --git a/tests/Serval.Translation.Tests/Services/EngineServiceTests.cs b/tests/Serval.Translation.Tests/Services/EngineServiceTests.cs index 7ee6bf75..64d38246 100644 --- a/tests/Serval.Translation.Tests/Services/EngineServiceTests.cs +++ b/tests/Serval.Translation.Tests/Services/EngineServiceTests.cs @@ -19,7 +19,7 @@ public void TranslateAsync_EngineDoesNotExist() public async Task TranslateAsync_EngineExists() { var env = new TestEnvironment(); - string engineId = (await env.CreateEngineAsync()).Id; + string engineId = (await env.CreateEngineWithTextFilesAsync()).Id; Models.TranslationResult? result = await env.Service.TranslateAsync(engineId, "esto es una prueba."); Assert.That(result, Is.Not.Null); Assert.That(result!.Translation, Is.EqualTo("this is a test.")); @@ -38,7 +38,7 @@ public void GetWordGraphAsync_EngineDoesNotExist() public async Task GetWordGraphAsync_EngineExists() { var env = new TestEnvironment(); - string engineId = (await env.CreateEngineAsync()).Id; + string engineId = (await env.CreateEngineWithTextFilesAsync()).Id; Models.WordGraph? result = await env.Service.GetWordGraphAsync(engineId, "esto es una prueba."); Assert.That(result, Is.Not.Null); Assert.That(result!.Arcs.SelectMany(a => a.TargetTokens), Is.EqualTo("this is a test .".Split())); @@ -57,7 +57,7 @@ public void TrainSegmentAsync_EngineDoesNotExist() public async Task TrainSegmentAsync_EngineExists() { var env = new TestEnvironment(); - string engineId = (await env.CreateEngineAsync()).Id; + string engineId = (await env.CreateEngineWithTextFilesAsync()).Id; Assert.DoesNotThrowAsync( () => env.Service.TrainSegmentPairAsync(engineId, "esto es una prueba.", "this is a test.", true) ); @@ -88,7 +88,7 @@ public async Task CreateAsync() public async Task DeleteAsync_EngineExists() { var env = new TestEnvironment(); - string engineId = (await env.CreateEngineAsync()).Id; + string engineId = (await env.CreateEngineWithTextFilesAsync()).Id; await env.Service.DeleteAsync("engine1"); Engine? engine = await env.Engines.GetAsync(engineId); Assert.That(engine, Is.Null); @@ -98,23 +98,299 @@ public async Task DeleteAsync_EngineExists() public async Task DeleteAsync_ProjectDoesNotExist() { var env = new TestEnvironment(); - await env.CreateEngineAsync(); + await env.CreateEngineWithTextFilesAsync(); Assert.ThrowsAsync(() => env.Service.DeleteAsync("engine3")); } [Test] - public async Task StartBuildAsync_EngineExists() + public async Task StartBuildAsync_TrainOnNotSpecified() { var env = new TestEnvironment(); - string engineId = (await env.CreateEngineAsync()).Id; - Assert.DoesNotThrowAsync(() => env.Service.StartBuildAsync(new Build { Id = BUILD1_ID, EngineRef = engineId })); + string engineId = (await env.CreateEngineWithTextFilesAsync()).Id; + await env.Service.StartBuildAsync(new Build { Id = BUILD1_ID, EngineRef = engineId }); + _ = env.TranslationServiceClient.Received() + .StartBuildAsync( + new StartBuildRequest + { + BuildId = BUILD1_ID, + EngineId = engineId, + EngineType = "Smt", + Corpora = + { + new V1.Corpus + { + Id = "corpus1", + SourceLanguage = "es", + TargetLanguage = "en", + TrainOnAll = true, + SourceFiles = + { + new V1.CorpusFile + { + Location = "file1.txt", + Format = FileFormat.Text, + TextId = "text1" + } + }, + TargetFiles = + { + new V1.CorpusFile + { + Location = "file2.txt", + Format = FileFormat.Text, + TextId = "text1" + } + } + } + } + } + ); + } + + [Test] + public async Task StartBuildAsync_TextIdsEmpty() + { + var env = new TestEnvironment(); + string engineId = (await env.CreateEngineWithTextFilesAsync()).Id; + await env.Service.StartBuildAsync( + new Build + { + Id = BUILD1_ID, + EngineRef = engineId, + TrainOn = [new TrainingCorpus { CorpusRef = "corpus1", TextIds = [] }] + } + ); + _ = env.TranslationServiceClient.Received() + .StartBuildAsync( + new StartBuildRequest + { + BuildId = BUILD1_ID, + EngineId = engineId, + EngineType = "Smt", + Corpora = + { + new V1.Corpus + { + Id = "corpus1", + SourceLanguage = "es", + TargetLanguage = "en", + TrainOnAll = false, + TrainOnTextIds = { }, + SourceFiles = + { + new V1.CorpusFile + { + Location = "file1.txt", + Format = FileFormat.Text, + TextId = "text1" + } + }, + TargetFiles = + { + new V1.CorpusFile + { + Location = "file2.txt", + Format = FileFormat.Text, + TextId = "text1" + } + } + } + } + } + ); + } + + [Test] + public async Task StartBuildAsync_TextIdsPopulated() + { + var env = new TestEnvironment(); + string engineId = (await env.CreateEngineWithTextFilesAsync()).Id; + await env.Service.StartBuildAsync( + new Build + { + Id = BUILD1_ID, + EngineRef = engineId, + TrainOn = [new TrainingCorpus { CorpusRef = "corpus1", TextIds = ["text1"] }] + } + ); + _ = env.TranslationServiceClient.Received() + .StartBuildAsync( + new StartBuildRequest + { + BuildId = BUILD1_ID, + EngineId = engineId, + EngineType = "Smt", + Corpora = + { + new V1.Corpus + { + Id = "corpus1", + SourceLanguage = "es", + TargetLanguage = "en", + TrainOnAll = false, + TrainOnTextIds = { "text1" }, + SourceFiles = + { + new V1.CorpusFile + { + Location = "file1.txt", + Format = FileFormat.Text, + TextId = "text1" + } + }, + TargetFiles = + { + new V1.CorpusFile + { + Location = "file2.txt", + Format = FileFormat.Text, + TextId = "text1" + } + } + } + } + } + ); + } + + [Test] + public async Task StartBuildAsync_TextIdsNotSpecified() + { + var env = new TestEnvironment(); + string engineId = (await env.CreateEngineWithTextFilesAsync()).Id; + await env.Service.StartBuildAsync( + new Build + { + Id = BUILD1_ID, + EngineRef = engineId, + TrainOn = [new TrainingCorpus { CorpusRef = "corpus1" }] + } + ); + _ = env.TranslationServiceClient.Received() + .StartBuildAsync( + new StartBuildRequest + { + BuildId = BUILD1_ID, + EngineId = engineId, + EngineType = "Smt", + Corpora = + { + new V1.Corpus + { + Id = "corpus1", + SourceLanguage = "es", + TargetLanguage = "en", + TrainOnAll = true, + SourceFiles = + { + new V1.CorpusFile + { + Location = "file1.txt", + Format = FileFormat.Text, + TextId = "text1" + } + }, + TargetFiles = + { + new V1.CorpusFile + { + Location = "file2.txt", + Format = FileFormat.Text, + TextId = "text1" + } + } + } + } + } + ); + } + + [Test] + public async Task StartBuildAsync_TextFilesScriptureRangeSpecified() + { + var env = new TestEnvironment(); + string engineId = (await env.CreateEngineWithTextFilesAsync()).Id; + Assert.ThrowsAsync( + () => + env.Service.StartBuildAsync( + new Build + { + Id = BUILD1_ID, + EngineRef = engineId, + TrainOn = [new TrainingCorpus { CorpusRef = "corpus1", ScriptureRange = "MAT" }] + } + ) + ); + } + + [Test] + public async Task StartBuildAsync_ScriptureRangeSpecified() + { + var env = new TestEnvironment(); + string engineId = (await env.CreateEngineWithParatextProjectAsync()).Id; + await env.Service.StartBuildAsync( + new Build + { + Id = BUILD1_ID, + EngineRef = engineId, + TrainOn = [new TrainingCorpus { CorpusRef = "corpus1", ScriptureRange = "MAT 1;MRK" }] + } + ); + _ = env.TranslationServiceClient.Received() + .StartBuildAsync( + new StartBuildRequest + { + BuildId = BUILD1_ID, + EngineId = engineId, + EngineType = "Smt", + Corpora = + { + new V1.Corpus + { + Id = "corpus1", + SourceLanguage = "es", + TargetLanguage = "en", + TrainOnAll = false, + TrainOnChapters = + { + { + "MAT", + new ScriptureChapters { Chapters = { 1 } } + }, + { + "MRK", + new ScriptureChapters { Chapters = { } } + } + }, + SourceFiles = + { + new V1.CorpusFile + { + Location = "file1.zip", + Format = FileFormat.Paratext, + TextId = "file1.zip" + } + }, + TargetFiles = + { + new V1.CorpusFile + { + Location = "file2.zip", + Format = FileFormat.Paratext, + TextId = "file2.zip" + } + } + } + } + } + ); } [Test] public async Task CancelBuildAsync_EngineExistsNotBuilding() { var env = new TestEnvironment(); - string engineId = (await env.CreateEngineAsync()).Id; + string engineId = (await env.CreateEngineWithTextFilesAsync()).Id; await env.Service.CancelBuildAsync(engineId); } @@ -122,7 +398,7 @@ public async Task CancelBuildAsync_EngineExistsNotBuilding() public async Task UpdateCorpusAsync() { var env = new TestEnvironment(); - Engine engine = await env.CreateEngineAsync(); + Engine engine = await env.CreateEngineWithTextFilesAsync(); string corpusId = engine.Corpora[0].Id; Models.Corpus? corpus = await env.Service.UpdateCorpusAsync( @@ -160,8 +436,7 @@ private class TestEnvironment public TestEnvironment() { Engines = new MemoryRepository(); - TranslationEngineApi.TranslationEngineApiClient translationServiceClient = - Substitute.For(); + TranslationServiceClient = Substitute.For(); var translationResult = new V1.TranslationResult { Translation = "this is a test.", @@ -195,7 +470,7 @@ public TestEnvironment() } }; var translateResponse = new TranslateResponse { Results = { translationResult } }; - translationServiceClient + TranslationServiceClient .TranslateAsync(Arg.Any()) .Returns(CreateAsyncUnaryCall(translateResponse)); var wordGraph = new V1.WordGraph @@ -254,28 +529,46 @@ public TestEnvironment() } }; var getWordGraphResponse = new GetWordGraphResponse { WordGraph = wordGraph }; - translationServiceClient + TranslationServiceClient .GetWordGraphAsync(Arg.Any()) .Returns(CreateAsyncUnaryCall(getWordGraphResponse)); - translationServiceClient + TranslationServiceClient .CancelBuildAsync(Arg.Any()) .Returns(CreateAsyncUnaryCall(new Empty())); - translationServiceClient + TranslationServiceClient .CreateAsync(Arg.Any()) .Returns(CreateAsyncUnaryCall(new CreateResponse())); - translationServiceClient.DeleteAsync(Arg.Any()).Returns(CreateAsyncUnaryCall(new Empty())); - translationServiceClient + TranslationServiceClient.DeleteAsync(Arg.Any()).Returns(CreateAsyncUnaryCall(new Empty())); + TranslationServiceClient .StartBuildAsync(Arg.Any()) .Returns(CreateAsyncUnaryCall(new Empty())); - translationServiceClient + TranslationServiceClient .TrainSegmentPairAsync(Arg.Any()) .Returns(CreateAsyncUnaryCall(new Empty())); GrpcClientFactory grpcClientFactory = Substitute.For(); grpcClientFactory .CreateClient("Smt") - .Returns(translationServiceClient); + .Returns(TranslationServiceClient); IOptionsMonitor dataFileOptions = Substitute.For>(); dataFileOptions.CurrentValue.Returns(new DataFileOptions()); + var scriptureDataFileService = Substitute.For(); + scriptureDataFileService + .GetParatextProjectSettings(Arg.Any()) + .Returns( + new ParatextProjectSettings( + name: "Tst", + fullName: "Test", + encoding: Encoding.UTF8, + versification: ScrVers.English, + stylesheet: new UsfmStylesheet("usfm.sty"), + fileNamePrefix: "TST", + fileNameForm: "MAT", + fileNameSuffix: ".USFM", + biblicalTermsListType: "BiblicalTerms", + biblicalTermsProjectName: "", + biblicalTermsFileName: "BiblicalTerms.xml" + ) + ); Service = new EngineService( Engines, @@ -285,14 +578,15 @@ public TestEnvironment() dataFileOptions, new MemoryDataAccessContext(), new LoggerFactory(), - new ScriptureDataFileService(new FileSystem(), dataFileOptions) + scriptureDataFileService ); } public EngineService Service { get; } public IRepository Engines { get; } + public TranslationEngineApi.TranslationEngineApiClient TranslationServiceClient { get; } - public async Task CreateEngineAsync() + public async Task CreateEngineWithTextFilesAsync() { var engine = new Engine { @@ -335,6 +629,49 @@ public async Task CreateEngineAsync() return engine; } + public async Task CreateEngineWithParatextProjectAsync() + { + var engine = new Engine + { + Id = "engine1", + Owner = "owner1", + SourceLanguage = "es", + TargetLanguage = "en", + Type = "Smt", + Corpora = new Models.Corpus[] + { + new() + { + Id = "corpus1", + SourceLanguage = "es", + TargetLanguage = "en", + SourceFiles = + [ + new() + { + Id = "file1", + Filename = "file1.zip", + Format = Shared.Contracts.FileFormat.Paratext, + TextId = "file1.zip" + } + ], + TargetFiles = + [ + new() + { + Id = "file2", + Filename = "file2.zip", + Format = Shared.Contracts.FileFormat.Paratext, + TextId = "file2.zip" + } + ], + } + } + }; + await Engines.InsertAsync(engine); + return engine; + } + private static TranslationSources[] GetSources(int count, bool isUnknown) { var sources = new TranslationSources[count];