From b4f2e1417f253a3439a5a80d7aeb50997d28346d Mon Sep 17 00:00:00 2001 From: David Wengier Date: Mon, 16 Nov 2020 11:37:55 +1100 Subject: [PATCH] Create an unsafe method from a local function when necessary --- .../ConvertLocalFunctionToMethodTests.cs | 96 +++++++++++++++++++ ...FunctionToMethodCodeRefactoringProvider.cs | 7 +- 2 files changed, 102 insertions(+), 1 deletion(-) diff --git a/src/EditorFeatures/CSharpTest/CodeActions/ConvertLocalFunctionToMethod/ConvertLocalFunctionToMethodTests.cs b/src/EditorFeatures/CSharpTest/CodeActions/ConvertLocalFunctionToMethod/ConvertLocalFunctionToMethodTests.cs index 86ca5bec30ba8..11700e987a376 100644 --- a/src/EditorFeatures/CSharpTest/CodeActions/ConvertLocalFunctionToMethod/ConvertLocalFunctionToMethodTests.cs +++ b/src/EditorFeatures/CSharpTest/CodeActions/ConvertLocalFunctionToMethod/ConvertLocalFunctionToMethodTests.cs @@ -821,6 +821,102 @@ C LocalFunction(C c) return null; } } +}"); + } + + [Fact, Trait(Traits.Feature, Traits.Features.CodeActionsConvertLocalFunctionToMethod)] + [WorkItem(32976, "https://github.com/dotnet/roslyn/issues/32976")] + public async Task TestUnsafeLocalFunction() + { + await TestInRegularAndScriptAsync( +@"class C +{ + public unsafe void UnsafeFunction() + { + byte b = 1; + [|unsafe byte* GetPtr(byte* bytePt) + { + return bytePt; + }|] + var aReference = GetPtr(&b); + } +}", +@"class C +{ + public unsafe void UnsafeFunction() + { + byte b = 1; + var aReference = GetPtr(&b); + } + + private static unsafe byte* GetPtr(byte* bytePt) + { + return bytePt; + } +}"); + } + + [Fact, Trait(Traits.Feature, Traits.Features.CodeActionsConvertLocalFunctionToMethod)] + [WorkItem(32976, "https://github.com/dotnet/roslyn/issues/32976")] + public async Task TestUnsafeLocalFunctionInUnsafeMethod() + { + await TestInRegularAndScriptAsync( +@"class C +{ + public unsafe void UnsafeFunction() + { + byte b = 1; + [|byte* GetPtr(byte* bytePt) + { + return bytePt; + }|] + var aReference = GetPtr(&b); + } +}", +@"class C +{ + public unsafe void UnsafeFunction() + { + byte b = 1; + var aReference = GetPtr(&b); + } + + private static unsafe byte* GetPtr(byte* bytePt) + { + return bytePt; + } +}"); + } + + [Fact, Trait(Traits.Feature, Traits.Features.CodeActionsConvertLocalFunctionToMethod)] + [WorkItem(32976, "https://github.com/dotnet/roslyn/issues/32976")] + public async Task TestLocalFunctionInUnsafeMethod() + { + await TestInRegularAndScriptAsync( +@"class C +{ + public unsafe void UnsafeFunction() + { + byte b = 1; + [|byte GetPtr(byte bytePt) + { + return bytePt; + }|] + var aReference = GetPtr(b); + } +}", +@"class C +{ + public unsafe void UnsafeFunction() + { + byte b = 1; + var aReference = GetPtr(b); + } + + private static byte GetPtr(byte bytePt) + { + return bytePt; + } }"); } } diff --git a/src/Features/CSharp/Portable/CodeRefactorings/ConvertLocalFunctionToMethod/CSharpConvertLocalFunctionToMethodCodeRefactoringProvider.cs b/src/Features/CSharp/Portable/CodeRefactorings/ConvertLocalFunctionToMethod/CSharpConvertLocalFunctionToMethodCodeRefactoringProvider.cs index dad44de950eed..56a2b494bbedf 100644 --- a/src/Features/CSharp/Portable/CodeRefactorings/ConvertLocalFunctionToMethod/CSharpConvertLocalFunctionToMethodCodeRefactoringProvider.cs +++ b/src/Features/CSharp/Portable/CodeRefactorings/ConvertLocalFunctionToMethod/CSharpConvertLocalFunctionToMethodCodeRefactoringProvider.cs @@ -107,13 +107,18 @@ private static async Task UpdateDocumentAsync( var containerSymbol = semanticModel.GetDeclaredSymbol(container, cancellationToken); var isStatic = containerSymbol.IsStatic || captures.All(capture => !capture.IsThisParameter()); + // GetSymbolModifiers actually checks if the local function needs to be unsafe, not whether + // it is declared as such, so this check we don't need to worry about whether the containing method + // is unsafe, this will just work regardless. + var needsUnsafe = declaredSymbol.GetSymbolModifiers().IsUnsafe; + var methodName = GenerateUniqueMethodName(declaredSymbol); var parameters = declaredSymbol.Parameters; var methodSymbol = CodeGenerationSymbolFactory.CreateMethodSymbol( containingType: declaredSymbol.ContainingType, attributes: default, accessibility: Accessibility.Private, - modifiers: new DeclarationModifiers(isStatic, isAsync: declaredSymbol.IsAsync), + modifiers: new DeclarationModifiers(isStatic, isAsync: declaredSymbol.IsAsync, isUnsafe: needsUnsafe), returnType: declaredSymbol.ReturnType, refKind: default, explicitInterfaceImplementations: default,