diff --git a/README.md b/README.md index 76d6e68..07a18ac 100644 --- a/README.md +++ b/README.md @@ -7,7 +7,7 @@ Inshiminator is a .NET-first developer toolkit that uses **Roslyn analyzers, inc ### 🚀 Key Features - **Detect:** Automatically finds direct usage of `DateTime.UtcNow`, `Guid.NewGuid()`, `File.ReadAllText`, and more. -- **Generate:** Emits strongly typed abstractions (`IClock`, `IGuidGenerator`) and implementations (`SystemClock`, `SystemGuidGenerator`) at compile time. +- **Generate:** Emits strongly typed abstractions (`IClock`, `IGuidGenerator`) and implementations (`SystemClock`, `SystemGuidGenerator`) at compile time, while also supporting framework abstractions like `TimeProvider`. - **Guide:** Provides IDE code fixes to automatically inject shims into your classes. - **Govern:** Enforce boundary rules through analyzer severity and baselines. diff --git a/src/Inshiminator.Analyzers/ClockAnalyzer.cs b/src/Inshiminator.Analyzers/ClockAnalyzer.cs index 43e3530..a9b6f2b 100644 --- a/src/Inshiminator.Analyzers/ClockAnalyzer.cs +++ b/src/Inshiminator.Analyzers/ClockAnalyzer.cs @@ -12,7 +12,7 @@ public class ClockAnalyzer : DiagnosticAnalyzer public const string DiagnosticId = "INSHIM001"; private static readonly LocalizableString Title = "Direct system clock usage detected"; - private static readonly LocalizableString MessageFormat = "Use IClock instead of {0} so time can be controlled in tests"; + private static readonly LocalizableString MessageFormat = "Use IClock (or TimeProvider when available) instead of {0} so time can be controlled in tests"; private static readonly LocalizableString Description = "Direct usage of system clock makes code difficult to test."; private const string Category = "Design"; diff --git a/src/Inshiminator.CodeFixes/ClockCodeFixProvider.cs b/src/Inshiminator.CodeFixes/ClockCodeFixProvider.cs index 35de270..3a3f427 100644 --- a/src/Inshiminator.CodeFixes/ClockCodeFixProvider.cs +++ b/src/Inshiminator.CodeFixes/ClockCodeFixProvider.cs @@ -1,3 +1,4 @@ +using System.Collections.Generic; using System.Collections.Immutable; using System.Composition; using System.Linq; @@ -9,6 +10,7 @@ using Microsoft.CodeAnalysis.CSharp; using Microsoft.CodeAnalysis.CSharp.Syntax; using Microsoft.CodeAnalysis.Editing; +using Microsoft.CodeAnalysis.Formatting; namespace Inshiminator.CodeFixes; @@ -24,22 +26,51 @@ public override async Task RegisterCodeFixesAsync(CodeFixContext context) var root = await context.Document.GetSyntaxRootAsync(context.CancellationToken).ConfigureAwait(false); if (root is null) return; + var semanticModel = await context.Document.GetSemanticModelAsync(context.CancellationToken).ConfigureAwait(false); + if (semanticModel is null) return; + + var timeProviderType = semanticModel.Compilation.GetTypeByMetadataName("System.TimeProvider"); + var diagnostic = context.Diagnostics.First(); var diagnosticSpan = diagnostic.Location.SourceSpan; - var memberAccess = root.FindToken(diagnosticSpan.Start).Parent?.AncestorsAndSelf().OfType().First(); + var memberAccessCandidates = root.FindToken(diagnosticSpan.Start).Parent?.AncestorsAndSelf().OfType(); + var memberAccess = memberAccessCandidates?.FirstOrDefault(m => m.Name.Identifier.ValueText is "Now" or "UtcNow") + ?? memberAccessCandidates?.FirstOrDefault(); if (memberAccess is null) return; + var isStaticContext = IsStaticContext(memberAccess, semanticModel, context.CancellationToken); + + if (!isStaticContext + && !IsUnsafeInjectedDependencyUsageContext(memberAccess)) + { + context.RegisterCodeFix( + CodeAction.Create( + title: "Use injected IClock", + createChangedDocument: c => UseInjectedClockAsync(context.Document, memberAccess, c), + equivalenceKey: nameof(ClockCodeFixProvider)), + diagnostic); + } - context.RegisterCodeFix( - CodeAction.Create( - title: "Use injected IClock", - createChangedDocument: c => UseInjectedClockAsync(context.Document, memberAccess, c), - equivalenceKey: nameof(ClockCodeFixProvider)), - diagnostic); + if (timeProviderType is not null + && !isStaticContext + && CanOfferInjectedTimeProviderCodeFix(memberAccess, semanticModel, timeProviderType, context.CancellationToken)) + { + context.RegisterCodeFix( + CodeAction.Create( + title: "Use injected TimeProvider", + createChangedDocument: c => UseInjectedTimeProviderAsync(context.Document, memberAccess, c), + equivalenceKey: $"{nameof(ClockCodeFixProvider)}_TimeProvider"), + diagnostic); + } } private async Task UseInjectedClockAsync(Document document, MemberAccessExpressionSyntax memberAccess, CancellationToken cancellationToken) { + if (IsUnsafeInjectedDependencyUsageContext(memberAccess)) + { + return document; + } + var editor = await DocumentEditor.CreateAsync(document, cancellationToken).ConfigureAwait(false); var root = editor.OriginalRoot; @@ -99,4 +130,719 @@ private async Task UseInjectedClockAsync(Document document, MemberAcce return editor.GetChangedDocument(); } + + private static bool IsStaticContext(MemberAccessExpressionSyntax memberAccess, SemanticModel semanticModel, CancellationToken cancellationToken) + { + var classDeclaration = memberAccess.AncestorsAndSelf().OfType().FirstOrDefault(); + if (classDeclaration?.Modifiers.Any(SyntaxKind.StaticKeyword) == true) + { + return true; + } + + var containingMember = memberAccess.AncestorsAndSelf().OfType().FirstOrDefault( + m => m is BaseMethodDeclarationSyntax + or BasePropertyDeclarationSyntax + or FieldDeclarationSyntax + or EventFieldDeclarationSyntax + or EventDeclarationSyntax); + + if (containingMember?.Modifiers.Any(SyntaxKind.StaticKeyword) == true) + { + return true; + } + + var enclosingSymbol = semanticModel.GetEnclosingSymbol(memberAccess.SpanStart, cancellationToken); + for (var current = enclosingSymbol; current is not null; current = current.ContainingSymbol) + { + if (current is INamedTypeSymbol namedType && namedType.IsStatic) + { + return true; + } + + if (current is IMethodSymbol method && method.IsStatic) + { + return true; + } + } + + return false; + } + + private static bool CanOfferInjectedTimeProviderCodeFix( + MemberAccessExpressionSyntax memberAccess, + SemanticModel semanticModel, + INamedTypeSymbol timeProviderType, + CancellationToken cancellationToken) + { + if (IsUnsafeInjectedDependencyUsageContext(memberAccess)) + { + return false; + } + + var classDeclaration = memberAccess.AncestorsAndSelf().OfType().FirstOrDefault(); + if (classDeclaration is null) + { + return false; + } + + var classSymbol = semanticModel.GetDeclaredSymbol(classDeclaration, cancellationToken); + if (classSymbol is null) + { + return false; + } + + var hasReusableTimeProviderField = classSymbol.GetMembers() + .OfType() + .Any(fieldSymbol => + !fieldSymbol.IsImplicitlyDeclared + && !fieldSymbol.IsStatic + && !fieldSymbol.IsConst + && IsCompatibleTimeProviderType(fieldSymbol.Type, timeProviderType)); + + var currentSyntaxTree = semanticModel.SyntaxTree; + var hasInstanceConstructorsOutsideCurrentDocument = classSymbol.InstanceConstructors + .Where(ctor => !ctor.IsImplicitlyDeclared) + .SelectMany(ctor => ctor.DeclaringSyntaxReferences) + .Any(reference => reference.SyntaxTree != currentSyntaxTree); + + if (!hasInstanceConstructorsOutsideCurrentDocument) + { + return true; + } + + return hasReusableTimeProviderField + && AllExplicitInstanceConstructorsHaveCompatibleTimeProviderParameter(classSymbol, timeProviderType); + } + + private static bool AllExplicitInstanceConstructorsHaveCompatibleTimeProviderParameter( + INamedTypeSymbol classSymbol, + INamedTypeSymbol timeProviderType) + { + return classSymbol.InstanceConstructors + .Where(ctor => !ctor.IsImplicitlyDeclared) + .All(ctor => ctor.Parameters.Any(parameter => IsCompatibleTimeProviderType(parameter.Type, timeProviderType))); + } + + private static bool IsUnsafeInjectedDependencyUsageContext(MemberAccessExpressionSyntax memberAccess) + { + if (memberAccess.Ancestors().OfType().Any()) + { + return true; + } + + var equalsValueClause = memberAccess.Ancestors().OfType().FirstOrDefault(); + if (equalsValueClause is null) + { + return false; + } + + return equalsValueClause.Parent switch + { + VariableDeclaratorSyntax variableDeclarator => variableDeclarator.Parent?.Parent is FieldDeclarationSyntax or EventFieldDeclarationSyntax, + PropertyDeclarationSyntax => true, + _ => false, + }; + } + + private async Task UseInjectedTimeProviderAsync(Document document, MemberAccessExpressionSyntax memberAccess, CancellationToken cancellationToken) + { + var editor = await DocumentEditor.CreateAsync(document, cancellationToken).ConfigureAwait(false); + var semanticModel = editor.SemanticModel; + var timeProviderType = semanticModel.Compilation.GetTypeByMetadataName("System.TimeProvider"); + if (timeProviderType is null) + { + return document; + } + + var classDeclaration = memberAccess.AncestorsAndSelf().OfType().FirstOrDefault(); + if (classDeclaration is null) return document; + if (IsUnsafeInjectedDependencyUsageContext(memberAccess)) + { + return document; + } + + var classSymbol = semanticModel.GetDeclaredSymbol(classDeclaration, cancellationToken); + var currentSyntaxTree = semanticModel.SyntaxTree; + var allInstanceConstructorSymbols = classSymbol?.InstanceConstructors + .ToList() + ?? []; + var allInstanceConstructorsInCurrentDocument = allInstanceConstructorSymbols + .SelectMany(ctor => ctor.DeclaringSyntaxReferences) + .Select(reference => reference.GetSyntax(cancellationToken)) + .OfType() + .Where(constructor => constructor.SyntaxTree == currentSyntaxTree) + .Distinct() + .ToList(); + var hasInstanceConstructorsOutsideCurrentDocument = allInstanceConstructorSymbols + .SelectMany(ctor => ctor.DeclaringSyntaxReferences) + .Any(reference => reference.SyntaxTree != currentSyntaxTree); + + // 1. Add TimeProvider field if it doesn't exist + var timeProviderField = classSymbol?.GetMembers() + .OfType() + .FirstOrDefault(fieldSymbol => + !fieldSymbol.IsImplicitlyDeclared + && !fieldSymbol.IsStatic + && !fieldSymbol.IsConst + && IsCompatibleTimeProviderType(fieldSymbol.Type, timeProviderType)); + + var fieldName = "_timeProvider"; + TypeSyntax fieldTypeSyntax = SyntaxFactory.ParseTypeName("global::System.TimeProvider"); + ITypeSymbol fieldTypeSymbol = timeProviderType; + if (timeProviderField is null && hasInstanceConstructorsOutsideCurrentDocument) + { + return document; + } + + if (timeProviderField is null) + { + fieldName = GetUniqueFieldName(classDeclaration, classSymbol, fieldName); + // Safe to keep readonly because the guard above returns when constructors outside this document cannot be updated. + var field = (FieldDeclarationSyntax)editor.Generator.FieldDeclaration( + fieldName, + fieldTypeSyntax, + Accessibility.Private, + DeclarationModifiers.ReadOnly); + field = field.WithAdditionalAnnotations(Formatter.Annotation); + if (classDeclaration.Members.Count > 0) + { + editor.InsertBefore(classDeclaration.Members.First(), field); + } + else + { + editor.AddMember(classDeclaration, field); + } + } + else + { + fieldName = timeProviderField.Name; + fieldTypeSymbol = timeProviderField.Type; + fieldTypeSyntax = SymbolEqualityComparer.Default.Equals(fieldTypeSymbol, timeProviderType) + ? SyntaxFactory.ParseTypeName("global::System.TimeProvider") + : SyntaxFactory.ParseTypeName(fieldTypeSymbol.ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat)); + } + + // 2. Update constructors (or create one) + var instanceConstructors = allInstanceConstructorsInCurrentDocument; + if (instanceConstructors.Count == 0) + { + var allExplicitInstanceConstructors = (classSymbol?.InstanceConstructors ?? ImmutableArray.Empty) + .Where(static constructor => !constructor.IsImplicitlyDeclared) + .ToImmutableArray(); + + if (allExplicitInstanceConstructors.Length == 0) + { + var parameter = (ParameterSyntax)editor.Generator.ParameterDeclaration( + "timeProvider", + fieldTypeSyntax); + var assignment = CreateFieldAssignmentStatement(editor, fieldName, "timeProvider"); + + var newConstructor = (ConstructorDeclarationSyntax)editor.Generator.ConstructorDeclaration( + classDeclaration.Identifier.ValueText, + accessibility: Accessibility.Public, + parameters: [parameter], + statements: [assignment]); + newConstructor = newConstructor.WithAdditionalAnnotations(Formatter.Annotation); + + editor.AddMember(classDeclaration, newConstructor); + } + } + else + { + var constructorSymbols = new HashSet(SymbolEqualityComparer.Default); + var parameterNamesByConstructorSymbol = new Dictionary(SymbolEqualityComparer.Default); + foreach (var constructor in instanceConstructors) + { + var constructorSymbol = semanticModel.GetDeclaredSymbol(constructor, cancellationToken); + if (constructorSymbol is not null) + { + constructorSymbols.Add(constructorSymbol); + } + + var existingParameter = constructor.ParameterList.Parameters.FirstOrDefault( + p => + { + if (p.Type is null) + { + return false; + } + + var parameterType = semanticModel.GetTypeInfo(p.Type, cancellationToken).Type; + return IsCompatibleTimeProviderType(parameterType, timeProviderType) + && CanAssignType(parameterType, fieldTypeSymbol, semanticModel.Compilation); + }); + + var parameterName = existingParameter?.Identifier.ValueText ?? GetUniqueParameterName(constructor, "timeProvider"); + if (constructorSymbol is not null) + { + parameterNamesByConstructorSymbol[constructorSymbol] = parameterName; + } + } + + foreach (var constructor in instanceConstructors) + { + var updatedConstructor = constructor; + + var existingParameter = constructor.ParameterList.Parameters.FirstOrDefault( + p => + { + if (p.Type is null) + { + return false; + } + + var parameterType = semanticModel.GetTypeInfo(p.Type, cancellationToken).Type; + return IsCompatibleTimeProviderType(parameterType, timeProviderType) + && CanAssignType(parameterType, fieldTypeSymbol, semanticModel.Compilation); + }); + + var parameterName = existingParameter?.Identifier.ValueText ?? GetUniqueParameterName(constructor, "timeProvider"); + if (existingParameter is null) + { + var parameter = (ParameterSyntax)editor.Generator.ParameterDeclaration( + parameterName, + fieldTypeSyntax); + updatedConstructor = InsertRequiredParameter(updatedConstructor, parameter); + } + + var thisInitializerArgumentInfo = GetThisInitializerArgumentInfo( + constructor, + semanticModel, + constructorSymbols, + parameterNamesByConstructorSymbol, + fieldTypeSymbol, + timeProviderType, + cancellationToken); + var initializerResult = EnsureThisInitializerHasArgument(updatedConstructor, parameterName, thisInitializerArgumentInfo); + updatedConstructor = initializerResult.UpdatedConstructor; + + if (updatedConstructor.Body is null) + { + var body = updatedConstructor.ExpressionBody is null + ? SyntaxFactory.Block() + : SyntaxFactory.Block(ExpressionToStatement(updatedConstructor.ExpressionBody.Expression)); + updatedConstructor = updatedConstructor + .WithExpressionBody(null) + .WithSemicolonToken(default) + .WithBody(body); + } + + if (!initializerResult.PassesParameterToThisInitializer) + { + updatedConstructor = EnsureFieldAssignmentFromParameter(editor, updatedConstructor, fieldName, parameterName, existingParameter is null); + } + + editor.ReplaceNode(constructor, updatedConstructor.WithAdditionalAnnotations(Formatter.Annotation)); + } + } + + // 3. Replace usage + var memberSymbol = semanticModel.GetSymbolInfo(memberAccess, cancellationToken).Symbol; + var containingType = memberSymbol?.ContainingType; + var isDateTime = containingType?.ToDisplayString() == "System.DateTime"; + var memberName = memberAccess.Name.Identifier.Text; + + var replacement = memberName switch + { + "UtcNow" => editor.Generator.InvocationExpression( + editor.Generator.MemberAccessExpression(editor.Generator.IdentifierName(fieldName), "GetUtcNow")), + "Now" => editor.Generator.InvocationExpression( + editor.Generator.MemberAccessExpression(editor.Generator.IdentifierName(fieldName), "GetLocalNow")), + _ => memberAccess, + }; + + if (isDateTime) + { + replacement = memberName switch + { + "UtcNow" => editor.Generator.MemberAccessExpression(replacement, "UtcDateTime"), + "Now" => editor.Generator.InvocationExpression( + editor.Generator.MemberAccessExpression(SyntaxFactory.ParseExpression("System.DateTime"), "SpecifyKind"), + editor.Generator.MemberAccessExpression(replacement, "LocalDateTime"), + editor.Generator.MemberAccessExpression(SyntaxFactory.ParseExpression("System.DateTimeKind"), "Local")), + _ => replacement, + }; + } + + editor.ReplaceNode(memberAccess, replacement.WithAdditionalAnnotations(Formatter.Annotation)); + + return editor.GetChangedDocument(); + } + + private static bool IsCompatibleTimeProviderType(ITypeSymbol? typeSymbol, INamedTypeSymbol timeProviderType) + { + if (typeSymbol is null) + { + return false; + } + + for (var currentType = typeSymbol; currentType is not null; currentType = currentType.BaseType) + { + if (SymbolEqualityComparer.Default.Equals(currentType, timeProviderType)) + { + return true; + } + } + + return false; + } + + private static string GetUniqueFieldName(ClassDeclarationSyntax classDeclaration, INamedTypeSymbol? classSymbol, string baseName) + { + var usedNameSet = new HashSet(); + if (classSymbol is not null) + { + foreach (var member in classSymbol.GetMembers()) + { + if (!member.IsImplicitlyDeclared && !string.IsNullOrWhiteSpace(member.Name)) + { + usedNameSet.Add(member.Name); + } + } + } + else + { + foreach (var member in classDeclaration.Members) + { + switch (member) + { + case FieldDeclarationSyntax field: + foreach (var variable in field.Declaration.Variables) + { + usedNameSet.Add(variable.Identifier.ValueText); + } + break; + case EventFieldDeclarationSyntax eventField: + foreach (var variable in eventField.Declaration.Variables) + { + usedNameSet.Add(variable.Identifier.ValueText); + } + break; + case PropertyDeclarationSyntax property: + usedNameSet.Add(property.Identifier.ValueText); + break; + case MethodDeclarationSyntax method: + usedNameSet.Add(method.Identifier.ValueText); + break; + case EventDeclarationSyntax @event: + usedNameSet.Add(@event.Identifier.ValueText); + break; + case BaseTypeDeclarationSyntax nestedType: + usedNameSet.Add(nestedType.Identifier.ValueText); + break; + case DelegateDeclarationSyntax @delegate: + usedNameSet.Add(@delegate.Identifier.ValueText); + break; + } + } + } + + if (!usedNameSet.Contains(baseName)) + { + return baseName; + } + + var suffix = 1; + while (usedNameSet.Contains($"{baseName}{suffix}")) + { + suffix++; + } + + return $"{baseName}{suffix}"; + } + + private static bool CanAssignType(ITypeSymbol? sourceType, ITypeSymbol destinationType, Compilation compilation) => + sourceType is not null && compilation.ClassifyConversion(sourceType, destinationType).IsImplicit; + + private static string GetUniqueParameterName(ConstructorDeclarationSyntax constructor, string baseName) + { + var usedNameSet = new HashSet(constructor.ParameterList.Parameters + .Select(parameter => parameter.Identifier.ValueText)); + foreach (var bodyIdentifier in GetDeclaredBodyIdentifiers(constructor)) + { + usedNameSet.Add(bodyIdentifier); + } + + if (!usedNameSet.Contains(baseName)) + { + return baseName; + } + + var suffix = 1; + while (usedNameSet.Contains($"{baseName}{suffix}")) + { + suffix++; + } + + return $"{baseName}{suffix}"; + } + + private static IEnumerable GetDeclaredBodyIdentifiers(ConstructorDeclarationSyntax constructor) + { + if (constructor.Body is null) + { + yield break; + } + + foreach (var variable in constructor.Body.DescendantNodes().OfType()) + { + yield return variable.Identifier.ValueText; + } + + foreach (var @foreach in constructor.Body.DescendantNodes().OfType()) + { + yield return @foreach.Identifier.ValueText; + } + + foreach (var @catch in constructor.Body.DescendantNodes().OfType()) + { + if (@catch.Identifier != default) + { + yield return @catch.Identifier.ValueText; + } + } + + foreach (var designation in constructor.Body.DescendantNodes().OfType()) + { + yield return designation.Identifier.ValueText; + } + } + + private static ConstructorDeclarationSyntax InsertRequiredParameter( + ConstructorDeclarationSyntax constructor, + ParameterSyntax parameter) + { + var parameters = constructor.ParameterList.Parameters; + int? firstOptionalOrParamsIndex = null; + for (var index = 0; index < parameters.Count; index++) + { + if (parameters[index].Default is not null + || parameters[index].Modifiers.Any(SyntaxKind.ParamsKeyword)) + { + firstOptionalOrParamsIndex = index; + break; + } + } + + if (!firstOptionalOrParamsIndex.HasValue) + { + return constructor.AddParameterListParameters(parameter); + } + + return constructor.WithParameterList( + constructor.ParameterList.WithParameters(parameters.Insert(firstOptionalOrParamsIndex.Value, parameter))); + } + + private static (bool CanPassToThisInitializer, int ArgumentIndex, string? TargetParameterName, bool UseNamedArgument, bool HasExistingTargetParameter) GetThisInitializerArgumentInfo( + ConstructorDeclarationSyntax constructor, + SemanticModel semanticModel, + ISet constructorSymbols, + IReadOnlyDictionary parameterNamesByConstructorSymbol, + ITypeSymbol fieldTypeSymbol, + INamedTypeSymbol timeProviderType, + CancellationToken cancellationToken) + { + var initializer = constructor.Initializer; + if (initializer is null || !initializer.IsKind(SyntaxKind.ThisConstructorInitializer)) + { + return (false, -1, null, false, false); + } + + var targetConstructorSymbol = semanticModel.GetSymbolInfo(initializer, cancellationToken).Symbol as IMethodSymbol; + if (targetConstructorSymbol is null || !constructorSymbols.Contains(targetConstructorSymbol)) + { + return (false, -1, null, false, false); + } + + var targetTimeProviderParameterIndex = -1; + string? targetTimeProviderParameterName = null; + for (var index = 0; index < targetConstructorSymbol.Parameters.Length; index++) + { + if (IsCompatibleTimeProviderType(targetConstructorSymbol.Parameters[index].Type, timeProviderType) + && CanAssignType(targetConstructorSymbol.Parameters[index].Type, fieldTypeSymbol, semanticModel.Compilation)) + { + targetTimeProviderParameterIndex = index; + targetTimeProviderParameterName = targetConstructorSymbol.Parameters[index].Name; + break; + } + } + + if (targetTimeProviderParameterIndex >= 0) + { + var shouldUseNamedArgument = + initializer.ArgumentList.Arguments.Any(argument => argument.NameColon is not null) + || targetTimeProviderParameterIndex >= initializer.ArgumentList.Arguments.Count; + return (true, targetTimeProviderParameterIndex, targetTimeProviderParameterName, shouldUseNamedArgument, true); + } + + var firstOptionalOrParamsIndex = -1; + for (var index = 0; index < targetConstructorSymbol.Parameters.Length; index++) + { + if (targetConstructorSymbol.Parameters[index].IsOptional || targetConstructorSymbol.Parameters[index].IsParams) + { + firstOptionalOrParamsIndex = index; + break; + } + } + + var argumentIndex = firstOptionalOrParamsIndex >= 0 ? firstOptionalOrParamsIndex : -1; + var useNamedArgument = initializer.ArgumentList.Arguments.Any(argument => argument.NameColon is not null); + parameterNamesByConstructorSymbol.TryGetValue(targetConstructorSymbol, out targetTimeProviderParameterName); + + return (true, argumentIndex, targetTimeProviderParameterName, useNamedArgument, false); + } + + private static (ConstructorDeclarationSyntax UpdatedConstructor, bool PassesParameterToThisInitializer) EnsureThisInitializerHasArgument( + ConstructorDeclarationSyntax constructor, + string parameterName, + (bool CanPassToThisInitializer, int ArgumentIndex, string? TargetParameterName, bool UseNamedArgument, bool HasExistingTargetParameter) thisInitializerArgumentInfo) + { + var initializer = constructor.Initializer; + if (!thisInitializerArgumentInfo.CanPassToThisInitializer + || initializer is null + || !initializer.IsKind(SyntaxKind.ThisConstructorInitializer)) + { + return (constructor, false); + } + + var targetParameterName = thisInitializerArgumentInfo.TargetParameterName; + ArgumentSyntax? namedTargetArgument = null; + if (targetParameterName is not null) + { + namedTargetArgument = initializer.ArgumentList.Arguments.FirstOrDefault( + argument => argument.NameColon?.Name is IdentifierNameSyntax name + && name.Identifier.ValueText == targetParameterName); + } + if (namedTargetArgument is not null) + { + if (namedTargetArgument.Expression is IdentifierNameSyntax namedTargetIdentifier + && namedTargetIdentifier.Identifier.ValueText == parameterName) + { + return (constructor, true); + } + + var updatedNamedArgument = namedTargetArgument.WithExpression(SyntaxFactory.IdentifierName(parameterName)); + return ( + constructor.WithInitializer( + initializer.WithArgumentList( + initializer.ArgumentList.WithArguments( + initializer.ArgumentList.Arguments.Replace(namedTargetArgument, updatedNamedArgument)))), + true); + } + + if (thisInitializerArgumentInfo.HasExistingTargetParameter + && !initializer.ArgumentList.Arguments.Any(argument => argument.NameColon is not null) + && thisInitializerArgumentInfo.ArgumentIndex >= 0 + && thisInitializerArgumentInfo.ArgumentIndex < initializer.ArgumentList.Arguments.Count) + { + var targetArgument = initializer.ArgumentList.Arguments[thisInitializerArgumentInfo.ArgumentIndex]; + if (targetArgument.Expression is IdentifierNameSyntax positionalTargetIdentifier + && positionalTargetIdentifier.Identifier.ValueText == parameterName) + { + return (constructor, true); + } + + var updatedTargetArgument = targetArgument.WithExpression(SyntaxFactory.IdentifierName(parameterName)); + return ( + constructor.WithInitializer( + initializer.WithArgumentList( + initializer.ArgumentList.WithArguments( + initializer.ArgumentList.Arguments.Replace(targetArgument, updatedTargetArgument)))), + true); + } + + var newArgument = SyntaxFactory.Argument(SyntaxFactory.IdentifierName(parameterName)); + if (thisInitializerArgumentInfo.UseNamedArgument && targetParameterName is not null) + { + newArgument = newArgument.WithNameColon( + SyntaxFactory.NameColon(SyntaxFactory.IdentifierName(targetParameterName))); + return ( + constructor.WithInitializer( + initializer.WithArgumentList( + initializer.ArgumentList.WithArguments(initializer.ArgumentList.Arguments.Add(newArgument)))), + true); + } + + var argumentIndex = thisInitializerArgumentInfo.ArgumentIndex; + var updatedArguments = argumentIndex >= 0 && argumentIndex < initializer.ArgumentList.Arguments.Count + ? initializer.ArgumentList.Arguments.Insert(argumentIndex, newArgument) + : initializer.ArgumentList.Arguments.Add(newArgument); + + return ( + constructor.WithInitializer( + initializer.WithArgumentList( + initializer.ArgumentList.WithArguments(updatedArguments))), + true); + } + + private static StatementSyntax ExpressionToStatement(ExpressionSyntax expression) => + expression is ThrowExpressionSyntax throwExpression + ? SyntaxFactory.ThrowStatement(throwExpression.Expression) + : SyntaxFactory.ExpressionStatement(expression); + + private static ExpressionStatementSyntax CreateFieldAssignmentStatement(DocumentEditor editor, string fieldName, string parameterName) + { + var needsQualifiedFieldAccess = string.Equals(fieldName, parameterName, System.StringComparison.Ordinal); + var assignmentTarget = needsQualifiedFieldAccess + ? editor.Generator.MemberAccessExpression(editor.Generator.ThisExpression(), editor.Generator.IdentifierName(fieldName)) + : editor.Generator.IdentifierName(fieldName); + + return (ExpressionStatementSyntax)editor.Generator.ExpressionStatement( + editor.Generator.AssignmentStatement( + assignmentTarget, + editor.Generator.IdentifierName(parameterName))); + } + + private static ConstructorDeclarationSyntax EnsureFieldAssignmentFromParameter( + DocumentEditor editor, + ConstructorDeclarationSyntax constructor, + string fieldName, + string parameterName, + bool parameterWasAdded) + { + if (constructor.Body is null) + { + return constructor; + } + + var fieldAssignmentStatement = constructor.Body.Statements + .OfType() + .FirstOrDefault(statement => statement.Expression is AssignmentExpressionSyntax assignment && IsFieldAssignmentTarget(assignment.Left, fieldName)); + if (fieldAssignmentStatement is null) + { + var assignment = CreateFieldAssignmentStatement(editor, fieldName, parameterName); + return constructor.WithBody(constructor.Body.WithStatements(constructor.Body.Statements.Insert(0, assignment))); + } + + var fieldAssignment = (AssignmentExpressionSyntax)fieldAssignmentStatement.Expression; + if (fieldAssignment.Right is IdentifierNameSyntax identifier && identifier.Identifier.ValueText == parameterName) + { + return constructor; + } + + if (fieldAssignment.Right is not IdentifierNameSyntax) + { + if (!parameterWasAdded) + { + return constructor; + } + + var updatedNonIdentifierStatement = fieldAssignmentStatement.WithExpression(fieldAssignment.WithRight(SyntaxFactory.IdentifierName(parameterName))); + return constructor.ReplaceNode(fieldAssignmentStatement, updatedNonIdentifierStatement); + } + + var updatedStatement = fieldAssignmentStatement.WithExpression(fieldAssignment.WithRight(SyntaxFactory.IdentifierName(parameterName))); + return constructor.ReplaceNode(fieldAssignmentStatement, updatedStatement); + } + + private static bool IsFieldAssignmentTarget(ExpressionSyntax expression, string fieldName) => + expression switch + { + IdentifierNameSyntax identifier => identifier.Identifier.ValueText == fieldName, + MemberAccessExpressionSyntax memberAccess + when memberAccess.Expression is ThisExpressionSyntax + && memberAccess.Name.Identifier.ValueText == fieldName => true, + _ => false, + }; + } diff --git a/tests/Inshiminator.Analyzers.Tests/ClockCodeFixTests.cs b/tests/Inshiminator.Analyzers.Tests/ClockCodeFixTests.cs index 84cc2a2..d291490 100644 --- a/tests/Inshiminator.Analyzers.Tests/ClockCodeFixTests.cs +++ b/tests/Inshiminator.Analyzers.Tests/ClockCodeFixTests.cs @@ -9,6 +9,18 @@ namespace Inshiminator.Analyzers.Tests; public class ClockCodeFixTests { + private const string TimeProviderCodeFixEquivalenceKey = "ClockCodeFixProvider_TimeProvider"; + private const string TimeProviderStub = """ +namespace System +{ + public abstract class TimeProvider + { + public abstract DateTimeOffset GetUtcNow(); + public virtual DateTimeOffset GetLocalNow() => GetUtcNow().ToLocalTime(); + } +} +"""; + [Fact] public async Task DateTimeUtcNow_AppliesCodeFix() { @@ -82,4 +94,1379 @@ void Method() await VerifyCS.VerifyCodeFixAsync(test, fixedCode); } + + [Fact] + public async Task DateTimeUtcNow_AppliesTimeProviderCodeFix() + { + await VerifyTimeProviderFixAsync("DateTime", "DateTime", "UtcNow", "_timeProvider.GetUtcNow().UtcDateTime"); + } + + [Fact] + public async Task DateTimeNow_AppliesTimeProviderCodeFix() + { + await VerifyTimeProviderFixAsync("DateTime", "DateTime", "Now", "System.DateTime.SpecifyKind(_timeProvider.GetLocalNow().LocalDateTime, System.DateTimeKind.Local)"); + } + + [Fact] + public async Task DateTimeOffsetUtcNow_AppliesTimeProviderCodeFix() + { + await VerifyTimeProviderFixAsync("DateTimeOffset", "DateTimeOffset", "UtcNow", "_timeProvider.GetUtcNow()"); + } + + [Fact] + public async Task DateTimeOffsetNow_AppliesTimeProviderCodeFix() + { + await VerifyTimeProviderFixAsync("DateTimeOffset", "DateTimeOffset", "Now", "_timeProvider.GetLocalNow()"); + } + + [Fact] + public async Task FullyQualifiedDateTimeNow_WithoutUsingSystem_AppliesTimeProviderCodeFix() + { + var test = $$""" +{{TimeProviderStub}} + +class Test +{ + public Test() + { + } + + void Method() + { + System.DateTime now = [|System.DateTime.Now|]; + } +} +"""; + + var fixedCode = $$""" +{{TimeProviderStub}} + +class Test +{ + private readonly global::System.TimeProvider _timeProvider; + + public Test(global::System.TimeProvider timeProvider) + { + _timeProvider = timeProvider; + } + + void Method() + { + System.DateTime now = System.DateTime.SpecifyKind(_timeProvider.GetLocalNow().LocalDateTime, System.DateTimeKind.Local); + } +} +"""; + + await VerifyTimeProviderCodeFixAsync(test, fixedCode); + } + + [Fact] + public async Task TimeProviderCodeFix_UpdatesThisConstructorInitializer() + { + var test = $$""" +using System; + +{{TimeProviderStub}} + +class Test +{ + public Test() : this(1) + { + } + + public Test(int value) + { + } + + void Method() + { + DateTime now = [|DateTime.UtcNow|]; + } +} +"""; + + var fixedCode = $$""" +using System; + +{{TimeProviderStub}} + +class Test +{ + private readonly global::System.TimeProvider _timeProvider; + + public Test(global::System.TimeProvider timeProvider) : this(1, timeProvider) + { + } + + public Test(int value, global::System.TimeProvider timeProvider) + { + _timeProvider = timeProvider; + } + + void Method() + { + DateTime now = _timeProvider.GetUtcNow().UtcDateTime; + } +} +"""; + + await VerifyTimeProviderCodeFixAsync(test, fixedCode); + } + + [Fact] + public async Task TimeProviderCodeFix_InsertsThisInitializerArgumentBeforeOptionalArguments() + { + var test = $$""" +using System; + +{{TimeProviderStub}} + +class Test +{ + public Test() : this(1, "name") + { + } + + public Test(int value, string name = "default") + { + } + + void Method() + { + DateTimeOffset now = [|DateTimeOffset.UtcNow|]; + } +} +"""; + + var fixedCode = $$""" +using System; + +{{TimeProviderStub}} + +class Test +{ + private readonly global::System.TimeProvider _timeProvider; + + public Test(global::System.TimeProvider timeProvider) : this(1, timeProvider, "name") + { + } + + public Test(int value, global::System.TimeProvider timeProvider, string name = "default") + { + _timeProvider = timeProvider; + } + + void Method() + { + DateTimeOffset now = _timeProvider.GetUtcNow(); + } +} +"""; + + await VerifyTimeProviderCodeFixAsync(test, fixedCode); + } + + [Fact] + public async Task TimeProviderCodeFix_UsesNamedArgumentWhenThisInitializerUsesNamedArguments() + { + var test = $$""" +using System; + +{{TimeProviderStub}} + +class Test +{ + public Test() : this(value: 1, name: "name") + { + } + + public Test(int value, string name = "default") + { + } + + void Method() + { + DateTimeOffset now = [|DateTimeOffset.UtcNow|]; + } +} +"""; + + var fixedCode = $$""" +using System; + +{{TimeProviderStub}} + +class Test +{ + private readonly global::System.TimeProvider _timeProvider; + + public Test(global::System.TimeProvider timeProvider) : this(value: 1, name: "name", timeProvider: timeProvider) + { + } + + public Test(int value, global::System.TimeProvider timeProvider, string name = "default") + { + _timeProvider = timeProvider; + } + + void Method() + { + DateTimeOffset now = _timeProvider.GetUtcNow(); + } +} +"""; + + await VerifyTimeProviderCodeFixAsync(test, fixedCode); + } + + [Fact] + public async Task TimeProviderCodeFix_UsesNamedArgumentWhenTargetConstructorAlreadyHasTimeProviderParameter() + { + var test = $$""" +using System; + +{{TimeProviderStub}} + +class Test +{ + public Test() : this(value: 1) + { + } + + public Test(int value, global::System.TimeProvider timeProvider = null) + { + } + + void Method() + { + DateTimeOffset now = [|DateTimeOffset.UtcNow|]; + } +} +"""; + + var fixedCode = $$""" +using System; + +{{TimeProviderStub}} + +class Test +{ + private readonly global::System.TimeProvider _timeProvider; + + public Test(global::System.TimeProvider timeProvider) : this(value: 1, timeProvider: timeProvider) + { + } + + public Test(int value, global::System.TimeProvider timeProvider = null) + { + _timeProvider = timeProvider; + } + + void Method() + { + DateTimeOffset now = _timeProvider.GetUtcNow(); + } +} +"""; + + await VerifyTimeProviderCodeFixAsync(test, fixedCode); + } + + [Fact] + public async Task TimeProviderCodeFix_UsesNamedArgumentWhenOptionalParameterPrecedesExistingTimeProviderParameter() + { + var test = $$""" +using System; + +{{TimeProviderStub}} + +class Test +{ + public Test() : this(1) + { + } + + public Test(int value, string name = "default", global::System.TimeProvider timeProvider = null) + { + } + + void Method() + { + DateTimeOffset now = [|DateTimeOffset.UtcNow|]; + } +} +"""; + + var fixedCode = $$""" +using System; + +{{TimeProviderStub}} + +class Test +{ + private readonly global::System.TimeProvider _timeProvider; + + public Test(global::System.TimeProvider timeProvider) : this(1, timeProvider: timeProvider) + { + } + + public Test(int value, string name = "default", global::System.TimeProvider timeProvider = null) + { + _timeProvider = timeProvider; + } + + void Method() + { + DateTimeOffset now = _timeProvider.GetUtcNow(); + } +} +"""; + + await VerifyTimeProviderCodeFixAsync(test, fixedCode); + } + + [Fact] + public async Task TimeProviderCodeFix_ReplacesExistingThisInitializerArgumentWithInjectedParameter() + { + var test = $$""" +using System; + +{{TimeProviderStub}} + +class Test +{ + public Test() : this(value: 1, timeProvider: null) + { + } + + public Test(int value, global::System.TimeProvider timeProvider = null) + { + } + + void Method() + { + DateTimeOffset now = [|DateTimeOffset.UtcNow|]; + } +} +"""; + + var fixedCode = $$""" +using System; + +{{TimeProviderStub}} + +class Test +{ + private readonly global::System.TimeProvider _timeProvider; + + public Test(global::System.TimeProvider timeProvider) : this(value: 1, timeProvider: timeProvider) + { + } + + public Test(int value, global::System.TimeProvider timeProvider = null) + { + _timeProvider = timeProvider; + } + + void Method() + { + DateTimeOffset now = _timeProvider.GetUtcNow(); + } +} +"""; + + await VerifyTimeProviderCodeFixAsync(test, fixedCode); + } + + [Fact] + public async Task TimeProviderCodeFix_AddsNamedArgumentWhenNamedInitializerDoesNotSupplyExistingTimeProviderParameter() + { + var test = $$""" +using System; + +{{TimeProviderStub}} + +class Test +{ + public Test() : this(value: 1, name: "name") + { + } + + public Test(int value, global::System.TimeProvider timeProvider = null, string name = "default") + { + } + + void Method() + { + DateTimeOffset now = [|DateTimeOffset.UtcNow|]; + } +} +"""; + + var fixedCode = $$""" +using System; + +{{TimeProviderStub}} + +class Test +{ + private readonly global::System.TimeProvider _timeProvider; + + public Test(global::System.TimeProvider timeProvider) : this(value: 1, name: "name", timeProvider: timeProvider) + { + } + + public Test(int value, global::System.TimeProvider timeProvider = null, string name = "default") + { + _timeProvider = timeProvider; + } + + void Method() + { + DateTimeOffset now = _timeProvider.GetUtcNow(); + } +} +"""; + + await VerifyTimeProviderCodeFixAsync(test, fixedCode); + } + + [Fact] + public async Task TimeProviderCodeFix_ConvertsExpressionBodiedConstructorToBlock() + { + var test = $$""" +using System; + +{{TimeProviderStub}} + +class Test +{ + private int _value; + public Test() => _value = 1; + + void Method() + { + DateTimeOffset now = [|DateTimeOffset.Now|]; + } +} +"""; + + var fixedCode = $$""" +using System; + +{{TimeProviderStub}} + +class Test +{ + private readonly global::System.TimeProvider _timeProvider; + private int _value; + public Test(global::System.TimeProvider timeProvider) + { + _timeProvider = timeProvider; + _value = 1; + } + + void Method() + { + DateTimeOffset now = _timeProvider.GetLocalNow(); + } +} +"""; + + await VerifyTimeProviderCodeFixAsync(test, fixedCode); + } + + [Fact] + public async Task TimeProviderCodeFix_UsesUniqueParameterNameWhenColliding() + { + var test = $$""" +using System; + +{{TimeProviderStub}} + +class Test +{ + public Test(string timeProvider) + { + } + + void Method() + { + DateTimeOffset now = [|DateTimeOffset.UtcNow|]; + } +} +"""; + + var fixedCode = $$""" +using System; + +{{TimeProviderStub}} + +class Test +{ + private readonly global::System.TimeProvider _timeProvider; + + public Test(string timeProvider, global::System.TimeProvider timeProvider1) + { + _timeProvider = timeProvider1; + } + + void Method() + { + DateTimeOffset now = _timeProvider.GetUtcNow(); + } +} +"""; + + await VerifyTimeProviderCodeFixAsync(test, fixedCode); + } + + [Fact] + public async Task TimeProviderCodeFix_InsertsRequiredParameterBeforeOptionalParameters() + { + var test = $$""" +using System; + +{{TimeProviderStub}} + +class Test +{ + public Test(string name = "default") + { + } + + void Method() + { + DateTimeOffset now = [|DateTimeOffset.UtcNow|]; + } +} +"""; + + var fixedCode = $$""" +using System; + +{{TimeProviderStub}} + +class Test +{ + private readonly global::System.TimeProvider _timeProvider; + + public Test(global::System.TimeProvider timeProvider, string name = "default") + { + _timeProvider = timeProvider; + } + + void Method() + { + DateTimeOffset now = _timeProvider.GetUtcNow(); + } +} +"""; + + await VerifyTimeProviderCodeFixAsync(test, fixedCode); + } + + [Fact] + public async Task TimeProviderCodeFix_RewritesConstructorChainForSameDocumentPartialDeclaration() + { + var test = $$""" +using System; + +{{TimeProviderStub}} + +partial class Test +{ + public Test() : this(1) + { + } + + void Method() + { + DateTime now = [|DateTime.UtcNow|]; + } +} + +partial class Test +{ + public Test(int value) + { + } +} +"""; + + var fixedCode = $$""" +using System; + +{{TimeProviderStub}} + +partial class Test +{ + private readonly global::System.TimeProvider _timeProvider; + + public Test(global::System.TimeProvider timeProvider) : this(1, timeProvider) + { + } + + void Method() + { + DateTime now = _timeProvider.GetUtcNow().UtcDateTime; + } +} + +partial class Test +{ + public Test(int value, global::System.TimeProvider timeProvider) + { + _timeProvider = timeProvider; + } +} +"""; + + await VerifyTimeProviderCodeFixAsync(test, fixedCode); + } + + [Fact] + public async Task TimeProviderCodeFix_ReusesTimeProviderFieldFromOtherPartialDeclaration() + { + var test = $$""" +using System; + +{{TimeProviderStub}} + +partial class Test +{ + public Test() + { + _timeProvider = null!; + } + + void Method() + { + DateTimeOffset now = [|DateTimeOffset.UtcNow|]; + } +} + +partial class Test +{ + private readonly global::System.TimeProvider _timeProvider; + + public Test(string name) + { + _timeProvider = null!; + } +} +"""; + + var fixedCode = $$""" +using System; + +{{TimeProviderStub}} + +partial class Test +{ + public Test(global::System.TimeProvider timeProvider) + { + _timeProvider = timeProvider; + } + + void Method() + { + DateTimeOffset now = _timeProvider.GetUtcNow(); + } +} + +partial class Test +{ + private readonly global::System.TimeProvider _timeProvider; + + public Test(string name, global::System.TimeProvider timeProvider) + { + _timeProvider = timeProvider; + } +} +"""; + + await VerifyTimeProviderCodeFixAsync(test, fixedCode); + } + + [Fact] + public async Task TimeProviderCodeFix_ReplacesExistingFieldAssignmentWithInjectedParameter() + { + var test = $$""" +using System; + +{{TimeProviderStub}} + +class Test +{ + private readonly global::System.TimeProvider _timeProvider; + + public Test(string timeProvider) + { + global::System.TimeProvider assignedProvider = null; + _timeProvider = assignedProvider; + } + + void Method() + { + DateTimeOffset now = [|DateTimeOffset.UtcNow|]; + } +} +"""; + + var fixedCode = $$""" +using System; + +{{TimeProviderStub}} + +class Test +{ + private readonly global::System.TimeProvider _timeProvider; + + public Test(string timeProvider, global::System.TimeProvider timeProvider1) + { + global::System.TimeProvider assignedProvider = null; + _timeProvider = timeProvider1; + } + + void Method() + { + DateTimeOffset now = _timeProvider.GetUtcNow(); + } +} +"""; + + await VerifyTimeProviderCodeFixAsync(test, fixedCode); + } + + [Fact] + public async Task TimeProviderCodeFix_ReplacesNonIdentifierFieldAssignmentWhenInjectedParameterAdded() + { + var test = $$""" +using System; + +{{TimeProviderStub}} + +class Test +{ + private readonly global::System.TimeProvider _timeProvider; + + public Test(string timeProvider) + { + _timeProvider = null; + } + + void Method() + { + DateTimeOffset now = [|DateTimeOffset.UtcNow|]; + } +} +"""; + + var fixedCode = $$""" +using System; + +{{TimeProviderStub}} + +class Test +{ + private readonly global::System.TimeProvider _timeProvider; + + public Test(string timeProvider, global::System.TimeProvider timeProvider1) + { + _timeProvider = timeProvider1; + } + + void Method() + { + DateTimeOffset now = _timeProvider.GetUtcNow(); + } +} +"""; + + await VerifyTimeProviderCodeFixAsync(test, fixedCode); + } + + [Fact] + public async Task TimeProviderCodeFix_PreservesNonIdentifierFieldAssignmentExpression() + { + var test = $$""" +using System; + +{{TimeProviderStub}} + +class Test +{ + private readonly global::System.TimeProvider _timeProvider; + + public Test(global::System.TimeProvider timeProvider) + { + _timeProvider = timeProvider ?? throw new ArgumentNullException(nameof(timeProvider)); + } + + void Method() + { + DateTimeOffset now = [|DateTimeOffset.UtcNow|]; + } +} +"""; + + var fixedCode = $$""" +using System; + +{{TimeProviderStub}} + +class Test +{ + private readonly global::System.TimeProvider _timeProvider; + + public Test(global::System.TimeProvider timeProvider) + { + _timeProvider = timeProvider ?? throw new ArgumentNullException(nameof(timeProvider)); + } + + void Method() + { + DateTimeOffset now = _timeProvider.GetUtcNow(); + } +} +"""; + + await VerifyTimeProviderCodeFixAsync(test, fixedCode); + } + + [Fact] + public async Task TimeProviderCodeFix_UsesUniqueFieldNameWhenDefaultNameTaken() + { + var test = $$""" +using System; + +{{TimeProviderStub}} + +class Test +{ + private readonly string _timeProvider; + + public Test() + { + } + + void Method() + { + DateTimeOffset now = [|DateTimeOffset.UtcNow|]; + } +} +"""; + + var fixedCode = $$""" +using System; + +{{TimeProviderStub}} + +class Test +{ + private readonly global::System.TimeProvider _timeProvider1; + private readonly string _timeProvider; + + public Test(global::System.TimeProvider timeProvider) + { + _timeProvider1 = timeProvider; + } + + void Method() + { + DateTimeOffset now = _timeProvider1.GetUtcNow(); + } +} +"""; + + await VerifyTimeProviderCodeFixAsync(test, fixedCode); + } + + [Fact] + public async Task TimeProviderCodeFix_ReusesDerivedTimeProviderFieldAndParameter() + { + var test = $$""" +using System; + +{{TimeProviderStub}} + +class CustomTimeProvider : TimeProvider +{ + public override DateTimeOffset GetUtcNow() => DateTimeOffset.MinValue; +} + +class Test +{ + private readonly CustomTimeProvider _timeProvider; + + public Test(CustomTimeProvider timeProvider) + { + _timeProvider = timeProvider; + } + + void Method() + { + DateTimeOffset now = [|DateTimeOffset.UtcNow|]; + } +} +"""; + + var fixedCode = $$""" +using System; + +{{TimeProviderStub}} + +class CustomTimeProvider : TimeProvider +{ + public override DateTimeOffset GetUtcNow() => DateTimeOffset.MinValue; +} + +class Test +{ + private readonly CustomTimeProvider _timeProvider; + + public Test(CustomTimeProvider timeProvider) + { + _timeProvider = timeProvider; + } + + void Method() + { + DateTimeOffset now = _timeProvider.GetUtcNow(); + } +} +"""; + + await VerifyTimeProviderCodeFixAsync(test, fixedCode); + } + + [Fact] + public async Task TimeProviderCodeFix_UsesDerivedFieldTypeWhenAddingConstructorParameter() + { + var test = $$""" +using System; + +{{TimeProviderStub}} + +class CustomTimeProvider : TimeProvider +{ + public override DateTimeOffset GetUtcNow() => DateTimeOffset.MinValue; +} + +class Test +{ + private readonly CustomTimeProvider _timeProvider; + + public Test(string name) + { + _timeProvider = null!; + } + + void Method() + { + DateTimeOffset now = [|DateTimeOffset.UtcNow|]; + } +} +"""; + + var fixedCode = $$""" +using System; + +{{TimeProviderStub}} + +class CustomTimeProvider : TimeProvider +{ + public override DateTimeOffset GetUtcNow() => DateTimeOffset.MinValue; +} + +class Test +{ + private readonly CustomTimeProvider _timeProvider; + + public Test(string name, global::CustomTimeProvider timeProvider) + { + _timeProvider = timeProvider; + } + + void Method() + { + DateTimeOffset now = _timeProvider.GetUtcNow(); + } +} +"""; + + await VerifyTimeProviderCodeFixAsync(test, fixedCode); + } + + [Fact] + public async Task TimeProviderCodeFix_UsesFullyQualifiedDerivedFieldTypeWhenNamespaceIsOutOfScope() + { + var test = $$""" +using System; + +{{TimeProviderStub}} + +namespace MyApp +{ + class CustomTimeProvider : TimeProvider + { + public override DateTimeOffset GetUtcNow() => DateTimeOffset.MinValue; + } +} + +class Test +{ + private readonly MyApp.CustomTimeProvider _timeProvider; + + public Test(string name) + { + _timeProvider = null!; + } + + void Method() + { + DateTimeOffset now = [|DateTimeOffset.UtcNow|]; + } +} +"""; + + var fixedCode = $$""" +using System; + +{{TimeProviderStub}} + +namespace MyApp +{ + class CustomTimeProvider : TimeProvider + { + public override DateTimeOffset GetUtcNow() => DateTimeOffset.MinValue; + } +} + +class Test +{ + private readonly MyApp.CustomTimeProvider _timeProvider; + + public Test(string name, global::MyApp.CustomTimeProvider timeProvider) + { + _timeProvider = timeProvider; + } + + void Method() + { + DateTimeOffset now = _timeProvider.GetUtcNow(); + } +} +"""; + + await VerifyTimeProviderCodeFixAsync(test, fixedCode); + } + + [Fact] + public async Task TimeProviderCodeFix_InsertsAssignmentBeforeConstructorReturn() + { + var test = $$""" +using System; + +{{TimeProviderStub}} + +class Test +{ + public Test() + { + return; + } + + void Method() + { + DateTimeOffset now = [|DateTimeOffset.UtcNow|]; + } +} +"""; + + var fixedCode = $$""" +using System; + +{{TimeProviderStub}} + +class Test +{ + private readonly global::System.TimeProvider _timeProvider; + + public Test(global::System.TimeProvider timeProvider) + { + _timeProvider = timeProvider; + return; + } + + void Method() + { + DateTimeOffset now = _timeProvider.GetUtcNow(); + } +} +"""; + + await VerifyTimeProviderCodeFixAsync(test, fixedCode); + } + + [Fact] + public async Task TimeProviderCodeFix_InsertsAssignmentBeforeNestedReturnPaths() + { + var test = $$""" +using System; + +{{TimeProviderStub}} + +class Test +{ + public Test(bool shouldReturn) + { + if (shouldReturn) + { + return; + } + } + + void Method() + { + DateTimeOffset now = [|DateTimeOffset.UtcNow|]; + } +} +"""; + + var fixedCode = $$""" +using System; + +{{TimeProviderStub}} + +class Test +{ + private readonly global::System.TimeProvider _timeProvider; + + public Test(bool shouldReturn, global::System.TimeProvider timeProvider) + { + _timeProvider = timeProvider; + if (shouldReturn) + { + return; + } + } + + void Method() + { + DateTimeOffset now = _timeProvider.GetUtcNow(); + } +} +"""; + + await VerifyTimeProviderCodeFixAsync(test, fixedCode); + } + + [Fact] + public async Task TimeProviderCodeFix_DoesNotModifyStaticConstructors() + { + var test = $$""" +using System; + +{{TimeProviderStub}} + +class Test +{ + static Test() + { + } + + void Method() + { + DateTimeOffset now = [|DateTimeOffset.UtcNow|]; + } +} +"""; + + var fixedCode = $$""" +using System; + +{{TimeProviderStub}} + +class Test +{ + private readonly global::System.TimeProvider _timeProvider; + + static Test() + { + } + + void Method() + { + DateTimeOffset now = _timeProvider.GetUtcNow(); + } + + public Test(global::System.TimeProvider timeProvider) + { + _timeProvider = timeProvider; + } +} +"""; + + await VerifyTimeProviderCodeFixAsync(test, fixedCode); + } + + [Fact] + public async Task TimeProviderCodeFix_DoesNotReuseStaticTimeProviderField() + { + var test = $$""" +using System; + +{{TimeProviderStub}} + +class Test +{ + private static readonly global::System.TimeProvider _timeProvider = null; + + public Test() + { + } + + void Method() + { + DateTimeOffset now = [|DateTimeOffset.UtcNow|]; + } +} +"""; + + var fixedCode = $$""" +using System; + +{{TimeProviderStub}} + +class Test +{ + private readonly global::System.TimeProvider _timeProvider1; + private static readonly global::System.TimeProvider _timeProvider = null; + + public Test(global::System.TimeProvider timeProvider) + { + _timeProvider1 = timeProvider; + } + + void Method() + { + DateTimeOffset now = _timeProvider1.GetUtcNow(); + } +} +"""; + + await VerifyTimeProviderCodeFixAsync(test, fixedCode); + } + + [Fact] + public async Task TimeProviderCodeFix_UsesUniqueParameterNameWhenConstructorBodyDeclaresTimeProviderLocal() + { + var test = $$""" +using System; + +{{TimeProviderStub}} + +class Test +{ + public Test() + { + var timeProvider = string.Empty; + } + + void Method() + { + DateTimeOffset now = [|DateTimeOffset.UtcNow|]; + } +} +"""; + + var fixedCode = $$""" +using System; + +{{TimeProviderStub}} + +class Test +{ + private readonly global::System.TimeProvider _timeProvider; + + public Test(global::System.TimeProvider timeProvider1) + { + _timeProvider = timeProvider1; + var timeProvider = string.Empty; + } + + void Method() + { + DateTimeOffset now = _timeProvider.GetUtcNow(); + } +} +"""; + + await VerifyTimeProviderCodeFixAsync(test, fixedCode); + } + + private static async Task VerifyTimeProviderFixAsync(string targetTypeName, string sourceTypeName, string memberName, string replacement) + { + var test = $$""" +using System; + +{{TimeProviderStub}} + +class Test +{ + public Test() + { + } + + void Method() + { + {{targetTypeName}} now = [|{{sourceTypeName}}.{{memberName}}|]; + } +} +"""; + + var fixedCode = $$""" +using System; + +{{TimeProviderStub}} + +class Test +{ + private readonly global::System.TimeProvider _timeProvider; + + public Test(global::System.TimeProvider timeProvider) + { + _timeProvider = timeProvider; + } + + void Method() + { + {{targetTypeName}} now = {{replacement}}; + } +} +"""; + + var testCase = new VerifyCS.Test + { + TestCode = test, + FixedCode = fixedCode, + CodeActionEquivalenceKey = TimeProviderCodeFixEquivalenceKey, + }; + + await testCase.RunAsync(); + } + + private static async Task VerifyTimeProviderCodeFixAsync(string testCode, string fixedCode) + { + var testCase = new VerifyCS.Test + { + TestCode = testCode, + FixedCode = fixedCode, + CodeActionEquivalenceKey = TimeProviderCodeFixEquivalenceKey, + }; + + await testCase.RunAsync(); + } }