Skip to content

Commit

Permalink
Fix ReduceArg of TFLite importer
Browse files Browse the repository at this point in the history
  • Loading branch information
sunnycase committed Mar 6, 2025
1 parent 0273c62 commit b39a7bd
Showing 1 changed file with 2 additions and 13 deletions.
15 changes: 2 additions & 13 deletions src/Nncase.Importer/TFLite/Reduce.cs
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ public partial class TFLiteImporter
private Expr VisitReduce(in tflite.Operator op, ReduceOp reduceOp, float initValue)
{
var (input, axis) = GetInputExprs(op, 0, 1);
return Reduce(reduceOp, input, ProcAxis(axis), initValue, op.BuiltinOptionsAsReducerOptions().KeepDims);
return Reduce(reduceOp, input, axis, initValue, op.BuiltinOptionsAsReducerOptions().KeepDims);
}

private Expr VisitReduceArg(in tflite.Operator op, ReduceArgOp reduceArgOp)
Expand All @@ -25,18 +25,7 @@ private Expr VisitReduceArg(in tflite.Operator op, ReduceArgOp reduceArgOp)
_ => throw new ArgumentOutOfRangeException(nameof(reduceArgOp), reduceArgOp, null),
};

return ReduceArg(reduceArgOp, (PrimType)GetDataType(outType), input, ProcAxis(axis), false, false);
}

private Expr ProcAxis(Expr axis)
{
if (axis is TensorConst axisValue)
{
// scalar to array
return axisValue.Value.ToArray<int>();
}

return axis;
return ReduceArg(reduceArgOp, (PrimType)GetDataType(outType), input, axis, false, false);
}
}
}

0 comments on commit b39a7bd

Please sign in to comment.