From bf6a61590273e795324a0c3900c7cea7b7f401ee Mon Sep 17 00:00:00 2001 From: Richard Beauchamp Date: Fri, 13 Sep 2024 08:12:25 -0700 Subject: [PATCH] fix: issue with subscription authorization (#66) --- README.md | 16 +- .../LiveDocs.GraphQLApi.csproj | 6 +- .../LiveDocs.ServiceDefaults.csproj | 2 +- .../Extensions/GraphQLBuilderExtensions.cs | 5 +- src/RxDBDotNet/RxDBDotNet.csproj | 7 +- .../Security/SocketConnectPayload.cs | 15 ++ .../Security/SubscriptionAuthMiddleware.cs | 211 ++++++++++++++++++ .../RxDBDotNet.TestModelGenerator.csproj | 2 +- .../RxDBDotNet.Tests.Setup.csproj | 4 +- .../Model/GraphQLTestModel.cs | 43 +--- .../RxDBDotNet.Tests/RxDBDotNet.Tests.csproj | 8 +- tests/RxDBDotNet.Tests/SubscriptionTests.cs | 129 +++++++++++ .../Utils/GraphQLSubscriptionClient.cs | 65 ++++-- .../Utils/WebApplicationFactoryExtensions.cs | 6 +- 14 files changed, 427 insertions(+), 92 deletions(-) create mode 100644 src/RxDBDotNet/Security/SocketConnectPayload.cs create mode 100644 src/RxDBDotNet/Security/SubscriptionAuthMiddleware.cs diff --git a/README.md b/README.md index 4d226fe..4791c47 100644 --- a/README.md +++ b/README.md @@ -1,17 +1,17 @@ +# RxDBDotNet +

- NuGet + NuGet Version + + + NuGet Downloads codecov - - CI -

-# RxDBDotNet - RxDBDotNet is a powerful .NET library that implements the RxDB replication protocol, enabling real-time data synchronization between RxDB clients and .NET servers using GraphQL and Hot Chocolate. It extends the standard RxDB replication protocol with .NET-specific enhancements. ## Key Features @@ -427,10 +427,6 @@ mutation PushWorkspace($input: PushWorkspaceInput!) { ... on UnauthorizedAccessError { message } - ... on ArgumentNullError { - message - paramName - } } } } diff --git a/example/LiveDocs.GraphQLApi/LiveDocs.GraphQLApi.csproj b/example/LiveDocs.GraphQLApi/LiveDocs.GraphQLApi.csproj index 1339fbe..c4931af 100644 --- a/example/LiveDocs.GraphQLApi/LiveDocs.GraphQLApi.csproj +++ b/example/LiveDocs.GraphQLApi/LiveDocs.GraphQLApi.csproj @@ -29,9 +29,9 @@ - - - + + + diff --git a/example/LiveDocs.ServiceDefaults/LiveDocs.ServiceDefaults.csproj b/example/LiveDocs.ServiceDefaults/LiveDocs.ServiceDefaults.csproj index 2b27347..cda14be 100644 --- a/example/LiveDocs.ServiceDefaults/LiveDocs.ServiceDefaults.csproj +++ b/example/LiveDocs.ServiceDefaults/LiveDocs.ServiceDefaults.csproj @@ -23,7 +23,7 @@ - + diff --git a/src/RxDBDotNet/Extensions/GraphQLBuilderExtensions.cs b/src/RxDBDotNet/Extensions/GraphQLBuilderExtensions.cs index 244a75a..a18f606 100644 --- a/src/RxDBDotNet/Extensions/GraphQLBuilderExtensions.cs +++ b/src/RxDBDotNet/Extensions/GraphQLBuilderExtensions.cs @@ -37,6 +37,8 @@ public static IRequestExecutorBuilder AddReplication(this IRequestExecutorBuilde builder.AddFiltering(); + builder.AddSocketSessionInterceptor(); + // Ensure Query, Mutation, and Subscription types exist EnsureRootTypesExist(builder); @@ -224,9 +226,6 @@ private static void AddFieldErrorTypes(IObjectFieldDescriptor field, field.Error(); addedErrorTypes.Add(typeof(UnauthorizedAccessException)); - field.Error(); - addedErrorTypes.Add(typeof(ArgumentNullException)); - // update the foreach code to not add the AuthenticationException error type if it has already been added foreach (var errorType in replicationOptions.Errors) { diff --git a/src/RxDBDotNet/RxDBDotNet.csproj b/src/RxDBDotNet/RxDBDotNet.csproj index 981b203..2cf8157 100644 --- a/src/RxDBDotNet/RxDBDotNet.csproj +++ b/src/RxDBDotNet/RxDBDotNet.csproj @@ -30,10 +30,11 @@ MIT - + - - + + + diff --git a/src/RxDBDotNet/Security/SocketConnectPayload.cs b/src/RxDBDotNet/Security/SocketConnectPayload.cs new file mode 100644 index 0000000..5d6ddef --- /dev/null +++ b/src/RxDBDotNet/Security/SocketConnectPayload.cs @@ -0,0 +1,15 @@ +namespace RxDBDotNet.Security; + +/// +/// Represents the payload for a socket connection, containing the necessary authorization information. +/// +public class SocketConnectPayload +{ + /// + /// Gets the authorization token required for establishing a WebSocket connection. + /// + /// + /// A string representing the authorization token, typically in the form of a JWT. + /// + public required string Authorization { get; init; } +} diff --git a/src/RxDBDotNet/Security/SubscriptionAuthMiddleware.cs b/src/RxDBDotNet/Security/SubscriptionAuthMiddleware.cs new file mode 100644 index 0000000..c8b3205 --- /dev/null +++ b/src/RxDBDotNet/Security/SubscriptionAuthMiddleware.cs @@ -0,0 +1,211 @@ +using System.Security.Claims; +using HotChocolate.AspNetCore; +using HotChocolate.AspNetCore.Subscriptions; +using HotChocolate.AspNetCore.Subscriptions.Protocols; +using Microsoft.AspNetCore.Authentication; +using Microsoft.AspNetCore.Authentication.JwtBearer; +using Microsoft.Extensions.Options; + +namespace RxDBDotNet.Security; + +/// +/// Middleware for authenticating WebSocket connections in GraphQL subscriptions. +/// This middleware implements part of the graphql-transport-ws protocol, specifically handling the ConnectionInit message. +/// It validates JWT tokens sent in the connection payload and sets up the ClaimsPrincipal for authenticated connections. +/// If JWT authentication is not configured, it allows all connections. +/// +/// +/// This middleware should be registered with the GraphQL server using AddSocketSessionInterceptor<SubscriptionAuthMiddleware>(). +/// It uses the same JWT configuration as set up in AddJwtBearer() for consistency across HTTP and WebSocket connections, if available. +/// +/// According to the graphql-transport-ws protocol: +/// - The server must receive the connection initialisation message within the allowed waiting time. +/// - If the server wishes to reject the connection during authentication, it should close the socket with the event 4403: Forbidden. +/// - If the server receives more than one ConnectionInit message, it should close the socket with the event 4429: Too many initialisation requests. +/// +/// Note: This implementation assumes that Hot Chocolate handles the connection timeout and multiple ConnectionInit messages internally. +/// If this is not the case, additional logic would need to be added to this middleware to fully comply with the protocol. +/// +public class SubscriptionAuthMiddleware : DefaultSocketSessionInterceptor +{ + private readonly IAuthenticationSchemeProvider? _schemeProvider; + private readonly IOptionsMonitor _jwtOptionsMonitor; + + /// + /// Initializes a new instance of the class. + /// + /// The authentication scheme provider. + /// The options monitor for JWT bearer token validation. + /// + /// We inject both IAuthenticationSchemeProvider and IOptionsMonitor<JwtBearerOptions> for the following reasons: + /// 1. IAuthenticationSchemeProvider allows us to check if JWT Bearer authentication is configured. + /// 2. IOptionsMonitor<JwtBearerOptions> provides access to the JWT configuration for token validation. + /// This approach allows the middleware to work correctly whether authentication is configured or not. + /// + public SubscriptionAuthMiddleware( + IAuthenticationSchemeProvider? schemeProvider, + IOptionsMonitor jwtOptionsMonitor) + { + _schemeProvider = schemeProvider; + _jwtOptionsMonitor = jwtOptionsMonitor; + } + + /// + /// Called when a new WebSocket connection is being established. + /// This method handles the ConnectionInit message as per the graphql-transport-ws protocol. + /// It validates the JWT token in the connection payload and sets up the ClaimsPrincipal for authenticated connections. + /// If JWT authentication is not configured, it allows all connections. + /// + /// The socket session for the connection. + /// The payload of the ConnectionInit message. + /// A token to cancel the operation. + /// Thrown when or is null. + /// A indicating whether the connection was accepted or rejected. + /// + /// This method follows these steps: + /// 1. Check if JWT Bearer authentication is configured. + /// 2. If not configured, accept all connections (allowing for non-authenticated setups). + /// 3. If configured, validate the JWT token from the ConnectionInit message payload. + /// 4. Set up the ClaimsPrincipal for authenticated connections. + /// 5. If authentication fails, reject the connection with a 4403: Forbidden status. + /// This approach ensures that the middleware works in both authenticated and non-authenticated scenarios, + /// providing flexibility for different application setups while adhering to the graphql-transport-ws protocol. + /// + public override async ValueTask OnConnectAsync( + ISocketSession session, + IOperationMessagePayload connectionInitMessage, + CancellationToken cancellationToken = default) + { + ArgumentNullException.ThrowIfNull(session); + ArgumentNullException.ThrowIfNull(connectionInitMessage); + + try + { + // Check if JWT Bearer authentication is configured + // This allows the middleware to work in both authenticated and non-authenticated setups + if (!await IsJwtBearerConfiguredAsync().ConfigureAwait(false)) + { + // If JWT Bearer is not configured, we accept all connections + // This is crucial for supporting non-authenticated scenarios + return ConnectionStatus.Accept(); + } + + // JWT Bearer is configured, so we proceed with token validation + var connectPayload = connectionInitMessage.As(); + var authorizationHeader = connectPayload?.Authorization; + + // Ensure the Authorization header is present and in the correct format + if (string.IsNullOrEmpty(authorizationHeader) || !authorizationHeader.StartsWith("Bearer ", StringComparison.OrdinalIgnoreCase)) + { + // As per the protocol, we reject the connection with a 4403: Forbidden status + return RejectConnection(); + } + + // Extract the token from the Authorization header + var token = authorizationHeader["Bearer ".Length..].Trim(); + + // Validate the token + var claimsPrincipal = await ValidateTokenAsync(token).ConfigureAwait(false); + + if (claimsPrincipal != null) + { + // If the token is valid, set the ClaimsPrincipal on the HttpContext + // This allows the rest of the application to access the authenticated user's claims + session.Connection.HttpContext.User = claimsPrincipal; + return ConnectionStatus.Accept(); + } + + // If the token is invalid, reject the connection with a 4403: Forbidden status + return RejectConnection(); + } + catch + { + // If any unexpected error occurs during the process, reject the connection + // This ensures that we don't accidentally allow unauthorized access in case of errors + return RejectConnection(); + } + } + + private static ConnectionStatus RejectConnection() + { + return ConnectionStatus.Reject("4403: Forbidden", new Dictionary(StringComparer.Ordinal) + { + { "reason", "Authentication failed" }, + }); + } + + /// + /// Validates the provided JWT token using the configured JWT bearer options. + /// + /// The JWT token to validate. + /// + /// A if the token is valid and a non-null principal was created; otherwise, null. + /// + /// + /// This method uses the same validation parameters as configured in AddJwtBearer(), + /// ensuring consistency between HTTP and WebSocket authentication. + /// The method is designed to handle exceptions during token validation, returning null for any validation failure. + /// This approach allows the calling method to easily distinguish between valid and invalid tokens. + /// + private async Task ValidateTokenAsync(string token) + { + // Retrieve the JWT Bearer options. These options are configured when setting up JWT authentication + var jwtBearerOptions = _jwtOptionsMonitor.Get(JwtBearerDefaults.AuthenticationScheme); + + // Get the token handler from the options. This is typically a JwtSecurityTokenHandler + var tokenHandler = jwtBearerOptions.TokenHandlers.Single(); + + // Get the token validation parameters from the options + var validationParameters = jwtBearerOptions.TokenValidationParameters; + + try + { + // Attempt to validate the token + var tokenValidationResult = await tokenHandler + .ValidateTokenAsync(token, validationParameters) + .ConfigureAwait(false); + + // If the token is valid, create and return a new ClaimsPrincipal + if (tokenValidationResult.IsValid) + { + return new ClaimsPrincipal(tokenValidationResult.ClaimsIdentity); + } + } + catch (Exception) + { + // If any exception occurs during validation, we catch it and return null + // This is to ensure that any unexpected errors in token validation are treated as validation failures + return null; + } + + // If we reach here, the token was invalid, so we return null + return null; + } + + /// + /// Checks if JWT Bearer authentication is explicitly configured. + /// + /// true if JWT Bearer authentication is explicitly configured; otherwise, false. + /// + /// This method checks for the presence of the JWT Bearer authentication scheme. + /// The presence of this scheme is a reliable indicator that JWT Bearer authentication has been configured. + /// This approach is chosen because: + /// 1. It's more reliable than checking specific option values, which might have default values even when not explicitly set. + /// 2. It's simpler and faster than comparing multiple option values. + /// 3. It directly reflects whether the AddJwtBearer() method has been called in the application's startup configuration. + /// + private async Task IsJwtBearerConfiguredAsync() + { + // Attempt to retrieve the JWT Bearer authentication scheme + if (_schemeProvider != null) + { + var scheme = await _schemeProvider.GetSchemeAsync(JwtBearerDefaults.AuthenticationScheme).ConfigureAwait(false); + + // If the scheme is not null, it means JWT Bearer authentication has been configured + return scheme != null; + } + + // if _schemeProvider is null, we assume that JWT Bearer authentication is not configured + return false; + } +} diff --git a/tests/RxDBDotNet.TestModelGenerator/RxDBDotNet.TestModelGenerator.csproj b/tests/RxDBDotNet.TestModelGenerator/RxDBDotNet.TestModelGenerator.csproj index 4009e62..f266379 100644 --- a/tests/RxDBDotNet.TestModelGenerator/RxDBDotNet.TestModelGenerator.csproj +++ b/tests/RxDBDotNet.TestModelGenerator/RxDBDotNet.TestModelGenerator.csproj @@ -23,7 +23,7 @@ - + diff --git a/tests/RxDBDotNet.Tests.Setup/RxDBDotNet.Tests.Setup.csproj b/tests/RxDBDotNet.Tests.Setup/RxDBDotNet.Tests.Setup.csproj index 1986a4d..f0487f4 100644 --- a/tests/RxDBDotNet.Tests.Setup/RxDBDotNet.Tests.Setup.csproj +++ b/tests/RxDBDotNet.Tests.Setup/RxDBDotNet.Tests.Setup.csproj @@ -28,8 +28,8 @@ - - + + diff --git a/tests/RxDBDotNet.Tests/Model/GraphQLTestModel.cs b/tests/RxDBDotNet.Tests/Model/GraphQLTestModel.cs index d83c812..2964d7a 100644 --- a/tests/RxDBDotNet.Tests/Model/GraphQLTestModel.cs +++ b/tests/RxDBDotNet.Tests/Model/GraphQLTestModel.cs @@ -1065,7 +1065,6 @@ public static class GraphQlTypes public const string String = "String"; public const string Uuid = "UUID"; - public const string ArgumentNullError = "ArgumentNullError"; public const string AuthenticationError = "AuthenticationError"; public const string Checkpoint = "Checkpoint"; public const string LiveDoc = "LiveDoc"; @@ -1517,27 +1516,6 @@ public partial class UnauthorizedAccessErrorQueryBuilderGql : GraphQlQueryBuilde public UnauthorizedAccessErrorQueryBuilderGql ExceptMessage() => ExceptField("message"); } - public partial class ArgumentNullErrorQueryBuilderGql : GraphQlQueryBuilder - { - private static readonly GraphQlFieldMetadata[] AllFieldMetadata = - { - new GraphQlFieldMetadata { Name = "message" }, - new GraphQlFieldMetadata { Name = "paramName" } - }; - - protected override string TypeName { get; } = "ArgumentNullError"; - - public override IReadOnlyList AllFields { get; } = AllFieldMetadata; - - public ArgumentNullErrorQueryBuilderGql WithMessage(string? alias = null, SkipDirective? skip = null, IncludeDirective? include = null) => WithScalarField("message", alias, new GraphQlDirective?[] { skip, include }); - - public ArgumentNullErrorQueryBuilderGql ExceptMessage() => ExceptField("message"); - - public ArgumentNullErrorQueryBuilderGql WithParamName(string? alias = null, SkipDirective? skip = null, IncludeDirective? include = null) => WithScalarField("paramName", alias, new GraphQlDirective?[] { skip, include }); - - public ArgumentNullErrorQueryBuilderGql ExceptParamName() => ExceptField("paramName"); - } - public partial class WorkspaceQueryBuilderGql : GraphQlQueryBuilder { private static readonly GraphQlFieldMetadata[] AllFieldMetadata = @@ -1640,8 +1618,6 @@ public partial class ErrorQueryBuilderGql : GraphQlQueryBuilder WithFragment(authenticationErrorQueryBuilder, new GraphQlDirective?[] { skip, include }); public ErrorQueryBuilderGql WithUnauthorizedAccessErrorFragment(UnauthorizedAccessErrorQueryBuilderGql unauthorizedAccessErrorQueryBuilder, SkipDirective? skip = null, IncludeDirective? include = null) => WithFragment(unauthorizedAccessErrorQueryBuilder, new GraphQlDirective?[] { skip, include }); - - public ErrorQueryBuilderGql WithArgumentNullErrorFragment(ArgumentNullErrorQueryBuilderGql argumentNullErrorQueryBuilder, SkipDirective? skip = null, IncludeDirective? include = null) => WithFragment(argumentNullErrorQueryBuilder, new GraphQlDirective?[] { skip, include }); } public partial class PushUserErrorQueryBuilderGql : GraphQlQueryBuilder @@ -1657,8 +1633,6 @@ public partial class PushUserErrorQueryBuilderGql : GraphQlQueryBuilder WithFragment(authenticationErrorQueryBuilder, new GraphQlDirective?[] { skip, include }); public PushUserErrorQueryBuilderGql WithUnauthorizedAccessErrorFragment(UnauthorizedAccessErrorQueryBuilderGql unauthorizedAccessErrorQueryBuilder, SkipDirective? skip = null, IncludeDirective? include = null) => WithFragment(unauthorizedAccessErrorQueryBuilder, new GraphQlDirective?[] { skip, include }); - - public PushUserErrorQueryBuilderGql WithArgumentNullErrorFragment(ArgumentNullErrorQueryBuilderGql argumentNullErrorQueryBuilder, SkipDirective? skip = null, IncludeDirective? include = null) => WithFragment(argumentNullErrorQueryBuilder, new GraphQlDirective?[] { skip, include }); } public partial class PushUserPayloadQueryBuilderGql : GraphQlQueryBuilder @@ -1695,8 +1669,6 @@ public partial class PushWorkspaceErrorQueryBuilderGql : GraphQlQueryBuilder WithFragment(authenticationErrorQueryBuilder, new GraphQlDirective?[] { skip, include }); public PushWorkspaceErrorQueryBuilderGql WithUnauthorizedAccessErrorFragment(UnauthorizedAccessErrorQueryBuilderGql unauthorizedAccessErrorQueryBuilder, SkipDirective? skip = null, IncludeDirective? include = null) => WithFragment(unauthorizedAccessErrorQueryBuilder, new GraphQlDirective?[] { skip, include }); - - public PushWorkspaceErrorQueryBuilderGql WithArgumentNullErrorFragment(ArgumentNullErrorQueryBuilderGql argumentNullErrorQueryBuilder, SkipDirective? skip = null, IncludeDirective? include = null) => WithFragment(argumentNullErrorQueryBuilder, new GraphQlDirective?[] { skip, include }); } public partial class PushWorkspacePayloadQueryBuilderGql : GraphQlQueryBuilder @@ -1733,8 +1705,6 @@ public partial class PushLiveDocErrorQueryBuilderGql : GraphQlQueryBuilder WithFragment(authenticationErrorQueryBuilder, new GraphQlDirective?[] { skip, include }); public PushLiveDocErrorQueryBuilderGql WithUnauthorizedAccessErrorFragment(UnauthorizedAccessErrorQueryBuilderGql unauthorizedAccessErrorQueryBuilder, SkipDirective? skip = null, IncludeDirective? include = null) => WithFragment(unauthorizedAccessErrorQueryBuilder, new GraphQlDirective?[] { skip, include }); - - public PushLiveDocErrorQueryBuilderGql WithArgumentNullErrorFragment(ArgumentNullErrorQueryBuilderGql argumentNullErrorQueryBuilder, SkipDirective? skip = null, IncludeDirective? include = null) => WithFragment(argumentNullErrorQueryBuilder, new GraphQlDirective?[] { skip, include }); } public partial class PushLiveDocPayloadQueryBuilderGql : GraphQlQueryBuilder @@ -3216,22 +3186,15 @@ public partial class CheckpointGql } [GraphQlObjectType("AuthenticationError")] - public partial class AuthenticationErrorGql : IPushUserErrorGql, IErrorGql + public partial class AuthenticationErrorGql : IPushUserErrorGql, IPushWorkspaceErrorGql, IPushLiveDocErrorGql, IErrorGql { public string Message { get; set; } } [GraphQlObjectType("UnauthorizedAccessError")] - public partial class UnauthorizedAccessErrorGql : IPushUserErrorGql, IErrorGql - { - public string Message { get; set; } - } - - [GraphQlObjectType("ArgumentNullError")] - public partial class ArgumentNullErrorGql : IPushUserErrorGql, IErrorGql + public partial class UnauthorizedAccessErrorGql : IPushUserErrorGql, IPushWorkspaceErrorGql, IPushLiveDocErrorGql, IErrorGql { public string Message { get; set; } - public string? ParamName { get; set; } } public partial class WorkspaceGql @@ -3298,5 +3261,5 @@ public partial class PushLiveDocPayloadGql public ICollection? Errors { get; set; } } #endregion -#nullable restore + #nullable restore } diff --git a/tests/RxDBDotNet.Tests/RxDBDotNet.Tests.csproj b/tests/RxDBDotNet.Tests/RxDBDotNet.Tests.csproj index f6752a7..59a0563 100644 --- a/tests/RxDBDotNet.Tests/RxDBDotNet.Tests.csproj +++ b/tests/RxDBDotNet.Tests/RxDBDotNet.Tests.csproj @@ -25,15 +25,15 @@ - - + + all runtime; build; native; contentfiles; analyzers; buildtransitive - - + + all diff --git a/tests/RxDBDotNet.Tests/SubscriptionTests.cs b/tests/RxDBDotNet.Tests/SubscriptionTests.cs index 167ada9..e83c765 100644 --- a/tests/RxDBDotNet.Tests/SubscriptionTests.cs +++ b/tests/RxDBDotNet.Tests/SubscriptionTests.cs @@ -24,6 +24,135 @@ public async Task DisposeAsync() await TestContext.DisposeAsync(); } + [Fact] + public async Task CreateWorkspaceShouldNotPropagateNewWorkspaceForAnAuthenticatedUserThroughASecuredSubscriptionAsync() + { + // Arrange + TestContext = new TestScenarioBuilder() + .WithAuthorization() + .ConfigureReplicatedDocument(options => options.Security.RequirePolicyToRead("IsSystemAdmin")) + .Build(); + + await TestContext.CreateWorkspaceAsync(TestContext.CancellationToken); + + Func act = async () => await TestContext.Factory.CreateGraphQLSubscriptionClientAsync(TestContext.CancellationToken); + + await act.Should().ThrowAsync(); + } + + [Fact] + public async Task CreateWorkspaceShouldNotPropagateNewWorkspaceForAnUnAuthorizedUserThroughASecuredSubscriptionAsync() + { + // Arrange + TestContext = new TestScenarioBuilder() + .WithAuthorization() + .ConfigureReplicatedDocument(options => options.Security.RequirePolicyToRead("IsSystemAdmin")) + .Build(); + + var workspace = await TestContext.CreateWorkspaceAsync(TestContext.CancellationToken); + var admin = await TestContext.CreateUserAsync(workspace, UserRole.WorkspaceAdmin, TestContext.CancellationToken); + + await using var subscriptionClient = await TestContext.Factory.CreateGraphQLSubscriptionClientAsync(TestContext.CancellationToken, bearerToken: admin.JwtAccessToken); + + var subscriptionQuery = new SubscriptionQueryBuilderGql().WithStreamWorkspace(new WorkspacePullBulkQueryBuilderGql() + .WithDocuments(new WorkspaceQueryBuilderGql().WithAllFields()) + .WithCheckpoint(new CheckpointQueryBuilderGql().WithAllFields())) + .Build(); + + // Start the subscription task before creating the workspace + // so that we do not miss subscription data + var subscriptionTask = CollectSubscriptionDataAsync(subscriptionClient, subscriptionQuery, TestContext.CancellationToken, maxResponses: 3); + + // Ensure the subscription is established + await Task.Delay(1000, TestContext.CancellationToken); + + // Act + await TestContext.HttpClient.CreateWorkspaceAsync(TestContext.CancellationToken); + + // Assert + var subscriptionResponses = await subscriptionTask; + + subscriptionResponses.Should() + .HaveCount(1); + var subscriptionResponse = subscriptionResponses.Single(); + subscriptionResponse.Errors.Should().HaveCount(1); + subscriptionResponse.Errors.Single() + .Message.Should() + .Be("The current user is not authorized to access this resource."); + } + + [Fact] + public async Task CreateWorkspaceShouldPropagateNewWorkspaceForAuthorizedUserThroughASecuredSubscriptionAsync() + { + // Arrange + TestContext = new TestScenarioBuilder() + .WithAuthorization() + .ConfigureReplicatedDocument(options => options.Security.RequirePolicyToRead("IsWorkspaceAdmin")) + .Build(); + + var workspace = await TestContext.CreateWorkspaceAsync(TestContext.CancellationToken); + var admin = await TestContext.CreateUserAsync(workspace, UserRole.WorkspaceAdmin, TestContext.CancellationToken); + + await using var subscriptionClient = await TestContext.Factory.CreateGraphQLSubscriptionClientAsync(TestContext.CancellationToken, bearerToken: admin.JwtAccessToken); + + var subscriptionQuery = new SubscriptionQueryBuilderGql().WithStreamWorkspace(new WorkspacePullBulkQueryBuilderGql() + .WithDocuments(new WorkspaceQueryBuilderGql().WithAllFields()) + .WithCheckpoint(new CheckpointQueryBuilderGql().WithAllFields())) + .Build(); + + // Start the subscription task before creating the workspace + // so that we do not miss subscription data + var subscriptionTask = CollectSubscriptionDataAsync(subscriptionClient, subscriptionQuery, TestContext.CancellationToken, maxResponses: 3); + + // Ensure the subscription is established + await Task.Delay(1000, TestContext.CancellationToken); + + // Act + var (newWorkspace, _) = await TestContext.HttpClient.CreateWorkspaceAsync(TestContext.CancellationToken); + + // Assert + var subscriptionResponses = await subscriptionTask; + + subscriptionResponses.Should() + .HaveCount(1); + var subscriptionResponse = subscriptionResponses[0]; + subscriptionResponse.Should() + .NotBeNull("Subscription data should not be null."); + subscriptionResponse.Errors.Should() + .BeNullOrEmpty(); + subscriptionResponse.Data.Should() + .NotBeNull(); + subscriptionResponse.Data?.StreamWorkspace.Should() + .NotBeNull(); + subscriptionResponse.Data?.StreamWorkspace?.Documents.Should() + .NotBeEmpty(); + + var streamedWorkspace = subscriptionResponse.Data?.StreamWorkspace?.Documents?.First(); + streamedWorkspace.Should() + .NotBeNull(); + + // Assert that the streamed workspace properties match the newWorkspace properties + streamedWorkspace?.Id.Should() + .Be(newWorkspace.Id, "The streamed workspace ID should match the created workspace ID"); + streamedWorkspace?.Name.Should() + .Be(newWorkspace.Name?.Value, "The streamed workspace name should match the created workspace name"); + streamedWorkspace?.IsDeleted.Should() + .Be(newWorkspace.IsDeleted, "The streamed workspace IsDeleted status should match the created workspace"); + streamedWorkspace?.UpdatedAt.Should() + .BeCloseTo(newWorkspace.UpdatedAt?.Value ?? DateTimeOffset.UtcNow, TimeSpan.FromSeconds(5), + "The streamed workspace UpdatedAt should be close to the created workspace's timestamp"); + + // Assert on the checkpoint + subscriptionResponse.Data?.StreamWorkspace?.Checkpoint.Should() + .NotBeNull("The checkpoint should be present"); + subscriptionResponse.Data?.StreamWorkspace?.Checkpoint?.LastDocumentId.Should() + .Be(newWorkspace.Id?.Value, "The checkpoint's LastDocumentId should match the new workspace's ID"); + Debug.Assert(newWorkspace.UpdatedAt != null, "newWorkspace.UpdatedAt != null"); + subscriptionResponse.Data?.StreamWorkspace?.Checkpoint?.UpdatedAt.Should() + .BeCloseTo(newWorkspace.UpdatedAt.Value, TimeSpan.FromSeconds(5), + "The checkpoint's UpdatedAt should be close to the new workspace's timestamp"); + } + /// /// Tests the behavior of the SubscriptionResolver when handling empty document updates. /// This test validates that the resolver correctly processes updates with no documents diff --git a/tests/RxDBDotNet.Tests/Utils/GraphQLSubscriptionClient.cs b/tests/RxDBDotNet.Tests/Utils/GraphQLSubscriptionClient.cs index 3f1e12d..e1f9f20 100644 --- a/tests/RxDBDotNet.Tests/Utils/GraphQLSubscriptionClient.cs +++ b/tests/RxDBDotNet.Tests/Utils/GraphQLSubscriptionClient.cs @@ -11,32 +11,36 @@ namespace RxDBDotNet.Tests.Utils; public sealed class GraphQLSubscriptionClient : IAsyncDisposable { private readonly WebSocket _webSocket; + private readonly string? _bearerToken; private bool _isDisposed; /// - /// Initializes a new instance of the GraphQLSubscriptionClient class for use in test scenarios. + /// Initializes a new instance of the class for use in test scenarios. /// /// - /// A WebSocket instance created by the test server, already connected to the GraphQL endpoint. + /// A instance created by the test server, already connected to the GraphQL endpoint. + /// + /// + /// An optional bearer token for authentication. /// /// - /// - /// This constructor initializes a new GraphQLSubscriptionClient with the provided WebSocket connection. - /// The client uses the graphql-transport-ws protocol for communication with a GraphQL server that supports - /// subscriptions. - /// - /// - /// The timeout parameter is particularly useful for testing and debugging scenarios where operations - /// might take longer than usual, allowing for extended debugging sessions without connection timeouts. - /// - /// - /// After creating an instance of GraphQLSubscriptionClient, you must call InitializeAsync() - /// before attempting to use the client for subscriptions. - /// + /// + /// This constructor initializes a new with the provided WebSocket connection. + /// The client uses the graphql-transport-ws protocol for communication with a GraphQL server that supports + /// subscriptions. + /// + /// + /// The timeout parameter is particularly useful for testing and debugging scenarios where operations + /// might take longer than usual, allowing for extended debugging sessions without connection timeouts. + /// + /// + /// After creating an instance of , you must call + /// before attempting to use the client for subscriptions. + /// /// /// - /// This example shows how to create and initialize a GraphQLSubscriptionClient in a test context: - /// + /// This example shows how to create and initialize a in a test context: + /// /// // Assuming 'factory' is a WebApplicationFactory<TProgram> instance /// var wsClient = factory.Server.CreateWebSocketClient(); /// var webSocket = await wsClient.ConnectAsync( @@ -47,10 +51,11 @@ public sealed class GraphQLSubscriptionClient : IAsyncDisposable /// // The client is now ready for use in tests /// /// - /// Thrown if the webSocket parameter is null. - public GraphQLSubscriptionClient(WebSocket webSocket) + /// Thrown if the parameter is null. + public GraphQLSubscriptionClient(WebSocket webSocket, string? bearerToken = null) { _webSocket = webSocket ?? throw new ArgumentNullException(nameof(webSocket)); + _bearerToken = bearerToken; } /// @@ -89,10 +94,26 @@ public async ValueTask DisposeAsync() /// Thrown when the connection initialization fails. public async Task InitializeAsync(CancellationToken cancellationToken) { - var initMessage = new + object initMessage; + + if (!string.IsNullOrEmpty(_bearerToken)) { - type = "connection_init", - }; + initMessage = new + { + type = "connection_init", + payload = new + { + Authorization = $"Bearer {_bearerToken}", + }, + }; + } + else + { + initMessage = new + { + type = "connection_init", + }; + } await SendMessageAsync(initMessage, cancellationToken); diff --git a/tests/RxDBDotNet.Tests/Utils/WebApplicationFactoryExtensions.cs b/tests/RxDBDotNet.Tests/Utils/WebApplicationFactoryExtensions.cs index 5387936..9467e4b 100644 --- a/tests/RxDBDotNet.Tests/Utils/WebApplicationFactoryExtensions.cs +++ b/tests/RxDBDotNet.Tests/Utils/WebApplicationFactoryExtensions.cs @@ -6,7 +6,8 @@ public static class WebApplicationFactoryExtensions { public static async Task CreateGraphQLSubscriptionClientAsync( this WebApplicationFactory factory, - CancellationToken cancellationToken) where TProgram : class + CancellationToken cancellationToken, + string? bearerToken = null) where TProgram : class { ArgumentNullException.ThrowIfNull(factory); @@ -22,8 +23,7 @@ public static async Task CreateGraphQLSubscriptionCli var webSocket = await wsClient.ConnectAsync(new Uri(factory.Server.BaseAddress, "/graphql"), cancellationToken); - // Pass the timeout to the GraphQLSubscriptionClient - var client = new GraphQLSubscriptionClient(webSocket); + var client = new GraphQLSubscriptionClient(webSocket, bearerToken); await client.InitializeAsync(cancellationToken); return client; }