diff --git a/osu.Framework.SourceGeneration/Analysers/DiagnosticRules.cs b/osu.Framework.SourceGeneration/Analysers/DiagnosticRules.cs
index dc6a82786..c1af3c327 100644
--- a/osu.Framework.SourceGeneration/Analysers/DiagnosticRules.cs
+++ b/osu.Framework.SourceGeneration/Analysers/DiagnosticRules.cs
@@ -13,12 +13,12 @@ namespace osu.Framework.SourceGeneration.Analysers
public static readonly DiagnosticDescriptor MAKE_DI_CLASS_PARTIAL = new DiagnosticDescriptor(
"OFSG001",
- "This class is a candidate for dependency injection and should be partial",
- "This class is a candidate for dependency injection and should be partial",
+ "This type, or a nested type, is a candidate for dependency injection and should be partial",
+ "This type, or a nested type, is a candidate for dependency injection and should be partial",
"Performance",
DiagnosticSeverity.Warning,
true,
- "Classes that are candidates for dependency injection should be made partial to benefit from compile-time optimisations.");
+ "Types that are candidates for dependency injection should be made partial to benefit from compile-time optimisations.");
#pragma warning restore RS2008
}
diff --git a/osu.Framework.SourceGeneration/Analysers/DrawableAnalyser.cs b/osu.Framework.SourceGeneration/Analysers/DrawableAnalyser.cs
index ae43c90ab..ec18e5194 100644
--- a/osu.Framework.SourceGeneration/Analysers/DrawableAnalyser.cs
+++ b/osu.Framework.SourceGeneration/Analysers/DrawableAnalyser.cs
@@ -24,19 +24,38 @@ namespace osu.Framework.SourceGeneration.Analysers
}
///
- /// Analyses class definitions for implementations of IDrawable, ISourceGeneratedDependencyActivator, and Transformable.
+ /// Analyses class definitions for implementations of IDependencyInjectionCandidateInterface.
///
private void analyseClass(SyntaxNodeAnalysisContext context)
{
var classSyntax = (ClassDeclarationSyntax)context.Node;
- if (classSyntax.Modifiers.Any(m => m.IsKind(SyntaxKind.PartialKeyword)))
+ if (classSyntax.Ancestors().OfType().Any())
return;
- INamedTypeSymbol? type = context.SemanticModel.GetDeclaredSymbol(classSyntax);
+ analyseRecursively(context, classSyntax);
- if (type?.AllInterfaces.Any(SyntaxHelpers.IsIDependencyInjectionCandidateInterface) == true)
- context.ReportDiagnostic(Diagnostic.Create(DiagnosticRules.MAKE_DI_CLASS_PARTIAL, context.Node.GetLocation(), context.Node));
+ static bool analyseRecursively(SyntaxNodeAnalysisContext context, ClassDeclarationSyntax node)
+ {
+ bool requiresPartial = false;
+
+ // Child nodes always have to be analysed to provide diagnostics.
+ foreach (var nested in node.DescendantNodes().OfType())
+ requiresPartial |= analyseRecursively(context, nested);
+
+ // - If at least one child requires partial, then this node also needs to be partial regardless of its own type (optimisation).
+ // - If no child requires partial, we need to check if this node is a DI candidate (e.g. If the node has no nested types).
+ if (!requiresPartial)
+ requiresPartial = context.SemanticModel.GetDeclaredSymbol(node)?.AllInterfaces.Any(SyntaxHelpers.IsIDependencyInjectionCandidateInterface) == true;
+
+ // Whether the node is already partial.
+ bool isPartial = node.Modifiers.Any(m => m.IsKind(SyntaxKind.PartialKeyword));
+
+ if (requiresPartial && !isPartial)
+ context.ReportDiagnostic(Diagnostic.Create(DiagnosticRules.MAKE_DI_CLASS_PARTIAL, node.GetLocation(), node));
+
+ return requiresPartial;
+ }
}
}
}