Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Prevent block inside ResolveAsync from blocking PollingResolver.Refresh #2385

Merged
merged 4 commits into from
Mar 12, 2024
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion src/Grpc.Net.Client/Balancer/PollingResolver.cs
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,9 @@ public sealed override void Refresh()

if (_resolveTask.IsCompleted)
{
_resolveTask = ResolveNowAsync(_cts.Token);
// Run ResolveAsync in a background task.
// This is done to prevent synchronous block inside ResolveAsync from blocking future Refresh calls.
_resolveTask = Task.Run(() => ResolveNowAsync(_cts.Token), _cts.Token);
_resolveTask.ContinueWith(static (t, state) =>
{
var pollingResolver = (PollingResolver)state!;
Expand Down
17 changes: 16 additions & 1 deletion test/Grpc.Net.Client.Tests/Balancer/ConnectionManagerTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -409,6 +409,7 @@ public async Task PickAsync_DoesNotDeadlockAfterReconnect_WithResolverError()
services.AddNUnitLogger();
await using var serviceProvider = services.BuildServiceProvider();
var loggerFactory = serviceProvider.GetRequiredService<ILoggerFactory>();
var logger = loggerFactory.CreateLogger(GetType());

var resolver = new TestResolver(loggerFactory);

Expand All @@ -427,22 +428,29 @@ public async Task PickAsync_DoesNotDeadlockAfterReconnect_WithResolverError()
clientChannel));

// Act
logger.LogInformation("Client connecting.");
var connectTask = clientChannel.ConnectAsync(waitForReady: true, cancellationToken: CancellationToken.None);

logger.LogInformation("Starting pick on connecting channel.");
var pickTask = clientChannel.PickAsync(
new PickContext { Request = new HttpRequestMessage() },
waitForReady: true,
CancellationToken.None).AsTask();

logger.LogInformation("Waiting for resolve to complete.");
await resolver.HasResolvedTask.DefaultTimeout();

resolver.UpdateAddresses(new List<BalancerAddress>
{
new BalancerAddress("localhost", 80)
});
await Task.WhenAll(connectTask, pickTask).DefaultTimeout();

// Simulate transport/network issue
logger.LogInformation("Simulate transport/network issue.");
transportFactory.Transports.ForEach(t => t.Disconnect());
resolver.UpdateError(new Status(StatusCode.Unavailable, "Test error"));

logger.LogInformation("Starting pick on disconnected channel.");
pickTask = clientChannel.PickAsync(
new PickContext { Request = new HttpRequestMessage() },
waitForReady: true,
Expand All @@ -454,7 +462,10 @@ public async Task PickAsync_DoesNotDeadlockAfterReconnect_WithResolverError()

// Assert
// Should not timeout (deadlock)
logger.LogInformation("Wait for pick task to complete.");
await pickTask.DefaultTimeout();

logger.LogInformation("Done.");
}

[Test]
Expand Down Expand Up @@ -489,6 +500,8 @@ public async Task PickAsync_DoesNotDeadlockAfterReconnect_WithZeroAddressResolve
waitForReady: true,
CancellationToken.None).AsTask();

await resolver.HasResolvedTask.DefaultTimeout();

resolver.UpdateAddresses(new List<BalancerAddress>
{
new BalancerAddress("localhost", 80)
Expand Down Expand Up @@ -560,6 +573,8 @@ public async Task PickAsync_ExecutionContext_DoesNotCaptureAsyncLocalsInConnect(
waitForReady: true,
CancellationToken.None).AsTask();

await resolver.HasResolvedTask.DefaultTimeout();

resolver.UpdateAddresses(new List<BalancerAddress>
{
new BalancerAddress("localhost", 80)
Expand Down
2 changes: 2 additions & 0 deletions test/Grpc.Net.Client.Tests/Balancer/PickFirstBalancerTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -196,6 +196,8 @@ public async Task ResolverError_HasFailedSubchannel_SubchannelShutdown()
_ = channel.ConnectAsync();

// Assert
await resolver.HasResolvedTask.DefaultTimeout();

var subchannels = channel.ConnectionManager.GetSubchannels();
Assert.AreEqual(1, subchannels.Count);

Expand Down
65 changes: 65 additions & 0 deletions test/Grpc.Net.Client.Tests/Balancer/ResolverTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,71 @@ namespace Grpc.Net.Client.Tests.Balancer;
[TestFixture]
public class ResolverTests
{
[Test]
public async Task Refresh_BlockInsideResolveAsync_ResolverNotBlocked()
{
var waitHandle = new ManualResetEvent(false);

var services = new ServiceCollection();
var testSink = new TestSink();
services.AddLogging(b =>
{
b.AddProvider(new TestLoggerProvider(testSink));
});
services.AddNUnitLogger();
var loggerFactory = services.BuildServiceProvider().GetRequiredService<ILoggerFactory>();

var logger = loggerFactory.CreateLogger<ResolverTests>();
logger.LogInformation("Starting.");

var ss = new LockingPollingResolver(loggerFactory, waitHandle);
ss.Start(result => { });

logger.LogInformation("Refresh call 1. This should block.");
var refreshTask1 = Task.Run(ss.Refresh);

logger.LogInformation("Refresh call 2. This should complete.");
var refreshTask2 = Task.Run(ss.Refresh);

await Task.WhenAny(refreshTask1, refreshTask2);

logger.LogInformation("Setting wait handle.");
waitHandle.Set();

logger.LogInformation("Finishing.");
}

private class LockingPollingResolver : PollingResolver
{
private ManualResetEvent? _waitHandle;
private readonly object _lock = new();

public LockingPollingResolver(ILoggerFactory loggerFactory, ManualResetEvent waitHandle) : base(loggerFactory)
{
_waitHandle = waitHandle;
}

protected override Task ResolveAsync(CancellationToken cancellationToken)
{
lock (_lock)
{
// Block the first caller.
if (_waitHandle != null)
{
_waitHandle.WaitOne();
_waitHandle = null;
}

Listener(ResolverResult.ForResult(new List<BalancerAddress>
{
new BalancerAddress("localhost", 80)
}));

return Task.CompletedTask;
}
}
}

[Test]
public async Task Resolver_ResolveNameFromServices_Success()
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -187,7 +187,10 @@ public async Task ResolverError_HasFailedSubchannel_SubchannelShutdown()
_ = channel.ConnectAsync();

// Assert
await resolver.HasResolvedTask.DefaultTimeout();

var subchannels = channel.ConnectionManager.GetSubchannels();

Assert.AreEqual(1, subchannels.Count);

Assert.AreEqual(1, subchannels[0]._addresses.Count);
Expand Down
13 changes: 9 additions & 4 deletions test/Grpc.Net.Client.Tests/Balancer/WaitForReadyTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -58,21 +58,26 @@ public async Task ResolverReturnsNoAddresses_CallWithWaitForReady_Wait()
});

var services = new ServiceCollection();
services.AddNUnitLogger();

var resolver = new TestResolver();

services.AddSingleton<ResolverFactory>(new TestResolverFactory(resolver));
services.AddSingleton<TestResolver>();
services.AddSingleton<ResolverFactory>(s => new TestResolverFactory(s.GetRequiredService<TestResolver>()));
services.AddSingleton<ISubchannelTransportFactory>(new TestSubchannelTransportFactory());
var serviceProvider = services.BuildServiceProvider();

var resolver = serviceProvider.GetRequiredService<TestResolver>();

var invoker = HttpClientCallInvokerFactory.Create(testMessageHandler, "test:///localhost", configure: o =>
{
o.Credentials = ChannelCredentials.Insecure;
o.ServiceProvider = services.BuildServiceProvider();
o.ServiceProvider = serviceProvider;
});

// Act
var call = invoker.AsyncUnaryCall<HelloRequest, HelloReply>(ClientTestHelpers.ServiceMethod, string.Empty, new CallOptions().WithWaitForReady(), new HelloRequest());

await resolver.HasResolvedTask.DefaultTimeout();

var responseTask = call.ResponseAsync;

Assert.IsFalse(responseTask.IsCompleted);
Expand Down
6 changes: 5 additions & 1 deletion test/Shared/TestResolver.cs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
#region Copyright notice and license
#region Copyright notice and license

// Copyright 2019 The gRPC Authors
//
Expand Down Expand Up @@ -34,6 +34,7 @@ internal class TestResolver : PollingResolver
{
private readonly Func<Task>? _onRefreshAsync;
private readonly TaskCompletionSource<object?> _hasResolvedTcs;
private readonly ILogger _logger;
private ResolverResult? _result;

public Task HasResolvedTask => _hasResolvedTcs.Task;
Expand All @@ -46,15 +47,18 @@ public TestResolver(ILoggerFactory? loggerFactory = null, Func<Task>? onRefreshA
{
_onRefreshAsync = onRefreshAsync;
_hasResolvedTcs = new TaskCompletionSource<object?>(TaskCreationOptions.RunContinuationsAsynchronously);
_logger = (ILogger?)loggerFactory?.CreateLogger<TestResolver>() ?? NullLogger.Instance;
}

public void UpdateAddresses(List<BalancerAddress> addresses, ServiceConfig? serviceConfig = null, Status? serviceConfigStatus = null)
{
_logger.LogInformation("Updating result addresses: {Addresses}", string.Join(", ", addresses));
UpdateResult(ResolverResult.ForResult(addresses, serviceConfig, serviceConfigStatus));
}

public void UpdateError(Status status)
{
_logger.LogInformation("Updating result error: {Status}", status);
UpdateResult(ResolverResult.ForFailure(status));
}

Expand Down
Loading