Skip to content

Commit

Permalink
Add some tests
Browse files Browse the repository at this point in the history
  • Loading branch information
DavidGregory084 committed Dec 16, 2017
1 parent 19216ed commit 08287ba
Show file tree
Hide file tree
Showing 5 changed files with 165 additions and 26 deletions.
4 changes: 2 additions & 2 deletions core/src/main/scala/schemes/Fix.scala
Original file line number Diff line number Diff line change
Expand Up @@ -19,13 +19,13 @@ package schemes
trait FixModule {
type Fix[F[_]]

def fix[F[_]](f: F[schemes.Fix[F]]): Fix[F]
def apply[F[_]](f: F[schemes.Fix[F]]): Fix[F]
def unfix[F[_]](f: Fix[F]): F[schemes.Fix[F]]
}

private[schemes] object FixImpl extends FixModule {
type Fix[F[_]] = F[schemes.Fix[F]]

def fix[F[_]](f: F[schemes.Fix[F]]): Fix[F] = f
def apply[F[_]](f: F[schemes.Fix[F]]): Fix[F] = f
def unfix[F[_]](f: Fix[F]): F[schemes.Fix[F]] = f
}
64 changes: 43 additions & 21 deletions core/src/main/scala/schemes/Schemes.scala
Original file line number Diff line number Diff line change
Expand Up @@ -16,46 +16,68 @@

package schemes

import cats.{ Functor, Monad, Traverse }
import cats.{ Functor, Monad, Traverse, ~> }

object Schemes {
def cata[F[_], A](fix: Fix[F])(algebra: F[A] => A)(implicit F: Functor[F]): A = {
var fn: Fix[F] => A = null
fn = f => algebra(F.map(Fix.unfix(f))(fn))
fn(fix)
def loop(fix: Fix[F]): A = algebra(F.map(Fix.unfix(fix))(loop))
loop(fix)
}

def cataM[M[_], F[_], A](fix: Fix[F])(algebra: F[A] => M[A])(implicit M: Monad[M], T: Traverse[F]): M[A] = {
var fn: Fix[F] => M[A] = null
fn = f => M.flatMap(T.traverse(Fix.unfix(f))(fn))(algebra)
fn(fix)
def loop(fix: Fix[F]): M[A] = M.flatMap(T.traverse(Fix.unfix(fix))(loop))(algebra)
loop(fix)
}

def ana[F[_], A](a: A)(coalgebra: A => F[A])(implicit F: Functor[F]): Fix[F] = {
var fn: A => Fix[F] = null
fn = aa => Fix.fix[F](F.map(coalgebra(aa))(fn))
fn(a)
def loop(a: A): Fix[F] = Fix[F](F.map(coalgebra(a))(loop))
loop(a)
}

def anaM[M[_], F[_], A](a: A)(coalgebra: A => M[F[A]])(implicit M: Monad[M], T: Traverse[F]): M[Fix[F]] = {
var fn: A => M[Fix[F]] = null
fn = aa => M.flatMap(coalgebra(aa)) { fa =>
M.map(T.traverse(fa)(fn))(Fix.fix[F])
def loop(a: A): M[Fix[F]] = M.flatMap(coalgebra(a)) { fa =>
M.map(T.traverse(fa)(loop))(Fix.apply[F])
}
fn(a)

loop(a)
}

def hylo[F[_], A, B](a: A)(coalgebra: A => F[A], algebra: F[B] => B)(implicit F: Functor[F]): B = {
var fn: A => B = null
fn = aa => algebra(F.map(coalgebra(aa))(fn))
fn(a)
def loop(a: A): B = algebra(F.map(coalgebra(a))(loop))
loop(a)
}

def hyloM[M[_], F[_], A, B](a: A)(coalgebra: A => M[F[A]], algebra: F[B] => M[B])(implicit M: Monad[M], T: Traverse[F]): M[B] = {
var fn: A => M[B] = null
fn = aa => M.flatMap(coalgebra(aa)) { fa =>
M.flatMap(T.traverse(fa)(fn))(algebra)
def loop(a: A): M[B] = M.flatMap(coalgebra(a)) { fa =>
M.flatMap(T.traverse(fa)(loop))(algebra)
}

loop(a)
}

def prepro[F[_], A](fix: Fix[F])(pre: F ~> F, algebra: F[A] => A)(implicit F: Functor[F]): A = {
def loop(fixf: Fix[F]): A = {
val fa = F.map(Fix.unfix(fixf)) { fixf =>
loop(cata[F, Fix[F]](fixf) { fa =>
Fix[F](pre(fa))
})
}
algebra(fa)
}
fn(a)

loop(fix)
}

def postpro[F[_], A](a: A)(coalgebra: A => F[A], post: F ~> F)(implicit F: Functor[F]): Fix[F] = {
def loop(a: A): Fix[F] = {
val ffixf = F.map(coalgebra(a)) { aa =>
ana[F, Fix[F]](loop(aa)) { fixf =>
post(Fix.unfix(fixf))
}
}
Fix[F](ffixf)
}

loop(a)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -15,14 +15,27 @@
*/

package schemes
package data

import cats._

sealed trait ListF[A, B]
case class ConsF[A, B](head: A, tail: B) extends ListF[A, B]
case class NilF[A, B]() extends ListF[A, B]
sealed abstract class ListF[A, B]
final case class ConsF[A, B](head: A, tail: B) extends ListF[A, B]
final case class NilF[A, B]() extends ListF[A, B]

object ListF {
def apply[A](as: A*): Fix[ListF[A, ?]] =
Schemes.ana[ListF[A, ?], List[A]](as.toList) {
case Nil => NilF()
case h :: t => ConsF(h, t)
}

def toList[A](list: Fix[ListF[A, ?]]) =
Schemes.cata[ListF[A, ?], List[A]](list) {
case NilF() => Nil
case ConsF(h, t) => h :: t
}

implicit def schemesListFFunctor[A]: Functor[ListF[A, ?]] = new Functor[ListF[A, ?]] {
def map[B, C](fa: ListF[A, B])(f: B => C): ListF[A, C] = fa match {
case ConsF(head, tail) => ConsF(head, f(tail))
Expand Down
4 changes: 4 additions & 0 deletions core/src/main/scala/schemes/package.scala
Original file line number Diff line number Diff line change
Expand Up @@ -17,4 +17,8 @@
package object schemes {
val Fix: FixModule = FixImpl
type Fix[F[_]] = Fix.Fix[F]

implicit class FixOps[F[_]](private val fix: Fix[F]) extends AnyVal {
def unfix: F[Fix[F]] = Fix.unfix(fix)
}
}
100 changes: 100 additions & 0 deletions core/src/test/scala/schemes/SchemesSpec.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
package schemes

import org.scalatest._

import cats._
import schemes.data._

sealed abstract class MathExpr[A] extends Product with Serializable
case class Num[A](num: Int) extends MathExpr[A] {
def retag[B]: Num[B] = this.asInstanceOf[Num[B]]
}
case class Mul[A](l: A, r: A) extends MathExpr[A]
case class Add[A](l: A, r: A) extends MathExpr[A]

object MathExpr {
type Expr = Fix[MathExpr]
def num(i: Int): Expr = Fix[MathExpr](Num(i))
def add(l: Expr, r: Expr): Expr = Fix[MathExpr](Add(l, r))
def mul(l: Expr, r: Expr): Expr = Fix[MathExpr](Mul(l, r))
val evalAlgebra: MathExpr[Int] => Int = {
case Num(i) => i
case Mul(l, r) => l * r
case Add(l, r) => l + r
}

implicit val mathExprFunctor: Functor[MathExpr] = new Functor[MathExpr] {
def map[A, B](expr: MathExpr[A])(f: A => B) = expr match {
case num @ Num(_) => num.retag[B]
case Mul(l, r) => Mul(f(l), f(r))
case Add(l, r) => Add(f(l), f(r))
}
}
}

class SchemesSpec extends FlatSpec with Matchers {
"cata" should "evaluate MathExprs" in {
import MathExpr._

val two = add(num(1), num(1))
Schemes.cata(two)(evalAlgebra) shouldBe 2

val four = mul(num(2), num(2))
Schemes.cata(four)(evalAlgebra) shouldBe 4

val sixteen = add(num(2), add(num(3), num(11)))
Schemes.cata(sixteen)(evalAlgebra) shouldBe 16
}

"ana" should "unfold MathExprs" in {
import MathExpr._

val unfoldAdd = Schemes.ana[MathExpr, Int](5) { i =>
if (i < 2)
Num(i)
else
Add(1, i - 1)
}

unfoldAdd shouldBe add(num(1), add(num(1), add(num(1), add(num(1), num(1)))))
}

"hylo" should "unfold and then evaluate MathExprs" in {
import MathExpr._

Schemes.hylo[MathExpr, Int, Int](5)(
i => if (i < 2) Num(i) else Add(1, i - 1),
evalAlgebra) shouldBe 5
}

"prepro" should "apply a transformation at each layer before folding some structure" in {
val sum: ListF[Int, Int] => Int = {
case ConsF(h, t) => h + t
case NilF() => 0
}

val stopAtFive = Lambda[ListF[Int, ?] ~> ListF[Int, ?]] {
case ConsF(n, _) if n > 5 => NilF()
case other => other
}

val `1 to 10` = ListF(1 to 10: _*)

Schemes.prepro[ListF[Int, ?], Int](`1 to 10`)(
stopAtFive,
sum) shouldBe 15
}

"postpro" should "apply a transformation at each layer after unfolding some structure" in {
val stopAtFive = Lambda[ListF[Int, ?] ~> ListF[Int, ?]] {
case ConsF(n, _) if n > 5 => NilF()
case other => other
}

val `1 to 5` = ListF(1 to 5: _*)

Schemes.postpro[ListF[Int, ?], Int](1)(
i => if (i > 100) NilF() else ConsF(i, i + 1),
stopAtFive) shouldBe `1 to 5`
}
}

0 comments on commit 08287ba

Please sign in to comment.