Skip to content

Commit

Permalink
transpose matmul on graph
Browse files Browse the repository at this point in the history
  • Loading branch information
xhuohai committed Mar 6, 2025
1 parent 4f412cb commit f43c902
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 8 deletions.
14 changes: 10 additions & 4 deletions src/Nncase.Passes/Rules/Neutral/MatMulToConv2D.cs
Original file line number Diff line number Diff line change
Expand Up @@ -45,23 +45,26 @@ public sealed partial class MatMulToConv2D : IRewriteRule
return null;
}

var batchGT1 = aShape[0].FixedValue > 1;
var if_shape = new Shape(new[] { aShape[0].FixedValue, aShape[1].FixedValue, 1, 1 });
var w_shape = new Shape(new[] { bShape[1].FixedValue, bShape[0].FixedValue, 1, 1 });
var of_shape = new Shape(new[] { aShape[0].FixedValue, bShape[1].FixedValue });

var if_reshape = Reshape(a, if_shape);
var if_tp = Transpose(if_reshape, new[] { 3, 1, 2, 0 });
var w_tp = Transpose(b, Tensor.From<int>(new[] { 1, 0 })).InheritMetaData(b);
var w_reshape = Reshape(w_tp, w_shape).InheritMetaData(b);
var conv2d = Conv2D(
if_reshape,
batchGT1 ? if_tp : if_reshape,
w_reshape,
Tensor.FromScalar(0.0f, w_shape[0].FixedValue),
Tensor.FromScalar(1, new[] { 2 }),
Tensor.FromScalar(0, new[] { 2, 2 }),
new int[] { 1, 1 },
PadMode.Constant,
1).InheritMetaData(matMulCall);
return Reshape(conv2d, of_shape).InheritMetaData(matMulCall);
var of_tp = Transpose(conv2d, new[] { 3, 1, 2, 0 });
return Reshape(batchGT1 ? of_tp : conv2d, of_shape).InheritMetaData(matMulCall);
}
}

Expand Down Expand Up @@ -89,24 +92,27 @@ public sealed partial class BroadcastMatMulToConv2D : IRewriteRule
return null;
}

var batchGT1 = aShape[0].FixedValue * aShape[1].FixedValue > 1;
var if_shape = new Shape(new[] { aShape[0].FixedValue * aShape[1].FixedValue, aShape[2].FixedValue, 1, 1 });
var w_shape = new Shape(new[] { bShape[1].FixedValue, bShape[0].FixedValue, 1, 1 });
var of_shape = new Shape(new[] { aShape[0].FixedValue, aShape[1].FixedValue, bShape[1].FixedValue });

var if_reshape = Reshape(a, if_shape);
var if_tp = Transpose(if_reshape, new[] { 3, 1, 2, 0 });
var w_tp = Transpose(b, Tensor.From<int>(new[] { 1, 0 })).InheritMetaData(b);
var w_reshape = Reshape(w_tp, w_shape).InheritMetaData(b);

var conv2d = Conv2D(
if_reshape,
batchGT1 ? if_tp : if_reshape,
w_reshape,
Tensor.FromScalar(0.0f, w_shape[0].FixedValue),
Tensor.FromScalar(1, new[] { 2 }),
Tensor.FromScalar(0, new[] { 2, 2 }),
new int[] { 1, 1 },
PadMode.Constant,
1).InheritMetaData(matMulCall);
return Reshape(conv2d, of_shape).InheritMetaData(matMulCall);
var of_tp = Transpose(conv2d, new[] { 3, 1, 2, 0 });
return Reshape(batchGT1 ? of_tp : conv2d, of_shape).InheritMetaData(matMulCall);
}
}

Expand Down
14 changes: 10 additions & 4 deletions src/Nncase.Passes/Rules/WithMarker/MatMulToConv2DWithMarker.cs
Original file line number Diff line number Diff line change
Expand Up @@ -54,23 +54,26 @@ public sealed partial class MatMulToConv2DWithMarker : IRewriteRule
return null;
}

var batchGT1 = aShape[^2].FixedValue > 1;
var if_shape = new Shape(new[] { aShape[^2].FixedValue, aShape[^1].FixedValue, 1, 1 });
var w_shape = new Shape(new[] { bShape[^1].FixedValue, bShape[^2].FixedValue, 1, 1 });
var of_shape = new Shape(new[] { aShape[^2].FixedValue, bShape[^1].FixedValue });

var if_reshape = Reshape(a, if_shape);
var if_tp = Transpose(am.With(target: if_reshape), new[] { 3, 1, 2, 0 });
var w_tp = Transpose(b, Tensor.From<int>(new[] { 1, 0 })).InheritMetaData(b);
var w_reshape = Reshape(w_tp, w_shape).InheritMetaData(b);
var conv2d = Conv2D(
am.With(target: if_reshape),
batchGT1 ? am.With(target: if_tp) : am.With(target: if_reshape),
bm.With(target: w_reshape),
Tensor.FromScalar(0.0f, w_shape[0].FixedValue),
Tensor.FromScalar(1, new[] { 2 }),
Tensor.FromScalar(0, new[] { 2, 2 }),
new int[] { 1, 1 },
PadMode.Constant,
1).InheritMetaData(matMulCall);
var m = Reshape(marker.With(target: conv2d), of_shape).InheritMetaData(matMulCall);
var of_tp = Transpose(marker.With(target: conv2d), new[] { 3, 1, 2, 0 });
var m = Reshape(batchGT1 ? marker.With(target: of_tp) : marker.With(target: conv2d), of_shape).InheritMetaData(matMulCall);
DumpScope.Current.DumpIR(m, $"{_counter++}", "withMarker");
return m;
}
Expand Down Expand Up @@ -103,24 +106,27 @@ public sealed partial class BroadcastMatMulToConv2DWithMarker : IRewriteRule
return null;
}

var batchGT1 = aShape[0].FixedValue * aShape[1].FixedValue > 1;
var if_shape = new Shape(new[] { aShape[0].FixedValue * aShape[1].FixedValue, aShape[2].FixedValue, 1, 1 });
var w_shape = new Shape(new[] { bShape[1].FixedValue, bShape[0].FixedValue, 1, 1 });
var of_shape = new Shape(new[] { aShape[0].FixedValue, aShape[1].FixedValue, bShape[1].FixedValue });

var if_reshape = Reshape(am, if_shape);
var if_tp = Transpose(am.With(target: if_reshape), new[] { 3, 1, 2, 0 });
var w_tp = Transpose(b, Tensor.From<int>(new[] { 1, 0 })).InheritMetaData(b);
var w_reshape = Reshape(w_tp, w_shape).InheritMetaData(b);

var conv2d = Conv2D(
am.With(target: if_reshape),
batchGT1 ? am.With(target: if_tp) : am.With(target: if_reshape),
bm.With(target: w_reshape),
Tensor.FromScalar(0.0f, w_shape[0].FixedValue),
Tensor.FromScalar(1, new[] { 2 }),
Tensor.FromScalar(0, new[] { 2, 2 }),
new int[] { 1, 1 },
PadMode.Constant,
1).InheritMetaData(matMulCall);
var m = Reshape(marker.With(target: conv2d), of_shape).InheritMetaData(matMulCall);
var of_tp = Transpose(marker.With(target: conv2d), new[] { 3, 1, 2, 0 });
var m = Reshape(batchGT1 ? marker.With(target: of_tp) : marker.With(target: conv2d), of_shape).InheritMetaData(matMulCall);
return m;
}
}

0 comments on commit f43c902

Please sign in to comment.