Skip to content

Commit

Permalink
feat: deploy-api websocket (#30)
Browse files Browse the repository at this point in the history
* feat: deploy-api websocket

* chore: cleanup
  • Loading branch information
Alxandr authored Oct 9, 2024
1 parent 7b1b033 commit a55a1d9
Show file tree
Hide file tree
Showing 5 changed files with 228 additions and 32 deletions.
Original file line number Diff line number Diff line change
@@ -1,16 +1,16 @@
using Altinn.Authorization.DeployApi.Pipelines;
using System.Text.Json.Serialization;
using Altinn.Authorization.DeployApi.Pipelines;
using Azure.Core;
using Azure.ResourceManager;
using Azure.ResourceManager.KeyVault;
using Azure.ResourceManager.PostgreSql.FlexibleServers;
using Azure.Security.KeyVault.Secrets;
using Npgsql;
using System.Text.Json.Serialization;

namespace Altinn.Authorization.DeployApi.BootstrapDatabase;

internal sealed class BootstrapDatabasePipeline
: Pipeline
: TaskPipeline
{
[JsonPropertyName("resources")]
public required ResourcesConfig Resources { get; init; }
Expand Down Expand Up @@ -108,11 +108,11 @@ await context.RunTask(

connStringBuilder.Username = migratorUser.RoleName;
connStringBuilder.Password = migratorUser.Password;
connectionStrings[$"db-{DatabaseName}-migrator"] = connStringBuilder.ToString();
connectionStrings[$"db-{UserPrefix}-migrator"] = connStringBuilder.ToString();

connStringBuilder.Username = appUser.RoleName;
connStringBuilder.Password = appUser.Password;
connectionStrings[$"db-{DatabaseName}-app"] = connStringBuilder.ToString();
connectionStrings[$"db-{UserPrefix}-app"] = connStringBuilder.ToString();

await context.RunTask(new SaveConnectionStringsTask(secretClient, connectionStrings), cancellationToken);
}
Expand Down

This file was deleted.

Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
using System.Text;
using System.IO.Pipelines;
using System.Net.WebSockets;
using System.Text;
using Altinn.Authorization.DeployApi.Tasks;
using Microsoft.AspNetCore.Http.Features;
using Nerdbank.Streams;
using Spectre.Console;
using Spectre.Console.Rendering;

Expand All @@ -11,18 +14,95 @@ public sealed class PipelineContext
, ISupportRequiredService
, IKeyedServiceProvider
{
internal static async Task Run(Pipeline pipeline, HttpContext context, CancellationToken cancellationToken)
internal static async Task Run(TaskPipeline pipeline, HttpContext context)
{
var ct = context.RequestAborted;
var responseBody = context.Features.Get<IHttpResponseBodyFeature>()!;
responseBody.DisableBuffering();

var response = context.Response;
response.StatusCode = 200;
response.ContentType = "text/plain; charset=utf-8";

await responseBody.StartAsync(cancellationToken);
await responseBody.StartAsync(ct);

await using var textWriter = new StreamWriter(responseBody.Stream, Encoding.UTF8);
////await using var textWriter = new StreamWriter(responseBody.Stream, Encoding.UTF8);
await Run(pipeline, responseBody.Writer, context.RequestServices, ct);

await responseBody.CompleteAsync();
}

internal static async Task<TaskPipelineResult> Run(TaskPipeline pipeline, WebSocket context, IServiceProvider services, CancellationToken cancellationToken)
{
using var cts = CancellationTokenSource.CreateLinkedTokenSource(cancellationToken);
var pipe = new Pipe();
var ct = cts.Token;
var reader = pipe.Reader;
var writer = pipe.Writer;

var readerTask = Task.Run(
async () =>
{
ReadResult result;

do
{
result = await reader.ReadAsync(ct);
if (result.IsCanceled)
{
break;
}

var buffer = result.Buffer;
if (buffer.IsSingleSegment)
{
var segment = buffer.First;
await context.SendAsync(segment, WebSocketMessageType.Binary, true, ct);
}
else
{
foreach (var segment in buffer)
{
await context.SendAsync(segment, WebSocketMessageType.Binary, false, ct);
}

await context.SendAsync(ArraySegment<byte>.Empty, WebSocketMessageType.Binary, true, ct);
}

reader.AdvanceTo(buffer.End);
}
while (!result.IsCompleted);
},
ct);

try
{
return await Run(pipeline, writer, services, ct);
}
catch (OperationCanceledException e) when (e.CancellationToken == ct)
{
return TaskPipelineResult.Canceled;
}
finally
{
await writer.CompleteAsync();

try
{
await readerTask;
}
catch (OperationCanceledException e) when (e.CancellationToken == ct)
{
}
catch (Exception e)
{
}
}
}

private static async Task<TaskPipelineResult> Run(TaskPipeline pipeline, PipeWriter writer, IServiceProvider services, CancellationToken cancellationToken)
{
await using var textWriter = new BufferTextWriter(writer, Encoding.UTF8);
var consoleOutput = new ConsoleOutput(textWriter);
var console = new Console(
AnsiConsole.Create(new AnsiConsoleSettings
Expand All @@ -33,7 +113,7 @@ internal static async Task Run(Pipeline pipeline, HttpContext context, Cancellat
Interactive = InteractionSupport.Yes,
}),
textWriter,
responseBody.Stream);
writer);

console.WriteLine();
var progress = console
Expand All @@ -43,13 +123,15 @@ internal static async Task Run(Pipeline pipeline, HttpContext context, Cancellat
new TaskDescriptionColumn { Alignment = Justify.Left },
]);

var ctx = new PipelineContext(console, progress, context);
var ctx = new PipelineContext(console, progress, services);
try
{
await pipeline.ExecuteAsync(ctx, cancellationToken);
return TaskPipelineResult.Ok;
}
catch (OperationCanceledException ex) when (ex.CancellationToken == cancellationToken)
{
return TaskPipelineResult.Canceled;
}
catch (Exception ex)
{
Expand All @@ -63,33 +145,33 @@ internal static async Task Run(Pipeline pipeline, HttpContext context, Cancellat
{
// Ignore exceptions from the exception handler.
}
}

await responseBody.CompleteAsync();
return TaskPipelineResult.Error;
}
}

private readonly IAnsiConsole _console;
private readonly Progress _progress;
private readonly HttpContext _context;
private readonly IServiceProvider _services;

private PipelineContext(IAnsiConsole console, Progress progress, HttpContext context)
private PipelineContext(IAnsiConsole console, Progress progress, IServiceProvider services)
{
_console = console;
_progress = progress;
_context = context;
_services = services;
}

object? IServiceProvider.GetService(Type serviceType)
=> _context.RequestServices.GetService(serviceType);
=> _services.GetService(serviceType);

object ISupportRequiredService.GetRequiredService(Type serviceType)
=> _context.RequestServices.GetRequiredService(serviceType);
=> _services.GetRequiredService(serviceType);

object? IKeyedServiceProvider.GetKeyedService(Type serviceType, object? serviceKey)
=> _context.RequestServices.GetKeyedServices(serviceType, serviceKey);
=> _services.GetKeyedServices(serviceType, serviceKey);

object IKeyedServiceProvider.GetRequiredKeyedService(Type serviceType, object? serviceKey)
=> _context.RequestServices.GetRequiredKeyedService(serviceType, serviceKey);
=> _services.GetRequiredKeyedService(serviceType, serviceKey);

public Task<T> RunTask<T>(StepTask<T> task, CancellationToken cancellationToken)
{
Expand Down Expand Up @@ -160,7 +242,7 @@ public void SetEncoding(Encoding encoding)
}
}

private class Console(IAnsiConsole inner, TextWriter writer, Stream stream)
private class Console(IAnsiConsole inner, TextWriter writer, PipeWriter innerWriter)
: IAnsiConsole
{
public Profile Profile => inner.Profile;
Expand All @@ -182,7 +264,7 @@ public void Write(IRenderable renderable)
{
inner.Write(renderable);
writer.Flush();
stream.Flush();
innerWriter.FlushAsync().AsTask().Wait();
}
}

Expand Down Expand Up @@ -222,3 +304,10 @@ private enum TaskOutcome
Error,
}
}

public enum TaskPipelineResult
{
Ok,
Error,
Canceled,
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,113 @@
using System.Buffers;
using System.Net.WebSockets;
using System.Text.Json;
using Nerdbank.Streams;

namespace Altinn.Authorization.DeployApi.Pipelines;

internal abstract class TaskPipeline
{
protected internal abstract Task ExecuteAsync(PipelineContext context, CancellationToken cancellationToken);

public Task Run(HttpContext context)
=> PipelineContext.Run(this, context);

public Task<TaskPipelineResult> Run(WebSocket context, IServiceProvider services, CancellationToken cancellationToken)
=> PipelineContext.Run(this, context, services, cancellationToken);
}

internal static class TaskPipelineExtensions
{
private static JsonSerializerOptions Options { get; } = new(JsonSerializerDefaults.Web);

public static IEndpointConventionBuilder MapTaskPipeline<TPipeline>(this IEndpointRouteBuilder endpoints, string pattern)
where TPipeline : TaskPipeline
=> endpoints.Map(pattern, async (HttpContext context) =>
{
if (context.WebSockets.IsWebSocketRequest)
{
using var webSocket = await context.WebSockets.AcceptWebSocketAsync("altinn.task-pipeline");
TPipeline? pipeline;

{
using var sequence = new Sequence<byte>(ArrayPool<byte>.Shared);
var result = await webSocket.ReceiveAsync(sequence, context.RequestAborted);
if (result.MessageType == WebSocketMessageType.Close)
{
await webSocket.CloseAsync(WebSocketCloseStatus.NormalClosure, "Closed by client", context.RequestAborted);
return;
}

pipeline = DeserializePipeline<TPipeline>(sequence.AsReadOnlySequence);
}

if (pipeline is null)
{
await webSocket.CloseAsync(WebSocketCloseStatus.ProtocolError, "Missing pipeline payload", context.RequestAborted);
return;
}

TaskPipelineResult pipelineResult;
try
{
pipelineResult = await pipeline.Run(webSocket, context.RequestServices, context.RequestAborted);
}
catch (OperationCanceledException ex) when (ex.CancellationToken == context.RequestAborted)
{
pipelineResult = TaskPipelineResult.Canceled;
}
catch (Exception)
{
pipelineResult = TaskPipelineResult.Error;
}

var (closeCode, closeDescription) = pipelineResult switch
{
TaskPipelineResult.Ok => ((WebSocketCloseStatus)4000, "ok"),
TaskPipelineResult.Canceled => ((WebSocketCloseStatus)4001, "canceled"),
TaskPipelineResult.Error => ((WebSocketCloseStatus)4002, "error"),
_ => (WebSocketCloseStatus.InternalServerError, "unexpected pipeline result"),
};

await webSocket.CloseAsync(closeCode, closeDescription, context.RequestAborted);
}
else if (HttpMethods.IsPost(context.Request.Method))
{
var pipeline = await context.Request.ReadFromJsonAsync<TPipeline>(Options, context.RequestAborted);
if (pipeline is null)
{
context.Response.StatusCode = StatusCodes.Status400BadRequest;
return;
}

await pipeline.Run(context);
}
else
{
context.Response.StatusCode = StatusCodes.Status400BadRequest;
return;
}
});

private static TPipeline? DeserializePipeline<TPipeline>(ReadOnlySequence<byte> sequence)
where TPipeline : TaskPipeline
{
var reader = new Utf8JsonReader(sequence);
return JsonSerializer.Deserialize<TPipeline>(ref reader, Options);
}

private static async ValueTask<ValueWebSocketReceiveResult> ReceiveAsync(this WebSocket webSocket, IBufferWriter<byte> writer, CancellationToken cancellationToken)
{
ValueWebSocketReceiveResult result;

do
{
var memory = writer.GetMemory(4096);
result = await webSocket.ReceiveAsync(memory, cancellationToken);
writer.Advance(result.Count);
}
while (!result.EndOfMessage);

return result;
}
}
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
using Altinn.Authorization.DeployApi.BootstrapDatabase;
using Altinn.Authorization.Hosting.Extensions;
using Altinn.Authorization.DeployApi.Pipelines;
using Azure.Core;
using Azure.Identity;
using Microsoft.AspNetCore.Server.Kestrel.Core;
Expand Down Expand Up @@ -34,6 +35,8 @@
var app = builder.Build();

app.UseAltinnHostDefaults();
app.MapPost("deployapi/api/v1/databases/bootstrap", (BootstrapDatabasePipeline pipeline, HttpContext context) => pipeline.Run(context));
app.UseWebSockets();

app.MapTaskPipeline<BootstrapDatabasePipeline>("/deployapi/api/v1/database/bootstrap");

app.Run();

0 comments on commit a55a1d9

Please sign in to comment.