Restructure for memoisation

This commit is contained in:
Dan Balasescu
2022-12-03 04:17:50 +09:00
parent 127a7bab6f
commit 485293484b
10 changed files with 222 additions and 69 deletions

View File

@@ -41,6 +41,9 @@ namespace osu.Framework.SourceGeneration.Emitters
public void Emit(AddSourceDelegate addSource)
{
if (!Candidate.IsValid)
return;
StringBuilder result = new StringBuilder();
result.Append(headers);

View File

@@ -1,7 +1,6 @@
// Copyright (c) ppy Pty Ltd <contact@ppy.sh>. Licensed under the MIT Licence.
// See the LICENCE file in the repository root for full licence text.
using System;
using System.Collections.Generic;
using System.Linq;
using Microsoft.CodeAnalysis;
@@ -11,25 +10,37 @@ using osu.Framework.SourceGeneration.Data;
namespace osu.Framework.SourceGeneration
{
public class GeneratorClassCandidate : IEquatable<GeneratorClassCandidate>
public class GeneratorClassCandidate
{
public readonly string TypeName;
public readonly string FullyQualifiedTypeName;
public readonly ClassDeclarationSyntax ClassSyntax;
public readonly string FullyQualifiedTypeName = string.Empty;
public readonly string TypeName = string.Empty;
public readonly bool NeedsOverride;
public readonly string? ContainingNamespace;
public readonly bool IsValid;
public readonly List<string> TypeHierarchy = new List<string>();
public readonly HashSet<CachedAttributeData> CachedInterfaces = new HashSet<CachedAttributeData>();
public readonly HashSet<CachedAttributeData> CachedMembers = new HashSet<CachedAttributeData>();
public readonly HashSet<CachedAttributeData> CachedClasses = new HashSet<CachedAttributeData>();
public readonly HashSet<ResolvedAttributeData> ResolvedMembers = new HashSet<ResolvedAttributeData>();
public readonly HashSet<BackgroundDependencyLoaderAttributeData> DependencyLoaderMembers = new HashSet<BackgroundDependencyLoaderAttributeData>();
public GeneratorClassCandidate(INamedTypeSymbol symbol)
public GeneratorClassCandidate(ClassDeclarationSyntax classSyntax, SemanticModel semanticModel)
{
TypeName = symbol.ToDisplayString();
ClassSyntax = classSyntax;
INamedTypeSymbol symbol = semanticModel.GetDeclaredSymbol(ClassSyntax)!;
// Determine if the class is a candidate for the source generator.
IsValid = symbol.AllInterfaces.Any(SyntaxHelpers.IsIDependencyInjectionCandidateInterface);
if (!IsValid)
return;
FullyQualifiedTypeName = SyntaxHelpers.GetFullyQualifiedTypeName(symbol);
TypeName = symbol.ToDisplayString();
NeedsOverride = symbol.BaseType != null && symbol.BaseType.AllInterfaces.Any(SyntaxHelpers.IsIDependencyInjectionCandidateInterface);
ContainingNamespace = symbol.ContainingNamespace.IsGlobalNamespace ? null : symbol.ContainingNamespace.ToDisplayString();
@@ -90,7 +101,7 @@ namespace osu.Framework.SourceGeneration
}
}
public static bool IsCandidate(SyntaxNode syntaxNode)
public static bool IsSyntaxTarget(SyntaxNode syntaxNode)
{
if (syntaxNode is not ClassDeclarationSyntax classSyntax)
return false;
@@ -98,29 +109,9 @@ namespace osu.Framework.SourceGeneration
if (classSyntax.AncestorsAndSelf().OfType<ClassDeclarationSyntax>().Any(c => !c.Modifiers.Any(SyntaxKind.PartialKeyword)))
return false;
if (classSyntax.BaseList == null && classSyntax.AttributeLists.Count == 0)
return false;
return true;
}
public static GeneratorClassCandidate? TryCreate(SyntaxNode syntaxNode, SemanticModel semanticModel)
{
if (syntaxNode is not ClassDeclarationSyntax classSyntax)
return null;
INamedTypeSymbol? symbol = semanticModel.GetDeclaredSymbol(classSyntax);
if (symbol == null)
return null;
// Determine if the class is a candidate for the source generator.
if (!symbol.AllInterfaces.Any(SyntaxHelpers.IsIDependencyInjectionCandidateInterface))
return null;
return new GeneratorClassCandidate(symbol);
}
private static string createTypeName(ITypeSymbol typeSymbol)
{
string name = typeSymbol.Name;
@@ -130,27 +121,5 @@ namespace osu.Framework.SourceGeneration
return name;
}
public bool Equals(GeneratorClassCandidate? other)
{
if (ReferenceEquals(null, other)) return false;
if (ReferenceEquals(this, other)) return true;
return FullyQualifiedTypeName == other.FullyQualifiedTypeName;
}
public override bool Equals(object? obj)
{
if (ReferenceEquals(null, obj)) return false;
if (ReferenceEquals(this, obj)) return true;
if (obj.GetType() != GetType()) return false;
return Equals((GeneratorClassCandidate)obj);
}
public override int GetHashCode()
{
return FullyQualifiedTypeName.GetHashCode();
}
}
}

View File

@@ -2,6 +2,8 @@
// See the LICENCE file in the repository root for full licence text.
using System.Collections.Generic;
using System.Linq;
using System.Text;
using Microsoft.CodeAnalysis;
using Microsoft.CodeAnalysis.CSharp;
using Microsoft.CodeAnalysis.CSharp.Syntax;
@@ -113,6 +115,33 @@ namespace osu.Framework.SourceGeneration
public static string GetFullyQualifiedTypeName(INamedTypeSymbol type)
=> type.ToDisplayString(SymbolDisplayFormat.CSharpErrorMessageFormat);
public static string GetFullyQualifiedSyntaxName(TypeDeclarationSyntax syntax)
{
StringBuilder sb = new StringBuilder();
foreach (var node in syntax.AncestorsAndSelf())
{
switch (node)
{
case NamespaceDeclarationSyntax ns:
sb.Append(ns.Name);
break;
case ClassDeclarationSyntax cls:
sb.Append(cls.Identifier.ToString());
if (cls.TypeParameterList != null)
sb.Append($"{{{string.Join(",", cls.TypeParameterList.Parameters.Select(p => p.Identifier.ToString()))}}}");
break;
default:
continue;
}
}
return sb.ToString();
}
public static IEnumerable<ITypeSymbol> GetDeclaredInterfacesOnType(INamedTypeSymbol type)
{
foreach (var declared in type.Interfaces)

View File

@@ -4,6 +4,7 @@
using System.Collections.Generic;
using System.Linq;
using Microsoft.CodeAnalysis;
using Microsoft.CodeAnalysis.CSharp.Syntax;
using osu.Framework.SourceGeneration.Emitters;
namespace osu.Framework.SourceGeneration
@@ -21,7 +22,7 @@ namespace osu.Framework.SourceGeneration
if (context.SyntaxContextReceiver is not CustomSyntaxContextReceiver receiver)
return;
foreach (var candidate in receiver.Candidates.Distinct())
foreach (var candidate in receiver.Candidates.Where(c => c.IsValid).Distinct(GeneratorClassCandidateComparer.DEFAULT))
new DependenciesFileEmitter(candidate).Emit(context.AddSource);
}
@@ -31,13 +32,10 @@ namespace osu.Framework.SourceGeneration
public void OnVisitSyntaxNode(GeneratorSyntaxContext context)
{
if (!GeneratorClassCandidate.IsCandidate(context.Node))
if (!GeneratorClassCandidate.IsSyntaxTarget(context.Node))
return;
GeneratorClassCandidate? candidate = GeneratorClassCandidate.TryCreate(context.Node, context.SemanticModel);
if (candidate != null)
Candidates.Add(candidate);
Candidates.Add(new GeneratorClassCandidate((ClassDeclarationSyntax)context.Node, context.SemanticModel));
}
}
}

View File

@@ -0,0 +1,27 @@
// Copyright (c) ppy Pty Ltd <contact@ppy.sh>. Licensed under the MIT Licence.
// See the LICENCE file in the repository root for full licence text.
using System.Collections.Generic;
namespace osu.Framework.SourceGeneration
{
public class GeneratorClassCandidateComparer : IEqualityComparer<GeneratorClassCandidate>
{
public static readonly GeneratorClassCandidateComparer DEFAULT = new GeneratorClassCandidateComparer();
public bool Equals(GeneratorClassCandidate x, GeneratorClassCandidate y)
{
if (ReferenceEquals(x, y)) return true;
if (ReferenceEquals(x, null)) return false;
if (ReferenceEquals(y, null)) return false;
if (x.GetType() != y.GetType()) return false;
return string.Equals(x.FullyQualifiedTypeName, y.FullyQualifiedTypeName);
}
public int GetHashCode(GeneratorClassCandidate obj)
{
return obj.FullyQualifiedTypeName.GetHashCode();
}
}
}

View File

@@ -8,5 +8,6 @@
<ItemGroup>
<Compile Include="DependencyInjectionSourceGenerator.cs" />
<Compile Include="GeneratorClassCandidateComparer.cs" />
</ItemGroup>
</Project>

View File

@@ -1,9 +1,11 @@
// Copyright (c) ppy Pty Ltd <contact@ppy.sh>. Licensed under the MIT Licence.
// See the LICENCE file in the repository root for full licence text.
using System;
using System.Collections.Generic;
using System.Linq;
using System.Threading;
using Microsoft.CodeAnalysis;
using Microsoft.CodeAnalysis.CSharp.Syntax;
using osu.Framework.SourceGeneration.Emitters;
namespace osu.Framework.SourceGeneration
@@ -13,22 +15,56 @@ namespace osu.Framework.SourceGeneration
{
public void Initialize(IncrementalGeneratorInitializationContext context)
{
IncrementalValuesProvider<GeneratorClassCandidate> candidateClasses =
context.SyntaxProvider.CreateSyntaxProvider(selectClasses, extractCandidates)
.Where(c => c != null);
// Stage 1: Create SyntaxTarget objects for all classes.
IncrementalValuesProvider<SyntaxTarget> syntaxTargets =
context.SyntaxProvider.CreateSyntaxProvider(
(n, _) => GeneratorClassCandidate.IsSyntaxTarget(n),
(ctx, _) => new SyntaxTarget((ClassDeclarationSyntax)ctx.Node, ctx.SemanticModel))
.Select((t, _) => t.WithName());
IncrementalValuesProvider<GeneratorClassCandidate> distinctCandidates =
candidateClasses.Collect().SelectMany((c, _) => c.Distinct());
// Stage 2: Separate out the old and new syntax targets for the same class object.
// At this point, there are a bunch of old and new syntax targets that may refer to the same class object.
// Find a distinct syntax target for any one class object, preferring the most-recent target.
// Example: Multi-partial definitions where one file is updated. We need to find the definition that was newly-updated.
// Example: Multi-partial definitions where an unrelated file is updated. Need to find the definition that was used for the last generation.
// Bug: Due to an internal bug in Roslyn, this may also occur for non-multi-partial files.
IncrementalValuesProvider<SyntaxTarget> distinctSyntaxTargets =
syntaxTargets
.Collect()
.SelectMany((targets, _) =>
{
// Ensure all targets have a generation ID. This is over-engineered as two loops to:
// 1. Increment the generation ID locally for deterministic test output.
// 2. Remain performant across many thousands of objects.
Dictionary<SyntaxTarget, long> maxGenerationIds = new Dictionary<SyntaxTarget, long>(SyntaxTargetNameComparer.DEFAULT);
context.RegisterImplementationSourceOutput(distinctCandidates, emit);
foreach (var target in targets)
{
maxGenerationIds.TryGetValue(target, out long existingValue);
maxGenerationIds[target] = Math.Max(existingValue, target.GenerationId ?? 0);
}
foreach (var target in targets)
target.GenerationId ??= maxGenerationIds[target] + 1;
HashSet<SyntaxTarget> result = new HashSet<SyntaxTarget>(SyntaxTargetNameComparer.DEFAULT);
// Filter out the targets, preferring the most recent at all times.
foreach (SyntaxTarget t in targets.OrderByDescending(t => t.GenerationId))
result.Add(t);
return result;
});
// Stage 3: Generate the semantic targets for the filtered syntax targets.
// For any old syntax targets, this is a no-op. For any new targets, this is a fairly complex operation involving semantic lookup.
IncrementalValuesProvider<GeneratorClassCandidate> semanticTargets =
distinctSyntaxTargets
.Select((t, _) => t.ResolveSemanticTarget());
context.RegisterImplementationSourceOutput(semanticTargets, emit);
}
private bool selectClasses(SyntaxNode syntaxNode, CancellationToken cancellationToken)
=> GeneratorClassCandidate.IsCandidate(syntaxNode);
private GeneratorClassCandidate extractCandidates(GeneratorSyntaxContext context, CancellationToken cancellationToken)
=> GeneratorClassCandidate.TryCreate(context.Node, context.SemanticModel)!;
private void emit(SourceProductionContext context, GeneratorClassCandidate candidate)
=> new DependenciesFileEmitter(candidate).Emit(context.AddSource);
}

View File

@@ -0,0 +1,61 @@
// Copyright (c) ppy Pty Ltd <contact@ppy.sh>. Licensed under the MIT Licence.
// See the LICENCE file in the repository root for full licence text.
using System;
using Microsoft.CodeAnalysis;
using Microsoft.CodeAnalysis.CSharp.Syntax;
namespace osu.Framework.SourceGeneration
{
public class SyntaxTarget : IEquatable<SyntaxTarget>
{
public readonly ClassDeclarationSyntax Syntax;
public string? SyntaxName { get; set; }
public long? GenerationId;
private SemanticModel? semanticModel;
private GeneratorClassCandidate? semanticTarget;
public SyntaxTarget(ClassDeclarationSyntax syntax, SemanticModel semanticModel)
{
Syntax = syntax;
this.semanticModel = semanticModel;
}
public SyntaxTarget WithName()
{
SyntaxName ??= SyntaxHelpers.GetFullyQualifiedSyntaxName(Syntax);
return this;
}
public GeneratorClassCandidate ResolveSemanticTarget()
{
semanticTarget ??= new GeneratorClassCandidate(Syntax, semanticModel!);
semanticModel = null;
return semanticTarget;
}
public bool Equals(SyntaxTarget? other)
{
if (ReferenceEquals(null, other)) return false;
if (ReferenceEquals(this, other)) return true;
return Syntax == other.Syntax;
}
public override bool Equals(object? obj)
{
if (ReferenceEquals(null, obj)) return false;
if (ReferenceEquals(this, obj)) return true;
if (obj.GetType() != GetType()) return false;
return Equals((SyntaxTarget)obj);
}
public override int GetHashCode()
{
return Syntax.GetHashCode();
}
}
}

View File

@@ -0,0 +1,27 @@
// Copyright (c) ppy Pty Ltd <contact@ppy.sh>. Licensed under the MIT Licence.
// See the LICENCE file in the repository root for full licence text.
using System.Collections.Generic;
namespace osu.Framework.SourceGeneration
{
public class SyntaxTargetNameComparer : IEqualityComparer<SyntaxTarget>
{
public static readonly SyntaxTargetNameComparer DEFAULT = new SyntaxTargetNameComparer();
public bool Equals(SyntaxTarget x, SyntaxTarget y)
{
if (ReferenceEquals(x, y)) return true;
if (ReferenceEquals(x, null)) return false;
if (ReferenceEquals(y, null)) return false;
if (x.GetType() != y.GetType()) return false;
return x.SyntaxName == y.SyntaxName;
}
public int GetHashCode(SyntaxTarget obj)
{
return obj.SyntaxName!.GetHashCode();
}
}
}

View File

@@ -8,5 +8,7 @@
<ItemGroup>
<Compile Include="DependencyInjectionSourceGenerator.cs" />
<Compile Include="SyntaxTarget.cs" />
<Compile Include="SyntaxTargetNameComparer.cs" />
</ItemGroup>
</Project>