diff --git a/mlir/python/mlir/dialects/_structured_transform_ops_ext.py b/mlir/python/mlir/dialects/_structured_transform_ops_ext.py index c5134b6e718f3..fd3dbca7c5a60 100644 --- a/mlir/python/mlir/dialects/_structured_transform_ops_ext.py +++ b/mlir/python/mlir/dialects/_structured_transform_ops_ext.py @@ -366,7 +366,7 @@ class MaskedVectorizeOp: def __init__( self, target: Union[Operation, OpView, Value], - vector_sizes: Union[DynamicIndexList, ArrayAttr], + vector_sizes: Optional[Union[DynamicIndexList, ArrayAttr]] = None, *, vectorize_nd_extract: Optional[bool] = None, scalable_sizes: OptionalBoolList = None, @@ -374,7 +374,13 @@ def __init__( loc=None, ip=None, ): - if scalable_sizes is None and static_vector_sizes is None: + if ( + scalable_sizes is None + and static_vector_sizes is None + and vector_sizes is None + ): + dynamic_vector_sizes = [] + elif scalable_sizes is None and static_vector_sizes is None: ( dynamic_vector_sizes, static_vector_sizes, diff --git a/mlir/test/python/dialects/transform_structured_ext.py b/mlir/test/python/dialects/transform_structured_ext.py index 5d5ee945b6686..69181160d5489 100644 --- a/mlir/test/python/dialects/transform_structured_ext.py +++ b/mlir/test/python/dialects/transform_structured_ext.py @@ -169,6 +169,16 @@ def testMatchOpNamesList(target): # CHECK-SAME: (!transform.any_op) -> !transform.any_op +@run +@create_sequence +def testMaskedVectorizeNoArgs(target): + structured.MaskedVectorizeOp(target) + # CHECK-LABEL: TEST: testMaskedVectorizeNoArgs + # CHECK: transform.sequence + # CHECK: transform.structured.masked_vectorize + # CHECK-NOT: vector_sizes + + @run @create_sequence def testMaskedVectorizeStatic(target):