Skip to content

Commit

Permalink
Add test/proof-of-concept
Browse files Browse the repository at this point in the history
  • Loading branch information
Enkidu93 committed Sep 30, 2024
1 parent a3e10ed commit 84d78ed
Show file tree
Hide file tree
Showing 8 changed files with 547 additions and 4 deletions.
298 changes: 298 additions & 0 deletions src/Machine/src/Serval.Machine.Shared/Services/PreprocessBuildJob.cs
Original file line number Diff line number Diff line change
Expand Up @@ -205,6 +205,210 @@ CancellationToken cancellationToken
return (trainCount, pretranslateCount);
}

public record ParallelCorpusSubcorpus
{
public required string Id { get; set; }
public required string Language { get; set; }
public required IReadOnlyList<CorpusFile> Files { get; set; }
public HashSet<string>? TrainOnTextIds { get; set; }
public Dictionary<string, HashSet<int>>? TrainOnChapters { get; set; }
public HashSet<string>? PretranslateTextIds { get; set; }
public Dictionary<string, HashSet<int>>? PretranslateChapters { get; set; }
}

public record ParallelCorpus
{
public required string Id { get; set; }
public IReadOnlyList<ParallelCorpusSubcorpus> SourceCorpora { get; set; } = new List<ParallelCorpusSubcorpus>();
public IReadOnlyList<ParallelCorpusSubcorpus> TargetCorpora { get; set; } = new List<ParallelCorpusSubcorpus>();
}

public async Task<(int TrainCount, int PretranslateCount)> WriteDataFilesAsync(
string buildId,
IReadOnlyList<ParallelCorpus> corpora,
string? buildOptions,
CancellationToken cancellationToken
)
{
JsonObject? buildOptionsObject = null;
if (buildOptions is not null)
buildOptionsObject = JsonSerializer.Deserialize<JsonObject>(buildOptions);
await using StreamWriter sourceTrainWriter =
new(await _sharedFileService.OpenWriteAsync($"builds/{buildId}/train.src.txt", cancellationToken));
await using StreamWriter targetTrainWriter =
new(await _sharedFileService.OpenWriteAsync($"builds/{buildId}/train.trg.txt", cancellationToken));

await using Stream pretranslateStream = await _sharedFileService.OpenWriteAsync(
$"builds/{buildId}/pretranslate.src.json",
cancellationToken
);
await using Utf8JsonWriter pretranslateWriter = new(pretranslateStream, PretranslateWriterOptions);

int trainCount = 0;
int pretranslateCount = 0;
pretranslateWriter.WriteStartArray();
foreach (ParallelCorpus corpus in corpora)
{
(ParallelCorpusSubcorpus Subcorpus, ITextCorpus TextCorpus)[] sourceCorpora = corpus
.SourceCorpora.SelectMany(c => _corpusService.CreateTextCorpora(c.Files).Select(tc => (c, tc)))
.ToArray();
ITextCorpus[] sourceTrainingCorpora = sourceCorpora
.Select(sc =>
(
sc.Subcorpus,
sc.TextCorpus.FilterTexts(
(sc.Subcorpus.TrainOnTextIds ?? new()).Union((sc.Subcorpus.TrainOnChapters ?? new()).Keys)
)
)
)
.Select(sc =>
sc.Item2.FilterTextRows(row =>

Check failure on line 265 in src/Machine/src/Serval.Machine.Shared/Services/PreprocessBuildJob.cs

View workflow job for this annotation

GitHub Actions / Build

'ITextCorpus' does not contain a definition for 'FilterTextRows' and no accessible extension method 'FilterTextRows' accepting a first argument of type 'ITextCorpus' could be found (are you missing a using directive or an assembly reference?)

Check failure on line 265 in src/Machine/src/Serval.Machine.Shared/Services/PreprocessBuildJob.cs

View workflow job for this annotation

GitHub Actions / Build

'ITextCorpus' does not contain a definition for 'FilterTextRows' and no accessible extension method 'FilterTextRows' accepting a first argument of type 'ITextCorpus' could be found (are you missing a using directive or an assembly reference?)
row.Ref is not ScriptureRef sr
|| (
(sc.Subcorpus.TrainOnChapters ?? new()).TryGetValue(sr.Book, out HashSet<int>? chapters)
&& chapters != null
&& (chapters.Count == 0 || chapters.Contains(sr.ChapterNum))
)
)
)
.ToArray();
ITextCorpus[] sourcePretranslateCorpora = sourceCorpora
.Select(sc =>
(
sc.Subcorpus,
sc.TextCorpus.FilterTexts(
(sc.Subcorpus.PretranslateTextIds ?? new()).Union(
(sc.Subcorpus.PretranslateChapters ?? new()).Keys
)
)
)
)
.Select(sc =>
sc.Item2.FilterTextRows(row =>

Check failure on line 287 in src/Machine/src/Serval.Machine.Shared/Services/PreprocessBuildJob.cs

View workflow job for this annotation

GitHub Actions / Build

'ITextCorpus' does not contain a definition for 'FilterTextRows' and no accessible extension method 'FilterTextRows' accepting a first argument of type 'ITextCorpus' could be found (are you missing a using directive or an assembly reference?)

Check failure on line 287 in src/Machine/src/Serval.Machine.Shared/Services/PreprocessBuildJob.cs

View workflow job for this annotation

GitHub Actions / Build

'ITextCorpus' does not contain a definition for 'FilterTextRows' and no accessible extension method 'FilterTextRows' accepting a first argument of type 'ITextCorpus' could be found (are you missing a using directive or an assembly reference?)
row.Ref is not ScriptureRef sr
|| (
(sc.Subcorpus.PretranslateChapters ?? new()).TryGetValue(
sr.Book,
out HashSet<int>? chapters
)
&& chapters != null
&& (chapters.Count == 0 || chapters.Contains(sr.ChapterNum))
)
)
)
.ToArray();

(ParallelCorpusSubcorpus Subcorpus, ITextCorpus TextCorpus)[] targetCorpora = corpus
.TargetCorpora.SelectMany(c => _corpusService.CreateTextCorpora(c.Files).Select(tc => (c, tc)))
.ToArray();
ITextCorpus[] targetTrainingCorpora = targetCorpora
.Select(tc =>
(
tc.Subcorpus,
tc.TextCorpus.FilterTexts(
(tc.Subcorpus.TrainOnTextIds ?? new()).Union((tc.Subcorpus.TrainOnChapters ?? new()).Keys)
)
)
)
.Select(tc =>
tc.Item2.FilterTextRows(row =>

Check failure on line 314 in src/Machine/src/Serval.Machine.Shared/Services/PreprocessBuildJob.cs

View workflow job for this annotation

GitHub Actions / Build

'ITextCorpus' does not contain a definition for 'FilterTextRows' and no accessible extension method 'FilterTextRows' accepting a first argument of type 'ITextCorpus' could be found (are you missing a using directive or an assembly reference?)

Check failure on line 314 in src/Machine/src/Serval.Machine.Shared/Services/PreprocessBuildJob.cs

View workflow job for this annotation

GitHub Actions / Build

'ITextCorpus' does not contain a definition for 'FilterTextRows' and no accessible extension method 'FilterTextRows' accepting a first argument of type 'ITextCorpus' could be found (are you missing a using directive or an assembly reference?)
row.Ref is not ScriptureRef sr
|| (
(tc.Subcorpus.TrainOnChapters ?? new()).TryGetValue(sr.Book, out HashSet<int>? chapters)
&& chapters != null
&& (chapters.Count == 0 || chapters.Contains(sr.ChapterNum))
)
)
)
.ToArray();

if (sourceCorpora.Length == 0)
continue;

int skipCount = 0;
foreach (Row?[] rows in AlignTrainCorpus(sourceTrainingCorpora, targetTrainingCorpora))
{
if (skipCount > 0)
{
skipCount--;
continue;
}

Row[] trainRows = rows.Where(r => r is not null && IsInTrain(r, corpus)).Cast<Row>().ToArray();
if (trainRows.Length > 0)
{
Row row = trainRows[0];
if (rows.Length > 1)
{
Row[] nonEmptyRows = trainRows.Where(r => r.SourceSegment.Length > 0).ToArray();
Row[] targetNonEmptyRows = nonEmptyRows.Where(r => r.TargetSegment.Length > 0).ToArray();
if (targetNonEmptyRows.Length > 0)
nonEmptyRows = targetNonEmptyRows;
if (nonEmptyRows.Length > 0)
{
nonEmptyRows = nonEmptyRows
.GroupBy(r => r.SourceSegment)
.Select(group => group.First())
.ToArray();
row = nonEmptyRows[_random.Next(nonEmptyRows.Length)];
}
}

await sourceTrainWriter.WriteAsync($"{row.SourceSegment}\n");
await targetTrainWriter.WriteAsync($"{row.TargetSegment}\n");
skipCount = row.RowCount - 1;
if (row.SourceSegment.Length > 0 && row.TargetSegment.Length > 0)
trainCount++;
}
}

if ((bool?)buildOptionsObject?["use_key_terms"] ?? true)
{
ITextCorpus? sourceTermCorpus = _corpusService
.CreateTermCorpora(corpus.SourceCorpora.SelectMany(sc => sc.Files).ToList())
.FirstOrDefault();
ITextCorpus? targetTermCorpus = _corpusService
.CreateTermCorpora(corpus.TargetCorpora.SelectMany(tc => tc.Files).ToList())
.FirstOrDefault();
if (sourceTermCorpus is not null && targetTermCorpus is not null)
{
IParallelTextCorpus parallelKeyTermsCorpus = sourceTermCorpus.AlignRows(targetTermCorpus);
foreach (ParallelTextRow row in parallelKeyTermsCorpus)
{
await sourceTrainWriter.WriteAsync($"{row.SourceText}\n");
await targetTrainWriter.WriteAsync($"{row.TargetText}\n");
trainCount++;
}
}
}

foreach (Row row in AlignPretranslateCorpus(sourcePretranslateCorpora, targetCorpora[0].TextCorpus))
{
if (
IsInPretranslate(row, corpus)
&& row.SourceSegment.Length > 0
&& (row.TargetSegment.Length == 0 || !IsInTrain(row, corpus))
)
{
pretranslateWriter.WriteStartObject();
pretranslateWriter.WriteString("corpusId", corpus.Id);
pretranslateWriter.WriteString("textId", row.TextId);
pretranslateWriter.WriteStartArray("refs");
foreach (object rowRef in row.Refs)
pretranslateWriter.WriteStringValue(rowRef.ToString());
pretranslateWriter.WriteEndArray();
pretranslateWriter.WriteString("translation", row.SourceSegment);
pretranslateWriter.WriteEndObject();
pretranslateCount++;
}
}
}

pretranslateWriter.WriteEndArray();

return (trainCount, pretranslateCount);
}

protected override async Task CleanupAsync(
string engineId,
string buildId,
Expand All @@ -231,11 +435,22 @@ private static bool IsInTrain(Row row, Corpus corpus)
return IsIncluded(row, corpus.TrainOnTextIds, corpus.TrainOnChapters);
}

private static bool IsInTrain(Row row, ParallelCorpus corpus)
{
return corpus.SourceCorpora.Any(sc => IsIncluded(row, sc.TrainOnTextIds, sc.TrainOnChapters))
&& corpus.TargetCorpora.Any(tc => IsIncluded(row, tc.TrainOnTextIds, tc.TrainOnChapters));
}

private static bool IsInPretranslate(Row row, Corpus corpus)
{
return IsIncluded(row, corpus.PretranslateTextIds, corpus.PretranslateChapters);
}

private static bool IsInPretranslate(Row row, ParallelCorpus corpus)
{
return corpus.SourceCorpora.Any(sc => IsIncluded(row, sc.PretranslateTextIds, sc.PretranslateChapters));
}

private static bool IsIncluded(
Row? row,
IReadOnlySet<string>? textIds,
Expand Down Expand Up @@ -302,6 +517,45 @@ ITextCorpus trgCorpus
.Where(rows => rows.Any(r => r.SourceSegment.Length > 0 || r.TargetSegment.Length > 0));
}

private static IEnumerable<Row?[]> AlignTrainCorpus(
IReadOnlyList<ITextCorpus> srcCorpora,
IReadOnlyList<ITextCorpus> trgCorpora
)
{
srcCorpora = srcCorpora.Select(sc => sc.Transform(CleanSegment)).ToArray();
trgCorpora = trgCorpora.Select(tc => tc.Transform(CleanSegment)).ToArray();

if (trgCorpora.All(tc => tc.IsScripture()))
{
return srcCorpora
.SelectMany(sc => trgCorpora.Select(tc => AlignScripture(sc, tc)))
.ZipMany(rows => rows.ToArray())
// filter out every list that only contains completely empty rows
.Where(rows => rows.Any(r => r is null || r.SourceSegment.Length > 0 || r.TargetSegment.Length > 0));
}

IEnumerable<Row[]> sourceOnlyRows = srcCorpora
.SelectMany(sc => trgCorpora.Select(tc => sc.AlignRows(tc, allSourceRows: true)))
.ZipMany(rows =>
rows.Where(r => r.TargetSegment.Count == 0)
.Select(r => new Row(r.TextId, r.Refs, r.SourceText, r.TargetText, 1))
.ToArray()
);

IEnumerable<Row[]> targetRows = srcCorpora
.SelectMany(sc => trgCorpora.Select(tc => sc.AlignRows(tc, allTargetRows: true)))
.ZipMany(rows =>
rows.Where(r => r.TargetSegment.Count > 0)
.Select(r => new Row(r.TextId, r.Refs, r.SourceText, r.TargetText, 1))
.ToArray()
);

return sourceOnlyRows
.Concat(targetRows)
// filter out every list that only contains completely empty rows
.Where(rows => rows.Any(r => r.SourceSegment.Length > 0 || r.TargetSegment.Length > 0));
}

private static IEnumerable<Row?> AlignScripture(ITextCorpus srcCorpus, ITextCorpus trgCorpus)
{
int rowCount = 0;
Expand Down Expand Up @@ -433,6 +687,50 @@ private static IEnumerable<Row> AlignPretranslateCorpus(Corpus corpus, ITextCorp
yield return new(textId, refs, srcSegBuffer.ToString(), trgSegBuffer.ToString(), 1);
}

private static IEnumerable<Row> AlignPretranslateCorpus(ITextCorpus[] srcCorpora, ITextCorpus trgCorpus)
{
int rowCount = 0;
StringBuilder srcSegBuffer = new();
StringBuilder trgSegBuffer = new();
List<object> refs = [];
string textId = "";
foreach (ParallelTextRow row in srcCorpora.SelectMany(sc => sc.AlignRows(trgCorpus, allSourceRows: true)))
{
if (!row.IsTargetRangeStart && row.IsTargetInRange)
{
refs.AddRange(row.TargetRefs);
if (row.SourceText.Length > 0)
{
if (srcSegBuffer.Length > 0)
srcSegBuffer.Append(' ');
srcSegBuffer.Append(row.SourceText);
}
rowCount++;
}
else
{
if (rowCount > 0)
{
yield return new(textId, refs, srcSegBuffer.ToString(), trgSegBuffer.ToString(), 1);
textId = "";
srcSegBuffer.Clear();
trgSegBuffer.Clear();
refs.Clear();
rowCount = 0;
}

textId = row.TextId;
refs.AddRange(row.TargetRefs);
srcSegBuffer.Append(row.SourceText);
trgSegBuffer.Append(row.TargetText);
rowCount++;
}
}

if (rowCount > 0)
yield return new(textId, refs, srcSegBuffer.ToString(), trgSegBuffer.ToString(), 1);
}

private record Row(
string TextId,
IReadOnlyList<object> Refs,
Expand Down
Loading

0 comments on commit 84d78ed

Please sign in to comment.