From 1300e764335ff2ce612ba2546c7a0af4ca4db932 Mon Sep 17 00:00:00 2001 From: uranus0515 <110005227+uranus0515@users.noreply.github.com> Date: Thu, 2 Nov 2023 11:27:30 +0800 Subject: [PATCH] GNNE-1714:Fix/regression bugs (#1117) * revise bug for ort.transposeConv * Avoid change caused by reference types * Apply code-format changes * Change equal to similarity because of rounding * ToArray should be called outside the loop. * change name to Camel * change var name to camel --------- Co-authored-by: guodongliang Co-authored-by: uranus0515 Co-authored-by: FusionBolt <59008347+FusionBolt@users.noreply.github.com> --- src/Nncase.Evaluator/NN/Conv2DTranspose.cs | 95 ++++++++++++++++--- src/Nncase.Importer/TFLite/MatMul.cs | 12 +-- .../Evaluator/UnitTestEvaluatorNN.cs | 6 +- 3 files changed, 91 insertions(+), 22 deletions(-) diff --git a/src/Nncase.Evaluator/NN/Conv2DTranspose.cs b/src/Nncase.Evaluator/NN/Conv2DTranspose.cs index 626043680e..56ae681279 100644 --- a/src/Nncase.Evaluator/NN/Conv2DTranspose.cs +++ b/src/Nncase.Evaluator/NN/Conv2DTranspose.cs @@ -27,24 +27,91 @@ public IValue Visit(IEvaluateContext context, Conv2DTranspose conv) var stride = context.GetArgumentValueAsArray(conv, Conv2DTranspose.Stride); var outputShape = context.GetArgumentValueAsArray(conv, Conv2DTranspose.OutputShape); - // [w:[left right] h:[top bottom]] + // [h:[top bottom] w:[left right] ] var pads = context.GetArgumentValueAsArray(conv, Conv2DTranspose.Padding); - var outputPaddings = context.GetArgumentValueAsArray(conv, Conv2DTranspose.OutputPadding); + _ = context.GetArgumentValueAsArray(conv, Conv2DTranspose.OutputPadding); var dilation = context.GetArgumentValueAsArray(conv, Conv2DTranspose.Dilation); var groups = context.GetArgumentValueAsScalar(conv, Conv2DTranspose.Groups); var kernelShape = weights.Shape; - return OrtKI.ConvTranspose( - input, - OrtKI.Transpose(weights, new long[] { 1, 0, 2, 3 }), - bias, - "NOTSET", - dilation, - groups, - new long[] { kernelShape[2], kernelShape[3] }, - outputPaddings, - outputShape, - pads, - stride).ToValue(); + var inputShape = input.Shape; + + var outputSize = outputShape[0] * outputShape[1] * outputShape[2] * outputShape[3]; + float[] outCache = new float[outputSize]; + Array.Clear(outCache, 0, (int)outputSize); + + var gIC = inputShape[1] / groups; + var gOC = outputShape[1] / groups; + + var weightsArray = weights.ToArray(); + var inputsArray = input.ToArray(); + var biasArray = bias.ToArray(); + int inputIndex = 0; + for (int batch = 0; batch < inputShape[0]; batch++) + { + var outBatchP = outCache.AsSpan().Slice(batch * (int)outputShape[1] * (int)outputShape[2] * (int)outputShape[3]); + + for (int g = 0; g < groups; g++) + { + var outGroupP = outBatchP.Slice(g * (int)gOC * (int)outputShape[2] * (int)outputShape[3]); + var wGroupP = weightsArray.AsSpan().Slice((int)g * (int)gOC * (int)gIC * (int)kernelShape[2] * (int)kernelShape[3]); + + for (int ic = 0; ic < gIC; ic++) + { + for (int iy = 0; iy < inputShape[2]; iy++) + { + for (int ix = 0; ix < inputShape[3]; ix++) + { + int outYOrigin = (int)((iy * stride[0]) - pads[0]); + int outXOrigin = (int)((ix * stride[1]) - pads[2]); + int filterYStart = System.Math.Max(0, (int)((-outYOrigin + dilation[0] - 1) / dilation[0])); + int filterYEnd = (int)System.Math.Min(kernelShape[2], ((int)outputShape[2] - outYOrigin + dilation[0] - 1) / dilation[0]); + int filterXStart = (int)System.Math.Max(0, (-outXOrigin + dilation[1] - 1) / dilation[1]); + int filterXEnd = (int)System.Math.Min(kernelShape[3], ((int)outputShape[3] - outXOrigin + dilation[1] - 1) / dilation[1]); + + float inV; + if (ix < 0 || ix >= inputShape[3] || iy < 0 || iy >= inputShape[2]) + { + inV = 0f; + } + else + { + inV = inputsArray[inputIndex]; + } + + inputIndex++; + + for (int oc = 0; oc < gOC; oc++) + { + var outCP = outGroupP.Slice((int)(oc * outputShape[2] * outputShape[3])); + var wOCP = wGroupP.Slice((int)(oc * gIC * kernelShape[2] * kernelShape[3])); + var wICP = wOCP.Slice((int)(ic * kernelShape[2] * kernelShape[3])); + + for (int ky = filterYStart; ky < filterYEnd; ky++) + { + for (int kx = filterXStart; kx < filterXEnd; kx++) + { + int outY = (int)(outYOrigin + (dilation[0] * ky)); + int outX = (int)(outXOrigin + (dilation[1] * kx)); + + var w = wICP[(int)((ky * kernelShape[3]) + kx)]; + + outCP[(int)((outY * outputShape[3]) + outX)] += (float)inV * w; + } + } + } + } + } + } + } + } + + for (int i = 0; i < outputSize; i++) + { + var biasIdx = i / (outputShape[2] * outputShape[3]) % outputShape[1]; + outCache[i] = outCache[i] + biasArray[biasIdx]; + } + + return new TensorValue(Tensor.From(outCache, new[] { (int)outputShape[0], (int)outputShape[1], (int)outputShape[2], (int)outputShape[3] })); } /// diff --git a/src/Nncase.Importer/TFLite/MatMul.cs b/src/Nncase.Importer/TFLite/MatMul.cs index 59bbfcd1d7..056d2bdee5 100644 --- a/src/Nncase.Importer/TFLite/MatMul.cs +++ b/src/Nncase.Importer/TFLite/MatMul.cs @@ -66,14 +66,12 @@ private Expr VisitMatMul(in tflite.Operator op, bool isFullyConnected = true) : Expand(Cast(0, GetDataType(GetInputTensor(op, 0).Type)), new[] { otherTensor.Shape(0) }).Evaluate().AsTensor(); var matmul = MatMul(lhs, rhs); - List outputNames = new() { GetOutputTensor(op, 0).Name + "_matmul" }; - matmul.Metadata.OutputNames = outputNames; - outputNames.Clear(); - outputNames.Add(GetOutputTensor(op, 0).Name + "_bias"); - bias.Metadata.OutputNames = outputNames; + List outputNames_matmul = new() { GetOutputTensor(op, 0).Name + "_matmul" }; + matmul.Metadata.OutputNames = outputNames_matmul; + List outputNames_bias = new() { GetOutputTensor(op, 0).Name + "_bias" }; + bias.Metadata.OutputNames = outputNames_bias; var mm = matmul + bias; - outputNames.Clear(); - outputNames.Add(GetOutputTensor(op, 0).Name); + List outputNames = new() { GetOutputTensor(op, 0).Name }; mm.Metadata.OutputNames = outputNames; return fusedActivationFunction switch diff --git a/src/Nncase.Tests/Evaluator/UnitTestEvaluatorNN.cs b/src/Nncase.Tests/Evaluator/UnitTestEvaluatorNN.cs index 385dd863fe..74296b8cfb 100755 --- a/src/Nncase.Tests/Evaluator/UnitTestEvaluatorNN.cs +++ b/src/Nncase.Tests/Evaluator/UnitTestEvaluatorNN.cs @@ -5,6 +5,7 @@ using System.Collections.Generic; using System.IO; using System.Linq; +using NetFabric.Hyperlinq; using Nncase.Evaluator; using Nncase.IR; using Nncase.IR.F; @@ -275,7 +276,10 @@ public void TestConv2DTranspose() PadMode.Constant, 1); CompilerServices.InferenceType(expr); - Assert.Equal(expect, expr.Evaluate().AsTensor().ToOrtTensor()); + var expectValue = expect.ToArray(); + var realValue = expr.Evaluate().AsTensor().ToArray(); + var cos = Nncase.Tests.Comparator.CosSimilarity(expectValue, realValue); + Assert.True(cos >= 0.99); } [Fact]