From 7409f4e07cd28fe320738bf80505f24f449ca383 Mon Sep 17 00:00:00 2001 From: David Tchepak Date: Thu, 7 May 2020 09:49:54 +1000 Subject: [PATCH 1/2] Add protected `BaseResult()` method to CallInfo. Create CallInfo to calls that return results and expose `BaseResult`. This gets messy to push the generic all the way through the code, so am just using a cast in `Returns` extensions to handle this. This should be safe as if we are in `Returns` then the return value should be safe to cast to a `T`. Based off discussion here: https://github.com/nsubstitute/NSubstitute/issues/622#issuecomment-640166693 --- .gitignore | 3 ++ src/NSubstitute/Core/CallInfo.cs | 17 +++++- src/NSubstitute/Core/CallInfoFactory.cs | 2 +- src/NSubstitute/Core/CallInfoWithReturns.cs | 16 ++++++ src/NSubstitute/Core/IReturn.cs | 31 +++++------ .../NoBaseImplementationException.cs | 11 ++++ .../SubstituteExtensions.Returns.Task.cs | 4 +- .../SubstituteExtensions.Returns.ValueTask.cs | 4 +- .../SubstituteExtensions.Returns.cs | 10 ++-- .../ReturnFromBase.cs | 52 +++++++++++++++++++ 10 files changed, 124 insertions(+), 26 deletions(-) create mode 100644 src/NSubstitute/Core/CallInfoWithReturns.cs create mode 100644 src/NSubstitute/Exceptions/NoBaseImplementationException.cs create mode 100644 tests/NSubstitute.Acceptance.Specs/ReturnFromBase.cs diff --git a/.gitignore b/.gitignore index b0d183069..c44c5d5e1 100644 --- a/.gitignore +++ b/.gitignore @@ -307,3 +307,6 @@ docs/_site/* # Ignore Ionide files (https://ionide.io/) .ionide + +# kdiff/merge files +*.orig diff --git a/src/NSubstitute/Core/CallInfo.cs b/src/NSubstitute/Core/CallInfo.cs index 89aa1b8ea..faeb1de1b 100644 --- a/src/NSubstitute/Core/CallInfo.cs +++ b/src/NSubstitute/Core/CallInfo.cs @@ -12,10 +12,25 @@ namespace NSubstitute.Core public class CallInfo { private readonly Argument[] _callArguments; + private readonly Func> _baseResult; - public CallInfo(Argument[] callArguments) + public CallInfo(Argument[] callArguments, Func> baseResult) { _callArguments = callArguments; + _baseResult = baseResult; + } + + protected CallInfo(CallInfo info) : this(info._callArguments, info._baseResult) { + } + + /// + /// Call and returns the result from the base implementation of a substitute for a class. + /// Will throw an exception if no base implementation exists. + /// + /// Result from base implementation + /// Throws in no base implementation exists + protected object GetBaseResult() { + return _baseResult().ValueOr(() => throw new NoBaseImplementationException()); } /// diff --git a/src/NSubstitute/Core/CallInfoFactory.cs b/src/NSubstitute/Core/CallInfoFactory.cs index fa7662d07..a02cc3fa3 100644 --- a/src/NSubstitute/Core/CallInfoFactory.cs +++ b/src/NSubstitute/Core/CallInfoFactory.cs @@ -5,7 +5,7 @@ public class CallInfoFactory : ICallInfoFactory public CallInfo Create(ICall call) { var arguments = GetArgumentsFromCall(call); - return new CallInfo(arguments); + return new CallInfo(arguments, () => call.TryCallBase()); } private static Argument[] GetArgumentsFromCall(ICall call) diff --git a/src/NSubstitute/Core/CallInfoWithReturns.cs b/src/NSubstitute/Core/CallInfoWithReturns.cs new file mode 100644 index 000000000..f08d059a7 --- /dev/null +++ b/src/NSubstitute/Core/CallInfoWithReturns.cs @@ -0,0 +1,16 @@ +namespace NSubstitute.Core +{ + /// + /// Information for a call that returns a value of type T. + /// + /// + public class CallInfo : CallInfo + { + internal CallInfo(CallInfo info) : base(info) { + } + + public T BaseResult() { + return (T)GetBaseResult(); + } + } +} \ No newline at end of file diff --git a/src/NSubstitute/Core/IReturn.cs b/src/NSubstitute/Core/IReturn.cs index 0901f2820..30038a624 100644 --- a/src/NSubstitute/Core/IReturn.cs +++ b/src/NSubstitute/Core/IReturn.cs @@ -39,21 +39,21 @@ public ReturnValue(object? value) public class ReturnValueFromFunc : IReturn { - private readonly Func _funcToReturnValue; + private readonly Func, T?> _funcToReturnValue; - public ReturnValueFromFunc(Func? funcToReturnValue) + public ReturnValueFromFunc(Func, T?>? funcToReturnValue) { _funcToReturnValue = funcToReturnValue ?? ReturnNull(); } - public object? ReturnFor(CallInfo info) => _funcToReturnValue(info); - public Type TypeOrNull() => typeof (T); - public bool CanBeAssignedTo(Type t) => typeof (T).IsAssignableFrom(t); + public object? ReturnFor(CallInfo info) => _funcToReturnValue(new CallInfo(info)); + public Type TypeOrNull() => typeof(T); + public bool CanBeAssignedTo(Type t) => typeof(T).IsAssignableFrom(t); private static Func ReturnNull() { if (typeof(T).GetTypeInfo().IsValueType) throw new CannotReturnNullForValueType(typeof(T)); - return x => default(T); + return x => default; } } @@ -70,27 +70,28 @@ public ReturnMultipleValues(T?[] values) public object? GetReturnValue() => GetNext(); public object? ReturnFor(CallInfo info) => GetReturnValue(); - public Type TypeOrNull() => typeof (T); - public bool CanBeAssignedTo(Type t) => typeof (T).IsAssignableFrom(t); + public Type TypeOrNull() => typeof(T); + public bool CanBeAssignedTo(Type t) => typeof(T).IsAssignableFrom(t); private T? GetNext() => _valuesToReturn.TryDequeue(out var nextResult) ? nextResult : _lastValue; } public class ReturnMultipleFuncsValues : IReturn { - private readonly ConcurrentQueue> _funcsToReturn; - private readonly Func _lastFunc; + private readonly ConcurrentQueue, T?>> _funcsToReturn; + private readonly Func, T?> _lastFunc; - public ReturnMultipleFuncsValues(Func[] funcs) + public ReturnMultipleFuncsValues(Func, T?>[] funcs) { - _funcsToReturn = new ConcurrentQueue>(funcs); + _funcsToReturn = new ConcurrentQueue, T?>>(funcs); _lastFunc = funcs.Last(); } public object? ReturnFor(CallInfo info) => GetNext(info); - public Type TypeOrNull() => typeof (T); - public bool CanBeAssignedTo(Type t) => typeof (T).IsAssignableFrom(t); + public Type TypeOrNull() => typeof(T); + public bool CanBeAssignedTo(Type t) => typeof(T).IsAssignableFrom(t); - private T? GetNext(CallInfo info) => _funcsToReturn.TryDequeue(out var nextFunc) ? nextFunc(info) : _lastFunc(info); + private T? GetNext(CallInfo info) => + _funcsToReturn.TryDequeue(out var nextFunc) ? nextFunc(new CallInfo(info)) : _lastFunc(new CallInfo(info)); } } \ No newline at end of file diff --git a/src/NSubstitute/Exceptions/NoBaseImplementationException.cs b/src/NSubstitute/Exceptions/NoBaseImplementationException.cs new file mode 100644 index 000000000..3ab623ff1 --- /dev/null +++ b/src/NSubstitute/Exceptions/NoBaseImplementationException.cs @@ -0,0 +1,11 @@ +namespace NSubstitute.Exceptions +{ + public class NoBaseImplementationException : SubstituteException + { + private const string Explanation = + "Cannot call the base method as the base method implementation is missing. " + + "You can call base method only if you create a class substitute and the method is not abstract."; + + public NoBaseImplementationException() : base(Explanation) { } + } +} diff --git a/src/NSubstitute/SubstituteExtensions.Returns.Task.cs b/src/NSubstitute/SubstituteExtensions.Returns.Task.cs index 57388dfeb..87ed66eeb 100644 --- a/src/NSubstitute/SubstituteExtensions.Returns.Task.cs +++ b/src/NSubstitute/SubstituteExtensions.Returns.Task.cs @@ -40,7 +40,7 @@ public static ConfiguredCall Returns(this Task value, Func re var wrappedFunc = WrapFuncInTask(returnThis); var wrappedReturnThese = returnThese.Length > 0 ? returnThese.Select(WrapFuncInTask).ToArray() : null; - return ConfigureReturn(MatchArgs.AsSpecifiedInCall, wrappedFunc, wrappedReturnThese); + return ConfigureFuncReturn(MatchArgs.AsSpecifiedInCall, wrappedFunc, wrappedReturnThese); } /// @@ -72,7 +72,7 @@ public static ConfiguredCall ReturnsForAnyArgs(this Task value, Func 0 ? returnThese.Select(WrapFuncInTask).ToArray() : null; - return ConfigureReturn(MatchArgs.Any, wrappedFunc, wrappedReturnThese); + return ConfigureFuncReturn(MatchArgs.Any, wrappedFunc, wrappedReturnThese); } #nullable restore diff --git a/src/NSubstitute/SubstituteExtensions.Returns.ValueTask.cs b/src/NSubstitute/SubstituteExtensions.Returns.ValueTask.cs index d899f0f7d..b96cba512 100644 --- a/src/NSubstitute/SubstituteExtensions.Returns.ValueTask.cs +++ b/src/NSubstitute/SubstituteExtensions.Returns.ValueTask.cs @@ -40,7 +40,7 @@ public static ConfiguredCall Returns(this ValueTask value, Func 0 ? returnThese.Select(WrapFuncInValueTask).ToArray() : null; - return ConfigureReturn(MatchArgs.AsSpecifiedInCall, wrappedFunc, wrappedReturnThese); + return ConfigureFuncReturn(MatchArgs.AsSpecifiedInCall, wrappedFunc, wrappedReturnThese); } /// @@ -72,7 +72,7 @@ public static ConfiguredCall ReturnsForAnyArgs(this ValueTask value, Func< var wrappedFunc = WrapFuncInValueTask(returnThis); var wrappedReturnThese = returnThese.Length > 0 ? returnThese.Select(WrapFuncInValueTask).ToArray() : null; - return ConfigureReturn(MatchArgs.Any, wrappedFunc, wrappedReturnThese); + return ConfigureFuncReturn(MatchArgs.Any, wrappedFunc, wrappedReturnThese); } #nullable restore diff --git a/src/NSubstitute/SubstituteExtensions.Returns.cs b/src/NSubstitute/SubstituteExtensions.Returns.cs index 0e959ebd4..3ff53d65b 100644 --- a/src/NSubstitute/SubstituteExtensions.Returns.cs +++ b/src/NSubstitute/SubstituteExtensions.Returns.cs @@ -24,8 +24,8 @@ public static ConfiguredCall Returns(this T value, T returnThis, params T[] r /// /// Function to calculate the return value /// Optionally use these functions next - public static ConfiguredCall Returns(this T value, Func returnThis, params Func[] returnThese) => - ConfigureReturn(MatchArgs.AsSpecifiedInCall, returnThis, returnThese); + public static ConfiguredCall Returns(this T value, Func, T> returnThis, params Func, T>[] returnThese) => + ConfigureFuncReturn(MatchArgs.AsSpecifiedInCall, returnThis, returnThese); /// /// Set a return value for this call made with any arguments. @@ -43,8 +43,8 @@ public static ConfiguredCall ReturnsForAnyArgs(this T value, T returnThis, pa /// Function to calculate the return value /// Optionally use these functions next /// - public static ConfiguredCall ReturnsForAnyArgs(this T value, Func returnThis, params Func[] returnThese) => - ConfigureReturn(MatchArgs.Any, returnThis, returnThese); + public static ConfiguredCall ReturnsForAnyArgs(this T value, Func, T> returnThis, params Func, T>[] returnThese) => + ConfigureFuncReturn(MatchArgs.Any, returnThis, returnThese); #nullable restore private static ConfiguredCall ConfigureReturn(MatchArgs matchArgs, T? returnThis, T?[]? returnThese) @@ -64,7 +64,7 @@ private static ConfiguredCall ConfigureReturn(MatchArgs matchArgs, T? returnT .LastCallShouldReturn(returnValue, matchArgs); } - private static ConfiguredCall ConfigureReturn(MatchArgs matchArgs, Func returnThis, Func[]? returnThese) + private static ConfiguredCall ConfigureFuncReturn(MatchArgs matchArgs, Func, T?> returnThis, Func, T?>[]? returnThese) { IReturn returnValue; if (returnThese == null || returnThese.Length == 0) diff --git a/tests/NSubstitute.Acceptance.Specs/ReturnFromBase.cs b/tests/NSubstitute.Acceptance.Specs/ReturnFromBase.cs new file mode 100644 index 000000000..b17d73bac --- /dev/null +++ b/tests/NSubstitute.Acceptance.Specs/ReturnFromBase.cs @@ -0,0 +1,52 @@ +using System; +using NSubstitute.Exceptions; +using NUnit.Framework; + +namespace NSubstitute.Acceptance.Specs +{ + public class ReturnFromBase + { + public class Sample + { + public virtual string RepeatButLouder(string s) => s + "!"; + public virtual void VoidMethod() { } + } + + public abstract class SampleWithAbstractMethod + { + public abstract string NoBaseImplementation(); + } + + public interface ISample + { + string InterfaceMethod(); + } + + [Test] + public void UseBaseInReturn() { + var sub = Substitute.For(); + sub.RepeatButLouder(Arg.Any()).Returns(x => x.BaseResult() + "?"); + + Assert.AreEqual("Hi!?", sub.RepeatButLouder("Hi")); + } + + [Test] + public void CallWithNoBaseImplementation() { + var sub = Substitute.For(); + sub.NoBaseImplementation().Returns(x => x.BaseResult()); + + Assert.Throws(() => + sub.NoBaseImplementation() + ); + } + + [Test] + public void CallBaseForInterface() { + var sub = Substitute.For(); + sub.InterfaceMethod().Returns(x => x.BaseResult()); + Assert.Throws(() => + sub.InterfaceMethod() + ); + } + } +} \ No newline at end of file From b89b34ea40f7c2956fbc672abe41286f9e46c5f1 Mon Sep 17 00:00:00 2001 From: David Tchepak Date: Sun, 28 Mar 2021 16:31:13 +1100 Subject: [PATCH 2/2] Extract ICallInfo interfaces (#641) Apply review comments. --- src/NSubstitute/Callback.cs | 26 +++--- .../Callbacks/ConfiguredCallback.cs | 10 +-- src/NSubstitute/Core/CallInfo.cs | 82 ++++++------------- src/NSubstitute/Core/CallInfoWithReturns.cs | 2 +- src/NSubstitute/Core/IReturn.cs | 32 ++++---- src/NSubstitute/Core/WhenCalled.cs | 4 +- .../Extensions/ExceptionExtensions.cs | 4 +- .../Extensions/ReturnsForAllExtensions.cs | 2 +- src/NSubstitute/ICallInfo.cs | 68 +++++++++++++++ .../ReturnFromAndConfigureDynamicCall.cs | 4 +- .../SubstituteExtensions.Returns.Task.cs | 6 +- .../SubstituteExtensions.Returns.ValueTask.cs | 6 +- .../SubstituteExtensions.Returns.cs | 6 +- .../SubstituteExtensions.When.Task.cs | 4 +- .../SubstituteExtensions.When.ValueTask.cs | 4 +- src/NSubstitute/SubstituteExtensions.When.cs | 4 +- .../CallbackCalling.cs | 2 +- .../WhenCalledDo.cs | 2 +- 18 files changed, 152 insertions(+), 116 deletions(-) create mode 100644 src/NSubstitute/ICallInfo.cs diff --git a/src/NSubstitute/Callback.cs b/src/NSubstitute/Callback.cs index 2d7f4dae2..7cf4a1705 100644 --- a/src/NSubstitute/Callback.cs +++ b/src/NSubstitute/Callback.cs @@ -18,7 +18,7 @@ public class Callback /// /// /// - public static ConfiguredCallback First(Action doThis) + public static ConfiguredCallback First(Action doThis) { return new ConfiguredCallback().Then(doThis); } @@ -28,7 +28,7 @@ public static ConfiguredCallback First(Action doThis) /// /// /// - public static Callback Always(Action doThis) + public static Callback Always(Action doThis) { return new ConfiguredCallback().AndAlways(doThis); } @@ -38,7 +38,7 @@ public static Callback Always(Action doThis) /// /// /// - public static ConfiguredCallback FirstThrow(Func throwThis) where TException : Exception + public static ConfiguredCallback FirstThrow(Func throwThis) where TException : Exception { return new ConfiguredCallback().ThenThrow(throwThis); } @@ -59,7 +59,7 @@ public static ConfiguredCallback FirstThrow(TException exception) wh /// The type of the exception. /// The throw this. /// - public static Callback AlwaysThrow(Func throwThis) where TException : Exception + public static Callback AlwaysThrow(Func throwThis) where TException : Exception { return new ConfiguredCallback().AndAlways(ToCallback(throwThis)); } @@ -75,33 +75,33 @@ public static Callback AlwaysThrow(TException exception) where TExce return AlwaysThrow(_ => exception); } - protected static Action ToCallback(Func throwThis) + protected static Action ToCallback(Func throwThis) where TException : notnull, Exception { return ci => { if (throwThis != null) throw throwThis(ci); }; } internal Callback() { } - private readonly ConcurrentQueue> callbackQueue = new ConcurrentQueue>(); - private Action alwaysDo = x => { }; - private Action keepDoing = x => { }; + private readonly ConcurrentQueue> callbackQueue = new ConcurrentQueue>(); + private Action alwaysDo = x => { }; + private Action keepDoing = x => { }; - protected void AddCallback(Action doThis) + protected void AddCallback(Action doThis) { callbackQueue.Enqueue(doThis); } - protected void SetAlwaysDo(Action always) + protected void SetAlwaysDo(Action always) { alwaysDo = always ?? (_ => { }); } - protected void SetKeepDoing(Action keep) + protected void SetKeepDoing(Action keep) { keepDoing = keep ?? (_ => { }); } - public void Call(CallInfo callInfo) + public void Call(ICallInfo callInfo) { try { @@ -113,7 +113,7 @@ public void Call(CallInfo callInfo) } } - private void CallFromStack(CallInfo callInfo) + private void CallFromStack(ICallInfo callInfo) { if (callbackQueue.TryDequeue(out var callback)) { diff --git a/src/NSubstitute/Callbacks/ConfiguredCallback.cs b/src/NSubstitute/Callbacks/ConfiguredCallback.cs index 0c32ad881..871b2cf58 100644 --- a/src/NSubstitute/Callbacks/ConfiguredCallback.cs +++ b/src/NSubstitute/Callbacks/ConfiguredCallback.cs @@ -13,7 +13,7 @@ internal ConfiguredCallback() { } /// /// Perform this action once in chain of called callbacks. /// - public ConfiguredCallback Then(Action doThis) + public ConfiguredCallback Then(Action doThis) { AddCallback(doThis); return this; @@ -22,7 +22,7 @@ public ConfiguredCallback Then(Action doThis) /// /// Keep doing this action after the other callbacks have run. /// - public EndCallbackChain ThenKeepDoing(Action doThis) + public EndCallbackChain ThenKeepDoing(Action doThis) { SetKeepDoing(doThis); return this; @@ -31,7 +31,7 @@ public EndCallbackChain ThenKeepDoing(Action doThis) /// /// Keep throwing this exception after the other callbacks have run. /// - public EndCallbackChain ThenKeepThrowing(Func throwThis) where TException : Exception => + public EndCallbackChain ThenKeepThrowing(Func throwThis) where TException : Exception => ThenKeepDoing(ToCallback(throwThis)); /// @@ -45,7 +45,7 @@ public EndCallbackChain ThenKeepThrowing(TException throwThis) where /// /// The type of the exception /// Produce the exception to throw for a CallInfo - public ConfiguredCallback ThenThrow(Func throwThis) where TException : Exception + public ConfiguredCallback ThenThrow(Func throwThis) where TException : Exception { AddCallback(ToCallback(throwThis)); return this; @@ -68,7 +68,7 @@ internal EndCallbackChain() { } /// Perform the given action for every call. /// /// The action to perform for every call - public Callback AndAlways(Action doThis) + public Callback AndAlways(Action doThis) { SetAlwaysDo(doThis); return this; diff --git a/src/NSubstitute/Core/CallInfo.cs b/src/NSubstitute/Core/CallInfo.cs index faeb1de1b..faf4bc652 100644 --- a/src/NSubstitute/Core/CallInfo.cs +++ b/src/NSubstitute/Core/CallInfo.cs @@ -9,13 +9,12 @@ namespace NSubstitute.Core { - public class CallInfo + public class CallInfo : ICallInfo { private readonly Argument[] _callArguments; private readonly Func> _baseResult; - public CallInfo(Argument[] callArguments, Func> baseResult) - { + public CallInfo(Argument[] callArguments, Func> baseResult) { _callArguments = callArguments; _baseResult = baseResult; } @@ -33,63 +32,41 @@ protected object GetBaseResult() { return _baseResult().ValueOr(() => throw new NoBaseImplementationException()); } - /// - /// Gets the nth argument to this call. - /// - /// Index of argument - /// The value of the argument at the given index - public object this[int index] - { + /// + public object this[int index] { get => _callArguments[index].Value; - set - { + set { var argument = _callArguments[index]; EnsureArgIsSettable(argument, index, value); argument.Value = value; } } - private void EnsureArgIsSettable(Argument argument, int index, object value) - { - if (!argument.IsByRef) - { + private void EnsureArgIsSettable(Argument argument, int index, object value) { + if (!argument.IsByRef) { throw new ArgumentIsNotOutOrRefException(index, argument.DeclaredType); } - if (value != null && !argument.CanSetValueWithInstanceOf(value.GetType())) - { + if (value != null && !argument.CanSetValueWithInstanceOf(value.GetType())) { throw new ArgumentSetWithIncompatibleValueException(index, argument.DeclaredType, value.GetType()); } } - /// - /// Get the arguments passed to this call. - /// - /// Array of all arguments passed to this call + /// public object[] Args() => _callArguments.Select(x => x.Value).ToArray(); - /// - /// Gets the types of all the arguments passed to this call. - /// - /// Array of types of all arguments passed to this call + /// public Type[] ArgTypes() => _callArguments.Select(x => x.DeclaredType).ToArray(); - /// - /// Gets the argument of type `T` passed to this call. This will throw if there are no arguments - /// of this type, or if there is more than one matching argument. - /// - /// The type of the argument to retrieve - /// The argument passed to the call, or throws if there is not exactly one argument of this type - public T Arg() - { + /// + public T Arg() { T arg; if (TryGetArg(x => x.IsDeclaredTypeEqualToOrByRefVersionOf(typeof(T)), out arg)) return arg; if (TryGetArg(x => x.IsValueAssignableTo(typeof(T)), out arg)) return arg; throw new ArgumentNotFoundException("Can not find an argument of type " + typeof(T).FullName + " to this call."); } - private bool TryGetArg(Func condition, [MaybeNullWhen(false)] out T value) - { + private bool TryGetArg(Func condition, [MaybeNullWhen(false)] out T value) { value = default; var matchingArgs = _callArguments.Where(condition); @@ -100,10 +77,8 @@ private bool TryGetArg(Func condition, [MaybeNullWhen(false)] return true; } - private void ThrowIfMoreThanOne(IEnumerable arguments) - { - if (arguments.Skip(1).Any()) - { + private void ThrowIfMoreThanOne(IEnumerable arguments) { + if (arguments.Skip(1).Any()) { throw new AmbiguousArgumentsException( "There is more than one argument of type " + typeof(T).FullName + " to this call.\n" + "The call signature is (" + DisplayTypes(ArgTypes()) + ")\n" + @@ -112,27 +87,15 @@ private void ThrowIfMoreThanOne(IEnumerable arguments) } } - /// - /// Gets the argument passed to this call at the specified zero-based position, converted to type `T`. - /// This will throw if there are no arguments, if the argument is out of range or if it - /// cannot be converted to the specified type. - /// - /// The type of the argument to retrieve - /// The zero-based position of the argument to retrieve - /// The argument passed to the call, or throws if there is not exactly one argument of this type - public T ArgAt(int position) - { - if (position >= _callArguments.Length) - { + /// + public T ArgAt(int position) { + if (position >= _callArguments.Length) { throw new ArgumentOutOfRangeException(nameof(position), $"There is no argument at position {position}"); } - try - { - return (T) _callArguments[position].Value!; - } - catch (InvalidCastException) - { + try { + return (T)_callArguments[position].Value!; + } catch (InvalidCastException) { throw new InvalidCastException( $"Couldn't convert parameter at position {position} to type {typeof(T).FullName}"); } @@ -140,5 +103,8 @@ public T ArgAt(int position) private static string DisplayTypes(IEnumerable types) => string.Join(", ", types.Select(x => x.Name).ToArray()); + + /// + public ICallInfo ForCallReturning() => new CallInfo(this); } } diff --git a/src/NSubstitute/Core/CallInfoWithReturns.cs b/src/NSubstitute/Core/CallInfoWithReturns.cs index f08d059a7..5eb12aefd 100644 --- a/src/NSubstitute/Core/CallInfoWithReturns.cs +++ b/src/NSubstitute/Core/CallInfoWithReturns.cs @@ -4,7 +4,7 @@ /// Information for a call that returns a value of type T. /// /// - public class CallInfo : CallInfo + public class CallInfo : CallInfo, ICallInfo { internal CallInfo(CallInfo info) : base(info) { } diff --git a/src/NSubstitute/Core/IReturn.cs b/src/NSubstitute/Core/IReturn.cs index 30038a624..817e25008 100644 --- a/src/NSubstitute/Core/IReturn.cs +++ b/src/NSubstitute/Core/IReturn.cs @@ -9,13 +9,13 @@ namespace NSubstitute.Core { public interface IReturn { - object? ReturnFor(CallInfo info); + object? ReturnFor(ICallInfo info); Type? TypeOrNull(); bool CanBeAssignedTo(Type t); } /// - /// Performance optimization. Allows to not construct if configured result doesn't depend on it. + /// Performance optimization. Allows to not construct if configured result doesn't depend on it. /// internal interface ICallIndependentReturn { @@ -32,25 +32,25 @@ public ReturnValue(object? value) } public object? GetReturnValue() => _value; - public object? ReturnFor(CallInfo info) => GetReturnValue(); + public object? ReturnFor(ICallInfo info) => GetReturnValue(); public Type? TypeOrNull() => _value?.GetType(); public bool CanBeAssignedTo(Type t) => _value.IsCompatibleWith(t); } public class ReturnValueFromFunc : IReturn { - private readonly Func, T?> _funcToReturnValue; + private readonly Func, T?> _funcToReturnValue; - public ReturnValueFromFunc(Func, T?>? funcToReturnValue) + public ReturnValueFromFunc(Func, T?>? funcToReturnValue) { _funcToReturnValue = funcToReturnValue ?? ReturnNull(); } - public object? ReturnFor(CallInfo info) => _funcToReturnValue(new CallInfo(info)); + public object? ReturnFor(ICallInfo info) => _funcToReturnValue(info.ForCallReturning()); public Type TypeOrNull() => typeof(T); public bool CanBeAssignedTo(Type t) => typeof(T).IsAssignableFrom(t); - private static Func ReturnNull() + private static Func ReturnNull() { if (typeof(T).GetTypeInfo().IsValueType) throw new CannotReturnNullForValueType(typeof(T)); return x => default; @@ -69,7 +69,7 @@ public ReturnMultipleValues(T?[] values) } public object? GetReturnValue() => GetNext(); - public object? ReturnFor(CallInfo info) => GetReturnValue(); + public object? ReturnFor(ICallInfo info) => GetReturnValue(); public Type TypeOrNull() => typeof(T); public bool CanBeAssignedTo(Type t) => typeof(T).IsAssignableFrom(t); @@ -78,20 +78,22 @@ public ReturnMultipleValues(T?[] values) public class ReturnMultipleFuncsValues : IReturn { - private readonly ConcurrentQueue, T?>> _funcsToReturn; - private readonly Func, T?> _lastFunc; + private readonly ConcurrentQueue, T?>> _funcsToReturn; + private readonly Func, T?> _lastFunc; - public ReturnMultipleFuncsValues(Func, T?>[] funcs) + public ReturnMultipleFuncsValues(Func, T?>[] funcs) { - _funcsToReturn = new ConcurrentQueue, T?>>(funcs); + _funcsToReturn = new ConcurrentQueue, T?>>(funcs); _lastFunc = funcs.Last(); } - public object? ReturnFor(CallInfo info) => GetNext(info); + public object? ReturnFor(ICallInfo info) => GetNext(info); public Type TypeOrNull() => typeof(T); public bool CanBeAssignedTo(Type t) => typeof(T).IsAssignableFrom(t); - private T? GetNext(CallInfo info) => - _funcsToReturn.TryDequeue(out var nextFunc) ? nextFunc(new CallInfo(info)) : _lastFunc(new CallInfo(info)); + private T? GetNext(ICallInfo info) => + _funcsToReturn.TryDequeue(out var nextFunc) + ? nextFunc(info.ForCallReturning()) + : _lastFunc(info.ForCallReturning()); } } \ No newline at end of file diff --git a/src/NSubstitute/Core/WhenCalled.cs b/src/NSubstitute/Core/WhenCalled.cs index fbc80c24d..daf9e4104 100644 --- a/src/NSubstitute/Core/WhenCalled.cs +++ b/src/NSubstitute/Core/WhenCalled.cs @@ -29,7 +29,7 @@ public WhenCalled(ISubstitutionContext context, T substitute, Action call, Ma /// Perform this action when called. /// /// - public void Do(Action callbackWithArguments) + public void Do(Action callbackWithArguments) { _threadContext.SetNextRoute(_callRouter, x => _routeFactory.DoWhenCalled(x, callbackWithArguments, _matchArgs)); _call(_substitute); @@ -82,7 +82,7 @@ public void Throw(Exception exception) => /// /// Throw an exception generated by the specified function when called. /// - public void Throw(Func createException) => + public void Throw(Func createException) => Do(ci => throw createException(ci)); } } \ No newline at end of file diff --git a/src/NSubstitute/Extensions/ExceptionExtensions.cs b/src/NSubstitute/Extensions/ExceptionExtensions.cs index 60f3b5339..15004dc88 100644 --- a/src/NSubstitute/Extensions/ExceptionExtensions.cs +++ b/src/NSubstitute/Extensions/ExceptionExtensions.cs @@ -35,7 +35,7 @@ public static ConfiguredCall Throws(this object value) /// /// Func creating exception object /// - public static ConfiguredCall Throws(this object value, Func createException) => + public static ConfiguredCall Throws(this object value, Func createException) => value.Returns(ci => throw createException(ci)); /// @@ -65,7 +65,7 @@ public static ConfiguredCall ThrowsForAnyArgs(this object value) /// /// Func creating exception object /// - public static ConfiguredCall ThrowsForAnyArgs(this object value, Func createException) => + public static ConfiguredCall ThrowsForAnyArgs(this object value, Func createException) => value.ReturnsForAnyArgs(ci => throw createException(ci)); } } diff --git a/src/NSubstitute/Extensions/ReturnsForAllExtensions.cs b/src/NSubstitute/Extensions/ReturnsForAllExtensions.cs index 302033d4c..81df863c6 100644 --- a/src/NSubstitute/Extensions/ReturnsForAllExtensions.cs +++ b/src/NSubstitute/Extensions/ReturnsForAllExtensions.cs @@ -31,7 +31,7 @@ public static void ReturnsForAll(this object substitute, T returnThis) /// /// /// - public static void ReturnsForAll(this object substitute, Func returnThis) + public static void ReturnsForAll(this object substitute, Func returnThis) { if (substitute == null) throw new NullSubstituteReferenceException(); diff --git a/src/NSubstitute/ICallInfo.cs b/src/NSubstitute/ICallInfo.cs new file mode 100644 index 000000000..2a72b572f --- /dev/null +++ b/src/NSubstitute/ICallInfo.cs @@ -0,0 +1,68 @@ +using System; + +namespace NSubstitute +{ + /// + /// Access information on arguments for call. + /// + public interface ICallInfo + { + /// + /// Gets the nth argument to this call. + /// + /// Index of argument + /// The value of the argument at the given index + object this[int index] { get; set; } + + /// + /// Gets the argument of type `T` passed to this call. This will throw if there are no arguments + /// of this type, or if there is more than one matching argument. + /// + /// The type of the argument to retrieve + /// The argument passed to the call, or throws if there is not exactly one argument of this type + T Arg(); + + /// + /// Gets the argument passed to this call at the specified zero-based position, converted to type `T`. + /// This will throw if there are no arguments, if the argument is out of range or if it + /// cannot be converted to the specified type. + /// + /// The type of the argument to retrieve + /// The zero-based position of the argument to retrieve + /// The argument passed to the call, or throws if there is not exactly one argument of this type + T ArgAt(int position); + + /// + /// Get the arguments passed to this call. + /// + /// Array of all arguments passed to this call + object[] Args(); + + /// + /// Gets the types of all the arguments passed to this call. + /// + /// Array of types of all arguments passed to this call + Type[] ArgTypes(); + + /// + /// If we are sure this call returns a value of type , return an + /// that allows us to access the . + /// + /// This will not be checked by the compiler, so if this method is misused the resulting + /// may throw . + /// + /// + /// + ICallInfo ForCallReturning(); + } + + public interface ICallInfo : ICallInfo + { + /// + /// Calls the base implementation and attempts to cast the result to . + /// + /// Result from base (non-substituted) implementation of call + /// + T BaseResult(); + } +} \ No newline at end of file diff --git a/src/NSubstitute/Routing/Handlers/ReturnFromAndConfigureDynamicCall.cs b/src/NSubstitute/Routing/Handlers/ReturnFromAndConfigureDynamicCall.cs index e4fadd2c4..70a720dad 100644 --- a/src/NSubstitute/Routing/Handlers/ReturnFromAndConfigureDynamicCall.cs +++ b/src/NSubstitute/Routing/Handlers/ReturnFromAndConfigureDynamicCall.cs @@ -55,7 +55,7 @@ public ConfiguredCall Returns(T? returnThis, params T?[] returnThese) return default(T).Returns(returnThis, returnThese); } - public ConfiguredCall Returns(Func returnThis, params Func[] returnThese) + public ConfiguredCall Returns(Func returnThis, params Func[] returnThese) { return default(T).Returns(returnThis, returnThese); } @@ -65,7 +65,7 @@ public ConfiguredCall ReturnsForAnyArgs(T? returnThis, params T?[] returnThes return default(T).ReturnsForAnyArgs(returnThis, returnThese); } - public ConfiguredCall ReturnsForAnyArgs(Func returnThis, params Func[] returnThese) + public ConfiguredCall ReturnsForAnyArgs(Func returnThis, params Func[] returnThese) { return default(T).ReturnsForAnyArgs(returnThis, returnThese); } diff --git a/src/NSubstitute/SubstituteExtensions.Returns.Task.cs b/src/NSubstitute/SubstituteExtensions.Returns.Task.cs index 87ed66eeb..6d8ba1cda 100644 --- a/src/NSubstitute/SubstituteExtensions.Returns.Task.cs +++ b/src/NSubstitute/SubstituteExtensions.Returns.Task.cs @@ -33,7 +33,7 @@ public static ConfiguredCall Returns(this Task value, T returnThis, params /// /// Function to calculate the return value /// Optionally use these functions next - public static ConfiguredCall Returns(this Task value, Func returnThis, params Func[] returnThese) + public static ConfiguredCall Returns(this Task value, Func returnThis, params Func[] returnThese) { ReThrowOnNSubstituteFault(value); @@ -65,7 +65,7 @@ public static ConfiguredCall ReturnsForAnyArgs(this Task value, T returnTh /// /// Function to calculate the return value /// Optionally use these functions next - public static ConfiguredCall ReturnsForAnyArgs(this Task value, Func returnThis, params Func[] returnThese) + public static ConfiguredCall ReturnsForAnyArgs(this Task value, Func returnThis, params Func[] returnThese) { ReThrowOnNSubstituteFault(value); @@ -86,7 +86,7 @@ private static void ReThrowOnNSubstituteFault(Task task) private static Task CompletedTask(T? result) => Task.FromResult(result); - private static Func> WrapFuncInTask(Func returnThis) => + private static Func> WrapFuncInTask(Func returnThis) => x => CompletedTask(returnThis(x)); } } \ No newline at end of file diff --git a/src/NSubstitute/SubstituteExtensions.Returns.ValueTask.cs b/src/NSubstitute/SubstituteExtensions.Returns.ValueTask.cs index b96cba512..91483835e 100644 --- a/src/NSubstitute/SubstituteExtensions.Returns.ValueTask.cs +++ b/src/NSubstitute/SubstituteExtensions.Returns.ValueTask.cs @@ -33,7 +33,7 @@ public static ConfiguredCall Returns(this ValueTask value, T returnThis, p /// /// Function to calculate the return value /// Optionally use these functions next - public static ConfiguredCall Returns(this ValueTask value, Func returnThis, params Func[] returnThese) + public static ConfiguredCall Returns(this ValueTask value, Func returnThis, params Func[] returnThese) { ReThrowOnNSubstituteFault(value); @@ -65,7 +65,7 @@ public static ConfiguredCall ReturnsForAnyArgs(this ValueTask value, T ret /// /// Function to calculate the return value /// Optionally use these functions next - public static ConfiguredCall ReturnsForAnyArgs(this ValueTask value, Func returnThis, params Func[] returnThese) + public static ConfiguredCall ReturnsForAnyArgs(this ValueTask value, Func returnThis, params Func[] returnThese) { ReThrowOnNSubstituteFault(value); @@ -86,7 +86,7 @@ private static void ReThrowOnNSubstituteFault(ValueTask task) private static ValueTask CompletedValueTask(T? result) => new(result); - private static Func> WrapFuncInValueTask(Func returnThis) => + private static Func> WrapFuncInValueTask(Func returnThis) => x => CompletedValueTask(returnThis(x)); } } \ No newline at end of file diff --git a/src/NSubstitute/SubstituteExtensions.Returns.cs b/src/NSubstitute/SubstituteExtensions.Returns.cs index 3ff53d65b..13bcbd57f 100644 --- a/src/NSubstitute/SubstituteExtensions.Returns.cs +++ b/src/NSubstitute/SubstituteExtensions.Returns.cs @@ -24,7 +24,7 @@ public static ConfiguredCall Returns(this T value, T returnThis, params T[] r /// /// Function to calculate the return value /// Optionally use these functions next - public static ConfiguredCall Returns(this T value, Func, T> returnThis, params Func, T>[] returnThese) => + public static ConfiguredCall Returns(this T value, Func, T> returnThis, params Func, T>[] returnThese) => ConfigureFuncReturn(MatchArgs.AsSpecifiedInCall, returnThis, returnThese); /// @@ -43,7 +43,7 @@ public static ConfiguredCall ReturnsForAnyArgs(this T value, T returnThis, pa /// Function to calculate the return value /// Optionally use these functions next /// - public static ConfiguredCall ReturnsForAnyArgs(this T value, Func, T> returnThis, params Func, T>[] returnThese) => + public static ConfiguredCall ReturnsForAnyArgs(this T value, Func, T> returnThis, params Func, T>[] returnThese) => ConfigureFuncReturn(MatchArgs.Any, returnThis, returnThese); #nullable restore @@ -64,7 +64,7 @@ private static ConfiguredCall ConfigureReturn(MatchArgs matchArgs, T? returnT .LastCallShouldReturn(returnValue, matchArgs); } - private static ConfiguredCall ConfigureFuncReturn(MatchArgs matchArgs, Func, T?> returnThis, Func, T?>[]? returnThese) + private static ConfiguredCall ConfigureFuncReturn(MatchArgs matchArgs, Func, T?> returnThis, Func, T?>[]? returnThese) { IReturn returnValue; if (returnThese == null || returnThese.Length == 0) diff --git a/src/NSubstitute/SubstituteExtensions.When.Task.cs b/src/NSubstitute/SubstituteExtensions.When.Task.cs index 79a20a2fa..37231c93d 100644 --- a/src/NSubstitute/SubstituteExtensions.When.Task.cs +++ b/src/NSubstitute/SubstituteExtensions.When.Task.cs @@ -11,7 +11,7 @@ public static partial class SubstituteExtensions { /// /// Perform an action when this member is called. - /// Must be followed by to provide the callback. + /// Must be followed by to provide the callback. /// public static WhenCalled When(this T substitute, Func substituteCall) where T : class { @@ -20,7 +20,7 @@ public static WhenCalled When(this T substitute, Func substituteC /// /// Perform an action when this member is called with any arguments. - /// Must be followed by to provide the callback. + /// Must be followed by to provide the callback. /// public static WhenCalled WhenForAnyArgs(this T substitute, Func substituteCall) where T : class { diff --git a/src/NSubstitute/SubstituteExtensions.When.ValueTask.cs b/src/NSubstitute/SubstituteExtensions.When.ValueTask.cs index 7759804da..f0d1ece1c 100644 --- a/src/NSubstitute/SubstituteExtensions.When.ValueTask.cs +++ b/src/NSubstitute/SubstituteExtensions.When.ValueTask.cs @@ -11,7 +11,7 @@ public static partial class SubstituteExtensions { /// /// Perform an action when this member is called. - /// Must be followed by to provide the callback. + /// Must be followed by to provide the callback. /// public static WhenCalled When(this TSubstitute substitute, Func> substituteCall) where TSubstitute : class @@ -21,7 +21,7 @@ public static WhenCalled When(this TSubstitut /// /// Perform an action when this member is called with any arguments. - /// Must be followed by to provide the callback. + /// Must be followed by to provide the callback. /// public static WhenCalled WhenForAnyArgs(this TSubstitute substitute, Func> substituteCall) where TSubstitute : class diff --git a/src/NSubstitute/SubstituteExtensions.When.cs b/src/NSubstitute/SubstituteExtensions.When.cs index 055e2aa69..0da4b33c6 100644 --- a/src/NSubstitute/SubstituteExtensions.When.cs +++ b/src/NSubstitute/SubstituteExtensions.When.cs @@ -11,7 +11,7 @@ public static partial class SubstituteExtensions { /// /// Perform an action when this member is called. - /// Must be followed by to provide the callback. + /// Must be followed by to provide the callback. /// public static WhenCalled When(this T substitute, Action substituteCall) where T : class { @@ -20,7 +20,7 @@ public static WhenCalled When(this T substitute, Action substituteCall) /// /// Perform an action when this member is called with any arguments. - /// Must be followed by to provide the callback. + /// Must be followed by to provide the callback. /// public static WhenCalled WhenForAnyArgs(this T substitute, Action substituteCall) where T : class { diff --git a/tests/NSubstitute.Acceptance.Specs/CallbackCalling.cs b/tests/NSubstitute.Acceptance.Specs/CallbackCalling.cs index 95beadbee..e24dd47bb 100644 --- a/tests/NSubstitute.Acceptance.Specs/CallbackCalling.cs +++ b/tests/NSubstitute.Acceptance.Specs/CallbackCalling.cs @@ -100,7 +100,7 @@ public void Throw_exception_when_Throw_with_specific_exception() [Test] public void Throw_exception_when_Throw_with_exception_generator() { - Func createException = ci => new ArgumentException("Argument: " + ci.Args()[0]); + Func createException = ci => new ArgumentException("Argument: " + ci.Args()[0]); int called = 0; _something.When(x => x.Echo(Arg.Any())).Do(Callback.Always(x => called++)); _something.When(x => x.Echo(Arg.Any())).Do(Callback.AlwaysThrow(createException)); diff --git a/tests/NSubstitute.Acceptance.Specs/WhenCalledDo.cs b/tests/NSubstitute.Acceptance.Specs/WhenCalledDo.cs index 6778947c1..7029fb3d7 100644 --- a/tests/NSubstitute.Acceptance.Specs/WhenCalledDo.cs +++ b/tests/NSubstitute.Acceptance.Specs/WhenCalledDo.cs @@ -99,7 +99,7 @@ public void Throw_exception_when_Throw_with_specific_exception() [Test] public void Throw_exception_when_Throw_with_exception_generator() { - Func createException = ci => new ArgumentException("Argument: " + ci.Args()[0]); + Func createException = ci => new ArgumentException("Argument: " + ci.Args()[0]); int called = 0; _something.When(x => x.Echo(Arg.Any())).Do(x => called++); _something.When(x => x.Echo(Arg.Any())).Throw(createException);