Skip to content

Commit

Permalink
implement VariantGetUnsafe
Browse files Browse the repository at this point in the history
  • Loading branch information
konnov committed Jun 20, 2022
1 parent ff6798d commit b0c78f6
Show file tree
Hide file tree
Showing 6 changed files with 56 additions and 9 deletions.
9 changes: 6 additions & 3 deletions src/tla/Variants.tla
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
* Apalache treats these operators as typed, so it enforces type safety of
* variants.
*
* Igor Konnov, Informal Systems, 2021
* Igor Konnov, Informal Systems, 2021-2022
*)

(**
Expand Down Expand Up @@ -42,6 +42,9 @@ VariantFilter(__tagName, __S) ==


(**
* NOTE: This operator is not supported by the model checker yet.
* We are thinking about a reasonably simple implementation of it.
*
* Test the tag of `variant` against the value `tagValue`.
* If `variant.tag = tagValue`, then apply `ThenOper(rec)`,
* where `rec` is a record extracted from `variant`.
Expand Down Expand Up @@ -89,7 +92,7 @@ VariantGetOnly(__tagName, __variant) ==
\* default untyped implementation
IF __variant.tag = __tagName
THEN __variant.value
ELSE \* trigger an error in TLC by choosing a non-existant element
ELSE \* trigger an error in TLC by choosing a non-existent element
CHOOSE x \in { __variant }: x.tag = __tagName

(**
Expand All @@ -103,7 +106,7 @@ VariantGetOnly(__tagName, __variant) ==
*
* Its type could look like follows:
*
* (Str, Tag(a)) => a
* (Str, Tag(a) | b, a) => a
*)
VariantGetOrElse(__tagName, __variant, __defaultValue) ==
\* default untyped implementation
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -283,6 +283,8 @@ class SymbStateRewriterImpl(
// variants
key(tla.variant("Tag", tla.int(33)))
-> List(new VariantOpsRule(this)),
key(tla.variantGetUnsafe("Tag", tla.name("V")))
-> List(new VariantOpsRule(this)),
// FiniteSets
key(OperEx(ApalacheOper.constCard, tla.ge(tla.card(tla.name("S")), tla.int(3))))
-> List(new CardinalityConstRule(this)),
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
package at.forsyte.apalache.tla.bmcmt.rules

import at.forsyte.apalache.tla.bmcmt._
import at.forsyte.apalache.tla.bmcmt.rules.aux.{CherryPick, RecordAndVariantOps}
import at.forsyte.apalache.tla.bmcmt.rules.aux.RecordAndVariantOps
import at.forsyte.apalache.tla.lir._
import at.forsyte.apalache.tla.lir.oper.VariantOper
import at.forsyte.apalache.tla.lir.values.TlaStr
Expand All @@ -13,13 +13,13 @@ import at.forsyte.apalache.tla.lir.values.TlaStr
* Igor Konnov
*/
class VariantOpsRule(rewriter: SymbStateRewriter) extends RewritingRule {
private val picker = new CherryPick(rewriter)
private val variantOps = new RecordAndVariantOps(rewriter)

override def isApplicable(symbState: SymbState): Boolean = {
symbState.ex match {
case OperEx(VariantOper.variant, _, _) => true
case _ => false
case OperEx(VariantOper.variant, _, _) => true
case OperEx(VariantOper.variantGetUnsafe, _, _) => true
case _ => false
}
}

Expand All @@ -29,6 +29,9 @@ class VariantOpsRule(rewriter: SymbStateRewriter) extends RewritingRule {
val variantT = TlaType1.fromTypeTag(ex.typeTag)
translateVariant(state, tagName, valueEx, variantT)

case OperEx(VariantOper.variantGetUnsafe, ValEx(TlaStr(tagName)), variantEx) =>
translateVariantGetUnsafe(state, tagName, variantEx)

case _ =>
throw new RewriterException("%s is not applicable".format(getClass.getSimpleName), state.ex)
}
Expand All @@ -46,4 +49,17 @@ class VariantOpsRule(rewriter: SymbStateRewriter) extends RewritingRule {
val valueCell = nextState.asCell
variantOps.makeVariant(nextState, variantT, tagName, valueCell)
}

/**
* Translate VariantGetUnsafe(tagName, variant).
*/
private def translateVariantGetUnsafe(
state: SymbState,
tagName: String,
variantEx: TlaEx): SymbState = {
val nextState = rewriter.rewriteUntilDone(state.setRex(variantEx))
val variantCell = nextState.asCell
val unsafeValueCell = variantOps.getUnsafeVariantValue(nextState.arena, variantCell, tagName)
nextState.setRex(unsafeValueCell.toNameEx)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -156,7 +156,7 @@ class RecordAndVariantOps(rewriter: SymbStateRewriter) {

/**
* Get the variant value by tag. This is an unsafe method, that is, if the associated tag name is different from the
* provided one, this method returns some of the proper type (usually, the default value).
* provided one, this method returns some value of the proper type (usually, the default value).
*
* @param arena
* current arena
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,30 @@ trait TestSymbStateRewriterVariant extends RewriterBase {
assertTlaExAndRestore(rewriter, state)
}

test("""VariantGetUnsafe with a right tag""") { rewriterType: SMTEncoding =>
val variantT = parser("Foo(Int) | Bar(Bool)")
val vrt1 = variant("Foo", int(33)).as(variantT)
val unsafe = variantGetUnsafe("Foo", vrt1).as(IntT1)
val eq = eql(unsafe, int(33)).as(BoolT1)

val state = new SymbState(eq, arena, Binding())
val rewriter = create(rewriterType)
assertTlaExAndRestore(rewriter, state)
}

test("""VariantGetUnsafe with a wrong tag""") { rewriterType: SMTEncoding =>
val variantT = parser("Foo(Int) | Bar(Bool)")
val vrt2 = variant("Foo", minus(int(44), int(11)).as(IntT1)).as(variantT)
val unsafe = variantGetUnsafe("Foo", vrt2).as(IntT1)

val state = new SymbState(unsafe, arena, Binding())
val rewriter = create(rewriterType)
rewriter.rewriteUntilDone(state)
// The implementation is free to return any value of the right type (Int).
// This operator should not make the solver stuck, that is, produce unsatisfiable constraints.
assert(solverContext.sat())
}

private def getVariantOptions(tp: CellT): Map[String, TlaType1] = {
tp match {
case CellTFrom(VariantT1(RowT1(variantOptions, None))) =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -747,7 +747,7 @@ class Builder {
def variantGetUnsafe(
tagName: String,
variantEx: BuilderEx): BuilderEx = {
BuilderOper(VariantOper.variantGetOnly, str(tagName), variantEx)
BuilderOper(VariantOper.variantGetUnsafe, str(tagName), variantEx)
}

/**
Expand Down Expand Up @@ -865,6 +865,8 @@ class Builder {
ApalacheOper.guess.name -> ApalacheOper.guess,
VariantOper.variant.name -> VariantOper.variant,
VariantOper.variantGetOnly.name -> VariantOper.variantGetOnly,
VariantOper.variantGetUnsafe.name -> VariantOper.variantGetUnsafe,
VariantOper.variantGetOrElse.name -> VariantOper.variantGetOrElse,
VariantOper.variantMatch.name -> VariantOper.variantMatch,
VariantOper.variantFilter.name -> VariantOper.variantFilter,
)
Expand Down

0 comments on commit b0c78f6

Please sign in to comment.