Skip to content
This repository has been archived by the owner on Dec 18, 2018. It is now read-only.

Commit

Permalink
Adding support for JWT in the C# client
Browse files Browse the repository at this point in the history
Fixes: #1018

(Bonus: also enabling passing headers)
  • Loading branch information
moozzyk committed Nov 22, 2017
1 parent fadd6f8 commit 0bafb30
Show file tree
Hide file tree
Showing 20 changed files with 363 additions and 81 deletions.
1 change: 1 addition & 0 deletions build/dependencies.props
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
<InternalAspNetCoreSdkPackageVersion>2.1.0-preview1-15549</InternalAspNetCoreSdkPackageVersion>
<MicrosoftAspNetCoreAuthenticationCookiesPackageVersion>2.1.0-preview1-27579</MicrosoftAspNetCoreAuthenticationCookiesPackageVersion>
<MicrosoftAspNetCoreAuthenticationCorePackageVersion>2.1.0-preview1-27579</MicrosoftAspNetCoreAuthenticationCorePackageVersion>
<MicrosoftAspNetCoreAuthenticationJwtBearerPackageVersion>2.1.0-preview1-27475</MicrosoftAspNetCoreAuthenticationJwtBearerPackageVersion>
<MicrosoftAspNetCoreAuthorizationPackageVersion>2.1.0-preview1-27579</MicrosoftAspNetCoreAuthorizationPackageVersion>
<MicrosoftAspNetCoreAuthorizationPolicyPackageVersion>2.1.0-preview1-27579</MicrosoftAspNetCoreAuthorizationPolicyPackageVersion>
<MicrosoftAspNetCoreCorsPackageVersion>2.1.0-preview1-27579</MicrosoftAspNetCoreCorsPackageVersion>
Expand Down
1 change: 0 additions & 1 deletion samples/ClientSample/HubSample.cs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
using System.Threading;
using System.Threading.Tasks;
using Microsoft.AspNetCore.SignalR.Client;
using Microsoft.AspNetCore.Sockets.Client;
using Microsoft.Extensions.CommandLineUtils;
using Microsoft.Extensions.Logging;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,9 +42,11 @@ public HubConnection Build()
throw new InvalidOperationException("Cannot create IConnection instance. The connection factory was not configured.");
}

IHubConnectionBuilder builder = this;
var connection = _connectionFactoryDelegate();
var loggerFactory = ((IHubConnectionBuilder)this).GetLoggerFactory();
var hubProtocol = ((IHubConnectionBuilder)this).GetHubProtocol();

var loggerFactory = builder.GetLoggerFactory();
var hubProtocol = builder.GetHubProtocol();

return new HubConnection(connection, hubProtocol ?? new JsonHubProtocol(), loggerFactory);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,16 +2,21 @@
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.

using System;
using System.Collections.Generic;
using System.Collections.ObjectModel;
using System.Net.Http;
using Microsoft.AspNetCore.Sockets;
using Microsoft.AspNetCore.Sockets.Client;
using Microsoft.AspNetCore.Sockets.Client.Http;

namespace Microsoft.AspNetCore.SignalR.Client
{
public static class HubConnectionBuilderHttpExtensions
{
public static readonly string TransportTypeKey = "TransportType";
public static readonly string HttpMessageHandlerKey = "HttpMessageHandler";
public static readonly string HeadersKey = "Headers";
public static readonly string JwtBearerTokenFactoryKey = "JwtBearerTokenFactory";

public static IHubConnectionBuilder WithUrl(this IHubConnectionBuilder hubConnectionBuilder, string url)
{
Expand All @@ -32,10 +37,18 @@ public static IHubConnectionBuilder WithUrl(this IHubConnectionBuilder hubConnec

hubConnectionBuilder.ConfigureConnectionFactory(() =>
{
var headers = hubConnectionBuilder.GetHeaders();
var httpOptions = new HttpOptions
{
HttpMessageHandler = hubConnectionBuilder.GetMessageHandler(),
Headers = headers != null ? new ReadOnlyDictionary<string, string>(headers) : null,
JwtBearerTokenFactory = hubConnectionBuilder.GetJwtBearerTokenFactory()
};

return new HttpConnection(url,
hubConnectionBuilder.GetTransport(),
hubConnectionBuilder.GetLoggerFactory(),
hubConnectionBuilder.GetMessageHandler());
httpOptions);
});
return hubConnectionBuilder;
}
Expand All @@ -52,6 +65,37 @@ public static IHubConnectionBuilder WithMessageHandler(this IHubConnectionBuilde
return hubConnectionBuilder;
}

public static IHubConnectionBuilder WithHeader(this IHubConnectionBuilder hubConnectionBuilder, string name, string value)
{
if (string.IsNullOrEmpty(name))
{
throw new ArgumentException("Header name cannot be null or empty string.", nameof(name));
}

var headers = hubConnectionBuilder.GetHeaders();
if (headers == null)
{
headers = new Dictionary<string, string>();
hubConnectionBuilder.AddSetting(HeadersKey, headers);
}

headers.Add(name, value);

return hubConnectionBuilder;
}

public static IHubConnectionBuilder WithJwtBearer(this IHubConnectionBuilder hubConnectionBuilder, Func<string> jwtBearerTokenFactory)
{
if (jwtBearerTokenFactory == null)
{
throw new ArgumentNullException(nameof(jwtBearerTokenFactory));
}

hubConnectionBuilder.AddSetting(JwtBearerTokenFactoryKey, jwtBearerTokenFactory);

return hubConnectionBuilder;
}

public static TransportType GetTransport(this IHubConnectionBuilder hubConnectionBuilder)
{
if (hubConnectionBuilder.TryGetSetting<TransportType>(TransportTypeKey, out var transportType))
Expand All @@ -67,5 +111,25 @@ public static HttpMessageHandler GetMessageHandler(this IHubConnectionBuilder hu
hubConnectionBuilder.TryGetSetting<HttpMessageHandler>(HttpMessageHandlerKey, out var messageHandler);
return messageHandler;
}

public static IDictionary<string, string> GetHeaders(this IHubConnectionBuilder hubConnectionBuilder)
{
if (hubConnectionBuilder.TryGetSetting<IDictionary<string, string>>(HeadersKey, out var headers))
{
return headers;
}

return null;
}

public static Func<string> GetJwtBearerTokenFactory(this IHubConnectionBuilder hubConnectionBuilder)
{
if (hubConnectionBuilder.TryGetSetting<Func<string>>(JwtBearerTokenFactoryKey, out var factory))
{
return factory;
}

return null;
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -3,18 +3,20 @@

using System;
using System.Net.Http;
using Microsoft.AspNetCore.Sockets.Client.Http;
using Microsoft.Extensions.Logging;

namespace Microsoft.AspNetCore.Sockets.Client
{
public class DefaultTransportFactory : ITransportFactory
{
private readonly HttpClient _httpClient;
private readonly HttpOptions _httpOptions;
private readonly TransportType _requestedTransportType;
private readonly ILoggerFactory _loggerFactory;
private static volatile bool _websocketsSupported = true;

public DefaultTransportFactory(TransportType requestedTransportType, ILoggerFactory loggerFactory, HttpClient httpClient)
public DefaultTransportFactory(TransportType requestedTransportType, ILoggerFactory loggerFactory, HttpClient httpClient, HttpOptions httpOptions)
{
if (requestedTransportType <= 0 || requestedTransportType > TransportType.All)
{
Expand All @@ -29,6 +31,7 @@ public DefaultTransportFactory(TransportType requestedTransportType, ILoggerFact
_requestedTransportType = requestedTransportType;
_loggerFactory = loggerFactory;
_httpClient = httpClient;
_httpOptions = httpOptions;
}

public ITransport CreateTransport(TransportType availableServerTransports)
Expand All @@ -37,7 +40,7 @@ public ITransport CreateTransport(TransportType availableServerTransports)
{
try
{
return new WebSocketsTransport(_loggerFactory);
return new WebSocketsTransport(_httpOptions, _loggerFactory);
}
catch (PlatformNotSupportedException)
{
Expand All @@ -47,12 +50,12 @@ public ITransport CreateTransport(TransportType availableServerTransports)

if ((availableServerTransports & TransportType.ServerSentEvents & _requestedTransportType) == TransportType.ServerSentEvents)
{
return new ServerSentEventsTransport(_httpClient, _loggerFactory);
return new ServerSentEventsTransport(_httpClient, _httpOptions, _loggerFactory);
}

if ((availableServerTransports & TransportType.LongPolling & _requestedTransportType) == TransportType.LongPolling)
{
return new LongPollingTransport(_httpClient, _loggerFactory);
return new LongPollingTransport(_httpClient, _httpOptions, _loggerFactory);
}

throw new InvalidOperationException("No requested transports available on the server.");
Expand Down
28 changes: 14 additions & 14 deletions src/Microsoft.AspNetCore.Sockets.Client.Http/HttpConnection.cs
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ public class HttpConnection : IConnection
private volatile int _connectionState = ConnectionState.Initial;
private volatile ChannelConnection<byte[], SendMessage> _transportChannel;
private readonly HttpClient _httpClient;
private readonly HttpOptions _httpOptions;
private volatile ITransport _transport;
private volatile Task _receiveLoopTask;
private TaskCompletionSource<object> _startTcs = new TaskCompletionSource<object>(TaskCreationOptions.RunContinuationsAsynchronously);
Expand All @@ -54,49 +55,47 @@ public HttpConnection(Uri url)
: this(url, TransportType.All)
{ }

public HttpConnection(Uri url, HttpMessageHandler httpMessageHandler)
: this(url, TransportType.All, loggerFactory: null, httpMessageHandler: httpMessageHandler)
{ }

public HttpConnection(Uri url, TransportType transportType)
: this(url, transportType, loggerFactory: null)
{
}

public HttpConnection(Uri url, ILoggerFactory loggerFactory)
: this(url, TransportType.All, loggerFactory, httpMessageHandler: null)
: this(url, TransportType.All, loggerFactory, httpOptions: null)
{
}

public HttpConnection(Uri url, TransportType transportType, ILoggerFactory loggerFactory)
: this(url, transportType, loggerFactory, httpMessageHandler: null)
: this(url, transportType, loggerFactory, httpOptions: null)
{
}

public HttpConnection(Uri url, TransportType transportType, ILoggerFactory loggerFactory, HttpMessageHandler httpMessageHandler)
public HttpConnection(Uri url, TransportType transportType, ILoggerFactory loggerFactory, HttpOptions httpOptions)
{
Url = url ?? throw new ArgumentNullException(nameof(url));

_loggerFactory = loggerFactory ?? NullLoggerFactory.Instance;
_logger = _loggerFactory.CreateLogger<HttpConnection>();
_httpOptions = httpOptions;

_requestedTransportType = transportType;
if (_requestedTransportType != TransportType.WebSockets)
{
_httpClient = httpMessageHandler == null ? new HttpClient() : new HttpClient(httpMessageHandler);
_httpClient = httpOptions?.HttpMessageHandler == null ? new HttpClient() : new HttpClient(httpOptions.HttpMessageHandler);
_httpClient.Timeout = HttpClientTimeout;
}

_transportFactory = new DefaultTransportFactory(transportType, _loggerFactory, _httpClient);
_transportFactory = new DefaultTransportFactory(transportType, _loggerFactory, _httpClient, httpOptions);
}

public HttpConnection(Uri url, ITransportFactory transportFactory, ILoggerFactory loggerFactory, HttpMessageHandler httpMessageHandler)
public HttpConnection(Uri url, ITransportFactory transportFactory, ILoggerFactory loggerFactory, HttpOptions httpOptions)
{
Url = url ?? throw new ArgumentNullException(nameof(url));
_loggerFactory = loggerFactory ?? NullLoggerFactory.Instance;
_logger = _loggerFactory.CreateLogger<HttpConnection>();
_httpClient = httpMessageHandler == null ? new HttpClient() : new HttpClient(httpMessageHandler);
_httpClient.Timeout = HttpClientTimeout;
_httpOptions = httpOptions;
_httpClient = _httpOptions?.HttpMessageHandler == null ? new HttpClient() : new HttpClient(_httpOptions?.HttpMessageHandler);
_httpClient.Timeout = HttpClientTimeout;
_transportFactory = transportFactory ?? throw new ArgumentNullException(nameof(transportFactory));
}

Expand Down Expand Up @@ -214,7 +213,7 @@ private async Task StartAsyncInternal()
}
}

private async static Task<NegotiationResponse> Negotiate(Uri url, HttpClient httpClient, ILogger logger)
private async Task<NegotiationResponse> Negotiate(Uri url, HttpClient httpClient, ILogger logger)
{
try
{
Expand All @@ -229,7 +228,8 @@ private async static Task<NegotiationResponse> Negotiate(Uri url, HttpClient htt

using (var request = new HttpRequestMessage(HttpMethod.Post, urlBuilder.Uri))
{
request.Headers.UserAgent.Add(Constants.UserAgentHeader);
SendUtils.PrepareHttpRequest(request, _httpOptions);

using (var response = await httpClient.SendAsync(request))
{
response.EnsureSuccessStatusCode();
Expand Down
16 changes: 16 additions & 0 deletions src/Microsoft.AspNetCore.Sockets.Client.Http/HttpOptions.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
// Copyright (c) .NET Foundation. All rights reserved.
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.

using System;
using System.Collections.Generic;
using System.Net.Http;

namespace Microsoft.AspNetCore.Sockets.Client.Http
{
public class HttpOptions
{
public HttpMessageHandler HttpMessageHandler { get; set; }
public IReadOnlyCollection<KeyValuePair<string, string>> Headers { get; set; }
public Func<string> JwtBearerTokenFactory { get; set; }
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.

using System;
using System.Collections.Generic;
using System.Net;
using System.Net.Http;
using System.Threading;
Expand All @@ -17,6 +18,7 @@ namespace Microsoft.AspNetCore.Sockets.Client
public class LongPollingTransport : ITransport
{
private readonly HttpClient _httpClient;
private readonly HttpOptions _httpOptions;
private readonly ILogger _logger;
private Channel<byte[], SendMessage> _application;
private Task _sender;
Expand All @@ -30,12 +32,13 @@ public class LongPollingTransport : ITransport
public TransferMode? Mode { get; private set; }

public LongPollingTransport(HttpClient httpClient)
: this(httpClient, null)
: this(httpClient, null, null)
{ }

public LongPollingTransport(HttpClient httpClient, ILoggerFactory loggerFactory)
public LongPollingTransport(HttpClient httpClient, HttpOptions httpOptions, ILoggerFactory loggerFactory)
{
_httpClient = httpClient;
_httpOptions = httpOptions;
_logger = (loggerFactory ?? NullLoggerFactory.Instance).CreateLogger<LongPollingTransport>();
}

Expand All @@ -54,7 +57,7 @@ public Task StartAsync(Uri url, Channel<byte[], SendMessage> application, Transf

// Start sending and polling (ask for binary if the server supports it)
_poller = Poll(url, _transportCts.Token);
_sender = SendUtils.SendMessages(url, _application, _httpClient, _transportCts, _logger, _connectionId);
_sender = SendUtils.SendMessages(url, _application, _httpClient, _httpOptions, _transportCts, _logger, _connectionId);

Running = Task.WhenAll(_sender, _poller).ContinueWith(t =>
{
Expand Down Expand Up @@ -90,7 +93,7 @@ private async Task Poll(Uri pollUrl, CancellationToken cancellationToken)
while (!cancellationToken.IsCancellationRequested)
{
var request = new HttpRequestMessage(HttpMethod.Get, pollUrl);
request.Headers.UserAgent.Add(Constants.UserAgentHeader);
SendUtils.PrepareHttpRequest(request, _httpOptions);

HttpResponseMessage response;

Expand Down
21 changes: 19 additions & 2 deletions src/Microsoft.AspNetCore.Sockets.Client.Http/SendUtils.cs
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ namespace Microsoft.AspNetCore.Sockets.Client
internal static class SendUtils
{
public static async Task SendMessages(Uri sendUrl, Channel<byte[], SendMessage> application, HttpClient httpClient,
CancellationTokenSource transportCts, ILogger logger, string connectionId)
HttpOptions httpOptions, CancellationTokenSource transportCts, ILogger logger, string connectionId)
{
logger.SendStarted(connectionId);
IList<SendMessage> messages = null;
Expand All @@ -39,7 +39,7 @@ public static async Task SendMessages(Uri sendUrl, Channel<byte[], SendMessage>

// Send them in a single post
var request = new HttpRequestMessage(HttpMethod.Post, sendUrl);
request.Headers.UserAgent.Add(Constants.UserAgentHeader);
PrepareHttpRequest(request, httpOptions);

// TODO: We can probably use a pipeline here or some kind of pooled memory.
// But where do we get the pool from? ArrayBufferPool.Instance?
Expand Down Expand Up @@ -107,5 +107,22 @@ public static async Task SendMessages(Uri sendUrl, Channel<byte[], SendMessage>

logger.SendStopped(connectionId);
}

public static void PrepareHttpRequest(HttpRequestMessage request, HttpOptions httpOptions)
{
if (httpOptions?.Headers != null)
{
foreach (var header in httpOptions.Headers)
{
request.Headers.Add(header.Key, header.Value);
}
}
request.Headers.UserAgent.Add(Constants.UserAgentHeader);

if (httpOptions?.JwtBearerTokenFactory != null)
{
request.Headers.Add("Authorization", $"Bearer {httpOptions.JwtBearerTokenFactory()}");
}
}
}
}
Loading

0 comments on commit 0bafb30

Please sign in to comment.