Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Introduce len for ADTs to be used as a termination measure #591

Merged
merged 4 commits into from
Jan 10, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -502,7 +502,7 @@ trait ExprTyping extends BaseTyping { this: TypeInfoImpl =>
case PLength(op) => isExpr(op).out ++ {
underlyingType(exprType(op)) match {
case _: ArrayT | _: SliceT | _: GhostSliceT | StringT | _: VariadicT | _: MapT | _: MathMapT => noMessages
case _: SequenceT | _: SetT | _: MultisetT => isPureExpr(op)
case _: SequenceT | _: SetT | _: MultisetT | _: AdtT => isPureExpr(op)
case typ => error(op, s"expected an array, string, sequence or slice type, but got $typ")
}
}
Expand Down Expand Up @@ -1051,7 +1051,7 @@ trait ExprTyping extends BaseTyping { this: TypeInfoImpl =>
private[typing] def typeOfPLength(expr: PLength): Type =
underlyingType(exprType(expr.exp)) match {
case _: ArrayT | _: SliceT | _: GhostSliceT | StringT | _: VariadicT | _: MapT => INT_TYPE
case _: SequenceT | _: SetT | _: MultisetT | _: MathMapT => UNTYPED_INT_CONST
case _: SequenceT | _: SetT | _: MultisetT | _: MathMapT | _: AdtT => UNTYPED_INT_CONST
case t => violation(s"unexpected argument ${expr.exp} of type $t passed to len")
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,9 @@ class AdtEncoding extends LeafTypeEncoding {
* unique X_clause1_tag(): Int
* ...
*
* // rank function
* rank(): Int
*
* axiom {
* forall f11: F11, ... :: { X_clause1(f11, ...) }
* X_tag(X_clause1(f11, ...)) == X_clause1_tag() && X_F11(X_clause1(f11, ...)) )) == f11 && ...
Expand All @@ -76,6 +79,16 @@ class AdtEncoding extends LeafTypeEncoding {
* forall t: X :: {X_tag(t)} t == X_clause1(X_clause1_f11(t), ...) || t == ...
* }
*
* axiom {
* forall t X :: {rank(t)} 0 <= rank(t)
* }
*
* // for every parameter pi of a constructor C with arity n, if pi has an ADT type:
* axiom {
* forall p1, ..., pn ==> rank(pi) < rank(C(p1, ..., pi, ..., pn))
* }
* ...
*
* }
*/
override def member(ctx: Context): in.Member ==> MemberWriter[Vector[vpr.Member]] = {
Expand Down Expand Up @@ -227,13 +240,77 @@ class AdtEncoding extends LeafTypeEncoding {
vpr.Forall(Seq(variableDecl), Seq(trigger), vu.bigOr(equalities)(aPos, aInfo, aErrT))(aPos, aInfo, aErrT)
)(aPos, aInfo, adtName, aErrT)
}

// rank function for ADTs, as axiomatized in Paul Dahlke's thesis - section 5.3.2 of
// https://ethz.ch/content/dam/ethz/special-interest/infk/chair-program-method/pm/documents/Education/Theses/Paul_Dahlke_BA_Report.pdf
val rankFunc = adtRankFunc(adtName)(aPos, aInfo, aErrT)
val rankAxioms = {
// the following axiom is useful for Gobra to easily infer that there is a lower bound
// on the values produced by rank:
// forall x X :: {rank(x)} 0 <= rank(x)
val rankIsBounded = {
val variableDecl = vpr.LocalVarDecl("x", adtT)(aPos, aInfo, aErrT)
val rankApp = applyRankFunc(adtName, variableDecl.localVar)(aPos, aInfo, aErrT)
val trigger = vpr.Trigger(Seq(rankApp))(aPos, aInfo, aErrT)
val body = vpr.Forall(
variables = Seq(variableDecl),
triggers = Seq(trigger),
exp = vpr.LeCmp(vpr.IntLit(0)(aPos, aInfo, aErrT), rankApp)(aPos, aInfo, aErrT)
)(aPos, aInfo, aErrT)
vpr.AnonymousDomainAxiom(body)(aPos, aInfo, adtName, aErrT)
}

// for every parameter pi of a constructor C with arity n, if pi has an ADT type:
// forall p1, ..., pn ==> rank(pi) < rank(C(p1, ..., pi, ..., pn))
val defsRankPerClause = adt.clauses zip constructors flatMap { case (clause, cons) =>
val variables = fieldDecls(clause)
val constApp = vpr.DomainFuncApp(
funcname = cons.name,
args = variables map (_.localVar),
typVarMap = Map()
)(aPos, aInfo, adtT, adtName, aErrT)
val rankOfConst = applyRankFunc(adtName, constApp)(aPos, aInfo, aErrT) // rank(C(p1, ..., pn))
val trigger = vpr.Trigger(Seq(rankOfConst))(aPos, aInfo, aErrT)
clause.args.map(arg => underlyingType(arg.typ)(ctx)) zip variables collect {
case (inVarT: in.AdtT, vprVar) =>
// selects the appropriate rank function according to the ADT type
val rankOfParam = applyRankFunc(inVarT.name, vprVar.localVar)(aPos, aInfo, aErrT) // rank(pi)
val body = vpr.LtCmp(rankOfParam, rankOfConst)(aPos, aInfo, aErrT)
val axExp = vpr.Forall(variables = variables, triggers = Seq(trigger), exp = body)(aPos, aInfo, aErrT)
vpr.AnonymousDomainAxiom(axExp)(aPos, aInfo, adtName, aErrT)
}
}
rankIsBounded +: defsRankPerClause
}

ml.unit(Vector(vpr.Domain(
adtName,
functions = (defaultFunc +: tagFunc +: clauseTags) ++ constructors ++ destructors,
axioms = (exclusiveAxiom +: constructorAxioms) ++ destructorAxioms
functions = (defaultFunc +: rankFunc +: tagFunc +: clauseTags) ++ constructors ++ destructors,
axioms = (exclusiveAxiom +: constructorAxioms) ++ destructorAxioms ++ rankAxioms
)(pos = aPos, info = aInfo, errT = aErrT)))
}

private def adtRankFuncName(adtName: String): String = s"rank$$$adtName"
private def adtRankFunc(adtName: String)(pos: vpr.Position, info: vpr.Info, errT: vpr.ErrorTrafo) = {
val funcName = adtRankFuncName(adtName)
val adtT = adtType(adtName)
val variableDecl = vpr.LocalVarDecl("x", adtT)(pos, info, errT)
vpr.DomainFunc(
name = funcName,
formalArgs = Seq(variableDecl),
typ = vpr.Int,
unique = false,
interpretation = None
)(pos, info, adtName, errT)
}

private def applyRankFunc(adtName: String, arg: vpr.Exp)
(pos: vpr.Position, info: vpr.Info, errT: vpr.ErrorTrafo): vpr.DomainFuncApp = vpr.DomainFuncApp(
funcname = adtRankFuncName(adtName),
args = Seq(arg),
typVarMap = Map()
)(pos, info, vpr.Int, adtName, errT)

/**
* [ dflt(adt{N}) ] -> N_default()
* [ C{args}: adt{N} ] -> N_C([args])
Expand Down Expand Up @@ -263,6 +340,12 @@ class AdtEncoding extends LeafTypeEncoding {
} yield withSrc(destructor(ad.field.name, adtType.name, value, ctx.typ(ad.field.typ)), ad)

case p: in.PatternMatchExp => translatePatternMatchExp(p)(ctx)

case l@in.Length(expr :: ctx.Adt(a)) =>
for {
e <- ctx.expression(expr)
rankApp = withSrc(applyRankFunc(a.name, e), l)
} yield rankApp
}
}

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
// Any copyright is dedicated to the Public Domain.
// http://creativecommons.org/publicdomain/zero/1.0/

package pkg

type tree adt{
leaf{ value int }
node{ left, right tree }
}

ghost
pure
decreases // cause: wrong termination measure
func leafCount(t tree) int {
return match t {
case leaf{_}: 1
//:: ExpectedOutput(pure_function_termination_error)
case node{?l, ?r}: leafCount(l) + leafCount(r)
}
}

type list adt {
Empty{}

Cons{
head any
tail list
}
}

ghost
decreases len(l)
func length(l list) int {
match l {
case Empty{}:
return 0
case Cons{_, ?t}:
//:: ExpectedOutput(function_termination_error)
return 1 + length(l) // cause: pass l to length instead of t
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
// Any copyright is dedicated to the Public Domain.
// http://creativecommons.org/publicdomain/zero/1.0/

package pkg

type tree adt{
leaf{ value int }
node{ left, right tree }
}

ghost
pure
decreases len(t)
func leafCount(t tree) int {
return match t {
case leaf{_}: 1
case node{?l, ?r}: leafCount(l) + leafCount(r)
}
}

type list adt {
Empty{}

Cons{
head any
tail list
}
}

ghost
decreases len(l)
func length(l list) int {
match l {
case Empty{}:
return 0
case Cons{_, ?t}:
return 1 + length(t)
}
}

ghost
requires l === r || l.tail.tail === r
decreases len(l)
func testSubSubList(l, r list) {
if (l === r) {
return
} else {
assert l.tail.tail === r
assert l !== r
assume l.isCons // Gobra cannot infer this - adt axiomatization still a bit weak?
assume l.tail.isCons // Gobra cannot infer this - adt axiomatization still a bit weak?
assert len(l.tail) < len(l)
assert len(l.tail.tail) < len(l)
testSubSubList(l.tail.tail, r)
}
}