Skip to content

Commit

Permalink
Merge pull request #1646 from informalsystems/ik/unification1622
Browse files Browse the repository at this point in the history
Type unification for rows, new records, and variants
  • Loading branch information
konnov committed Apr 26, 2022
2 parents 2ef78c2 + 849b7a4 commit 856075f
Show file tree
Hide file tree
Showing 10 changed files with 376 additions and 56 deletions.
5 changes: 5 additions & 0 deletions UNRELEASED.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,13 @@
DO NOT LEAVE A BLANK LINE BELOW THIS PREAMBLE -->
### Features

* Experimental type unification over rows, new records, and variants, see #1646

### Breaking changes

* Add the option `--features` to enable experimental features, see #1648

### Bug fixes

* Fix references to `--tune-here` (actually `--tuning-options`), see #1579
* Not failing when assignment and `UNCHANGED` appear in invariants, see #1664
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ object DefaultType1Parser extends Parsers with Type1Parser {
private def noFunExpr: Parser[TlaType1] = {
(INT() | REAL() | BOOL() | STR() | typeVar | typeConst
| set | seq | tuple | row | sparseTuple
| record | parametricRecord | recordFromRow | recordVar
| record | recordFromRow
| variant | variantVar | parenExpr) ^^ {
case INT() => IntT1()
case REAL() => RealT1()
Expand Down Expand Up @@ -146,7 +146,8 @@ object DefaultType1Parser extends Parsers with Type1Parser {

case _ ~ list ~ None ~ _ =>
RowT1(list: _*)
}
} | // the degenerate case of (| var |)
LROW() ~> typeVar <~ RROW() ^^ { v => RowT1(v) }
}

// a sparse tuple type like <| 3: Int, 5: Bool |>
Expand Down Expand Up @@ -178,13 +179,6 @@ object DefaultType1Parser extends Parsers with Type1Parser {
}
}

private def parametricRecord: Parser[TlaType1] = {
// special rule for a record that is completely underspecified, that is, { a }
LCURLY() ~ typeVar ~ RCURLY() ^^ { case _ ~ VarT1(v) ~ _ =>
RecRowT1(RowT1(VarT1(v)))
}
}

private def findDups(list: List[String]): Option[String] = {
// we could use list.groupBy(identity) to count the number of occurrences,
// but that would introduce an unnecessary map
Expand Down Expand Up @@ -215,14 +209,8 @@ object DefaultType1Parser extends Parsers with Type1Parser {

case list ~ None =>
RecRowT1(RowT1(list: _*))
}
}

// the general record constructor which may be used in conjunction with a row variable
private def recordVar: Parser[TlaType1] = {
RECORD() ~ LPAREN() ~ typeVar ~ RPAREN() ^^ { case _ ~ _ ~ VarT1(v) ~ _ =>
RecRowT1(RowT1(VarT1(v)))
}
} | // the degenerate case of a single variable
(LCURLY() ~> typeVar <~ RCURLY()) ^^ (v => RecRowT1(RowT1(v)))
}

// An option in the variant type that is constructed from a row.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ import at.forsyte.apalache.tla.lir.transformations.standard.{
}
import at.forsyte.apalache.tla.lir.transformations.{TlaExTransformation, TransformationTracker}
import at.forsyte.apalache.tla.pp.Inliner.FilterFun
import at.forsyte.apalache.tla.typecheck.etc.{Substitution, TypeUnifier}
import at.forsyte.apalache.tla.typecheck.etc.{Substitution, TypeUnifier, TypeVarPool}

/**
* Given a module m, with global operators F1,...,Fn, Inliner performs the following transformation:
Expand Down Expand Up @@ -82,7 +82,8 @@ class Inliner(
// a substitution of the two. A substitution is assumed to exist, otherwise TypingException is thrown.
private def getSubstitution(targetType: TlaType1, decl: TlaOperDecl): (Substitution, TlaType1) = {
val genericType = decl.typeTag.asTlaType1()
new TypeUnifier().unify(Substitution.empty, genericType, targetType) match {
val maxUsedVar = Math.max(genericType.usedNames.foldLeft(0)(Math.max), targetType.usedNames.foldLeft(0)(Math.max))
new TypeUnifier(new TypeVarPool(maxUsedVar + 1)).unify(Substitution.empty, genericType, targetType) match {
case None =>
throw new TypingException(
s"Inliner: Unable to unify generic signature $genericType of ${decl.name} with the concrete type $targetType",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ import at.forsyte.apalache.tla.lir.TlaType1
* @author
* Igor Konnov
*/
class ConstraintSolver(approximateSolution: Substitution = Substitution.empty) {
class ConstraintSolver(varPool: TypeVarPool, approximateSolution: Substitution = Substitution.empty) {
private var solution: Substitution = approximateSolution
private var constraints: List[Clause] = List.empty
private var typesToReport: List[(Clause, TlaType1)] = List.empty
Expand Down Expand Up @@ -98,7 +98,7 @@ class ConstraintSolver(approximateSolution: Substitution = Substitution.empty) {
constraint match {
case EqClause(unknown, term) =>
// If there is a solution, we return it. We ignore the type, as it should be bound to `unknown`.
new TypeUnifier().unify(solution, unknown, term)
new TypeUnifier(varPool).unify(solution, unknown, term)

case OrClause(eqs @ _*) =>
// try to solve a disjunctive clause
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ class EtcTypeChecker(varPool: TypeVarPool, inferPolytypes: Boolean = true) exten

// The types are computed in operator applications, add extra tests and listener calls for non-operators
try {
val rootSolver = new ConstraintSolver
val rootSolver = new ConstraintSolver(varPool)
// The whole expression has been processed. Compute the type of the expression.
val rootType = computeRec(rootCtx, rootSolver, rootEx)
rootSolver.solve() match {
Expand Down Expand Up @@ -240,7 +240,7 @@ class EtcTypeChecker(varPool: TypeVarPool, inferPolytypes: Boolean = true) exten
val approxSolution = solver.solvePartially().getOrElse(throw new UnwindException)

// introduce a new instance of the constraint solver for the operator definition
val letInSolver = new ConstraintSolver()
val letInSolver = new ConstraintSolver(varPool)
val operScheme =
ctx.types.get(name) match {
case Some(scheme @ TlaType1Scheme(OperT1(_, _), _)) =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package at.forsyte.apalache.tla.typecheck.etc

import at.forsyte.apalache.tla.lir._

import scala.annotation.tailrec
import scala.collection.immutable.SortedMap

/**
Expand All @@ -15,10 +16,12 @@ import scala.collection.immutable.SortedMap
*
* <p>This class is not designed for concurrency. Use different instances in different threads.</p>
*
* @param varPool
* variable pool that is used to create fresh variables
* @author
* Igor Konnov
*/
class TypeUnifier {
class TypeUnifier(varPool: TypeVarPool) {
// A variable is mapped to its equivalence class. By default, a variable sits in the singleton equivalence class
// of its own. When two variables are unified, they are merged in the same equivalence class.
private var varToClass: Map[Int, EqClass] = Map.empty
Expand Down Expand Up @@ -87,32 +90,33 @@ class TypeUnifier {
}
}

private def compute(lhs: TlaType1, rhs: TlaType1): Option[TlaType1] = {
// Try to unify a variable with a non-variable term `typeTerm`.
// If `typeTerm` refers to a variable in the equivalence class of `typeVar`, then this is a cyclic reference,
// and there should be no unifier.
def unifyVarWithNonVarTerm(typeVar: Int, typeTerm: TlaType1): Option[TlaType1] = {
// Note that `typeTerm` is not a variable.
val varClass = varToClass(typeVar)
if (doesUseClass(typeTerm, varClass)) {
// No unifier: `typeTerm` refers to a variable in the equivalence class of `typeVar`.
None
} else {
// this variable is associated with an equivalence class, unify the class with `typeTerm`
solution(varClass) match {
case VarT1(_) =>
// an equivalence class of free variables, just assign `typeTerm` to this class
solution += varClass -> typeTerm
Some(typeTerm)

case _ =>
// unify `typeTerm` with the term assigned to the equivalence class, if possible
val unifier = compute(solution(varClass), typeTerm)
unifier.foreach { t => solution += varClass -> t }
unifier
}
// Try to unify a variable with a non-variable term `typeTerm`.
// If `typeTerm` refers to a variable in the equivalence class of `typeVar`, then this is a cyclic reference,
// and there should be no unifier.
private def unifyVarWithNonVarTerm(typeVar: Int, typeTerm: TlaType1): Option[TlaType1] = {
// Note that `typeTerm` is not a variable.
val varClass = varToClass(typeVar)
if (doesUseClass(typeTerm, varClass)) {
// No unifier: `typeTerm` refers to a variable in the equivalence class of `typeVar`.
None
} else {
// this variable is associated with an equivalence class, unify the class with `typeTerm`
solution(varClass) match {
case VarT1(_) =>
// an equivalence class of free variables, just assign `typeTerm` to this class
solution += varClass -> typeTerm
Some(typeTerm)

case nonVar =>
// unify `typeTerm` with the term assigned to the equivalence class, if possible
val unifier = compute(nonVar, typeTerm)
unifier.foreach { t => solution += varClass -> t }
unifier
}
}
}

private def compute(lhs: TlaType1, rhs: TlaType1): Option[TlaType1] = {

// unify types as terms
(lhs, rhs) match {
Expand Down Expand Up @@ -201,7 +205,8 @@ class TypeUnifier {
case (l @ TupT1(_ @_*), r @ SparseTupT1(_)) =>
compute(r, l)

// records join their keys, but the values for the intersecting keys should unify
// Records join their keys, but the values for the intersecting keys should unify.
// This is the old unification rule for the records. For the new records, see the rule for RecRowT1.
case (RecT1(lfields), RecT1(rfields)) =>
val jointKeys = (lfields.keySet ++ rfields.keySet).toSeq
val pairs = jointKeys.map(key => (key, computeFields(key, lfields, rfields)))
Expand All @@ -212,12 +217,106 @@ class TypeUnifier {
Some(unifiedTuple)
}

case (RowT1(lfields, lv), RowT1(rfields, rv)) =>
unifyRows(lfields, rfields, lv, rv)

case (RecRowT1(RowT1(lfields, lv)), RecRowT1(RowT1(rfields, rv))) =>
unifyRows(lfields, rfields, lv, rv).map(t => RecRowT1(t))

case (VariantT1(RowT1(lfields, lv)), VariantT1(RowT1(rfields, rv))) =>
unifyRows(lfields, rfields, lv, rv).map(t => VariantT1(t))

// everything else does not unify
case _ =>
None // no unifier
}
}

// unify two rows
@tailrec
private def unifyRows(
lfields: SortedMap[String, TlaType1],
rfields: SortedMap[String, TlaType1],
lvar: Option[VarT1],
rvar: Option[VarT1]): Option[RowT1] = {
// assuming that a type is either a row, or a variable, make it a row type
def asRow(rowOpt: Option[TlaType1]): Option[RowT1] = rowOpt.map {
case r: RowT1 => r
case v: VarT1 => RowT1(v)
case tp => throw new IllegalStateException("Expected RowT1(_, _) or VarT1(_), found: " + tp)
}

// consider four cases
if (lfields.isEmpty) {
// the base case
(lvar, rvar) match {
case (None, None) =>
if (rfields.nonEmpty) None else Some(RowT1())

case (Some(lv), Some(rv)) =>
if (rfields.isEmpty) {
asRow(compute(lv, rv))
} else {
asRow(unifyVarWithNonVarTerm(lv.no, RowT1(rfields, rvar)))
}

case (Some(lv), None) =>
asRow(unifyVarWithNonVarTerm(lv.no, RowT1(rfields, None)))

case (None, Some(rv)) =>
if (rfields.isEmpty) {
// the only way to match is to make the right variable equal to the empty row
asRow(unifyVarWithNonVarTerm(rv.no, RowT1()))
} else {
// the left row is empty, whereas the right row is non-empty
None
}
}
} else if (rfields.isEmpty) {
// the symmetric case above
unifyRows(rfields, lfields, rvar, lvar)
} else {
val sharedFieldNames = lfields.keySet.intersect(rfields.keySet)
if (sharedFieldNames.isEmpty) {
// The easy case: no shared fields.
// The left row is (| lfields | lvar |).
// The right row is (| rfields | rvar |).
// Introduce a fresh type variable to contain the common tail.
val tailVar = freshVar()
// Unify lvar with (| rfields | tailVar |).
// Unify rvar with (| lfields | tailVar |).
// If both unifiers exist, the result is (| lfields | rfields | tailVar |).
if (
compute(lvar.getOrElse(RowT1()), RowT1(rfields, Some(tailVar))).isEmpty
|| compute(rvar.getOrElse(RowT1()), RowT1(lfields, Some(tailVar))).isEmpty
) {
None
} else {
// apply the computed substitution to obtain the whole row
asRow(Some(Substitution(solution).sub(RowT1(lfields, lvar))._1))
}
} else {
// the general case: some fields are shared
val lfieldsUniq = lfields.filter(p => !sharedFieldNames.keySet.contains(p._1))
val rfieldsUniq = rfields.filter(p => !sharedFieldNames.keySet.contains(p._1))
// Unify the disjoint fields and tail variables, see the above case
compute(RowT1(lfieldsUniq, lvar), RowT1(rfieldsUniq, rvar)) match {
case Some(RowT1(disjointFields, tailVar)) =>
// unify the shared fields, if possible
val unifiedSharedFields = sharedFieldNames.map(key => (key, compute(lfields(key), rfields(key))))
if (unifiedSharedFields.exists(_._2.isEmpty)) {
None
} else {
val finalSharedFields = SortedMap(unifiedSharedFields.map(p => (p._1, p._2.get)).toSeq: _*)
Some(RowT1(finalSharedFields ++ disjointFields, tailVar))
}

case _ => None
}
}
}
}

// unify two sequences
private def unifySeqs(ls: Seq[TlaType1], rs: Seq[TlaType1]): Option[Seq[TlaType1]] = {
val len = ls.length
Expand Down Expand Up @@ -289,6 +388,15 @@ class TypeUnifier {
}
new Substitution(Map[EqClass, TlaType1](mapping: _*))
}

// introduce a fresh variable
private def freshVar(): VarT1 = {
val fresh = varPool.fresh
val cls = EqClass(fresh.no)
varToClass += (fresh.no -> cls)
solution += (cls -> fresh)
fresh
}
}

object TypeUnifier {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -200,6 +200,12 @@ class TestDefaultType1Parser extends AnyFunSuite with Checkers with TlaType1Gen
assert(RowT1() == result)
}

test("single-variable row") {
val text = """(| c |)"""
val result = DefaultType1Parser.parseType(text)
assert(RowT1(VarT1("c")) == result)
}

test("concrete row") {
val text = """(| f: Int | g: c |)"""
val result = DefaultType1Parser.parseType(text)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,11 @@ import org.scalatestplus.junit.JUnitRunner

@RunWith(classOf[JUnitRunner])
class TestConstraintSolver extends AnyFunSuite with EasyMockSugar with EtcBuilder {
private val FIRST_VAR: Int = 100
private val parser: Type1Parser = DefaultType1Parser

test("unique solution") {
val solver = new ConstraintSolver
val solver = new ConstraintSolver(new TypeVarPool(FIRST_VAR))
// a disjunctive constraint that comes from a tuple constructor
// either a == (b, c) => <<b, c>>
val option1 = EqClause(VarT1("a"), OperT1(Seq(VarT1("b"), VarT1("c")), parser("<<b, c>>")))
Expand All @@ -30,7 +31,7 @@ class TestConstraintSolver extends AnyFunSuite with EasyMockSugar with EtcBuilde
}

test("multiple solutions") {
val solver = new ConstraintSolver
val solver = new ConstraintSolver(new TypeVarPool(FIRST_VAR))
// a disjunctive constraint that comes from a tuple constructor
// either a == (b, c) => <<b, c>>
val option1 = EqClause(VarT1("a"), OperT1(Seq(VarT1("b"), VarT1("c")), parser("<<b, c>>")))
Expand All @@ -47,7 +48,7 @@ class TestConstraintSolver extends AnyFunSuite with EasyMockSugar with EtcBuilde
}

test("constraints in the reverse order") {
val solver = new ConstraintSolver
val solver = new ConstraintSolver(new TypeVarPool(FIRST_VAR))
// The following constraints come in the order that is reverse to the one that is required to solve the constraints.
// These constraints are made up, they do not come from any real constraints that are produced by TLA+ operators.
val eq1 = EqClause(VarT1("a"), parser("(b, c) => b"))
Expand Down
Loading

0 comments on commit 856075f

Please sign in to comment.