Skip to content

Commit

Permalink
Merge pull request #330 from sideeffffect/fix-YamlDecoder
Browse files Browse the repository at this point in the history
Fix YamlDecoder and NaN tag
  • Loading branch information
lbialy authored Jul 30, 2024
2 parents fca6e33 + 1520652 commit 0bc9b52
Show file tree
Hide file tree
Showing 3 changed files with 77 additions and 46 deletions.
42 changes: 23 additions & 19 deletions core/shared/src/main/scala/org/virtuslab/yaml/Tag.scala
Original file line number Diff line number Diff line change
Expand Up @@ -29,29 +29,33 @@ object Tag {
val corePrimitives = Set(nullTag, boolean, int, float, str)
val coreSchemaValues = (corePrimitives ++ Set(seq, map)).map(_.value)

private val nullPattern = "null|Null|NULL|~".r
private val booleanPattern = "true|True|TRUE|false|False|FALSE".r
private val int10Pattern = "[-+]?[0-9]+".r
private val int8Pattern = "0o[0-7]+".r
private val int16Pattern = "0x[0-9a-fA-F]+".r
private val floatPattern = "[-+]?(\\.[0-9]+|[0-9]+(\\.[0-9]*)?)([eE][-+]?[0-9]+)?".r
private val minusInfinity = "-(\\.inf|\\.Inf|\\.INF)".r
private val plusInfinity = "\\+?(\\.inf|\\.Inf|\\.INF)".r
private[yaml] val nullPattern = "^(null|Null|NULL|~)?$".r
private[yaml] val falsePattern = "false|False|FALSE".r
private[yaml] val truePattern = "true|True|TRUE".r
private val int10Pattern = "[-+]?[0-9]+".r
private val int8Pattern = "0o[0-7]+".r
private val int16Pattern = "0x[0-9a-fA-F]+".r
private val floatPattern = "[-+]?(\\.[0-9]+|[0-9]+(\\.[0-9]*)?)([eE][-+]?[0-9]+)?".r
private[yaml] val minusInfinity = "-(\\.inf|\\.Inf|\\.INF)".r
private[yaml] val plusInfinity = "\\+?(\\.inf|\\.Inf|\\.INF)".r
private[yaml] val nan = "\\.nan|\\.NaN|\\.NAN".r

def resolveTag(value: String, style: Option[ScalarStyle] = None): Tag = {
val assumeString = style.exists(s => s == DoubleQuoted || s == SingleQuoted)
value match {
case null => nullTag
case _ if assumeString => str
case nullPattern(_*) => nullTag
case booleanPattern(_*) => boolean
case int10Pattern(_*) => int
case int8Pattern(_*) => int
case int16Pattern(_*) => int
case floatPattern(_*) => float
case minusInfinity(_*) => float
case plusInfinity(_*) => float
case _ => str
case null => nullTag
case _ if assumeString => str
case nullPattern(_*) => nullTag
case falsePattern(_*) => boolean
case truePattern(_*) => boolean
case int10Pattern(_*) => int
case int8Pattern(_*) => int
case int16Pattern(_*) => int
case floatPattern(_*) => float
case minusInfinity(_*) => float
case plusInfinity(_*) => float
case nan(_*) => float
case _ => str
}
}
}
69 changes: 43 additions & 26 deletions core/shared/src/main/scala/org/virtuslab/yaml/YamlDecoder.scala
Original file line number Diff line number Diff line change
Expand Up @@ -92,67 +92,83 @@ object YamlDecoder extends YamlDecoderCompanionCrossCompat {
tpe
)

implicit def forInt: YamlDecoder[Int] = YamlDecoder { case s @ ScalarNode(value, _) =>
val normalizedValue =
if (value.startsWith("0o")) value.stripPrefix("0o").prepended('0') else value
private def normalizeInt(string: String): String = {
val octal = if (string.startsWith("0o")) string.stripPrefix("0o").prepended('0') else string
octal.replaceAll("_", "")
}

Try(java.lang.Integer.decode(normalizedValue.replaceAll("_", "")).toInt).toEither.left
implicit def forInt: YamlDecoder[Int] = YamlDecoder { case s @ ScalarNode(value, _) =>
Try(java.lang.Integer.decode(normalizeInt(value)).toInt).toEither.left
.map(ConstructError.from(_, "Int", s))
}

implicit def forLong: YamlDecoder[Long] = YamlDecoder { case s @ ScalarNode(value, _) =>
val normalizedValue =
if (value.startsWith("0o")) value.stripPrefix("0o").prepended('0') else value

Try(java.lang.Long.decode(normalizedValue.replaceAll("_", "")).toLong).toEither.left
Try(java.lang.Long.decode(normalizeInt(value)).toLong).toEither.left
.map(ConstructError.from(_, "Long", s))
}

implicit def forDouble: YamlDecoder[Double] = YamlDecoder { case s @ ScalarNode(value, _) =>
val lowercased = value.toLowerCase
if (lowercased.endsWith("inf")) {
if (value.startsWith("-")) Right(Double.NegativeInfinity)
else Right(Double.PositiveInfinity)
} else if (lowercased.endsWith("nan")) {
if (Tag.nan.matches(value)) {
Right(Double.NaN)
} else if (Tag.plusInfinity.matches(value)) {
Right(Double.PositiveInfinity)
} else if (Tag.minusInfinity.matches(value)) {
Right(Double.NegativeInfinity)
} else {
Try(java.lang.Double.parseDouble(value.replaceAll("_", ""))).toEither.left
.map(ConstructError.from(_, "Double", s))
}
}

def forDoublePrecise: YamlDecoder[Double] = YamlDecoder { case s @ ScalarNode(value, _) =>
forDouble.construct(s).flatMap { n =>
val ns = n.toString
if (ns == value) Right(n) else Left(ConstructError.from(s"Double, decoded $ns", s))
}
}

implicit def forFloat: YamlDecoder[Float] = YamlDecoder { case s @ ScalarNode(value, _) =>
val lowercased = value.toLowerCase
if (lowercased.endsWith("inf")) {
if (value.startsWith("-")) Right(Float.NegativeInfinity)
else Right(Float.PositiveInfinity)
} else if (lowercased.endsWith("nan")) {
if (Tag.nan.matches(value)) {
Right(Float.NaN)
} else if (Tag.plusInfinity.matches(value)) {
Right(Float.PositiveInfinity)
} else if (Tag.minusInfinity.matches(value)) {
Right(Float.NegativeInfinity)
} else {
Try(java.lang.Float.parseFloat(value.replaceAll("_", ""))).toEither.left
.map(ConstructError.from(_, "Float", s))
}
}

implicit def forShort: YamlDecoder[Short] = YamlDecoder { case s @ ScalarNode(value, _) =>
val normalizedValue =
if (value.startsWith("0o")) value.stripPrefix("0o").prepended('0') else value
def forFloatPrecise: YamlDecoder[Float] = YamlDecoder { case s @ ScalarNode(value, _) =>
forFloat.construct(s).flatMap { n =>
val ns = n.toString
if (ns == value) Right(n) else Left(ConstructError.from(s"Float, decoded $ns", s))
}
}

Try(java.lang.Short.decode(normalizedValue.replaceAll("_", "")).toShort).toEither.left
implicit def forShort: YamlDecoder[Short] = YamlDecoder { case s @ ScalarNode(value, _) =>
Try(java.lang.Short.decode(normalizeInt(value)).toShort).toEither.left
.map(ConstructError.from(_, "Short", s))
}

implicit def forByte: YamlDecoder[Byte] = YamlDecoder { case s @ ScalarNode(value, _) =>
Try(java.lang.Byte.decode(value.replaceAll("_", "")).toByte).toEither.left
Try(java.lang.Byte.decode(normalizeInt(value)).toByte).toEither.left
.map(ConstructError.from(_, "Byte", s))
}

implicit def forBoolean: YamlDecoder[Boolean] = YamlDecoder { case s @ ScalarNode(value, _) =>
value.toBooleanOption.toRight(cannotParse(value, "Boolean", s))
if (Tag.falsePattern.matches(value)) {
Right(false)
} else if (Tag.truePattern.matches(value)) {
Right(true)
} else {
Left(cannotParse(value, "Boolean", s))
}
}

implicit def forBigInt: YamlDecoder[BigInt] = YamlDecoder { case s @ ScalarNode(value, _) =>
Try(BigInt(value.replaceAll("_", ""))).toEither.left
Try(BigInt(normalizeInt(value))).toEither.left
.map(ConstructError.from(_, "BigInt", s))
}

Expand All @@ -177,8 +193,9 @@ object YamlDecoder extends YamlDecoderCompanionCrossCompat {
.orElse(forBigInt.widen)
.construct(node)
case node @ ScalarNode(_, Tag.float) =>
forDouble
forFloatPrecise
.widen[Any]
.orElse(forDoublePrecise.widen)
.orElse(forBigDecimal.widen)
.construct(node)
case ScalarNode(value, Tag.str) =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -321,7 +321,7 @@ class DecoderSuite extends munit.FunSuite:
123 -> 321,
"string" -> "aezakmi",
true -> false,
5.5f -> 55.55d
5.5f -> 55.55f
)

assertEquals(yaml.as[Map[Any, Any]], Right(expected))
Expand Down Expand Up @@ -565,3 +565,13 @@ class DecoderSuite extends munit.FunSuite:
assertEquals(foo.b, "from yaml")
assert(!evaluated)
}

test("Fails decoding -XXXinf as Float") {
val yaml = "-XXXinf"

yaml.as[Float] match
case Left(e: ConstructError) =>
assertEquals(e.expected, Some("Float"))
case Right(value) =>
fail(s"Should fail, but got $value")
}

0 comments on commit 0bc9b52

Please sign in to comment.