Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[CodeGen] Support transform.oneflow.apply_patterns Op in MLIR #10255

Merged
merged 12 commits into from
May 12, 2023
12 changes: 6 additions & 6 deletions oneflow/ir/include/Transform/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
set(LLVM_TARGET_DEFINITIONS TestTransformDialectExtension.td)
mlir_tablegen(TestTransformDialectExtension.h.inc -gen-op-decls)
mlir_tablegen(TestTransformDialectExtension.cpp.inc -gen-op-defs)
mlir_tablegen(TestTransformDialectExtensionTypes.h.inc -gen-typedef-decls
set(LLVM_TARGET_DEFINITIONS TransformDialectExtension.td)
mlir_tablegen(TransformDialectExtension.h.inc -gen-op-decls)
mlir_tablegen(TransformDialectExtension.cpp.inc -gen-op-defs)
mlir_tablegen(TransformDialectExtensionTypes.h.inc -gen-typedef-decls
-typedefs-dialect=transform)
mlir_tablegen(TestTransformDialectExtensionTypes.cpp.inc -gen-typedef-defs
mlir_tablegen(TransformDialectExtensionTypes.cpp.inc -gen-typedef-defs
-typedefs-dialect=transform)
add_public_tablegen_target(MLIRTestTransformDialectExtensionIncGen)
add_public_tablegen_target(MLIROneFlowTransformDialectExtensionIncGen)
452 changes: 0 additions & 452 deletions oneflow/ir/include/Transform/TestTransformDialectExtension.td

This file was deleted.

51 changes: 0 additions & 51 deletions oneflow/ir/include/Transform/TestTransformStateExtension.h

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
//===- TestTransformDialectExtension.h --------------------------*- C++ -*-===//
//===- TransformDialectExtension.h --------------------------*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
Expand All @@ -26,8 +26,8 @@ limitations under the License.
//
//===----------------------------------------------------------------------===//

#ifndef MLIR_TESTTRANSFORMDIALECTEXTENSION_H
#define MLIR_TESTTRANSFORMDIALECTEXTENSION_H
#ifndef ONEFLOW_IR_INCLUDE_TRANSOFRM_TRANSFORM_DIALECT_EXTENSION_H_
#define ONEFLOW_IR_INCLUDE_TRANSOFRM_TRANSFORM_DIALECT_EXTENSION_H_

#include "mlir/Dialect/PDL/IR/PDLTypes.h"
#include "mlir/Dialect/Transform/IR/MatchInterfaces.h"
Expand All @@ -38,19 +38,26 @@ limitations under the License.
namespace mlir {
class DialectRegistry;

namespace transform {
namespace oneflow {
namespace transform_dialect {
/// Registers the test extension to the Transform dialect.
void registerTestTransformDialectExtension(::mlir::DialectRegistry& registry);
void registerTestTransformDialectEraseSchedulePass();
void registerTestTransformDialectInterpreterPass();
} // namespace transform
void registerTransformDialectExtension(::mlir::DialectRegistry& registry);
void registerTransformDialectEraseSchedulePass();
void registerTransformDialectInterpreterPass();

struct ApplyPatternsOpPatterns {
bool canonicalization = false;
};

} // namespace transform_dialect

} // namespace oneflow
} // namespace mlir

#define GET_TYPEDEF_CLASSES
#include "Transform/TestTransformDialectExtensionTypes.h.inc"
#include "Transform/TransformDialectExtensionTypes.h.inc"

#define GET_OP_CLASSES
#include "Transform/TestTransformDialectExtension.h.inc"
#include "Transform/TransformDialectExtension.h.inc"

#endif // MLIR_TESTTRANSFORMDIALECTEXTENSION_H
#endif // ONEFLOW_IR_INCLUDE_TRANSOFRM_TRANSFORM_DIALECT_EXTENSION_H_
71 changes: 71 additions & 0 deletions oneflow/ir/include/Transform/TransformDialectExtension.td
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
#ifndef ONEFLOW_IR_INCLUDE_TRANSOFRM_TRANSFORM_DIALECT_EXTENSION_TD_
#define ONEFLOW_IR_INCLUDE_TRANSOFRM_TRANSFORM_DIALECT_EXTENSION_TD_

include "mlir/Interfaces/SideEffectInterfaces.td"
include "mlir/IR/AttrTypeBase.td"
include "mlir/IR/OpBase.td"
include "mlir/Dialect/Transform/IR/MatchInterfaces.td"
include "mlir/Dialect/Transform/IR/TransformDialect.td"
include "mlir/Dialect/Transform/IR/TransformInterfaces.td"
include "mlir/Dialect/PDL/IR/PDLTypes.td"

def ApplyPatternsOp : Op<Transform_Dialect, "oneflow.apply_patterns",
[DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
TransformEachOpTrait,
TransformOpInterface]> {
let description = [{
Modified from iree project: https://github.com/openxla/iree
Greedily applies patterns as specified by its attributes.

Must be applied to an op with trait IsolatedFromAbove since the
GreedyPatternRewriter asserts those. Internally, uses the tracking rewriter
to preserve handles to payload operations nested within operations
associated with `target`. Fails if tracking cannot find replacement for a
payload operation. This may become controllable with an attribute in the
future.

Returns the IsolatedFromAbove op whose content it has modified for better
chaining APIs.

The following additive attributes can be set, they add patterns in an
unspecified order:
- canonicalization: adds all the canonicalization patterns of all
registered dialects and ops.


#### Return modes:

This operation applies a set of patterns specified by attributes. To apply
these patterns, this operation must target an operation that is isolated
from above, otherwise the transform definitely fails.

If the pattern application fails, or if the underlying listener fails to
capture op handles, the transformation definitely fails.

Otherwise the transformation is successful.

This operation does not consume the target handle and does not produce any
handle.
}];

let arguments = (ins TransformHandleTypeInterface:$target,
UnitAttr:$canonicalization);
let results = (outs);

let assemblyFormat = "$target attr-dict `:` functional-type($target, results)";
let cppNamespace = "mlir::oneflow::transform_dialect";

let builders = [
OpBuilder<(ins "Value":$target,
"const ApplyPatternsOpPatterns &":$patterns)>
];

let extraClassDeclaration = [{
::mlir::DiagnosedSilenceableFailure applyToOne(
::mlir::Operation *target,
::mlir::transform::ApplyToEachResultList &results,
::mlir::transform::TransformState &state);
}];
}

#endif // ONEFLOW_IR_INCLUDE_TRANSOFRM_TRANSFORM_DIALECT_EXTENSION_TD_
43 changes: 43 additions & 0 deletions oneflow/ir/include/Transform/TransformStateExtension.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
/*
Copyright 2020 The OneFlow Authors. All rights reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/

#ifndef ONEFLOW_IR_INCLUDE_TRANSOFRM_TRANSFORM_STATE_EXTENSION_H_
#define ONEFLOW_IR_INCLUDE_TRANSOFRM_TRANSFORM_STATE_EXTENSION_H_

#include "mlir/Dialect/Transform/IR/TransformInterfaces.h"

namespace mlir {
namespace oneflow {

namespace transform_dialect {
class TransformStateExtension : public ::mlir::transform::TransformState::Extension {
public:
TransformStateExtension(::mlir::transform::TransformState& state, StringAttr message)
: Extension(state), message(message) {}

StringRef getMessage() const { return message.getValue(); }

LogicalResult updateMapping(Operation* previous, Operation* updated);

private:
StringAttr message;
};

} // namespace transform_dialect
} // namespace oneflow
} // namespace mlir

#endif // ONEFLOW_IR_INCLUDE_TRANSOFRM_TRANSFORM_STATE_EXTENSION_H_
10 changes: 5 additions & 5 deletions oneflow/ir/lib/Transform/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
add_mlir_library(
MLIRTestTransformDialect
TestTransformDialectExtension.cpp
TestTransformDialectInterpreter.cpp
TestTransformStateExtension.cpp
MLIROneFlowTransformDialect
TransformDialectExtension.cpp
TransformDialectInterpreter.cpp
TransformStateExtension.cpp
EXCLUDE_FROM_LIBMLIR
DEPENDS
MLIRTestTransformDialectExtensionIncGen
MLIROneFlowTransformDialectExtensionIncGen
LINK_LIBS
PUBLIC
MLIRIR
Expand Down
Loading