Skip to content

Commit

Permalink
Optimize AsyncEnumerable generators that don't use iterator functions.
Browse files Browse the repository at this point in the history
Partial class cleanup.
  • Loading branch information
timcassell committed Dec 4, 2023
1 parent 0379e78 commit 5ce664b
Show file tree
Hide file tree
Showing 14 changed files with 73 additions and 72 deletions.
2 changes: 1 addition & 1 deletion Package/Core/Cancelations/Internal/CancelationInternal.cs
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@

namespace Proto.Promises
{
internal static partial class Internal
partial class Internal
{
#if !PROTO_PROMISE_DEVELOPER_MODE
[DebuggerNonUserCode, StackTraceHidden]
Expand Down
2 changes: 1 addition & 1 deletion Package/Core/InternalShared/InterfacesInternal.cs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

namespace Proto.Promises
{
internal static partial class Internal
partial class Internal
{
internal partial interface ITraceable { }

Expand Down
2 changes: 1 addition & 1 deletion Package/Core/InternalShared/ValueStopwatch.cs
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@

namespace Proto.Promises
{
internal static partial class Internal
partial class Internal
{
// Idea from https://www.meziantou.net/how-to-measure-elapsed-time-without-allocating-a-stopwatch.htm
#pragma warning disable IDE0250 // Make struct 'readonly'
Expand Down
8 changes: 4 additions & 4 deletions Package/Core/Linq/CompilerServices/AsyncStreamWriter.cs
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,11 @@ namespace Proto.Promises.Async.CompilerServices
#endif
public readonly struct AsyncStreamWriter<T>
{
private readonly Internal.PromiseRefBase.AsyncEnumerableBase<T> _target;
private readonly Internal.PromiseRefBase.AsyncEnumerableWithIterator<T> _target;
private readonly int _id;

[MethodImpl(Internal.InlineOption)]
internal AsyncStreamWriter(Internal.PromiseRefBase.AsyncEnumerableBase<T> target, int id)
internal AsyncStreamWriter(Internal.PromiseRefBase.AsyncEnumerableWithIterator<T> target, int id)
{
_target = target;
_id = id;
Expand All @@ -41,11 +41,11 @@ public AsyncStreamYielder<T> YieldAsync(T value)
#endif
public readonly partial struct AsyncStreamYielder<T> : ICriticalNotifyCompletion, Internal.IPromiseAwaiter
{
private readonly Internal.PromiseRefBase.AsyncEnumerableBase<T> _target;
private readonly Internal.PromiseRefBase.AsyncEnumerableWithIterator<T> _target;
private readonly int _enumerableId;

[MethodImpl(Internal.InlineOption)]
internal AsyncStreamYielder(Internal.PromiseRefBase.AsyncEnumerableBase<T> target, int enumerableId)
internal AsyncStreamYielder(Internal.PromiseRefBase.AsyncEnumerableWithIterator<T> target, int enumerableId)
{
_target = target;
_enumerableId = enumerableId;
Expand Down
4 changes: 1 addition & 3 deletions Package/Core/Linq/Generators/Canceled.cs
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
namespace Proto.Promises.Linq
{
#if CSHARP_7_3_OR_NEWER // We only expose AsyncEnumerable where custom async method builders are supported.
public static partial class AsyncEnumerable
partial class AsyncEnumerable
{
/// <summary>
/// Generates an async-enumerable sequence that will be immediately canceled when iterated.
Expand Down Expand Up @@ -72,8 +72,6 @@ internal override Promise DisposeAsync(int id)
// Do nothing, just return a resolved promise.
=> Promise.Resolved();

protected override void Start(int enumerableId) { throw new System.InvalidOperationException(); }
protected override void DisposeAndReturnToPool() { throw new System.InvalidOperationException(); }
internal override void MaybeDispose() { throw new System.InvalidOperationException(); }
}
}
Expand Down
4 changes: 2 additions & 2 deletions Package/Core/Linq/Generators/Create.cs
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ public static AsyncEnumerable<T> Create(Func<AsyncStreamWriter<T>, CancelationTo
{
ValidateArgument(asyncIterator, nameof(asyncIterator), 1);

var enumerable = Internal.AsyncEnumerableImpl<T, Internal.AsyncIterator<T>>.GetOrCreate(new Internal.AsyncIterator<T>(asyncIterator));
var enumerable = Internal.AsyncEnumerableCreate<T, Internal.AsyncIterator<T>>.GetOrCreate(new Internal.AsyncIterator<T>(asyncIterator));
return new AsyncEnumerable<T>(enumerable);
}

Expand All @@ -54,7 +54,7 @@ public static AsyncEnumerable<T> Create<TCapture>(TCapture captureValue, Func<TC
{
ValidateArgument(asyncIterator, nameof(asyncIterator), 1);

var enumerable = Internal.AsyncEnumerableImpl<T, Internal.AsyncIterator<T, TCapture>>.GetOrCreate(new Internal.AsyncIterator<T, TCapture>(captureValue, asyncIterator));
var enumerable = Internal.AsyncEnumerableCreate<T, Internal.AsyncIterator<T, TCapture>>.GetOrCreate(new Internal.AsyncIterator<T, TCapture>(captureValue, asyncIterator));
return new AsyncEnumerable<T>(enumerable);
}
}
Expand Down
6 changes: 2 additions & 4 deletions Package/Core/Linq/Generators/Empty.cs
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
namespace Proto.Promises.Linq
{
#if CSHARP_7_3_OR_NEWER // We only expose AsyncEnumerable where custom async method builders are supported.
public static partial class AsyncEnumerable
partial class AsyncEnumerable
{
/// <summary>
/// Returns an empty async-enumerable sequence.
Expand All @@ -21,7 +21,7 @@ public static AsyncEnumerable<T> Empty<T>()
=> AsyncEnumerable<T>.Empty();
}

public readonly partial struct AsyncEnumerable<T>
partial struct AsyncEnumerable<T>
{
/// <summary>
/// Returns an empty async-enumerable sequence.
Expand Down Expand Up @@ -69,8 +69,6 @@ internal override Promise DisposeAsync(int id)
// Do nothing, just return a resolved promise.
=> Promise.Resolved();

protected override void Start(int enumerableId) { throw new System.InvalidOperationException(); }
protected override void DisposeAndReturnToPool() { throw new System.InvalidOperationException(); }
internal override void MaybeDispose() { throw new System.InvalidOperationException(); }
}
}
Expand Down
15 changes: 7 additions & 8 deletions Package/Core/Linq/Generators/Range.cs
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
namespace Proto.Promises.Linq
{
#if CSHARP_7_3_OR_NEWER // We only expose AsyncEnumerable where custom async method builders are supported.
public static partial class AsyncEnumerable
partial class AsyncEnumerable
{
/// <summary>
/// Generates an async-enumerable sequence of integral numbers within a specified range.
Expand Down Expand Up @@ -104,23 +104,22 @@ internal override Promise DisposeAsync(int id)
if (Interlocked.CompareExchange(ref _enumerableId, id + 1, id) == id)
{
// This was not already disposed.
#if PROMISE_DEBUG || PROTO_PROMISE_DEVELOPER_MODE
SetCompletionState(null, Promise.State.Resolved);
#endif
DisposeAndReturnToPool();
Dispose();
}
// IAsyncDisposable.DisposeAsync must not throw if it's called multiple times, according to MSDN documentation.
return Promise.Resolved();
}

protected override void DisposeAndReturnToPool()
new private void Dispose()
{
Dispose();
#if PROMISE_DEBUG || PROTO_PROMISE_DEVELOPER_MODE
SetCompletionState(null, Promise.State.Resolved);
#endif
base.Dispose();
_disposed = true;
ObjectPool.MaybeRepool(this);
}

protected override void Start(int enumerableId) { throw new System.InvalidOperationException(); }
internal override void MaybeDispose() { throw new System.InvalidOperationException(); }
}
}
Expand Down
2 changes: 1 addition & 1 deletion Package/Core/Linq/Generators/Rejected.cs
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
namespace Proto.Promises.Linq
{
#if CSHARP_7_3_OR_NEWER // We only expose AsyncEnumerable where custom async method builders are supported.
public static partial class AsyncEnumerable
partial class AsyncEnumerable
{
/// <summary>
/// Generates an async-enumerable sequence that will be immediately rejected with the provided reason when iterated.
Expand Down
15 changes: 7 additions & 8 deletions Package/Core/Linq/Generators/Repeat.cs
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
namespace Proto.Promises.Linq
{
#if CSHARP_7_3_OR_NEWER // We only expose AsyncEnumerable where custom async method builders are supported.
public static partial class AsyncEnumerable
partial class AsyncEnumerable
{
/// <summary>
/// Generates an async-enumerable sequence that contains one repeated value.
Expand Down Expand Up @@ -107,24 +107,23 @@ internal override Promise DisposeAsync(int id)
if (Interlocked.CompareExchange(ref _enumerableId, id + 1, id) == id)
{
// This was not already disposed.
#if PROMISE_DEBUG || PROTO_PROMISE_DEVELOPER_MODE
SetCompletionState(null, Promise.State.Resolved);
#endif
DisposeAndReturnToPool();
Dispose();
}
// IAsyncDisposable.DisposeAsync must not throw if it's called multiple times, according to MSDN documentation.
return Promise.Resolved();
}

protected override void DisposeAndReturnToPool()
new private void Dispose()
{
Dispose();
#if PROMISE_DEBUG || PROTO_PROMISE_DEVELOPER_MODE
SetCompletionState(null, Promise.State.Resolved);
#endif
base.Dispose();
_current = default;
_disposed = true;
ObjectPool.MaybeRepool(this);
}

protected override void Start(int enumerableId) { throw new System.InvalidOperationException(); }
internal override void MaybeDispose() { throw new System.InvalidOperationException(); }
}
}
Expand Down
15 changes: 7 additions & 8 deletions Package/Core/Linq/Generators/Return.cs
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
namespace Proto.Promises.Linq
{
#if CSHARP_7_3_OR_NEWER // We only expose AsyncEnumerable where custom async method builders are supported.
public static partial class AsyncEnumerable
partial class AsyncEnumerable
{
/// <summary>
/// Generates an async-enumerable sequence that contains a single element.
Expand Down Expand Up @@ -83,24 +83,23 @@ internal override Promise DisposeAsync(int id)
if (Interlocked.CompareExchange(ref _enumerableId, id + 1, id) == id)
{
// This was not already disposed.
#if PROMISE_DEBUG || PROTO_PROMISE_DEVELOPER_MODE
SetCompletionState(null, Promise.State.Resolved);
#endif
DisposeAndReturnToPool();
Dispose();
}
// IAsyncDisposable.DisposeAsync must not throw if it's called multiple times, according to MSDN documentation.
return Promise.Resolved();
}

protected override void DisposeAndReturnToPool()
new private void Dispose()
{
Dispose();
#if PROMISE_DEBUG || PROTO_PROMISE_DEVELOPER_MODE
SetCompletionState(null, Promise.State.Resolved);
#endif
base.Dispose();
_current = default;
_disposed = true;
ObjectPool.MaybeRepool(this);
}

protected override void Start(int enumerableId) { throw new System.InvalidOperationException(); }
internal override void MaybeDispose() { throw new System.InvalidOperationException(); }
}
}
Expand Down
66 changes: 37 additions & 29 deletions Package/Core/Linq/Internal/AsyncEnumerableInternal.cs
Original file line number Diff line number Diff line change
Expand Up @@ -76,17 +76,11 @@ partial class PromiseRefBase
#endif
internal abstract class AsyncEnumerableBase<T> : PromiseSingleAwait<bool>
{
// This is used as the backing reference to 3 different awaiters. MoveNextAsync (Promise<bool>), DisposeAsync (Promise), and YieldAsync (AsyncStreamYielder<T>).
// We use `Interlocked.CompareExchange(ref _enumerableId` to enforce only 1 awaiter uses it at a time, in the correct order.
// We use a separate field for AsyncStreamYielder continuation, because using _next for 2 separate async functions (the iterator and the consumer) proves problematic.
protected PromiseRefBase _iteratorPromiseRef;
internal CancelationToken _cancelationToken;
protected T _current;
private int _iteratorCompleteExpectedId;
private int _iteratorCompleteId;
protected int _enumerableId = 1; // Start with Id 1 instead of 0 to reduce risk of false positives.
protected bool _disposed;
protected bool _isStarted;
internal CancelationToken _cancelationToken;

internal int EnumerableId
{
Expand Down Expand Up @@ -149,8 +143,23 @@ internal T GetCurrent(int id)
return _current;
}

[MethodImpl(InlineOption)]
internal virtual Promise<bool> MoveNextAsync(int id)
internal abstract Promise<bool> MoveNextAsync(int id);
internal abstract Promise DisposeAsync(int id);
} // class AsyncEnumerableBase<T>

#if !PROTO_PROMISE_DEVELOPER_MODE
[DebuggerNonUserCode, StackTraceHidden]
#endif
internal abstract class AsyncEnumerableWithIterator<T> : AsyncEnumerableBase<T>
{
// This is used as the backing reference to 3 different awaiters. MoveNextAsync (Promise<bool>), DisposeAsync (Promise), and YieldAsync (AsyncStreamYielder<T>).
// We use `Interlocked.CompareExchange(ref _enumerableId` to enforce only 1 awaiter uses it at a time, in the correct order.
// We use a separate field for AsyncStreamYielder continuation, because using _next for 2 separate async functions (the iterator and the consumer) proves problematic.
protected PromiseRefBase _iteratorPromiseRef;
private int _iteratorCompleteExpectedId;
private int _iteratorCompleteId;

internal override Promise<bool> MoveNextAsync(int id)
{
// We increment by 1 when MoveNextAsync, then decrement by 1 when YieldAsync.
int newId = id + 1;
Expand Down Expand Up @@ -183,6 +192,16 @@ internal virtual Promise<bool> MoveNextAsync(int id)
return new Promise<bool>(this, Id, 0);
}

private void MoveNext()
{
// Invalidate the previous awaiter.
IncrementPromiseIdAndClearPrevious();
// Reset for the next awaiter.
ResetWithoutStacktrace();
// Handle iterator promise to move the async state machine forward.
InterlockedExchange(ref _iteratorPromiseRef, null).Handle(this, null, Promise.State.Resolved);
}

[MethodImpl(InlineOption)]
internal AsyncStreamYielder<T> YieldAsync(in T value, int id)
{
Expand All @@ -199,7 +218,7 @@ internal AsyncStreamYielder<T> YieldAsync(in T value, int id)
return new AsyncStreamYielder<T>(this, newId);
}

internal virtual Promise DisposeAsync(int id)
internal override Promise DisposeAsync(int id)
{
int newId = id + 3;
// When the async iterator function completes before DisposeAsync is called, it's set to id + 2.
Expand Down Expand Up @@ -317,43 +336,32 @@ internal void AwaitOnCompletedForAsyncStreamYielder(PromiseRefBase asyncPromiseR
HandleNextInternal(null, Promise.State.Resolved);
}

protected void MoveNext()
{
// Invalidate the previous awaiter.
IncrementPromiseIdAndClearPrevious();
// Reset for the next awaiter.
ResetWithoutStacktrace();
// Handle iterator promise to move the async state machine forward.
InterlockedExchange(ref _iteratorPromiseRef, null).Handle(this, null, Promise.State.Resolved);
}

protected abstract void Start(int enumerableId);

protected abstract void DisposeAndReturnToPool();
} // class AsyncEnumerableBase<T>
} // class AsyncEnumerableWithIterator<TValue>
} // class PromiseRefBase

#if !PROTO_PROMISE_DEVELOPER_MODE
[DebuggerNonUserCode, StackTraceHidden]
#endif
internal sealed class AsyncEnumerableImpl<TValue, TIterator> : PromiseRefBase.AsyncEnumerableBase<TValue>
internal sealed class AsyncEnumerableCreate<TValue, TIterator> : PromiseRefBase.AsyncEnumerableWithIterator<TValue>
where TIterator : IAsyncIterator<TValue>
{
private TIterator _iterator;

private AsyncEnumerableImpl() { }
private AsyncEnumerableCreate() { }

[MethodImpl(InlineOption)]
private static AsyncEnumerableImpl<TValue, TIterator> GetOrCreate()
private static AsyncEnumerableCreate<TValue, TIterator> GetOrCreate()
{
var obj = ObjectPool.TryTakeOrInvalid<AsyncEnumerableImpl<TValue, TIterator>>();
var obj = ObjectPool.TryTakeOrInvalid<AsyncEnumerableCreate<TValue, TIterator>>();
return obj == InvalidAwaitSentinel.s_instance
? new AsyncEnumerableImpl<TValue, TIterator>()
: obj.UnsafeAs<AsyncEnumerableImpl<TValue, TIterator>>();
? new AsyncEnumerableCreate<TValue, TIterator>()
: obj.UnsafeAs<AsyncEnumerableCreate<TValue, TIterator>>();
}

[MethodImpl(InlineOption)]
internal static AsyncEnumerableImpl<TValue, TIterator> GetOrCreate(in TIterator iterator)
internal static AsyncEnumerableCreate<TValue, TIterator> GetOrCreate(in TIterator iterator)
{
var enumerable = GetOrCreate();
enumerable.Reset();
Expand Down
2 changes: 1 addition & 1 deletion Package/Core/Linq/Internal/MergeInternal.cs
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ partial class Internal
#if !PROTO_PROMISE_DEVELOPER_MODE
[DebuggerNonUserCode, StackTraceHidden]
#endif
internal abstract class AsyncEnumerableMergerBase<TValue> : PromiseRefBase.AsyncEnumerableBase<TValue>
internal abstract class AsyncEnumerableMergerBase<TValue> : PromiseRefBase.AsyncEnumerableWithIterator<TValue>
{
// TODO: optimize these collections.
protected readonly List<AsyncEnumerator<TValue>> _enumerators = new List<AsyncEnumerator<TValue>>();
Expand Down
2 changes: 1 addition & 1 deletion Package/Core/Promises/Internal/RejectContainersInternal.cs
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@

namespace Proto.Promises
{
internal static partial class Internal
partial class Internal
{
// Extension method instead of including on the interface, since old IL2CPP compiler does not support virtual generics with structs.
internal static bool TryGetValue<TValue>(this IRejectContainer rejectContainer, out TValue converted)
Expand Down

0 comments on commit 5ce664b

Please sign in to comment.