diff --git a/README.md b/README.md
index 4d226fe..4791c47 100644
--- a/README.md
+++ b/README.md
@@ -1,17 +1,17 @@
+# RxDBDotNet
+
-
+
+
+
+
-
-
-
-# RxDBDotNet
-
RxDBDotNet is a powerful .NET library that implements the RxDB replication protocol, enabling real-time data synchronization between RxDB clients and .NET servers using GraphQL and Hot Chocolate. It extends the standard RxDB replication protocol with .NET-specific enhancements.
## Key Features
@@ -427,10 +427,6 @@ mutation PushWorkspace($input: PushWorkspaceInput!) {
... on UnauthorizedAccessError {
message
}
- ... on ArgumentNullError {
- message
- paramName
- }
}
}
}
diff --git a/example/LiveDocs.GraphQLApi/LiveDocs.GraphQLApi.csproj b/example/LiveDocs.GraphQLApi/LiveDocs.GraphQLApi.csproj
index 1339fbe..c4931af 100644
--- a/example/LiveDocs.GraphQLApi/LiveDocs.GraphQLApi.csproj
+++ b/example/LiveDocs.GraphQLApi/LiveDocs.GraphQLApi.csproj
@@ -29,9 +29,9 @@
-
-
-
+
+
+
diff --git a/example/LiveDocs.ServiceDefaults/LiveDocs.ServiceDefaults.csproj b/example/LiveDocs.ServiceDefaults/LiveDocs.ServiceDefaults.csproj
index 2b27347..cda14be 100644
--- a/example/LiveDocs.ServiceDefaults/LiveDocs.ServiceDefaults.csproj
+++ b/example/LiveDocs.ServiceDefaults/LiveDocs.ServiceDefaults.csproj
@@ -23,7 +23,7 @@
-
+
diff --git a/src/RxDBDotNet/Extensions/GraphQLBuilderExtensions.cs b/src/RxDBDotNet/Extensions/GraphQLBuilderExtensions.cs
index 244a75a..a18f606 100644
--- a/src/RxDBDotNet/Extensions/GraphQLBuilderExtensions.cs
+++ b/src/RxDBDotNet/Extensions/GraphQLBuilderExtensions.cs
@@ -37,6 +37,8 @@ public static IRequestExecutorBuilder AddReplication(this IRequestExecutorBuilde
builder.AddFiltering();
+ builder.AddSocketSessionInterceptor();
+
// Ensure Query, Mutation, and Subscription types exist
EnsureRootTypesExist(builder);
@@ -224,9 +226,6 @@ private static void AddFieldErrorTypes(IObjectFieldDescriptor field,
field.Error();
addedErrorTypes.Add(typeof(UnauthorizedAccessException));
- field.Error();
- addedErrorTypes.Add(typeof(ArgumentNullException));
-
// update the foreach code to not add the AuthenticationException error type if it has already been added
foreach (var errorType in replicationOptions.Errors)
{
diff --git a/src/RxDBDotNet/RxDBDotNet.csproj b/src/RxDBDotNet/RxDBDotNet.csproj
index 981b203..2cf8157 100644
--- a/src/RxDBDotNet/RxDBDotNet.csproj
+++ b/src/RxDBDotNet/RxDBDotNet.csproj
@@ -30,10 +30,11 @@
MIT
-
+
-
-
+
+
+
diff --git a/src/RxDBDotNet/Security/SocketConnectPayload.cs b/src/RxDBDotNet/Security/SocketConnectPayload.cs
new file mode 100644
index 0000000..5d6ddef
--- /dev/null
+++ b/src/RxDBDotNet/Security/SocketConnectPayload.cs
@@ -0,0 +1,15 @@
+namespace RxDBDotNet.Security;
+
+///
+/// Represents the payload for a socket connection, containing the necessary authorization information.
+///
+public class SocketConnectPayload
+{
+ ///
+ /// Gets the authorization token required for establishing a WebSocket connection.
+ ///
+ ///
+ /// A string representing the authorization token, typically in the form of a JWT.
+ ///
+ public required string Authorization { get; init; }
+}
diff --git a/src/RxDBDotNet/Security/SubscriptionAuthMiddleware.cs b/src/RxDBDotNet/Security/SubscriptionAuthMiddleware.cs
new file mode 100644
index 0000000..c8b3205
--- /dev/null
+++ b/src/RxDBDotNet/Security/SubscriptionAuthMiddleware.cs
@@ -0,0 +1,211 @@
+using System.Security.Claims;
+using HotChocolate.AspNetCore;
+using HotChocolate.AspNetCore.Subscriptions;
+using HotChocolate.AspNetCore.Subscriptions.Protocols;
+using Microsoft.AspNetCore.Authentication;
+using Microsoft.AspNetCore.Authentication.JwtBearer;
+using Microsoft.Extensions.Options;
+
+namespace RxDBDotNet.Security;
+
+///
+/// Middleware for authenticating WebSocket connections in GraphQL subscriptions.
+/// This middleware implements part of the graphql-transport-ws protocol, specifically handling the ConnectionInit message.
+/// It validates JWT tokens sent in the connection payload and sets up the ClaimsPrincipal for authenticated connections.
+/// If JWT authentication is not configured, it allows all connections.
+///
+///
+/// This middleware should be registered with the GraphQL server using AddSocketSessionInterceptor<SubscriptionAuthMiddleware>().
+/// It uses the same JWT configuration as set up in AddJwtBearer() for consistency across HTTP and WebSocket connections, if available.
+///
+/// According to the graphql-transport-ws protocol:
+/// - The server must receive the connection initialisation message within the allowed waiting time.
+/// - If the server wishes to reject the connection during authentication, it should close the socket with the event 4403: Forbidden.
+/// - If the server receives more than one ConnectionInit message, it should close the socket with the event 4429: Too many initialisation requests.
+///
+/// Note: This implementation assumes that Hot Chocolate handles the connection timeout and multiple ConnectionInit messages internally.
+/// If this is not the case, additional logic would need to be added to this middleware to fully comply with the protocol.
+///
+public class SubscriptionAuthMiddleware : DefaultSocketSessionInterceptor
+{
+ private readonly IAuthenticationSchemeProvider? _schemeProvider;
+ private readonly IOptionsMonitor _jwtOptionsMonitor;
+
+ ///
+ /// Initializes a new instance of the class.
+ ///
+ /// The authentication scheme provider.
+ /// The options monitor for JWT bearer token validation.
+ ///
+ /// We inject both IAuthenticationSchemeProvider and IOptionsMonitor<JwtBearerOptions> for the following reasons:
+ /// 1. IAuthenticationSchemeProvider allows us to check if JWT Bearer authentication is configured.
+ /// 2. IOptionsMonitor<JwtBearerOptions> provides access to the JWT configuration for token validation.
+ /// This approach allows the middleware to work correctly whether authentication is configured or not.
+ ///
+ public SubscriptionAuthMiddleware(
+ IAuthenticationSchemeProvider? schemeProvider,
+ IOptionsMonitor jwtOptionsMonitor)
+ {
+ _schemeProvider = schemeProvider;
+ _jwtOptionsMonitor = jwtOptionsMonitor;
+ }
+
+ ///
+ /// Called when a new WebSocket connection is being established.
+ /// This method handles the ConnectionInit message as per the graphql-transport-ws protocol.
+ /// It validates the JWT token in the connection payload and sets up the ClaimsPrincipal for authenticated connections.
+ /// If JWT authentication is not configured, it allows all connections.
+ ///
+ /// The socket session for the connection.
+ /// The payload of the ConnectionInit message.
+ /// A token to cancel the operation.
+ /// Thrown when or is null.
+ /// A indicating whether the connection was accepted or rejected.
+ ///
+ /// This method follows these steps:
+ /// 1. Check if JWT Bearer authentication is configured.
+ /// 2. If not configured, accept all connections (allowing for non-authenticated setups).
+ /// 3. If configured, validate the JWT token from the ConnectionInit message payload.
+ /// 4. Set up the ClaimsPrincipal for authenticated connections.
+ /// 5. If authentication fails, reject the connection with a 4403: Forbidden status.
+ /// This approach ensures that the middleware works in both authenticated and non-authenticated scenarios,
+ /// providing flexibility for different application setups while adhering to the graphql-transport-ws protocol.
+ ///
+ public override async ValueTask OnConnectAsync(
+ ISocketSession session,
+ IOperationMessagePayload connectionInitMessage,
+ CancellationToken cancellationToken = default)
+ {
+ ArgumentNullException.ThrowIfNull(session);
+ ArgumentNullException.ThrowIfNull(connectionInitMessage);
+
+ try
+ {
+ // Check if JWT Bearer authentication is configured
+ // This allows the middleware to work in both authenticated and non-authenticated setups
+ if (!await IsJwtBearerConfiguredAsync().ConfigureAwait(false))
+ {
+ // If JWT Bearer is not configured, we accept all connections
+ // This is crucial for supporting non-authenticated scenarios
+ return ConnectionStatus.Accept();
+ }
+
+ // JWT Bearer is configured, so we proceed with token validation
+ var connectPayload = connectionInitMessage.As();
+ var authorizationHeader = connectPayload?.Authorization;
+
+ // Ensure the Authorization header is present and in the correct format
+ if (string.IsNullOrEmpty(authorizationHeader) || !authorizationHeader.StartsWith("Bearer ", StringComparison.OrdinalIgnoreCase))
+ {
+ // As per the protocol, we reject the connection with a 4403: Forbidden status
+ return RejectConnection();
+ }
+
+ // Extract the token from the Authorization header
+ var token = authorizationHeader["Bearer ".Length..].Trim();
+
+ // Validate the token
+ var claimsPrincipal = await ValidateTokenAsync(token).ConfigureAwait(false);
+
+ if (claimsPrincipal != null)
+ {
+ // If the token is valid, set the ClaimsPrincipal on the HttpContext
+ // This allows the rest of the application to access the authenticated user's claims
+ session.Connection.HttpContext.User = claimsPrincipal;
+ return ConnectionStatus.Accept();
+ }
+
+ // If the token is invalid, reject the connection with a 4403: Forbidden status
+ return RejectConnection();
+ }
+ catch
+ {
+ // If any unexpected error occurs during the process, reject the connection
+ // This ensures that we don't accidentally allow unauthorized access in case of errors
+ return RejectConnection();
+ }
+ }
+
+ private static ConnectionStatus RejectConnection()
+ {
+ return ConnectionStatus.Reject("4403: Forbidden", new Dictionary(StringComparer.Ordinal)
+ {
+ { "reason", "Authentication failed" },
+ });
+ }
+
+ ///
+ /// Validates the provided JWT token using the configured JWT bearer options.
+ ///
+ /// The JWT token to validate.
+ ///
+ /// A if the token is valid and a non-null principal was created; otherwise, null.
+ ///
+ ///
+ /// This method uses the same validation parameters as configured in AddJwtBearer(),
+ /// ensuring consistency between HTTP and WebSocket authentication.
+ /// The method is designed to handle exceptions during token validation, returning null for any validation failure.
+ /// This approach allows the calling method to easily distinguish between valid and invalid tokens.
+ ///
+ private async Task ValidateTokenAsync(string token)
+ {
+ // Retrieve the JWT Bearer options. These options are configured when setting up JWT authentication
+ var jwtBearerOptions = _jwtOptionsMonitor.Get(JwtBearerDefaults.AuthenticationScheme);
+
+ // Get the token handler from the options. This is typically a JwtSecurityTokenHandler
+ var tokenHandler = jwtBearerOptions.TokenHandlers.Single();
+
+ // Get the token validation parameters from the options
+ var validationParameters = jwtBearerOptions.TokenValidationParameters;
+
+ try
+ {
+ // Attempt to validate the token
+ var tokenValidationResult = await tokenHandler
+ .ValidateTokenAsync(token, validationParameters)
+ .ConfigureAwait(false);
+
+ // If the token is valid, create and return a new ClaimsPrincipal
+ if (tokenValidationResult.IsValid)
+ {
+ return new ClaimsPrincipal(tokenValidationResult.ClaimsIdentity);
+ }
+ }
+ catch (Exception)
+ {
+ // If any exception occurs during validation, we catch it and return null
+ // This is to ensure that any unexpected errors in token validation are treated as validation failures
+ return null;
+ }
+
+ // If we reach here, the token was invalid, so we return null
+ return null;
+ }
+
+ ///
+ /// Checks if JWT Bearer authentication is explicitly configured.
+ ///
+ /// true if JWT Bearer authentication is explicitly configured; otherwise, false.
+ ///
+ /// This method checks for the presence of the JWT Bearer authentication scheme.
+ /// The presence of this scheme is a reliable indicator that JWT Bearer authentication has been configured.
+ /// This approach is chosen because:
+ /// 1. It's more reliable than checking specific option values, which might have default values even when not explicitly set.
+ /// 2. It's simpler and faster than comparing multiple option values.
+ /// 3. It directly reflects whether the AddJwtBearer() method has been called in the application's startup configuration.
+ ///
+ private async Task IsJwtBearerConfiguredAsync()
+ {
+ // Attempt to retrieve the JWT Bearer authentication scheme
+ if (_schemeProvider != null)
+ {
+ var scheme = await _schemeProvider.GetSchemeAsync(JwtBearerDefaults.AuthenticationScheme).ConfigureAwait(false);
+
+ // If the scheme is not null, it means JWT Bearer authentication has been configured
+ return scheme != null;
+ }
+
+ // if _schemeProvider is null, we assume that JWT Bearer authentication is not configured
+ return false;
+ }
+}
diff --git a/tests/RxDBDotNet.TestModelGenerator/RxDBDotNet.TestModelGenerator.csproj b/tests/RxDBDotNet.TestModelGenerator/RxDBDotNet.TestModelGenerator.csproj
index 4009e62..f266379 100644
--- a/tests/RxDBDotNet.TestModelGenerator/RxDBDotNet.TestModelGenerator.csproj
+++ b/tests/RxDBDotNet.TestModelGenerator/RxDBDotNet.TestModelGenerator.csproj
@@ -23,7 +23,7 @@
-
+
diff --git a/tests/RxDBDotNet.Tests.Setup/RxDBDotNet.Tests.Setup.csproj b/tests/RxDBDotNet.Tests.Setup/RxDBDotNet.Tests.Setup.csproj
index 1986a4d..f0487f4 100644
--- a/tests/RxDBDotNet.Tests.Setup/RxDBDotNet.Tests.Setup.csproj
+++ b/tests/RxDBDotNet.Tests.Setup/RxDBDotNet.Tests.Setup.csproj
@@ -28,8 +28,8 @@
-
-
+
+
diff --git a/tests/RxDBDotNet.Tests/Model/GraphQLTestModel.cs b/tests/RxDBDotNet.Tests/Model/GraphQLTestModel.cs
index d83c812..2964d7a 100644
--- a/tests/RxDBDotNet.Tests/Model/GraphQLTestModel.cs
+++ b/tests/RxDBDotNet.Tests/Model/GraphQLTestModel.cs
@@ -1065,7 +1065,6 @@ public static class GraphQlTypes
public const string String = "String";
public const string Uuid = "UUID";
- public const string ArgumentNullError = "ArgumentNullError";
public const string AuthenticationError = "AuthenticationError";
public const string Checkpoint = "Checkpoint";
public const string LiveDoc = "LiveDoc";
@@ -1517,27 +1516,6 @@ public partial class UnauthorizedAccessErrorQueryBuilderGql : GraphQlQueryBuilde
public UnauthorizedAccessErrorQueryBuilderGql ExceptMessage() => ExceptField("message");
}
- public partial class ArgumentNullErrorQueryBuilderGql : GraphQlQueryBuilder
- {
- private static readonly GraphQlFieldMetadata[] AllFieldMetadata =
- {
- new GraphQlFieldMetadata { Name = "message" },
- new GraphQlFieldMetadata { Name = "paramName" }
- };
-
- protected override string TypeName { get; } = "ArgumentNullError";
-
- public override IReadOnlyList AllFields { get; } = AllFieldMetadata;
-
- public ArgumentNullErrorQueryBuilderGql WithMessage(string? alias = null, SkipDirective? skip = null, IncludeDirective? include = null) => WithScalarField("message", alias, new GraphQlDirective?[] { skip, include });
-
- public ArgumentNullErrorQueryBuilderGql ExceptMessage() => ExceptField("message");
-
- public ArgumentNullErrorQueryBuilderGql WithParamName(string? alias = null, SkipDirective? skip = null, IncludeDirective? include = null) => WithScalarField("paramName", alias, new GraphQlDirective?[] { skip, include });
-
- public ArgumentNullErrorQueryBuilderGql ExceptParamName() => ExceptField("paramName");
- }
-
public partial class WorkspaceQueryBuilderGql : GraphQlQueryBuilder
{
private static readonly GraphQlFieldMetadata[] AllFieldMetadata =
@@ -1640,8 +1618,6 @@ public partial class ErrorQueryBuilderGql : GraphQlQueryBuilder WithFragment(authenticationErrorQueryBuilder, new GraphQlDirective?[] { skip, include });
public ErrorQueryBuilderGql WithUnauthorizedAccessErrorFragment(UnauthorizedAccessErrorQueryBuilderGql unauthorizedAccessErrorQueryBuilder, SkipDirective? skip = null, IncludeDirective? include = null) => WithFragment(unauthorizedAccessErrorQueryBuilder, new GraphQlDirective?[] { skip, include });
-
- public ErrorQueryBuilderGql WithArgumentNullErrorFragment(ArgumentNullErrorQueryBuilderGql argumentNullErrorQueryBuilder, SkipDirective? skip = null, IncludeDirective? include = null) => WithFragment(argumentNullErrorQueryBuilder, new GraphQlDirective?[] { skip, include });
}
public partial class PushUserErrorQueryBuilderGql : GraphQlQueryBuilder
@@ -1657,8 +1633,6 @@ public partial class PushUserErrorQueryBuilderGql : GraphQlQueryBuilder WithFragment(authenticationErrorQueryBuilder, new GraphQlDirective?[] { skip, include });
public PushUserErrorQueryBuilderGql WithUnauthorizedAccessErrorFragment(UnauthorizedAccessErrorQueryBuilderGql unauthorizedAccessErrorQueryBuilder, SkipDirective? skip = null, IncludeDirective? include = null) => WithFragment(unauthorizedAccessErrorQueryBuilder, new GraphQlDirective?[] { skip, include });
-
- public PushUserErrorQueryBuilderGql WithArgumentNullErrorFragment(ArgumentNullErrorQueryBuilderGql argumentNullErrorQueryBuilder, SkipDirective? skip = null, IncludeDirective? include = null) => WithFragment(argumentNullErrorQueryBuilder, new GraphQlDirective?[] { skip, include });
}
public partial class PushUserPayloadQueryBuilderGql : GraphQlQueryBuilder
@@ -1695,8 +1669,6 @@ public partial class PushWorkspaceErrorQueryBuilderGql : GraphQlQueryBuilder WithFragment(authenticationErrorQueryBuilder, new GraphQlDirective?[] { skip, include });
public PushWorkspaceErrorQueryBuilderGql WithUnauthorizedAccessErrorFragment(UnauthorizedAccessErrorQueryBuilderGql unauthorizedAccessErrorQueryBuilder, SkipDirective? skip = null, IncludeDirective? include = null) => WithFragment(unauthorizedAccessErrorQueryBuilder, new GraphQlDirective?[] { skip, include });
-
- public PushWorkspaceErrorQueryBuilderGql WithArgumentNullErrorFragment(ArgumentNullErrorQueryBuilderGql argumentNullErrorQueryBuilder, SkipDirective? skip = null, IncludeDirective? include = null) => WithFragment(argumentNullErrorQueryBuilder, new GraphQlDirective?[] { skip, include });
}
public partial class PushWorkspacePayloadQueryBuilderGql : GraphQlQueryBuilder
@@ -1733,8 +1705,6 @@ public partial class PushLiveDocErrorQueryBuilderGql : GraphQlQueryBuilder WithFragment(authenticationErrorQueryBuilder, new GraphQlDirective?[] { skip, include });
public PushLiveDocErrorQueryBuilderGql WithUnauthorizedAccessErrorFragment(UnauthorizedAccessErrorQueryBuilderGql unauthorizedAccessErrorQueryBuilder, SkipDirective? skip = null, IncludeDirective? include = null) => WithFragment(unauthorizedAccessErrorQueryBuilder, new GraphQlDirective?[] { skip, include });
-
- public PushLiveDocErrorQueryBuilderGql WithArgumentNullErrorFragment(ArgumentNullErrorQueryBuilderGql argumentNullErrorQueryBuilder, SkipDirective? skip = null, IncludeDirective? include = null) => WithFragment(argumentNullErrorQueryBuilder, new GraphQlDirective?[] { skip, include });
}
public partial class PushLiveDocPayloadQueryBuilderGql : GraphQlQueryBuilder
@@ -3216,22 +3186,15 @@ public partial class CheckpointGql
}
[GraphQlObjectType("AuthenticationError")]
- public partial class AuthenticationErrorGql : IPushUserErrorGql, IErrorGql
+ public partial class AuthenticationErrorGql : IPushUserErrorGql, IPushWorkspaceErrorGql, IPushLiveDocErrorGql, IErrorGql
{
public string Message { get; set; }
}
[GraphQlObjectType("UnauthorizedAccessError")]
- public partial class UnauthorizedAccessErrorGql : IPushUserErrorGql, IErrorGql
- {
- public string Message { get; set; }
- }
-
- [GraphQlObjectType("ArgumentNullError")]
- public partial class ArgumentNullErrorGql : IPushUserErrorGql, IErrorGql
+ public partial class UnauthorizedAccessErrorGql : IPushUserErrorGql, IPushWorkspaceErrorGql, IPushLiveDocErrorGql, IErrorGql
{
public string Message { get; set; }
- public string? ParamName { get; set; }
}
public partial class WorkspaceGql
@@ -3298,5 +3261,5 @@ public partial class PushLiveDocPayloadGql
public ICollection? Errors { get; set; }
}
#endregion
-#nullable restore
+ #nullable restore
}
diff --git a/tests/RxDBDotNet.Tests/RxDBDotNet.Tests.csproj b/tests/RxDBDotNet.Tests/RxDBDotNet.Tests.csproj
index f6752a7..59a0563 100644
--- a/tests/RxDBDotNet.Tests/RxDBDotNet.Tests.csproj
+++ b/tests/RxDBDotNet.Tests/RxDBDotNet.Tests.csproj
@@ -25,15 +25,15 @@
-
-
+
+
all
runtime; build; native; contentfiles; analyzers; buildtransitive
-
-
+
+
all
diff --git a/tests/RxDBDotNet.Tests/SubscriptionTests.cs b/tests/RxDBDotNet.Tests/SubscriptionTests.cs
index 167ada9..e83c765 100644
--- a/tests/RxDBDotNet.Tests/SubscriptionTests.cs
+++ b/tests/RxDBDotNet.Tests/SubscriptionTests.cs
@@ -24,6 +24,135 @@ public async Task DisposeAsync()
await TestContext.DisposeAsync();
}
+ [Fact]
+ public async Task CreateWorkspaceShouldNotPropagateNewWorkspaceForAnAuthenticatedUserThroughASecuredSubscriptionAsync()
+ {
+ // Arrange
+ TestContext = new TestScenarioBuilder()
+ .WithAuthorization()
+ .ConfigureReplicatedDocument(options => options.Security.RequirePolicyToRead("IsSystemAdmin"))
+ .Build();
+
+ await TestContext.CreateWorkspaceAsync(TestContext.CancellationToken);
+
+ Func act = async () => await TestContext.Factory.CreateGraphQLSubscriptionClientAsync(TestContext.CancellationToken);
+
+ await act.Should().ThrowAsync();
+ }
+
+ [Fact]
+ public async Task CreateWorkspaceShouldNotPropagateNewWorkspaceForAnUnAuthorizedUserThroughASecuredSubscriptionAsync()
+ {
+ // Arrange
+ TestContext = new TestScenarioBuilder()
+ .WithAuthorization()
+ .ConfigureReplicatedDocument(options => options.Security.RequirePolicyToRead("IsSystemAdmin"))
+ .Build();
+
+ var workspace = await TestContext.CreateWorkspaceAsync(TestContext.CancellationToken);
+ var admin = await TestContext.CreateUserAsync(workspace, UserRole.WorkspaceAdmin, TestContext.CancellationToken);
+
+ await using var subscriptionClient = await TestContext.Factory.CreateGraphQLSubscriptionClientAsync(TestContext.CancellationToken, bearerToken: admin.JwtAccessToken);
+
+ var subscriptionQuery = new SubscriptionQueryBuilderGql().WithStreamWorkspace(new WorkspacePullBulkQueryBuilderGql()
+ .WithDocuments(new WorkspaceQueryBuilderGql().WithAllFields())
+ .WithCheckpoint(new CheckpointQueryBuilderGql().WithAllFields()))
+ .Build();
+
+ // Start the subscription task before creating the workspace
+ // so that we do not miss subscription data
+ var subscriptionTask = CollectSubscriptionDataAsync(subscriptionClient, subscriptionQuery, TestContext.CancellationToken, maxResponses: 3);
+
+ // Ensure the subscription is established
+ await Task.Delay(1000, TestContext.CancellationToken);
+
+ // Act
+ await TestContext.HttpClient.CreateWorkspaceAsync(TestContext.CancellationToken);
+
+ // Assert
+ var subscriptionResponses = await subscriptionTask;
+
+ subscriptionResponses.Should()
+ .HaveCount(1);
+ var subscriptionResponse = subscriptionResponses.Single();
+ subscriptionResponse.Errors.Should().HaveCount(1);
+ subscriptionResponse.Errors.Single()
+ .Message.Should()
+ .Be("The current user is not authorized to access this resource.");
+ }
+
+ [Fact]
+ public async Task CreateWorkspaceShouldPropagateNewWorkspaceForAuthorizedUserThroughASecuredSubscriptionAsync()
+ {
+ // Arrange
+ TestContext = new TestScenarioBuilder()
+ .WithAuthorization()
+ .ConfigureReplicatedDocument(options => options.Security.RequirePolicyToRead("IsWorkspaceAdmin"))
+ .Build();
+
+ var workspace = await TestContext.CreateWorkspaceAsync(TestContext.CancellationToken);
+ var admin = await TestContext.CreateUserAsync(workspace, UserRole.WorkspaceAdmin, TestContext.CancellationToken);
+
+ await using var subscriptionClient = await TestContext.Factory.CreateGraphQLSubscriptionClientAsync(TestContext.CancellationToken, bearerToken: admin.JwtAccessToken);
+
+ var subscriptionQuery = new SubscriptionQueryBuilderGql().WithStreamWorkspace(new WorkspacePullBulkQueryBuilderGql()
+ .WithDocuments(new WorkspaceQueryBuilderGql().WithAllFields())
+ .WithCheckpoint(new CheckpointQueryBuilderGql().WithAllFields()))
+ .Build();
+
+ // Start the subscription task before creating the workspace
+ // so that we do not miss subscription data
+ var subscriptionTask = CollectSubscriptionDataAsync(subscriptionClient, subscriptionQuery, TestContext.CancellationToken, maxResponses: 3);
+
+ // Ensure the subscription is established
+ await Task.Delay(1000, TestContext.CancellationToken);
+
+ // Act
+ var (newWorkspace, _) = await TestContext.HttpClient.CreateWorkspaceAsync(TestContext.CancellationToken);
+
+ // Assert
+ var subscriptionResponses = await subscriptionTask;
+
+ subscriptionResponses.Should()
+ .HaveCount(1);
+ var subscriptionResponse = subscriptionResponses[0];
+ subscriptionResponse.Should()
+ .NotBeNull("Subscription data should not be null.");
+ subscriptionResponse.Errors.Should()
+ .BeNullOrEmpty();
+ subscriptionResponse.Data.Should()
+ .NotBeNull();
+ subscriptionResponse.Data?.StreamWorkspace.Should()
+ .NotBeNull();
+ subscriptionResponse.Data?.StreamWorkspace?.Documents.Should()
+ .NotBeEmpty();
+
+ var streamedWorkspace = subscriptionResponse.Data?.StreamWorkspace?.Documents?.First();
+ streamedWorkspace.Should()
+ .NotBeNull();
+
+ // Assert that the streamed workspace properties match the newWorkspace properties
+ streamedWorkspace?.Id.Should()
+ .Be(newWorkspace.Id, "The streamed workspace ID should match the created workspace ID");
+ streamedWorkspace?.Name.Should()
+ .Be(newWorkspace.Name?.Value, "The streamed workspace name should match the created workspace name");
+ streamedWorkspace?.IsDeleted.Should()
+ .Be(newWorkspace.IsDeleted, "The streamed workspace IsDeleted status should match the created workspace");
+ streamedWorkspace?.UpdatedAt.Should()
+ .BeCloseTo(newWorkspace.UpdatedAt?.Value ?? DateTimeOffset.UtcNow, TimeSpan.FromSeconds(5),
+ "The streamed workspace UpdatedAt should be close to the created workspace's timestamp");
+
+ // Assert on the checkpoint
+ subscriptionResponse.Data?.StreamWorkspace?.Checkpoint.Should()
+ .NotBeNull("The checkpoint should be present");
+ subscriptionResponse.Data?.StreamWorkspace?.Checkpoint?.LastDocumentId.Should()
+ .Be(newWorkspace.Id?.Value, "The checkpoint's LastDocumentId should match the new workspace's ID");
+ Debug.Assert(newWorkspace.UpdatedAt != null, "newWorkspace.UpdatedAt != null");
+ subscriptionResponse.Data?.StreamWorkspace?.Checkpoint?.UpdatedAt.Should()
+ .BeCloseTo(newWorkspace.UpdatedAt.Value, TimeSpan.FromSeconds(5),
+ "The checkpoint's UpdatedAt should be close to the new workspace's timestamp");
+ }
+
///
/// Tests the behavior of the SubscriptionResolver when handling empty document updates.
/// This test validates that the resolver correctly processes updates with no documents
diff --git a/tests/RxDBDotNet.Tests/Utils/GraphQLSubscriptionClient.cs b/tests/RxDBDotNet.Tests/Utils/GraphQLSubscriptionClient.cs
index 3f1e12d..e1f9f20 100644
--- a/tests/RxDBDotNet.Tests/Utils/GraphQLSubscriptionClient.cs
+++ b/tests/RxDBDotNet.Tests/Utils/GraphQLSubscriptionClient.cs
@@ -11,32 +11,36 @@ namespace RxDBDotNet.Tests.Utils;
public sealed class GraphQLSubscriptionClient : IAsyncDisposable
{
private readonly WebSocket _webSocket;
+ private readonly string? _bearerToken;
private bool _isDisposed;
///
- /// Initializes a new instance of the GraphQLSubscriptionClient class for use in test scenarios.
+ /// Initializes a new instance of the class for use in test scenarios.
///
///
- /// A WebSocket instance created by the test server, already connected to the GraphQL endpoint.
+ /// A instance created by the test server, already connected to the GraphQL endpoint.
+ ///
+ ///
+ /// An optional bearer token for authentication.
///
///
- ///
- /// This constructor initializes a new GraphQLSubscriptionClient with the provided WebSocket connection.
- /// The client uses the graphql-transport-ws protocol for communication with a GraphQL server that supports
- /// subscriptions.
- ///
- ///
- /// The timeout parameter is particularly useful for testing and debugging scenarios where operations
- /// might take longer than usual, allowing for extended debugging sessions without connection timeouts.
- ///
- ///
- /// After creating an instance of GraphQLSubscriptionClient, you must call InitializeAsync()
- /// before attempting to use the client for subscriptions.
- ///
+ ///
+ /// This constructor initializes a new with the provided WebSocket connection.
+ /// The client uses the graphql-transport-ws protocol for communication with a GraphQL server that supports
+ /// subscriptions.
+ ///
+ ///
+ /// The timeout parameter is particularly useful for testing and debugging scenarios where operations
+ /// might take longer than usual, allowing for extended debugging sessions without connection timeouts.
+ ///
+ ///
+ /// After creating an instance of , you must call
+ /// before attempting to use the client for subscriptions.
+ ///
///
///
- /// This example shows how to create and initialize a GraphQLSubscriptionClient in a test context:
- ///
+ /// This example shows how to create and initialize a in a test context:
+ ///
/// // Assuming 'factory' is a WebApplicationFactory<TProgram> instance
/// var wsClient = factory.Server.CreateWebSocketClient();
/// var webSocket = await wsClient.ConnectAsync(
@@ -47,10 +51,11 @@ public sealed class GraphQLSubscriptionClient : IAsyncDisposable
/// // The client is now ready for use in tests
///
///
- /// Thrown if the webSocket parameter is null.
- public GraphQLSubscriptionClient(WebSocket webSocket)
+ /// Thrown if the parameter is null.
+ public GraphQLSubscriptionClient(WebSocket webSocket, string? bearerToken = null)
{
_webSocket = webSocket ?? throw new ArgumentNullException(nameof(webSocket));
+ _bearerToken = bearerToken;
}
///
@@ -89,10 +94,26 @@ public async ValueTask DisposeAsync()
/// Thrown when the connection initialization fails.
public async Task InitializeAsync(CancellationToken cancellationToken)
{
- var initMessage = new
+ object initMessage;
+
+ if (!string.IsNullOrEmpty(_bearerToken))
{
- type = "connection_init",
- };
+ initMessage = new
+ {
+ type = "connection_init",
+ payload = new
+ {
+ Authorization = $"Bearer {_bearerToken}",
+ },
+ };
+ }
+ else
+ {
+ initMessage = new
+ {
+ type = "connection_init",
+ };
+ }
await SendMessageAsync(initMessage, cancellationToken);
diff --git a/tests/RxDBDotNet.Tests/Utils/WebApplicationFactoryExtensions.cs b/tests/RxDBDotNet.Tests/Utils/WebApplicationFactoryExtensions.cs
index 5387936..9467e4b 100644
--- a/tests/RxDBDotNet.Tests/Utils/WebApplicationFactoryExtensions.cs
+++ b/tests/RxDBDotNet.Tests/Utils/WebApplicationFactoryExtensions.cs
@@ -6,7 +6,8 @@ public static class WebApplicationFactoryExtensions
{
public static async Task CreateGraphQLSubscriptionClientAsync(
this WebApplicationFactory factory,
- CancellationToken cancellationToken) where TProgram : class
+ CancellationToken cancellationToken,
+ string? bearerToken = null) where TProgram : class
{
ArgumentNullException.ThrowIfNull(factory);
@@ -22,8 +23,7 @@ public static async Task CreateGraphQLSubscriptionCli
var webSocket = await wsClient.ConnectAsync(new Uri(factory.Server.BaseAddress, "/graphql"), cancellationToken);
- // Pass the timeout to the GraphQLSubscriptionClient
- var client = new GraphQLSubscriptionClient(webSocket);
+ var client = new GraphQLSubscriptionClient(webSocket, bearerToken);
await client.InitializeAsync(cancellationToken);
return client;
}