diff --git a/aqua-src/print.aqua b/aqua-src/print.aqua deleted file mode 100644 index 5ffdcc24..00000000 --- a/aqua-src/print.aqua +++ /dev/null @@ -1,5 +0,0 @@ -service Println("println-service-id"): - print: string -> () - -func print(str: string): - Println.print(str) \ No newline at end of file diff --git a/aqua-src/ret.aqua b/aqua-src/ret.aqua new file mode 100644 index 00000000..087e40c7 --- /dev/null +++ b/aqua-src/ret.aqua @@ -0,0 +1,16 @@ +data DT: + field: string + +service DTGetter("get-dt"): + get_dt(s: string) -> DT + +func use_name1(name: string) -> string: + results <- DTGetter.get_dt(name) + <- results.field + +func use_name2(name: string) -> []string: + results: *string + results <- use_name1(name) + results <- use_name1(name) + results <- use_name1(name) + <- results \ No newline at end of file diff --git a/aqua-src/test.aqua b/aqua-src/test.aqua deleted file mode 100644 index 39d31f7a..00000000 --- a/aqua-src/test.aqua +++ /dev/null @@ -1,11 +0,0 @@ -service Op("op"): - noop: -> () - bz: string -> string - -func return_none() -> string: - <- "some result in string" - -func use() -> string: - res <- return_none() - res2 <- Op.bz(res) - <- res \ No newline at end of file diff --git a/backend/air/src/main/scala/aqua/backend/air/AirGen.scala b/backend/air/src/main/scala/aqua/backend/air/AirGen.scala index 806b9f39..f4d3b623 100644 --- a/backend/air/src/main/scala/aqua/backend/air/AirGen.scala +++ b/backend/air/src/main/scala/aqua/backend/air/AirGen.scala @@ -80,7 +80,7 @@ object AirGen extends LogSupport { case FoldRes(item, iterable) => Eval later ForGen(valueToData(iterable), item, opsToSingle(ops)) - case CallServiceRes(serviceId, funcName, Call(args, exportTo), peerId) => + case CallServiceRes(serviceId, funcName, CallRes(args, exportTo), peerId) => Eval.later( ServiceCallGen( valueToData(peerId), diff --git a/backend/js/src/main/scala/aqua/backend/js/JavaScriptFunc.scala b/backend/js/src/main/scala/aqua/backend/js/JavaScriptFunc.scala index a8c3f2bc..29aa8748 100644 --- a/backend/js/src/main/scala/aqua/backend/js/JavaScriptFunc.scala +++ b/backend/js/src/main/scala/aqua/backend/js/JavaScriptFunc.scala @@ -1,7 +1,7 @@ package aqua.backend.js import aqua.backend.air.FuncAirGen -import aqua.model.func.{ArgDef, FuncCallable} +import aqua.model.func.FuncCallable import aqua.model.transform.GenerationConfig import aqua.types._ import cats.syntax.show._ @@ -11,7 +11,7 @@ case class JavaScriptFunc(func: FuncCallable) { import JavaScriptFunc._ def argsJavaScript: String = - func.args.args.map(ad => s"${ad.name}").mkString(", ") + func.argNames.mkString(", ") // TODO: use common functions between TypeScript and JavaScript backends private def genReturnCallback( @@ -42,26 +42,29 @@ case class JavaScriptFunc(func: FuncCallable) { val tsAir = FuncAirGen(func).generateAir(conf) - val setCallbacks = func.args.args.map { - case ArgDef.Data(argName, OptionType(_)) => + val setCallbacks = func.args.collect { + case (argName, OptionType(_)) => s"""h.on('${conf.getDataService}', '$argName', () => {return $argName === null ? [] : [$argName];});""" - case ArgDef.Data(argName, _) => + case (argName, _: DataType) => s"""h.on('${conf.getDataService}', '$argName', () => {return $argName;});""" - case ArgDef.Arrow(argName, at) => + case (argName, at: ArrowType) => val value = s"$argName(${argsCallToJs( at )})" val expr = at.res.fold(s"$value; return {}")(_ => s"return $value") s"""h.on('${conf.callbackService}', '$argName', (args) => {$expr;});""" - }.mkString("\n") + } + .mkString("\n") - val returnCallback = func.ret - .map(_._2) + // TODO support multi-return + val returnCallback = func.arrowType.codomain.uncons + .map(_._1) .map(t => genReturnCallback(t, conf.callbackService, conf.respFuncName)) .getOrElse("") + // TODO support multi-return val returnVal = - func.ret.fold("Promise.race([promise, Promise.resolve()])")(_ => "promise") + func.ret.headOption.fold("Promise.race([promise, Promise.resolve()])")(_ => "promise") // TODO: it could be non-unique val configArgName = "config" @@ -113,12 +116,9 @@ case class JavaScriptFunc(func: FuncCallable) { object JavaScriptFunc { def argsToTs(at: ArrowType): String = - at.args.zipWithIndex - .map(_.swap) - .map(kv => "arg" + kv._1) - .mkString(", ") + at.domain.toLabelledList().map(_._1).mkString(", ") def argsCallToJs(at: ArrowType): String = - at.args.zipWithIndex.map(_._2).map(idx => s"args[$idx]").mkString(", ") + at.domain.toList.zipWithIndex.map(_._2).map(idx => s"args[$idx]").mkString(", ") } diff --git a/backend/ts/src/main/scala/aqua/backend/ts/TypeScriptFunc.scala b/backend/ts/src/main/scala/aqua/backend/ts/TypeScriptFunc.scala index 1144c0cf..6a92c49a 100644 --- a/backend/ts/src/main/scala/aqua/backend/ts/TypeScriptFunc.scala +++ b/backend/ts/src/main/scala/aqua/backend/ts/TypeScriptFunc.scala @@ -1,7 +1,7 @@ package aqua.backend.ts import aqua.backend.air.FuncAirGen -import aqua.model.func.{ArgDef, FuncCallable} +import aqua.model.func.FuncCallable import aqua.model.transform.GenerationConfig import aqua.types._ import cats.syntax.show._ @@ -11,7 +11,7 @@ case class TypeScriptFunc(func: FuncCallable) { import TypeScriptFunc._ def argsTypescript: String = - func.args.args.map(ad => s"${ad.name}: " + typeToTs(ad.`type`)).mkString(", ") + func.arrowType.domain.toLabelledList().map(ad => s"${ad._1}: " + typeToTs(ad._2)).mkString(", ") def generateUniqueArgName(args: List[String], basis: String, attempt: Int): String = { val name = if (attempt == 0) { @@ -50,42 +50,45 @@ case class TypeScriptFunc(func: FuncCallable) { val tsAir = FuncAirGen(func).generateAir(conf) - val retType = func.ret - .map(_._2) + // TODO: support multi return + val retType = func.arrowType.codomain.uncons + .map(_._1) + val retTypeTs = retType .fold("void")(typeToTs) - val returnCallback = func.ret - .map(_._2) + val returnCallback = retType .map(t => genReturnCallback(t, conf.callbackService, conf.respFuncName)) .getOrElse("") - val setCallbacks = func.args.args.map { - case ArgDef.Data(argName, OptionType(_)) => + val setCallbacks = func.args.collect { // Product types are not handled + case (argName, OptionType(_)) => s"""h.on('${conf.getDataService}', '$argName', () => {return $argName === null ? [] : [$argName];});""" - case ArgDef.Data(argName, _) => + case (argName, _: DataType) => s"""h.on('${conf.getDataService}', '$argName', () => {return $argName;});""" - case ArgDef.Arrow(argName, at) => + case (argName, at: ArrowType) => val value = s"$argName(${argsCallToTs( at )})" val expr = at.res.fold(s"$value; return {}")(_ => s"return $value") s"""h.on('${conf.callbackService}', '$argName', (args) => {$expr;});""" - }.mkString("\n") + } + .mkString("\n") + // TODO support multi return val returnVal = - func.ret.fold("Promise.race([promise, Promise.resolve()])")(_ => "promise") + func.ret.headOption.fold("Promise.race([promise, Promise.resolve()])")(_ => "promise") - val clientArgName = generateUniqueArgName(func.args.args.map(_.name), "client", 0) - val configArgName = generateUniqueArgName(func.args.args.map(_.name), "config", 0) + val clientArgName = generateUniqueArgName(func.argNames, "client", 0) + val configArgName = generateUniqueArgName(func.argNames, "config", 0) val configType = "{ttl?: number}" s""" |export async function ${func.funcName}($clientArgName: FluenceClient${if (func.args.isEmpty) "" - else ", "}${argsTypescript}, $configArgName?: $configType): Promise<$retType> { + else ", "}${argsTypescript}, $configArgName?: $configType): Promise<$retTypeTs> { | let request: RequestFlow; - | const promise = new Promise<$retType>((resolve, reject) => { + | const promise = new Promise<$retTypeTs>((resolve, reject) => { | const r = new RequestFlowBuilder() | .disableInjections() | .withRawScript( @@ -130,7 +133,7 @@ object TypeScriptFunc { case OptionType(t) => typeToTs(t) + " | null" case ArrayType(t) => typeToTs(t) + "[]" case StreamType(t) => typeToTs(t) + "[]" - case pt: ProductType => + case pt: StructType => s"{${pt.fields.map(typeToTs).toNel.map(kv => kv._1 + ":" + kv._2).toList.mkString(";")}}" case st: ScalarType if ScalarType.number(st) => "number" case ScalarType.bool => "boolean" @@ -142,14 +145,15 @@ object TypeScriptFunc { case at: ArrowType => s"(${argsToTs(at)}) => ${at.res .fold("void")(typeToTs)}" + case _ => + // TODO: handle product types in returns + "any" } def argsToTs(at: ArrowType): String = - at.args - .map(typeToTs) - .zipWithIndex - .map(_.swap) - .map(kv => "arg" + kv._1 + ": " + kv._2) + at.domain + .toLabelledList() + .map(nt => nt._1 + ": " + typeToTs(nt._2)) .mkString(", ") def argsCallToTs(at: ArrowType): String = diff --git a/model/src/main/scala/aqua/model/AquaContext.scala b/model/src/main/scala/aqua/model/AquaContext.scala index 11b0e015..cc611779 100644 --- a/model/src/main/scala/aqua/model/AquaContext.scala +++ b/model/src/main/scala/aqua/model/AquaContext.scala @@ -2,10 +2,9 @@ package aqua.model import aqua.model.func.raw.{CallServiceTag, FuncOp} import aqua.model.func.{ArgsCall, FuncCallable, FuncModel} -import aqua.types.{ProductType, Type} +import aqua.types.{StructType, Type} import cats.Monoid import cats.data.NonEmptyMap -import cats.syntax.apply.* import cats.syntax.functor.* import cats.syntax.monoid.* import wvlet.log.LogSupport @@ -52,7 +51,7 @@ case class AquaContext( } .map(prefixFirst(prefix, _)) - def `type`(name: String): Option[ProductType] = + def `type`(name: String): Option[StructType] = NonEmptyMap .fromMap( SortedMap.from( @@ -65,7 +64,7 @@ case class AquaContext( } ) ) - .map(ProductType(name, _)) + .map(StructType(name, _)) } object AquaContext extends LogSupport { @@ -104,8 +103,8 @@ object AquaContext extends LogSupport { fnName, // TODO: capture ability resolution, get ID from the call context FuncOp.leaf(CallServiceTag(serviceId, fnName, call)), - args, - (ret.map(_.model), arrowType.res).mapN(_ -> _), + arrowType, + ret.map(_.model), Map.empty, Map.empty ) diff --git a/model/src/main/scala/aqua/model/ServiceModel.scala b/model/src/main/scala/aqua/model/ServiceModel.scala index 1bf01ce1..e8b1a331 100644 --- a/model/src/main/scala/aqua/model/ServiceModel.scala +++ b/model/src/main/scala/aqua/model/ServiceModel.scala @@ -1,6 +1,6 @@ package aqua.model -import aqua.types.{ArrowType, ProductType} +import aqua.types.{ArrowType, StructType} import cats.data.NonEmptyMap case class ServiceModel( @@ -8,5 +8,5 @@ case class ServiceModel( arrows: NonEmptyMap[String, ArrowType], defaultId: Option[ValueModel] ) extends Model { - def `type`: ProductType = ProductType(name, arrows) + def `type`: StructType = StructType(name, arrows) } diff --git a/model/src/main/scala/aqua/model/ValueModel.scala b/model/src/main/scala/aqua/model/ValueModel.scala index 7e2e1d01..dcd93f69 100644 --- a/model/src/main/scala/aqua/model/ValueModel.scala +++ b/model/src/main/scala/aqua/model/ValueModel.scala @@ -100,7 +100,7 @@ object VarModel { val lastError: VarModel = VarModel( "%last_error%", - ProductType( + StructType( "LastError", NonEmptyMap.of( "instruction" -> ScalarType.string, @@ -112,6 +112,6 @@ object VarModel { val nil: VarModel = VarModel( "nil", - StreamType(DataType.Bottom) + StreamType(BottomType) ) } diff --git a/model/src/main/scala/aqua/model/func/ArgDef.scala b/model/src/main/scala/aqua/model/func/ArgDef.scala deleted file mode 100644 index e3e297a8..00000000 --- a/model/src/main/scala/aqua/model/func/ArgDef.scala +++ /dev/null @@ -1,12 +0,0 @@ -package aqua.model.func - -import aqua.types.{ArrowType, DataType, Type} - -sealed abstract class ArgDef(val `type`: Type) { - def name: String -} - -object ArgDef { - case class Data(name: String, dataType: DataType) extends ArgDef(dataType) - case class Arrow(name: String, arrowType: ArrowType) extends ArgDef(arrowType) -} diff --git a/model/src/main/scala/aqua/model/func/ArgsCall.scala b/model/src/main/scala/aqua/model/func/ArgsCall.scala index e592b1cc..cedaac81 100644 --- a/model/src/main/scala/aqua/model/func/ArgsCall.scala +++ b/model/src/main/scala/aqua/model/func/ArgsCall.scala @@ -1,27 +1,26 @@ package aqua.model.func import aqua.model.{ValueModel, VarModel} -import aqua.types.{ArrowType, DataType} -import cats.syntax.functor.* +import aqua.types.{ArrowType, DataType, ProductType, Type} /** * Wraps argument definitions of a function, along with values provided when this function is called * @param args Argument definitions * @param callWith Values provided for arguments */ -case class ArgsCall(args: List[ArgDef], callWith: List[ValueModel]) { +case class ArgsCall(args: ProductType, callWith: List[ValueModel]) { // Both arguments (arg names and types how they seen from the function body) // and values (value models and types how they seen on the call site) - lazy val zipped: List[(ArgDef, ValueModel)] = args zip callWith + lazy val zipped: List[((String, Type), ValueModel)] = args.toLabelledList() zip callWith lazy val dataArgs: Map[String, ValueModel] = - zipped.collect { case (ArgDef.Data(name, _), value) => + zipped.collect { case ((name, _: DataType), value) => name -> value }.toMap def arrowArgs(arrowsInScope: Map[String, FuncCallable]): Map[String, FuncCallable] = zipped.collect { - case (ArgDef.Arrow(name, _), VarModel(value, _, _)) if arrowsInScope.contains(value) => + case ((name, _: ArrowType), VarModel(value, _, _)) if arrowsInScope.contains(value) => name -> arrowsInScope(value) }.toMap } @@ -32,22 +31,18 @@ object ArgsCall { arrow: ArrowType, argPrefix: String = "arg", retName: String = "init_call_res" - ): (ArgsDef, Call, Option[Call.Export]) = { - val argNamesTypes = arrow.args.zipWithIndex.map { case (t, i) => (argPrefix + i, t) } - - val argsDef = ArgsDef(argNamesTypes.map { - case (a, t: DataType) => ArgDef.Data(a, t) - case (a, t: ArrowType) => ArgDef.Arrow(a, t) - }) + ): (ProductType, Call, List[Call.Export]) = { + val argNamesTypes = arrow.domain.toLabelledList(argPrefix) + val res = arrow.codomain.toLabelledList(retName).map(Call.Export(_, _)) val call = Call( argNamesTypes.map { case (a, t) => VarModel(a, t) }, - arrow.res.map(Call.Export(retName, _)) + res ) - (argsDef, call, arrow.res.map(t => Call.Export(retName, t))) + (arrow.domain, call, res) } } diff --git a/model/src/main/scala/aqua/model/func/ArgsDef.scala b/model/src/main/scala/aqua/model/func/ArgsDef.scala deleted file mode 100644 index 4076bc49..00000000 --- a/model/src/main/scala/aqua/model/func/ArgsDef.scala +++ /dev/null @@ -1,27 +0,0 @@ -package aqua.model.func - -import aqua.model.VarModel -import aqua.types.Type -import cats.data.Chain - -case class ArgsDef(args: List[ArgDef]) { - def isEmpty: Boolean = args.isEmpty - - def call(c: Call): ArgsCall = ArgsCall(args, c.args) - - def types: List[Type] = args.map(_.`type`) - - def toCallArgs: List[VarModel] = args.map(ad => VarModel(ad.name, ad.`type`)) - - lazy val dataArgs: Chain[ArgDef.Data] = Chain.fromSeq(args.collect { case ad: ArgDef.Data => - ad - }) - - lazy val arrowArgs: Chain[ArgDef.Arrow] = Chain.fromSeq(args.collect { case ad: ArgDef.Arrow => - ad - }) -} - -object ArgsDef { - val empty: ArgsDef = ArgsDef(Nil) -} diff --git a/model/src/main/scala/aqua/model/func/Call.scala b/model/src/main/scala/aqua/model/func/Call.scala index f715ce41..acb84168 100644 --- a/model/src/main/scala/aqua/model/func/Call.scala +++ b/model/src/main/scala/aqua/model/func/Call.scala @@ -3,7 +3,7 @@ package aqua.model.func import aqua.model.{ValueModel, VarModel} import aqua.types.Type -case class Call(args: List[ValueModel], exportTo: Option[Call.Export]) { +case class Call(args: List[ValueModel], exportTo: List[Call.Export]) { def mapValues(f: ValueModel => ValueModel): Call = Call( @@ -18,7 +18,7 @@ case class Call(args: List[ValueModel], exportTo: Option[Call.Export]) { }.toSet override def toString: String = - s"[${args.mkString(" ")}]${exportTo.map(_.model).map(" " + _).getOrElse("")}" + s"[${args.mkString(" ")}]${exportTo.map(_.model).map(" " + _).mkString(",")}" } object Call { diff --git a/model/src/main/scala/aqua/model/func/FuncCallable.scala b/model/src/main/scala/aqua/model/func/FuncCallable.scala index a0e756db..19749cf1 100644 --- a/model/src/main/scala/aqua/model/func/FuncCallable.scala +++ b/model/src/main/scala/aqua/model/func/FuncCallable.scala @@ -1,9 +1,9 @@ package aqua.model.func import aqua.model.ValueModel.varName -import aqua.model.func.raw._ +import aqua.model.func.raw.* import aqua.model.{Model, ValueModel, VarModel} -import aqua.types.{ArrowType, StreamType, Type} +import aqua.types.{ArrowType, ProductType, StreamType, Type} import cats.Eval import cats.data.Chain import cats.free.Cofree @@ -12,8 +12,8 @@ import wvlet.log.Logger case class FuncCallable( funcName: String, body: FuncOp, - args: ArgsDef, - ret: Option[(ValueModel, Type)], + arrowType: ArrowType, + ret: List[ValueModel], capturedArrows: Map[String, FuncCallable], capturedValues: Map[String, ValueModel] ) extends Model { @@ -21,11 +21,8 @@ case class FuncCallable( private val logger = Logger.of[FuncCallable] import logger._ - def arrowType: ArrowType = - ArrowType( - args.types, - ret.map(_._2) - ) + lazy val args: List[(String, Type)] = arrowType.domain.toLabelledList() + lazy val argNames: List[String] = args.map(_._1) def findNewNames(forbidden: Set[String], introduce: Set[String]): Map[String, String] = (forbidden intersect introduce).foldLeft(Map.empty[String, String]) { case (acc, name) => @@ -49,12 +46,12 @@ case class FuncCallable( call: Call, arrows: Map[String, FuncCallable], forbiddenNames: Set[String] - ): Eval[(FuncOp, Option[ValueModel])] = { + ): Eval[(FuncOp, List[ValueModel])] = { debug("Call: " + call) // Collect all arguments: what names are used inside the function, what values are received - val argsFull = args.call(call) + val argsFull = ArgsCall(arrowType.domain, call.args) // DataType arguments val argsToDataRaw = argsFull.dataArgs // Arrow arguments: expected type is Arrow, given by-name @@ -96,7 +93,7 @@ case class FuncCallable( val treeRenamed = treeWithValues.rename(shouldRename) // Result could be derived from arguments, or renamed; take care about that - val result = ret.map(_._1).map(_.resolveWith(argsToData)).map { + val result: List[ValueModel] = ret.map(_.resolveWith(argsToData)).map { case v: VarModel if shouldRename.contains(v.name) => v.copy(shouldRename(v.name)) case v => v } @@ -155,23 +152,22 @@ case class FuncCallable( } .map { case ((_, resolvedExports), callableFuncBody) => // If return value is affected by any of internal functions, resolve it - (for { - exp <- call.exportTo - res <- result - pair <- exp match { - case Call.Export(name, StreamType(_)) => - val resolved = res.resolveWith(resolvedExports) - // path nested function results to a stream - Some( - FuncOps.seq(FuncOp(callableFuncBody), FuncOps.identity(resolved, exp)) -> Some( - exp.model - ) - ) - case _ => None + val resolvedResult = result.map(_.resolveWith(resolvedExports)) + + val (ops, rets) = (call.exportTo zip resolvedResult) + .map[(Option[FuncOp], ValueModel)] { + case (exp @ Call.Export(_, StreamType(_)), res) => + // pass nested function results to a stream + Some(FuncOps.identity(res, exp)) -> exp.model + case (_, res) => + None -> res } - } yield { - pair - }).getOrElse(FuncOp(callableFuncBody) -> result.map(_.resolveWith(resolvedExports))) + .foldLeft[(List[FuncOp], List[ValueModel])]((FuncOp(callableFuncBody) :: Nil, Nil)) { + case ((ops, rets), (Some(fo), r)) => (fo :: ops, r :: rets) + case ((ops, rets), (_, r)) => (ops, r :: rets) + } + + FuncOps.seq(ops.reverse: _*) -> rets } } diff --git a/model/src/main/scala/aqua/model/func/FuncModel.scala b/model/src/main/scala/aqua/model/func/FuncModel.scala index f8719d5e..367f3651 100644 --- a/model/src/main/scala/aqua/model/func/FuncModel.scala +++ b/model/src/main/scala/aqua/model/func/FuncModel.scala @@ -2,12 +2,12 @@ package aqua.model.func import aqua.model.func.raw.FuncOp import aqua.model.{Model, ValueModel} -import aqua.types.Type +import aqua.types.ArrowType case class FuncModel( name: String, - args: ArgsDef, - ret: Option[(ValueModel, Type)], + arrowType: ArrowType, + ret: List[ValueModel], body: FuncOp ) extends Model { @@ -15,6 +15,6 @@ case class FuncModel( arrows: Map[String, FuncCallable], constants: Map[String, ValueModel] ): FuncCallable = - FuncCallable(name, body.fixXorPar, args, ret, arrows, constants) + FuncCallable(name, body.fixXorPar, arrowType, ret, arrows, constants) } diff --git a/model/src/main/scala/aqua/model/func/raw/FuncOp.scala b/model/src/main/scala/aqua/model/func/raw/FuncOp.scala index 4384e0ad..5a408705 100644 --- a/model/src/main/scala/aqua/model/func/raw/FuncOp.scala +++ b/model/src/main/scala/aqua/model/func/raw/FuncOp.scala @@ -23,19 +23,19 @@ case class FuncOp(tree: Cofree[Chain, RawTag]) extends Model { Cofree.cata(tree)(folder) def definesVarNames: Eval[Set[String]] = cata[Set[String]] { - case (CallArrowTag(_, Call(_, Some(exportTo))), acc) => - Eval.later(acc.foldLeft(Set(exportTo.name))(_ ++ _)) - case (CallServiceTag(_, _, Call(_, Some(exportTo))), acc) => - Eval.later(acc.foldLeft(Set(exportTo.name))(_ ++ _)) + case (CallArrowTag(_, Call(_, exportTo)), acc) if exportTo.nonEmpty => + Eval.later(acc.foldLeft(exportTo.map(_.name).toSet)(_ ++ _)) + case (CallServiceTag(_, _, Call(_, exportTo)), acc) if exportTo.nonEmpty => + Eval.later(acc.foldLeft(exportTo.map(_.name).toSet)(_ ++ _)) case (NextTag(exportTo), acc) => Eval.later(acc.foldLeft(Set(exportTo))(_ ++ _)) case (_, acc) => Eval.later(acc.foldLeft(Set.empty[String])(_ ++ _)) } def exportsVarNames: Eval[Set[String]] = cata[Set[String]] { - case (CallArrowTag(_, Call(_, Some(exportTo))), acc) => - Eval.later(acc.foldLeft(Set(exportTo.name))(_ ++ _)) - case (CallServiceTag(_, _, Call(_, Some(exportTo))), acc) => - Eval.later(acc.foldLeft(Set(exportTo.name))(_ ++ _)) + case (CallArrowTag(_, Call(_, exportTo)), acc) if exportTo.nonEmpty => + Eval.later(acc.foldLeft(exportTo.map(_.name).toSet)(_ ++ _)) + case (CallServiceTag(_, _, Call(_, exportTo)), acc) if exportTo.nonEmpty => + Eval.later(acc.foldLeft(exportTo.map(_.name).toSet)(_ ++ _)) case (_, acc) => Eval.later(acc.foldLeft(Set.empty[String])(_ ++ _)) } diff --git a/model/src/main/scala/aqua/model/func/raw/FuncOps.scala b/model/src/main/scala/aqua/model/func/raw/FuncOps.scala index 7392ed3d..8e535de8 100644 --- a/model/src/main/scala/aqua/model/func/raw/FuncOps.scala +++ b/model/src/main/scala/aqua/model/func/raw/FuncOps.scala @@ -8,11 +8,11 @@ import cats.free.Cofree object FuncOps { def noop: FuncOp = - FuncOp.leaf(CallServiceTag(LiteralModel.quote("op"), "identity", Call(Nil, None))) + FuncOp.leaf(CallServiceTag(LiteralModel.quote("op"), "identity", Call(Nil, Nil))) def identity(what: ValueModel, to: Call.Export): FuncOp = FuncOp.leaf( - CallServiceTag(LiteralModel.quote("op"), "identity", Call(what :: Nil, Some(to))) + CallServiceTag(LiteralModel.quote("op"), "identity", Call(what :: Nil, to :: Nil)) ) def callService(srvId: ValueModel, funcName: String, call: Call): FuncOp = diff --git a/model/src/main/scala/aqua/model/func/resolved/CallRes.scala b/model/src/main/scala/aqua/model/func/resolved/CallRes.scala new file mode 100644 index 00000000..0031d227 --- /dev/null +++ b/model/src/main/scala/aqua/model/func/resolved/CallRes.scala @@ -0,0 +1,6 @@ +package aqua.model.func.resolved + +import aqua.model.ValueModel +import aqua.model.func.Call + +case class CallRes(args: List[ValueModel], exportTo: Option[Call.Export]) diff --git a/model/src/main/scala/aqua/model/func/resolved/MakeRes.scala b/model/src/main/scala/aqua/model/func/resolved/MakeRes.scala index 007255b1..e86b369a 100644 --- a/model/src/main/scala/aqua/model/func/resolved/MakeRes.scala +++ b/model/src/main/scala/aqua/model/func/resolved/MakeRes.scala @@ -1,6 +1,17 @@ package aqua.model.func.resolved import aqua.model.func.Call +import aqua.model.func.raw.{ + CallServiceTag, + ForTag, + MatchMismatchTag, + NextTag, + OnTag, + ParTag, + RawTag, + SeqTag, + XorTag +} import aqua.model.topology.Topology.Res import aqua.model.{LiteralModel, ValueModel} import cats.Eval @@ -28,5 +39,25 @@ object MakeRes { Cofree[Chain, ResolvedOp](FoldRes(item, iter), Eval.now(Chain.one(body))) def noop(onPeer: ValueModel): Res = - leaf(CallServiceRes(LiteralModel.quote("op"), "noop", Call(Nil, None), onPeer)) + leaf(CallServiceRes(LiteralModel.quote("op"), "noop", CallRes(Nil, None), onPeer)) + + def resolve( + currentPeerId: Option[ValueModel] + ): PartialFunction[RawTag, ResolvedOp] = { + case SeqTag => SeqRes + case _: OnTag => SeqRes + case MatchMismatchTag(a, b, s) => MatchMismatchRes(a, b, s) + case ForTag(item, iter) => FoldRes(item, iter) + case ParTag | ParTag.Detach => ParRes + case XorTag | XorTag.LeftBiased => XorRes + case NextTag(item) => NextRes(item) + case CallServiceTag(serviceId, funcName, Call(args, exportTo)) => + CallServiceRes( + serviceId, + funcName, + CallRes(args, exportTo.headOption), + currentPeerId + .getOrElse(LiteralModel.initPeerId) + ) + } } diff --git a/model/src/main/scala/aqua/model/func/resolved/ResolvedOp.scala b/model/src/main/scala/aqua/model/func/resolved/ResolvedOp.scala index 4e16b539..403f4c9c 100644 --- a/model/src/main/scala/aqua/model/func/resolved/ResolvedOp.scala +++ b/model/src/main/scala/aqua/model/func/resolved/ResolvedOp.scala @@ -27,7 +27,7 @@ case class AbilityIdRes( case class CallServiceRes( serviceId: ValueModel, funcName: String, - call: Call, + call: CallRes, peerId: ValueModel ) extends ResolvedOp { override def toString: String = s"(call $peerId ($serviceId $funcName) $call)" diff --git a/model/src/main/scala/aqua/model/topology/Topology.scala b/model/src/main/scala/aqua/model/topology/Topology.scala index b8c3eade..a4818a27 100644 --- a/model/src/main/scala/aqua/model/topology/Topology.scala +++ b/model/src/main/scala/aqua/model/topology/Topology.scala @@ -46,32 +46,14 @@ object Topology extends LogSupport { else cz.current ) - private def rawToResolved( - currentPeerId: Option[ValueModel] - ): PartialFunction[RawTag, ResolvedOp] = { - case SeqTag => SeqRes - case _: OnTag => SeqRes - case MatchMismatchTag(a, b, s) => MatchMismatchRes(a, b, s) - case ForTag(item, iter) => FoldRes(item, iter) - case ParTag | ParTag.Detach => ParRes - case XorTag | XorTag.LeftBiased => XorRes - case NextTag(item) => NextRes(item) - case CallServiceTag(serviceId, funcName, call) => - CallServiceRes( - serviceId, - funcName, - call, - currentPeerId - .getOrElse(LiteralModel.initPeerId) - ) - } - def resolveOnMoves(op: Tree): Eval[Res] = { val cursor = RawCursor(NonEmptyList.one(ChainZipper.one(op))) val resolvedCofree = cursor .cata(wrap) { rc => debug(s"<:> $rc") - val resolved = rawToResolved(rc.currentPeerId).lift + val resolved = MakeRes + .resolve(rc.currentPeerId) + .lift .apply(rc.tag) .map(MakeRes.leaf) val chainZipperEv = resolved.traverse(cofree => diff --git a/model/src/main/scala/aqua/model/transform/ArgsProvider.scala b/model/src/main/scala/aqua/model/transform/ArgsProvider.scala index 2283ae01..159b03d0 100644 --- a/model/src/main/scala/aqua/model/transform/ArgsProvider.scala +++ b/model/src/main/scala/aqua/model/transform/ArgsProvider.scala @@ -20,7 +20,7 @@ case class ArgsFromService(dataServiceId: ValueModel, names: List[(String, DataT FuncOps.callService( dataServiceId, name, - Call(Nil, Some(Call.Export(iter, ArrayType(t.element)))) + Call(Nil, Call.Export(iter, ArrayType(t.element)) :: Nil) ), FuncOps.fold( item, @@ -41,7 +41,7 @@ case class ArgsFromService(dataServiceId: ValueModel, names: List[(String, DataT FuncOps.callService( dataServiceId, name, - Call(Nil, Some(Call.Export(name, t))) + Call(Nil, Call.Export(name, t) :: Nil) ) } diff --git a/model/src/main/scala/aqua/model/transform/ErrorsCatcher.scala b/model/src/main/scala/aqua/model/transform/ErrorsCatcher.scala index 7920cc21..a7855d55 100644 --- a/model/src/main/scala/aqua/model/transform/ErrorsCatcher.scala +++ b/model/src/main/scala/aqua/model/transform/ErrorsCatcher.scala @@ -58,6 +58,6 @@ object ErrorsCatcher { def lastErrorCall(i: Int): Call = Call( lastErrorArg :: LiteralModel(i.toString, LiteralType.number) :: Nil, - None + Nil ) } diff --git a/model/src/main/scala/aqua/model/transform/ResolveFunc.scala b/model/src/main/scala/aqua/model/transform/ResolveFunc.scala index f7d38c74..9b4e2cf4 100644 --- a/model/src/main/scala/aqua/model/transform/ResolveFunc.scala +++ b/model/src/main/scala/aqua/model/transform/ResolveFunc.scala @@ -1,11 +1,10 @@ package aqua.model.transform -import aqua.model.func._ +import aqua.model.func.* import aqua.model.func.raw.{FuncOp, FuncOps} import aqua.model.{ValueModel, VarModel} -import aqua.types.{ArrayType, ArrowType, StreamType} +import aqua.types.{ArrayType, ArrowType, ConsType, NilType, ProductType, StreamType} import cats.Eval -import cats.syntax.apply._ case class ResolveFunc( transform: FuncOp => FuncOp, @@ -22,7 +21,7 @@ case class ResolveFunc( respFuncName, Call( retModel :: Nil, - None + Nil ) ) @@ -31,19 +30,19 @@ case class ResolveFunc( FuncCallable( arrowCallbackPrefix + name, callback(name, call), - args, - (ret.map(_.model), arrowType.res).mapN(_ -> _), + arrowType, + ret.map(_.model), Map.empty, Map.empty ) } def wrap(func: FuncCallable): FuncCallable = { - val returnType = func.ret.map(_._1.lastType).map { + val returnType = ProductType(func.ret.map(_.lastType).map { // we mustn't return a stream in response callback to avoid pushing stream to `-return-` value case StreamType(t) => ArrayType(t) case t => t - } + }).toLabelledList(returnVar) FuncCallable( wrapCallableName, @@ -53,21 +52,22 @@ case class ResolveFunc( .callArrow( func.funcName, Call( - func.args.toCallArgs, - returnType.map(t => Call.Export(returnVar, t)) + func.arrowType.domain.toLabelledList().map(ad => VarModel(ad._1, ad._2)), + returnType.map { case (l, t) => Call.Export(l, t) } ) ) :: - returnType - .map(t => VarModel(returnVar, t)) - .map(returnCallback) - .toList: _* + returnType.map { case (l, t) => VarModel(l, t) } + .map(returnCallback): _* ) ), - ArgsDef(ArgDef.Arrow(func.funcName, func.arrowType) :: Nil), - None, - func.args.arrowArgs.map { case ArgDef.Arrow(argName, arrowType) => - argName -> arrowToCallback(argName, arrowType) - }.toList.toMap, + ArrowType(ConsType.cons(func.funcName, func.arrowType, NilType), NilType), + Nil, + func.arrowType.domain + .toLabelledList() + .collect { case (argName, arrowType: ArrowType) => + argName -> arrowToCallback(argName, arrowType) + } + .toMap, Map.empty ) } @@ -78,7 +78,7 @@ case class ResolveFunc( ): Eval[FuncOp] = wrap(func) .resolve( - Call(VarModel(funcArgName, func.arrowType) :: Nil, None), + Call(VarModel(funcArgName, func.arrowType) :: Nil, Nil), Map(funcArgName -> func), Set.empty ) diff --git a/model/src/main/scala/aqua/model/transform/Transform.scala b/model/src/main/scala/aqua/model/transform/Transform.scala index bbc734df..30871d12 100644 --- a/model/src/main/scala/aqua/model/transform/Transform.scala +++ b/model/src/main/scala/aqua/model/transform/Transform.scala @@ -35,9 +35,7 @@ object Transform extends LogSupport { val argsProvider: ArgsProvider = ArgsFromService( conf.dataSrvId, - conf.relayVarName.map(_ -> ScalarType.string).toList ::: func.args.dataArgs.toList.map( - add => add.name -> add.dataType - ) + conf.relayVarName.map(_ -> ScalarType.string).toList ::: func.arrowType.domain.labelledData ) val transform = diff --git a/model/test-kit/src/main/scala/aqua/Node.scala b/model/test-kit/src/main/scala/aqua/Node.scala index d528f29e..ddf7b16e 100644 --- a/model/test-kit/src/main/scala/aqua/Node.scala +++ b/model/test-kit/src/main/scala/aqua/Node.scala @@ -1,8 +1,8 @@ package aqua import aqua.model.func.Call -import aqua.model.func.raw._ -import aqua.model.func.resolved.{CallServiceRes, MakeRes, MatchMismatchRes, ResolvedOp} +import aqua.model.func.raw.* +import aqua.model.func.resolved.{CallRes, CallServiceRes, MakeRes, MatchMismatchRes, ResolvedOp} import aqua.model.transform.{ErrorsCatcher, GenerationConfig} import aqua.model.{LiteralModel, ValueModel, VarModel} import aqua.types.{ArrayType, LiteralType, ScalarType} @@ -48,7 +48,7 @@ object Node { val relay = LiteralModel("-relay-", ScalarType.string) val relayV = VarModel("-relay-", ScalarType.string) val initPeer = LiteralModel.initPeerId - val emptyCall = Call(Nil, None) + val emptyCall = Call(Nil, Nil) val otherPeer = LiteralModel("other-peer", ScalarType.string) val otherPeerL = LiteralModel("\"other-peer\"", LiteralType.string) val otherRelay = LiteralModel("other-relay", ScalarType.string) @@ -63,10 +63,10 @@ object Node { exportTo: Option[Call.Export] = None, args: List[ValueModel] = Nil ): Res = Node( - CallServiceRes(LiteralModel(s"srv$i", ScalarType.string), s"fn$i", Call(args, exportTo), on) + CallServiceRes(LiteralModel(s"srv$i", ScalarType.string), s"fn$i", CallRes(args, exportTo), on) ) - def callTag(i: Int, exportTo: Option[Call.Export] = None, args: List[ValueModel] = Nil): Raw = + def callTag(i: Int, exportTo: List[Call.Export] = Nil, args: List[ValueModel] = Nil): Raw = Node( CallServiceTag(LiteralModel(s"srv$i", ScalarType.string), s"fn$i", Call(args, exportTo)) ) @@ -75,12 +75,12 @@ object Node { CallServiceRes( LiteralModel("\"srv" + i + "\"", LiteralType.string), s"fn$i", - Call(Nil, exportTo), + CallRes(Nil, exportTo), on ) ) - def callLiteralRaw(i: Int, exportTo: Option[Call.Export] = None): Raw = Node( + def callLiteralRaw(i: Int, exportTo: List[Call.Export] = Nil): Raw = Node( CallServiceTag( LiteralModel("\"srv" + i + "\"", LiteralType.string), s"fn$i", @@ -92,7 +92,7 @@ object Node { CallServiceRes( bc.errorHandlingCallback, bc.errorFuncName, - Call( + CallRes( ErrorsCatcher.lastErrorArg :: LiteralModel( i.toString, LiteralType.number @@ -108,7 +108,7 @@ object Node { CallServiceRes( bc.callbackSrvId, bc.respFuncName, - Call(value :: Nil, None), + CallRes(value :: Nil, None), on ) ) @@ -118,7 +118,7 @@ object Node { CallServiceRes( bc.dataSrvId, name, - Call(Nil, Some(Call.Export(name, ScalarType.string))), + CallRes(Nil, Some(Call.Export(name, ScalarType.string))), on ) ) @@ -156,7 +156,7 @@ object Node { Console.GREEN + "(" + equalOrNot(left, right) + Console.GREEN + ")" - private def diffCall(left: Call, right: Call): String = + private def diffCall(left: CallRes, right: CallRes): String = if (left == right) Console.GREEN + left + Console.RESET else Console.GREEN + "Call(" + diff --git a/model/test-kit/src/test/scala/aqua/model/topology/TopologySpec.scala b/model/test-kit/src/test/scala/aqua/model/topology/TopologySpec.scala index 2640bdfc..968c34ef 100644 --- a/model/test-kit/src/test/scala/aqua/model/topology/TopologySpec.scala +++ b/model/test-kit/src/test/scala/aqua/model/topology/TopologySpec.scala @@ -90,7 +90,7 @@ class TopologySpec extends AnyFlatSpec with Matchers { } "topology resolver" should "build return path in par if there are exported variables" in { - val exportTo = Some(Call.Export("result", ScalarType.string)) + val exportTo = Call.Export("result", ScalarType.string) :: Nil val result = VarModel("result", ScalarType.string) val init = on( @@ -105,7 +105,7 @@ class TopologySpec extends AnyFlatSpec with Matchers { ), callTag(2) ), - callTag(3, None, result :: Nil) + callTag(3, Nil, result :: Nil) ) ) @@ -117,7 +117,7 @@ class TopologySpec extends AnyFlatSpec with Matchers { MakeRes.seq( through(relay), through(otherRelay), - callRes(1, otherPeer, exportTo), + callRes(1, otherPeer, exportTo.headOption), through(otherRelay), through(relay), // we should return to a caller to continue execution diff --git a/model/test-kit/src/test/scala/aqua/model/transform/TransformSpec.scala b/model/test-kit/src/test/scala/aqua/model/transform/TransformSpec.scala index 6b0890a1..1bb87aaa 100644 --- a/model/test-kit/src/test/scala/aqua/model/transform/TransformSpec.scala +++ b/model/test-kit/src/test/scala/aqua/model/transform/TransformSpec.scala @@ -2,16 +2,18 @@ package aqua.model.transform import aqua.Node import aqua.model.func.raw.{CallArrowTag, CallServiceTag, FuncOp, FuncOps} -import aqua.model.func.resolved.{CallServiceRes, MakeRes} -import aqua.model.func.{ArgsDef, Call, FuncCallable} +import aqua.model.func.resolved.{CallRes, CallServiceRes, MakeRes} +import aqua.model.func.{Call, FuncCallable} import aqua.model.{LiteralModel, VarModel} -import aqua.types.ScalarType +import aqua.types.{ArrowType, NilType, ProductType, ScalarType} import org.scalatest.flatspec.AnyFlatSpec import org.scalatest.matchers.should.Matchers class TransformSpec extends AnyFlatSpec with Matchers { import Node._ + val stringArrow: ArrowType = ArrowType(NilType, ProductType(ScalarType.string :: Nil)) + "transform.forClient" should "work well with function 1 (no calls before on), generate correct error handling" in { val ret = LiteralModel.quote("return this") @@ -20,8 +22,8 @@ class TransformSpec extends AnyFlatSpec with Matchers { FuncCallable( "ret", on(otherPeer, otherRelay :: Nil, callTag(1)), - ArgsDef.empty, - Some((ret, ScalarType.string)), + stringArrow, + ret :: Nil, Map.empty, Map.empty ) @@ -70,8 +72,8 @@ class TransformSpec extends AnyFlatSpec with Matchers { val func: FuncCallable = FuncCallable( "ret", FuncOps.seq(callTag(0), on(otherPeer, Nil, callTag(1))), - ArgsDef.empty, - Some((ret, ScalarType.string)), + stringArrow, + ret :: Nil, Map.empty, Map.empty ) @@ -115,12 +117,12 @@ class TransformSpec extends AnyFlatSpec with Matchers { CallServiceTag( LiteralModel.quote("srv1"), "foo", - Call(Nil, Some(Call.Export("v", ScalarType.string))) + Call(Nil, Call.Export("v", ScalarType.string) :: Nil) ) ).cof ), - ArgsDef.empty, - Some((VarModel("v", ScalarType.string), ScalarType.string)), + stringArrow, + VarModel("v", ScalarType.string) :: Nil, Map.empty, Map.empty ) @@ -129,10 +131,10 @@ class TransformSpec extends AnyFlatSpec with Matchers { FuncCallable( "f2", FuncOp( - Node(CallArrowTag("callable", Call(Nil, Some(Call.Export("v", ScalarType.string))))).cof + Node(CallArrowTag("callable", Call(Nil, Call.Export("v", ScalarType.string) :: Nil))).cof ), - ArgsDef.empty, - Some((VarModel("v", ScalarType.string), ScalarType.string)), + stringArrow, + VarModel("v", ScalarType.string) :: Nil, Map("callable" -> f1), Map.empty ) @@ -148,7 +150,7 @@ class TransformSpec extends AnyFlatSpec with Matchers { CallServiceRes( LiteralModel.quote("srv1"), "foo", - Call(Nil, Some(Call.Export("v", ScalarType.string))), + CallRes(Nil, Some(Call.Export("v", ScalarType.string))), initPeer ) ), diff --git a/parser/src/main/scala/aqua/parser/lexer/Token.scala b/parser/src/main/scala/aqua/parser/lexer/Token.scala index f7ea5dd7..0a05840a 100644 --- a/parser/src/main/scala/aqua/parser/lexer/Token.scala +++ b/parser/src/main/scala/aqua/parser/lexer/Token.scala @@ -68,6 +68,7 @@ object Token { val `[]` : P[Unit] = P.string("[]") val `⊤` : P[Unit] = P.char('⊤') val `⊥` : P[Unit] = P.char('⊥') + val `∅` : P[Unit] = P.char('∅') val `(` : P[Unit] = P.char('(').surroundedBy(` `.?) val `)` : P[Unit] = P.char(')').surroundedBy(` `.?) val `()` : P[Unit] = P.string("()") diff --git a/parser/src/main/scala/aqua/parser/lexer/TypeToken.scala b/parser/src/main/scala/aqua/parser/lexer/TypeToken.scala index 1597b8ed..5ec69bc5 100644 --- a/parser/src/main/scala/aqua/parser/lexer/TypeToken.scala +++ b/parser/src/main/scala/aqua/parser/lexer/TypeToken.scala @@ -113,7 +113,8 @@ object DataTypeToken { (`[]`.lift ~ `datatypedef`[F]).map(ud => ArrayTypeToken(ud._1, ud._2)) def `topbottomdef`[F[_]: LiftParser: Comonad]: P[TopBottomToken[F]] = - `⊥`.lift.map(TopBottomToken(_, isTop = false)) | `⊤`.lift.map(TopBottomToken(_, isTop = true)) + `⊥`.lift.map(TopBottomToken(_, isTop = false)) | + `⊤`.lift.map(TopBottomToken(_, isTop = true)) def `datatypedef`[F[_]: LiftParser: Comonad]: P[DataTypeToken[F]] = P.oneOf( diff --git a/semantics/src/main/scala/aqua/semantics/expr/CallArrowSem.scala b/semantics/src/main/scala/aqua/semantics/expr/CallArrowSem.scala index 37d5fa36..20448a9a 100644 --- a/semantics/src/main/scala/aqua/semantics/expr/CallArrowSem.scala +++ b/semantics/src/main/scala/aqua/semantics/expr/CallArrowSem.scala @@ -83,18 +83,18 @@ class CallArrowSem[F[_]](val expr: CallArrowExpr[F]) extends AnyVal { case _ => Free.pure[Alg, Option[Call.Export]](None) - }).map(call => + }).map(maybeExport => FuncOp.leaf(serviceId match { case Some(sid) => CallServiceTag( serviceId = sid, funcName = funcName.value, - Call(argsResolved, call) + Call(argsResolved, maybeExport.toList) ) case None => CallArrowTag( funcName = funcName.value, - Call(argsResolved, call) + Call(argsResolved, maybeExport.toList) ) }) ) diff --git a/semantics/src/main/scala/aqua/semantics/expr/DataStructSem.scala b/semantics/src/main/scala/aqua/semantics/expr/DataStructSem.scala index a9b47a52..33c4a69c 100644 --- a/semantics/src/main/scala/aqua/semantics/expr/DataStructSem.scala +++ b/semantics/src/main/scala/aqua/semantics/expr/DataStructSem.scala @@ -5,7 +5,7 @@ import aqua.parser.expr.DataStructExpr import aqua.semantics.Prog import aqua.semantics.rules.names.NamesAlgebra import aqua.semantics.rules.types.TypesAlgebra -import aqua.types.ProductType +import aqua.types.StructType import cats.free.Free import cats.syntax.functor._ @@ -20,7 +20,7 @@ class DataStructSem[F[_]](val expr: DataStructExpr[F]) extends AnyVal { case Some(fields) => T.defineDataType(expr.name, fields) as (TypeModel( expr.name.value, - ProductType(expr.name.value, fields) + StructType(expr.name.value, fields) ): Model) case None => Free.pure[Alg, Model](Model.error("Data struct types unresolved")) } diff --git a/semantics/src/main/scala/aqua/semantics/expr/FuncSem.scala b/semantics/src/main/scala/aqua/semantics/expr/FuncSem.scala index 144177f8..e0f6db62 100644 --- a/semantics/src/main/scala/aqua/semantics/expr/FuncSem.scala +++ b/semantics/src/main/scala/aqua/semantics/expr/FuncSem.scala @@ -1,7 +1,7 @@ package aqua.semantics.expr import aqua.model.func.raw.{FuncOp, FuncOps} -import aqua.model.func.{ArgDef, ArgsDef, FuncModel} +import aqua.model.func.FuncModel import aqua.model.{Model, ReturnModel, ValueModel} import aqua.parser.expr.FuncExpr import aqua.parser.lexer.Arg @@ -10,7 +10,7 @@ import aqua.semantics.rules.ValuesAlgebra import aqua.semantics.rules.abilities.AbilitiesAlgebra import aqua.semantics.rules.names.NamesAlgebra import aqua.semantics.rules.types.TypesAlgebra -import aqua.types.{ArrowType, DataType, Type} +import aqua.types.{ArrowType, ProductType, Type} import cats.Applicative import cats.data.Chain import cats.free.Free @@ -32,15 +32,15 @@ class FuncSem[F[_]](val expr: FuncExpr[F]) extends AnyVal { args .foldLeft( // Begin scope -- for mangling - N.beginScope(name).as[Chain[Type]](Chain.empty) + N.beginScope(name).as[Chain[(String, Type)]](Chain.empty) ) { case (f, Arg(argName, argType)) => // Resolve arg type, remember it f.flatMap(acc => T.resolveType(argType).flatMap { case Some(t: ArrowType) => - N.defineArrow(argName, t, isRoot = false).as(acc.append(t)) + N.defineArrow(argName, t, isRoot = false).as(acc.append(argName.value -> t)) case Some(t) => - N.define(argName, t).as(acc.append(t)) + N.define(argName, t).as(acc.append(argName.value -> t)) case None => Free.pure(acc) } @@ -50,24 +50,19 @@ class FuncSem[F[_]](val expr: FuncExpr[F]) extends AnyVal { // Resolve return type ret.fold(Free.pure[Alg, Option[Type]](None))(T.resolveType(_)) ) - .map(argsAndRes => ArrowType(argsAndRes._1, argsAndRes._2)) + .map(argsAndRes => + ArrowType(ProductType.labelled(argsAndRes._1), ProductType(argsAndRes._2.toList)) + ) - def generateFuncModel[Alg[_]](funcArrow: ArrowType, retModel: Option[ValueModel], body: FuncOp)( + def generateFuncModel[Alg[_]](funcArrow: ArrowType, retModel: List[ValueModel], body: FuncOp)( implicit N: NamesAlgebra[F, Alg] ): Free[Alg, Model] = { val argNames = args.map(_.name.value) val model = FuncModel( name = name.value, - args = ArgsDef( - argNames - .zip(funcArrow.args) - .map { - case (n, dt: DataType) => ArgDef.Data(n, dt) - case (n, at: ArrowType) => ArgDef.Arrow(n, at) - } - ), - ret = retModel zip funcArrow.res, + arrowType = funcArrow, + ret = retModel, body = body ) @@ -87,12 +82,14 @@ class FuncSem[F[_]](val expr: FuncExpr[F]) extends AnyVal { // Check return value type ((funcArrow.res, retValue) match { case (Some(t), Some(v)) => - V.valueToModel(v).flatTap { - case Some(vt) => T.ensureTypeMatches(v, t, vt.lastType).void - case None => Free.pure[Alg, Unit](()) - } - case _ => - Free.pure[Alg, Option[ValueModel]](None) + V.valueToModel(v) + .flatTap { + case Some(vt) => T.ensureTypeMatches(v, t, vt.lastType).void + case None => Free.pure[Alg, Unit](()) + } + .map(_.toList) + case (_, _) => + Free.pure[Alg, List[ValueModel]](Nil) // Erase arguments and internal variables }).flatMap(retModel => diff --git a/semantics/src/main/scala/aqua/semantics/expr/PushToStreamSem.scala b/semantics/src/main/scala/aqua/semantics/expr/PushToStreamSem.scala index 779f5db1..1b8bd91f 100644 --- a/semantics/src/main/scala/aqua/semantics/expr/PushToStreamSem.scala +++ b/semantics/src/main/scala/aqua/semantics/expr/PushToStreamSem.scala @@ -55,10 +55,11 @@ class PushToStreamSem[F[_]](val expr: PushToStreamExpr[F]) extends AnyVal { .map(t => FuncOp .leaf( + // TODO: replace with Apply CallServiceTag( LiteralModel.quote("op"), "identity", - Call(vm :: Nil, Some(Call.Export(expr.stream.value, t))) + Call(vm :: Nil, Call.Export(expr.stream.value, t) :: Nil) ) ): Model ) diff --git a/semantics/src/main/scala/aqua/semantics/rules/ValuesAlgebra.scala b/semantics/src/main/scala/aqua/semantics/rules/ValuesAlgebra.scala index 32eef98e..8b8fe06c 100644 --- a/semantics/src/main/scala/aqua/semantics/rules/ValuesAlgebra.scala +++ b/semantics/src/main/scala/aqua/semantics/rules/ValuesAlgebra.scala @@ -43,15 +43,16 @@ class ValuesAlgebra[F[_], Alg[_]](implicit N: NamesAlgebra[F, Alg], T: TypesAlge } } - def checkArguments(token: Token[F], arr: ArrowType, args: List[Value[F]]): Free[Alg, Boolean] = { - T.checkArgumentsNumber(token, arr.args.length, args.length).flatMap { + def checkArguments(token: Token[F], arr: ArrowType, args: List[Value[F]]): Free[Alg, Boolean] = + // TODO: do we really need to check this? + T.checkArgumentsNumber(token, arr.domain.length, args.length).flatMap { case false => Free.pure[Alg, Boolean](false) case true => args .map[Free[Alg, Option[(Token[F], Type)]]](tkn => resolveType(tkn).map(_.map(t => tkn -> t)) ) - .zip(arr.args) + .zip(arr.domain.toList) .foldLeft( Free.pure[Alg, Boolean](true) ) { case (f, (ft, t)) => @@ -66,7 +67,6 @@ class ValuesAlgebra[F[_], Alg[_]](implicit N: NamesAlgebra[F, Alg], T: TypesAlge ).mapN(_ && _) } } - } } diff --git a/semantics/src/main/scala/aqua/semantics/rules/types/TypesInterpreter.scala b/semantics/src/main/scala/aqua/semantics/rules/types/TypesInterpreter.scala index 93942cc8..d53821e4 100644 --- a/semantics/src/main/scala/aqua/semantics/rules/types/TypesInterpreter.scala +++ b/semantics/src/main/scala/aqua/semantics/rules/types/TypesInterpreter.scala @@ -2,7 +2,7 @@ package aqua.semantics.rules.types import aqua.parser.lexer.Token import aqua.semantics.rules.ReportError -import aqua.types.{ArrowType, ProductType} +import aqua.types.{ArrowType, StructType} import cats.data.Validated.{Invalid, Valid} import cats.data.{NonEmptyMap, State} import cats.syntax.flatMap._ @@ -71,7 +71,7 @@ class TypesInterpreter[F[_], X](implicit lens: Lens[X, TypesState[F]], error: Re case None => modify(st => st.copy( - strict = st.strict.updated(ddt.name.value, ProductType(ddt.name.value, ddt.fields)), + strict = st.strict.updated(ddt.name.value, StructType(ddt.name.value, ddt.fields)), definitions = st.definitions.updated(ddt.name.value, ddt.name) ) ) diff --git a/semantics/src/main/scala/aqua/semantics/rules/types/TypesState.scala b/semantics/src/main/scala/aqua/semantics/rules/types/TypesState.scala index d74aba84..1f73e6b2 100644 --- a/semantics/src/main/scala/aqua/semantics/rules/types/TypesState.scala +++ b/semantics/src/main/scala/aqua/semantics/rules/types/TypesState.scala @@ -17,7 +17,18 @@ import aqua.parser.lexer.{ TopBottomToken, TypeToken } -import aqua.types.{ArrayType, ArrowType, DataType, OptionType, ProductType, StreamType, Type} +import aqua.types.{ + ArrayType, + ArrowType, + BottomType, + DataType, + OptionType, + ProductType, + StreamType, + StructType, + TopType, + Type +} import cats.data.Validated.{Invalid, Valid} import cats.data.{Chain, NonEmptyChain, ValidatedNec} import cats.kernel.Monoid @@ -32,7 +43,7 @@ case class TypesState[F[_]]( def resolveTypeToken(tt: TypeToken[F]): Option[Type] = tt match { case TopBottomToken(_, isTop) => - Option(if (isTop) DataType.Top else DataType.Bottom) + Option(if (isTop) TopType else BottomType) case ArrayTypeToken(_, dtt) => resolveTypeToken(dtt).collect { case it: DataType => ArrayType(it) @@ -55,7 +66,7 @@ case class TypesState[F[_]]( dt } Option.when(strictRes.isDefined == res.isDefined && strictArgs.length == args.length)( - ArrowType(strictArgs, strictRes) + ArrowType(ProductType(strictArgs), ProductType(strictRes.toList)) ) } @@ -72,7 +83,7 @@ case class TypesState[F[_]]( NonEmptyChain .fromChain(errs) .fold[ValidatedNec[(Token[F], String), ArrowType]]( - Valid(ArrowType(argTypes.toList, resType)) + Valid(ArrowType(ProductType(argTypes.toList), ProductType(resType.toList))) )(Invalid(_)) case _ => @@ -95,7 +106,7 @@ case class TypesState[F[_]]( } case (i @ IntoField(_)) :: tail => rootT match { - case pt @ ProductType(_, fields) => + case pt @ StructType(_, fields) => fields(i.value) .toRight(i -> s"Field `${i.value}` not found in type `${pt.name}``") .flatMap(t => resolveOps(t, tail).map(IntoFieldModel(i.value, t) :: _)) diff --git a/types/src/main/scala/aqua/types/CompareTypes.scala b/types/src/main/scala/aqua/types/CompareTypes.scala new file mode 100644 index 00000000..f4f6db5c --- /dev/null +++ b/types/src/main/scala/aqua/types/CompareTypes.scala @@ -0,0 +1,141 @@ +package aqua.types + +import cats.data.NonEmptyMap +import cats.kernel.PartialOrder + +/** + * Types variance is given as a partial order of types. + * Type A is less than type B if B has more data than A. + * E.g. u8 < u16 + */ +object CompareTypes { + import Double.NaN + + private def compareTypesList(l: List[Type], r: List[Type]): Double = + if (l.length != r.length) NaN + else if (l == r) 0.0 + else + (l zip r).map(lr => apply(lr._1, lr._2)).fold(0.0) { + case (a, b) if a == b => a + case (`NaN`, _) => NaN + case (_, `NaN`) => NaN + case (0, b) => b + case (a, 0) => a + case _ => NaN + } + + import ScalarType.* + + private def isLessThen(a: ScalarType, b: ScalarType): Boolean = (a, b) match { + // Signed numbers + case (`i32` | `i16` | `i8`, `i64`) => true + case (`i16` | `i8`, `i32`) => true + case (`i8`, `i16`) => true + + // Unsigned numbers -- can fit into larger signed ones too + case (`u32` | `u16` | `u8`, `u64` | `i64`) => true + case (`u16` | `u8`, `u32` | `i32`) => true + case (`u8`, `u16` | `i16`) => true + + // Floats + case (`f32`, `f64`) => true + + case (`i8` | `i16` | `u8` | `u16`, `f32` | `f64`) => true + case (`i32` | `u32`, `f64`) => true + + case _ => false + } + + private val scalarOrder: PartialOrder[ScalarType] = + PartialOrder.from { + case (a, b) if a == b => 0.0 + case (a, b) if isLessThen(a, b) => -1.0 + case (a, b) if isLessThen(b, a) => 1.0 + case _ => Double.NaN + } + + private def compareStructs(lf: NonEmptyMap[String, Type], rf: NonEmptyMap[String, Type]): Double = + if (lf.toSortedMap == rf.toSortedMap) 0.0 + else if ( + lf.keys.forall(rf.contains) && compareTypesList( + lf.toSortedMap.toList.map(_._2), + rf.toSortedMap.view.filterKeys(lf.keys.contains).toList.map(_._2) + ) == -1.0 + ) 1.0 + else if ( + rf.keys.forall(lf.contains) && compareTypesList( + lf.toSortedMap.view.filterKeys(rf.keys.contains).toList.map(_._2), + rf.toSortedMap.toList.map(_._2) + ) == 1.0 + ) -1.0 + else NaN + + private def compareProducts(l: ProductType, r: ProductType): Double = (l, r) match { + case (NilType, NilType) => 0.0 + case (_: ConsType, NilType) => -1.0 + case (NilType, _: ConsType) => 1.0 + case (ConsType(lhead, ltail), ConsType(rhead, rtail)) => + // If any is not Cons, than it's Bottom and already handled + val headCmp = apply(lhead, rhead) + if (headCmp.isNaN) NaN + else { + val tailCmp = compareProducts(ltail, rtail) + // If one is >, and another eq, it's >, and vice versa + if (headCmp >= 0 && tailCmp >= 0) 1.0 + else if (headCmp <= 0 && tailCmp <= 0) -1.0 + else NaN + } + } + + /** + * Compare types in the meaning of type variance. + * + * @param l Type + * @param r Type + * @return 0 if types match, + * 1 if left type is supertype for the right one, + * -1 if left is a subtype of the right + */ + def apply(l: Type, r: Type): Double = + if (l == r) 0.0 + else + (l, r) match { + case (TopType, _) | (_, BottomType) => 1.0 + case (BottomType, _) | (_, TopType) => -1.0 + + // Literals and scalars + case (x: ScalarType, y: ScalarType) => scalarOrder.partialCompare(x, y) + case (LiteralType(xs, _), y: ScalarType) if xs == Set(y) => 0.0 + case (LiteralType(xs, _), y: ScalarType) if xs(y) => -1.0 + case (x: ScalarType, LiteralType(ys, _)) if ys == Set(x) => 0.0 + case (x: ScalarType, LiteralType(ys, _)) if ys(x) => 1.0 + + // Collections + case (x: ArrayType, y: ArrayType) => apply(x.element, y.element) + case (x: ArrayType, y: StreamType) => apply(x.element, y.element) + case (x: ArrayType, y: OptionType) => apply(x.element, y.element) + case (x: OptionType, y: StreamType) => apply(x.element, y.element) + case (x: OptionType, y: ArrayType) => apply(x.element, y.element) + case (x: StreamType, y: StreamType) => apply(x.element, y.element) + case (StructType(_, xFields), StructType(_, yFields)) => + compareStructs(xFields, yFields) + + // Products + case (l: ProductType, r: ProductType) => compareProducts(l, r) + + // Arrows + case (ArrowType(ldom, lcodom), ArrowType(rdom, rcodom)) => + val cmpDom = apply(ldom, rdom) + val cmpCodom = apply(lcodom, rcodom) + + if (cmpDom >= 0 && cmpCodom <= 0) -1.0 + else if (cmpDom <= 0 && cmpCodom >= 0) 1.0 + else NaN + + case _ => + Double.NaN + } + + implicit val partialOrder: PartialOrder[Type] = + PartialOrder.from(CompareTypes.apply) +} diff --git a/types/src/main/scala/aqua/types/Type.scala b/types/src/main/scala/aqua/types/Type.scala index 60b5656f..502a4131 100644 --- a/types/src/main/scala/aqua/types/Type.scala +++ b/types/src/main/scala/aqua/types/Type.scala @@ -2,8 +2,6 @@ package aqua.types import cats.PartialOrder import cats.data.NonEmptyMap -import cats.instances.option._ -import cats.syntax.apply._ sealed trait Type { @@ -12,12 +10,107 @@ sealed trait Type { import cats.syntax.partialOrder._ this >= incoming } + + def isInhabited: Boolean = true } + +// Product is a list of (optionally labelled) types +sealed trait ProductType extends Type { + def isEmpty: Boolean = this == NilType + + def length: Int + + def uncons: Option[(Type, ProductType)] = this match { + case ConsType(t, pt) => Some(t -> pt) + case _ => None + } + + lazy val toList: List[Type] = this match { + case ConsType(t, pt) => t :: pt.toList + case _ => Nil + } + + /** + * Converts product type to a list of types, labelling each of them with a string + * Label is either got from the types with labels, or from the given prefix and index of a type. + * @param prefix Prefix to generate a missing label + * @param index Index to ensure generated labels are unique + * @return + */ + def toLabelledList(prefix: String = "arg", index: Int = 0): List[(String, Type)] = this match { + case LabelledConsType(label, t, pt) => (label -> t) :: pt.toLabelledList(prefix, index + 1) + case UnlabelledConsType(t, pt) => + (s"$prefix$index" -> t) :: pt.toLabelledList(prefix, index + 1) + case _ => Nil + } + + lazy val labelledData: List[(String, DataType)] = this match { + case LabelledConsType(label, t: DataType, pt) => (label -> t) :: pt.labelledData + case UnlabelledConsType(_, pt) => pt.labelledData + case _ => Nil + } +} + +object ProductType { + + def apply(types: List[Type]): ProductType = types match { + case h :: t => + ConsType.cons(h, ProductType(t)) + case _ => NilType + } + + def labelled(types: List[(String, Type)]): ProductType = types match { + case (l, h) :: t => + ConsType.cons(l, h, ProductType.labelled(t)) + case _ => NilType + } +} + +/** + * ConsType adds a type to the ProductType, and delegates all the others to tail + * Corresponds to Cons (::) in the List + */ +sealed trait ConsType extends ProductType { + def `type`: Type + def tail: ProductType + + override def length: Int = 1 + tail.length +} + +object ConsType { + def unapply(cons: ConsType): Option[(Type, ProductType)] = Some(cons.`type` -> cons.tail) + def cons(`type`: Type, tail: ProductType): ConsType = UnlabelledConsType(`type`, tail) + + def cons(label: String, `type`: Type, tail: ProductType): ConsType = + LabelledConsType(label, `type`, tail) +} + +case class LabelledConsType(label: String, `type`: Type, tail: ProductType) extends ConsType { + override def toString: String = s"($label: " + `type` + s" :: $tail" +} + +case class UnlabelledConsType(`type`: Type, tail: ProductType) extends ConsType { + override def toString: String = `type`.toString + s" :: $tail" +} + +object NilType extends ProductType { + override def toString: String = "∅" + + override def isInhabited: Boolean = false + + override def length: Int = 0 +} + sealed trait DataType extends Type -object DataType { - case object Top extends DataType - case object Bottom extends DataType +case object TopType extends DataType { + override def toString: String = "⊤" +} + +case object BottomType extends DataType { + override def toString: String = "⊥" + + override def isInhabited: Boolean = false } case class ScalarType private (name: String) extends DataType { @@ -46,34 +139,6 @@ object ScalarType { val signed = float ++ Set(i8, i16, i32, i64) val number = signed ++ Set(u8, u16, u32, u64) val all = number ++ Set(bool, string) - - private def isLessThen(a: ScalarType, b: ScalarType): Boolean = (a, b) match { - // Signed numbers - case (`i32` | `i16` | `i8`, `i64`) => true - case (`i16` | `i8`, `i32`) => true - case (`i8`, `i16`) => true - - // Unsigned numbers -- can fit into larger signed ones too - case (`u32` | `u16` | `u8`, `u64` | `i64`) => true - case (`u16` | `u8`, `u32` | `i32`) => true - case (`u8`, `u16` | `i16`) => true - - // Floats - case (`f32`, `f64`) => true - - case (`i8` | `i16` | `u8` | `u16`, `f32` | `f64`) => true - case (`i32` | `u32`, `f64`) => true - - case _ => false - } - - val scalarOrder: PartialOrder[ScalarType] = - PartialOrder.from { - case (a, b) if a == b => 0.0 - case (a, b) if isLessThen(a, b) => -1.0 - case (a, b) if isLessThen(b, a) => 1.0 - case _ => Double.NaN - } } case class LiteralType private (oneOf: Set[ScalarType], name: String) extends DataType { @@ -100,93 +165,49 @@ case class OptionType(element: Type) extends BoxType { override def toString: String = "?" + element } -case class ProductType(name: String, fields: NonEmptyMap[String, Type]) extends DataType { +// Struct is an unordered collection of labelled types +case class StructType(name: String, fields: NonEmptyMap[String, Type]) extends DataType { override def toString: String = s"$name{${fields.map(_.toString).toNel.toList.map(kv => kv._1 + ": " + kv._2).mkString(", ")}}" } -case class ArrowType(args: List[Type], res: Option[Type]) extends Type { +/** + * ArrowType is a profunctor pointing its domain to codomain. + * Profunctor means variance: Arrow is contravariant on domain, and variant on codomain. + * See tests for details. + * @param domain Where this Arrow is defined + * @param codomain Where this Arrow points on + */ +case class ArrowType(domain: ProductType, codomain: ProductType) extends Type { + @deprecated( + "Use .domain to get arguments, add .args helper to the typed object, if needed", + "5.08.2021" + ) + def args: List[Type] = domain.toList + + @deprecated( + "Use .codomain to get results, add .res helper to the typed object, if needed; consider multi-value return", + "5.08.2021" + ) + def res: Option[Type] = codomain.uncons.map(_._1) + + @deprecated( + "Replace with this function's body", + "5.08.2021" + ) def acceptsAsArguments(valueTypes: List[Type]): Boolean = - (args.length == valueTypes.length) && args - .zip(valueTypes) - .forall(av => av._1.acceptsValueOf(av._2)) + domain.acceptsValueOf(ProductType(valueTypes)) override def toString: String = - args.map(_.toString).mkString(", ") + " -> " + res.map(_.toString).getOrElse("()") + s"$domain -> $codomain" } case class StreamType(element: Type) extends BoxType object Type { - import Double.NaN - private def cmpTypesList(l: List[Type], r: List[Type]): Double = - if (l.length != r.length) NaN - else if (l == r) 0.0 - else - (l zip r).map(lr => cmp(lr._1, lr._2)).fold(0.0) { - case (a, b) if a == b => a - case (`NaN`, _) => NaN - case (_, `NaN`) => NaN - case (0, b) => b - case (a, 0) => a - case _ => NaN - } - - private def cmpProd(lf: NonEmptyMap[String, Type], rf: NonEmptyMap[String, Type]): Double = - if (lf.toSortedMap == rf.toSortedMap) 0.0 - else if ( - lf.keys.forall(rf.contains) && cmpTypesList( - lf.toSortedMap.toList.map(_._2), - rf.toSortedMap.view.filterKeys(lf.keys.contains).toList.map(_._2) - ) == -1.0 - ) 1.0 - else if ( - rf.keys.forall(lf.contains) && cmpTypesList( - lf.toSortedMap.view.filterKeys(rf.keys.contains).toList.map(_._2), - rf.toSortedMap.toList.map(_._2) - ) == 1.0 - ) -1.0 - else NaN - - private def cmp(l: Type, r: Type): Double = - if (l == r) 0.0 - else - (l, r) match { - case (DataType.Top, _: DataType) | (_: DataType, DataType.Bottom) => 1.0 - case (DataType.Bottom, _: DataType) | (_: DataType, DataType.Top) => -1.0 - case (x: ScalarType, y: ScalarType) => ScalarType.scalarOrder.partialCompare(x, y) - case (LiteralType(xs, _), y: ScalarType) if xs == Set(y) => 0.0 - case (LiteralType(xs, _), y: ScalarType) if xs(y) => -1.0 - case (x: ScalarType, LiteralType(ys, _)) if ys == Set(x) => 0.0 - case (x: ScalarType, LiteralType(ys, _)) if ys(x) => 1.0 - case (x: ArrayType, y: ArrayType) => cmp(x.element, y.element) - case (x: ArrayType, y: StreamType) => cmp(x.element, y.element) - case (x: ArrayType, y: OptionType) => cmp(x.element, y.element) - case (x: OptionType, y: StreamType) => cmp(x.element, y.element) - case (x: OptionType, y: ArrayType) => cmp(x.element, y.element) - case (x: StreamType, y: StreamType) => cmp(x.element, y.element) - case (ProductType(_, xFields), ProductType(_, yFields)) => - cmpProd(xFields, yFields) - case (l: ArrowType, r: ArrowType) => - val argL = l.args - val resL = l.res - val argR = r.args - val resR = r.res - val cmpTypes = cmpTypesList(argR, argL) - val cmpRes = - if (resL == resR) 0.0 - else (resL, resR).mapN(cmp).getOrElse(NaN) - - if (cmpTypes >= 0 && cmpRes >= 0) 1.0 - else if (cmpTypes <= 0 && cmpRes <= 0) -1.0 - else NaN - - case _ => - Double.NaN - } - - implicit lazy val typesPartialOrder: PartialOrder[Type] = PartialOrder.from(cmp) + implicit lazy val typesPartialOrder: PartialOrder[Type] = + CompareTypes.partialOrder } diff --git a/types/src/test/scala/aqua/types/TypeSpec.scala b/types/src/test/scala/aqua/types/TypeSpec.scala index 3c99d07c..6c67ff3f 100644 --- a/types/src/test/scala/aqua/types/TypeSpec.scala +++ b/types/src/test/scala/aqua/types/TypeSpec.scala @@ -38,15 +38,15 @@ class TypeSpec extends AnyFlatSpec with Matchers { } "top type" should "accept anything" in { - accepts(DataType.Top, u64) should be(true) - accepts(DataType.Top, LiteralType.bool) should be(true) - accepts(DataType.Top, `*`(u64)) should be(true) + accepts(TopType, u64) should be(true) + accepts(TopType, LiteralType.bool) should be(true) + accepts(TopType, `*`(u64)) should be(true) } "bottom type" should "be accepted by everything" in { - accepts(u64, DataType.Bottom) should be(true) - accepts(LiteralType.bool, DataType.Bottom) should be(true) - accepts(`*`(u64), DataType.Bottom) should be(true) + accepts(u64, BottomType) should be(true) + accepts(LiteralType.bool, BottomType) should be(true) + accepts(`*`(u64), BottomType) should be(true) } "arrays of scalars" should "be variant" in { @@ -62,52 +62,16 @@ class TypeSpec extends AnyFlatSpec with Matchers { (`[]`(`[]`(u32)): Type) <= `[]`(`[]`(u64)) should be(true) } - "products of scalars" should "be variant" in { - val one: Type = ProductType("one", NonEmptyMap.of("field" -> u32)) - val two: Type = ProductType("two", NonEmptyMap.of("field" -> u64, "other" -> string)) - val three: Type = ProductType("three", NonEmptyMap.of("field" -> u32)) + "structs of scalars" should "be variant" in { + val one: Type = StructType("one", NonEmptyMap.of("field" -> u32)) + val two: Type = StructType("two", NonEmptyMap.of("field" -> u64, "other" -> string)) + val three: Type = StructType("three", NonEmptyMap.of("field" -> u32)) accepts(one, two) should be(true) accepts(two, one) should be(false) PartialOrder[Type].eqv(one, three) should be(true) } - "arrows" should "be contravariant on arguments" in { - val one: Type = ArrowType(u32 :: Nil, None) - val two: Type = ArrowType(u64 :: Nil, None) - - accepts(one, two) should be(true) - - one > two should be(true) - two < one should be(true) - } - - "arrows" should "be variant on results" in { - val one: Type = ArrowType(Nil, Some(u64)) - val two: Type = ArrowType(Nil, Some(u32)) - - accepts(one, two) should be(true) - - one > two should be(true) - two < one should be(true) - } - - "arrows" should "respect both args and results" in { - val one: Type = ArrowType(bool :: f64 :: Nil, Some(u64)) - val two: Type = ArrowType(bool :: Nil, Some(u64)) - val three: Type = ArrowType(bool :: f32 :: Nil, Some(u64)) - val four: Type = ArrowType(bool :: f32 :: Nil, Some(u32)) - - accepts(one, two) should be(false) - accepts(two, one) should be(false) - - accepts(one, three) should be(false) - accepts(three, one) should be(true) - - accepts(one, four) should be(false) - accepts(four, one) should be(false) - } - "streams" should "be accepted as an array, but not vice versa" in { val stream: Type = StreamType(bool) val array: Type = ArrayType(bool) @@ -126,4 +90,68 @@ class TypeSpec extends AnyFlatSpec with Matchers { accepts(opt, opt) should be(true) } + "products" should "compare" in { + val empty: ProductType = NilType + val smth: ProductType = ConsType.cons(bool, empty) + + accepts(empty, smth) should be(true) + accepts(smth, empty) should be(false) + + val longer = ConsType.cons(string, smth) + accepts(empty, longer) should be(true) + accepts(smth, longer) should be(false) + accepts(longer, longer) should be(true) + accepts(longer, empty) should be(false) + accepts(longer, smth) should be(false) + accepts(ConsType.cons("label", string, empty), longer) should be(true) + + accepts(ConsType.cons(u64, empty), ConsType.cons(u32, empty)) should be(true) + accepts(ConsType.cons(u32, empty), ConsType.cons(u64, empty)) should be(false) + } + + "arrows" should "be contravariant on arguments" in { + val one: Type = ArrowType(ProductType(u32 :: Nil), NilType) + val onePrime: Type = ArrowType(ProductType(u32 :: bool :: Nil), NilType) + val two: Type = ArrowType(ProductType(u64 :: Nil), NilType) + + accepts(one, onePrime) should be(false) + accepts(onePrime, one) should be(true) + accepts(one, two) should be(true) + accepts(onePrime, two) should be(true) + + one > two should be(true) + two < one should be(true) + } + + "arrows" should "be variant on results" in { + val one: Type = ArrowType(NilType, ProductType(u64 :: Nil)) + val two: Type = ArrowType(NilType, ProductType(u32 :: Nil)) + val three: Type = ArrowType(NilType, ProductType(u32 :: bool :: Nil)) + + accepts(one, two) should be(true) + accepts(one, three) should be(true) + accepts(three, two) should be(false) + accepts(three, one) should be(false) + accepts(two, one) should be(false) + + one > two should be(true) + two < one should be(true) + } + + "arrows" should "respect both args and results" in { + val one: Type = ArrowType(ProductType(bool :: f64 :: Nil), ProductType(u64 :: Nil)) + val two: Type = ArrowType(ProductType(bool :: Nil), ProductType(u64 :: Nil)) + val three: Type = ArrowType(ProductType(bool :: f32 :: Nil), ProductType(u64 :: Nil)) + val four: Type = ArrowType(ProductType(bool :: f32 :: Nil), ProductType(u32 :: Nil)) + + accepts(one, two) should be(true) + accepts(two, one) should be(false) + + accepts(one, three) should be(false) + accepts(three, one) should be(true) + + accepts(one, four) should be(false) + accepts(four, one) should be(false) + } + }