Skip to content

Commit

Permalink
complete rewrite of the variant types
Browse files Browse the repository at this point in the history
  • Loading branch information
konnov committed Jun 1, 2022
1 parent 13dd4ee commit f0c9fc9
Show file tree
Hide file tree
Showing 15 changed files with 231 additions and 251 deletions.
41 changes: 18 additions & 23 deletions src/tla/Variants.tla
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,7 @@
*
* The type could look like follows, if we supported string literals in types:
*
* (Str, a) =>
* { tag: "$tagValue", a } | b
* (Str, a) => Tag(a) | b
*)
Variant(__tag, __value) ==
\* default untyped implementation
Expand All @@ -36,9 +35,9 @@ Variant(__tag, __value) ==
*
* The type could look like follows, if we supported string literals in types:
*
* (Str, Set({ tag: "$tagValue", a} | b)) => Set({ a })
* (Str, Set(Tag(a) | b)) => Set(a)
*)
FilterByTag(__tag, __S) ==
VariantFilter(__tag, __S) ==
\* default untyped implementation
{ __d \in { __e \in __S: __e.tag = __tag }: __d.value }

Expand All @@ -63,38 +62,34 @@ FilterByTag(__tag, __S) ==
*
* (
* Str,
* { "$tagValue": a | b },
* { a } => r,
* { Tag(a) | b },
* a => r,
* Variant(b) => r
* ) => r
*)
MatchTag(__tagValue, __variant, __ThenOper(_), __ElseOper(_)) ==
VariantMatch(__tagValue, __variant, __ThenOper(_), __ElseOper(_)) ==
\* default untyped implementation
IF __variant.tag = __tagValue
THEN __ThenOper(__variant.value)
ELSE __ElseOper(__variant)

(**
* In case when `variant` allows for one record type,
* apply `ThenOper(rec)`, where `rec` is a record extracted from `variant`.
* The type checker must enforce that `variant` allows for one record type.
* The untyped implementation does not perform such a test,
* as it is impossible to do so without types.
* In case when `variant` allows for one value,
* extract the associated value and return it.
* The type checker must enforce that `variant` allows for one option.
*
* @param `tagValue` the tag attached to the variant
* @param `variant` a variant that is constructed with `Variant(...)`
* @param `ThenOper` an operator that is called
* when `variant` is tagged with `tagValue`
* @return the result returned by `ThenOper`
* @return the value extracted from the variant
*
* The type could look like follows, if we supported string literals in types:
* Its type could look like follows:
*
* (
* Str,
* { "$tagValue": a },
* { a } => r
* ) => r
* (Str, Tag(a)) => a
*)
MatchOnly(__tagValue, __variant, __ThenOper(_)) ==
VariantGet(__tagValue, __variant) ==
\* default untyped implementation
__ThenOper(__variant.value)
IF __variant.tag = __tagValue
THEN __variant.value
ELSE \* trigger an error in TLC by choosing a non-existant element
CHOOSE x \in { __variant }: x.tag = __tagValue
===============================================================================
37 changes: 17 additions & 20 deletions test/tla/TestReqAckVariants.tla
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
--------------------------- MODULE TestVariants -------------------------------
----------------------- MODULE TestReqAckVariants ------------------------------
\* A test for the Variants module.
\* This test should work when ADR-014 is implemented.

Expand All @@ -10,16 +10,14 @@ VARIABLES
\* @type: Int;
balance,
(*
@typeAlias: MESSAGE =
{ tag: "req", ask: Int }
| { tag: "ack", success: Bool };
@typeAlias: MESSAGE = Req({ ask: Int }) | Ack({ success: Bool });
@type: Set(MESSAGE);
*)
msgs,
(*
@typeAlias: EVENT =
{ tag: "withdraw", amount: Int }
| { tag: "lacking", amount: Int };
@typeAlias: EVENT = Withdraw(Int) | Lacking(Int);
@type: Seq(EVENT);
*)
log
Expand All @@ -31,30 +29,29 @@ Init ==

SendRequest(ask) ==
/\ ask > 0
/\ LET m == Variant([ tag |-> "req", ask |-> ask ]) IN
/\ LET m == Variant("Req", [ ask |-> ask ]) IN
msgs' = msgs \union { m }
/\ UNCHANGED <<balance, log>>


ProcessRequest(m) ==
/\ IF balance >= m.ask
THEN LET entry == Variant([ tag |-> "withdraw", amount |-> m.ask ]) IN
LET ack == Variant([ tag |-> "ack", success |-> TRUE ]) IN
/\ balance' = balance - m.ask
ProcessRequest(ask) ==
/\ IF balance >= ask
THEN LET entry == Variant("Withdraw", ask) IN
LET ack == Variant("Ack", [ success |-> TRUE ]) IN
/\ balance' = balance - ask
/\ log' = Append(log, entry)
/\ msgs' = (msgs \ { m }) \union { ack }
ELSE LET entry ==
Variant([ tag |-> "lacking", amount |-> m.ask - balance ]) IN
LET ack == Variant([ tag |-> "ack", success |-> FALSE ]) IN
/\ msgs' = (msgs \ { Variant("Req", [ ask |-> ask]) }) \union { ack }
ELSE LET entry == Variant("Lacking", ask - balance) IN
LET ack == Variant("Ack", [ success |-> FALSE ]) IN
/\ log' = Append(log, entry)
/\ msgs' = (msgs \ { m }) \union { ack }
/\ msgs' = (msgs \ { Variant("Req", [ ask |-> ask]) }) \union { ack }
/\ UNCHANGED balance


Next ==
\/ \E ask \in Amounts:
SendRequest(ask)
\/ \E m \in FilterByTag(msgs, "req"):
ProcessRequest(m)
\/ \E m \in VariantFilter("Req", msgs):
ProcessRequest(m.ask)

===============================================================================
42 changes: 42 additions & 0 deletions test/tla/TestVariants.tla
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
---------------------------- MODULE TestVariants -------------------------------
(*
* Functional tests for operators over variants.
* We introduce a trivial state machine and write tests as state invariants.
*)

EXTENDS Integers, FiniteSets, Apalache, Variants

Init == TRUE
Next == TRUE

(* DEFINITIONS *)

VarA == Variant("A", 1)

VarB == Variant("B", [ value |-> "hello" ])

TestVariant ==
VarA \in { VarA, VarB }

TestVariantFilter ==
\E v \in VariantFilter("B", { VarA, VarB }):
v.value = "hello"

TestVariantGet ==
VariantGet("B", VarB) = [ value |-> "hello" ]

TestVariantMatch ==
VariantMatch(
"A",
VarB,
LAMBDA i: i > 0,
LAMBDA v: FALSE
)

AllTests ==
/\ TestVariant
/\ TestVariantFilter
/\ TestVariantGet
/\ TestVariantMatch

===============================================================================
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ object DefaultType1Parser extends Parsers with Type1Parser {
}

private def typeExpr: Parser[TlaType1] = {
(operator | function | noFunExpr)
operator | function | noFunExpr
}

// A type expression. We wrap it with a list, as (type, ..., type) may start an operator type
Expand Down Expand Up @@ -214,35 +214,11 @@ object DefaultType1Parser extends Parsers with Type1Parser {
}

// An option in the variant type that is constructed from a row.
// For example, { tag: "tag1", f: Bool } or { tag: "tag2", g: Bool, c }.
// For example, Tag1(a).
private def variantOption: Parser[(String, TlaType1)] = {
// the first rule tests for duplicates in the rule names
(LCURLY() ~> tag ~> COLON() ~> stringLiteral ~>
COMMA() ~> rep1sep(typedField, COMMA()) <~ opt(COMMA() ~ typeVar) <~ RCURLY()) >> { list =>
val dup = findDups("tag" +: list.map(_._1))
if (dup.nonEmpty) {
err(s"Found a duplicate key ${dup.get} in a record")
} else {
// fail here to try the second rule
failure("")
}
} | // the second rule is actually producing the result, provided that the sequence is accepted
(LCURLY() ~> tag ~ COLON() ~ stringLiteral ~
opt(COMMA() ~ rep1sep(typedField, COMMA())) ~ opt(COMMA() ~ typeVar) <~ RCURLY()) ^^ {
case _ ~ _ ~ STR_LITERAL(tagValue) ~ optList ~ Some(_ ~ VarT1(v)) =>
val list = optList match {
case Some(_ ~ l) => l
case _ => Nil
}
(tagValue, RecRowT1(RowT1(VarT1(v), ("tag" -> StrT1()) :: list: _*)))

case _ ~ _ ~ STR_LITERAL(tagValue) ~ optList ~ None =>
val list = optList match {
case Some(_ ~ l) => l
case _ => Nil
}
(tagValue, RecRowT1(RowT1(("tag" -> StrT1()) :: list: _*)))
}
((tag <~ LPAREN()) ~ typeExpr <~ RPAREN()) ^^ { case IDENT(tagName) ~ valueType =>
(tagName, valueType)
}
}

// the user-friendly syntax of the variant type
Expand All @@ -257,8 +233,8 @@ object DefaultType1Parser extends Parsers with Type1Parser {
failure("")
}
} | // the second rule is actually producing the result, provided that the sequence is accepted
(rep1sep(variantOption, PIPE()) ~ opt(PIPE() ~ typeVar)) ^^ {
case list ~ Some(_ ~ VarT1(v)) =>
(rep1sep(variantOption, PIPE()) ~ opt(PIPE() ~> typeVar)) ^^ {
case list ~ Some(VarT1(v)) =>
VariantT1(RowT1(VarT1(v), list: _*))

case list ~ None =>
Expand All @@ -280,14 +256,6 @@ object DefaultType1Parser extends Parsers with Type1Parser {
}
}

// A tag name
private def stringLiteral: Parser[STR_LITERAL] = {
accept("string literal",
{ case f @ STR_LITERAL(_) =>
f
})
}

// A record field name, like foo_BAR2.
// As field name are colliding with CAPS_IDENT and TYPE_VAR, we expect all of them.
private def fieldName: Parser[IDENT] = {
Expand All @@ -298,10 +266,9 @@ object DefaultType1Parser extends Parsers with Type1Parser {
}

private def tag: Parser[IDENT] = {
accept("tag",
{
case f @ IDENT(name) if name == "tag" =>
f
accept("variant tag",
{ case f @ IDENT(_) =>
f
})
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -69,9 +69,9 @@ object StandardLibrary {
("Apalache", "ApaFoldSeqLeft") -> ApalacheOper.foldSeq,
// Variants
("Variants", "Variant") -> VariantOper.variant,
("Variants", "FilterByTag") -> VariantOper.filterByTag,
("Variants", "MatchTag") -> VariantOper.matchTag,
("Variants", "MatchOnly") -> VariantOper.matchOnly,
("Variants", "VariantFilter") -> VariantOper.variantFilter,
("Variants", "VariantMatch") -> VariantOper.variantMatch,
("Variants", "VariantGet") -> VariantOper.variantGet,
// internal modules
("__apalache_folds", "__ApalacheFoldSeq") -> ApalacheOper.foldSeq,
("__apalache_folds", "__ApalacheMkSeq") -> ApalacheOper.mkSeq,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ class TestAnnotationParser extends AnyFunSuite with Checkers {
.parse(""" @type: (Int, Int) -> Set(Int) ;""")
.map(a => assert(expected == a))
.swap
.map(r => "Unexpected parser outcome: " + r)
.map(r => fail("Unexpected parser outcome: " + r))
}

test("test on multiline input") {
Expand All @@ -84,7 +84,7 @@ class TestAnnotationParser extends AnyFunSuite with Checkers {
.parse(text)
.map(a => assert(expected == a))
.swap
.map(r => "Unexpected parser outcome: " + r)
.map(r => fail("Unexpected parser outcome: " + r))
}

test("regression") {
Expand All @@ -110,6 +110,26 @@ class TestAnnotationParser extends AnyFunSuite with Checkers {
.map(r => fail("Expected a failure. Found: " + r))
}

test("parse variants in type aliases") {
val extractedText =
"""MESSAGE =""" + "\n" + """ Req(ask: Int) """ + "\n" + """ | Ack(success: Bool)"""
val text =
s"""
|@typeAlias: $extractedText;
|""".stripMargin

val expected =
Annotation(
"typeAlias",
AnnotationStr(extractedText),
)
AnnotationParser
.parse(text)
.map(a => assert(expected == a))
.swap
.map(r => fail("Unexpected parser outcome: " + r))
}

test("multiple annotations as in unit tests") {
// The text should be preprocessed by CommentPreprocessor first, so we expect a failure.
val text =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,12 @@ import org.junit.runner.RunWith
import org.scalacheck.Gen.asciiStr
import org.scalacheck.Prop.forAll
import org.scalatest.funsuite.AnyFunSuite
import org.scalatest.matchers.should.Matchers
import org.scalatestplus.junit.JUnitRunner
import org.scalatestplus.scalacheck.Checkers

@RunWith(classOf[JUnitRunner])
class TestCommentPreprocessor extends AnyFunSuite with Checkers {
class TestCommentPreprocessor extends AnyFunSuite with Checkers with Matchers {

test("test on empty input") {
val (output, potentialAnnotations) = CommentPreprocessor()("")
Expand Down Expand Up @@ -156,6 +157,15 @@ class TestCommentPreprocessor extends AnyFunSuite with Checkers {
hasAnnotationsWhenNonEmpty("""(* aaa *)""")
}

test("accept pipe") {
val extractedText =
"""@typeAlias: MESSAGE = { tag: "req", ask: Int }""" + "\n" + """ | { tag: "ack", success: Bool };"""
val input = s"(*\n $extractedText\n *)"
val (_, potentialAnnotations) = CommentPreprocessor()(input)
potentialAnnotations should equal(
List("""@typeAlias: MESSAGE = { tag: "req", ask: Int } | { tag: "ack", success: Bool };"""))
}

test("no failure on random inputs") {
check(
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -612,7 +612,7 @@ class TestSanyImporterStandardModules extends SanyImporterTestBase {
expectDecl(
"FBT",
OperEx(
VariantOper.filterByTag,
VariantOper.variantFilter,
ValEx(TlaStr("1a")),
OperEx(TlaSetOper.enumSet, OperEx(TlaOper.apply, NameEx("V"))),
),
Expand All @@ -628,7 +628,7 @@ class TestSanyImporterStandardModules extends SanyImporterTestBase {
declOp("ElseOper", bool(false), OperParam("v")).untypedOperDecl()
val applyMatchTag =
OperEx(
VariantOper.matchTag,
VariantOper.variantMatch,
ValEx(TlaStr("1a")),
name("var"),
name("ThenOper"),
Expand All @@ -648,7 +648,7 @@ class TestSanyImporterStandardModules extends SanyImporterTestBase {
declOp("ThenOper", appFun(name("v"), str("found")), OperParam("v")).untypedOperDecl()
val applyMatchOnly =
OperEx(
VariantOper.matchOnly,
VariantOper.variantGet,
name("var"),
name("ThenOper"),
)
Expand Down
Loading

0 comments on commit f0c9fc9

Please sign in to comment.