From 525950f0f102dd4bed426865f74567d1a6d054b1 Mon Sep 17 00:00:00 2001 From: John Lambert Date: Fri, 9 Feb 2024 15:18:21 -0500 Subject: [PATCH] Nmt download (#164) * Presigned URL code cleaning script IsModelPersisted nullable Return IsModelPersistedState to Serval Check for modelRevision + 1 in cleanup and just delete without keeping internal states. * Reviewer comments * update to most recent api --- .../IMachineBuilderExtensions.cs | 35 ++++--- .../Models/ModelDownloadUrl.cs | 11 +++ .../Models/TranslationEngine.cs | 1 + .../SIL.Machine.AspNetCore.csproj | 2 +- .../Services/HangfireHealthCheck.cs | 12 +-- .../Services/IFileStorage.cs | 2 + .../Services/ISharedFileService.cs | 8 ++ .../Services/ITranslationEngineService.cs | 5 +- .../Services/InMemoryStorage.cs | 9 ++ .../Services/LocalStorage.cs | 9 ++ .../Services/ModelCleanupService.cs | 59 +++++++++++ .../Services/NmtClearMLBuildJobFactory.cs | 31 +++--- .../Services/NmtEngineService.cs | 75 +++++++++++--- .../Services/NmtTrainBuildJob.cs | 7 ++ .../Services/S3FileStorage.cs | 22 +++++ .../ServalTranslationEngineServiceV1.cs | 32 +++++- .../Services/SharedFileService.cs | 14 +++ .../Services/SmtTransferEngineService.cs | 84 ++++++++-------- src/SIL.Machine.AspNetCore/Usings.cs | 2 + .../Program.cs | 3 + .../Services/ModelCleanupServiceTests.cs | 97 +++++++++++++++++++ .../Services/NmtEngineServiceTests.cs | 11 ++- .../Services/SmtTransferEngineServiceTests.cs | 10 +- tests/SIL.Machine.AspNetCore.Tests/Usings.cs | 1 + 24 files changed, 437 insertions(+), 105 deletions(-) create mode 100644 src/SIL.Machine.AspNetCore/Models/ModelDownloadUrl.cs create mode 100644 src/SIL.Machine.AspNetCore/Services/ModelCleanupService.cs create mode 100644 tests/SIL.Machine.AspNetCore.Tests/Services/ModelCleanupServiceTests.cs diff --git a/src/SIL.Machine.AspNetCore/Configuration/IMachineBuilderExtensions.cs b/src/SIL.Machine.AspNetCore/Configuration/IMachineBuilderExtensions.cs index d1fe54c1..0f17f80e 100644 --- a/src/SIL.Machine.AspNetCore/Configuration/IMachineBuilderExtensions.cs +++ b/src/SIL.Machine.AspNetCore/Configuration/IMachineBuilderExtensions.cs @@ -151,6 +151,21 @@ private static IMachineBuilder AddHangfireBuildJobRunner(this IMachineBuilder bu return builder; } + private static MongoStorageOptions GetMongoStorageOptions() + { + var mongoStorageOptions = new MongoStorageOptions + { + MigrationOptions = new MongoMigrationOptions + { + MigrationStrategy = new MigrateMongoMigrationStrategy(), + BackupStrategy = new CollectionMongoBackupStrategy() + }, + CheckConnection = true, + CheckQueuedJobsStrategy = CheckQueuedJobsStrategy.TailNotificationsCollection, + }; + return mongoStorageOptions; + } + public static IMachineBuilder AddMongoHangfireJobClient( this IMachineBuilder builder, string? connectionString = null @@ -164,19 +179,7 @@ public static IMachineBuilder AddMongoHangfireJobClient( c.SetDataCompatibilityLevel(CompatibilityLevel.Version_170) .UseSimpleAssemblyNameTypeSerializer() .UseRecommendedSerializerSettings() - .UseMongoStorage( - connectionString, - new MongoStorageOptions - { - MigrationOptions = new MongoMigrationOptions - { - MigrationStrategy = new MigrateMongoMigrationStrategy(), - BackupStrategy = new CollectionMongoBackupStrategy() - }, - CheckConnection = true, - CheckQueuedJobsStrategy = CheckQueuedJobsStrategy.TailNotificationsCollection, - } - ) + .UseMongoStorage(connectionString, GetMongoStorageOptions()) .UseFilter(new AutomaticRetryAttribute { Attempts = 0 }) ); builder.Services.AddHealthChecks().AddCheck(name: "Hangfire"); @@ -402,6 +405,12 @@ public static IMachineBuilder AddBuildJobService(this IMachineBuilder builder) return builder; } + public static IMachineBuilder AddModelCleanupService(this IMachineBuilder builder) + { + builder.Services.AddHostedService(); + return builder; + } + private static IMachineBuilder AddBuildJobService(this IMachineBuilder builder, BuildJobOptions options) { builder.Services.AddScoped(); diff --git a/src/SIL.Machine.AspNetCore/Models/ModelDownloadUrl.cs b/src/SIL.Machine.AspNetCore/Models/ModelDownloadUrl.cs new file mode 100644 index 00000000..f48d7301 --- /dev/null +++ b/src/SIL.Machine.AspNetCore/Models/ModelDownloadUrl.cs @@ -0,0 +1,11 @@ +using System; + +namespace SIL.Machine.AspNetCore.Models +{ + public class ModelDownloadUrl + { + public string Url { get; set; } = default!; + public int ModelRevision { get; set; } = default!; + public DateTime ExipiresAt { get; set; } = default!; + } +} diff --git a/src/SIL.Machine.AspNetCore/Models/TranslationEngine.cs b/src/SIL.Machine.AspNetCore/Models/TranslationEngine.cs index ffc639fc..07b94f75 100644 --- a/src/SIL.Machine.AspNetCore/Models/TranslationEngine.cs +++ b/src/SIL.Machine.AspNetCore/Models/TranslationEngine.cs @@ -7,6 +7,7 @@ public class TranslationEngine : IEntity public string EngineId { get; set; } = default!; public string SourceLanguage { get; set; } = default!; public string TargetLanguage { get; set; } = default!; + public bool IsModelPersisted { get; set; } public int BuildRevision { get; set; } public Build? CurrentBuild { get; set; } } diff --git a/src/SIL.Machine.AspNetCore/SIL.Machine.AspNetCore.csproj b/src/SIL.Machine.AspNetCore/SIL.Machine.AspNetCore.csproj index c06937b5..d45b3e2b 100644 --- a/src/SIL.Machine.AspNetCore/SIL.Machine.AspNetCore.csproj +++ b/src/SIL.Machine.AspNetCore/SIL.Machine.AspNetCore.csproj @@ -35,7 +35,7 @@ - + diff --git a/src/SIL.Machine.AspNetCore/Services/HangfireHealthCheck.cs b/src/SIL.Machine.AspNetCore/Services/HangfireHealthCheck.cs index 73bffff3..1d011c57 100644 --- a/src/SIL.Machine.AspNetCore/Services/HangfireHealthCheck.cs +++ b/src/SIL.Machine.AspNetCore/Services/HangfireHealthCheck.cs @@ -1,15 +1,9 @@ namespace SIL.Machine.AspNetCore.Services; -public class HangfireHealthCheck : IHealthCheck +public class HangfireHealthCheck(JobStorage jobStorage, IOptions options) : IHealthCheck { - private readonly JobStorage _jobStorage; - private readonly IOptions _options; - - public HangfireHealthCheck(JobStorage jobStorage, IOptions options) - { - _jobStorage = jobStorage; - _options = options; - } + private readonly JobStorage _jobStorage = jobStorage; + private readonly IOptions _options = options; public Task CheckHealthAsync( HealthCheckContext context, diff --git a/src/SIL.Machine.AspNetCore/Services/IFileStorage.cs b/src/SIL.Machine.AspNetCore/Services/IFileStorage.cs index 89a15ccc..3417cffa 100644 --- a/src/SIL.Machine.AspNetCore/Services/IFileStorage.cs +++ b/src/SIL.Machine.AspNetCore/Services/IFileStorage.cs @@ -14,5 +14,7 @@ Task> ListFilesAsync( Task OpenWriteAsync(string path, CancellationToken cancellationToken = default); + Task GetDownloadUrlAsync(string path, DateTime expiresAt, CancellationToken cancellationToken = default); + Task DeleteAsync(string path, bool recurse = false, CancellationToken cancellationToken = default); } diff --git a/src/SIL.Machine.AspNetCore/Services/ISharedFileService.cs b/src/SIL.Machine.AspNetCore/Services/ISharedFileService.cs index acbac068..f082a79c 100644 --- a/src/SIL.Machine.AspNetCore/Services/ISharedFileService.cs +++ b/src/SIL.Machine.AspNetCore/Services/ISharedFileService.cs @@ -6,6 +6,14 @@ public interface ISharedFileService Uri GetResolvedUri(string path); + Task GetDownloadUrlAsync(string path, DateTime expiresAt); + + Task> ListFilesAsync( + string path, + bool recurse = false, + CancellationToken cancellationToken = default + ); + Task OpenReadAsync(string path, CancellationToken cancellationToken = default); Task OpenWriteAsync(string path, CancellationToken cancellationToken = default); diff --git a/src/SIL.Machine.AspNetCore/Services/ITranslationEngineService.cs b/src/SIL.Machine.AspNetCore/Services/ITranslationEngineService.cs index 2fea9430..c1238b1e 100644 --- a/src/SIL.Machine.AspNetCore/Services/ITranslationEngineService.cs +++ b/src/SIL.Machine.AspNetCore/Services/ITranslationEngineService.cs @@ -4,11 +4,12 @@ public interface ITranslationEngineService { TranslationEngineType Type { get; } - Task CreateAsync( + Task CreateAsync( string engineId, string? engineName, string sourceLanguage, string targetLanguage, + bool? isModelPersisted = null, CancellationToken cancellationToken = default ); Task DeleteAsync(string engineId, CancellationToken cancellationToken = default); @@ -40,6 +41,8 @@ Task StartBuildAsync( Task CancelBuildAsync(string engineId, CancellationToken cancellationToken = default); + Task GetModelDownloadUrlAsync(string engineId, CancellationToken cancellationToken = default); + Task GetQueueSizeAsync(CancellationToken cancellationToken = default); bool IsLanguageNativeToModel(string language, out string internalCode); diff --git a/src/SIL.Machine.AspNetCore/Services/InMemoryStorage.cs b/src/SIL.Machine.AspNetCore/Services/InMemoryStorage.cs index 76755f24..e92109a3 100644 --- a/src/SIL.Machine.AspNetCore/Services/InMemoryStorage.cs +++ b/src/SIL.Machine.AspNetCore/Services/InMemoryStorage.cs @@ -96,6 +96,15 @@ public Task> ListFilesAsync( ); } + public Task GetDownloadUrlAsync( + string path, + DateTime expiresAt, + CancellationToken cancellationToken = default + ) + { + throw new NotSupportedException(); + } + public Task OpenReadAsync(string path, CancellationToken cancellationToken = default) { if (!_memoryStreams.TryGetValue(Normalize(path), out Entry? ret)) diff --git a/src/SIL.Machine.AspNetCore/Services/LocalStorage.cs b/src/SIL.Machine.AspNetCore/Services/LocalStorage.cs index 6826869e..9fc26c09 100644 --- a/src/SIL.Machine.AspNetCore/Services/LocalStorage.cs +++ b/src/SIL.Machine.AspNetCore/Services/LocalStorage.cs @@ -36,6 +36,15 @@ public Task> ListFilesAsync( ); } + public Task GetDownloadUrlAsync( + string path, + DateTime expiresAt, + CancellationToken cancellationToken = default + ) + { + throw new NotSupportedException(); + } + public Task OpenReadAsync(string path, CancellationToken cancellationToken = default) { Uri pathUri = new(_basePath, Normalize(path)); diff --git a/src/SIL.Machine.AspNetCore/Services/ModelCleanupService.cs b/src/SIL.Machine.AspNetCore/Services/ModelCleanupService.cs new file mode 100644 index 00000000..21cf753d --- /dev/null +++ b/src/SIL.Machine.AspNetCore/Services/ModelCleanupService.cs @@ -0,0 +1,59 @@ +namespace SIL.Machine.AspNetCore.Services; + +public class ModelCleanupService( + IServiceProvider services, + ISharedFileService sharedFileService, + IRepository engines, + ILogger logger +) : RecurrentTask("Model Cleanup Service", services, RefreshPeriod, logger) +{ + private readonly ISharedFileService _sharedFileService = sharedFileService; + private readonly ILogger _logger = logger; + private readonly IRepository _engines = engines; + private static readonly TimeSpan RefreshPeriod = TimeSpan.FromDays(1); + + protected override async Task DoWorkAsync(IServiceScope scope, CancellationToken cancellationToken) + { + await CheckModelsAsync(cancellationToken); + } + + private async Task CheckModelsAsync(CancellationToken cancellationToken) + { + _logger.LogInformation("Running model cleanup job"); + IReadOnlyCollection paths = await _sharedFileService.ListFilesAsync( + NmtEngineService.ModelDirectory, + cancellationToken: cancellationToken + ); + // Get all engine ids from the database + IReadOnlyList? allEngines = await _engines.GetAllAsync(cancellationToken: cancellationToken); + IEnumerable validFilenames = allEngines.Select(e => + NmtEngineService.GetModelPath(e.EngineId, e.BuildRevision) + ); + // If there is a currently running build that creates and pushes a new file, but the database has not + // updated yet, don't delete the new file. + IEnumerable validFilenamesForNextBuild = allEngines.Select(e => + NmtEngineService.GetModelPath(e.EngineId, e.BuildRevision + 1) + ); + HashSet filenameFilter = validFilenames.Concat(validFilenamesForNextBuild).ToHashSet(); + + foreach (string path in paths) + { + if (!filenameFilter.Contains(path)) + { + await DeleteFileAsync( + path, + $"file in S3 bucket not found in database. It may be an old rev, etc.", + cancellationToken + ); + } + } + } + + private async Task DeleteFileAsync(string path, string message, CancellationToken cancellationToken = default) + { + // This may delete a file while it is being downloaded, but the chance is rare + // enough and the solution easy enough (just download again) to just live with it. + _logger.LogInformation("Deleting old model file {filename}: {message}", path, message); + await _sharedFileService.DeleteAsync(path, cancellationToken); + } +} diff --git a/src/SIL.Machine.AspNetCore/Services/NmtClearMLBuildJobFactory.cs b/src/SIL.Machine.AspNetCore/Services/NmtClearMLBuildJobFactory.cs index 927310e7..dfc8423e 100644 --- a/src/SIL.Machine.AspNetCore/Services/NmtClearMLBuildJobFactory.cs +++ b/src/SIL.Machine.AspNetCore/Services/NmtClearMLBuildJobFactory.cs @@ -1,24 +1,16 @@ namespace SIL.Machine.AspNetCore.Services; -public class NmtClearMLBuildJobFactory : IClearMLBuildJobFactory +public class NmtClearMLBuildJobFactory( + ISharedFileService sharedFileService, + ILanguageTagService languageTagService, + IRepository engines, + IOptionsMonitor options +) : IClearMLBuildJobFactory { - private readonly ISharedFileService _sharedFileService; - private readonly ILanguageTagService _languageTagService; - private readonly IRepository _engines; - private readonly IOptionsMonitor _options; - - public NmtClearMLBuildJobFactory( - ISharedFileService sharedFileService, - ILanguageTagService languageTagService, - IRepository engines, - IOptionsMonitor options - ) - { - _sharedFileService = sharedFileService; - _languageTagService = languageTagService; - _engines = engines; - _options = options; - } + private readonly ISharedFileService _sharedFileService = sharedFileService; + private readonly ILanguageTagService _languageTagService = languageTagService; + private readonly IRepository _engines = engines; + private readonly IOptionsMonitor _options = options; public TranslationEngineType EngineType => TranslationEngineType.Nmt; @@ -52,6 +44,9 @@ public async Task CreateJobScriptAsync( + $" 'shared_file_uri': '{baseUri}',\n" + $" 'shared_file_folder': '{folder}',\n" + (buildOptions is not null ? $" 'build_options': '''{buildOptions}''',\n" : "") + // buildRevision + 1 because the build revision is incremented after the build job + // is finished successfully but the file should be saved with the new revision number + + (engine.IsModelPersisted ? $" 'save_model': '{engineId}_{engine.BuildRevision + 1}',\n" : $"") + $" 'clearml': True\n" + "}\n" + "run(args)\n"; diff --git a/src/SIL.Machine.AspNetCore/Services/NmtEngineService.cs b/src/SIL.Machine.AspNetCore/Services/NmtEngineService.cs index 8ca475e7..355e042a 100644 --- a/src/SIL.Machine.AspNetCore/Services/NmtEngineService.cs +++ b/src/SIL.Machine.AspNetCore/Services/NmtEngineService.cs @@ -14,7 +14,8 @@ public class NmtEngineService( IRepository engines, IBuildJobService buildJobService, ILanguageTagService languageTagService, - ClearMLMonitorService clearMLMonitorService + ClearMLMonitorService clearMLMonitorService, + ISharedFileService sharedFileService ) : ITranslationEngineService { private readonly IDistributedReaderWriterLockFactory _lockFactory = lockFactory; @@ -22,36 +23,47 @@ ClearMLMonitorService clearMLMonitorService private readonly IDataAccessContext _dataAccessContext = dataAccessContext; private readonly IRepository _engines = engines; private readonly IBuildJobService _buildJobService = buildJobService; - private readonly ILanguageTagService _languageTagService = languageTagService; private readonly ClearMLMonitorService _clearMLMonitorService = clearMLMonitorService; + private readonly ILanguageTagService _languageTagService = languageTagService; + private readonly ISharedFileService _sharedFileService = sharedFileService; + + public const string ModelDirectory = "models/"; + + public static string GetModelPath(string engineId, int buildRevision) + { + return $"{ModelDirectory}{engineId}_{buildRevision}.tar.gz"; + } public TranslationEngineType Type => TranslationEngineType.Nmt; - public async Task CreateAsync( + private const int MinutesToExpire = 60; + + public async Task CreateAsync( string engineId, string? engineName, string sourceLanguage, string targetLanguage, + bool? isModelPersisted = null, CancellationToken cancellationToken = default ) { await _dataAccessContext.BeginTransactionAsync(cancellationToken); - await _engines.InsertAsync( - new TranslationEngine - { - EngineId = engineId, - SourceLanguage = sourceLanguage, - TargetLanguage = targetLanguage - }, - cancellationToken - ); + var translationEngine = new TranslationEngine + { + EngineId = engineId, + SourceLanguage = sourceLanguage, + TargetLanguage = targetLanguage, + IsModelPersisted = isModelPersisted ?? false // models are not persisted if not specified + }; + await _engines.InsertAsync(translationEngine, cancellationToken); await _buildJobService.CreateEngineAsync( - new[] { BuildJobType.Cpu, BuildJobType.Gpu }, + [BuildJobType.Cpu, BuildJobType.Gpu], engineId, engineName, cancellationToken ); await _dataAccessContext.CommitTransactionAsync(CancellationToken.None); + return translationEngine; } public async Task DeleteAsync(string engineId, CancellationToken cancellationToken = default) @@ -109,6 +121,35 @@ public async Task CancelBuildAsync(string engineId, CancellationToken cancellati } } + public async Task GetModelDownloadUrlAsync( + string engineId, + CancellationToken cancellationToken = default + ) + { + TranslationEngine engine = await GetEngineAsync(engineId, cancellationToken); + if (engine.IsModelPersisted != true) + throw new NotSupportedException( + "The model cannot be downloaded. " + + "To enable downloading the model, recreate the engine with IsModelPersisted property to true." + ); + if (engine.BuildRevision == 0) + throw new InvalidOperationException("The engine has not been built yet."); + string filepath = GetModelPath(engineId, engine.BuildRevision); + bool fileExists = await _sharedFileService.ExistsAsync(filepath, cancellationToken); + if (!fileExists) + throw new FileNotFoundException( + $"The model should exist to be downloaded but is not there for BuildRevision {engine.BuildRevision}." + ); + var expiresAt = DateTime.UtcNow.AddMinutes(MinutesToExpire); + var modelInfo = new ModelDownloadUrl + { + Url = await _sharedFileService.GetDownloadUrlAsync(filepath, expiresAt), + ModelRevision = engine.BuildRevision, + ExipiresAt = expiresAt + }; + return modelInfo; + } + public Task> TranslateAsync( string engineId, int n, @@ -159,4 +200,12 @@ private async Task CancelBuildJobAsync(string engineId, CancellationToken await _platformService.BuildCanceledAsync(buildId, CancellationToken.None); return buildId is not null; } + + private async Task GetEngineAsync(string engineId, CancellationToken cancellationToken) + { + TranslationEngine? engine = await _engines.GetAsync(e => e.EngineId == engineId, cancellationToken); + if (engine is null) + throw new InvalidOperationException($"The engine {engineId} does not exist."); + return engine; + } } diff --git a/src/SIL.Machine.AspNetCore/Services/NmtTrainBuildJob.cs b/src/SIL.Machine.AspNetCore/Services/NmtTrainBuildJob.cs index ad6d6061..8ed1020d 100644 --- a/src/SIL.Machine.AspNetCore/Services/NmtTrainBuildJob.cs +++ b/src/SIL.Machine.AspNetCore/Services/NmtTrainBuildJob.cs @@ -68,6 +68,13 @@ await PipInstallModuleAsync( + $" 'trg_lang': '{ConvertLanguageTag(engine.TargetLanguage)}',\n" + $" 'shared_file_uri': '{_sharedFileService.GetBaseUri()}',\n" + (buildOptions is not null ? $" 'build_options': '''{buildOptions}''',\n" : "") + // buildRevision + 1 because the build revision is incremented after the build job + // is finished successfully but the file should be saved with the new revision number + + ( + engine.IsModelPersisted + ? $" 'save_model': '{engine.Id}_{engine.BuildRevision + 1}',\n" + : "" + ) + $" 'clearml': False\n" + "}\n" + "run(args)\n" diff --git a/src/SIL.Machine.AspNetCore/Services/S3FileStorage.cs b/src/SIL.Machine.AspNetCore/Services/S3FileStorage.cs index 3df6c67c..17683140 100644 --- a/src/SIL.Machine.AspNetCore/Services/S3FileStorage.cs +++ b/src/SIL.Machine.AspNetCore/Services/S3FileStorage.cs @@ -65,6 +65,28 @@ public async Task> ListFilesAsync( return response.S3Objects.Select(s3Obj => s3Obj.Key[_basePath.Length..]).ToList(); } + public Task GetDownloadUrlAsync( + string path, + DateTime expiresAt, + CancellationToken cancellationToken = default + ) + { + return Task.FromResult( + _client.GetPreSignedURL( + new GetPreSignedUrlRequest + { + BucketName = _bucketName, + Key = _basePath + Normalize(path), + Expires = expiresAt, + ResponseHeaderOverrides = new ResponseHeaderOverrides + { + ContentDisposition = new ContentDisposition() { FileName = Path.GetFileName(path) }.ToString() + } + } + ) + ); + } + public async Task OpenReadAsync(string path, CancellationToken cancellationToken = default) { GetObjectRequest request = new() { BucketName = _bucketName, Key = _basePath + Normalize(path) }; diff --git a/src/SIL.Machine.AspNetCore/Services/ServalTranslationEngineServiceV1.cs b/src/SIL.Machine.AspNetCore/Services/ServalTranslationEngineServiceV1.cs index c1c7fbf5..680ec5f6 100644 --- a/src/SIL.Machine.AspNetCore/Services/ServalTranslationEngineServiceV1.cs +++ b/src/SIL.Machine.AspNetCore/Services/ServalTranslationEngineServiceV1.cs @@ -15,17 +15,18 @@ HealthCheckService healthCheckService private readonly HealthCheckService _healthCheckService = healthCheckService; - public override async Task Create(CreateRequest request, ServerCallContext context) + public override async Task Create(CreateRequest request, ServerCallContext context) { ITranslationEngineService engineService = GetEngineService(request.EngineType); - await engineService.CreateAsync( + TranslationEngine translationEngine = await engineService.CreateAsync( request.EngineId, request.HasEngineName ? request.EngineName : null, request.SourceLanguage, request.TargetLanguage, + request.HasIsModelPersisted ? request.IsModelPersisted : null, context.CancellationToken ); - return Empty; + return new CreateResponse { IsModelPersisted = translationEngine.IsModelPersisted }; } public override async Task Delete(DeleteRequest request, ServerCallContext context) @@ -126,6 +127,31 @@ public override async Task CancelBuild(CancelBuildRequest request, Server return Empty; } + public override async Task GetModelDownloadUrl( + GetModelDownloadUrlRequest request, + ServerCallContext context + ) + { + try + { + ITranslationEngineService engineService = GetEngineService(request.EngineType); + ModelDownloadUrl modelDownloadUrl = await engineService.GetModelDownloadUrlAsync( + request.EngineId, + context.CancellationToken + ); + return new GetModelDownloadUrlResponse + { + Url = modelDownloadUrl.Url, + ModelRevision = modelDownloadUrl.ModelRevision, + ExpiresAt = modelDownloadUrl.ExipiresAt.ToTimestamp() + }; + } + catch (InvalidOperationException e) + { + throw new RpcException(new Status(StatusCode.Aborted, e.Message)); + } + } + public override async Task GetQueueSize( GetQueueSizeRequest request, ServerCallContext context diff --git a/src/SIL.Machine.AspNetCore/Services/SharedFileService.cs b/src/SIL.Machine.AspNetCore/Services/SharedFileService.cs index db349dba..b4244211 100644 --- a/src/SIL.Machine.AspNetCore/Services/SharedFileService.cs +++ b/src/SIL.Machine.AspNetCore/Services/SharedFileService.cs @@ -55,6 +55,20 @@ public Uri GetResolvedUri(string path) return new Uri(_baseUri, path); } + public async Task GetDownloadUrlAsync(string path, DateTime expiresAt) + { + return await _fileStorage.GetDownloadUrlAsync(path, expiresAt); + } + + public Task> ListFilesAsync( + string path, + bool recurse = false, + CancellationToken cancellationToken = default + ) + { + return _fileStorage.ListFilesAsync(path, recurse, cancellationToken); + } + public Task OpenReadAsync(string path, CancellationToken cancellationToken = default) { return _fileStorage.OpenReadAsync(path, cancellationToken); diff --git a/src/SIL.Machine.AspNetCore/Services/SmtTransferEngineService.cs b/src/SIL.Machine.AspNetCore/Services/SmtTransferEngineService.cs index d7d03d9f..ec77dc88 100644 --- a/src/SIL.Machine.AspNetCore/Services/SmtTransferEngineService.cs +++ b/src/SIL.Machine.AspNetCore/Services/SmtTransferEngineService.cs @@ -5,59 +5,52 @@ public static class SmtTransferBuildStages public const string Train = "train"; } -public class SmtTransferEngineService : ITranslationEngineService +public class SmtTransferEngineService( + IDistributedReaderWriterLockFactory lockFactory, + IPlatformService platformService, + IDataAccessContext dataAccessContext, + IRepository engines, + IRepository trainSegmentPairs, + SmtTransferEngineStateService stateService, + IBuildJobService buildJobService, + JobStorage jobStorage +) : ITranslationEngineService { - private readonly IDistributedReaderWriterLockFactory _lockFactory; - private readonly IPlatformService _platformService; - private readonly IDataAccessContext _dataAccessContext; - private readonly IRepository _engines; - private readonly IRepository _trainSegmentPairs; - private readonly SmtTransferEngineStateService _stateService; - private readonly IBuildJobService _buildJobService; - private readonly JobStorage _jobStorage; - - public SmtTransferEngineService( - IDistributedReaderWriterLockFactory lockFactory, - IPlatformService platformService, - IDataAccessContext dataAccessContext, - IRepository engines, - IRepository trainSegmentPairs, - SmtTransferEngineStateService stateService, - IBuildJobService buildJobService, - JobStorage jobStorage - ) - { - _lockFactory = lockFactory; - _platformService = platformService; - _dataAccessContext = dataAccessContext; - _engines = engines; - _trainSegmentPairs = trainSegmentPairs; - _stateService = stateService; - _buildJobService = buildJobService; - _jobStorage = jobStorage; - } + private readonly IDistributedReaderWriterLockFactory _lockFactory = lockFactory; + private readonly IPlatformService _platformService = platformService; + private readonly IDataAccessContext _dataAccessContext = dataAccessContext; + private readonly IRepository _engines = engines; + private readonly IRepository _trainSegmentPairs = trainSegmentPairs; + private readonly SmtTransferEngineStateService _stateService = stateService; + private readonly IBuildJobService _buildJobService = buildJobService; + private readonly JobStorage _jobStorage = jobStorage; public TranslationEngineType Type => TranslationEngineType.SmtTransfer; - public async Task CreateAsync( + public async Task CreateAsync( string engineId, string? engineName, string sourceLanguage, string targetLanguage, + bool? isModelPersisted = null, CancellationToken cancellationToken = default ) { + if (isModelPersisted == false) + throw new NotSupportedException( + "SMT transfer engines do not support non-persisted models." + + "Please remove the isModelPersisted parameter or set it to true." + ); await _dataAccessContext.BeginTransactionAsync(cancellationToken); - await _engines.InsertAsync( - new TranslationEngine - { - EngineId = engineId, - SourceLanguage = sourceLanguage, - TargetLanguage = targetLanguage - }, - cancellationToken - ); - await _buildJobService.CreateEngineAsync(new[] { BuildJobType.Cpu }, engineId, engineName, cancellationToken); + var translationEngine = new TranslationEngine + { + EngineId = engineId, + SourceLanguage = sourceLanguage, + TargetLanguage = targetLanguage, + IsModelPersisted = isModelPersisted ?? true // models are persisted if not specified + }; + await _engines.InsertAsync(translationEngine, cancellationToken); + await _buildJobService.CreateEngineAsync([BuildJobType.Cpu], engineId, engineName, cancellationToken); await _dataAccessContext.CommitTransactionAsync(CancellationToken.None); IDistributedReaderWriterLock @lock = await _lockFactory.CreateAsync(engineId, CancellationToken.None); @@ -66,6 +59,7 @@ await _engines.InsertAsync( SmtTransferEngineState state = _stateService.Get(engineId); state.InitNew(); } + return translationEngine; } public async Task DeleteAsync(string engineId, CancellationToken cancellationToken = default) @@ -227,6 +221,14 @@ private async Task CancelBuildJobAsync(string engineId, CancellationToken return buildId is not null; } + public Task GetModelDownloadUrlAsync( + string engineId, + CancellationToken cancellationToken = default + ) + { + throw new NotSupportedException(); + } + private async Task GetEngineAsync(string engineId, CancellationToken cancellationToken) { TranslationEngine? engine = await _engines.GetAsync(e => e.EngineId == engineId, cancellationToken); diff --git a/src/SIL.Machine.AspNetCore/Usings.cs b/src/SIL.Machine.AspNetCore/Usings.cs index 3994bcdf..5bbd614c 100644 --- a/src/SIL.Machine.AspNetCore/Usings.cs +++ b/src/SIL.Machine.AspNetCore/Usings.cs @@ -2,9 +2,11 @@ global using System.Data; global using System.Diagnostics; global using System.Diagnostics.CodeAnalysis; +global using System.Globalization; global using System.IO.Compression; global using System.Linq.Expressions; global using System.Net; +global using System.Net.Mime; global using System.Reflection; global using System.Runtime.CompilerServices; global using System.Security.Cryptography; diff --git a/src/SIL.Machine.Serval.EngineServer/Program.cs b/src/SIL.Machine.Serval.EngineServer/Program.cs index 863598c8..5140b15a 100644 --- a/src/SIL.Machine.Serval.EngineServer/Program.cs +++ b/src/SIL.Machine.Serval.EngineServer/Program.cs @@ -1,5 +1,6 @@ using Hangfire; using OpenTelemetry.Trace; +using SIL.Machine.AspNetCore.Services; var builder = WebApplication.CreateBuilder(args); @@ -10,7 +11,9 @@ .AddMongoHangfireJobClient() .AddServalTranslationEngineService() .AddBuildJobService() + .AddModelCleanupService() .AddClearMLService(); + if (builder.Environment.IsDevelopment()) builder .Services.AddOpenTelemetry() diff --git a/tests/SIL.Machine.AspNetCore.Tests/Services/ModelCleanupServiceTests.cs b/tests/SIL.Machine.AspNetCore.Tests/Services/ModelCleanupServiceTests.cs new file mode 100644 index 00000000..4215236b --- /dev/null +++ b/tests/SIL.Machine.AspNetCore.Tests/Services/ModelCleanupServiceTests.cs @@ -0,0 +1,97 @@ +namespace SIL.Machine.AspNetCore.Services; + +[TestFixture] +public class ModelCleanupServiceTests +{ + private readonly ISharedFileService _sharedFileService = new SharedFileService(Substitute.For()); + private readonly MemoryRepository _engines = new MemoryRepository(); + private static readonly List validFiles = + [ + "models/engineId1_1.tar.gz", + "models/engineId2_2.tar.gz", + "models/engineId2_3.tar.gz" // only one build ahead - keep + ]; + private static readonly List invalidFiles = + [ + "models/engineId2_1.targ.gz", // 1 build behind + "models/engineId2_4.tar.gz", // 2 builds ahead + "models/worngId_1.tar.gz", + "models/engineId1_badbuildnumber.tar.gz", + "models/noBuildNumber.tar.gz", + "models/engineId1_1.differentExtension" + ]; + + private async Task SetUpAsync() + { + _engines.Add( + new TranslationEngine + { + Id = "engine1", + EngineId = "engineId1", + SourceLanguage = "es", + TargetLanguage = "en", + BuildRevision = 1, + IsModelPersisted = true + } + ); + _engines.Add( + new TranslationEngine + { + Id = "engine2", + EngineId = "engineId2", + SourceLanguage = "es", + TargetLanguage = "en", + BuildRevision = 2, + IsModelPersisted = true + } + ); + async Task WriteFileStub(string path, string content) + { + using StreamWriter streamWriter = + new(await _sharedFileService.OpenWriteAsync(path, CancellationToken.None)); + await streamWriter.WriteAsync(content); + } + foreach (string path in validFiles) + { + await WriteFileStub(path, "content"); + } + foreach (string path in invalidFiles) + { + await WriteFileStub(path, "content"); + } + } + + public class TestModelCleanupService( + IServiceProvider serviceProvider, + ISharedFileService sharedFileService, + IRepository engines, + ILogger logger + ) : ModelCleanupService(serviceProvider, sharedFileService, engines, logger) + { + public async Task DoWorkAsync() => + await base.DoWorkAsync(Substitute.For(), CancellationToken.None); + } + + [Test] + public async Task DoWorkAsync_ValidFiles() + { + await SetUpAsync(); + + var cleanupJob = new TestModelCleanupService( + Substitute.For(), + _sharedFileService, + _engines, + Substitute.For>() + ); + Assert.That( + _sharedFileService.ListFilesAsync("models").Result.ToHashSet(), + Is.EquivalentTo(validFiles.Concat(invalidFiles).ToHashSet()) + ); + await cleanupJob.DoWorkAsync(); + // only valid files exist after running service + Assert.That( + _sharedFileService.ListFilesAsync("models").Result.ToHashSet(), + Is.EquivalentTo(validFiles.ToHashSet()) + ); + } +} diff --git a/tests/SIL.Machine.AspNetCore.Tests/Services/NmtEngineServiceTests.cs b/tests/SIL.Machine.AspNetCore.Tests/Services/NmtEngineServiceTests.cs index 0794a914..b46340c5 100644 --- a/tests/SIL.Machine.AspNetCore.Tests/Services/NmtEngineServiceTests.cs +++ b/tests/SIL.Machine.AspNetCore.Tests/Services/NmtEngineServiceTests.cs @@ -12,8 +12,12 @@ public async Task StartBuildAsync() await env.Service.StartBuildAsync("engine1", "build1", "{}", Array.Empty()); await env.WaitForBuildToFinishAsync(); engine = env.Engines.Get("engine1"); - Assert.That(engine.CurrentBuild, Is.Null); - Assert.That(engine.BuildRevision, Is.EqualTo(2)); + Assert.Multiple(() => + { + Assert.That(engine.CurrentBuild, Is.Null); + Assert.That(engine.BuildRevision, Is.EqualTo(2)); + Assert.That(engine.IsModelPersisted, Is.False); + }); } [Test] @@ -204,7 +208,8 @@ private NmtEngineService CreateService() Engines, BuildJobService, new LanguageTagService(), - ClearMLMonitorService + ClearMLMonitorService, + SharedFileService ); } diff --git a/tests/SIL.Machine.AspNetCore.Tests/Services/SmtTransferEngineServiceTests.cs b/tests/SIL.Machine.AspNetCore.Tests/Services/SmtTransferEngineServiceTests.cs index c0763560..0a3d93bf 100644 --- a/tests/SIL.Machine.AspNetCore.Tests/Services/SmtTransferEngineServiceTests.cs +++ b/tests/SIL.Machine.AspNetCore.Tests/Services/SmtTransferEngineServiceTests.cs @@ -9,9 +9,13 @@ public async Task CreateAsync() using var env = new TestEnvironment(); await env.Service.CreateAsync("engine2", "Engine 2", "es", "en"); TranslationEngine? engine = await env.Engines.GetAsync(e => e.EngineId == "engine2"); - Assert.That(engine, Is.Not.Null); - Assert.That(engine.EngineId, Is.EqualTo("engine2")); - Assert.That(engine.BuildRevision, Is.EqualTo(0)); + Assert.Multiple(() => + { + Assert.That(engine, Is.Not.Null); + Assert.That(engine?.EngineId, Is.EqualTo("engine2")); + Assert.That(engine?.BuildRevision, Is.EqualTo(0)); + Assert.That(engine?.IsModelPersisted, Is.True); + }); env.SmtModelFactory.Received().InitNew("engine2"); env.TransferEngineFactory.Received().InitNew("engine2"); } diff --git a/tests/SIL.Machine.AspNetCore.Tests/Usings.cs b/tests/SIL.Machine.AspNetCore.Tests/Usings.cs index 222a7a74..4aafc901 100644 --- a/tests/SIL.Machine.AspNetCore.Tests/Usings.cs +++ b/tests/SIL.Machine.AspNetCore.Tests/Usings.cs @@ -3,6 +3,7 @@ global using System.Text.Json.Nodes; global using Hangfire; global using Hangfire.Storage; +global using Microsoft.Extensions.DependencyInjection; global using Microsoft.Extensions.Hosting; global using Microsoft.Extensions.Hosting.Internal; global using Microsoft.Extensions.Logging;