diff --git a/src/Machine/test/Serval.Machine.Shared.Tests/Services/PreprocessBuildJobTests.cs b/src/Machine/test/Serval.Machine.Shared.Tests/Services/PreprocessBuildJobTests.cs index 442cb06a..74e5031c 100644 --- a/src/Machine/test/Serval.Machine.Shared.Tests/Services/PreprocessBuildJobTests.cs +++ b/src/Machine/test/Serval.Machine.Shared.Tests/Services/PreprocessBuildJobTests.cs @@ -65,7 +65,7 @@ public async Task RunAsync_TrainAndPretranslateAll() await env.RunBuildJobAsync(corpus1); - Assert.That(await env.GetPretranslateCountAsync(), Is.EqualTo(4)); + Assert.That(await env.GetPretranslateCountAsync(), Is.EqualTo(2)); } [Test] @@ -76,7 +76,7 @@ public async Task RunAsync_PretranslateAll() await env.RunBuildJobAsync(corpus1); - Assert.That(await env.GetPretranslateCountAsync(), Is.EqualTo(4)); + Assert.That(await env.GetPretranslateCountAsync(), Is.EqualTo(2)); } [Test] @@ -87,7 +87,7 @@ public async Task RunAsync_PretranslateTextIds() await env.RunBuildJobAsync(corpus1); - Assert.That(await env.GetPretranslateCountAsync(), Is.EqualTo(4)); + Assert.That(await env.GetPretranslateCountAsync(), Is.EqualTo(2)); } [Test] @@ -206,7 +206,11 @@ public async Task RunAsync_MixedSource_Paratext() Assert.That(trgCount, Is.EqualTo(1)); Assert.That(termCount, Is.EqualTo(0)); }); - Assert.That(await env.GetPretranslateCountAsync(), Is.EqualTo(56)); + Assert.That( + await env.GetPretranslateCountAsync(), + Is.EqualTo(11), + JsonSerializer.Serialize(await env.GetPretranslationsAsync()) + ); } [Test] @@ -225,7 +229,7 @@ public async Task RunAsync_MixedSource_Text() Assert.That(trgCount, Is.EqualTo(1)); Assert.That(termCount, Is.EqualTo(0)); }); - Assert.That(await env.GetPretranslateCountAsync(), Is.EqualTo(9)); + Assert.That(await env.GetPretranslateCountAsync(), Is.EqualTo(3)); } [Test] @@ -488,10 +492,10 @@ await env.GetTargetExtractAsync(), }); JsonArray? pretranslations = await env.GetPretranslationsAsync(); Assert.That(pretranslations, Is.Not.Null); - Assert.That(pretranslations!.Count, Is.EqualTo(37), pretranslations.ToJsonString()); + Assert.That(pretranslations!.Count, Is.EqualTo(3), pretranslations.ToJsonString()); Assert.That( pretranslations[2]!["translation"]!.ToString(), - Is.EqualTo("Source one, chapter twelve, verse one.") + Is.EqualTo("Source one, chapter thirteen, verse one.") ); } diff --git a/src/ServiceToolkit/src/SIL.ServiceToolkit/Utils/ParallelCorpusPreprocessor.cs b/src/ServiceToolkit/src/SIL.ServiceToolkit/Utils/ParallelCorpusPreprocessor.cs index e4a4dba8..0c728f0b 100644 --- a/src/ServiceToolkit/src/SIL.ServiceToolkit/Utils/ParallelCorpusPreprocessor.cs +++ b/src/ServiceToolkit/src/SIL.ServiceToolkit/Utils/ParallelCorpusPreprocessor.cs @@ -52,7 +52,7 @@ public void Preprocess( { ITextCorpus textCorpus = sc.TextCorpus; if (sc.Corpus.TrainOnTextIds is not null) - textCorpus = textCorpus.FilterTexts(sc.Corpus.TrainOnTextIds); + return textCorpus = textCorpus.FilterTexts(sc.Corpus.TrainOnTextIds); return textCorpus.Where(row => row.Ref is not ScriptureRef sr || sc.Corpus.TrainOnChapters is null @@ -66,7 +66,7 @@ row.Ref is not ScriptureRef sr ITextCorpus textCorpus = sc.TextCorpus; if (sc.Corpus.PretranslateTextIds is not null) { - return textCorpus.FilterTexts( + return textCorpus = textCorpus.FilterTexts( sc.Corpus.PretranslateTextIds.Except(sc.Corpus.TrainOnTextIds ?? new()) ); } @@ -90,7 +90,7 @@ row.Ref is not ScriptureRef sr { ITextCorpus textCorpus = tc.TextCorpus; if (tc.Corpus.TrainOnTextIds is not null) - textCorpus = textCorpus.FilterTexts(tc.Corpus.TrainOnTextIds); + return textCorpus = textCorpus.FilterTexts(tc.Corpus.TrainOnTextIds); return textCorpus.Where(row => row.Ref is not ScriptureRef sr || tc.Corpus.TrainOnChapters is null @@ -309,9 +309,9 @@ private static IEnumerable AlignPretranslateCorpus(ITextCorpus[] srcCorpora foreach ( ParallelTextRow? row in srcCorpora .SelectMany(sc => trgCorpora.Select(tc => sc.AlignRows(tc, allSourceRows: true))) - .ZipMany(rows => - rows.Where(r => r.SourceSegment.Count > 0 && r.TargetSegment.Count == 0).FirstOrDefault() - ) + .ZipMany(rows => rows.ToArray()) + .Where(rows => rows.All(r => r.TargetSegment.Count == 0)) + .Select(rows => rows.Where(r => r.SourceSegment.Count > 0).FirstOrDefault()) ) { if (row is null)