Skip to content

Commit 6e3e387

Browse files
committed
ref: introduce templated formula and expression visitors
1 parent ac1aeaf commit 6e3e387

31 files changed

+810
-609
lines changed

dlinear/parser/smt2/Driver.cpp

+2-2
Original file line numberDiff line numberDiff line change
@@ -91,10 +91,10 @@ void Smt2Driver::GetValue(const std::vector<Term> &term_list) const {
9191
switch (term.type()) {
9292
case Term::Type::EXPRESSION: {
9393
const Expression &e{term.expression()};
94-
const ExpressionEvaluator evaluator{e};
94+
const ExpressionEvaluator evaluator{e, context_.config()};
9595
pp.Print(e);
9696
term_str = ss.str();
97-
const Interval iv{ExpressionEvaluator(term.expression())(box)};
97+
const Interval iv{ExpressionEvaluator(term.expression(), context_.config())(box)};
9898
value_str = (std::stringstream{} << iv).str();
9999
break;
100100
}

dlinear/parser/smt2/Sort.h

+24-24
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,9 @@
11
/**
2-
* @author Ernesto Casablanca (casablancaernesto@gmail.com)
3-
* @copyright 2024 dlinear
4-
* @licence BSD 3-Clause License
5-
* Sort enum.
6-
*/
2+
* @author Ernesto Casablanca (casablancaernesto@gmail.com)
3+
* @copyright 2024 dlinear
4+
* @licence BSD 3-Clause License
5+
* Sort enum.
6+
*/
77
#pragma once
88

99
#include <ostream>
@@ -15,34 +15,34 @@ namespace dlinear::smt2 {
1515

1616
/** Sort of a term. */
1717
enum class Sort {
18-
Binary, ///< Binary sort.
19-
Bool, ///< Boolean sort.
20-
Int, ///< Integer sort.
21-
Real, ///< Real sort.
18+
Binary, ///< Binary sort.
19+
Bool, ///< Boolean sort.
20+
Int, ///< Integer sort.
21+
Real, ///< Real sort.
2222
};
2323

2424
/**
25-
* Parse a string to a sort.
26-
* @param s string to parse
27-
* @return sort parsed from @p s
28-
*/
25+
* Parse a string to a sort.
26+
* @param s string to parse
27+
* @return sort parsed from @p s
28+
*/
2929
Sort ParseSort(const std::string &s);
3030
/**
31-
* Convert a sort to a variable type.
32-
*
33-
* The conversion is as follows:
34-
* - Binary -> BINARY
35-
* - Bool -> BOOLEAN
36-
* - Int -> INTEGER
37-
* - Real -> CONTINUOUS
38-
* @param sort sort to convert
39-
* @return variable type corresponding to @p sort
40-
*/
31+
* Convert a sort to a variable type.
32+
*
33+
* The conversion is as follows:
34+
* - Binary -> BINARY
35+
* - Bool -> BOOLEAN
36+
* - Int -> INTEGER
37+
* - Real -> CONTINUOUS
38+
* @param sort sort to convert
39+
* @return variable type corresponding to @p sort
40+
*/
4141
Variable::Type SortToType(Sort sort);
4242

4343
std::ostream &operator<<(std::ostream &os, const Sort &sort);
4444

45-
} // namespace dlinear::vnnlib
45+
} // namespace dlinear::smt2
4646

4747
#ifdef DLINEAR_INCLUDE_FMT
4848

dlinear/parser/vnnlib/Driver.cpp

+2-2
Original file line numberDiff line numberDiff line change
@@ -95,10 +95,10 @@ void VnnlibDriver::GetValue(const std::vector<Term> &term_list) const {
9595
switch (term.type()) {
9696
case Term::Type::EXPRESSION: {
9797
const Expression &e{term.expression()};
98-
const ExpressionEvaluator evaluator{e};
98+
const ExpressionEvaluator evaluator{e, context_.config()};
9999
pp.Print(e);
100100
term_str = ss.str();
101-
const Interval iv{ExpressionEvaluator(term.expression())(box)};
101+
const Interval iv{ExpressionEvaluator(term.expression(), context_.config())(box)};
102102
value_str = (std::stringstream{} << iv).str();
103103
break;
104104
}

dlinear/solver/ContextImpl.cpp

+3-3
Original file line numberDiff line numberDiff line change
@@ -89,9 +89,9 @@ void Context::Impl::AssertPiecewiseLinearFunction(const Variable &var, const For
8989
DLINEAR_ASSERT(!var.is_dummy() && var.get_type() == Variable::Type::CONTINUOUS, "Variable must be a real variable");
9090
DLINEAR_ASSERT(is_relational(cond), "Condition must be a relational formula");
9191

92-
const Formula condition_lit = predicate_abstractor_.Convert(cond);
93-
const Formula active_lit = predicate_abstractor_.Convert(var - active == 0);
94-
const Formula inactive_lit = predicate_abstractor_.Convert(var - inactive == 0);
92+
const Formula condition_lit = predicate_abstractor_(cond);
93+
const Formula active_lit = predicate_abstractor_(var - active == 0);
94+
const Formula inactive_lit = predicate_abstractor_(var - inactive == 0);
9595
// Make sure the cond is assigned a value (true or false) in the SAT solver
9696
const Formula force_assignment(condition_lit || !condition_lit);
9797
const Formula active_assertion{active_lit || !condition_lit};

dlinear/solver/SatSolver.cpp

+3-3
Original file line numberDiff line numberDiff line change
@@ -41,13 +41,13 @@ std::vector<std::vector<Literal>> SatSolver::clauses() const {
4141

4242
void SatSolver::AddFormula(const Formula &f) {
4343
DLINEAR_DEBUG_FMT("SatSolver::AddFormula({})", f);
44-
std::vector<Formula> clauses{cnfizer_.Convert(f)};
44+
auto [clauses, aux] = cnfizer_(f);
4545

4646
// Collect CNF variables and store them in `cnf_variables_`.
47-
for (const Variable &p : cnfizer_.vars()) cnf_variables_.insert(p.get_id());
47+
for (const Variable &p : aux) cnf_variables_.insert(p.get_id());
4848
// Convert a first-order clauses into a Boolean formula by predicate abstraction
4949
// The original can be retrieved by `predicate_abstractor_[abstracted_formula]`.
50-
for (Formula &clause : clauses) clause = predicate_abstractor_.Convert(clause);
50+
for (Formula &clause : clauses) clause = predicate_abstractor_.Process(clause);
5151

5252
AddClauses(clauses);
5353
}

dlinear/symbolic/BUILD.bazel

+23-4
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,11 @@ dlinear_cc_library(
2424
":literal",
2525
"//dlinear/libs:gmp",
2626
],
27-
deps = [":symbolic"],
27+
deps = [
28+
":expression_visitor",
29+
":formula_visitor",
30+
":symbolic",
31+
],
2832
)
2933

3034
dlinear_cc_library(
@@ -44,9 +48,20 @@ dlinear_cc_library(
4448

4549
dlinear_cc_library(
4650
name = "formula_visitor",
47-
srcs = ["FormulaVisitor.cpp"],
48-
hdrs = ["FormulaVisitor.h"],
49-
implementation_deps = ["//dlinear/util:exception"],
51+
hdrs = [
52+
"FormulaVisitor.h",
53+
"GenericFormulaVisitor.h",
54+
],
55+
deps = [
56+
":symbolic",
57+
"//dlinear/util:config",
58+
"//dlinear/util:stats",
59+
],
60+
)
61+
62+
dlinear_cc_library(
63+
name = "expression_visitor",
64+
hdrs = ["GenericExpressionVisitor.h"],
5065
deps = [
5166
":symbolic",
5267
"//dlinear/util:config",
@@ -94,6 +109,7 @@ dlinear_cc_library(
94109
"//dlinear/util:exception",
95110
],
96111
deps = [
112+
":expression_visitor",
97113
":symbolic",
98114
"//dlinear/util:box",
99115
],
@@ -105,6 +121,7 @@ dlinear_cc_library(
105121
hdrs = ["Nnfizer.h"],
106122
implementation_deps = ["//dlinear/util:logging"],
107123
deps = [
124+
":formula_visitor",
108125
":symbolic",
109126
"//dlinear/util:config",
110127
],
@@ -132,6 +149,8 @@ dlinear_cc_library(
132149
"//dlinear/util:timer",
133150
],
134151
deps = [
152+
":expression_visitor",
153+
":formula_visitor",
135154
":literal",
136155
":symbolic",
137156
"//dlinear/util:config",

dlinear/symbolic/ExpressionEvaluator.cpp

+16-11
Original file line numberDiff line numberDiff line change
@@ -16,22 +16,26 @@
1616

1717
namespace dlinear {
1818

19-
ExpressionEvaluator::ExpressionEvaluator(Expression e) : e_{std::move(e)} {}
19+
ExpressionEvaluator::ExpressionEvaluator(Expression e, const Config& config)
20+
: GenericExpressionVisitor<Interval, const Box&>{config, "ExpressionEvaluator"}, e_{std::move(e)} {}
2021

21-
Interval ExpressionEvaluator::operator()(const Box& box) const { return Visit(e_, box); }
22-
23-
Interval ExpressionEvaluator::Visit(const Expression& e, const Box& box) const {
24-
return VisitExpression<Interval>(this, e, box);
22+
Interval ExpressionEvaluator::Process(const Box& box) const {
23+
const TimerGuard timer_guard(&stats_.m_timer(), stats_.enabled());
24+
stats_.Increase();
25+
return VisitExpression(e_, box);
2526
}
27+
Interval ExpressionEvaluator::operator()(const Box& box) const { return Process(box); }
2628

27-
Interval ExpressionEvaluator::VisitVariable(const Expression& e, const Box& box) {
29+
Interval ExpressionEvaluator::VisitVariable(const Expression& e, const Box& box) const {
2830
const Variable& var{get_variable(e)};
2931
return box[var];
3032
}
3133

32-
Interval ExpressionEvaluator::VisitConstant(const Expression& e, const Box&) { return Interval{get_constant_value(e)}; }
34+
Interval ExpressionEvaluator::VisitConstant(const Expression& e, const Box&) const {
35+
return Interval{get_constant_value(e)};
36+
}
3337

34-
Interval ExpressionEvaluator::VisitRealConstant(const Expression&, const Box&) {
38+
Interval ExpressionEvaluator::VisitRealConstant(const Expression&, const Box&) const {
3539
DLINEAR_RUNTIME_ERROR("Operation is not supported yet.");
3640
}
3741

@@ -40,7 +44,7 @@ Interval ExpressionEvaluator::VisitAddition(const Expression& e, const Box& box)
4044
const auto& expr_to_coeff_map = get_expr_to_coeff_map_in_addition(e);
4145
return std::accumulate(expr_to_coeff_map.begin(), expr_to_coeff_map.end(), Interval{c},
4246
[this, &box](const Interval& init, const std::pair<const Expression, mpq_class>& p) {
43-
return init + Visit(p.first, box) * p.second;
47+
return init + VisitExpression(p.first, box) * p.second;
4448
});
4549
}
4650

@@ -148,11 +152,12 @@ Interval ExpressionEvaluator::VisitMax(const Expression&, const Box&) const {
148152
DLINEAR_RUNTIME_ERROR("Operation is not supported yet.");
149153
}
150154

151-
Interval ExpressionEvaluator::VisitIfThenElse(const Expression& /* unused */, const Box& /* unused */) {
155+
Interval ExpressionEvaluator::VisitIfThenElse(const Expression& /* unused */, const Box& /* unused */) const {
152156
DLINEAR_RUNTIME_ERROR("If-then-else expression is not supported yet.");
153157
}
154158

155-
Interval ExpressionEvaluator::VisitUninterpretedFunction(const Expression& /* unused */, const Box& /* unused */) {
159+
Interval ExpressionEvaluator::VisitUninterpretedFunction(const Expression& /* unused */,
160+
const Box& /* unused */) const {
156161
DLINEAR_RUNTIME_ERROR("Uninterpreted function is not supported.");
157162
}
158163

dlinear/symbolic/ExpressionEvaluator.h

+35-34
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88

99
#include <iosfwd>
1010

11+
#include "dlinear/symbolic/GenericExpressionVisitor.h"
1112
#include "dlinear/symbolic/symbolic.h"
1213
#include "dlinear/util/Box.h"
1314
#include "dlinear/util/Interval.h"
@@ -20,51 +21,51 @@ namespace dlinear {
2021
* The ExpressionEvaluator is used to evaluate an expression with a given box.
2122
* The box provides the values of the variables in the expression with intervals.
2223
*/
23-
class ExpressionEvaluator {
24+
class ExpressionEvaluator : public GenericExpressionVisitor<Interval, const Box&> {
2425
public:
25-
explicit ExpressionEvaluator(Expression e);
26+
/**
27+
* Construct a new ExpressionEvaluator object with the given expression and configuration.
28+
* @param e expression to evaluate
29+
* @param config configuration to use
30+
*/
31+
ExpressionEvaluator(Expression e, const Config& config);
2632

2733
/// Evaluates the expression with @p box.
28-
Interval operator()(const Box& box) const;
34+
[[nodiscard]] Interval Process(const Box& box) const;
35+
[[nodiscard]] Interval operator()(const Box& box) const;
2936

3037
[[nodiscard]] const Variables& variables() const { return e_.GetVariables(); }
31-
3238
[[nodiscard]] const Expression& expression() const { return e_; }
3339

3440
private:
35-
[[nodiscard]] Interval Visit(const Expression& e, const Box& box) const;
36-
static Interval VisitVariable(const Expression& e, const Box& box);
37-
static Interval VisitConstant(const Expression& e, const Box& box);
38-
static Interval VisitRealConstant(const Expression& e, const Box& box);
39-
[[nodiscard]] Interval VisitAddition(const Expression& e, const Box& box) const;
40-
[[nodiscard]] Interval VisitMultiplication(const Expression& e, const Box& box) const;
41-
[[nodiscard]] Interval VisitDivision(const Expression& e, const Box& box) const;
42-
[[nodiscard]] Interval VisitLog(const Expression& e, const Box& box) const;
43-
[[nodiscard]] Interval VisitAbs(const Expression& e, const Box& box) const;
44-
[[nodiscard]] Interval VisitExp(const Expression& e, const Box& box) const;
45-
[[nodiscard]] Interval VisitSqrt(const Expression& e, const Box& box) const;
46-
[[nodiscard]] Interval VisitPow(const Expression& e, const Box& box) const;
41+
[[nodiscard]] Interval VisitVariable(const Expression& e, const Box& box) const override;
42+
[[nodiscard]] Interval VisitConstant(const Expression& e, const Box& box) const override;
43+
[[nodiscard]] Interval VisitRealConstant(const Expression& e, const Box& box) const;
44+
[[nodiscard]] Interval VisitAddition(const Expression& e, const Box& box) const override;
45+
[[nodiscard]] Interval VisitMultiplication(const Expression& e, const Box& box) const override;
46+
[[nodiscard]] Interval VisitDivision(const Expression& e, const Box& box) const override;
47+
[[nodiscard]] Interval VisitLog(const Expression& e, const Box& box) const override;
48+
[[nodiscard]] Interval VisitAbs(const Expression& e, const Box& box) const override;
49+
[[nodiscard]] Interval VisitExp(const Expression& e, const Box& box) const override;
50+
[[nodiscard]] Interval VisitSqrt(const Expression& e, const Box& box) const override;
51+
[[nodiscard]] Interval VisitPow(const Expression& e, const Box& box) const override;
4752

4853
// Evaluates `pow(e1, e2)` with the @p box.
4954
[[nodiscard]] Interval VisitPow(const Expression& e1, const Expression& e2, const Box& box) const;
50-
[[nodiscard]] Interval VisitSin(const Expression& e, const Box& box) const;
51-
[[nodiscard]] Interval VisitCos(const Expression& e, const Box& box) const;
52-
[[nodiscard]] Interval VisitTan(const Expression& e, const Box& box) const;
53-
[[nodiscard]] Interval VisitAsin(const Expression& e, const Box& box) const;
54-
[[nodiscard]] Interval VisitAcos(const Expression& e, const Box& box) const;
55-
[[nodiscard]] Interval VisitAtan(const Expression& e, const Box& box) const;
56-
[[nodiscard]] Interval VisitAtan2(const Expression& e, const Box& box) const;
57-
[[nodiscard]] Interval VisitSinh(const Expression& e, const Box& box) const;
58-
[[nodiscard]] Interval VisitCosh(const Expression& e, const Box& box) const;
59-
[[nodiscard]] Interval VisitTanh(const Expression& e, const Box& box) const;
60-
[[nodiscard]] Interval VisitMin(const Expression& e, const Box& box) const;
61-
[[nodiscard]] Interval VisitMax(const Expression& e, const Box& box) const;
62-
static Interval VisitIfThenElse(const Expression& e, const Box& box);
63-
static Interval VisitUninterpretedFunction(const Expression& e, const Box& box);
64-
65-
// Makes VisitExpression a friend of this class so that it can use private
66-
// operator()s.
67-
friend Interval drake::symbolic::VisitExpression<Interval>(const ExpressionEvaluator*, const Expression&, const Box&);
55+
[[nodiscard]] Interval VisitSin(const Expression& e, const Box& box) const override;
56+
[[nodiscard]] Interval VisitCos(const Expression& e, const Box& box) const override;
57+
[[nodiscard]] Interval VisitTan(const Expression& e, const Box& box) const override;
58+
[[nodiscard]] Interval VisitAsin(const Expression& e, const Box& box) const override;
59+
[[nodiscard]] Interval VisitAcos(const Expression& e, const Box& box) const override;
60+
[[nodiscard]] Interval VisitAtan(const Expression& e, const Box& box) const override;
61+
[[nodiscard]] Interval VisitAtan2(const Expression& e, const Box& box) const override;
62+
[[nodiscard]] Interval VisitSinh(const Expression& e, const Box& box) const override;
63+
[[nodiscard]] Interval VisitCosh(const Expression& e, const Box& box) const override;
64+
[[nodiscard]] Interval VisitTanh(const Expression& e, const Box& box) const override;
65+
[[nodiscard]] Interval VisitMin(const Expression& e, const Box& box) const override;
66+
[[nodiscard]] Interval VisitMax(const Expression& e, const Box& box) const override;
67+
[[nodiscard]] Interval VisitIfThenElse(const Expression& e, const Box& box) const override;
68+
[[nodiscard]] Interval VisitUninterpretedFunction(const Expression& e, const Box& box) const override;
6869

6970
const Expression e_;
7071
};

dlinear/symbolic/FormulaVisitor.cpp

-44
This file was deleted.

0 commit comments

Comments
 (0)