Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added support for no top-level statements #1895

Merged
merged 4 commits into from
May 17, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ internal class ProjectModifier
private readonly IEnumerable<string> _files;
private readonly IConsoleLogger _consoleLogger;
private PropertyInfo? _codeModifierConfigPropertyInfo;

private const string Main = nameof(Main);
public ProjectModifier(ProvisioningToolOptions toolOptions, IEnumerable<string> files, IConsoleLogger consoleLogger)
{
_toolOptions = toolOptions ?? throw new ArgumentNullException(nameof(toolOptions));
Expand Down Expand Up @@ -75,12 +75,14 @@ public async Task AddAuthCodeAsync()
return;
}

var isMinimalApp = await ProjectModifierHelper.IsMinimalApp(project);
var isMinimalApp = await ProjectModifierHelper.IsMinimalApp(project.Documents.ToList());
var useTopLevelsStatements = await ProjectModifierHelper.IsUsingTopLevelStatements(project.Documents.ToList());
CodeChangeOptions options = new CodeChangeOptions
{
MicrosoftGraph = _toolOptions.CallsGraph,
DownstreamApi = _toolOptions.CallsDownstreamApi,
IsMinimalApp = isMinimalApp
IsMinimalApp = isMinimalApp,
UsingTopLevelsStatements = useTopLevelsStatements
};

// Go through all the files, make changes using DocumentBuilder.
Expand Down Expand Up @@ -225,7 +227,7 @@ internal async Task ModifyCsFile(CodeFile file, CodeAnalysis.Project project, Co
if (file.FileName.Equals("Startup.cs"))
{
// Startup class file name may be different
file.FileName = await ProjectModifierHelper.GetStartupClass(project) ?? file.FileName;
file.FileName = await ProjectModifierHelper.GetStartupClass(project.Documents.ToList()) ?? file.FileName;
}

var fileDoc = project.Documents.Where(d => d.Name.Equals(file.FileName)).FirstOrDefault();
Expand Down Expand Up @@ -260,8 +262,7 @@ internal async Task ModifyCsFile(CodeFile file, CodeAnalysis.Project project, Co
private static SyntaxNode? ModifyRoot(DocumentBuilder documentBuilder, CodeChangeOptions options, CodeFile file)
{
var root = documentBuilder.AddUsings(options);
if (file.FileName.Equals("Program.cs") && file.Methods.TryGetValue("Global", out var globalChanges)
&& root.Members.Any(node => node.IsKind(SyntaxKind.GlobalStatement)))
if (file.FileName.Equals("Program.cs") && file.Methods.TryGetValue("Global", out var globalChanges))
{
var filteredChanges = ProjectModifierHelper.FilterCodeSnippets(globalChanges.CodeChanges, options);
var updatedIdentifer = ProjectModifierHelper.GetBuilderVariableIdentifierTransformation(root.Members);
Expand All @@ -270,9 +271,20 @@ internal async Task ModifyCsFile(CodeFile file, CodeAnalysis.Project project, Co
(string oldValue, string newValue) = updatedIdentifer.Value;
filteredChanges = ProjectModifierHelper.UpdateVariables(filteredChanges, oldValue, newValue);
}

var updatedRoot = DocumentBuilder.ApplyChangesToMethod(root, filteredChanges);
return updatedRoot;
if (!options.UsingTopLevelsStatements)
{
var mainMethod = root?.ChildNodes().FirstOrDefault(n => n is MethodDeclarationSyntax
&& ((MethodDeclarationSyntax)n).Identifier.ToString().Equals(Main, StringComparison.OrdinalIgnoreCase));
if (mainMethod != null)
{
var updatedMethod = DocumentBuilder.ApplyChangesToMethod(mainMethod, filteredChanges);
return root?.ReplaceNode(mainMethod, updatedMethod);
}
}
else if (root.Members.Any(node => node.IsKind(SyntaxKind.GlobalStatement)))
{
return DocumentBuilder.ApplyChangesToMethod(root, filteredChanges);
}
}
else
{
Expand Down
126 changes: 88 additions & 38 deletions src/Scaffolding/VS.Web.CG.EFCore/DbContextEditorServices.cs
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
using Microsoft.CodeAnalysis.CSharp;
using Microsoft.CodeAnalysis.CSharp.Syntax;
using Microsoft.CodeAnalysis.Text;
using Microsoft.DotNet.Scaffolding.Shared.CodeModifier;
using Microsoft.DotNet.Scaffolding.Shared.Project;
using Microsoft.DotNet.Scaffolding.Shared.ProjectModel;
using Microsoft.VisualStudio.Web.CodeGeneration.DotNet;
Expand All @@ -35,7 +36,7 @@ public class DbContextEditorServices : IDbContextEditorServices
private const string WebApplicationCreateBuilder = "WebApplication.CreateBuilder";
private const string AddRazorPages = "Services.AddRazorPages()";
private const string CreateBuilder = "CreateBuilder(args)";

private const string Main = nameof(Main);

public DbContextEditorServices(
IProjectContext projectContext,
Expand Down Expand Up @@ -155,7 +156,13 @@ private string GetSafeModelName(string name, ITypeSymbol dbContext)
return safeName;
}

public EditSyntaxTreeResult EditStartupForNewContext(ModelType startUp, string dbContextTypeName, string dbContextNamespace, string dataBaseName, bool useSqlite)
public EditSyntaxTreeResult EditStartupForNewContext(
ModelType startUp,
string dbContextTypeName,
string dbContextNamespace,
string dataBaseName,
bool useSqlite,
bool useTopLevelStatements)
{
Contract.Assert(startUp != null && startUp.TypeSymbol != null);
Contract.Assert(!String.IsNullOrEmpty(dbContextTypeName));
Expand All @@ -169,12 +176,13 @@ public EditSyntaxTreeResult EditStartupForNewContext(ModelType startUp, string d

var startUpClassNode = rootNode.FindNode(declarationReference.Span);

var configServicesMethod = startUpClassNode.ChildNodes()
.FirstOrDefault(n => n is MethodDeclarationSyntax
&& ((MethodDeclarationSyntax)n).Identifier.ToString() == ConfigureServices) as MethodDeclarationSyntax;
var configRootProperty = TryGetIConfigurationRootProperty(startUp.TypeSymbol);
//if using Startup.cs, the ConfigureServices method should exist.
if (configServicesMethod != null && configRootProperty != null)
if (startUpClassNode.ChildNodes()
.FirstOrDefault(n =>
n is MethodDeclarationSyntax syntax &&
syntax.Identifier.ToString() == ConfigureServices)
is MethodDeclarationSyntax configServicesMethod && configRootProperty != null)
{
var servicesParam = configServicesMethod.ParameterList.Parameters
.FirstOrDefault(p => p.Type.ToString().Equals(IServiceCollection));
Expand Down Expand Up @@ -217,46 +225,56 @@ public EditSyntaxTreeResult EditStartupForNewContext(ModelType startUp, string d
//minimal hosting scenario
else
{
CompilationUnitSyntax classSyntax = startUpClassNode as CompilationUnitSyntax;
if (classSyntax != null)
var statementLeadingTrivia = string.Empty;
StatementSyntax dbContextExpression = null;
var compilationSyntax = rootNode as CompilationUnitSyntax;
if (!useTopLevelStatements)
{
//get leading trivia. there should be atleast one member
var statementLeadingTrivia = classSyntax.Members.First()?.GetLeadingTrivia().ToString();

string textToAddAtEnd = AddDbContextString(minimalHostingTemplate: true, useSqlite, statementLeadingTrivia);
_connectionStringsWriter.AddConnectionString(dbContextTypeName, dataBaseName, useSqlite: useSqlite);
textToAddAtEnd = Environment.NewLine + textToAddAtEnd;

//get builder identifier string, should exist
var builderExpression = classSyntax.Members.Where(st => st.ToString().Contains(WebApplicationCreateBuilder)).FirstOrDefault();
var builderIdentifierString = GetBuilderIdentifier(builderExpression);

//create syntax expression that adds DbContext
//added InvalidOperationExceptino if Configuration.GetConnectionString returns null.
var expression = SyntaxFactory.ParseStatement(string.Format(textToAddAtEnd,
string.Format("{0}.Services", builderIdentifierString),
dbContextTypeName,
string.Format("{0}.Configuration", builderIdentifierString),
string.Format(" ?? throw new InvalidOperationException(\"Connection string '{0}' not found.\")", dbContextTypeName)));
var dbContextExpression = SyntaxFactory.GlobalStatement(expression);

//get global statement to insert after (different for web app vs web api)
var statementToInsertAfter = classSyntax.Members.Where(st => st.ToString().Contains(AddRazorPages)).FirstOrDefault();
if (statementToInsertAfter == null)
{
statementToInsertAfter = classSyntax.Members.Where(st => st.ToString().Contains(CreateBuilder)).FirstOrDefault();
}

var newClassSyntax = classSyntax.InsertNodesAfter(statementToInsertAfter, new List<GlobalStatementSyntax>() { dbContextExpression });
var newRoot = rootNode.ReplaceNode(classSyntax, newClassSyntax);
MethodDeclarationSyntax methodSyntax = DocumentBuilder.GetMethodFromSyntaxRoot(compilationSyntax, Main);
dbContextExpression = GetAddDbContextStatement(methodSyntax.Body, dbContextTypeName, dbContextNamespace, useSqlite);
}
else if(useTopLevelStatements)
{
dbContextExpression = GetAddDbContextStatement(compilationSyntax, dbContextTypeName, dbContextNamespace, useSqlite);
}

if (statementLeadingTrivia != null && dbContextExpression != null)
{
var newRoot = compilationSyntax;
//add additional namespaces
var namespacesToAdd = new[] { "Microsoft.EntityFrameworkCore", "Microsoft.Extensions.DependencyInjection", dbContextNamespace };
foreach (var namespaceName in namespacesToAdd)
{
newRoot = RoslynCodeEditUtilities.AddUsingDirectiveIfNeeded(namespaceName, newRoot as CompilationUnitSyntax);
newRoot = RoslynCodeEditUtilities.AddUsingDirectiveIfNeeded(namespaceName, newRoot);
}
if (!useTopLevelStatements)
{
MethodDeclarationSyntax methodSyntax = DocumentBuilder.GetMethodFromSyntaxRoot(newRoot, Main);
var modifiedBlock = methodSyntax.Body;
var statementToInsertAround = methodSyntax.Body.ChildNodes().Where(st => st.ToString().Contains(AddRazorPages)).FirstOrDefault();
if (statementToInsertAround == null)
{
statementToInsertAround = methodSyntax.Body.ChildNodes().Where(st => st.ToString().Contains(CreateBuilder)).FirstOrDefault();
modifiedBlock = methodSyntax.Body.InsertNodesAfter(statementToInsertAround, new List<StatementSyntax>() { dbContextExpression });
}
else
{
modifiedBlock = methodSyntax.Body.InsertNodesBefore(statementToInsertAround, new List<StatementSyntax>() { dbContextExpression });
}
var modifiedMethod = methodSyntax.WithBody(modifiedBlock);
newRoot = newRoot.ReplaceNode(methodSyntax, modifiedMethod);
}
else
{
var statementToInsertAfter = newRoot.Members.Where(st => st.ToString().Contains(AddRazorPages)).FirstOrDefault();
if (statementToInsertAfter == null)
{
statementToInsertAfter = newRoot.Members.Where(st => st.ToString().Contains(CreateBuilder)).FirstOrDefault();
}

newRoot = newRoot.InsertNodesAfter(statementToInsertAfter, new List<GlobalStatementSyntax>() { SyntaxFactory.GlobalStatement(dbContextExpression) });
}

return new EditSyntaxTreeResult()
{
Edited = true,
Expand All @@ -273,6 +291,38 @@ public EditSyntaxTreeResult EditStartupForNewContext(ModelType startUp, string d
};
}

/// <summary>
/// Get the StatementSyntax that adds the db context to the WebApplicationBuilder.
/// </summary>
/// <param name="rootNode">Using the base class to allow this var to be either CompilationUnitSyntax or a MethodBodySyntax
/// To get the WebApplicationBuilder variable name
/// </param>
/// <param name="dbContextTypeName"></param>
/// <param name="dataBaseName"></param>
/// <param name="useSqlite"></param>
internal StatementSyntax GetAddDbContextStatement(SyntaxNode rootNode, string dbContextTypeName, string dataBaseName, bool useSqlite)
{
//get leading trivia. there should be atleast one member var statementLeadingTrivia = classSyntax.ChildNodes()
var statementLeadingTrivia = rootNode.ChildNodes().First()?.GetLeadingTrivia().ToString();
string textToAddAtEnd = AddDbContextString(minimalHostingTemplate: true, useSqlite, statementLeadingTrivia);
_connectionStringsWriter.AddConnectionString(dbContextTypeName, dataBaseName, useSqlite: useSqlite);
textToAddAtEnd = Environment.NewLine + textToAddAtEnd;

//get builder identifier string, should exist
var builderExpression = rootNode.ChildNodes().Where(st => st.ToString().Contains(WebApplicationCreateBuilder)).FirstOrDefault() as MemberDeclarationSyntax;
var builderIdentifierString = GetBuilderIdentifier(builderExpression);

//create syntax expression that adds DbContext
//added InvalidOperationExceptino if Configuration.GetConnectionString returns null.
var expression = SyntaxFactory.ParseStatement(string.Format(textToAddAtEnd,
string.Format("{0}.Services", builderIdentifierString),
dbContextTypeName,
string.Format("{0}.Configuration", builderIdentifierString),
string.Format(" ?? throw new InvalidOperationException(\"Connection string '{0}' not found.\")", dbContextTypeName))).WithLeadingTrivia(SyntaxFactory.Whitespace(statementLeadingTrivia));

return expression;
}

private string GetBuilderIdentifier(MemberDeclarationSyntax builderMember)
{
if (builderMember != null)
Expand Down
15 changes: 10 additions & 5 deletions src/Scaffolding/VS.Web.CG.EFCore/EntityFrameworkModelProcessor.cs
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ public async Task Process()
{
throw new InvalidOperationException(string.Format(MessageStrings.ModelTypeNotFound, "Program"));
}

if (!dbContextSymbols.Any())
{
//add nullable properties
Expand Down Expand Up @@ -367,6 +367,7 @@ private ReflectedTypesProvider GetReflectedTypesProvider(Compilation projectComp
_loader,
_logger);
}

private async Task GenerateNewDbContextAndRegisterProgramFile(ModelType programType, IApplicationInfo applicationInfo)
{
AssemblyAttributeGenerator assemblyAttributeGenerator = GetAssemblyAttributeGenerator();
Expand All @@ -382,17 +383,20 @@ private async Task GenerateNewDbContextAndRegisterProgramFile(ModelType programT
// Create a new Context
_logger.LogMessage(string.Format(MessageStrings.GeneratingDbContext, _dbContextFullTypeName));
bool nullabledEnabled = "enable".Equals(applicationInfo?.WorkspaceHelper?.GetMsBuildProperty("Nullable"), StringComparison.OrdinalIgnoreCase);
bool useTopLevelsStatements = await ProjectModifierHelper.IsUsingTopLevelStatements(_modelTypesLocator);
var dbContextTemplateModel = new NewDbContextTemplateModel(_dbContextFullTypeName, _modelTypeSymbol, programType, nullabledEnabled);
_dbContextSyntaxTree = await _dbContextEditorServices.AddNewContext(dbContextTemplateModel);
ContextProcessingStatus = ContextProcessingStatus.ContextAdded;

if (programType != null)
{
_programEditResult = _dbContextEditorServices.EditStartupForNewContext(programType,
_programEditResult = _dbContextEditorServices.EditStartupForNewContext(
programType,
dbContextTemplateModel.DbContextTypeName,
dbContextTemplateModel.DbContextNamespace,
dataBaseName: dbContextTemplateModel.DbContextTypeName + "-" + Guid.NewGuid().ToString(),
_useSqlite);
_useSqlite,
useTopLevelsStatements);
}

if (!_programEditResult.Edited)
Expand Down Expand Up @@ -452,14 +456,15 @@ private async Task GenerateNewDbContextAndRegister(ModelType startupType, ModelT

_dbContextSyntaxTree = await _dbContextEditorServices.AddNewContext(dbContextTemplateModel);
ContextProcessingStatus = ContextProcessingStatus.ContextAdded;

bool useTopLevelsStatements = await ProjectModifierHelper.IsUsingTopLevelStatements(_modelTypesLocator);
if (startupType != null)
{
_startupEditResult = _dbContextEditorServices.EditStartupForNewContext(startupType,
dbContextTemplateModel.DbContextTypeName,
dbContextTemplateModel.DbContextNamespace,
dataBaseName: dbContextTemplateModel.DbContextTypeName + "-" + Guid.NewGuid().ToString(),
_useSqlite);
_useSqlite,
useTopLevelsStatements);
}

if (!_startupEditResult.Edited)
Expand Down
4 changes: 3 additions & 1 deletion src/Scaffolding/VS.Web.CG.EFCore/IDbContextEditorServices.cs
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
// Copyright (c) .NET Foundation. All rights reserved.
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.

using System;
using System.Collections.Generic;
using System.Threading.Tasks;
using Microsoft.CodeAnalysis;
using Microsoft.DotNet.Scaffolding.Shared.Project;
Expand All @@ -13,6 +15,6 @@ public interface IDbContextEditorServices

EditSyntaxTreeResult AddModelToContext(ModelType dbContext, ModelType modelType, bool nullableEnabled);

EditSyntaxTreeResult EditStartupForNewContext(ModelType startup, string dbContextTypeName, string dbContextNamespace, string dataBaseName, bool useSqlite);
EditSyntaxTreeResult EditStartupForNewContext(ModelType startup, string dbContextTypeName, string dbContextNamespace, string dataBaseName, bool useSqlite, bool useTopLevelStatements);
}
}
Loading