diff --git a/Servant.Tests/ServantTests.cs b/Servant.Tests/ServantTests.cs index 1d86fe0..e510799 100644 --- a/Servant.Tests/ServantTests.cs +++ b/Servant.Tests/ServantTests.cs @@ -411,5 +411,83 @@ public async Task Add_MultipleConstructors() } #endregion + + #region Disposal + + [ExcludeFromCodeCoverage] + private class Disposable : IDisposable + { + public int DisposeCount { get; private set; } + + public void Dispose() + { + DisposeCount++; + } + } + + [Fact] + public async Task Dispose_DisposesSingletons() + { + var servant = new Servant(); + + var singleton = new Disposable(); + servant.AddSingleton(singleton); + + await servant.CreateSingletonsAsync(); + + Assert.Equal(0, singleton.DisposeCount); + + servant.Dispose(); + + Assert.Equal(1, singleton.DisposeCount); + } + + [Fact] + public async Task Add_AfterDisposeThrows() + { + var servant = new Servant(); + + servant.Dispose(); + + var exception = Assert.Throws(() => servant.AddSingleton(new Disposable())); + + Assert.Equal(nameof(Servant), exception.ObjectName); + } + + [Fact] + public async Task CreateSingletonsAsync_AfterDisposeThrows() + { + var servant = new Servant(); + + servant.Dispose(); + + var exception = await Assert.ThrowsAsync(() => servant.CreateSingletonsAsync()); + + Assert.Equal(nameof(Servant), exception.ObjectName); + } + + [Fact] + public async Task ServeAsync_AfterDisposeThrows() + { + var servant = new Servant(); + + servant.Dispose(); + + var exception = await Assert.ThrowsAsync(() => servant.ServeAsync()); + + Assert.Equal(nameof(Servant), exception.ObjectName); + } + + [Fact] + public async Task Dispose_CanCallRepeatedly() + { + var servant = new Servant(); + + servant.Dispose(); + servant.Dispose(); + servant.Dispose(); + } + + #endregion } } diff --git a/Servant/Servant.cs b/Servant/Servant.cs index f49d34c..0c0cda6 100644 --- a/Servant/Servant.cs +++ b/Servant/Servant.cs @@ -27,6 +27,7 @@ using System.Collections.Generic; using System.Diagnostics.CodeAnalysis; using System.Linq; +using System.Threading; using System.Threading.Tasks; using JetBrains.Annotations; @@ -109,17 +110,23 @@ public async Task GetAsync() return instance; } - } - // TODO make disposable, disposing all singletons (what about transients?) + public void TryDisposeSingleton() => (_singletonInstance as IDisposable)?.Dispose(); + } /// /// Serves instances of specific types, resolving dependencies as required, and running any async initialisation. /// - public sealed class Servant + /// + /// Disposing this class will dispose any contained singleton instances that implement . + /// Transient instances are not tracked by this class and must be disposed by their consumers. + /// + public sealed class Servant : IDisposable { private readonly ConcurrentDictionary _entryByType = new ConcurrentDictionary(); + private int _disposed; + private TypeEntry GetOrAddTypeEntry(Type declaredType) => _entryByType.GetOrAdd(declaredType, t => new TypeEntry(t)); /// @@ -131,6 +138,9 @@ public sealed class Servant /// The types of dependencies required by . public void Add(Lifestyle lifestyle, Type declaredType, Func> factory, Type[] parameterTypes) { + if (_disposed != 0) + throw new ObjectDisposedException(nameof(Servant)); + // Validate the type doesn't depend upon itself if (parameterTypes.Contains(declaredType)) throw new ServantException($"Type \"{declaredType}\" depends upon its own type, which is disallowed."); @@ -188,6 +198,9 @@ private static bool DependsUpon(TypeEntry dependant, Type dependent) /// A task that completes when singleton initialisation has finished. public Task CreateSingletonsAsync() { + if (_disposed != 0) + throw new ObjectDisposedException(nameof(Servant)); + return Task.WhenAll( from typeEntry in _entryByType.Values let provider = typeEntry.Provider @@ -202,12 +215,31 @@ from typeEntry in _entryByType.Values /// A task that completes when the instance is ready. public Task ServeAsync() { + if (_disposed != 0) + throw new ObjectDisposedException(nameof(Servant)); + TypeEntry entry; if (!_entryByType.TryGetValue(typeof(T), out entry) || entry.Provider == null) throw new ServantException($"Type \"{typeof(T)}\" is not registered."); return TaskUtil.Upcast(entry.Provider.GetAsync()); } + + /// + public void Dispose() + { + if (Interlocked.CompareExchange(ref _disposed, 1, 0) != 0) + return; + + var singletonInstances = + from typeEntry in _entryByType.Values + let provider = typeEntry.Provider + where provider?.Lifestyle == Lifestyle.Singleton + select provider; + + foreach (var provider in singletonInstances) + provider.TryDisposeSingleton(); + } } ///