Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix allowedVmSizes #210

Merged
merged 11 commits into from
May 2, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 29 additions & 8 deletions src/TesApi.Tests/BatchSchedulerTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -247,12 +247,14 @@ public async Task TestIfVmSizeIsAvailable(string vmSize, bool preemptible)
task.Resources.Preemptible = preemptible;
task.Resources.BackendParameters = new() { { "vm_size", vmSize } };

var config = GetMockConfig(false)();
using var serviceProvider = GetServiceProvider(
GetMockConfig(false)(),
config,
GetMockAzureProxy(AzureProxyReturnValues.Defaults),
GetMockQuotaProvider(AzureProxyReturnValues.Defaults),
GetMockSkuInfoProvider(AzureProxyReturnValues.Defaults),
GetContainerRegistryInfoProvider(AzureProxyReturnValues.Defaults));
GetContainerRegistryInfoProvider(AzureProxyReturnValues.Defaults),
GetMockAllowedVms(config));
var batchScheduler = serviceProvider.GetT();

var size = await ((BatchScheduler)batchScheduler).GetVmSizeAsync(task);
Expand Down Expand Up @@ -620,12 +622,14 @@ void Validator(TesTask _1, IEnumerable<(LogLevel logLevel, Exception exception)>
public async Task BatchJobContainsExpectedBatchPoolInformation()
{
var tesTask = GetTesTask();
var config = GetMockConfig(false)();
using var serviceProvider = GetServiceProvider(
GetMockConfig(false)(),
config,
GetMockAzureProxy(AzureProxyReturnValues.Defaults),
GetMockQuotaProvider(AzureProxyReturnValues.Defaults),
GetMockSkuInfoProvider(AzureProxyReturnValues.Defaults),
GetContainerRegistryInfoProvider(AzureProxyReturnValues.Defaults));
GetContainerRegistryInfoProvider(AzureProxyReturnValues.Defaults),
GetMockAllowedVms(config));
var batchScheduler = serviceProvider.GetT();

await batchScheduler.ProcessTesTaskAsync(tesTask);
Expand Down Expand Up @@ -1372,6 +1376,7 @@ public async Task PoolIsCreatedWithoutPublicIpWhenSubnetAndDisableBatchNodesPubl
GetMockQuotaProvider(azureProxyReturnValues),
GetMockSkuInfoProvider(azureProxyReturnValues),
GetContainerRegistryInfoProvider(azureProxyReturnValues),
GetMockAllowedVms(configuration),
additionalActions: additionalActions);
var batchScheduler = serviceProvider.GetT();

Expand All @@ -1389,6 +1394,20 @@ public async Task PoolIsCreatedWithoutPublicIpWhenSubnetAndDisableBatchNodesPubl
return (jobId, cloudTask, poolInformation, batchPoolsModel);
}

private static Action<Mock<IAllowedVmSizesService>> GetMockAllowedVms(IEnumerable<(string Key, string Value)> configuration)
=> new(proxy =>
{
var allowedVmsConfig = configuration.FirstOrDefault(x => x.Key == "AllowedVmSizes").Value;
var allowedVms = new List<string>();
if (!string.IsNullOrWhiteSpace(allowedVmsConfig))
{
allowedVms = allowedVmsConfig.Split(",").ToList();
}
proxy.Setup(p => p.GetAllowedVmSizes())
.ReturnsAsync(allowedVms);
});


private static Action<Mock<IBatchSkuInformationProvider>> GetMockSkuInfoProvider(AzureProxyReturnValues azureProxyReturnValues)
=> new(proxy =>
proxy.Setup(p => p.GetVmSizesAndPricesAsync(It.IsAny<string>()))
Expand Down Expand Up @@ -1431,8 +1450,8 @@ private static Action<Mock<IBatchQuotaProvider>> GetMockQuotaProvider(AzureProxy
new(batchQuotas.ActiveJobAndJobScheduleQuota, batchQuotas.PoolQuota, batchQuotas.DedicatedCoreQuota, batchQuotas.LowPriorityCoreQuota)));
});

private static TestServices.TestServiceProvider<IBatchScheduler> GetServiceProvider(IEnumerable<(string Key, string Value)> configuration, Action<Mock<IAzureProxy>> azureProxy, Action<Mock<IBatchQuotaProvider>> quotaProvider, Action<Mock<IBatchSkuInformationProvider>> skuInfoProvider, Action<Mock<ContainerRegistryProvider>> containerRegistryProviderSetup, Action<IServiceCollection> additionalActions = default)
=> new(wrapAzureProxy: true, configuration: configuration, azureProxy: azureProxy, batchQuotaProvider: quotaProvider, batchSkuInformationProvider: skuInfoProvider, accountResourceInformation: GetNewBatchResourceInfo(), containerRegistryProviderSetup: containerRegistryProviderSetup, additionalActions: additionalActions);
private static TestServices.TestServiceProvider<IBatchScheduler> GetServiceProvider(IEnumerable<(string Key, string Value)> configuration, Action<Mock<IAzureProxy>> azureProxy, Action<Mock<IBatchQuotaProvider>> quotaProvider, Action<Mock<IBatchSkuInformationProvider>> skuInfoProvider, Action<Mock<ContainerRegistryProvider>> containerRegistryProviderSetup, Action<Mock<IAllowedVmSizesService>> allowedVmSizesServiceSetup, Action<IServiceCollection> additionalActions = default)
=> new(wrapAzureProxy: true, configuration: configuration, azureProxy: azureProxy, batchQuotaProvider: quotaProvider, batchSkuInformationProvider: skuInfoProvider, accountResourceInformation: GetNewBatchResourceInfo(), containerRegistryProviderSetup: containerRegistryProviderSetup, allowedVmSizesServiceSetup: allowedVmSizesServiceSetup, additionalActions: additionalActions);

private static async Task<TesState> GetNewTesTaskStateAsync(TesTask tesTask, AzureProxyReturnValues azureProxyReturnValues)
{
Expand Down Expand Up @@ -1572,13 +1591,15 @@ private static IEnumerable<FileToDownload> GetFilesToDownload(Mock<IAzureProxy>
private static TestServices.TestServiceProvider<IBatchScheduler> GetServiceProvider(AzureProxyReturnValues azureProxyReturn = default)
{
azureProxyReturn ??= AzureProxyReturnValues.Defaults;
var config = GetMockConfig(false)();
return new(
wrapAzureProxy: true,
accountResourceInformation: new("defaultbatchaccount", "defaultresourcegroup", "defaultsubscription", "defaultregion"),
configuration: GetMockConfig(false)(),
configuration: config,
azureProxy: GetMockAzureProxy(azureProxyReturn),
batchQuotaProvider: GetMockQuotaProvider(azureProxyReturn),
batchSkuInformationProvider: GetMockSkuInfoProvider(azureProxyReturn));
batchSkuInformationProvider: GetMockSkuInfoProvider(azureProxyReturn),
allowedVmSizesServiceSetup: GetMockAllowedVms(config));
}

private static async Task<BatchPool> AddPool(BatchScheduler batchScheduler)
Expand Down
4 changes: 2 additions & 2 deletions src/TesApi.Tests/ConfigurationUtilsTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -76,9 +76,9 @@ public async Task UnsupportedVmSizeInAllowedVmSizesFileIsIgnoredAndTaggedWithWar
batchQuotaProvider: GetMockQuotaProvider());
var configurationUtils = serviceProvider.GetT();

await configurationUtils.ProcessAllowedVmSizesConfigurationFileAsync();
var result = await configurationUtils.ProcessAllowedVmSizesConfigurationFileAsync();

Assert.AreEqual("VmSize1,VmSize2,VmFamily3", serviceProvider.Configuration["AllowedVmSizes"]);
Assert.AreEqual("VmSize1,VmSize2,VmFamily3", string.Join(",", result));

var expectedAllowedVmSizesFileContent =
"VmSize1\n" +
Expand Down
11 changes: 11 additions & 0 deletions src/TesApi.Tests/TestServices/TestServiceProvider.cs
Original file line number Diff line number Diff line change
Expand Up @@ -42,10 +42,13 @@ internal TestServiceProvider(
Action<Mock<IBatchQuotaProvider>> batchQuotaProvider = default,
(Func<IServiceProvider, System.Linq.Expressions.Expression<Func<ArmBatchQuotaProvider>>> expression, Action<Mock<ArmBatchQuotaProvider>> action) armBatchQuotaProvider = default, //added so config utils gets the arm implementation, to be removed once config utils is refactored.
Action<Mock<ContainerRegistryProvider>> containerRegistryProviderSetup = default,
Action<Mock<IAllowedVmSizesService>> allowedVmSizesServiceSetup = default,
Action<IServiceCollection> additionalActions = default)
{
Configuration = GetConfiguration(configuration);
provider = new ServiceCollection()
.AddSingleton<ConfigurationUtils>()
.AddSingleton(_ => GetAllowedVmSizesServiceProviderProvider(allowedVmSizesServiceSetup).Object)
.AddSingleton(_ => GetContainerRegisterProvider(containerRegistryProviderSetup).Object)
.AddSingleton(Configuration)
.AddSingleton(BindHelper<BatchAccountOptions>(BatchAccountOptions.SectionName))
Expand Down Expand Up @@ -100,6 +103,7 @@ internal TestServiceProvider(
internal Mock<IRepository<TesTask>> TesTaskRepository { get; private set; }
internal Mock<IStorageAccessProvider> StorageAccessProvider { get; private set; }
internal Mock<ContainerRegistryProvider> ContainerRegistryProvider { get; private set; }
internal Mock<IAllowedVmSizesService> AllowedVmSizesServiceProvider { get; private set; }

internal T GetT()
=> GetT(Array.Empty<Type>(), Array.Empty<object>());
Expand Down Expand Up @@ -162,6 +166,13 @@ private Mock<IAzureProxy> GetAzureProxy(Action<Mock<IAzureProxy>> action)
return AzureProxy = proxy;
}

private Mock<IAllowedVmSizesService> GetAllowedVmSizesServiceProviderProvider(Action<Mock<IAllowedVmSizesService>> action)
{
var proxy = new Mock<IAllowedVmSizesService>();
action?.Invoke(proxy);
return AllowedVmSizesServiceProvider = proxy;
}

private Mock<ContainerRegistryProvider> GetContainerRegisterProvider(Action<Mock<ContainerRegistryProvider>> action)
{
var proxy = new Mock<ContainerRegistryProvider>();
Expand Down
91 changes: 91 additions & 0 deletions src/TesApi.Web/AllowedVmSizesService.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
// Copyright (c) Microsoft Corporation.
// Licensed under the MIT License.

using System;
using System.Collections.Generic;
using System.Threading;
using System.Threading.Tasks;
using Microsoft.Extensions.Hosting;
using Microsoft.Extensions.Logging;

namespace TesApi.Web
{
/// <summary>
/// Service that periodically fetches the allowed vms list from storage, and updates the supported vms list.
/// </summary>
public class AllowedVmSizesService : BackgroundService, IAllowedVmSizesService
{
private readonly TimeSpan refreshInterval = TimeSpan.FromHours(24);
private readonly ILogger logger;
private readonly ConfigurationUtils configUtils;
private List<string> allowedVmSizes;
private Task firstTask;

/// <summary>
/// Service that periodically fetches the allowed vms list from storage, and updates the supported vms list.
/// </summary>
/// <param name="configUtils"></param>
/// <param name="logger"></param>
public AllowedVmSizesService(ConfigurationUtils configUtils, ILogger<AllowedVmSizesService> logger)
{
ArgumentNullException.ThrowIfNull(configUtils);
ArgumentNullException.ThrowIfNull(logger);

this.configUtils = configUtils;
this.logger = logger;
}

private async Task GetAllowedVmSizesImpl()
{
try
{
logger.LogInformation("Executing allowed vm sizes config setup");
allowedVmSizes = await configUtils.ProcessAllowedVmSizesConfigurationFileAsync();
}
catch (Exception e)
{
logger.LogError(e, "Failed to execute allowed vm sizes config setup");
throw;
}
}

/// <inheritdoc/>
protected override async Task ExecuteAsync(CancellationToken stoppingToken)
{
firstTask = GetAllowedVmSizesImpl();
await firstTask;

using PeriodicTimer timer = new(refreshInterval);

try
{
while (await timer.WaitForNextTickAsync(stoppingToken))
{
await GetAllowedVmSizesImpl();
}
}
catch (OperationCanceledException)
{
logger.LogInformation("AllowedVmSizes Service is stopping.");
}
}

/// <summary>
/// Awaits start up and then return allowed vm sizes.
/// </summary>
/// <returns>List of allowed vms.</returns>
public async Task<List<string>> GetAllowedVmSizes()
{
if (allowedVmSizes == null)
{
while (firstTask is null)
{
await Task.Delay(TimeSpan.FromSeconds(1));
}
await firstTask;
}

return allowedVmSizes;
}
}
}
13 changes: 6 additions & 7 deletions src/TesApi.Web/BatchScheduler.cs
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,6 @@ public partial class BatchScheduler : IBatchScheduler
private readonly ILogger logger;
private readonly IAzureProxy azureProxy;
private readonly IStorageAccessProvider storageAccessProvider;
private readonly IEnumerable<string> allowedVmSizes;
private readonly IBatchQuotaVerifier quotaVerifier;
private readonly IBatchSkuInformationProvider skuInformationProvider;
private readonly List<TesTaskStateTransition> tesTaskStateTransitions;
Expand All @@ -87,6 +86,7 @@ public partial class BatchScheduler : IBatchScheduler
private readonly IBatchPoolFactory _batchPoolFactory;
private readonly string[] taskRunScriptContent;
private readonly string[] taskCleanupScriptContent;
private readonly IAllowedVmSizesService allowedVmSizesService;

private HashSet<string> onlyLogBatchTaskStateOnce = new();

Expand All @@ -108,6 +108,7 @@ public partial class BatchScheduler : IBatchScheduler
/// <param name="skuInformationProvider">Sku information provider <see cref="IBatchSkuInformationProvider"/></param>
/// <param name="containerRegistryProvider">Container registry information <see cref="ContainerRegistryProvider"/></param>
/// <param name="poolFactory">Batch pool factory <see cref="IBatchPoolFactory"/></param>
/// <param name="allowedVmSizesService">Service to get allowed vm sizes.</param>
public BatchScheduler(
ILogger<BatchScheduler> logger,
IOptions<Options.BatchImageGeneration1Options> batchGen1Options,
Expand All @@ -117,16 +118,15 @@ public BatchScheduler(
IOptions<Options.BatchImageNameOptions> batchImageNameOptions,
IOptions<Options.BatchNodesOptions> batchNodesOptions,
IOptions<Options.BatchSchedulingOptions> batchSchedulingOptions,
IConfiguration configuration,
IAzureProxy azureProxy,
IStorageAccessProvider storageAccessProvider,
IBatchQuotaVerifier quotaVerifier,
IBatchSkuInformationProvider skuInformationProvider,
ContainerRegistryProvider containerRegistryProvider,
IBatchPoolFactory poolFactory)
IBatchPoolFactory poolFactory,
IAllowedVmSizesService allowedVmSizesService)
{
ArgumentNullException.ThrowIfNull(logger);
ArgumentNullException.ThrowIfNull(configuration);
ArgumentNullException.ThrowIfNull(azureProxy);
ArgumentNullException.ThrowIfNull(storageAccessProvider);
ArgumentNullException.ThrowIfNull(quotaVerifier);
Expand All @@ -141,7 +141,6 @@ public BatchScheduler(
this.skuInformationProvider = skuInformationProvider;
this.containerRegistryProvider = containerRegistryProvider;

this.allowedVmSizes = GetStringValue(configuration, "AllowedVmSizes", null)?.Split(',', StringSplitOptions.RemoveEmptyEntries).ToList();
this.usePreemptibleVmsOnly = batchSchedulingOptions.Value.UsePreemptibleVmsOnly;
this.batchNodesSubnetId = batchNodesOptions.Value.SubnetId;
this.dockerInDockerImageName = batchImageNameOptions.Value.Docker;
Expand All @@ -158,6 +157,7 @@ public BatchScheduler(
this.marthaSecretName = marthaOptions.Value.SecretName;
this.globalStartTaskPath = StandardizeStartTaskPath(batchNodesOptions.Value.GlobalStartTask, this.defaultStorageAccountName);
this.globalManagedIdentity = batchNodesOptions.Value.GlobalManagedIdentity;
this.allowedVmSizesService = allowedVmSizesService;

if (!this.enableBatchAutopool)
{
Expand Down Expand Up @@ -188,8 +188,6 @@ public BatchScheduler(

logger.LogInformation($"usePreemptibleVmsOnly: {usePreemptibleVmsOnly}");

static string GetStringValue(IConfiguration configuration, string key, string defaultValue = "") => string.IsNullOrWhiteSpace(configuration[key]) ? defaultValue : configuration[key];

static bool tesTaskIsQueuedInitializingOrRunning(TesTask tesTask) => tesTask.State == TesState.QUEUEDEnum || tesTask.State == TesState.INITIALIZINGEnum || tesTask.State == TesState.RUNNINGEnum;
static bool tesTaskIsInitializingOrRunning(TesTask tesTask) => tesTask.State == TesState.INITIALIZINGEnum || tesTask.State == TesState.RUNNINGEnum;
static bool tesTaskIsQueuedOrInitializing(TesTask tesTask) => tesTask.State == TesState.QUEUEDEnum || tesTask.State == TesState.INITIALIZINGEnum;
Expand Down Expand Up @@ -1617,6 +1615,7 @@ private static string RemoveQueryStringsFromLocalFilePaths(string originalString
/// <returns>The virtual machine info</returns>
public async Task<VirtualMachineInformation> GetVmSizeAsync(TesTask tesTask, bool forcePreemptibleVmsOnly = false)
{
var allowedVmSizes = await allowedVmSizesService.GetAllowedVmSizes();
bool allowedVmSizesFilter(VirtualMachineInformation vm) => allowedVmSizes is null || !allowedVmSizes.Any() || allowedVmSizes.Contains(vm.VmSize, StringComparer.OrdinalIgnoreCase) || allowedVmSizes.Contains(vm.VmFamily, StringComparer.OrdinalIgnoreCase);

var tesResources = tesTask.Resources;
Expand Down
9 changes: 3 additions & 6 deletions src/TesApi.Web/ConfigurationUtils.cs
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ public ConfigurationUtils(
/// entries in the allowed-vm-sizes file with a warning. Sets the AllowedVmSizes configuration key.
/// </summary>
/// <returns></returns>
public async Task ProcessAllowedVmSizesConfigurationFileAsync()
public async Task<List<string>> ProcessAllowedVmSizesConfigurationFileAsync()
{
var supportedVmSizesFilePath = $"/{defaultStorageAccountName}/configuration/supported-vm-sizes";
var allowedVmSizesFilePath = $"/{defaultStorageAccountName}/configuration/allowed-vm-sizes";
Expand All @@ -95,7 +95,7 @@ public async Task ProcessAllowedVmSizesConfigurationFileAsync()
if (allowedVmSizesFileContent is null)
{
logger.LogWarning($"Unable to read from {allowedVmSizesFilePath}. All supported VM sizes will be eligible for Azure Batch task scheduling.");
return;
return new List<string>();
}

// Read the allowed-vm-sizes configuration file and remove any previous warnings (those start with "<" following the VM size or family name)
Expand Down Expand Up @@ -141,10 +141,7 @@ public async Task ProcessAllowedVmSizesConfigurationFileAsync()
}
}

if (allowedAndSupportedVmSizes.Any())
{
this.configuration["AllowedVmSizes"] = string.Join(',', allowedAndSupportedVmSizes);
}
return allowedAndSupportedVmSizes;
}

/// <summary>
Expand Down
Loading