This commit is contained in:
dmitry 2021-03-05 14:26:56 +03:00
parent 8b1eae4c9c
commit 716ea53240
2 changed files with 34 additions and 8 deletions

View File

@ -97,13 +97,13 @@ object Type {
case (ProductType(_, xFields), ProductType(_, yFields)) =>
cmpProd(xFields, yFields)
case (ArrowType(argL, resL), ArrowType(argR, resR)) =>
val cmpTypes = cmpTypesList(argL, argR)
val cmpTypes = cmpTypesList(argR, argL)
val cmpRes =
if (resL == resR) 0.0
else (resL, resR).mapN(cmp).getOrElse(NaN)
if (cmpTypes >= 0 && cmpRes <= 0) 1.0
else if (cmpTypes <= 0 && cmpRes >= 0) -1.0
if (cmpTypes >= 0 && cmpRes >= 0) 1.0
else if (cmpTypes <= 0 && cmpRes <= 0) -1.0
else NaN
case (x: FuncArrowType, y: ArrowType) =>

View File

@ -5,6 +5,7 @@ import org.scalatest.matchers.should.Matchers
import cats.syntax.partialOrder._
import Type.typesPartialOrder
import cats.data.NonEmptyMap
import cats.kernel.PartialOrder
class TypeSpec extends AnyFlatSpec with Matchers {
@ -12,8 +13,11 @@ class TypeSpec extends AnyFlatSpec with Matchers {
def `[]`(t: DataType): DataType = ArrayType(t)
def accepts(recv: Type, incoming: Type) =
recv >= incoming
"scalar types" should "be variant" in {
(i32: Type) < i64 should be(true)
accepts(i64, i32) should be(true)
(i32: Type) <= i32 should be(true)
(i32: Type) >= i32 should be(true)
(i32: Type) > i32 should be(false)
@ -40,25 +44,47 @@ class TypeSpec extends AnyFlatSpec with Matchers {
"products of scalars" should "be variant" in {
val one: Type = ProductType("one", NonEmptyMap.of("field" -> i32))
val two: Type = ProductType("two", NonEmptyMap.of("field" -> i64, "other" -> string))
val three: Type = ProductType("three", NonEmptyMap.of("field" -> i32))
one < two should be(true)
two > one should be(true)
PartialOrder[Type].eqv(one, three) should be(true)
}
"arrows" should "be variant on arguments" in {
val one: Type = ArrowType(i32 :: Nil, None)
val two: Type = ArrowType(i64 :: Nil, None)
one < two should be(true)
two > one should be(true)
accepts(one, two) should be(true)
one > two should be(true)
two < one should be(true)
}
"arrows" should "be contravariant on results" in {
val one: Type = ArrowType(Nil, Some(i64))
val two: Type = ArrowType(Nil, Some(i32))
one < two should be(true)
two > one should be(true)
accepts(one, two) should be(true)
one > two should be(true)
two < one should be(true)
}
"arrows" should "respect both args and results" in {
val one: Type = ArrowType(bool :: f64 :: Nil, Some(i64))
val two: Type = ArrowType(bool :: Nil, Some(i64))
val three: Type = ArrowType(bool :: f32 :: Nil, Some(i64))
val four: Type = ArrowType(bool :: f32 :: Nil, Some(i32))
accepts(one, two) should be(false)
accepts(two, one) should be(false)
accepts(one, three) should be(false)
accepts(three, one) should be(true)
accepts(one, four) should be(false)
accepts(four, one) should be(false)
}
}