Skip to content

Commit

Permalink
extend Variants.tla with VariantGetUnsafe and VariantGetOrElse
Browse files Browse the repository at this point in the history
  • Loading branch information
konnov committed Jun 10, 2022
1 parent 4ed5d34 commit ff6798d
Show file tree
Hide file tree
Showing 8 changed files with 215 additions and 24 deletions.
41 changes: 40 additions & 1 deletion src/tla/Variants.tla
Original file line number Diff line number Diff line change
Expand Up @@ -85,10 +85,49 @@ VariantMatch(__tagName, __variant, __ThenOper(_), __ElseOper(_)) ==
*
* (Str, Tag(a)) => a
*)
VariantGet(__tagName, __variant) ==
VariantGetOnly(__tagName, __variant) ==
\* default untyped implementation
IF __variant.tag = __tagName
THEN __variant.value
ELSE \* trigger an error in TLC by choosing a non-existant element
CHOOSE x \in { __variant }: x.tag = __tagName

(**
* Return the value associated with the tag, when the tag equals to __tagName.
* Otherwise, return __elseValue.
*
* @param `__tagName` the tag attached to the variant
* @param `__variant` a variant that is constructed with `Variant(...)`
* @param `__defaultValue` the default value to return, if not tagged with __tagName
* @return the value extracted from the variant, or the __defaultValue
*
* Its type could look like follows:
*
* (Str, Tag(a)) => a
*)
VariantGetOrElse(__tagName, __variant, __defaultValue) ==
\* default untyped implementation
IF __variant.tag = __tagName
THEN __variant.value
ELSE __defaultValue


(**
* Unsafely return a value of the type associated with __tagName.
* If the variant is tagged with __tagName, then return the associated value.
* Otherwise, return some value of the type associated with __tagName.
*
* @param `tagValue` the tag attached to the variant
* @param `variant` a variant that is constructed with `Variant(...)`
* @return the value extracted from the variant, when tagged __tagName;
* otherwise, return some value
*
* Its type could look like follows:
*
* (Str, Tag(a) | b) => a
*)
VariantGetUnsafe(__tagName, __variant) ==
\* the default untyped implementation
__variant.value

===============================================================================
16 changes: 13 additions & 3 deletions test/tla/TestVariants.tla
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,18 @@ TestVariantFilter ==
\E v \in VariantFilter("B", { VarA, VarB }):
v.value = "hello"

TestVariantGet ==
TestVariantGetOnly ==
\* We could just pass "hello", without wrapping it in a record.
\* But we want to see how it works with records too.
VariantGet("B", VarB) = [ value |-> "hello" ]
VariantGetOnly("B", VarB) = [ value |-> "hello" ]

TestVariantGetUnsafe ==
\* The unsafe version gives us only a type guarantee.
VariantGetUnsafe("A", VarB) \in Int

TestVariantGetOrElse ==
\* When the tag name is different from the actual one, return the default value.
VariantGetOrElse("A", VarB, 12) = 12

TestVariantMatch ==
VariantMatch(
Expand All @@ -38,7 +46,9 @@ TestVariantMatch ==
AllTests ==
/\ TestVariant
/\ TestVariantFilter
/\ TestVariantGet
/\ TestVariantGetOnly
/\ TestVariantMatch
/\ TestVariantGetUnsafe
/\ TestVariantGetOrElse

===============================================================================
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,9 @@ object StandardLibrary {
("Variants", "Variant") -> VariantOper.variant,
("Variants", "VariantFilter") -> VariantOper.variantFilter,
("Variants", "VariantMatch") -> VariantOper.variantMatch,
("Variants", "VariantGet") -> VariantOper.variantGet,
("Variants", "VariantGetOnly") -> VariantOper.variantGetOnly,
("Variants", "VariantGetUnsafe") -> VariantOper.variantGetUnsafe,
("Variants", "VariantGetOrElse") -> VariantOper.variantGetOrElse,
// internal modules
("__apalache_folds", "__ApalacheFoldSeq") -> ApalacheOper.foldSeq,
("__apalache_folds", "__ApalacheMkSeq") -> ApalacheOper.mkSeq,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -578,8 +578,16 @@ class TestSanyImporterStandardModules extends SanyImporterTestBase {
| VariantMatch("T1a", var, ThenOper, ElseOper)
|
|\* @type: T1a({ val: Int, found: Bool }) => { val: Int, found: Bool };
|MO(var) ==
| VariantGet("T1a", var)
|TestVariantGetOnly(var) ==
| VariantGetOnly("T1a", var)
|
|\* @type: T1a({ val: Int, found: Bool }) => { val: Int, found: Bool };
|TestVariantGetUnsafe(var) ==
| VariantGetUnsafe("T1a", var)
|
|\* @type: (T1a({ val: Int, found: Bool }), { val: Int, found: Bool }) => { val: Int, found: Bool };
|TestVariantGetOrElse(var) ==
| VariantGetOrElse("T1a", var, [ val |-> 0, found |-> FALSE])
|================================
""".stripMargin

Expand Down Expand Up @@ -640,19 +648,44 @@ class TestSanyImporterStandardModules extends SanyImporterTestBase {
OperParam("var"),
)

// MO(var) ==
// VariantGet("T1a", var)
// TestVariantGetOnly(var) ==
// VariantGetOnly("T1a", var)
val applyMatchOnly =
OperEx(
VariantOper.variantGet,
VariantOper.variantGetOnly,
str("T1a"),
name("var"),
)

expectDecl(
"MO",
"TestVariantGetOnly",
applyMatchOnly,
OperParam("var"),
)

// TestVariantGetUnsafe(var) == VariantGetUnsafe("T1a", var)
expectDecl(
"TestVariantGetUnsafe",
OperEx(
VariantOper.variantGetUnsafe,
ValEx(TlaStr("T1a")),
NameEx("var"),
),
OperParam("var"),
)

// TestVariantGetOrElse(var) ==
// VariantGetOrElse("T1a", var, [ val |-> 0, found |-> FALSE])
expectDecl(
"TestVariantGetOrElse",
OperEx(
VariantOper.variantGetOrElse,
ValEx(TlaStr("T1a")),
NameEx("var"),
OperEx(TlaFunOper.rec, ValEx(TlaStr("val")), ValEx(TlaInt(0)), ValEx(TlaStr("found")),
ValEx(TlaBool(false))),
),
OperParam("var"),
)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -783,7 +783,7 @@ class ToEtcExpr(
mkExRefApp(opsig, Seq(v, setEx))

case ex @ OperEx(VariantOper.variantFilter, tag @ _, _) =>
throw new TypingInputException(s"The first argument of FilterByTag must be a string, found: $tag", ex.ID)
throw new TypingInputException(s"The first argument of VariantFilter must be a string, found: $tag", ex.ID)

case OperEx(VariantOper.variantMatch, v @ ValEx(TlaStr(tagName)), variantEx, thenOper, elseOper) =>
val a = varPool.fresh
Expand All @@ -806,9 +806,9 @@ class ToEtcExpr(
mkExRefApp(opsig, Seq(v, variantEx, thenOper, elseOper))

case OperEx(VariantOper.variantMatch, tag @ _, _, _, _) =>
throw new TypingInputException(s"The first argument of MatchTag must be a string, found: $tag", ex.ID)
throw new TypingInputException(s"The first argument of VariantMatch must be a string, found: $tag", ex.ID)

case OperEx(VariantOper.variantGet, v @ ValEx(TlaStr(tagName)), variantEx) =>
case OperEx(VariantOper.variantGetOnly, v @ ValEx(TlaStr(tagName)), variantEx) =>
val a = varPool.fresh
// (Str, T1a(a)) => a
val operArgs =
Expand All @@ -820,6 +820,37 @@ class ToEtcExpr(
val opsig = OperT1(operArgs, a)
mkExRefApp(opsig, Seq(v, variantEx))

case OperEx(VariantOper.variantGetOnly, tag @ _, _) =>
throw new TypingInputException(s"The first argument of VariantGetOnly must be a string, found: $tag", ex.ID)

case OperEx(VariantOper.variantGetUnsafe, v @ ValEx(TlaStr(tagName)), variantEx) =>
val a = varPool.fresh
// (Str, T1a(a)) => a
val operArgs =
Seq(
StrT1,
VariantT1(RowT1(tagName -> a)),
)

val opsig = OperT1(operArgs, a)
mkExRefApp(opsig, Seq(v, variantEx))

case OperEx(VariantOper.variantGetUnsafe, tag @ _, _) =>
throw new TypingInputException(s"The first argument of VariantGetUnsafe must be a string, found: $tag", ex.ID)

case OperEx(VariantOper.variantGetOrElse, v @ ValEx(TlaStr(tagName)), variantEx, defaultEx) =>
val a = varPool.fresh
// (Str, T1a(a), a) => a
val operArgs =
Seq(
StrT1,
VariantT1(RowT1(tagName -> a)),
a,
)

val opsig = OperT1(operArgs, a)
mkExRefApp(opsig, Seq(v, variantEx, defaultEx))

// ******************************************** Apalache **************************************************
case OperEx(ApalacheOper.mkSeq, len, ctor) =>
val a = varPool.fresh
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -376,22 +376,22 @@ class TestToEtcExpr extends AnyFunSuite with BeforeAndAfterEach with ToEtcExprBa
}
}

test("""FilterByTag("T1a", set)""") {
test("""VariantFilter("T1a", set)""") {
val operType = parser("""(Str, Set(T1a(a) | b)) => Set(a)""")
val expected = mkUniqApp(Seq(operType), mkUniqConst(StrT1), mkUniqName("set"))
val filterEx = tla.variantFilter("T1a", tla.name("set"))
val produced = gen(filterEx)
produced should equal(expected)
}

test("""FilterByTag(foo, set)""") {
test("""VariantFilter(foo, set)""") {
val filterEx = OperEx(VariantOper.variantFilter, tla.name("foo"), tla.name("set"))
assertThrows[TypingInputException] {
gen(filterEx)
}
}

test("""MatchTag("T1a", v, ThenOper, ElseOper)""") {
test("""VariantMatch("T1a", v, ThenOper, ElseOper)""") {
val thenType = parser("a => c")
val elseType = parser("Variant(b) => c")
val operType = parser(s"""(Str, T1a(a) | b, $thenType, $elseType) => c""")
Expand All @@ -403,10 +403,26 @@ class TestToEtcExpr extends AnyFunSuite with BeforeAndAfterEach with ToEtcExprBa
produced should equal(expected)
}

test("""VariantGet("T1a", v)""") {
test("""VariantGetOnly("T1a", v)""") {
val operType = parser(s"""(Str, T1a(a)) => a""")
val expected = mkUniqApp(Seq(operType), mkUniqConst(StrT1), mkUniqName("v"))
val matchEx = tla.variantGet("T1a", tla.name("v"))
val matchEx = tla.variantGetOnly("T1a", tla.name("v"))
val produced = gen(matchEx)
produced should equal(expected)
}

test("""VariantGetUnsafe("T1a", v)""") {
val operType = parser(s"""(Str, T1a(a)) => a""")
val expected = mkUniqApp(Seq(operType), mkUniqConst(StrT1), mkUniqName("v"))
val matchEx = tla.variantGetUnsafe("T1a", tla.name("v"))
val produced = gen(matchEx)
produced should equal(expected)
}

test("""VariantGetOrElse("T1a", v, d)""") {
val operType = parser(s"""(Str, T1a(a), a) => a""")
val expected = mkUniqApp(Seq(operType), mkUniqConst(StrT1), mkUniqName("v"), mkUniqName("d"))
val matchEx = tla.variantGetOrElse("T1a", tla.name("v"), tla.name("d"))
val produced = gen(matchEx)
produced should equal(expected)
}
Expand Down
42 changes: 39 additions & 3 deletions tlair/src/main/scala/at/forsyte/apalache/tla/lir/Builder.scala
Original file line number Diff line number Diff line change
Expand Up @@ -727,10 +727,46 @@ class Builder {
* @return
* the value extracted from the variant
*/
def variantGet(
def variantGetOnly(
tagName: String,
variantEx: BuilderEx): BuilderEx = {
BuilderOper(VariantOper.variantGet, str(tagName), variantEx)
BuilderOper(VariantOper.variantGetOnly, str(tagName), variantEx)
}

/**
* Unsafely extract the value associated with a tag. If the tag name is different from the actual tag, return some
* value of proper type.
*
* @param tagName
* a tag value (string)
* @param variantEx
* a variant expression
* @return
* the value extracted from the variant, when tagged with tagName; otherwise, return some value
*/
def variantGetUnsafe(
tagName: String,
variantEx: BuilderEx): BuilderEx = {
BuilderOper(VariantOper.variantGetOnly, str(tagName), variantEx)
}

/**
* Return the value associated with the tag, when the tag equals to __tagName. Otherwise, return __elseValue.
*
* @param tagName
* a tag value (string)
* @param variantEx
* a variant expression
* @param defaultEx
* default expression
* @return
* the value extracted from the variant
*/
def variantGetOrElse(
tagName: String,
variantEx: BuilderEx,
defaultEx: BuilderEx): BuilderEx = {
BuilderOper(VariantOper.variantGetOrElse, str(tagName), variantEx, defaultEx)
}

private val m_nameMap: Map[String, TlaOper] =
Expand Down Expand Up @@ -828,7 +864,7 @@ class Builder {
ApalacheOper.setAsFun.name -> ApalacheOper.setAsFun,
ApalacheOper.guess.name -> ApalacheOper.guess,
VariantOper.variant.name -> VariantOper.variant,
VariantOper.variantGet.name -> VariantOper.variantGet,
VariantOper.variantGetOnly.name -> VariantOper.variantGetOnly,
VariantOper.variantMatch.name -> VariantOper.variantMatch,
VariantOper.variantFilter.name -> VariantOper.variantFilter,
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -48,8 +48,32 @@ object VariantOper {
/**
* Match a single variant.
*/
object variantGet extends VariantOper {
override def name: String = "Variants!VariantGet"
object variantGetOnly extends VariantOper {
override def name: String = "Variants!VariantGetOnly"

override def arity: OperArity = FixedArity(2)

override val precedence: (Int, Int) = (100, 100)
}

/**
* Get the value associated with the tag name, if the tag is matching the tag name. Otherwise, return the default
* value.
*/
object variantGetOrElse extends VariantOper {
override def name: String = "Variants!VariantGetOrElse"

override def arity: OperArity = FixedArity(3)

override val precedence: (Int, Int) = (100, 100)
}

/**
* Unsafely extract the value associated with a tag. If the tag name is different from the actual tag, return some
* value of proper type.
*/
object variantGetUnsafe extends VariantOper {
override def name: String = "Variants!VariantGetUnsafe"

override def arity: OperArity = FixedArity(2)

Expand Down

0 comments on commit ff6798d

Please sign in to comment.