Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 13 additions & 0 deletions src/Common/EncodingUtilities.cs
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,19 @@ public static byte[] GetUtf8Bytes(ReadOnlySpan<char> utf16)
return bytes;
}

/// <summary>Decodes a UTF-8 <see cref="ReadOnlySequence{T}"/> to a <see cref="string"/>.</summary>
public static string GetUtf8String(in ReadOnlySequence<byte> sequence)
{
if (sequence.IsEmpty)
{
return string.Empty;
}

return sequence.IsSingleSegment
? Encoding.UTF8.GetString(sequence.First.Span)
: Encoding.UTF8.GetString(sequence.ToArray());
}

/// <summary>
/// Encodes binary data to base64-encoded UTF-8 bytes.
/// </summary>
Expand Down
19 changes: 19 additions & 0 deletions src/Common/Polyfills/System/Text/EncodingExtensions.cs
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,25 @@ public static int GetBytes(this Encoding encoding, ReadOnlySpan<char> chars, Spa
}
}
}

/// <summary>
/// Decodes all the bytes in the specified span into a string.
/// </summary>
public static string GetString(this Encoding encoding, ReadOnlySpan<byte> bytes)
{
if (bytes.IsEmpty)
{
return string.Empty;
}

unsafe
{
fixed (byte* bytesPtr = bytes)
{
return encoding.GetString(bytesPtr, bytes.Length);
}
}
}
}

#endif
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ namespace ModelContextProtocol.Client;
/// <summary>Provides the client side of a stdio-based session transport.</summary>
internal sealed class StdioClientSessionTransport(
StdioClientTransportOptions options, Process process, string endpointName, Queue<string> stderrRollingLog, ILoggerFactory? loggerFactory) :
StreamClientSessionTransport(process.StandardInput.BaseStream, process.StandardOutput.BaseStream, encoding: null, endpointName, loggerFactory)
StreamClientSessionTransport(process.StandardInput.BaseStream, process.StandardOutput.BaseStream, endpointName, loggerFactory)
{
private readonly StdioClientTransportOptions _options = options;
private readonly Process _process = process;
Expand Down
10 changes: 6 additions & 4 deletions src/ModelContextProtocol.Core/Client/StdioClientTransport.cs
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,8 @@ public sealed partial class StdioClientTransport : IClientTransport
private static readonly object s_consoleEncodingLock = new();
#endif

private static readonly UTF8Encoding s_noBomUtf8Encoding = new(encoderShouldEmitUTF8Identifier: false);

private readonly StdioClientTransportOptions _options;
private readonly ILoggerFactory? _loggerFactory;

Expand Down Expand Up @@ -85,10 +87,10 @@ public async Task<ITransport> ConnectAsync(CancellationToken cancellationToken =
UseShellExecute = false,
CreateNoWindow = true,
WorkingDirectory = _options.WorkingDirectory ?? Environment.CurrentDirectory,
StandardOutputEncoding = StreamClientSessionTransport.NoBomUtf8Encoding,
StandardErrorEncoding = StreamClientSessionTransport.NoBomUtf8Encoding,
StandardOutputEncoding = s_noBomUtf8Encoding,
StandardErrorEncoding = s_noBomUtf8Encoding,
#if NET
StandardInputEncoding = StreamClientSessionTransport.NoBomUtf8Encoding,
StandardInputEncoding = s_noBomUtf8Encoding,
#endif
};

Expand Down Expand Up @@ -173,7 +175,7 @@ public async Task<ITransport> ConnectAsync(CancellationToken cancellationToken =
Encoding originalInputEncoding = Console.InputEncoding;
try
{
Console.InputEncoding = StreamClientSessionTransport.NoBomUtf8Encoding;
Console.InputEncoding = s_noBomUtf8Encoding;
processStarted = process.Start();
}
finally
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
using Microsoft.Extensions.Logging;
using ModelContextProtocol.Protocol;
using System.Text;
using System.Buffers;
using System.IO.Pipelines;
using System.Text.Json;

namespace ModelContextProtocol.Client;
Expand All @@ -10,9 +11,7 @@ internal class StreamClientSessionTransport : TransportBase
{
private static readonly byte[] s_newlineBytes = "\n"u8.ToArray();

internal static UTF8Encoding NoBomUtf8Encoding { get; } = new(encoderShouldEmitUTF8Identifier: false);

private readonly TextReader _serverOutput;
private readonly PipeReader _serverOutputPipe;
private readonly Stream _serverInputStream;
private readonly SemaphoreSlim _sendLock = new(1, 1);
private CancellationTokenSource? _shutdownCts = new();
Expand All @@ -27,9 +26,6 @@ internal class StreamClientSessionTransport : TransportBase
/// <param name="serverOutput">
/// The server's output stream. Messages read from this stream will be received from the server.
/// </param>
/// <param name="encoding">
/// The encoding used for reading and writing messages from the input and output streams. Defaults to UTF-8 without BOM if null.
/// </param>
/// <param name="endpointName">
/// A name that identifies this transport endpoint in logs.
/// </param>
Expand All @@ -40,18 +36,14 @@ internal class StreamClientSessionTransport : TransportBase
/// This constructor starts a background task to read messages from the server output stream.
/// The transport will be marked as connected once initialized.
/// </remarks>
public StreamClientSessionTransport(Stream serverInput, Stream serverOutput, Encoding? encoding, string endpointName, ILoggerFactory? loggerFactory)
public StreamClientSessionTransport(Stream serverInput, Stream serverOutput, string endpointName, ILoggerFactory? loggerFactory)
: base(endpointName, loggerFactory)
{
Throw.IfNull(serverInput);
Throw.IfNull(serverOutput);

_serverInputStream = serverInput;
#if NET
_serverOutput = new StreamReader(serverOutput, encoding ?? NoBomUtf8Encoding);
#else
_serverOutput = new CancellableStreamReader(serverOutput, encoding ?? NoBomUtf8Encoding);
#endif
_serverOutputPipe = PipeReader.Create(serverOutput);

SetConnected();

Expand Down Expand Up @@ -102,24 +94,8 @@ private async Task ReadMessagesAsync(CancellationToken cancellationToken)
try
{
LogTransportEnteringReadMessagesLoop(Name);

while (true)
{
if (await _serverOutput.ReadLineAsync(cancellationToken).ConfigureAwait(false) is not string line)
{
LogTransportEndOfStream(Name);
break;
}

if (string.IsNullOrWhiteSpace(line))
{
continue;
}

LogTransportReceivedMessageSensitive(Name, line);

await ProcessMessageAsync(line, cancellationToken).ConfigureAwait(false);
}
await _serverOutputPipe.ReadLinesAsync(ProcessLineAsync, cancellationToken).ConfigureAwait(false);
LogTransportEndOfStream(Name);
}
catch (OperationCanceledException)
{
Expand All @@ -137,25 +113,43 @@ private async Task ReadMessagesAsync(CancellationToken cancellationToken)
}
}

private async Task ProcessMessageAsync(string line, CancellationToken cancellationToken)
private async Task ProcessLineAsync(ReadOnlySequence<byte> line, CancellationToken cancellationToken)
{
if (Logger.IsEnabled(LogLevel.Trace))
{
LogTransportReceivedMessageSensitive(Name, EncodingUtilities.GetUtf8String(line));
}

try
{
var message = (JsonRpcMessage?)JsonSerializer.Deserialize(line.AsSpan().Trim(), McpJsonUtilities.DefaultOptions.GetTypeInfo(typeof(JsonRpcMessage)));
if (message != null)
JsonRpcMessage? message;
if (line.IsSingleSegment)
{
message = JsonSerializer.Deserialize(line.First.Span, McpJsonUtilities.DefaultOptions.GetTypeInfo(typeof(JsonRpcMessage))) as JsonRpcMessage;
}
else
{
var reader = new Utf8JsonReader(line, isFinalBlock: true, state: default);
message = JsonSerializer.Deserialize(ref reader, McpJsonUtilities.DefaultOptions.GetTypeInfo(typeof(JsonRpcMessage))) as JsonRpcMessage;
}

if (message is not null)
{
await WriteMessageAsync(message, cancellationToken).ConfigureAwait(false);
}
else
{
LogTransportMessageParseUnexpectedTypeSensitive(Name, line);
if (Logger.IsEnabled(LogLevel.Trace))
{
LogTransportMessageParseUnexpectedTypeSensitive(Name, EncodingUtilities.GetUtf8String(line));
}
}
}
catch (JsonException ex)
{
if (Logger.IsEnabled(LogLevel.Trace))
{
LogTransportMessageParseFailedSensitive(Name, line, ex);
LogTransportMessageParseFailedSensitive(Name, EncodingUtilities.GetUtf8String(line), ex);
}
else
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,6 @@ public Task<ITransport> ConnectAsync(CancellationToken cancellationToken = defau
return Task.FromResult<ITransport>(new StreamClientSessionTransport(
_serverInput,
_serverOutput,
encoding: null,
"Client (stream)",
_loggerFactory));
}
Expand Down
72 changes: 72 additions & 0 deletions src/ModelContextProtocol.Core/Protocol/PipeReaderExtensions.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
using System.Buffers;
using System.IO.Pipelines;

namespace ModelContextProtocol.Protocol;

/// <summary>Internal helper for reading newline-delimited UTF-8 lines from a <see cref="PipeReader"/>.</summary>
internal static class PipeReaderExtensions
{
/// <summary>
/// Reads newline-delimited lines from <paramref name="reader"/>, invoking
/// <paramref name="processLine"/> for each non-empty line, until the reader signals completion.
/// </summary>
internal static async Task ReadLinesAsync(
this PipeReader reader,
Func<ReadOnlySequence<byte>, CancellationToken, Task> processLine,
CancellationToken cancellationToken)
{
while (true)
{
ReadResult result = await reader.ReadAsync(cancellationToken).ConfigureAwait(false);
ReadOnlySequence<byte> buffer = result.Buffer;

SequencePosition? position;
while ((position = buffer.PositionOf((byte)'\n')) != null)
{
ReadOnlySequence<byte> line = buffer.Slice(0, position.Value);

// Trim trailing \r for Windows-style CRLF line endings.
if (EndsWithCarriageReturn(line))
{
line = line.Slice(0, line.Length - 1);
}

if (!line.IsEmpty)
{
await processLine(line, cancellationToken).ConfigureAwait(false);
}

// Advance past the '\n'.
buffer = buffer.Slice(buffer.GetPosition(1, position.Value));
}

reader.AdvanceTo(buffer.Start, buffer.End);

if (result.IsCompleted)
{
break;
}
}
}

private static bool EndsWithCarriageReturn(in ReadOnlySequence<byte> sequence)
{
if (sequence.IsSingleSegment)
{
ReadOnlySpan<byte> span = sequence.First.Span;
return span.Length > 0 && span[span.Length - 1] == (byte)'\r';
}

// Multi-segment: find the last non-empty segment to check its last byte.
ReadOnlyMemory<byte> last = default;
foreach (ReadOnlyMemory<byte> segment in sequence)
{
if (!segment.IsEmpty)
{
last = segment;
}
}

return !last.IsEmpty && last.Span[last.Length - 1] == (byte)'\r';
}
}
Loading