diff --git a/osu.Framework.SourceGeneration/Analysers/DiagnosticRules.cs b/osu.Framework.SourceGeneration/Analysers/DiagnosticRules.cs index dc6a82786..c1af3c327 100644 --- a/osu.Framework.SourceGeneration/Analysers/DiagnosticRules.cs +++ b/osu.Framework.SourceGeneration/Analysers/DiagnosticRules.cs @@ -13,12 +13,12 @@ namespace osu.Framework.SourceGeneration.Analysers public static readonly DiagnosticDescriptor MAKE_DI_CLASS_PARTIAL = new DiagnosticDescriptor( "OFSG001", - "This class is a candidate for dependency injection and should be partial", - "This class is a candidate for dependency injection and should be partial", + "This type, or a nested type, is a candidate for dependency injection and should be partial", + "This type, or a nested type, is a candidate for dependency injection and should be partial", "Performance", DiagnosticSeverity.Warning, true, - "Classes that are candidates for dependency injection should be made partial to benefit from compile-time optimisations."); + "Types that are candidates for dependency injection should be made partial to benefit from compile-time optimisations."); #pragma warning restore RS2008 } diff --git a/osu.Framework.SourceGeneration/Analysers/DrawableAnalyser.cs b/osu.Framework.SourceGeneration/Analysers/DrawableAnalyser.cs index ae43c90ab..ec18e5194 100644 --- a/osu.Framework.SourceGeneration/Analysers/DrawableAnalyser.cs +++ b/osu.Framework.SourceGeneration/Analysers/DrawableAnalyser.cs @@ -24,19 +24,38 @@ namespace osu.Framework.SourceGeneration.Analysers } /// - /// Analyses class definitions for implementations of IDrawable, ISourceGeneratedDependencyActivator, and Transformable. + /// Analyses class definitions for implementations of IDependencyInjectionCandidateInterface. /// private void analyseClass(SyntaxNodeAnalysisContext context) { var classSyntax = (ClassDeclarationSyntax)context.Node; - if (classSyntax.Modifiers.Any(m => m.IsKind(SyntaxKind.PartialKeyword))) + if (classSyntax.Ancestors().OfType().Any()) return; - INamedTypeSymbol? type = context.SemanticModel.GetDeclaredSymbol(classSyntax); + analyseRecursively(context, classSyntax); - if (type?.AllInterfaces.Any(SyntaxHelpers.IsIDependencyInjectionCandidateInterface) == true) - context.ReportDiagnostic(Diagnostic.Create(DiagnosticRules.MAKE_DI_CLASS_PARTIAL, context.Node.GetLocation(), context.Node)); + static bool analyseRecursively(SyntaxNodeAnalysisContext context, ClassDeclarationSyntax node) + { + bool requiresPartial = false; + + // Child nodes always have to be analysed to provide diagnostics. + foreach (var nested in node.DescendantNodes().OfType()) + requiresPartial |= analyseRecursively(context, nested); + + // - If at least one child requires partial, then this node also needs to be partial regardless of its own type (optimisation). + // - If no child requires partial, we need to check if this node is a DI candidate (e.g. If the node has no nested types). + if (!requiresPartial) + requiresPartial = context.SemanticModel.GetDeclaredSymbol(node)?.AllInterfaces.Any(SyntaxHelpers.IsIDependencyInjectionCandidateInterface) == true; + + // Whether the node is already partial. + bool isPartial = node.Modifiers.Any(m => m.IsKind(SyntaxKind.PartialKeyword)); + + if (requiresPartial && !isPartial) + context.ReportDiagnostic(Diagnostic.Create(DiagnosticRules.MAKE_DI_CLASS_PARTIAL, node.GetLocation(), node)); + + return requiresPartial; + } } } }