Skip to content

Commit 933023b

Browse files
authored
Prevent block inside ResolveAsync from blocking PollingResolver.Refresh (#2385)
1 parent f29d927 commit 933023b

File tree

7 files changed

+106
-7
lines changed

7 files changed

+106
-7
lines changed

src/Grpc.Net.Client/Balancer/PollingResolver.cs

+3-1
Original file line numberDiff line numberDiff line change
@@ -135,7 +135,9 @@ public sealed override void Refresh()
135135

136136
if (_resolveTask.IsCompleted)
137137
{
138-
_resolveTask = ResolveNowAsync(_cts.Token);
138+
// Run ResolveAsync in a background task.
139+
// This is done to prevent synchronous block inside ResolveAsync from blocking future Refresh calls.
140+
_resolveTask = Task.Run(() => ResolveNowAsync(_cts.Token), _cts.Token);
139141
_resolveTask.ContinueWith(static (t, state) =>
140142
{
141143
var pollingResolver = (PollingResolver)state!;

test/Grpc.Net.Client.Tests/Balancer/ConnectionManagerTests.cs

+16-1
Original file line numberDiff line numberDiff line change
@@ -409,6 +409,7 @@ public async Task PickAsync_DoesNotDeadlockAfterReconnect_WithResolverError()
409409
services.AddNUnitLogger();
410410
await using var serviceProvider = services.BuildServiceProvider();
411411
var loggerFactory = serviceProvider.GetRequiredService<ILoggerFactory>();
412+
var logger = loggerFactory.CreateLogger(GetType());
412413

413414
var resolver = new TestResolver(loggerFactory);
414415

@@ -427,22 +428,29 @@ public async Task PickAsync_DoesNotDeadlockAfterReconnect_WithResolverError()
427428
clientChannel));
428429

429430
// Act
431+
logger.LogInformation("Client connecting.");
430432
var connectTask = clientChannel.ConnectAsync(waitForReady: true, cancellationToken: CancellationToken.None);
433+
434+
logger.LogInformation("Starting pick on connecting channel.");
431435
var pickTask = clientChannel.PickAsync(
432436
new PickContext { Request = new HttpRequestMessage() },
433437
waitForReady: true,
434438
CancellationToken.None).AsTask();
435439

440+
logger.LogInformation("Waiting for resolve to complete.");
441+
await resolver.HasResolvedTask.DefaultTimeout();
442+
436443
resolver.UpdateAddresses(new List<BalancerAddress>
437444
{
438445
new BalancerAddress("localhost", 80)
439446
});
440447
await Task.WhenAll(connectTask, pickTask).DefaultTimeout();
441448

442-
// Simulate transport/network issue
449+
logger.LogInformation("Simulate transport/network issue.");
443450
transportFactory.Transports.ForEach(t => t.Disconnect());
444451
resolver.UpdateError(new Status(StatusCode.Unavailable, "Test error"));
445452

453+
logger.LogInformation("Starting pick on disconnected channel.");
446454
pickTask = clientChannel.PickAsync(
447455
new PickContext { Request = new HttpRequestMessage() },
448456
waitForReady: true,
@@ -454,7 +462,10 @@ public async Task PickAsync_DoesNotDeadlockAfterReconnect_WithResolverError()
454462

455463
// Assert
456464
// Should not timeout (deadlock)
465+
logger.LogInformation("Wait for pick task to complete.");
457466
await pickTask.DefaultTimeout();
467+
468+
logger.LogInformation("Done.");
458469
}
459470

460471
[Test]
@@ -489,6 +500,8 @@ public async Task PickAsync_DoesNotDeadlockAfterReconnect_WithZeroAddressResolve
489500
waitForReady: true,
490501
CancellationToken.None).AsTask();
491502

503+
await resolver.HasResolvedTask.DefaultTimeout();
504+
492505
resolver.UpdateAddresses(new List<BalancerAddress>
493506
{
494507
new BalancerAddress("localhost", 80)
@@ -560,6 +573,8 @@ public async Task PickAsync_ExecutionContext_DoesNotCaptureAsyncLocalsInConnect(
560573
waitForReady: true,
561574
CancellationToken.None).AsTask();
562575

576+
await resolver.HasResolvedTask.DefaultTimeout();
577+
563578
resolver.UpdateAddresses(new List<BalancerAddress>
564579
{
565580
new BalancerAddress("localhost", 80)

test/Grpc.Net.Client.Tests/Balancer/PickFirstBalancerTests.cs

+2
Original file line numberDiff line numberDiff line change
@@ -196,6 +196,8 @@ public async Task ResolverError_HasFailedSubchannel_SubchannelShutdown()
196196
_ = channel.ConnectAsync();
197197

198198
// Assert
199+
await resolver.HasResolvedTask.DefaultTimeout();
200+
199201
var subchannels = channel.ConnectionManager.GetSubchannels();
200202
Assert.AreEqual(1, subchannels.Count);
201203

test/Grpc.Net.Client.Tests/Balancer/ResolverTests.cs

+68
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,74 @@ namespace Grpc.Net.Client.Tests.Balancer;
4242
[TestFixture]
4343
public class ResolverTests
4444
{
45+
[Test]
46+
public async Task Refresh_BlockInsideResolveAsync_ResolverNotBlocked()
47+
{
48+
// Arrange
49+
var waitHandle = new ManualResetEvent(false);
50+
51+
var services = new ServiceCollection();
52+
var testSink = new TestSink();
53+
services.AddLogging(b =>
54+
{
55+
b.AddProvider(new TestLoggerProvider(testSink));
56+
});
57+
services.AddNUnitLogger();
58+
var loggerFactory = services.BuildServiceProvider().GetRequiredService<ILoggerFactory>();
59+
60+
var logger = loggerFactory.CreateLogger<ResolverTests>();
61+
logger.LogInformation("Starting.");
62+
63+
var lockingResolver = new LockingPollingResolver(loggerFactory, waitHandle);
64+
lockingResolver.Start(result => { });
65+
66+
// Act
67+
logger.LogInformation("Refresh call 1. This should block.");
68+
var refreshTask1 = Task.Run(lockingResolver.Refresh);
69+
70+
logger.LogInformation("Refresh call 2. This should complete.");
71+
var refreshTask2 = Task.Run(lockingResolver.Refresh);
72+
73+
// Assert
74+
await Task.WhenAny(refreshTask1, refreshTask2).DefaultTimeout();
75+
76+
logger.LogInformation("Setting wait handle.");
77+
waitHandle.Set();
78+
79+
logger.LogInformation("Finishing.");
80+
}
81+
82+
private class LockingPollingResolver : PollingResolver
83+
{
84+
private ManualResetEvent? _waitHandle;
85+
private readonly object _lock = new();
86+
87+
public LockingPollingResolver(ILoggerFactory loggerFactory, ManualResetEvent waitHandle) : base(loggerFactory)
88+
{
89+
_waitHandle = waitHandle;
90+
}
91+
92+
protected override Task ResolveAsync(CancellationToken cancellationToken)
93+
{
94+
lock (_lock)
95+
{
96+
// Block the first caller.
97+
if (_waitHandle != null)
98+
{
99+
_waitHandle.WaitOne();
100+
_waitHandle = null;
101+
}
102+
}
103+
104+
Listener(ResolverResult.ForResult(new List<BalancerAddress>
105+
{
106+
new BalancerAddress("localhost", 80)
107+
}));
108+
109+
return Task.CompletedTask;
110+
}
111+
}
112+
45113
[Test]
46114
public async Task Resolver_ResolveNameFromServices_Success()
47115
{

test/Grpc.Net.Client.Tests/Balancer/RoundRobinBalancerTests.cs

+3
Original file line numberDiff line numberDiff line change
@@ -187,7 +187,10 @@ public async Task ResolverError_HasFailedSubchannel_SubchannelShutdown()
187187
_ = channel.ConnectAsync();
188188

189189
// Assert
190+
await resolver.HasResolvedTask.DefaultTimeout();
191+
190192
var subchannels = channel.ConnectionManager.GetSubchannels();
193+
191194
Assert.AreEqual(1, subchannels.Count);
192195

193196
Assert.AreEqual(1, subchannels[0]._addresses.Count);

test/Grpc.Net.Client.Tests/Balancer/WaitForReadyTests.cs

+9-4
Original file line numberDiff line numberDiff line change
@@ -58,21 +58,26 @@ public async Task ResolverReturnsNoAddresses_CallWithWaitForReady_Wait()
5858
});
5959

6060
var services = new ServiceCollection();
61+
services.AddNUnitLogger();
6162

62-
var resolver = new TestResolver();
63-
64-
services.AddSingleton<ResolverFactory>(new TestResolverFactory(resolver));
63+
services.AddSingleton<TestResolver>();
64+
services.AddSingleton<ResolverFactory>(s => new TestResolverFactory(s.GetRequiredService<TestResolver>()));
6565
services.AddSingleton<ISubchannelTransportFactory>(new TestSubchannelTransportFactory());
66+
var serviceProvider = services.BuildServiceProvider();
67+
68+
var resolver = serviceProvider.GetRequiredService<TestResolver>();
6669

6770
var invoker = HttpClientCallInvokerFactory.Create(testMessageHandler, "test:///localhost", configure: o =>
6871
{
6972
o.Credentials = ChannelCredentials.Insecure;
70-
o.ServiceProvider = services.BuildServiceProvider();
73+
o.ServiceProvider = serviceProvider;
7174
});
7275

7376
// Act
7477
var call = invoker.AsyncUnaryCall<HelloRequest, HelloReply>(ClientTestHelpers.ServiceMethod, string.Empty, new CallOptions().WithWaitForReady(), new HelloRequest());
7578

79+
await resolver.HasResolvedTask.DefaultTimeout();
80+
7681
var responseTask = call.ResponseAsync;
7782

7883
Assert.IsFalse(responseTask.IsCompleted);

test/Shared/TestResolver.cs

+5-1
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
#region Copyright notice and license
1+
#region Copyright notice and license
22

33
// Copyright 2019 The gRPC Authors
44
//
@@ -34,6 +34,7 @@ internal class TestResolver : PollingResolver
3434
{
3535
private readonly Func<Task>? _onRefreshAsync;
3636
private readonly TaskCompletionSource<object?> _hasResolvedTcs;
37+
private readonly ILogger _logger;
3738
private ResolverResult? _result;
3839

3940
public Task HasResolvedTask => _hasResolvedTcs.Task;
@@ -46,15 +47,18 @@ public TestResolver(ILoggerFactory? loggerFactory = null, Func<Task>? onRefreshA
4647
{
4748
_onRefreshAsync = onRefreshAsync;
4849
_hasResolvedTcs = new TaskCompletionSource<object?>(TaskCreationOptions.RunContinuationsAsynchronously);
50+
_logger = (ILogger?)loggerFactory?.CreateLogger<TestResolver>() ?? NullLogger.Instance;
4951
}
5052

5153
public void UpdateAddresses(List<BalancerAddress> addresses, ServiceConfig? serviceConfig = null, Status? serviceConfigStatus = null)
5254
{
55+
_logger.LogInformation("Updating result addresses: {Addresses}", string.Join(", ", addresses));
5356
UpdateResult(ResolverResult.ForResult(addresses, serviceConfig, serviceConfigStatus));
5457
}
5558

5659
public void UpdateError(Status status)
5760
{
61+
_logger.LogInformation("Updating result error: {Status}", status);
5862
UpdateResult(ResolverResult.ForFailure(status));
5963
}
6064

0 commit comments

Comments
 (0)