Skip to content

Commit

Permalink
GNNE-1714:Fix/regression bugs (#1117)
Browse files Browse the repository at this point in the history
* revise bug for ort.transposeConv

* Avoid change caused by reference types

* Apply code-format changes

* Change equal to similarity because of rounding

* ToArray should be called outside the loop.

* change name to Camel

* change var name to camel

---------

Co-authored-by: guodongliang <guodongliang@canaan-creative.com>
Co-authored-by: uranus0515 <uranus0515@users.noreply.github.com>
Co-authored-by: FusionBolt <59008347+FusionBolt@users.noreply.github.com>
  • Loading branch information
4 people authored Nov 2, 2023
1 parent 1073f35 commit 1300e76
Show file tree
Hide file tree
Showing 3 changed files with 91 additions and 22 deletions.
95 changes: 81 additions & 14 deletions src/Nncase.Evaluator/NN/Conv2DTranspose.cs
Original file line number Diff line number Diff line change
Expand Up @@ -27,24 +27,91 @@ public IValue Visit(IEvaluateContext context, Conv2DTranspose conv)
var stride = context.GetArgumentValueAsArray<long>(conv, Conv2DTranspose.Stride);
var outputShape = context.GetArgumentValueAsArray<long>(conv, Conv2DTranspose.OutputShape);

// [w:[left right] h:[top bottom]]
// [h:[top bottom] w:[left right] ]
var pads = context.GetArgumentValueAsArray<long>(conv, Conv2DTranspose.Padding);
var outputPaddings = context.GetArgumentValueAsArray<long>(conv, Conv2DTranspose.OutputPadding);
_ = context.GetArgumentValueAsArray<long>(conv, Conv2DTranspose.OutputPadding);
var dilation = context.GetArgumentValueAsArray<long>(conv, Conv2DTranspose.Dilation);
var groups = context.GetArgumentValueAsScalar<long>(conv, Conv2DTranspose.Groups);
var kernelShape = weights.Shape;
return OrtKI.ConvTranspose(
input,
OrtKI.Transpose(weights, new long[] { 1, 0, 2, 3 }),
bias,
"NOTSET",
dilation,
groups,
new long[] { kernelShape[2], kernelShape[3] },
outputPaddings,
outputShape,
pads,
stride).ToValue();
var inputShape = input.Shape;

var outputSize = outputShape[0] * outputShape[1] * outputShape[2] * outputShape[3];
float[] outCache = new float[outputSize];
Array.Clear(outCache, 0, (int)outputSize);

var gIC = inputShape[1] / groups;
var gOC = outputShape[1] / groups;

var weightsArray = weights.ToArray<float>();
var inputsArray = input.ToArray<float>();
var biasArray = bias.ToArray<float>();
int inputIndex = 0;
for (int batch = 0; batch < inputShape[0]; batch++)
{
var outBatchP = outCache.AsSpan().Slice(batch * (int)outputShape[1] * (int)outputShape[2] * (int)outputShape[3]);

for (int g = 0; g < groups; g++)
{
var outGroupP = outBatchP.Slice(g * (int)gOC * (int)outputShape[2] * (int)outputShape[3]);
var wGroupP = weightsArray.AsSpan().Slice((int)g * (int)gOC * (int)gIC * (int)kernelShape[2] * (int)kernelShape[3]);

for (int ic = 0; ic < gIC; ic++)
{
for (int iy = 0; iy < inputShape[2]; iy++)
{
for (int ix = 0; ix < inputShape[3]; ix++)
{
int outYOrigin = (int)((iy * stride[0]) - pads[0]);
int outXOrigin = (int)((ix * stride[1]) - pads[2]);
int filterYStart = System.Math.Max(0, (int)((-outYOrigin + dilation[0] - 1) / dilation[0]));
int filterYEnd = (int)System.Math.Min(kernelShape[2], ((int)outputShape[2] - outYOrigin + dilation[0] - 1) / dilation[0]);
int filterXStart = (int)System.Math.Max(0, (-outXOrigin + dilation[1] - 1) / dilation[1]);
int filterXEnd = (int)System.Math.Min(kernelShape[3], ((int)outputShape[3] - outXOrigin + dilation[1] - 1) / dilation[1]);

float inV;
if (ix < 0 || ix >= inputShape[3] || iy < 0 || iy >= inputShape[2])
{
inV = 0f;
}
else
{
inV = inputsArray[inputIndex];
}

inputIndex++;

for (int oc = 0; oc < gOC; oc++)
{
var outCP = outGroupP.Slice((int)(oc * outputShape[2] * outputShape[3]));
var wOCP = wGroupP.Slice((int)(oc * gIC * kernelShape[2] * kernelShape[3]));
var wICP = wOCP.Slice((int)(ic * kernelShape[2] * kernelShape[3]));

for (int ky = filterYStart; ky < filterYEnd; ky++)
{
for (int kx = filterXStart; kx < filterXEnd; kx++)
{
int outY = (int)(outYOrigin + (dilation[0] * ky));
int outX = (int)(outXOrigin + (dilation[1] * kx));

var w = wICP[(int)((ky * kernelShape[3]) + kx)];

outCP[(int)((outY * outputShape[3]) + outX)] += (float)inV * w;
}
}
}
}
}
}
}
}

for (int i = 0; i < outputSize; i++)
{
var biasIdx = i / (outputShape[2] * outputShape[3]) % outputShape[1];
outCache[i] = outCache[i] + biasArray[biasIdx];
}

return new TensorValue(Tensor.From(outCache, new[] { (int)outputShape[0], (int)outputShape[1], (int)outputShape[2], (int)outputShape[3] }));
}

/// <inheritdoc/>
Expand Down
12 changes: 5 additions & 7 deletions src/Nncase.Importer/TFLite/MatMul.cs
Original file line number Diff line number Diff line change
Expand Up @@ -66,14 +66,12 @@ private Expr VisitMatMul(in tflite.Operator op, bool isFullyConnected = true)
: Expand(Cast(0, GetDataType(GetInputTensor(op, 0).Type)), new[] { otherTensor.Shape(0) }).Evaluate().AsTensor();

var matmul = MatMul(lhs, rhs);
List<string> outputNames = new() { GetOutputTensor(op, 0).Name + "_matmul" };
matmul.Metadata.OutputNames = outputNames;
outputNames.Clear();
outputNames.Add(GetOutputTensor(op, 0).Name + "_bias");
bias.Metadata.OutputNames = outputNames;
List<string> outputNames_matmul = new() { GetOutputTensor(op, 0).Name + "_matmul" };
matmul.Metadata.OutputNames = outputNames_matmul;
List<string> outputNames_bias = new() { GetOutputTensor(op, 0).Name + "_bias" };
bias.Metadata.OutputNames = outputNames_bias;
var mm = matmul + bias;
outputNames.Clear();
outputNames.Add(GetOutputTensor(op, 0).Name);
List<string> outputNames = new() { GetOutputTensor(op, 0).Name };
mm.Metadata.OutputNames = outputNames;

return fusedActivationFunction switch
Expand Down
6 changes: 5 additions & 1 deletion src/Nncase.Tests/Evaluator/UnitTestEvaluatorNN.cs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
using System.Collections.Generic;
using System.IO;
using System.Linq;
using NetFabric.Hyperlinq;
using Nncase.Evaluator;
using Nncase.IR;
using Nncase.IR.F;
Expand Down Expand Up @@ -275,7 +276,10 @@ public void TestConv2DTranspose()
PadMode.Constant,
1);
CompilerServices.InferenceType(expr);
Assert.Equal(expect, expr.Evaluate().AsTensor().ToOrtTensor());
var expectValue = expect.ToArray<float>();
var realValue = expr.Evaluate().AsTensor().ToArray<float>();
var cos = Nncase.Tests.Comparator.CosSimilarity(expectValue, realValue);
Assert.True(cos >= 0.99);
}

[Fact]
Expand Down

0 comments on commit 1300e76

Please sign in to comment.