Skip to content

Commit

Permalink
Introduce SecurityTokenAuthenticator.ValidateTokenAsync api (CoreWCF#419
Browse files Browse the repository at this point in the history
)

* Introduce SecurityTokenAuthenticator.ValidateTokenAsync api
* Add async UserNamePasswordValidator impl
* Deprecates sync ValidateUserNamePassword
  • Loading branch information
g7ed6e authored Oct 5, 2021
1 parent 0c8c0f0 commit 62fd7e0
Show file tree
Hide file tree
Showing 37 changed files with 348 additions and 224 deletions.
43 changes: 34 additions & 9 deletions src/CoreWCF.Http/tests/BasicHttpTransportMessage.cs
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,15 @@
// The .NET Foundation licenses this file to you under the MIT license.

using System;
using System.Collections.Generic;
using System.ComponentModel;
using System.Net;
using System.Net.Security;
using System.Security.Cryptography.X509Certificates;
using System.ServiceModel.Channels;
using System.ServiceModel.Description;
using System.Threading;
using System.Threading.Tasks;
using CoreWCF;
using CoreWCF.Configuration;
using CoreWCF.IdentityModel.Selectors;
Expand All @@ -30,12 +32,19 @@ public SimpleBasicHttpTest(ITestOutputHelper output)
_output = output;
}

[Fact, Description("transport-security-with-basic-authentication")]
public void BasicHttpRequestReplyWithTransportMessageEchoString()
public static IEnumerable<object[]> GetTestsVariations()
{
yield return new[] { typeof(BasicHttpTransportWithMessageCredentialWithUserName<CustomTestValidator>) };
yield return new[] { typeof(BasicHttpTransportWithMessageCredentialWithUserName<CustomAsynchronousTestValidator>) };
}

[Theory, Description("transport-security-with-basic-authentication")]
[MemberData(nameof(GetTestsVariations))]
public void BasicHttpRequestReplyWithTransportMessageEchoString(Type startupType)
{
ServicePointManager.ServerCertificateValidationCallback += new RemoteCertificateValidationCallback(ValidateCertificate);
string testString = new string('a', 3000);
IWebHost host = ServiceHelper.CreateHttpsWebHostBuilder<BasicHttpTransportWithMessageCredentialWithUserName>(_output).Build();
IWebHost host = ServiceHelper.CreateHttpsWebHostBuilder(_output, startupType).Build();
using (host)
{
host.Start();
Expand Down Expand Up @@ -64,13 +73,14 @@ public void BasicHttpRequestReplyWithTransportMessageEchoString()
[Theory]
[InlineData(false)]
[InlineData(true)]
[UseCulture("en-US")]
public void BasicHttpsCustomBindingRequestReplyEchoString(bool useHttps)
{
string testString = new string('a', 4000);
IWebHost host = ServiceHelper.CreateHttpsWebHostBuilder<StartupCustomBinding>(_output).Build();
using (host)
{
String serviceUrl = (useHttps ? "https" : "http") + "://localhost:8443/BasicHttpWcfService/basichttp.svc";
string serviceUrl = (useHttps ? "https" : "http") + "://localhost:8443/BasicHttpWcfService/basichttp.svc";
host.Start();
System.ServiceModel.BasicHttpBinding BasicHttpBinding = ClientHelper.GetBufferedModeBinding(System.ServiceModel.BasicHttpSecurityMode.Transport);
var factory = new System.ServiceModel.ChannelFactory<ClientContract.IEchoService>(BasicHttpBinding,
Expand All @@ -86,7 +96,8 @@ public void BasicHttpsCustomBindingRequestReplyEchoString(bool useHttps)
string result = channel.EchoString(testString);
Assert.Equal(testString, result);
((IChannel)channel).Close();
}catch(Exception ex)
}
catch (Exception ex)
{
Assert.True(!useHttps && ex.Message.Contains("The provided URI scheme 'http' is invalid"));
}
Expand Down Expand Up @@ -136,10 +147,23 @@ public override void Validate(string userName, string password)
{
return;
}
else

throw new Exception("Permission Denied");
}
}

internal class CustomAsynchronousTestValidator : UserNamePasswordValidator
{
public override async ValueTask ValidateAsync(string userName, string password)
{
// simulate a DB / API roundtrip
await Task.Delay(100);
if (string.Compare(userName, "testuser@corewcf", StringComparison.OrdinalIgnoreCase) == 0)
{
throw new Exception("Permission Denied");
return;
}

throw new Exception("Permission Denied");
}
}

Expand All @@ -164,7 +188,8 @@ public override void ChangeHostBehavior(ServiceHostBase host)
}
}

internal class BasicHttpTransportWithMessageCredentialWithUserName : StartupBasicHttpBase
internal class BasicHttpTransportWithMessageCredentialWithUserName<TUserNamePasswordValidator> : StartupBasicHttpBase
where TUserNamePasswordValidator : UserNamePasswordValidator, new()
{
public BasicHttpTransportWithMessageCredentialWithUserName() :
base(CoreWCF.Channels.BasicHttpSecurityMode.TransportWithMessageCredential, BasicHttpMessageCredentialType.UserName)
Expand All @@ -177,7 +202,7 @@ public override void ChangeHostBehavior(ServiceHostBase host)
srvCredentials.UserNameAuthentication.UserNamePasswordValidationMode =
CoreWCF.Security.UserNamePasswordValidationMode.Custom;
srvCredentials.UserNameAuthentication.CustomUserNamePasswordValidator =
new CustomTestValidator();
new TUserNamePasswordValidator();
host.Description.Behaviors.Add(srvCredentials);
}
}
Expand Down
43 changes: 43 additions & 0 deletions src/CoreWCF.Http/tests/Helpers/ServiceHelper.cs
Original file line number Diff line number Diff line change
Expand Up @@ -194,6 +194,49 @@ public static IWebHostBuilder CreateHttpsWebHostBuilder<TStartup>(ITestOutputHel
})
.UseStartup<TStartup>();

public static IWebHostBuilder CreateHttpsWebHostBuilder(ITestOutputHelper outputHelper, Type startupType) =>
WebHost.CreateDefaultBuilder(Array.Empty<string>())
#if DEBUG
.ConfigureLogging((ILoggingBuilder logging) =>
{
if (outputHelper != default)
logging.AddProvider(new XunitLoggerProvider(outputHelper));
logging.AddFilter("Default", LogLevel.Debug);
logging.AddFilter("Microsoft", LogLevel.Debug);
logging.SetMinimumLevel(LogLevel.Debug);
})
#endif // DEBUG
.UseKestrel(options =>
{
options.Listen(address: IPAddress.Loopback, 8444, listenOptions =>
{
listenOptions.UseHttps(httpsOptions =>
{
#if NET472
httpsOptions.SslProtocols = SslProtocols.Tls12 | SslProtocols.Tls11 | SslProtocols.Tls;
#endif // NET472
});
if (Debugger.IsAttached)
{
listenOptions.UseConnectionLogging();
}
});
options.Listen(address: IPAddress.Loopback, 8443, listenOptions =>
{
listenOptions.UseHttps(httpsOptions =>
{
#if NET472
httpsOptions.SslProtocols = SslProtocols.Tls12 | SslProtocols.Tls11 | SslProtocols.Tls;
#endif // NET472
});
if (Debugger.IsAttached)
{
listenOptions.UseConnectionLogging();
}
});
})
.UseStartup(startupType);

public static void CloseServiceModelObjects(params System.ServiceModel.ICommunicationObject[] objects)
{
foreach (System.ServiceModel.ICommunicationObject comObj in objects)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
using System.Linq;
using System.ServiceModel.Channels;
using System.Threading;
using System.Threading.Tasks;
using CoreWCF.Configuration;
using CoreWCF.IdentityModel.Selectors;
using Helpers;
Expand Down Expand Up @@ -78,7 +79,8 @@ private void BasicUserNameAuth(bool isError, string userName)
}

[Fact, Description("Demuxer-failure-nettcp")]
public void NetTCPRequestReplyWithTransportMessageEchoStringDemuxFailure()
[UseCulture("en-US")]
public async Task NetTCPRequestReplyWithTransportMessageEchoStringDemuxFailure()
{
string testString = new string('a', 3000);
IWebHost host = ServiceHelper.CreateWebHostBuilder<StartUpPermissionBaseForTCDemuxFailure>(_output).Build();
Expand All @@ -102,12 +104,12 @@ public void NetTCPRequestReplyWithTransportMessageEchoStringDemuxFailure()
try
{
((IChannel)channel).Open();
Thread.Sleep(6000);
await Task.Delay(6000);
string result = channel.EchoString(testString);
}
catch (Exception ex)
{
Assert.Equal(typeof(System.ServiceModel.FaultException), ex.InnerException?.GetType());
Assert.IsAssignableFrom<System.ServiceModel.FaultException>(ex.InnerException);
Assert.Contains("expired security context token", ex.InnerException.Message);
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -149,22 +149,14 @@ void IChannelBindingProvider.EnableChannelBindingSupport()
}


bool IChannelBindingProvider.IsChannelBindingSupportEnabled
{
get
{
return _enableChannelBinding;
}
}
bool IChannelBindingProvider.IsChannelBindingSupportEnabled => _enableChannelBinding;

public override StreamUpgradeAcceptor CreateUpgradeAcceptor()
{
ThrowIfDisposedOrNotOpen();
return new SslStreamSecurityUpgradeAcceptor(this);
}



protected override void OnAbort()
{
if (_clientCertificateAuthenticator != null)
Expand Down Expand Up @@ -244,13 +236,7 @@ internal ChannelBinding ChannelBinding
}
}

internal bool IsChannelBindingSupportEnabled
{
get
{
return ((IChannelBindingProvider)_parent).IsChannelBindingSupportEnabled;
}
}
internal bool IsChannelBindingSupportEnabled => ((IChannelBindingProvider)_parent).IsChannelBindingSupportEnabled;

protected override async Task<(Stream, SecurityMessageProperty)> OnAcceptUpgradeAsync(Stream stream)
{
Expand Down Expand Up @@ -298,7 +284,12 @@ private bool ValidateRemoteCertificate(object sender, X509Certificate certificat
try
{
SecurityToken token = new X509SecurityToken(certificate2, false);
ReadOnlyCollection<IAuthorizationPolicy> authorizationPolicies = _parent.ClientCertificateAuthenticator.ValidateToken(token);
ReadOnlyCollection<IAuthorizationPolicy> authorizationPolicies;
var validationValueTask = _parent.ClientCertificateAuthenticator.ValidateTokenAsync(token);
authorizationPolicies = validationValueTask.IsCompleted
? validationValueTask.Result
: validationValueTask.AsTask().GetAwaiter().GetResult();

_clientSecurity = new SecurityMessageProperty
{
TransportToken = new SecurityTokenSpecification(token, authorizationPolicies),
Expand All @@ -322,13 +313,6 @@ public override SecurityMessageProperty GetRemoteSecurity()
}
if (_clientCertificate != null)
{
SecurityToken token = new X509SecurityToken(_clientCertificate);
ReadOnlyCollection<IAuthorizationPolicy> authorizationPolicies = SecurityUtils.NonValidatingX509Authenticator.ValidateToken(token);
_clientSecurity = new SecurityMessageProperty
{
TransportToken = new SecurityTokenSpecification(token, authorizationPolicies),
ServiceSecurityContext = new ServiceSecurityContext(authorizationPolicies)
};
return _clientSecurity;
}
return base.GetRemoteSecurity();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ protected override TimeSpan DefaultCloseTimeout

protected override TimeSpan DefaultOpenTimeout
{
get { return _closeTimeout; }
get { return _openTimeout; }
}

public virtual T GetProperty<T>() where T : class
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -168,11 +168,11 @@ await negotiateStream.AuthenticateAsServerAsync(_parent.ServerCredential, _paren
SR.Format(SR.NegotiationFailedIO, ioException.Message), ioException));
}

SecurityMessageProperty remoteSecurity = CreateClientSecurity(negotiateStream, _parent.ExtractGroupsForWindowsAccounts);
SecurityMessageProperty remoteSecurity = await CreateClientSecurityAsync(negotiateStream, _parent.ExtractGroupsForWindowsAccounts);
return (negotiateStream, remoteSecurity);
}

private SecurityMessageProperty CreateClientSecurity(NegotiateStream negotiateStream,
private async Task<SecurityMessageProperty> CreateClientSecurityAsync(NegotiateStream negotiateStream,
bool extractGroupsForWindowsAccounts)
{
IIdentity remoteIdentity = negotiateStream.RemoteIdentity;
Expand All @@ -190,7 +190,7 @@ private SecurityMessageProperty CreateClientSecurity(NegotiateStream negotiateSt
ClaimsIdentity claimsIdentity = new ClaimsIdentity(remoteIdentity);
token = new GenericSecurityToken(remoteIdentity.Name, SecurityUniqueId.Create().Value);
}
authorizationPolicies = authenticator.ValidateToken(token);
authorizationPolicies = await authenticator.ValidateTokenAsync(token);
SecurityMessageProperty clientSecurity = new SecurityMessageProperty
{
TransportToken = new SecurityTokenSpecification(token, authorizationPolicies),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -335,25 +335,25 @@ private void ThrowIfSecureConversationCloseMessage(Message message)
}
}

internal SecurityProtocolCorrelationState VerifyIncomingMessage(ref Message message, TimeSpan timeout, params SecurityProtocolCorrelationState[] correlationState)
internal async ValueTask<(Message, SecurityProtocolCorrelationState)> VerifyIncomingMessageAsync(Message message, TimeSpan timeout, params SecurityProtocolCorrelationState[] correlationState)
{
if (message == null)
{
return null;
return (null, null);
}
Fx.Assert(SecurityProtocol != null, "SecurityProtocol can't be null");
ThrowIfSecureConversationCloseMessage(message);
return SecurityProtocol.VerifyIncomingMessage(ref message, timeout, correlationState);
return await SecurityProtocol.VerifyIncomingMessageAsync(message, timeout, correlationState);
}

internal void VerifyIncomingMessage(ref Message message, TimeSpan timeout)
internal ValueTask<Message> VerifyIncomingMessageAsync(Message message, TimeSpan timeout)
{
if (message == null)
{
return;
return new ValueTask<Message>((Message)null);
}
ThrowIfSecureConversationCloseMessage(message);
SecurityProtocol.VerifyIncomingMessage(ref message, timeout);
return SecurityProtocol.VerifyIncomingMessageAsync(message, timeout);
}

public abstract Task DispatchAsync(RequestContext context);
Expand Down Expand Up @@ -389,7 +389,7 @@ public SecurityReplyChannelDispatcher(SecurityServiceDispatcher securityServiceD
public CommunicationState State => OuterChannel.State;


internal RequestContext ProcessReceivedRequest(RequestContext requestContext)
internal async ValueTask<RequestContext> ProcessReceivedRequestAsync(RequestContext requestContext)
{
if (requestContext == null)
{
Expand All @@ -405,7 +405,10 @@ internal RequestContext ProcessReceivedRequest(RequestContext requestContext)
}
try
{
SecurityProtocolCorrelationState correlationState = VerifyIncomingMessage(ref message, timeoutHelper.RemainingTime(), null);
(Message message, SecurityProtocolCorrelationState correlationState) verifiedIncomingMessage = await VerifyIncomingMessageAsync(message, timeoutHelper.RemainingTime(), null);
message = verifiedIncomingMessage.message;
SecurityProtocolCorrelationState correlationState = verifiedIncomingMessage.correlationState;

if (message.Headers.RelatesTo == null && message.Headers.MessageId != null)
{
message.Headers.RelatesTo = message.Headers.MessageId;
Expand Down Expand Up @@ -454,7 +457,7 @@ private void SendFaultIfRequired(Exception e, RequestContext innerContext, TimeS

public override async Task DispatchAsync(RequestContext context)
{
SecurityRequestContext securedMessage = (SecurityRequestContext)ProcessReceivedRequest(context);
SecurityRequestContext securedMessage = (SecurityRequestContext)(await ProcessReceivedRequestAsync(context));
if (SecurityServiceDispatcher.SessionMode) // for SCT, sessiontoken is created so we channel the call to SecurityAuthentication and evevntually SecurityServerSession.
{
IServiceChannelDispatcher serviceChannelDispatcher =
Expand Down Expand Up @@ -599,7 +602,7 @@ public override Task DispatchAsync(RequestContext context)
public override async Task DispatchAsync(Message message)
{
Fx.Assert(State == CommunicationState.Opened, "Expected dispatcher state to be Opened, instead it's " + State.ToString());
ProcessInnerItem(message, ServiceDefaults.SendTimeout);
message = await ProcessInnerItemAsync(message, ServiceDefaults.SendTimeout);
if (_serviceChannelDispatcher == null)
{
_serviceChannelDispatcher = await SecurityServiceDispatcher.
Expand All @@ -618,7 +621,7 @@ public Task SendAsync(Message message, CancellationToken token)
return SendAsync(message);
}

private Message ProcessInnerItem(Message innerItem, TimeSpan timeout)
private async ValueTask<Message> ProcessInnerItemAsync(Message innerItem, TimeSpan timeout)
{
if (innerItem == null)
{
Expand All @@ -629,7 +632,7 @@ private Message ProcessInnerItem(Message innerItem, TimeSpan timeout)
Message unverifiedMessage = innerItem;
try
{
VerifyIncomingMessage(ref innerItem, timeout);
innerItem = await VerifyIncomingMessageAsync(innerItem, timeout);
}
catch (MessageSecurityException e)
{
Expand Down
Loading

0 comments on commit 62fd7e0

Please sign in to comment.