diff --git a/build/dependencies.props b/build/dependencies.props index f96f84e714..6f5e1830a5 100644 --- a/build/dependencies.props +++ b/build/dependencies.props @@ -9,6 +9,7 @@ 2.1.0-preview1-15549 2.1.0-preview1-27579 2.1.0-preview1-27579 + 2.1.0-preview1-27475 2.1.0-preview1-27579 2.1.0-preview1-27579 2.1.0-preview1-27579 diff --git a/samples/ClientSample/HubSample.cs b/samples/ClientSample/HubSample.cs index 397cb8f433..ebd805f5b0 100644 --- a/samples/ClientSample/HubSample.cs +++ b/samples/ClientSample/HubSample.cs @@ -6,7 +6,6 @@ using System.Threading; using System.Threading.Tasks; using Microsoft.AspNetCore.SignalR.Client; -using Microsoft.AspNetCore.Sockets.Client; using Microsoft.Extensions.CommandLineUtils; using Microsoft.Extensions.Logging; diff --git a/src/Microsoft.AspNetCore.SignalR.Client.Core/HubConnectionBuilder.cs b/src/Microsoft.AspNetCore.SignalR.Client.Core/HubConnectionBuilder.cs index d40232dac0..9be6b2993b 100644 --- a/src/Microsoft.AspNetCore.SignalR.Client.Core/HubConnectionBuilder.cs +++ b/src/Microsoft.AspNetCore.SignalR.Client.Core/HubConnectionBuilder.cs @@ -42,9 +42,11 @@ public HubConnection Build() throw new InvalidOperationException("Cannot create IConnection instance. The connection factory was not configured."); } + IHubConnectionBuilder builder = this; var connection = _connectionFactoryDelegate(); - var loggerFactory = ((IHubConnectionBuilder)this).GetLoggerFactory(); - var hubProtocol = ((IHubConnectionBuilder)this).GetHubProtocol(); + + var loggerFactory = builder.GetLoggerFactory(); + var hubProtocol = builder.GetHubProtocol(); return new HubConnection(connection, hubProtocol ?? new JsonHubProtocol(), loggerFactory); } diff --git a/src/Microsoft.AspNetCore.SignalR.Client/HubConnectionBuilderHttpExtensions.cs b/src/Microsoft.AspNetCore.SignalR.Client/HubConnectionBuilderHttpExtensions.cs index 5ba00304ec..cdab76ef41 100644 --- a/src/Microsoft.AspNetCore.SignalR.Client/HubConnectionBuilderHttpExtensions.cs +++ b/src/Microsoft.AspNetCore.SignalR.Client/HubConnectionBuilderHttpExtensions.cs @@ -2,9 +2,12 @@ // Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. using System; +using System.Collections.Generic; +using System.Collections.ObjectModel; using System.Net.Http; using Microsoft.AspNetCore.Sockets; using Microsoft.AspNetCore.Sockets.Client; +using Microsoft.AspNetCore.Sockets.Client.Http; namespace Microsoft.AspNetCore.SignalR.Client { @@ -12,6 +15,8 @@ public static class HubConnectionBuilderHttpExtensions { public static readonly string TransportTypeKey = "TransportType"; public static readonly string HttpMessageHandlerKey = "HttpMessageHandler"; + public static readonly string HeadersKey = "Headers"; + public static readonly string JwtBearerTokenFactoryKey = "JwtBearerTokenFactory"; public static IHubConnectionBuilder WithUrl(this IHubConnectionBuilder hubConnectionBuilder, string url) { @@ -32,10 +37,18 @@ public static IHubConnectionBuilder WithUrl(this IHubConnectionBuilder hubConnec hubConnectionBuilder.ConfigureConnectionFactory(() => { + var headers = hubConnectionBuilder.GetHeaders(); + var httpOptions = new HttpOptions + { + HttpMessageHandler = hubConnectionBuilder.GetMessageHandler(), + Headers = headers != null ? new ReadOnlyDictionary(headers) : null, + JwtBearerTokenFactory = hubConnectionBuilder.GetJwtBearerTokenFactory() + }; + return new HttpConnection(url, hubConnectionBuilder.GetTransport(), hubConnectionBuilder.GetLoggerFactory(), - hubConnectionBuilder.GetMessageHandler()); + httpOptions); }); return hubConnectionBuilder; } @@ -52,6 +65,37 @@ public static IHubConnectionBuilder WithMessageHandler(this IHubConnectionBuilde return hubConnectionBuilder; } + public static IHubConnectionBuilder WithHeader(this IHubConnectionBuilder hubConnectionBuilder, string name, string value) + { + if (string.IsNullOrEmpty(name)) + { + throw new ArgumentException("Header name cannot be null or empty string.", nameof(name)); + } + + var headers = hubConnectionBuilder.GetHeaders(); + if (headers == null) + { + headers = new Dictionary(); + hubConnectionBuilder.AddSetting(HeadersKey, headers); + } + + headers.Add(name, value); + + return hubConnectionBuilder; + } + + public static IHubConnectionBuilder WithJwtBearer(this IHubConnectionBuilder hubConnectionBuilder, Func jwtBearerTokenFactory) + { + if (jwtBearerTokenFactory == null) + { + throw new ArgumentNullException(nameof(jwtBearerTokenFactory)); + } + + hubConnectionBuilder.AddSetting(JwtBearerTokenFactoryKey, jwtBearerTokenFactory); + + return hubConnectionBuilder; + } + public static TransportType GetTransport(this IHubConnectionBuilder hubConnectionBuilder) { if (hubConnectionBuilder.TryGetSetting(TransportTypeKey, out var transportType)) @@ -67,5 +111,25 @@ public static HttpMessageHandler GetMessageHandler(this IHubConnectionBuilder hu hubConnectionBuilder.TryGetSetting(HttpMessageHandlerKey, out var messageHandler); return messageHandler; } + + public static IDictionary GetHeaders(this IHubConnectionBuilder hubConnectionBuilder) + { + if (hubConnectionBuilder.TryGetSetting>(HeadersKey, out var headers)) + { + return headers; + } + + return null; + } + + public static Func GetJwtBearerTokenFactory(this IHubConnectionBuilder hubConnectionBuilder) + { + if (hubConnectionBuilder.TryGetSetting>(JwtBearerTokenFactoryKey, out var factory)) + { + return factory; + } + + return null; + } } } diff --git a/src/Microsoft.AspNetCore.Sockets.Client.Http/DefaultTransportFactory.cs b/src/Microsoft.AspNetCore.Sockets.Client.Http/DefaultTransportFactory.cs index 155348e692..2f1f93c8da 100644 --- a/src/Microsoft.AspNetCore.Sockets.Client.Http/DefaultTransportFactory.cs +++ b/src/Microsoft.AspNetCore.Sockets.Client.Http/DefaultTransportFactory.cs @@ -3,6 +3,7 @@ using System; using System.Net.Http; +using Microsoft.AspNetCore.Sockets.Client.Http; using Microsoft.Extensions.Logging; namespace Microsoft.AspNetCore.Sockets.Client @@ -10,11 +11,12 @@ namespace Microsoft.AspNetCore.Sockets.Client public class DefaultTransportFactory : ITransportFactory { private readonly HttpClient _httpClient; + private readonly HttpOptions _httpOptions; private readonly TransportType _requestedTransportType; private readonly ILoggerFactory _loggerFactory; private static volatile bool _websocketsSupported = true; - public DefaultTransportFactory(TransportType requestedTransportType, ILoggerFactory loggerFactory, HttpClient httpClient) + public DefaultTransportFactory(TransportType requestedTransportType, ILoggerFactory loggerFactory, HttpClient httpClient, HttpOptions httpOptions) { if (requestedTransportType <= 0 || requestedTransportType > TransportType.All) { @@ -29,6 +31,7 @@ public DefaultTransportFactory(TransportType requestedTransportType, ILoggerFact _requestedTransportType = requestedTransportType; _loggerFactory = loggerFactory; _httpClient = httpClient; + _httpOptions = httpOptions; } public ITransport CreateTransport(TransportType availableServerTransports) @@ -37,7 +40,7 @@ public ITransport CreateTransport(TransportType availableServerTransports) { try { - return new WebSocketsTransport(_loggerFactory); + return new WebSocketsTransport(_httpOptions, _loggerFactory); } catch (PlatformNotSupportedException) { @@ -47,12 +50,12 @@ public ITransport CreateTransport(TransportType availableServerTransports) if ((availableServerTransports & TransportType.ServerSentEvents & _requestedTransportType) == TransportType.ServerSentEvents) { - return new ServerSentEventsTransport(_httpClient, _loggerFactory); + return new ServerSentEventsTransport(_httpClient, _httpOptions, _loggerFactory); } if ((availableServerTransports & TransportType.LongPolling & _requestedTransportType) == TransportType.LongPolling) { - return new LongPollingTransport(_httpClient, _loggerFactory); + return new LongPollingTransport(_httpClient, _httpOptions, _loggerFactory); } throw new InvalidOperationException("No requested transports available on the server."); diff --git a/src/Microsoft.AspNetCore.Sockets.Client.Http/HttpConnection.cs b/src/Microsoft.AspNetCore.Sockets.Client.Http/HttpConnection.cs index 7187db3a57..e312bad199 100644 --- a/src/Microsoft.AspNetCore.Sockets.Client.Http/HttpConnection.cs +++ b/src/Microsoft.AspNetCore.Sockets.Client.Http/HttpConnection.cs @@ -31,6 +31,7 @@ public class HttpConnection : IConnection private volatile int _connectionState = ConnectionState.Initial; private volatile ChannelConnection _transportChannel; private readonly HttpClient _httpClient; + private readonly HttpOptions _httpOptions; private volatile ITransport _transport; private volatile Task _receiveLoopTask; private TaskCompletionSource _startTcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); @@ -54,49 +55,47 @@ public HttpConnection(Uri url) : this(url, TransportType.All) { } - public HttpConnection(Uri url, HttpMessageHandler httpMessageHandler) - : this(url, TransportType.All, loggerFactory: null, httpMessageHandler: httpMessageHandler) - { } - public HttpConnection(Uri url, TransportType transportType) : this(url, transportType, loggerFactory: null) { } public HttpConnection(Uri url, ILoggerFactory loggerFactory) - : this(url, TransportType.All, loggerFactory, httpMessageHandler: null) + : this(url, TransportType.All, loggerFactory, httpOptions: null) { } public HttpConnection(Uri url, TransportType transportType, ILoggerFactory loggerFactory) - : this(url, transportType, loggerFactory, httpMessageHandler: null) + : this(url, transportType, loggerFactory, httpOptions: null) { } - public HttpConnection(Uri url, TransportType transportType, ILoggerFactory loggerFactory, HttpMessageHandler httpMessageHandler) + public HttpConnection(Uri url, TransportType transportType, ILoggerFactory loggerFactory, HttpOptions httpOptions) { Url = url ?? throw new ArgumentNullException(nameof(url)); _loggerFactory = loggerFactory ?? NullLoggerFactory.Instance; _logger = _loggerFactory.CreateLogger(); + _httpOptions = httpOptions; _requestedTransportType = transportType; if (_requestedTransportType != TransportType.WebSockets) { - _httpClient = httpMessageHandler == null ? new HttpClient() : new HttpClient(httpMessageHandler); + _httpClient = httpOptions?.HttpMessageHandler == null ? new HttpClient() : new HttpClient(httpOptions.HttpMessageHandler); _httpClient.Timeout = HttpClientTimeout; } - _transportFactory = new DefaultTransportFactory(transportType, _loggerFactory, _httpClient); + _transportFactory = new DefaultTransportFactory(transportType, _loggerFactory, _httpClient, httpOptions); } - public HttpConnection(Uri url, ITransportFactory transportFactory, ILoggerFactory loggerFactory, HttpMessageHandler httpMessageHandler) + public HttpConnection(Uri url, ITransportFactory transportFactory, ILoggerFactory loggerFactory, HttpOptions httpOptions) { Url = url ?? throw new ArgumentNullException(nameof(url)); _loggerFactory = loggerFactory ?? NullLoggerFactory.Instance; _logger = _loggerFactory.CreateLogger(); - _httpClient = httpMessageHandler == null ? new HttpClient() : new HttpClient(httpMessageHandler); - _httpClient.Timeout = HttpClientTimeout; + _httpOptions = httpOptions; + _httpClient = _httpOptions?.HttpMessageHandler == null ? new HttpClient() : new HttpClient(_httpOptions?.HttpMessageHandler); + _httpClient.Timeout = HttpClientTimeout; _transportFactory = transportFactory ?? throw new ArgumentNullException(nameof(transportFactory)); } @@ -214,7 +213,7 @@ private async Task StartAsyncInternal() } } - private async static Task Negotiate(Uri url, HttpClient httpClient, ILogger logger) + private async Task Negotiate(Uri url, HttpClient httpClient, ILogger logger) { try { @@ -229,7 +228,8 @@ private async static Task Negotiate(Uri url, HttpClient htt using (var request = new HttpRequestMessage(HttpMethod.Post, urlBuilder.Uri)) { - request.Headers.UserAgent.Add(Constants.UserAgentHeader); + SendUtils.PrepareHttpRequest(request, _httpOptions); + using (var response = await httpClient.SendAsync(request)) { response.EnsureSuccessStatusCode(); diff --git a/src/Microsoft.AspNetCore.Sockets.Client.Http/HttpOptions.cs b/src/Microsoft.AspNetCore.Sockets.Client.Http/HttpOptions.cs new file mode 100644 index 0000000000..7ccfacef2c --- /dev/null +++ b/src/Microsoft.AspNetCore.Sockets.Client.Http/HttpOptions.cs @@ -0,0 +1,16 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System; +using System.Collections.Generic; +using System.Net.Http; + +namespace Microsoft.AspNetCore.Sockets.Client.Http +{ + public class HttpOptions + { + public HttpMessageHandler HttpMessageHandler { get; set; } + public IReadOnlyCollection> Headers { get; set; } + public Func JwtBearerTokenFactory { get; set; } + } +} diff --git a/src/Microsoft.AspNetCore.Sockets.Client.Http/LongPollingTransport.cs b/src/Microsoft.AspNetCore.Sockets.Client.Http/LongPollingTransport.cs index a059a81230..b2c8dec3ae 100644 --- a/src/Microsoft.AspNetCore.Sockets.Client.Http/LongPollingTransport.cs +++ b/src/Microsoft.AspNetCore.Sockets.Client.Http/LongPollingTransport.cs @@ -2,6 +2,7 @@ // Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. using System; +using System.Collections.Generic; using System.Net; using System.Net.Http; using System.Threading; @@ -17,6 +18,7 @@ namespace Microsoft.AspNetCore.Sockets.Client public class LongPollingTransport : ITransport { private readonly HttpClient _httpClient; + private readonly HttpOptions _httpOptions; private readonly ILogger _logger; private Channel _application; private Task _sender; @@ -30,12 +32,13 @@ public class LongPollingTransport : ITransport public TransferMode? Mode { get; private set; } public LongPollingTransport(HttpClient httpClient) - : this(httpClient, null) + : this(httpClient, null, null) { } - public LongPollingTransport(HttpClient httpClient, ILoggerFactory loggerFactory) + public LongPollingTransport(HttpClient httpClient, HttpOptions httpOptions, ILoggerFactory loggerFactory) { _httpClient = httpClient; + _httpOptions = httpOptions; _logger = (loggerFactory ?? NullLoggerFactory.Instance).CreateLogger(); } @@ -54,7 +57,7 @@ public Task StartAsync(Uri url, Channel application, Transf // Start sending and polling (ask for binary if the server supports it) _poller = Poll(url, _transportCts.Token); - _sender = SendUtils.SendMessages(url, _application, _httpClient, _transportCts, _logger, _connectionId); + _sender = SendUtils.SendMessages(url, _application, _httpClient, _httpOptions, _transportCts, _logger, _connectionId); Running = Task.WhenAll(_sender, _poller).ContinueWith(t => { @@ -90,7 +93,7 @@ private async Task Poll(Uri pollUrl, CancellationToken cancellationToken) while (!cancellationToken.IsCancellationRequested) { var request = new HttpRequestMessage(HttpMethod.Get, pollUrl); - request.Headers.UserAgent.Add(Constants.UserAgentHeader); + SendUtils.PrepareHttpRequest(request, _httpOptions); HttpResponseMessage response; diff --git a/src/Microsoft.AspNetCore.Sockets.Client.Http/SendUtils.cs b/src/Microsoft.AspNetCore.Sockets.Client.Http/SendUtils.cs index af05b12060..023054ab30 100644 --- a/src/Microsoft.AspNetCore.Sockets.Client.Http/SendUtils.cs +++ b/src/Microsoft.AspNetCore.Sockets.Client.Http/SendUtils.cs @@ -17,7 +17,7 @@ namespace Microsoft.AspNetCore.Sockets.Client internal static class SendUtils { public static async Task SendMessages(Uri sendUrl, Channel application, HttpClient httpClient, - CancellationTokenSource transportCts, ILogger logger, string connectionId) + HttpOptions httpOptions, CancellationTokenSource transportCts, ILogger logger, string connectionId) { logger.SendStarted(connectionId); IList messages = null; @@ -39,7 +39,7 @@ public static async Task SendMessages(Uri sendUrl, Channel // Send them in a single post var request = new HttpRequestMessage(HttpMethod.Post, sendUrl); - request.Headers.UserAgent.Add(Constants.UserAgentHeader); + PrepareHttpRequest(request, httpOptions); // TODO: We can probably use a pipeline here or some kind of pooled memory. // But where do we get the pool from? ArrayBufferPool.Instance? @@ -107,5 +107,22 @@ public static async Task SendMessages(Uri sendUrl, Channel logger.SendStopped(connectionId); } + + public static void PrepareHttpRequest(HttpRequestMessage request, HttpOptions httpOptions) + { + if (httpOptions?.Headers != null) + { + foreach (var header in httpOptions.Headers) + { + request.Headers.Add(header.Key, header.Value); + } + } + request.Headers.UserAgent.Add(Constants.UserAgentHeader); + + if (httpOptions?.JwtBearerTokenFactory != null) + { + request.Headers.Add("Authorization", $"Bearer {httpOptions.JwtBearerTokenFactory()}"); + } + } } } diff --git a/src/Microsoft.AspNetCore.Sockets.Client.Http/ServerSentEventsTransport.cs b/src/Microsoft.AspNetCore.Sockets.Client.Http/ServerSentEventsTransport.cs index 62e92d12c2..2f288562a1 100644 --- a/src/Microsoft.AspNetCore.Sockets.Client.Http/ServerSentEventsTransport.cs +++ b/src/Microsoft.AspNetCore.Sockets.Client.Http/ServerSentEventsTransport.cs @@ -9,6 +9,7 @@ using System.Threading; using System.Threading.Tasks; using System.Threading.Channels; +using Microsoft.AspNetCore.Sockets.Client.Http; using Microsoft.AspNetCore.Sockets.Client.Internal; using Microsoft.AspNetCore.Sockets.Internal.Formatters; using Microsoft.Extensions.Logging; @@ -20,6 +21,7 @@ public class ServerSentEventsTransport : ITransport { private static readonly MemoryPool _memoryPool = new MemoryPool(); private readonly HttpClient _httpClient; + private readonly HttpOptions _httpOptions; private readonly ILogger _logger; private readonly CancellationTokenSource _transportCts = new CancellationTokenSource(); private readonly ServerSentEventsMessageParser _parser = new ServerSentEventsMessageParser(); @@ -32,10 +34,10 @@ public class ServerSentEventsTransport : ITransport public TransferMode? Mode { get; private set; } public ServerSentEventsTransport(HttpClient httpClient) - : this(httpClient, null) + : this(httpClient, null, null) { } - public ServerSentEventsTransport(HttpClient httpClient, ILoggerFactory loggerFactory) + public ServerSentEventsTransport(HttpClient httpClient, HttpOptions httpOptions, ILoggerFactory loggerFactory) { if (httpClient == null) { @@ -43,6 +45,7 @@ public ServerSentEventsTransport(HttpClient httpClient, ILoggerFactory loggerFac } _httpClient = httpClient; + _httpOptions = httpOptions; _logger = (loggerFactory ?? NullLoggerFactory.Instance).CreateLogger(); } @@ -59,7 +62,7 @@ public Task StartAsync(Uri url, Channel application, Transf _logger.StartTransport(_connectionId, Mode.Value); - var sendTask = SendUtils.SendMessages(url, _application, _httpClient, _transportCts, _logger, _connectionId); + var sendTask = SendUtils.SendMessages(url, _application, _httpClient, _httpOptions, _transportCts, _logger, _connectionId); var receiveTask = OpenConnection(_application, url, _transportCts.Token); Running = Task.WhenAll(sendTask, receiveTask).ContinueWith(t => @@ -78,6 +81,7 @@ private async Task OpenConnection(Channel application, Uri _logger.StartReceive(_connectionId); var request = new HttpRequestMessage(HttpMethod.Get, url); + SendUtils.PrepareHttpRequest(request, _httpOptions); request.Headers.Accept.Add(new MediaTypeWithQualityHeaderValue("text/event-stream")); var response = await _httpClient.SendAsync(request, HttpCompletionOption.ResponseHeadersRead, cancellationToken); diff --git a/src/Microsoft.AspNetCore.Sockets.Client.Http/WebSocketsTransport.cs b/src/Microsoft.AspNetCore.Sockets.Client.Http/WebSocketsTransport.cs index 5b8874b6c5..e4ee9a1fb1 100644 --- a/src/Microsoft.AspNetCore.Sockets.Client.Http/WebSocketsTransport.cs +++ b/src/Microsoft.AspNetCore.Sockets.Client.Http/WebSocketsTransport.cs @@ -6,8 +6,9 @@ using System.Diagnostics; using System.Net.WebSockets; using System.Threading; -using System.Threading.Tasks; using System.Threading.Channels; +using System.Threading.Tasks; +using Microsoft.AspNetCore.Sockets.Client.Http; using Microsoft.AspNetCore.Sockets.Client.Internal; using Microsoft.Extensions.Logging; using Microsoft.Extensions.Logging.Abstractions; @@ -16,7 +17,7 @@ namespace Microsoft.AspNetCore.Sockets.Client { public class WebSocketsTransport : ITransport { - private readonly ClientWebSocket _webSocket = new ClientWebSocket(); + private readonly ClientWebSocket _webSocket; private Channel _application; private readonly CancellationTokenSource _transportCts = new CancellationTokenSource(); private readonly CancellationTokenSource _receiveCts = new CancellationTokenSource(); @@ -28,12 +29,26 @@ public class WebSocketsTransport : ITransport public TransferMode? Mode { get; private set; } public WebSocketsTransport() - : this(null) + : this(null, null) { } - public WebSocketsTransport(ILoggerFactory loggerFactory) + public WebSocketsTransport(HttpOptions httpOptions, ILoggerFactory loggerFactory) { + _webSocket = new ClientWebSocket(); + if (httpOptions?.Headers != null) + { + foreach (var header in httpOptions.Headers) + { + _webSocket.Options.SetRequestHeader(header.Key, header.Value); + } + } + + if (httpOptions?.JwtBearerTokenFactory != null) + { + _webSocket.Options.SetRequestHeader("Authorization", $"Bearer {httpOptions.JwtBearerTokenFactory()}"); + } + _logger = (loggerFactory ?? NullLoggerFactory.Instance).CreateLogger(); } diff --git a/src/Microsoft.AspNetCore.Sockets.Http/Internal/Transports/WebSocketsTransport.cs b/src/Microsoft.AspNetCore.Sockets.Http/Internal/Transports/WebSocketsTransport.cs index 97756e3613..5d6b551066 100644 --- a/src/Microsoft.AspNetCore.Sockets.Http/Internal/Transports/WebSocketsTransport.cs +++ b/src/Microsoft.AspNetCore.Sockets.Http/Internal/Transports/WebSocketsTransport.cs @@ -6,8 +6,8 @@ using System.Diagnostics; using System.Net.WebSockets; using System.Threading; -using System.Threading.Tasks; using System.Threading.Channels; +using System.Threading.Tasks; using Microsoft.AspNetCore.Http; using Microsoft.Extensions.Logging; diff --git a/test/Microsoft.AspNetCore.SignalR.Client.FunctionalTests/HubConnectionTests.cs b/test/Microsoft.AspNetCore.SignalR.Client.FunctionalTests/HubConnectionTests.cs index 95ce2a9893..594d674bc3 100644 --- a/test/Microsoft.AspNetCore.SignalR.Client.FunctionalTests/HubConnectionTests.cs +++ b/test/Microsoft.AspNetCore.SignalR.Client.FunctionalTests/HubConnectionTests.cs @@ -3,6 +3,8 @@ using System; using System.Collections.Generic; +using System.Linq; +using System.Net.Http; using System.Threading; using System.Threading.Tasks; using System.Threading.Channels; @@ -555,6 +557,70 @@ public async Task ServerThrowsHubExceptionIfBuildingAsyncEnumeratorIsNotPossible } } + [Theory] + [MemberData(nameof(TransportTypes))] + public async Task ClientCanUseJwtBearerTokenForAuthentication(TransportType transportType) + { + using (StartLog(out var loggerFactory)) + { + var httpResponse = await new HttpClient().GetAsync(_serverFixture.Url + "/generateJwtToken"); + httpResponse.EnsureSuccessStatusCode(); + var token = await httpResponse.Content.ReadAsStringAsync(); + + var hubConnection = new HubConnectionBuilder() + .WithUrl(_serverFixture.Url + "/authorizedhub") + .WithTransport(transportType) + .WithLoggerFactory(loggerFactory) + .WithJwtBearer(() => token) + .Build(); + try + { + await hubConnection.StartAsync().OrTimeout(); + var message = await hubConnection.InvokeAsync("Echo", "Hello, World!").OrTimeout(); + Assert.Equal("Hello, World!", message); + } + catch (Exception ex) + { + loggerFactory.CreateLogger().LogError(ex, "Exception from test"); + throw; + } + finally + { + await hubConnection.DisposeAsync().OrTimeout(); + } + } + } + + [Theory] + [MemberData(nameof(TransportTypes))] + public async Task ClientCanSendHeaders(TransportType transportType) + { + using (StartLog(out var loggerFactory)) + { + var hubConnection = new HubConnectionBuilder() + .WithUrl(_serverFixture.Url + "/default") + .WithTransport(transportType) + .WithLoggerFactory(loggerFactory) + .WithHeader("X-test", "42") + .WithHeader("X-42", "test") + .Build(); + try + { + await hubConnection.StartAsync().OrTimeout(); + var headerValues = await hubConnection.InvokeAsync("GetHeaderValues", new object[] { new[] { "X-test", "X-42" } }).OrTimeout(); + Assert.Equal(new[] { "42", "test" }, headerValues); + } + catch (Exception ex) + { + loggerFactory.CreateLogger().LogError(ex, "Exception from test"); + throw; + } + finally + { + await hubConnection.DisposeAsync().OrTimeout(); + } + } + } public static IEnumerable HubProtocolsAndTransportsAndHubPaths { @@ -562,7 +628,7 @@ public static IEnumerable HubProtocolsAndTransportsAndHubPaths { foreach (var protocol in HubProtocols) { - foreach (var transport in TransportTypes()) + foreach (var transport in TransportTypes().SelectMany(t => t)) { foreach (var hubPath in HubPaths) { @@ -582,14 +648,14 @@ public static IEnumerable HubProtocolsAndTransportsAndHubPaths new MessagePackHubProtocol(), }; - public static IEnumerable TransportTypes() + public static IEnumerable TransportTypes() { if (TestHelpers.IsWebSocketsSupported()) { - yield return TransportType.WebSockets; + yield return new object[] { TransportType.WebSockets }; } - yield return TransportType.ServerSentEvents; - yield return TransportType.LongPolling; + yield return new object[] { TransportType.ServerSentEvents }; + yield return new object[] { TransportType.LongPolling }; } } } diff --git a/test/Microsoft.AspNetCore.SignalR.Client.FunctionalTests/Hubs.cs b/test/Microsoft.AspNetCore.SignalR.Client.FunctionalTests/Hubs.cs index 37b3e44dda..0fc7da25cf 100644 --- a/test/Microsoft.AspNetCore.SignalR.Client.FunctionalTests/Hubs.cs +++ b/test/Microsoft.AspNetCore.SignalR.Client.FunctionalTests/Hubs.cs @@ -2,10 +2,13 @@ // Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. using System; +using System.Collections.Generic; using System.Linq; using System.Reactive.Linq; -using System.Threading.Tasks; using System.Threading.Channels; +using System.Threading.Tasks; +using Microsoft.AspNetCore.Authentication.JwtBearer; +using Microsoft.AspNetCore.Authorization; namespace Microsoft.AspNetCore.SignalR.Client.FunctionalTests { @@ -30,6 +33,12 @@ public async Task CallHandlerThatDoesntExist() { await Clients.Client(Context.ConnectionId).InvokeAsync("NoClientHandler"); } + + public IEnumerable GetHeaderValues(string[] headerNames) + { + var headers = Context.Connection.GetHttpContext().Request.Headers; + return headerNames.Select(h => (string)headers[h]); + } } public class DynamicTestHub : DynamicHub @@ -111,4 +120,10 @@ public interface ITestHub Task Send(string message); Task NoClientHandler(); } + + [Authorize(JwtBearerDefaults.AuthenticationScheme)] + public class HubWithAuthorization : Hub + { + public string Echo(string message) => TestHubMethodsImpl.Echo(message); + } } diff --git a/test/Microsoft.AspNetCore.SignalR.Client.FunctionalTests/Microsoft.AspNetCore.SignalR.Client.FunctionalTests.csproj b/test/Microsoft.AspNetCore.SignalR.Client.FunctionalTests/Microsoft.AspNetCore.SignalR.Client.FunctionalTests.csproj index c3782767d0..5d81375a97 100644 --- a/test/Microsoft.AspNetCore.SignalR.Client.FunctionalTests/Microsoft.AspNetCore.SignalR.Client.FunctionalTests.csproj +++ b/test/Microsoft.AspNetCore.SignalR.Client.FunctionalTests/Microsoft.AspNetCore.SignalR.Client.FunctionalTests.csproj @@ -23,6 +23,7 @@ + diff --git a/test/Microsoft.AspNetCore.SignalR.Client.FunctionalTests/Startup.cs b/test/Microsoft.AspNetCore.SignalR.Client.FunctionalTests/Startup.cs index 2c5b944dfe..e0d9d09e5e 100644 --- a/test/Microsoft.AspNetCore.SignalR.Client.FunctionalTests/Startup.cs +++ b/test/Microsoft.AspNetCore.SignalR.Client.FunctionalTests/Startup.cs @@ -1,26 +1,77 @@ // Copyright (c) .NET Foundation. All rights reserved. // Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. +using System; +using System.IdentityModel.Tokens.Jwt; +using System.Security.Claims; +using Microsoft.AspNetCore.Authentication.JwtBearer; using Microsoft.AspNetCore.Builder; +using Microsoft.AspNetCore.Http; using Microsoft.Extensions.DependencyInjection; +using Microsoft.IdentityModel.Tokens; namespace Microsoft.AspNetCore.SignalR.Client.FunctionalTests { public class Startup { + private readonly SymmetricSecurityKey SecurityKey = new SymmetricSecurityKey(Guid.NewGuid().ToByteArray()); + private readonly JwtSecurityTokenHandler JwtTokenHandler = new JwtSecurityTokenHandler(); + public void ConfigureServices(IServiceCollection services) { services.AddSignalR(); + services.AddAuthorization(options => + { + options.AddPolicy(JwtBearerDefaults.AuthenticationScheme, policy => + { + policy.AddAuthenticationSchemes(JwtBearerDefaults.AuthenticationScheme); + policy.RequireClaim(ClaimTypes.NameIdentifier); + }); + }); + services.AddAuthentication(JwtBearerDefaults.AuthenticationScheme) + .AddJwtBearer(options => + { + options.TokenValidationParameters = + new TokenValidationParameters + { + ValidateAudience = false, + ValidateIssuer = false, + ValidateActor = false, + ValidateLifetime = true, + IssuerSigningKey = SecurityKey + }; + }); } public void Configure(IApplicationBuilder app) { + app.UseAuthentication(); + app.UseSignalR(routes => { routes.MapHub("default"); routes.MapHub("dynamic"); routes.MapHub("hubT"); + routes.MapHub("authorizedhub"); }); + + app.Run(async (context) => + { + if (context.Request.Path.StartsWithSegments("/generateJwtToken")) + { + await context.Response.WriteAsync(GenerateJwtToken()); + return; + } + }); + + } + + private string GenerateJwtToken() + { + var claims = new[] { new Claim(ClaimTypes.NameIdentifier, "testuser") }; + var credentials = new SigningCredentials(SecurityKey, SecurityAlgorithms.HmacSha256); + var token = new JwtSecurityToken("SignalRTestServer", "SignalRTests", claims, expires: DateTime.Now.AddSeconds(5), signingCredentials: credentials); + return JwtTokenHandler.WriteToken(token); } } } diff --git a/test/Microsoft.AspNetCore.SignalR.Client.Tests/HttpConnectionTests.cs b/test/Microsoft.AspNetCore.SignalR.Client.Tests/HttpConnectionTests.cs index 3645a03eeb..588b504484 100644 --- a/test/Microsoft.AspNetCore.SignalR.Client.Tests/HttpConnectionTests.cs +++ b/test/Microsoft.AspNetCore.SignalR.Client.Tests/HttpConnectionTests.cs @@ -10,6 +10,7 @@ using System.Threading.Channels; using Microsoft.AspNetCore.Client.Tests; using Microsoft.AspNetCore.SignalR.Tests.Common; +using Microsoft.AspNetCore.Sockets.Client.Http; using Microsoft.AspNetCore.Sockets.Features; using Microsoft.Extensions.Logging; using Moq; @@ -58,7 +59,8 @@ public async Task CannotStartRunningConnection() : ResponseUtils.CreateResponse(HttpStatusCode.OK); }); - var connection = new HttpConnection(new Uri("http://fakeuri.org/"), TransportType.LongPolling, loggerFactory: null, httpMessageHandler: mockHttpHandler.Object); + var connection = new HttpConnection(new Uri("http://fakeuri.org/"), TransportType.LongPolling, loggerFactory: null, + httpOptions: new HttpOptions { HttpMessageHandler = mockHttpHandler.Object }); try { await connection.StartAsync(); @@ -87,7 +89,8 @@ public async Task CannotStartStoppedConnection() : ResponseUtils.CreateResponse(HttpStatusCode.OK); }); - var connection = new HttpConnection(new Uri("http://fakeuri.org/"), TransportType.LongPolling, loggerFactory: null, httpMessageHandler: mockHttpHandler.Object); + var connection = new HttpConnection(new Uri("http://fakeuri.org/"), TransportType.LongPolling, loggerFactory: null, + httpOptions: new HttpOptions { HttpMessageHandler = mockHttpHandler.Object }); await connection.StartAsync(); await connection.DisposeAsync(); @@ -139,7 +142,8 @@ public async Task CanStopStartingConnection() var transport = new Mock(); transport.Setup(t => t.StopAsync()).Returns(async () => { await releaseDisposeTcs.Task; }); - var connection = new HttpConnection(new Uri("http://fakeuri.org/"), new TestTransportFactory(transport.Object), loggerFactory: null, httpMessageHandler: mockHttpHandler.Object); + var connection = new HttpConnection(new Uri("http://fakeuri.org/"), new TestTransportFactory(transport.Object), loggerFactory: null, + httpOptions: new HttpOptions { HttpMessageHandler = mockHttpHandler.Object }); var startTask = connection.StartAsync(); await allowDisposeTcs.Task; @@ -178,7 +182,8 @@ public async Task SendThrowsIfConnectionIsDisposed() : ResponseUtils.CreateResponse(HttpStatusCode.OK); }); - var connection = new HttpConnection(new Uri("http://fakeuri.org/"), TransportType.LongPolling, loggerFactory: null, httpMessageHandler: mockHttpHandler.Object); + var connection = new HttpConnection(new Uri("http://fakeuri.org/"), TransportType.LongPolling, loggerFactory: null, + httpOptions: new HttpOptions { HttpMessageHandler = mockHttpHandler.Object }); await connection.StartAsync(); await connection.DisposeAsync(); @@ -202,7 +207,8 @@ public async Task ClosedEventRaisedWhenTheClientIsBeingStopped() : ResponseUtils.CreateResponse(HttpStatusCode.OK); }); - var connection = new HttpConnection(new Uri("http://fakeuri.org/"), TransportType.LongPolling, loggerFactory: null, httpMessageHandler: mockHttpHandler.Object); + var connection = new HttpConnection(new Uri("http://fakeuri.org/"), TransportType.LongPolling, loggerFactory: null, + httpOptions: new HttpOptions { HttpMessageHandler = mockHttpHandler.Object }); await connection.StartAsync().OrTimeout(); @@ -228,7 +234,8 @@ public async Task ClosedEventRaisedWhenConnectionToServerLost() : ResponseUtils.CreateResponse(HttpStatusCode.OK); }); - var connection = new HttpConnection(new Uri("http://fakeuri.org/"), TransportType.LongPolling, loggerFactory: null, httpMessageHandler: mockHttpHandler.Object); + var connection = new HttpConnection(new Uri("http://fakeuri.org/"), TransportType.LongPolling, loggerFactory: null, + httpOptions: new HttpOptions { HttpMessageHandler = mockHttpHandler.Object }); try { @@ -274,7 +281,8 @@ public async Task ReceivedCallbackNotRaisedAfterConnectionIsDisposed() }); mockTransport.SetupGet(t => t.Mode).Returns(TransferMode.Text); - var connection = new HttpConnection(new Uri("http://fakeuri.org/"), new TestTransportFactory(mockTransport.Object), loggerFactory: null, httpMessageHandler: mockHttpHandler.Object); + var connection = new HttpConnection(new Uri("http://fakeuri.org/"), new TestTransportFactory(mockTransport.Object), loggerFactory: null, + httpOptions: new HttpOptions { HttpMessageHandler = mockHttpHandler.Object }); var onReceivedInvoked = false; connection.OnReceived( _ => @@ -321,7 +329,8 @@ public async Task EventsAreNotRunningOnMainLoop() var callbackInvokedTcs = new TaskCompletionSource(); var closedTcs = new TaskCompletionSource(); - var connection = new HttpConnection(new Uri("http://fakeuri.org/"), new TestTransportFactory(mockTransport.Object), loggerFactory: null, httpMessageHandler: mockHttpHandler.Object); + var connection = new HttpConnection(new Uri("http://fakeuri.org/"), new TestTransportFactory(mockTransport.Object), loggerFactory: null, + httpOptions: new HttpOptions { HttpMessageHandler = mockHttpHandler.Object }); connection.OnReceived(_ => { @@ -376,7 +385,8 @@ public async Task EventQueueTimeout() var blockReceiveCallbackTcs = new TaskCompletionSource(); - var connection = new HttpConnection(new Uri("http://fakeuri.org/"), new TestTransportFactory(mockTransport.Object), loggerFactory: null, httpMessageHandler: mockHttpHandler.Object); + var connection = new HttpConnection(new Uri("http://fakeuri.org/"), new TestTransportFactory(mockTransport.Object), loggerFactory: null, + httpOptions: new HttpOptions { HttpMessageHandler = mockHttpHandler.Object }); connection.OnReceived(_ => blockReceiveCallbackTcs.Task); await connection.StartAsync(); @@ -420,7 +430,8 @@ public async Task EventQueueTimeoutWithException() var callbackInvokedTcs = new TaskCompletionSource(); - var connection = new HttpConnection(new Uri("http://fakeuri.org/"), new TestTransportFactory(mockTransport.Object), loggerFactory: null, httpMessageHandler: mockHttpHandler.Object); + var connection = new HttpConnection(new Uri("http://fakeuri.org/"), new TestTransportFactory(mockTransport.Object), loggerFactory: null, + httpOptions: new HttpOptions { HttpMessageHandler = mockHttpHandler.Object }); connection.OnReceived( _ => { throw new OperationCanceledException(); @@ -461,8 +472,9 @@ public async Task TransportIsStoppedWhenConnectionIsStopped() using (var httpClient = new HttpClient(mockHttpHandler.Object)) { - var longPollingTransport = new LongPollingTransport(httpClient, new LoggerFactory()); - var connection = new HttpConnection(new Uri("http://fakeuri.org/"), new TestTransportFactory(longPollingTransport), loggerFactory: null, httpMessageHandler: mockHttpHandler.Object); + var longPollingTransport = new LongPollingTransport(httpClient, null, new LoggerFactory()); + var connection = new HttpConnection(new Uri("http://fakeuri.org/"), new TestTransportFactory(longPollingTransport), loggerFactory: null, + httpOptions: new HttpOptions { HttpMessageHandler = mockHttpHandler.Object }); try { @@ -504,7 +516,8 @@ public async Task CanSendData() return ResponseUtils.CreateResponse(HttpStatusCode.OK); }); - var connection = new HttpConnection(new Uri("http://fakeuri.org/"), TransportType.LongPolling, loggerFactory: null, httpMessageHandler: mockHttpHandler.Object); + var connection = new HttpConnection(new Uri("http://fakeuri.org/"), TransportType.LongPolling, loggerFactory: null, + httpOptions: new HttpOptions { HttpMessageHandler = mockHttpHandler.Object }); try { await connection.StartAsync(); @@ -550,7 +563,8 @@ public async Task SendAsyncThrowsIfConnectionIsDisposed() : ResponseUtils.CreateResponse(HttpStatusCode.OK, content); }); - var connection = new HttpConnection(new Uri("http://fakeuri.org/"), TransportType.LongPolling, loggerFactory: null, httpMessageHandler: mockHttpHandler.Object); + var connection = new HttpConnection(new Uri("http://fakeuri.org/"), TransportType.LongPolling, loggerFactory: null, + httpOptions: new HttpOptions { HttpMessageHandler = mockHttpHandler.Object }); await connection.StartAsync(); await connection.DisposeAsync(); @@ -578,7 +592,8 @@ public async Task CallerReceivesExceptionsFromSendAsync() : ResponseUtils.CreateResponse(HttpStatusCode.OK); }); - var connection = new HttpConnection(new Uri("http://fakeuri.org/"), TransportType.LongPolling, loggerFactory: null, httpMessageHandler: mockHttpHandler.Object); + var connection = new HttpConnection(new Uri("http://fakeuri.org/"), TransportType.LongPolling, loggerFactory: null, + httpOptions: new HttpOptions { HttpMessageHandler = mockHttpHandler.Object }); await connection.StartAsync(); var exception = await Assert.ThrowsAsync( @@ -609,7 +624,8 @@ public async Task CanReceiveData() : ResponseUtils.CreateResponse(HttpStatusCode.OK, content); }); - var connection = new HttpConnection(new Uri("http://fakeuri.org/"), TransportType.LongPolling, loggerFactory: null, httpMessageHandler: mockHttpHandler.Object); + var connection = new HttpConnection(new Uri("http://fakeuri.org/"), TransportType.LongPolling, loggerFactory: null, + httpOptions: new HttpOptions { HttpMessageHandler = mockHttpHandler.Object }); try { var receiveTcs = new TaskCompletionSource(); @@ -664,7 +680,8 @@ public async Task CanReceiveDataEvenIfExceptionThrownFromPreviousReceivedEvent() : ResponseUtils.CreateResponse(HttpStatusCode.OK, content); }); - var connection = new HttpConnection(new Uri("http://fakeuri.org/"), TransportType.LongPolling, loggerFactory: null, httpMessageHandler: mockHttpHandler.Object); + var connection = new HttpConnection(new Uri("http://fakeuri.org/"), TransportType.LongPolling, loggerFactory: null, + httpOptions: new HttpOptions { HttpMessageHandler = mockHttpHandler.Object }); try { var receiveTcs = new TaskCompletionSource(); @@ -727,7 +744,8 @@ public async Task CanReceiveDataEvenIfExceptionThrownSynchronouslyFromPreviousRe : ResponseUtils.CreateResponse(HttpStatusCode.OK, content); }); - var connection = new HttpConnection(new Uri("http://fakeuri.org/"), TransportType.LongPolling, loggerFactory: null, httpMessageHandler: mockHttpHandler.Object); + var connection = new HttpConnection(new Uri("http://fakeuri.org/"), TransportType.LongPolling, loggerFactory: null, + httpOptions: new HttpOptions { HttpMessageHandler = mockHttpHandler.Object }); try { var receiveTcs = new TaskCompletionSource(); @@ -785,7 +803,8 @@ public async Task CannotSendAfterReceiveThrewException() : ResponseUtils.CreateResponse(HttpStatusCode.OK); }); - var connection = new HttpConnection(new Uri("http://fakeuri.org/"), TransportType.LongPolling, loggerFactory: null, httpMessageHandler: mockHttpHandler.Object); + var connection = new HttpConnection(new Uri("http://fakeuri.org/"), TransportType.LongPolling, loggerFactory: null, + httpOptions: new HttpOptions { HttpMessageHandler = mockHttpHandler.Object }); try { @@ -818,7 +837,8 @@ public async Task StartThrowsFormatExceptionIfNegotiationResponseIsInvalid(strin return ResponseUtils.CreateResponse(HttpStatusCode.OK, negotiatePayload); }); - var connection = new HttpConnection(new Uri("http://fakeuri.org/"), TransportType.LongPolling, loggerFactory: null, httpMessageHandler: mockHttpHandler.Object); + var connection = new HttpConnection(new Uri("http://fakeuri.org/"), TransportType.LongPolling, loggerFactory: null, + httpOptions: new HttpOptions { HttpMessageHandler = mockHttpHandler.Object }); var exception = await Assert.ThrowsAsync( () => connection.StartAsync()); @@ -838,7 +858,8 @@ public async Task StartThrowsFormatExceptionIfNegotiationResponseHasNoConnection ResponseUtils.CreateNegotiationResponse(connectionId: null)); }); - var connection = new HttpConnection(new Uri("http://fakeuri.org/"), TransportType.LongPolling, loggerFactory: null, httpMessageHandler: mockHttpHandler.Object); + var connection = new HttpConnection(new Uri("http://fakeuri.org/"), TransportType.LongPolling, loggerFactory: null, + httpOptions: new HttpOptions { HttpMessageHandler = mockHttpHandler.Object }); var exception = await Assert.ThrowsAsync( () => connection.StartAsync()); @@ -858,7 +879,8 @@ public async Task StartThrowsFormatExceptionIfNegotiationResponseHasNoTransports ResponseUtils.CreateNegotiationResponse(transportTypes: null)); }); - var connection = new HttpConnection(new Uri("http://fakeuri.org/"), TransportType.LongPolling, loggerFactory: null, httpMessageHandler: mockHttpHandler.Object); + var connection = new HttpConnection(new Uri("http://fakeuri.org/"), TransportType.LongPolling, loggerFactory: null, + httpOptions: new HttpOptions { HttpMessageHandler = mockHttpHandler.Object }); var exception = await Assert.ThrowsAsync( () => connection.StartAsync()); @@ -880,7 +902,8 @@ public async Task ConnectionCannotBeStartedIfNoCommonTransportsBetweenClientAndS ResponseUtils.CreateNegotiationResponse(transportTypes: serverTransports)); }); - var connection = new HttpConnection(new Uri("http://fakeuri.org/"), TransportType.LongPolling, loggerFactory: null, httpMessageHandler: mockHttpHandler.Object); + var connection = new HttpConnection(new Uri("http://fakeuri.org/"), TransportType.LongPolling, loggerFactory: null, + httpOptions: new HttpOptions { HttpMessageHandler = mockHttpHandler.Object }); var exception = await Assert.ThrowsAsync( () => connection.StartAsync()); @@ -918,7 +941,7 @@ public async Task CanStartConnectionWithoutSettingTransferModeFeature() mockTransport.SetupGet(t => t.Mode).Returns(TransferMode.Binary); var connection = new HttpConnection(new Uri("http://fakeuri.org/"), new TestTransportFactory(mockTransport.Object), - loggerFactory: null, httpMessageHandler: mockHttpHandler.Object); + loggerFactory: null, httpOptions: new HttpOptions { HttpMessageHandler = mockHttpHandler.Object }); await connection.StartAsync().OrTimeout(); var transferModeFeature = connection.Features.Get(); @@ -950,7 +973,8 @@ public async Task query(string requested, string expectedNegotiate) ResponseUtils.CreateNegotiationResponse()); }); - var connection = new HttpConnection(new Uri(requested), TransportType.LongPolling, loggerFactory: null, httpMessageHandler: mockHttpHandler.Object); + var connection = new HttpConnection(new Uri(requested), TransportType.LongPolling, loggerFactory: null, + httpOptions: new HttpOptions { HttpMessageHandler = mockHttpHandler.Object }); await connection.StartAsync().OrTimeout(); await connection.DisposeAsync().OrTimeout(); } diff --git a/test/Microsoft.AspNetCore.SignalR.Tests/DefaultTransportFactoryTests.cs b/test/Microsoft.AspNetCore.SignalR.Tests/DefaultTransportFactoryTests.cs index a8a356abbd..ff127785f5 100644 --- a/test/Microsoft.AspNetCore.SignalR.Tests/DefaultTransportFactoryTests.cs +++ b/test/Microsoft.AspNetCore.SignalR.Tests/DefaultTransportFactoryTests.cs @@ -20,7 +20,7 @@ public class DefaultTransportFactoryTests public void DefaultTransportFactoryCannotBeCreatedWithInvalidTransportType(TransportType transportType) { Assert.Throws( - () => new DefaultTransportFactory(transportType, new LoggerFactory(), new HttpClient())); + () => new DefaultTransportFactory(transportType, new LoggerFactory(), new HttpClient(), httpOptions: null)); } [Theory] @@ -32,7 +32,7 @@ public void DefaultTransportFactoryCannotBeCreatedWithInvalidTransportType(Trans public void DefaultTransportFactoryCannotBeCreatedWithoutHttpClient(TransportType transportType) { var exception = Assert.Throws( - () => new DefaultTransportFactory(transportType, new LoggerFactory(), httpClient: null)); + () => new DefaultTransportFactory(transportType, new LoggerFactory(), httpClient: null, httpOptions: null)); Assert.Equal("httpClient", exception.ParamName); } @@ -40,7 +40,7 @@ public void DefaultTransportFactoryCannotBeCreatedWithoutHttpClient(TransportTyp [Fact] public void DefaultTransportFactoryCanBeCreatedWithoutHttpClientIfWebSocketsTransportRequestedExplicitly() { - new DefaultTransportFactory(TransportType.WebSockets, new LoggerFactory(), httpClient: null); + new DefaultTransportFactory(TransportType.WebSockets, new LoggerFactory(), httpClient: null, httpOptions: null); } [ConditionalTheory] @@ -50,7 +50,7 @@ public void DefaultTransportFactoryCanBeCreatedWithoutHttpClientIfWebSocketsTran [OSSkipCondition(OperatingSystems.Windows, WindowsVersions.Win7, WindowsVersions.Win2008R2, SkipReason = "No WebSockets Client for this platform")] public void DefaultTransportFactoryCreatesRequestedTransportIfAvailable(TransportType requestedTransport, Type expectedTransportType) { - var transportFactory = new DefaultTransportFactory(requestedTransport, loggerFactory: null, httpClient: new HttpClient()); + var transportFactory = new DefaultTransportFactory(requestedTransport, loggerFactory: null, httpClient: new HttpClient(), httpOptions: null); Assert.IsType(expectedTransportType, transportFactory.CreateTransport(TransportType.All)); } @@ -63,7 +63,7 @@ public void DefaultTransportFactoryCreatesRequestedTransportIfAvailable(Transpor public void DefaultTransportFactoryThrowsIfItCannotCreateRequestedTransport(TransportType requestedTransport) { var transportFactory = - new DefaultTransportFactory(requestedTransport, loggerFactory: null, httpClient: new HttpClient()); + new DefaultTransportFactory(requestedTransport, loggerFactory: null, httpClient: new HttpClient(), httpOptions: null); var ex = Assert.Throws( () => transportFactory.CreateTransport(~requestedTransport)); @@ -75,7 +75,7 @@ public void DefaultTransportFactoryThrowsIfItCannotCreateRequestedTransport(Tran public void DefaultTransportFactoryCreatesWebSocketsTransportIfAvailable() { Assert.IsType( - new DefaultTransportFactory(TransportType.All, loggerFactory: null, httpClient: new HttpClient()) + new DefaultTransportFactory(TransportType.All, loggerFactory: null, httpClient: new HttpClient(), httpOptions: null) .CreateTransport(TransportType.All)); } @@ -87,7 +87,7 @@ public void DefaultTransportFactoryCreatesRequestedTransportIfAvailable_Win7(Tra { if (!TestHelpers.IsWebSocketsSupported()) { - var transportFactory = new DefaultTransportFactory(requestedTransport, loggerFactory: null, httpClient: new HttpClient()); + var transportFactory = new DefaultTransportFactory(requestedTransport, loggerFactory: null, httpClient: new HttpClient(), httpOptions: null); Assert.IsType(expectedTransportType, transportFactory.CreateTransport(TransportType.All)); } @@ -100,7 +100,7 @@ public void DefaultTransportFactoryThrowsIfItCannotCreateRequestedTransport_Win7 if (!TestHelpers.IsWebSocketsSupported()) { var transportFactory = - new DefaultTransportFactory(requestedTransport, loggerFactory: null, httpClient: new HttpClient()); + new DefaultTransportFactory(requestedTransport, loggerFactory: null, httpClient: new HttpClient(), httpOptions: null); var ex = Assert.Throws( () => transportFactory.CreateTransport(TransportType.All)); diff --git a/test/Microsoft.AspNetCore.SignalR.Tests/EndToEndTests.cs b/test/Microsoft.AspNetCore.SignalR.Tests/EndToEndTests.cs index cc659add72..5195cebef6 100644 --- a/test/Microsoft.AspNetCore.SignalR.Tests/EndToEndTests.cs +++ b/test/Microsoft.AspNetCore.SignalR.Tests/EndToEndTests.cs @@ -12,6 +12,7 @@ using Microsoft.AspNetCore.SignalR.Tests.Common; using Microsoft.AspNetCore.Sockets; using Microsoft.AspNetCore.Sockets.Client; +using Microsoft.AspNetCore.Sockets.Client.Http; using Microsoft.AspNetCore.Sockets.Features; using Microsoft.AspNetCore.Testing.xunit; using Microsoft.Extensions.Logging; @@ -114,7 +115,7 @@ public async Task HTTPRequestsNotSentWhenWebSocketsTransportRequested() .Returns( (request, cancellationToken) => Task.FromException(new InvalidOperationException("HTTP requests should not be sent."))); - var connection = new HttpConnection(new Uri(url), TransportType.WebSockets, loggerFactory, mockHttpHandler.Object); + var connection = new HttpConnection(new Uri(url), TransportType.WebSockets, loggerFactory, new HttpOptions { HttpMessageHandler = mockHttpHandler.Object}); try { diff --git a/test/Microsoft.AspNetCore.SignalR.Tests/WebSocketsTransportTests.cs b/test/Microsoft.AspNetCore.SignalR.Tests/WebSocketsTransportTests.cs index edb9099fdb..1644c86ccc 100644 --- a/test/Microsoft.AspNetCore.SignalR.Tests/WebSocketsTransportTests.cs +++ b/test/Microsoft.AspNetCore.SignalR.Tests/WebSocketsTransportTests.cs @@ -40,7 +40,7 @@ public async Task WebSocketsTransportStopsSendAndReceiveLoopsWhenTransportIsStop var transportToConnection = Channel.CreateUnbounded(); var channelConnection = new ChannelConnection(connectionToTransport, transportToConnection); - var webSocketsTransport = new WebSocketsTransport(loggerFactory); + var webSocketsTransport = new WebSocketsTransport(httpOptions: null, loggerFactory: loggerFactory); await webSocketsTransport.StartAsync(new Uri(_serverFixture.WebSocketsUrl + "/echo"), channelConnection, TransferMode.Binary, connectionId: string.Empty).OrTimeout(); await webSocketsTransport.StopAsync().OrTimeout(); @@ -58,7 +58,7 @@ public async Task WebSocketsTransportStopsWhenConnectionChannelClosed() var transportToConnection = Channel.CreateUnbounded(); var channelConnection = new ChannelConnection(connectionToTransport, transportToConnection); - var webSocketsTransport = new WebSocketsTransport(loggerFactory); + var webSocketsTransport = new WebSocketsTransport(httpOptions: null, loggerFactory: loggerFactory); await webSocketsTransport.StartAsync(new Uri(_serverFixture.WebSocketsUrl + "/echo"), channelConnection, TransferMode.Binary, connectionId: string.Empty); connectionToTransport.Writer.TryComplete(); @@ -78,7 +78,7 @@ public async Task WebSocketsTransportStopsWhenConnectionClosedByTheServer(Transf var transportToConnection = Channel.CreateUnbounded(); var channelConnection = new ChannelConnection(connectionToTransport, transportToConnection); - var webSocketsTransport = new WebSocketsTransport(loggerFactory); + var webSocketsTransport = new WebSocketsTransport(httpOptions: null, loggerFactory: loggerFactory); await webSocketsTransport.StartAsync(new Uri(_serverFixture.WebSocketsUrl + "/echo"), channelConnection, transferMode, connectionId: string.Empty); var sendTcs = new TaskCompletionSource(); @@ -116,7 +116,7 @@ public async Task WebSocketsTransportSetsTransferMode(TransferMode transferMode) var transportToConnection = Channel.CreateUnbounded(); var channelConnection = new ChannelConnection(connectionToTransport, transportToConnection); - var webSocketsTransport = new WebSocketsTransport(loggerFactory); + var webSocketsTransport = new WebSocketsTransport(httpOptions: null, loggerFactory: loggerFactory); Assert.Null(webSocketsTransport.Mode); await webSocketsTransport.StartAsync(new Uri(_serverFixture.WebSocketsUrl + "/echo"), channelConnection, @@ -138,7 +138,7 @@ public async Task WebSocketsTransportThrowsForInvalidTransferMode() var transportToConnection = Channel.CreateUnbounded(); var channelConnection = new ChannelConnection(connectionToTransport, transportToConnection); - var webSocketsTransport = new WebSocketsTransport(loggerFactory); + var webSocketsTransport = new WebSocketsTransport(httpOptions: null, loggerFactory: loggerFactory); var exception = await Assert.ThrowsAsync(() => webSocketsTransport.StartAsync(new Uri("http://fakeuri.org"), channelConnection, TransferMode.Text | TransferMode.Binary, connectionId: string.Empty));