diff --git a/src/TesApi.Tests/BatchSchedulerTests.cs b/src/TesApi.Tests/BatchSchedulerTests.cs index 560fd2026..0ba80b85b 100644 --- a/src/TesApi.Tests/BatchSchedulerTests.cs +++ b/src/TesApi.Tests/BatchSchedulerTests.cs @@ -247,12 +247,14 @@ public async Task TestIfVmSizeIsAvailable(string vmSize, bool preemptible) task.Resources.Preemptible = preemptible; task.Resources.BackendParameters = new() { { "vm_size", vmSize } }; + var config = GetMockConfig(false)(); using var serviceProvider = GetServiceProvider( - GetMockConfig(false)(), + config, GetMockAzureProxy(AzureProxyReturnValues.Defaults), GetMockQuotaProvider(AzureProxyReturnValues.Defaults), GetMockSkuInfoProvider(AzureProxyReturnValues.Defaults), - GetContainerRegistryInfoProvider(AzureProxyReturnValues.Defaults)); + GetContainerRegistryInfoProvider(AzureProxyReturnValues.Defaults), + GetMockAllowedVms(config)); var batchScheduler = serviceProvider.GetT(); var size = await ((BatchScheduler)batchScheduler).GetVmSizeAsync(task); @@ -620,12 +622,14 @@ void Validator(TesTask _1, IEnumerable<(LogLevel logLevel, Exception exception)> public async Task BatchJobContainsExpectedBatchPoolInformation() { var tesTask = GetTesTask(); + var config = GetMockConfig(false)(); using var serviceProvider = GetServiceProvider( - GetMockConfig(false)(), + config, GetMockAzureProxy(AzureProxyReturnValues.Defaults), GetMockQuotaProvider(AzureProxyReturnValues.Defaults), GetMockSkuInfoProvider(AzureProxyReturnValues.Defaults), - GetContainerRegistryInfoProvider(AzureProxyReturnValues.Defaults)); + GetContainerRegistryInfoProvider(AzureProxyReturnValues.Defaults), + GetMockAllowedVms(config)); var batchScheduler = serviceProvider.GetT(); await batchScheduler.ProcessTesTaskAsync(tesTask); @@ -1372,6 +1376,7 @@ public async Task PoolIsCreatedWithoutPublicIpWhenSubnetAndDisableBatchNodesPubl GetMockQuotaProvider(azureProxyReturnValues), GetMockSkuInfoProvider(azureProxyReturnValues), GetContainerRegistryInfoProvider(azureProxyReturnValues), + GetMockAllowedVms(configuration), additionalActions: additionalActions); var batchScheduler = serviceProvider.GetT(); @@ -1389,6 +1394,20 @@ public async Task PoolIsCreatedWithoutPublicIpWhenSubnetAndDisableBatchNodesPubl return (jobId, cloudTask, poolInformation, batchPoolsModel); } + private static Action> GetMockAllowedVms(IEnumerable<(string Key, string Value)> configuration) + => new(proxy => + { + var allowedVmsConfig = configuration.FirstOrDefault(x => x.Key == "AllowedVmSizes").Value; + var allowedVms = new List(); + if (!string.IsNullOrWhiteSpace(allowedVmsConfig)) + { + allowedVms = allowedVmsConfig.Split(",").ToList(); + } + proxy.Setup(p => p.GetAllowedVmSizes()) + .ReturnsAsync(allowedVms); + }); + + private static Action> GetMockSkuInfoProvider(AzureProxyReturnValues azureProxyReturnValues) => new(proxy => proxy.Setup(p => p.GetVmSizesAndPricesAsync(It.IsAny())) @@ -1431,8 +1450,8 @@ private static Action> GetMockQuotaProvider(AzureProxy new(batchQuotas.ActiveJobAndJobScheduleQuota, batchQuotas.PoolQuota, batchQuotas.DedicatedCoreQuota, batchQuotas.LowPriorityCoreQuota))); }); - private static TestServices.TestServiceProvider GetServiceProvider(IEnumerable<(string Key, string Value)> configuration, Action> azureProxy, Action> quotaProvider, Action> skuInfoProvider, Action> containerRegistryProviderSetup, Action additionalActions = default) - => new(wrapAzureProxy: true, configuration: configuration, azureProxy: azureProxy, batchQuotaProvider: quotaProvider, batchSkuInformationProvider: skuInfoProvider, accountResourceInformation: GetNewBatchResourceInfo(), containerRegistryProviderSetup: containerRegistryProviderSetup, additionalActions: additionalActions); + private static TestServices.TestServiceProvider GetServiceProvider(IEnumerable<(string Key, string Value)> configuration, Action> azureProxy, Action> quotaProvider, Action> skuInfoProvider, Action> containerRegistryProviderSetup, Action> allowedVmSizesServiceSetup, Action additionalActions = default) + => new(wrapAzureProxy: true, configuration: configuration, azureProxy: azureProxy, batchQuotaProvider: quotaProvider, batchSkuInformationProvider: skuInfoProvider, accountResourceInformation: GetNewBatchResourceInfo(), containerRegistryProviderSetup: containerRegistryProviderSetup, allowedVmSizesServiceSetup: allowedVmSizesServiceSetup, additionalActions: additionalActions); private static async Task GetNewTesTaskStateAsync(TesTask tesTask, AzureProxyReturnValues azureProxyReturnValues) { @@ -1572,13 +1591,15 @@ private static IEnumerable GetFilesToDownload(Mock private static TestServices.TestServiceProvider GetServiceProvider(AzureProxyReturnValues azureProxyReturn = default) { azureProxyReturn ??= AzureProxyReturnValues.Defaults; + var config = GetMockConfig(false)(); return new( wrapAzureProxy: true, accountResourceInformation: new("defaultbatchaccount", "defaultresourcegroup", "defaultsubscription", "defaultregion"), - configuration: GetMockConfig(false)(), + configuration: config, azureProxy: GetMockAzureProxy(azureProxyReturn), batchQuotaProvider: GetMockQuotaProvider(azureProxyReturn), - batchSkuInformationProvider: GetMockSkuInfoProvider(azureProxyReturn)); + batchSkuInformationProvider: GetMockSkuInfoProvider(azureProxyReturn), + allowedVmSizesServiceSetup: GetMockAllowedVms(config)); } private static async Task AddPool(BatchScheduler batchScheduler) diff --git a/src/TesApi.Tests/ConfigurationUtilsTests.cs b/src/TesApi.Tests/ConfigurationUtilsTests.cs index fed161189..e41f86f51 100644 --- a/src/TesApi.Tests/ConfigurationUtilsTests.cs +++ b/src/TesApi.Tests/ConfigurationUtilsTests.cs @@ -76,9 +76,9 @@ public async Task UnsupportedVmSizeInAllowedVmSizesFileIsIgnoredAndTaggedWithWar batchQuotaProvider: GetMockQuotaProvider()); var configurationUtils = serviceProvider.GetT(); - await configurationUtils.ProcessAllowedVmSizesConfigurationFileAsync(); + var result = await configurationUtils.ProcessAllowedVmSizesConfigurationFileAsync(); - Assert.AreEqual("VmSize1,VmSize2,VmFamily3", serviceProvider.Configuration["AllowedVmSizes"]); + Assert.AreEqual("VmSize1,VmSize2,VmFamily3", string.Join(",", result)); var expectedAllowedVmSizesFileContent = "VmSize1\n" + diff --git a/src/TesApi.Tests/TestServices/TestServiceProvider.cs b/src/TesApi.Tests/TestServices/TestServiceProvider.cs index c13f216ca..3494c9602 100644 --- a/src/TesApi.Tests/TestServices/TestServiceProvider.cs +++ b/src/TesApi.Tests/TestServices/TestServiceProvider.cs @@ -42,10 +42,13 @@ internal TestServiceProvider( Action> batchQuotaProvider = default, (Func>> expression, Action> action) armBatchQuotaProvider = default, //added so config utils gets the arm implementation, to be removed once config utils is refactored. Action> containerRegistryProviderSetup = default, + Action> allowedVmSizesServiceSetup = default, Action additionalActions = default) { Configuration = GetConfiguration(configuration); provider = new ServiceCollection() + .AddSingleton() + .AddSingleton(_ => GetAllowedVmSizesServiceProviderProvider(allowedVmSizesServiceSetup).Object) .AddSingleton(_ => GetContainerRegisterProvider(containerRegistryProviderSetup).Object) .AddSingleton(Configuration) .AddSingleton(BindHelper(BatchAccountOptions.SectionName)) @@ -100,6 +103,7 @@ internal TestServiceProvider( internal Mock> TesTaskRepository { get; private set; } internal Mock StorageAccessProvider { get; private set; } internal Mock ContainerRegistryProvider { get; private set; } + internal Mock AllowedVmSizesServiceProvider { get; private set; } internal T GetT() => GetT(Array.Empty(), Array.Empty()); @@ -162,6 +166,13 @@ private Mock GetAzureProxy(Action> action) return AzureProxy = proxy; } + private Mock GetAllowedVmSizesServiceProviderProvider(Action> action) + { + var proxy = new Mock(); + action?.Invoke(proxy); + return AllowedVmSizesServiceProvider = proxy; + } + private Mock GetContainerRegisterProvider(Action> action) { var proxy = new Mock(); diff --git a/src/TesApi.Web/AllowedVmSizesService.cs b/src/TesApi.Web/AllowedVmSizesService.cs new file mode 100644 index 000000000..28cc2b792 --- /dev/null +++ b/src/TesApi.Web/AllowedVmSizesService.cs @@ -0,0 +1,91 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +using System; +using System.Collections.Generic; +using System.Threading; +using System.Threading.Tasks; +using Microsoft.Extensions.Hosting; +using Microsoft.Extensions.Logging; + +namespace TesApi.Web +{ + /// + /// Service that periodically fetches the allowed vms list from storage, and updates the supported vms list. + /// + public class AllowedVmSizesService : BackgroundService, IAllowedVmSizesService + { + private readonly TimeSpan refreshInterval = TimeSpan.FromHours(24); + private readonly ILogger logger; + private readonly ConfigurationUtils configUtils; + private List allowedVmSizes; + private Task firstTask; + + /// + /// Service that periodically fetches the allowed vms list from storage, and updates the supported vms list. + /// + /// + /// + public AllowedVmSizesService(ConfigurationUtils configUtils, ILogger logger) + { + ArgumentNullException.ThrowIfNull(configUtils); + ArgumentNullException.ThrowIfNull(logger); + + this.configUtils = configUtils; + this.logger = logger; + } + + private async Task GetAllowedVmSizesImpl() + { + try + { + logger.LogInformation("Executing allowed vm sizes config setup"); + allowedVmSizes = await configUtils.ProcessAllowedVmSizesConfigurationFileAsync(); + } + catch (Exception e) + { + logger.LogError(e, "Failed to execute allowed vm sizes config setup"); + throw; + } + } + + /// + protected override async Task ExecuteAsync(CancellationToken stoppingToken) + { + firstTask = GetAllowedVmSizesImpl(); + await firstTask; + + using PeriodicTimer timer = new(refreshInterval); + + try + { + while (await timer.WaitForNextTickAsync(stoppingToken)) + { + await GetAllowedVmSizesImpl(); + } + } + catch (OperationCanceledException) + { + logger.LogInformation("AllowedVmSizes Service is stopping."); + } + } + + /// + /// Awaits start up and then return allowed vm sizes. + /// + /// List of allowed vms. + public async Task> GetAllowedVmSizes() + { + if (allowedVmSizes == null) + { + while (firstTask is null) + { + await Task.Delay(TimeSpan.FromSeconds(1)); + } + await firstTask; + } + + return allowedVmSizes; + } + } +} diff --git a/src/TesApi.Web/BatchScheduler.cs b/src/TesApi.Web/BatchScheduler.cs index f13c37dcb..96303ece5 100644 --- a/src/TesApi.Web/BatchScheduler.cs +++ b/src/TesApi.Web/BatchScheduler.cs @@ -66,7 +66,6 @@ public partial class BatchScheduler : IBatchScheduler private readonly ILogger logger; private readonly IAzureProxy azureProxy; private readonly IStorageAccessProvider storageAccessProvider; - private readonly IEnumerable allowedVmSizes; private readonly IBatchQuotaVerifier quotaVerifier; private readonly IBatchSkuInformationProvider skuInformationProvider; private readonly List tesTaskStateTransitions; @@ -87,6 +86,7 @@ public partial class BatchScheduler : IBatchScheduler private readonly IBatchPoolFactory _batchPoolFactory; private readonly string[] taskRunScriptContent; private readonly string[] taskCleanupScriptContent; + private readonly IAllowedVmSizesService allowedVmSizesService; private HashSet onlyLogBatchTaskStateOnce = new(); @@ -108,6 +108,7 @@ public partial class BatchScheduler : IBatchScheduler /// Sku information provider /// Container registry information /// Batch pool factory + /// Service to get allowed vm sizes. public BatchScheduler( ILogger logger, IOptions batchGen1Options, @@ -117,16 +118,15 @@ public BatchScheduler( IOptions batchImageNameOptions, IOptions batchNodesOptions, IOptions batchSchedulingOptions, - IConfiguration configuration, IAzureProxy azureProxy, IStorageAccessProvider storageAccessProvider, IBatchQuotaVerifier quotaVerifier, IBatchSkuInformationProvider skuInformationProvider, ContainerRegistryProvider containerRegistryProvider, - IBatchPoolFactory poolFactory) + IBatchPoolFactory poolFactory, + IAllowedVmSizesService allowedVmSizesService) { ArgumentNullException.ThrowIfNull(logger); - ArgumentNullException.ThrowIfNull(configuration); ArgumentNullException.ThrowIfNull(azureProxy); ArgumentNullException.ThrowIfNull(storageAccessProvider); ArgumentNullException.ThrowIfNull(quotaVerifier); @@ -141,7 +141,6 @@ public BatchScheduler( this.skuInformationProvider = skuInformationProvider; this.containerRegistryProvider = containerRegistryProvider; - this.allowedVmSizes = GetStringValue(configuration, "AllowedVmSizes", null)?.Split(',', StringSplitOptions.RemoveEmptyEntries).ToList(); this.usePreemptibleVmsOnly = batchSchedulingOptions.Value.UsePreemptibleVmsOnly; this.batchNodesSubnetId = batchNodesOptions.Value.SubnetId; this.dockerInDockerImageName = batchImageNameOptions.Value.Docker; @@ -158,6 +157,7 @@ public BatchScheduler( this.marthaSecretName = marthaOptions.Value.SecretName; this.globalStartTaskPath = StandardizeStartTaskPath(batchNodesOptions.Value.GlobalStartTask, this.defaultStorageAccountName); this.globalManagedIdentity = batchNodesOptions.Value.GlobalManagedIdentity; + this.allowedVmSizesService = allowedVmSizesService; if (!this.enableBatchAutopool) { @@ -188,8 +188,6 @@ public BatchScheduler( logger.LogInformation($"usePreemptibleVmsOnly: {usePreemptibleVmsOnly}"); - static string GetStringValue(IConfiguration configuration, string key, string defaultValue = "") => string.IsNullOrWhiteSpace(configuration[key]) ? defaultValue : configuration[key]; - static bool tesTaskIsQueuedInitializingOrRunning(TesTask tesTask) => tesTask.State == TesState.QUEUEDEnum || tesTask.State == TesState.INITIALIZINGEnum || tesTask.State == TesState.RUNNINGEnum; static bool tesTaskIsInitializingOrRunning(TesTask tesTask) => tesTask.State == TesState.INITIALIZINGEnum || tesTask.State == TesState.RUNNINGEnum; static bool tesTaskIsQueuedOrInitializing(TesTask tesTask) => tesTask.State == TesState.QUEUEDEnum || tesTask.State == TesState.INITIALIZINGEnum; @@ -1617,6 +1615,7 @@ private static string RemoveQueryStringsFromLocalFilePaths(string originalString /// The virtual machine info public async Task GetVmSizeAsync(TesTask tesTask, bool forcePreemptibleVmsOnly = false) { + var allowedVmSizes = await allowedVmSizesService.GetAllowedVmSizes(); bool allowedVmSizesFilter(VirtualMachineInformation vm) => allowedVmSizes is null || !allowedVmSizes.Any() || allowedVmSizes.Contains(vm.VmSize, StringComparer.OrdinalIgnoreCase) || allowedVmSizes.Contains(vm.VmFamily, StringComparer.OrdinalIgnoreCase); var tesResources = tesTask.Resources; diff --git a/src/TesApi.Web/ConfigurationUtils.cs b/src/TesApi.Web/ConfigurationUtils.cs index 38e02c21d..0498bda1a 100644 --- a/src/TesApi.Web/ConfigurationUtils.cs +++ b/src/TesApi.Web/ConfigurationUtils.cs @@ -72,7 +72,7 @@ public ConfigurationUtils( /// entries in the allowed-vm-sizes file with a warning. Sets the AllowedVmSizes configuration key. /// /// - public async Task ProcessAllowedVmSizesConfigurationFileAsync() + public async Task> ProcessAllowedVmSizesConfigurationFileAsync() { var supportedVmSizesFilePath = $"/{defaultStorageAccountName}/configuration/supported-vm-sizes"; var allowedVmSizesFilePath = $"/{defaultStorageAccountName}/configuration/allowed-vm-sizes"; @@ -95,7 +95,7 @@ public async Task ProcessAllowedVmSizesConfigurationFileAsync() if (allowedVmSizesFileContent is null) { logger.LogWarning($"Unable to read from {allowedVmSizesFilePath}. All supported VM sizes will be eligible for Azure Batch task scheduling."); - return; + return new List(); } // Read the allowed-vm-sizes configuration file and remove any previous warnings (those start with "<" following the VM size or family name) @@ -141,10 +141,7 @@ public async Task ProcessAllowedVmSizesConfigurationFileAsync() } } - if (allowedAndSupportedVmSizes.Any()) - { - this.configuration["AllowedVmSizes"] = string.Join(',', allowedAndSupportedVmSizes); - } + return allowedAndSupportedVmSizes; } /// diff --git a/src/TesApi.Web/DoOnceAtStartUpService.cs b/src/TesApi.Web/DoOnceAtStartUpService.cs deleted file mode 100644 index 8d02bf1d7..000000000 --- a/src/TesApi.Web/DoOnceAtStartUpService.cs +++ /dev/null @@ -1,52 +0,0 @@ -// Copyright (c) Microsoft Corporation. -// Licensed under the MIT License. - -using System; -using System.Threading; -using System.Threading.Tasks; -using Microsoft.Extensions.Hosting; -using Microsoft.Extensions.Logging; - -namespace TesApi.Web -{ - /// - /// Hosted service that executes one-time set up tasks at start up. - /// - public class DoOnceAtStartUpService : BackgroundService - { - private readonly ILogger logger; - private readonly ConfigurationUtils configUtils; - - /// - /// Hosted service that executes one-time set-up tasks at start up. - /// - /// - /// - public DoOnceAtStartUpService(ConfigurationUtils configUtils, ILogger logger) - { - ArgumentNullException.ThrowIfNull(configUtils); - ArgumentNullException.ThrowIfNull(logger); - - this.configUtils = configUtils; - this.logger = logger; - } - - /// - protected override async Task ExecuteAsync(CancellationToken stoppingToken) - { - using (logger.BeginScope("Executing Start Up tasks")) - { - try - { - logger.LogInformation("Executing Configuration Utils Setup"); - await configUtils.ProcessAllowedVmSizesConfigurationFileAsync(); - } - catch (Exception e) - { - logger.LogError(e, "Failed to execute start up tasks"); - throw; - } - } - } - } -} diff --git a/src/TesApi.Web/IAllowedVmSizesService.cs b/src/TesApi.Web/IAllowedVmSizesService.cs new file mode 100644 index 000000000..920799e47 --- /dev/null +++ b/src/TesApi.Web/IAllowedVmSizesService.cs @@ -0,0 +1,20 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +using System.Collections.Generic; +using System.Threading.Tasks; + +namespace TesApi.Web +{ + /// + /// Interface to get allowed vm sizes for TES. + /// + public interface IAllowedVmSizesService + { + /// + /// Gets allowed vm sizes. + /// + /// A list of allowed vm sizes. + Task> GetAllowedVmSizes(); + } +} diff --git a/src/TesApi.Web/Startup.cs b/src/TesApi.Web/Startup.cs index 51c02a0d4..4fc732a40 100644 --- a/src/TesApi.Web/Startup.cs +++ b/src/TesApi.Web/Startup.cs @@ -104,6 +104,7 @@ public void ConfigureServices(IServiceCollection services) .AddSingleton(CreateBatchQuotaProviderFromConfiguration) .AddSingleton() .AddSingleton() + .AddSingleton() .AddSingleton(s => new DefaultAzureCredential()) .AddSwaggerGen(c => @@ -132,7 +133,7 @@ public void ConfigureServices(IServiceCollection services) }) // Order is important for hosted services - .AddHostedService() + .AddHostedService(sp => (AllowedVmSizesService)sp.GetRequiredService(typeof(IAllowedVmSizesService))) .AddHostedService() .AddHostedService() .AddHostedService()