Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

AsyncEnumerable optimizations #315

Merged
merged 1 commit into from
Dec 5, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Package/Core/Cancelations/Internal/CancelationInternal.cs
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@

namespace Proto.Promises
{
internal static partial class Internal
partial class Internal
{
#if !PROTO_PROMISE_DEVELOPER_MODE
[DebuggerNonUserCode, StackTraceHidden]
Expand Down
2 changes: 1 addition & 1 deletion Package/Core/InternalShared/InterfacesInternal.cs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

namespace Proto.Promises
{
internal static partial class Internal
partial class Internal
{
internal partial interface ITraceable { }

Expand Down
2 changes: 1 addition & 1 deletion Package/Core/InternalShared/ValueStopwatch.cs
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@

namespace Proto.Promises
{
internal static partial class Internal
partial class Internal
{
// Idea from https://www.meziantou.net/how-to-measure-elapsed-time-without-allocating-a-stopwatch.htm
#pragma warning disable IDE0250 // Make struct 'readonly'
Expand Down
8 changes: 4 additions & 4 deletions Package/Core/Linq/CompilerServices/AsyncStreamWriter.cs
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,11 @@ namespace Proto.Promises.Async.CompilerServices
#endif
public readonly struct AsyncStreamWriter<T>
{
private readonly Internal.PromiseRefBase.AsyncEnumerableBase<T> _target;
private readonly Internal.PromiseRefBase.AsyncEnumerableWithIterator<T> _target;
private readonly int _id;

[MethodImpl(Internal.InlineOption)]
internal AsyncStreamWriter(Internal.PromiseRefBase.AsyncEnumerableBase<T> target, int id)
internal AsyncStreamWriter(Internal.PromiseRefBase.AsyncEnumerableWithIterator<T> target, int id)
{
_target = target;
_id = id;
Expand All @@ -41,11 +41,11 @@ public AsyncStreamYielder<T> YieldAsync(T value)
#endif
public readonly partial struct AsyncStreamYielder<T> : ICriticalNotifyCompletion, Internal.IPromiseAwaiter
{
private readonly Internal.PromiseRefBase.AsyncEnumerableBase<T> _target;
private readonly Internal.PromiseRefBase.AsyncEnumerableWithIterator<T> _target;
private readonly int _enumerableId;

[MethodImpl(Internal.InlineOption)]
internal AsyncStreamYielder(Internal.PromiseRefBase.AsyncEnumerableBase<T> target, int enumerableId)
internal AsyncStreamYielder(Internal.PromiseRefBase.AsyncEnumerableWithIterator<T> target, int enumerableId)
{
_target = target;
_enumerableId = enumerableId;
Expand Down
4 changes: 1 addition & 3 deletions Package/Core/Linq/Generators/Canceled.cs
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
namespace Proto.Promises.Linq
{
#if CSHARP_7_3_OR_NEWER // We only expose AsyncEnumerable where custom async method builders are supported.
public static partial class AsyncEnumerable
partial class AsyncEnumerable
{
/// <summary>
/// Generates an async-enumerable sequence that will be immediately canceled when iterated.
Expand Down Expand Up @@ -72,8 +72,6 @@ internal override Promise DisposeAsync(int id)
// Do nothing, just return a resolved promise.
=> Promise.Resolved();

protected override void Start(int enumerableId) { throw new System.InvalidOperationException(); }
protected override void DisposeAndReturnToPool() { throw new System.InvalidOperationException(); }
internal override void MaybeDispose() { throw new System.InvalidOperationException(); }
}
}
Expand Down
4 changes: 2 additions & 2 deletions Package/Core/Linq/Generators/Create.cs
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ public static AsyncEnumerable<T> Create(Func<AsyncStreamWriter<T>, CancelationTo
{
ValidateArgument(asyncIterator, nameof(asyncIterator), 1);

var enumerable = Internal.AsyncEnumerableImpl<T, Internal.AsyncIterator<T>>.GetOrCreate(new Internal.AsyncIterator<T>(asyncIterator));
var enumerable = Internal.AsyncEnumerableCreate<T, Internal.AsyncIterator<T>>.GetOrCreate(new Internal.AsyncIterator<T>(asyncIterator));
return new AsyncEnumerable<T>(enumerable);
}

Expand All @@ -54,7 +54,7 @@ public static AsyncEnumerable<T> Create<TCapture>(TCapture captureValue, Func<TC
{
ValidateArgument(asyncIterator, nameof(asyncIterator), 1);

var enumerable = Internal.AsyncEnumerableImpl<T, Internal.AsyncIterator<T, TCapture>>.GetOrCreate(new Internal.AsyncIterator<T, TCapture>(captureValue, asyncIterator));
var enumerable = Internal.AsyncEnumerableCreate<T, Internal.AsyncIterator<T, TCapture>>.GetOrCreate(new Internal.AsyncIterator<T, TCapture>(captureValue, asyncIterator));
return new AsyncEnumerable<T>(enumerable);
}
}
Expand Down
6 changes: 2 additions & 4 deletions Package/Core/Linq/Generators/Empty.cs
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
namespace Proto.Promises.Linq
{
#if CSHARP_7_3_OR_NEWER // We only expose AsyncEnumerable where custom async method builders are supported.
public static partial class AsyncEnumerable
partial class AsyncEnumerable
{
/// <summary>
/// Returns an empty async-enumerable sequence.
Expand All @@ -21,7 +21,7 @@ public static AsyncEnumerable<T> Empty<T>()
=> AsyncEnumerable<T>.Empty();
}

public readonly partial struct AsyncEnumerable<T>
partial struct AsyncEnumerable<T>
{
/// <summary>
/// Returns an empty async-enumerable sequence.
Expand Down Expand Up @@ -69,8 +69,6 @@ internal override Promise DisposeAsync(int id)
// Do nothing, just return a resolved promise.
=> Promise.Resolved();

protected override void Start(int enumerableId) { throw new System.InvalidOperationException(); }
protected override void DisposeAndReturnToPool() { throw new System.InvalidOperationException(); }
internal override void MaybeDispose() { throw new System.InvalidOperationException(); }
}
}
Expand Down
15 changes: 7 additions & 8 deletions Package/Core/Linq/Generators/Range.cs
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
namespace Proto.Promises.Linq
{
#if CSHARP_7_3_OR_NEWER // We only expose AsyncEnumerable where custom async method builders are supported.
public static partial class AsyncEnumerable
partial class AsyncEnumerable
{
/// <summary>
/// Generates an async-enumerable sequence of integral numbers within a specified range.
Expand Down Expand Up @@ -104,23 +104,22 @@ internal override Promise DisposeAsync(int id)
if (Interlocked.CompareExchange(ref _enumerableId, id + 1, id) == id)
{
// This was not already disposed.
#if PROMISE_DEBUG || PROTO_PROMISE_DEVELOPER_MODE
SetCompletionState(null, Promise.State.Resolved);
#endif
DisposeAndReturnToPool();
Dispose();
}
// IAsyncDisposable.DisposeAsync must not throw if it's called multiple times, according to MSDN documentation.
return Promise.Resolved();
}

protected override void DisposeAndReturnToPool()
new private void Dispose()
{
Dispose();
#if PROMISE_DEBUG || PROTO_PROMISE_DEVELOPER_MODE
SetCompletionState(null, Promise.State.Resolved);
#endif
base.Dispose();
_disposed = true;
ObjectPool.MaybeRepool(this);
}

protected override void Start(int enumerableId) { throw new System.InvalidOperationException(); }
internal override void MaybeDispose() { throw new System.InvalidOperationException(); }
}
}
Expand Down
2 changes: 1 addition & 1 deletion Package/Core/Linq/Generators/Rejected.cs
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
namespace Proto.Promises.Linq
{
#if CSHARP_7_3_OR_NEWER // We only expose AsyncEnumerable where custom async method builders are supported.
public static partial class AsyncEnumerable
partial class AsyncEnumerable
{
/// <summary>
/// Generates an async-enumerable sequence that will be immediately rejected with the provided reason when iterated.
Expand Down
15 changes: 7 additions & 8 deletions Package/Core/Linq/Generators/Repeat.cs
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
namespace Proto.Promises.Linq
{
#if CSHARP_7_3_OR_NEWER // We only expose AsyncEnumerable where custom async method builders are supported.
public static partial class AsyncEnumerable
partial class AsyncEnumerable
{
/// <summary>
/// Generates an async-enumerable sequence that contains one repeated value.
Expand Down Expand Up @@ -107,24 +107,23 @@ internal override Promise DisposeAsync(int id)
if (Interlocked.CompareExchange(ref _enumerableId, id + 1, id) == id)
{
// This was not already disposed.
#if PROMISE_DEBUG || PROTO_PROMISE_DEVELOPER_MODE
SetCompletionState(null, Promise.State.Resolved);
#endif
DisposeAndReturnToPool();
Dispose();
}
// IAsyncDisposable.DisposeAsync must not throw if it's called multiple times, according to MSDN documentation.
return Promise.Resolved();
}

protected override void DisposeAndReturnToPool()
new private void Dispose()
{
Dispose();
#if PROMISE_DEBUG || PROTO_PROMISE_DEVELOPER_MODE
SetCompletionState(null, Promise.State.Resolved);
#endif
base.Dispose();
_current = default;
_disposed = true;
ObjectPool.MaybeRepool(this);
}

protected override void Start(int enumerableId) { throw new System.InvalidOperationException(); }
internal override void MaybeDispose() { throw new System.InvalidOperationException(); }
}
}
Expand Down
15 changes: 7 additions & 8 deletions Package/Core/Linq/Generators/Return.cs
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
namespace Proto.Promises.Linq
{
#if CSHARP_7_3_OR_NEWER // We only expose AsyncEnumerable where custom async method builders are supported.
public static partial class AsyncEnumerable
partial class AsyncEnumerable
{
/// <summary>
/// Generates an async-enumerable sequence that contains a single element.
Expand Down Expand Up @@ -83,24 +83,23 @@ internal override Promise DisposeAsync(int id)
if (Interlocked.CompareExchange(ref _enumerableId, id + 1, id) == id)
{
// This was not already disposed.
#if PROMISE_DEBUG || PROTO_PROMISE_DEVELOPER_MODE
SetCompletionState(null, Promise.State.Resolved);
#endif
DisposeAndReturnToPool();
Dispose();
}
// IAsyncDisposable.DisposeAsync must not throw if it's called multiple times, according to MSDN documentation.
return Promise.Resolved();
}

protected override void DisposeAndReturnToPool()
new private void Dispose()
{
Dispose();
#if PROMISE_DEBUG || PROTO_PROMISE_DEVELOPER_MODE
SetCompletionState(null, Promise.State.Resolved);
#endif
base.Dispose();
_current = default;
_disposed = true;
ObjectPool.MaybeRepool(this);
}

protected override void Start(int enumerableId) { throw new System.InvalidOperationException(); }
internal override void MaybeDispose() { throw new System.InvalidOperationException(); }
}
}
Expand Down
66 changes: 37 additions & 29 deletions Package/Core/Linq/Internal/AsyncEnumerableInternal.cs
Original file line number Diff line number Diff line change
Expand Up @@ -76,17 +76,11 @@ partial class PromiseRefBase
#endif
internal abstract class AsyncEnumerableBase<T> : PromiseSingleAwait<bool>
{
// This is used as the backing reference to 3 different awaiters. MoveNextAsync (Promise<bool>), DisposeAsync (Promise), and YieldAsync (AsyncStreamYielder<T>).
// We use `Interlocked.CompareExchange(ref _enumerableId` to enforce only 1 awaiter uses it at a time, in the correct order.
// We use a separate field for AsyncStreamYielder continuation, because using _next for 2 separate async functions (the iterator and the consumer) proves problematic.
protected PromiseRefBase _iteratorPromiseRef;
internal CancelationToken _cancelationToken;
protected T _current;
private int _iteratorCompleteExpectedId;
private int _iteratorCompleteId;
protected int _enumerableId = 1; // Start with Id 1 instead of 0 to reduce risk of false positives.
protected bool _disposed;
protected bool _isStarted;
internal CancelationToken _cancelationToken;

internal int EnumerableId
{
Expand Down Expand Up @@ -149,8 +143,23 @@ internal T GetCurrent(int id)
return _current;
}

[MethodImpl(InlineOption)]
internal virtual Promise<bool> MoveNextAsync(int id)
internal abstract Promise<bool> MoveNextAsync(int id);
internal abstract Promise DisposeAsync(int id);
} // class AsyncEnumerableBase<T>

#if !PROTO_PROMISE_DEVELOPER_MODE
[DebuggerNonUserCode, StackTraceHidden]
#endif
internal abstract class AsyncEnumerableWithIterator<T> : AsyncEnumerableBase<T>
{
// This is used as the backing reference to 3 different awaiters. MoveNextAsync (Promise<bool>), DisposeAsync (Promise), and YieldAsync (AsyncStreamYielder<T>).
// We use `Interlocked.CompareExchange(ref _enumerableId` to enforce only 1 awaiter uses it at a time, in the correct order.
// We use a separate field for AsyncStreamYielder continuation, because using _next for 2 separate async functions (the iterator and the consumer) proves problematic.
protected PromiseRefBase _iteratorPromiseRef;
private int _iteratorCompleteExpectedId;
private int _iteratorCompleteId;

internal override Promise<bool> MoveNextAsync(int id)
{
// We increment by 1 when MoveNextAsync, then decrement by 1 when YieldAsync.
int newId = id + 1;
Expand Down Expand Up @@ -183,6 +192,16 @@ internal virtual Promise<bool> MoveNextAsync(int id)
return new Promise<bool>(this, Id, 0);
}

private void MoveNext()
{
// Invalidate the previous awaiter.
IncrementPromiseIdAndClearPrevious();
// Reset for the next awaiter.
ResetWithoutStacktrace();
// Handle iterator promise to move the async state machine forward.
InterlockedExchange(ref _iteratorPromiseRef, null).Handle(this, null, Promise.State.Resolved);
}

[MethodImpl(InlineOption)]
internal AsyncStreamYielder<T> YieldAsync(in T value, int id)
{
Expand All @@ -199,7 +218,7 @@ internal AsyncStreamYielder<T> YieldAsync(in T value, int id)
return new AsyncStreamYielder<T>(this, newId);
}

internal virtual Promise DisposeAsync(int id)
internal override Promise DisposeAsync(int id)
{
int newId = id + 3;
// When the async iterator function completes before DisposeAsync is called, it's set to id + 2.
Expand Down Expand Up @@ -317,43 +336,32 @@ internal void AwaitOnCompletedForAsyncStreamYielder(PromiseRefBase asyncPromiseR
HandleNextInternal(null, Promise.State.Resolved);
}

protected void MoveNext()
{
// Invalidate the previous awaiter.
IncrementPromiseIdAndClearPrevious();
// Reset for the next awaiter.
ResetWithoutStacktrace();
// Handle iterator promise to move the async state machine forward.
InterlockedExchange(ref _iteratorPromiseRef, null).Handle(this, null, Promise.State.Resolved);
}

protected abstract void Start(int enumerableId);

protected abstract void DisposeAndReturnToPool();
} // class AsyncEnumerableBase<T>
} // class AsyncEnumerableWithIterator<TValue>
} // class PromiseRefBase

#if !PROTO_PROMISE_DEVELOPER_MODE
[DebuggerNonUserCode, StackTraceHidden]
#endif
internal sealed class AsyncEnumerableImpl<TValue, TIterator> : PromiseRefBase.AsyncEnumerableBase<TValue>
internal sealed class AsyncEnumerableCreate<TValue, TIterator> : PromiseRefBase.AsyncEnumerableWithIterator<TValue>
where TIterator : IAsyncIterator<TValue>
{
private TIterator _iterator;

private AsyncEnumerableImpl() { }
private AsyncEnumerableCreate() { }

[MethodImpl(InlineOption)]
private static AsyncEnumerableImpl<TValue, TIterator> GetOrCreate()
private static AsyncEnumerableCreate<TValue, TIterator> GetOrCreate()
{
var obj = ObjectPool.TryTakeOrInvalid<AsyncEnumerableImpl<TValue, TIterator>>();
var obj = ObjectPool.TryTakeOrInvalid<AsyncEnumerableCreate<TValue, TIterator>>();
return obj == InvalidAwaitSentinel.s_instance
? new AsyncEnumerableImpl<TValue, TIterator>()
: obj.UnsafeAs<AsyncEnumerableImpl<TValue, TIterator>>();
? new AsyncEnumerableCreate<TValue, TIterator>()
: obj.UnsafeAs<AsyncEnumerableCreate<TValue, TIterator>>();
}

[MethodImpl(InlineOption)]
internal static AsyncEnumerableImpl<TValue, TIterator> GetOrCreate(in TIterator iterator)
internal static AsyncEnumerableCreate<TValue, TIterator> GetOrCreate(in TIterator iterator)
{
var enumerable = GetOrCreate();
enumerable.Reset();
Expand Down
2 changes: 1 addition & 1 deletion Package/Core/Linq/Internal/MergeInternal.cs
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ partial class Internal
#if !PROTO_PROMISE_DEVELOPER_MODE
[DebuggerNonUserCode, StackTraceHidden]
#endif
internal abstract class AsyncEnumerableMergerBase<TValue> : PromiseRefBase.AsyncEnumerableBase<TValue>
internal abstract class AsyncEnumerableMergerBase<TValue> : PromiseRefBase.AsyncEnumerableWithIterator<TValue>
{
// TODO: optimize these collections.
protected readonly List<AsyncEnumerator<TValue>> _enumerators = new List<AsyncEnumerator<TValue>>();
Expand Down
2 changes: 1 addition & 1 deletion Package/Core/Promises/Internal/RejectContainersInternal.cs
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@

namespace Proto.Promises
{
internal static partial class Internal
partial class Internal
{
// Extension method instead of including on the interface, since old IL2CPP compiler does not support virtual generics with structs.
internal static bool TryGetValue<TValue>(this IRejectContainer rejectContainer, out TValue converted)
Expand Down
Loading