@@ -905,18 +905,29 @@ Func Stage::rfactor(const vector<pair<RVar, Var>> &preserved) {
905
905
}
906
906
}
907
907
908
+ ReductionDomain rdom{definition.schedule ().rvars (), definition.predicate (), true };
909
+ SubstitutionMap rdom_promises;
910
+ for (int i = 0 ; i < rdom.domain ().size (); i++) {
911
+ const auto &[var, min, extent] = rdom.domain ()[i];
912
+ rdom_promises.emplace (var, promise_clamped (RVar (rdom, i), min, min + extent - 1 ));
913
+ }
914
+
908
915
// Project the RDom into each side
909
916
ReductionDomain intermediate_rdom, preserved_rdom;
910
917
SubstitutionMap intermediate_map, preserved_map;
911
918
{
912
- ReductionDomain rdom{definition.schedule ().rvars (), definition.predicate (), true };
913
-
914
919
// Intermediate
915
920
std::tie (intermediate_rdom, intermediate_map) = project_rdom (intermediate_rdims, rdom, rvar_splits);
916
921
for (size_t i = 0 ; i < preserved.size (); i++) {
917
922
add_let (intermediate_map, preserved_rdims[i].var , preserved_vars[i]);
918
923
}
919
- intermediate_rdom.set_predicate (simplify (substitute (intermediate_map, intermediate_rdom.predicate ())));
924
+
925
+ {
926
+ Expr pred = intermediate_rdom.predicate ();
927
+ pred = substitute (rdom_promises, pred);
928
+ pred = substitute (intermediate_map, pred);
929
+ intermediate_rdom.set_predicate (simplify (pred));
930
+ }
920
931
921
932
// Preserved
922
933
std::tie (preserved_rdom, preserved_map) = project_rdom (preserved_rdims, rdom, rvar_splits);
@@ -926,7 +937,13 @@ Func Stage::rfactor(const vector<pair<RVar, Var>> &preserved) {
926
937
const auto &[_, min, extent] = intermediate_rdom.domain ()[i];
927
938
intm_rdom.push (var, Interval{min, min + extent - 1 });
928
939
}
929
- preserved_rdom.set_predicate (or_condition_over_domain (substitute (preserved_map, preserved_rdom.predicate ()), intm_rdom));
940
+ {
941
+ Expr pred = preserved_rdom.predicate ();
942
+ pred = substitute (rdom_promises, pred);
943
+ pred = substitute (preserved_map, pred);
944
+ pred = or_condition_over_domain (pred, intm_rdom);
945
+ preserved_rdom.set_predicate (pred);
946
+ }
930
947
}
931
948
932
949
// Intermediate func
@@ -943,10 +960,12 @@ Func Stage::rfactor(const vector<pair<RVar, Var>> &preserved) {
943
960
{
944
961
vector<Expr> args = definition.args ();
945
962
args.insert (args.end (), preserved_vars.begin (), preserved_vars.end ());
963
+ args = substitute (rdom_promises, args);
946
964
args = substitute (intermediate_map, args);
947
965
948
966
vector<Expr> values = definition.values ();
949
967
values = substitute_self_reference (values, function.name (), intm.function (), preserved_vars);
968
+ values = substitute (rdom_promises, values);
950
969
values = substitute (intermediate_map, values);
951
970
intm.function ().define_update (args, values, intermediate_rdom);
952
971
@@ -1041,7 +1060,7 @@ Func Stage::rfactor(const vector<pair<RVar, Var>> &preserved) {
1041
1060
}
1042
1061
1043
1062
definition.args () = dim_vars_exprs;
1044
- definition.values () = substitute (preserved_map, prover_result.pattern .ops );
1063
+ definition.values () = substitute (preserved_map, substitute (rdom_promises, prover_result.pattern .ops ) );
1045
1064
definition.predicate () = preserved_rdom.predicate ();
1046
1065
definition.schedule ().dims () = subst_dims (preserved_map, reducing_dims);
1047
1066
definition.schedule ().rvars () = preserved_rdom.domain ();
0 commit comments