Skip to content

Commit 04d62dc

Browse files
committed
Fix capturing ExecutionContext by timers and background tasks
1 parent 2dc971e commit 04d62dc

26 files changed

+392
-140
lines changed

src/Grpc.AspNetCore.Server/Grpc.AspNetCore.Server.csproj

+1
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
<Compile Include="..\Shared\Server\UnaryServerMethodInvoker.cs" Link="Model\Internal\UnaryServerMethodInvoker.cs" />
3131
<Compile Include="..\Shared\NullableAttributes.cs" Link="Internal\NullableAttributes.cs" />
3232
<Compile Include="..\Shared\CodeAnalysisAttributes.cs" Link="Internal\CodeAnalysisAttributes.cs" />
33+
<Compile Include="..\Shared\NonCapturingTimer.cs" Link="Internal\NonCapturingTimer.cs" />
3334
</ItemGroup>
3435

3536
<ItemGroup>

src/Grpc.AspNetCore.Server/Internal/ServerCallDeadlineManager.cs

+3-3
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
#region Copyright notice and license
1+
#region Copyright notice and license
22

33
// Copyright 2019 The gRPC Authors
44
//
@@ -91,12 +91,12 @@ public ServerCallDeadlineManager(HttpContextServerCallContext serverCallContext,
9191
// Ensures there is no weird situation where the timer triggers
9292
// before the field is set. Shouldn't happen because only long deadlines
9393
// will take this path but better to be safe than sorry.
94-
_longDeadlineTimer = new Timer(DeadlineExceededLongDelegate, (this, maxTimerDueTime), Timeout.Infinite, Timeout.Infinite);
94+
_longDeadlineTimer = NonCapturingTimer.Create(DeadlineExceededLongDelegate, (this, maxTimerDueTime), Timeout.InfiniteTimeSpan, Timeout.InfiniteTimeSpan);
9595
_longDeadlineTimer.Change(timerMilliseconds, Timeout.Infinite);
9696
}
9797
else
9898
{
99-
_longDeadlineTimer = new Timer(DeadlineExceededDelegate, this, timerMilliseconds, Timeout.Infinite);
99+
_longDeadlineTimer = NonCapturingTimer.Create(DeadlineExceededDelegate, this, TimeSpan.FromMilliseconds(timerMilliseconds), Timeout.InfiniteTimeSpan);
100100
}
101101
}
102102

src/Grpc.Net.Client/Balancer/DnsResolver.cs

+1-1
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,7 @@ protected override void OnStarted()
6868

6969
if (_refreshInterval != Timeout.InfiniteTimeSpan)
7070
{
71-
_timer = new Timer(OnTimerCallback, null, Timeout.InfiniteTimeSpan, Timeout.InfiniteTimeSpan);
71+
_timer = NonCapturingTimer.Create(OnTimerCallback, state: null, Timeout.InfiniteTimeSpan, Timeout.InfiniteTimeSpan);
7272
_timer.Change(_refreshInterval, _refreshInterval);
7373
}
7474
}

src/Grpc.Net.Client/Balancer/Internal/SocketConnectivitySubchannelTransport.cs

+2-2
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
#region Copyright notice and license
1+
#region Copyright notice and license
22

33
// Copyright 2019 The gRPC Authors
44
//
@@ -77,7 +77,7 @@ public SocketConnectivitySubchannelTransport(
7777
ConnectTimeout = connectTimeout;
7878
_socketConnect = socketConnect ?? OnConnect;
7979
_activeStreams = new List<ActiveStream>();
80-
_socketConnectedTimer = new Timer(OnCheckSocketConnection, state: null, Timeout.InfiniteTimeSpan, Timeout.InfiniteTimeSpan);
80+
_socketConnectedTimer = NonCapturingTimer.Create(OnCheckSocketConnection, state: null, Timeout.InfiniteTimeSpan, Timeout.InfiniteTimeSpan);
8181
}
8282

8383
private object Lock => _subchannel.Lock;

src/Grpc.Net.Client/Balancer/PollingResolver.cs

+2-2
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
#region Copyright notice and license
1+
#region Copyright notice and license
22

33
// Copyright 2019 The gRPC Authors
44
//
@@ -86,7 +86,7 @@ protected PollingResolver(ILoggerFactory loggerFactory, IBackoffPolicyFactory? b
8686
/// </para>
8787
/// </summary>
8888
/// <param name="listener">The callback used to receive updates on the target.</param>
89-
public override sealed void Start(Action<ResolverResult> listener)
89+
public sealed override void Start(Action<ResolverResult> listener)
9090
{
9191
if (listener == null)
9292
{

src/Grpc.Net.Client/Balancer/Subchannel.cs

+16-2
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
#region Copyright notice and license
1+
#region Copyright notice and license
22

33
// Copyright 2019 The gRPC Authors
44
//
@@ -233,7 +233,21 @@ public void RequestConnection()
233233
}
234234
}
235235

236-
_ = ConnectTransportAsync();
236+
// Don't capture the current ExecutionContext and its AsyncLocals onto the connect
237+
bool restoreFlow = false;
238+
if (!ExecutionContext.IsFlowSuppressed())
239+
{
240+
ExecutionContext.SuppressFlow();
241+
restoreFlow = true;
242+
}
243+
244+
_ = Task.Run(ConnectTransportAsync);
245+
246+
// Restore the current ExecutionContext
247+
if (restoreFlow)
248+
{
249+
ExecutionContext.RestoreFlow();
250+
}
237251
}
238252

239253
private void CancelInProgressConnect()

src/Grpc.Net.Client/Grpc.Net.Client.csproj

+2-1
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
<Project Sdk="Microsoft.NET.Sdk">
1+
<Project Sdk="Microsoft.NET.Sdk">
22

33
<PropertyGroup>
44
<Description>.NET client for gRPC</Description>
@@ -35,6 +35,7 @@
3535
<Compile Include="..\Shared\NullableAttributes.cs" Link="Internal\NullableAttributes.cs" />
3636
<Compile Include="..\Shared\Http2ErrorCode.cs" Link="Internal\Http2ErrorCode.cs" />
3737
<Compile Include="..\Shared\Http3ErrorCode.cs" Link="Internal\Http3ErrorCode.cs" />
38+
<Compile Include="..\Shared\NonCapturingTimer.cs" Link="Internal\NonCapturingTimer.cs" />
3839
<Compile Include="..\Shared\NonDisposableMemoryStream.cs" Link="Internal\NonDisposableMemoryStream.cs" />
3940
</ItemGroup>
4041

src/Grpc.Net.Client/GrpcChannel.cs

+16-7
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
#endregion
1818

1919
using System.Collections.Concurrent;
20+
using System.Net.Mail;
2021
using Grpc.Core;
2122
#if SUPPORT_LOAD_BALANCING
2223
using Grpc.Net.Client.Balancer;
@@ -54,8 +55,7 @@ public sealed class GrpcChannel : ChannelBase, IDisposable
5455
private readonly Dictionary<MethodKey, MethodConfig>? _serviceConfigMethods;
5556
private readonly bool _isSecure;
5657
private readonly List<CallCredentials>? _callCredentials;
57-
// Internal for testing
58-
internal readonly HashSet<IDisposable> ActiveCalls;
58+
private readonly HashSet<IDisposable> _activeCalls;
5959

6060
internal Uri Address { get; }
6161
internal HttpMessageInvoker HttpInvoker { get; }
@@ -165,7 +165,7 @@ internal GrpcChannel(Uri address, GrpcChannelOptions channelOptions) : base(addr
165165
ThrowOperationCanceledOnCancellation = channelOptions.ThrowOperationCanceledOnCancellation;
166166
UnsafeUseInsecureChannelCallCredentials = channelOptions.UnsafeUseInsecureChannelCallCredentials;
167167
_createMethodInfoFunc = CreateMethodInfo;
168-
ActiveCalls = new HashSet<IDisposable>();
168+
_activeCalls = new HashSet<IDisposable>();
169169
if (channelOptions.ServiceConfig is { } serviceConfig)
170170
{
171171
RetryThrottling = serviceConfig.RetryThrottling != null ? CreateChannelRetryThrottling(serviceConfig.RetryThrottling) : null;
@@ -490,15 +490,15 @@ internal void RegisterActiveCall(IDisposable grpcCall)
490490
throw new ObjectDisposedException(nameof(GrpcChannel));
491491
}
492492

493-
ActiveCalls.Add(grpcCall);
493+
_activeCalls.Add(grpcCall);
494494
}
495495
}
496496

497497
internal void FinishActiveCall(IDisposable grpcCall)
498498
{
499499
lock (_lock)
500500
{
501-
ActiveCalls.Remove(grpcCall);
501+
_activeCalls.Remove(grpcCall);
502502
}
503503
}
504504

@@ -749,9 +749,9 @@ public void Dispose()
749749
return;
750750
}
751751

752-
if (ActiveCalls.Count > 0)
752+
if (_activeCalls.Count > 0)
753753
{
754-
activeCallsCopy = ActiveCalls.ToArray();
754+
activeCallsCopy = _activeCalls.ToArray();
755755
}
756756

757757
Disposed = true;
@@ -807,6 +807,15 @@ internal int GetRandomNumber(int minValue, int maxValue)
807807
}
808808
}
809809

810+
// Internal for testing
811+
internal IDisposable[] GetActiveCalls()
812+
{
813+
lock (_lock)
814+
{
815+
return _activeCalls.ToArray();
816+
}
817+
}
818+
810819
#if SUPPORT_LOAD_BALANCING
811820
private sealed class SubChannelTransportFactory : ISubchannelTransportFactory
812821
{

src/Grpc.Net.Client/Internal/GrpcCall.cs

+1-1
Original file line numberDiff line numberDiff line change
@@ -838,7 +838,7 @@ public Exception CreateFailureStatusException(Status status)
838838
GrpcCallLog.StartingDeadlineTimeout(Logger, timeout.Value);
839839

840840
var dueTime = CommonGrpcProtocolHelpers.GetTimerDueTime(timeout.Value, Channel.MaxTimerDueTime);
841-
_deadlineTimer = new Timer(DeadlineExceededCallback, null, dueTime, Timeout.Infinite);
841+
_deadlineTimer = NonCapturingTimer.Create(DeadlineExceededCallback, state: null, TimeSpan.FromMilliseconds(dueTime), Timeout.InfiniteTimeSpan);
842842
}
843843
}
844844

src/Shared/NonCapturingTimer.cs

+39
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
// Licensed to the .NET Foundation under one or more agreements.
2+
// The .NET Foundation licenses this file to you under the MIT license.
3+
4+
namespace Grpc.Shared;
5+
6+
// A convenience API for interacting with System.Threading.Timer in a way
7+
// that doesn't capture the ExecutionContext. We should be using this (or equivalent)
8+
// everywhere we use timers to avoid rooting any values stored in asynclocals.
9+
internal static class NonCapturingTimer
10+
{
11+
public static Timer Create(TimerCallback callback, object? state, TimeSpan dueTime, TimeSpan period)
12+
{
13+
if (callback is null)
14+
{
15+
throw new ArgumentNullException(nameof(callback));
16+
}
17+
18+
// Don't capture the current ExecutionContext and its AsyncLocals onto the timer
19+
bool restoreFlow = false;
20+
try
21+
{
22+
if (!ExecutionContext.IsFlowSuppressed())
23+
{
24+
ExecutionContext.SuppressFlow();
25+
restoreFlow = true;
26+
}
27+
28+
return new Timer(callback, state, dueTime, period);
29+
}
30+
finally
31+
{
32+
// Restore the current ExecutionContext
33+
if (restoreFlow)
34+
{
35+
ExecutionContext.RestoreFlow();
36+
}
37+
}
38+
}
39+
}

test/FunctionalTests/Balancer/BalancerHelpers.cs

+1-69
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
#region Copyright notice and license
1+
#region Copyright notice and license
22

33
// Copyright 2019 The gRPC Authors
44
//
@@ -185,74 +185,6 @@ public static async Task<GrpcChannel> CreateChannel(
185185
return channel;
186186
}
187187

188-
public static Task WaitForChannelStateAsync(ILogger logger, GrpcChannel channel, ConnectivityState state, int channelId = 1)
189-
{
190-
return WaitForChannelStatesAsync(logger, channel, new[] { state }, channelId);
191-
}
192-
193-
public static async Task WaitForChannelStatesAsync(ILogger logger, GrpcChannel channel, ConnectivityState[] states, int channelId = 1)
194-
{
195-
var statesText = string.Join(", ", states.Select(s => $"'{s}'"));
196-
logger.LogInformation($"Channel id {channelId}: Waiting for channel states {statesText}.");
197-
198-
var currentState = channel.State;
199-
200-
while (!states.Contains(currentState))
201-
{
202-
logger.LogInformation($"Channel id {channelId}: Current channel state '{currentState}' doesn't match expected states {statesText}.");
203-
204-
await channel.WaitForStateChangedAsync(currentState).DefaultTimeout();
205-
currentState = channel.State;
206-
}
207-
208-
logger.LogInformation($"Channel id {channelId}: Current channel state '{currentState}' matches expected states {statesText}.");
209-
}
210-
211-
public static async Task<Subchannel> WaitForSubchannelToBeReadyAsync(ILogger logger, GrpcChannel channel, Func<SubchannelPicker?, Subchannel[]>? getPickerSubchannels = null)
212-
{
213-
var subChannel = (await WaitForSubchannelsToBeReadyAsync(logger, channel, 1)).Single();
214-
return subChannel;
215-
}
216-
217-
public static async Task<Subchannel[]> WaitForSubchannelsToBeReadyAsync(ILogger logger, GrpcChannel channel, int expectedCount, Func<SubchannelPicker?, Subchannel[]>? getPickerSubchannels = null)
218-
{
219-
if (getPickerSubchannels == null)
220-
{
221-
getPickerSubchannels = (picker) =>
222-
{
223-
return picker switch
224-
{
225-
RoundRobinPicker roundRobinPicker => roundRobinPicker._subchannels.ToArray(),
226-
PickFirstPicker pickFirstPicker => new[] { pickFirstPicker.Subchannel },
227-
EmptyPicker emptyPicker => Array.Empty<Subchannel>(),
228-
null => Array.Empty<Subchannel>(),
229-
_ => throw new Exception("Unexpected picker type: " + picker.GetType().FullName)
230-
};
231-
};
232-
}
233-
234-
logger.LogInformation($"Waiting for subchannel ready count: {expectedCount}");
235-
236-
Subchannel[]? subChannelsCopy = null;
237-
await TestHelpers.AssertIsTrueRetryAsync(() =>
238-
{
239-
var picker = channel.ConnectionManager._picker;
240-
subChannelsCopy = getPickerSubchannels(picker);
241-
logger.LogInformation($"Current subchannel ready count: {subChannelsCopy.Length}");
242-
for (var i = 0; i < subChannelsCopy.Length; i++)
243-
{
244-
logger.LogInformation($"Ready subchannel: {subChannelsCopy[i]}");
245-
}
246-
247-
return subChannelsCopy.Length == expectedCount;
248-
}, "Wait for all subconnections to be connected.");
249-
250-
logger.LogInformation($"Finished waiting for subchannel ready.");
251-
252-
Debug.Assert(subChannelsCopy != null);
253-
return subChannelsCopy;
254-
}
255-
256188
public static T? GetInnerLoadBalancer<T>(GrpcChannel channel) where T : LoadBalancer
257189
{
258190
var balancer = (ChildHandlerLoadBalancer)channel.ConnectionManager._balancer!;

test/FunctionalTests/Balancer/ConnectionTests.cs

+2-2
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
#region Copyright notice and license
1+
#region Copyright notice and license
22

33
// Copyright 2019 The gRPC Authors
44
//
@@ -352,7 +352,7 @@ Task<HelloReply> UnaryMethod(HelloRequest request, ServerCallContext context)
352352

353353
await channel.ConnectAsync().DefaultTimeout();
354354

355-
await BalancerHelpers.WaitForSubchannelsToBeReadyAsync(Logger, channel, 2).DefaultTimeout();
355+
await BalancerWaitHelpers.WaitForSubchannelsToBeReadyAsync(Logger, channel, 2).DefaultTimeout();
356356

357357
var client = TestClientFactory.Create(channel, endpoint1.Method);
358358

test/FunctionalTests/Balancer/LeastUsedBalancerTests.cs

+2-2
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
#region Copyright notice and license
1+
#region Copyright notice and license
22

33
// Copyright 2019 The gRPC Authors
44
//
@@ -67,7 +67,7 @@ async Task<HelloReply> UnaryMethod(HelloRequest request, ServerCallContext conte
6767

6868
var channel = await BalancerHelpers.CreateChannel(LoggerFactory, new LoadBalancingConfig("least_used"), new[] { endpoint1.Address, endpoint2.Address }, connect: true);
6969

70-
await BalancerHelpers.WaitForSubchannelsToBeReadyAsync(
70+
await BalancerWaitHelpers.WaitForSubchannelsToBeReadyAsync(
7171
Logger,
7272
channel,
7373
expectedCount: 2,

test/FunctionalTests/Balancer/PickFirstBalancerTests.cs

+7-8
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
#region Copyright notice and license
1+
#region Copyright notice and license
22

33
// Copyright 2019 The gRPC Authors
44
//
@@ -238,8 +238,8 @@ Task<HelloReply> UnaryMethod(HelloRequest request, ServerCallContext context)
238238
Logger.LogInformation("Ending " + endpoint1.Address);
239239
endpoint1.Dispose();
240240

241-
await BalancerHelpers.WaitForSubchannelsToBeReadyAsync(Logger, channel, expectedCount: 1,
242-
getPickerSubchannels: picker=>
241+
await BalancerWaitHelpers.WaitForSubchannelsToBeReadyAsync(Logger, channel, expectedCount: 1,
242+
getPickerSubchannels: picker =>
243243
{
244244
// We want a subchannel that has no current address
245245
if (picker is PickFirstPicker pickFirstPicker)
@@ -293,8 +293,7 @@ Task<HelloReply> UnaryMethod(HelloRequest request, ServerCallContext context)
293293
Assert.AreEqual(ConnectivityState.Ready, channel.State);
294294

295295
// Wait for pooled connection to timeout and return to idle
296-
await channel.WaitForStateChangedAsync(channel.State).DefaultTimeout();
297-
Assert.AreEqual(ConnectivityState.Idle, channel.State);
296+
await BalancerWaitHelpers.WaitForChannelStateAsync(Logger, channel, ConnectivityState.Idle).DefaultTimeout();
298297

299298
reply = await client.UnaryCall(new HelloRequest { Name = "Balancer" }).ResponseAsync.DefaultTimeout();
300299
Assert.AreEqual("Balancer", reply.Message);
@@ -355,7 +354,7 @@ async Task<HelloReply> UnaryMethod(HelloRequest request, ServerCallContext conte
355354

356355
Logger.LogInformation($"All gRPC calls on server");
357356

358-
await BalancerHelpers.WaitForChannelStateAsync(Logger, channel, ConnectivityState.Ready).DefaultTimeout();
357+
await BalancerWaitHelpers.WaitForChannelStateAsync(Logger, channel, ConnectivityState.Ready).DefaultTimeout();
359358

360359
var balancer = BalancerHelpers.GetInnerLoadBalancer<PickFirstBalancer>(channel)!;
361360
var subchannel = balancer._subchannel!;
@@ -468,8 +467,8 @@ Task<HelloReply> UnaryMethod(HelloRequest request, ServerCallContext context)
468467
endpoint.Dispose();
469468

470469
await Task.WhenAll(
471-
BalancerHelpers.WaitForChannelStateAsync(Logger, channel1, ConnectivityState.Idle, channelId: 1),
472-
BalancerHelpers.WaitForChannelStateAsync(Logger, channel2, ConnectivityState.Idle, channelId: 2)).DefaultTimeout();
470+
BalancerWaitHelpers.WaitForChannelStateAsync(Logger, channel1, ConnectivityState.Idle, channelId: 1),
471+
BalancerWaitHelpers.WaitForChannelStateAsync(Logger, channel2, ConnectivityState.Idle, channelId: 2)).DefaultTimeout();
473472

474473
Logger.LogInformation("Restarting");
475474
using var endpointNew = BalancerHelpers.CreateGrpcEndpoint<HelloRequest, HelloReply>(50051, UnaryMethod, nameof(UnaryMethod));

0 commit comments

Comments
 (0)