Skip to content

Commit 8480bcd

Browse files
authored
Add promise_clamped in rfactor (#8608)
Fixes #8600
1 parent d743c09 commit 8480bcd

File tree

2 files changed

+54
-7
lines changed

2 files changed

+54
-7
lines changed

src/Func.cpp

+24-5
Original file line numberDiff line numberDiff line change
@@ -905,18 +905,29 @@ Func Stage::rfactor(const vector<pair<RVar, Var>> &preserved) {
905905
}
906906
}
907907

908+
ReductionDomain rdom{definition.schedule().rvars(), definition.predicate(), true};
909+
SubstitutionMap rdom_promises;
910+
for (int i = 0; i < rdom.domain().size(); i++) {
911+
const auto &[var, min, extent] = rdom.domain()[i];
912+
rdom_promises.emplace(var, promise_clamped(RVar(rdom, i), min, min + extent - 1));
913+
}
914+
908915
// Project the RDom into each side
909916
ReductionDomain intermediate_rdom, preserved_rdom;
910917
SubstitutionMap intermediate_map, preserved_map;
911918
{
912-
ReductionDomain rdom{definition.schedule().rvars(), definition.predicate(), true};
913-
914919
// Intermediate
915920
std::tie(intermediate_rdom, intermediate_map) = project_rdom(intermediate_rdims, rdom, rvar_splits);
916921
for (size_t i = 0; i < preserved.size(); i++) {
917922
add_let(intermediate_map, preserved_rdims[i].var, preserved_vars[i]);
918923
}
919-
intermediate_rdom.set_predicate(simplify(substitute(intermediate_map, intermediate_rdom.predicate())));
924+
925+
{
926+
Expr pred = intermediate_rdom.predicate();
927+
pred = substitute(rdom_promises, pred);
928+
pred = substitute(intermediate_map, pred);
929+
intermediate_rdom.set_predicate(simplify(pred));
930+
}
920931

921932
// Preserved
922933
std::tie(preserved_rdom, preserved_map) = project_rdom(preserved_rdims, rdom, rvar_splits);
@@ -926,7 +937,13 @@ Func Stage::rfactor(const vector<pair<RVar, Var>> &preserved) {
926937
const auto &[_, min, extent] = intermediate_rdom.domain()[i];
927938
intm_rdom.push(var, Interval{min, min + extent - 1});
928939
}
929-
preserved_rdom.set_predicate(or_condition_over_domain(substitute(preserved_map, preserved_rdom.predicate()), intm_rdom));
940+
{
941+
Expr pred = preserved_rdom.predicate();
942+
pred = substitute(rdom_promises, pred);
943+
pred = substitute(preserved_map, pred);
944+
pred = or_condition_over_domain(pred, intm_rdom);
945+
preserved_rdom.set_predicate(pred);
946+
}
930947
}
931948

932949
// Intermediate func
@@ -943,10 +960,12 @@ Func Stage::rfactor(const vector<pair<RVar, Var>> &preserved) {
943960
{
944961
vector<Expr> args = definition.args();
945962
args.insert(args.end(), preserved_vars.begin(), preserved_vars.end());
963+
args = substitute(rdom_promises, args);
946964
args = substitute(intermediate_map, args);
947965

948966
vector<Expr> values = definition.values();
949967
values = substitute_self_reference(values, function.name(), intm.function(), preserved_vars);
968+
values = substitute(rdom_promises, values);
950969
values = substitute(intermediate_map, values);
951970
intm.function().define_update(args, values, intermediate_rdom);
952971

@@ -1041,7 +1060,7 @@ Func Stage::rfactor(const vector<pair<RVar, Var>> &preserved) {
10411060
}
10421061

10431062
definition.args() = dim_vars_exprs;
1044-
definition.values() = substitute(preserved_map, prover_result.pattern.ops);
1063+
definition.values() = substitute(preserved_map, substitute(rdom_promises, prover_result.pattern.ops));
10451064
definition.predicate() = preserved_rdom.predicate();
10461065
definition.schedule().dims() = subst_dims(preserved_map, reducing_dims);
10471066
definition.schedule().rvars() = preserved_rdom.domain();

test/correctness/rfactor.cpp

+30-2
Original file line numberDiff line numberDiff line change
@@ -1031,6 +1031,35 @@ int inlined_rfactor_with_disappearing_rvar_test() {
10311031
return 0;
10321032
}
10331033

1034+
// From issue: https://github.com/halide/Halide/issues/8600
1035+
int rfactor_precise_bounds_test() {
1036+
Var x("x"), y("y");
1037+
RDom r(0, 10, 0, 10);
1038+
1039+
// Create an input with random values
1040+
Buffer<uint8_t> input(10, 10, "input");
1041+
for (int y = 0; y < 10; ++y) {
1042+
for (int x = 0; x < 10; ++x) {
1043+
input(x, y) = (rand() % 256);
1044+
}
1045+
}
1046+
1047+
Func f;
1048+
1049+
f() = 0;
1050+
f() += input(r.x, r.y);
1051+
RVar rxo, rxi, ryo, ryi;
1052+
Func intm = f.update()
1053+
.tile(r.x, r.y, rxo, ryo, rxi, ryi, 4, 4)
1054+
.rfactor({{rxi, x}, {ryi, y}});
1055+
1056+
intm.compute_root();
1057+
1058+
Buffer<int> im = f.realize();
1059+
1060+
return 0;
1061+
}
1062+
10341063
} // namespace
10351064

10361065
int main(int argc, char **argv) {
@@ -1063,15 +1092,14 @@ int main(int argc, char **argv) {
10631092
{"tuple rfactor test: checking output img correctness...", tuple_rfactor_test<false>},
10641093
{"tuple specialize rdom predicate rfactor test: checking call graphs...", tuple_specialize_rdom_predicate_rfactor_test<true>},
10651094
{"tuple specialize rdom predicate rfactor test: checking output img correctness...", tuple_specialize_rdom_predicate_rfactor_test<false>},
1066-
{"parallel dot product rfactor test: checking call graphs...", parallel_dot_product_rfactor_test<true>},
1067-
{"parallel dot product rfactor test: checking output img correctness...", parallel_dot_product_rfactor_test<false>},
10681095
{"tuple partial reduction rfactor test: checking call graphs...", tuple_partial_reduction_rfactor_test<true>},
10691096
{"tuple partial reduction rfactor test: checking output img correctness...", tuple_partial_reduction_rfactor_test<false>},
10701097
{"check allocation bound test", check_allocation_bound_test},
10711098
{"rfactor tile reorder test: checking output img correctness...", rfactor_tile_reorder_test},
10721099
{"complex multiply rfactor test", complex_multiply_rfactor_test},
10731100
{"argmin rfactor test", argmin_rfactor_test},
10741101
{"inlined rfactor with disappearing rvar test", inlined_rfactor_with_disappearing_rvar_test},
1102+
{"rfactor bounds tests", rfactor_precise_bounds_test},
10751103
};
10761104

10771105
using Sharder = Halide::Internal::Test::Sharder;

0 commit comments

Comments
 (0)