-
-
Notifications
You must be signed in to change notification settings - Fork 15
/
NmtEngineService.cs
210 lines (192 loc) · 7.75 KB
/
NmtEngineService.cs
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
namespace SIL.Machine.AspNetCore.Services;
public static class NmtBuildStages
{
public const string Preprocess = "preprocess";
public const string Train = "train";
public const string Postprocess = "postprocess";
}
public class NmtEngineService(
IPlatformService platformService,
IDistributedReaderWriterLockFactory lockFactory,
IDataAccessContext dataAccessContext,
IRepository<TranslationEngine> engines,
IBuildJobService buildJobService,
ILanguageTagService languageTagService,
ClearMLMonitorService clearMLMonitorService,
ISharedFileService sharedFileService
) : ITranslationEngineService
{
private readonly IDistributedReaderWriterLockFactory _lockFactory = lockFactory;
private readonly IPlatformService _platformService = platformService;
private readonly IDataAccessContext _dataAccessContext = dataAccessContext;
private readonly IRepository<TranslationEngine> _engines = engines;
private readonly IBuildJobService _buildJobService = buildJobService;
private readonly ClearMLMonitorService _clearMLMonitorService = clearMLMonitorService;
private readonly ILanguageTagService _languageTagService = languageTagService;
private readonly ISharedFileService _sharedFileService = sharedFileService;
public TranslationEngineType Type => TranslationEngineType.Nmt;
private const int MinutesToExpire = 60;
public async Task CreateAsync(
string engineId,
string? engineName,
string sourceLanguage,
string targetLanguage,
bool isModelPersisted = false,
CancellationToken cancellationToken = default
)
{
await _dataAccessContext.BeginTransactionAsync(cancellationToken);
await _engines.InsertAsync(
new TranslationEngine
{
EngineId = engineId,
SourceLanguage = sourceLanguage,
TargetLanguage = targetLanguage,
IsModelPersisted = isModelPersisted
},
cancellationToken
);
await _buildJobService.CreateEngineAsync(
[BuildJobType.Cpu, BuildJobType.Gpu],
engineId,
engineName,
cancellationToken
);
await _dataAccessContext.CommitTransactionAsync(CancellationToken.None);
}
public async Task DeleteAsync(string engineId, CancellationToken cancellationToken = default)
{
IDistributedReaderWriterLock @lock = await _lockFactory.CreateAsync(engineId, cancellationToken);
await using (await @lock.WriterLockAsync(cancellationToken: cancellationToken))
{
await CancelBuildJobAsync(engineId, cancellationToken);
await _engines.DeleteAsync(e => e.EngineId == engineId, cancellationToken);
await _buildJobService.DeleteEngineAsync(
new[] { BuildJobType.Cpu, BuildJobType.Gpu },
engineId,
CancellationToken.None
);
}
await _lockFactory.DeleteAsync(engineId, CancellationToken.None);
}
public async Task StartBuildAsync(
string engineId,
string buildId,
string? buildOptions,
IReadOnlyList<Corpus> corpora,
CancellationToken cancellationToken = default
)
{
IDistributedReaderWriterLock @lock = await _lockFactory.CreateAsync(engineId, cancellationToken);
await using (await @lock.WriterLockAsync(cancellationToken: cancellationToken))
{
// If there is a pending/running build, then no need to start a new one.
if (await _buildJobService.IsEngineBuilding(engineId, cancellationToken))
throw new InvalidOperationException("The engine is already building or in the process of canceling.");
await _buildJobService.StartBuildJobAsync(
BuildJobType.Cpu,
TranslationEngineType.Nmt,
engineId,
buildId,
NmtBuildStages.Preprocess,
corpora,
buildOptions,
cancellationToken
);
}
}
public async Task CancelBuildAsync(string engineId, CancellationToken cancellationToken = default)
{
IDistributedReaderWriterLock @lock = await _lockFactory.CreateAsync(engineId, cancellationToken);
await using (await @lock.WriterLockAsync(cancellationToken: cancellationToken))
{
await CancelBuildJobAsync(engineId, cancellationToken);
}
}
public async Task<ModelDownloadUrl> GetModelDownloadUrlAsync(
string engineId,
CancellationToken cancellationToken = default
)
{
TranslationEngine engine = await GetEngineAsync(engineId, cancellationToken);
if (!engine.IsModelPersisted)
throw new InvalidOperationException(
"The model cannot be downloaded. "
+ "To enable downloading the model, recreate the engine with IsModelPersisted property to true."
);
if (engine.BuildRevision == 0)
throw new InvalidOperationException("The engine has not been built yet.");
string filename = $"{engineId}_{engine.BuildRevision}.tar.gz";
bool fileExists = await _sharedFileService.ExistsAsync(
ISharedFileService.ModelDirectory + filename,
cancellationToken
);
if (!fileExists)
throw new FileNotFoundException(
$"The model should exist to be downloaded but is not there for BuildRevision {engine.BuildRevision}."
);
var modelInfo = new ModelDownloadUrl
{
Url = (
await _sharedFileService.GetPresignedUrlAsync(
ISharedFileService.ModelDirectory + filename,
MinutesToExpire
)
).ToString(),
ModelRevision = engine.BuildRevision,
ExipiresAt = DateTime.UtcNow.AddMinutes(MinutesToExpire)
};
return modelInfo;
}
public Task<IReadOnlyList<TranslationResult>> TranslateAsync(
string engineId,
int n,
string segment,
CancellationToken cancellationToken = default
)
{
throw new NotSupportedException();
}
public Task<WordGraph> GetWordGraphAsync(
string engineId,
string segment,
CancellationToken cancellationToken = default
)
{
throw new NotSupportedException();
}
public Task TrainSegmentPairAsync(
string engineId,
string sourceSegment,
string targetSegment,
bool sentenceStart,
CancellationToken cancellationToken = default
)
{
throw new NotSupportedException();
}
public Task<int> GetQueueSizeAsync(CancellationToken cancellationToken = default)
{
return Task.FromResult(_clearMLMonitorService.QueueSize);
}
public bool IsLanguageNativeToModel(string language, out string internalCode)
{
return _languageTagService.ConvertToFlores200Code(language, out internalCode);
}
private async Task CancelBuildJobAsync(string engineId, CancellationToken cancellationToken)
{
(string? buildId, BuildJobState jobState) = await _buildJobService.CancelBuildJobAsync(
engineId,
cancellationToken
);
if (buildId is not null && jobState is BuildJobState.None)
await _platformService.BuildCanceledAsync(buildId, CancellationToken.None);
}
private async Task<TranslationEngine> GetEngineAsync(string engineId, CancellationToken cancellationToken)
{
TranslationEngine? engine = await _engines.GetAsync(e => e.EngineId == engineId, cancellationToken);
if (engine is null)
throw new InvalidOperationException($"The engine {engineId} does not exist.");
return engine;
}
}