diff --git a/model/inline/src/main/scala/aqua/model/inline/ArrowInliner.scala b/model/inline/src/main/scala/aqua/model/inline/ArrowInliner.scala index 2e7ce8d9..5e86f52c 100644 --- a/model/inline/src/main/scala/aqua/model/inline/ArrowInliner.scala +++ b/model/inline/src/main/scala/aqua/model/inline/ArrowInliner.scala @@ -5,11 +5,13 @@ import aqua.model.* import aqua.model.inline.state.{Arrows, Exports, Mangler} import aqua.raw.ops.RawTag import aqua.raw.value.{ValueRaw, VarRaw} -import aqua.types.{AbilityType, ArrowType, BoxType, StreamType} +import aqua.types.{AbilityType, ArrowType, BoxType, StreamType, Type} + import cats.data.{Chain, IndexedStateT, State} import cats.syntax.bifunctor.* import cats.syntax.foldable.* import cats.syntax.traverse.* +import cats.syntax.option.* import cats.{Eval, Monoid} import scribe.Logging @@ -117,24 +119,23 @@ object ArrowInliner extends Logging { exports <- Exports[S].exports arrows <- Arrows[S].arrows // gather all arrows and variables from abilities - returnedFromAbilities = rets.collect { case VarModel(name, st @ AbilityType(_, _), _) => - getVarsAndArrowsFromAbilities(name, None, st, exports, arrows) - }.foldMapA(_.bimap(_.toList, _.toList)).bimap(_.toMap, _.toMap) + returnedAbilities = rets.collect { case VarModel(name, at: AbilityType, _) => name -> at } + varsFromAbilities = returnedAbilities.flatMap { case (name, at) => + getAbilityVars(name, None, at, exports) + }.toMap + arrowsFromAbilities = returnedAbilities.flatMap { case (name, at) => + getAbilityArrows(name, None, at, exports, arrows) + }.toMap // find and get resolved arrows if we return them from the function - returnedArrows = rets.collect { case VarModel(name, ArrowType(_, _), _) => - name - }.toSet + returnedArrows = rets.collect { case VarModel(name, _: ArrowType, _) => name }.toSet arrowsToSave <- Arrows[S].pickArrows(returnedArrows) - } yield { - val (valsFromAbilities, arrowsFromAbilities) = returnedFromAbilities - InlineResult( - SeqModel.wrap(ops.reverse: _*), - rets.reverse, - valsFromAbilities, - arrowsFromAbilities ++ arrowsToSave - ) - } + } yield InlineResult( + SeqModel.wrap(ops.reverse), + rets.reverse, + varsFromAbilities, + arrowsFromAbilities ++ arrowsToSave + ) /** * Get all arrows that is arguments from outer Arrows. @@ -271,75 +272,108 @@ object ArrowInliner extends Logging { AbilityType.fullName(name, n) -> AbilityType.fullName(newName, n) } allNewNames = newFieldsName.add((name, newName)).toSortedMap - } yield { - val (allVars, allArrows) = - getVarsAndArrowsFromAbilities(vm.name, Option(newName), t, exports, arrows) - AbilityResolvingResult(allNewNames, allVars, allArrows) + allVars = getAbilityVars(vm.name, newName.some, t, exports) + allArrows = getAbilityArrows(vm.name, newName.some, t, exports, arrows) + } yield AbilityResolvingResult(allNewNames, allVars, allArrows) + } + + /** + * Get ability fields (vars or arrows) from exports + * + * @param abilityName ability current name in state + * @param abilityNewName ability new name (for renaming) + * @param abilityType ability type + * @param exports exports state to resolve fields + * @param fields fields selector + * @return resolved ability fields (renamed if necessary) + */ + private def getAbilityFields[T <: Type]( + abilityName: String, + abilityNewName: Option[String], + abilityType: AbilityType, + exports: Map[String, ValueModel] + )(fields: AbilityType => Map[String, T]): Map[String, ValueModel] = + fields(abilityType).flatMap { case (fName, _) => + val fullName = AbilityType.fullName(abilityName, fName) + val newFullName = AbilityType.fullName(abilityNewName.getOrElse(abilityName), fName) + + Exports + .getLastValue(fullName, exports) + .map(newFullName -> _) + } + + /** + * Get ability vars and arrows as vars from exports + * + * @param abilityName ability current name in state + * @param abilityNewName ability new name (for renaming) + * @param abilityType ability type + * @param exports exports state to resolve fields + * @return resolved ability vars and arrows as vars (renamed if necessary) + */ + private def getAbilityVars( + abilityName: String, + abilityNewName: Option[String], + abilityType: AbilityType, + exports: Map[String, ValueModel] + ): Map[String, ValueModel] = { + val get = getAbilityFields( + abilityName, + abilityNewName, + abilityType, + exports + ) + + get(_.variables) ++ get(_.arrows).flatMap { + case arrow @ (_, vm: VarModel) => + arrow.some + case (_, m) => + logger.error(s"Unexpected: '$m' cannot be an arrow") + None } } /** - * Gather all arrows and variables from abilities recursively (because of possible nested abilities). - * Rename top names if needed in gathered fields and arrows. - * `top` name is a first name, i.e.: `topName.fieldName`. - * Only top name must be renamed to keep all field names unique. - * @param topOldName old name to find all fields in states - * @param topNewName new name to rename all fields in states - * @param abilityType type of current ability - * @param exports where to get values - * @param arrows where to get arrows - * @return + * Get ability arrows from arrows + * + * @param abilityName ability current name in state + * @param abilityNewName ability new name (for renaming) + * @param abilityType ability type + * @param exports exports state to resolve fields + * @param arrows arrows state to resolve arrows + * @return resolved ability arrows (renamed if necessary) */ - private def getVarsAndArrowsFromAbilities( - topOldName: String, - topNewName: Option[String], + private def getAbilityArrows( + abilityName: String, + abilityNewName: Option[String], abilityType: AbilityType, exports: Map[String, ValueModel], arrows: Map[String, FuncArrow] - ): (Map[String, ValueModel], Map[String, FuncArrow]) = { - abilityType.fields.toSortedMap.toList.map { case (fName, fValue) => - val currentOldName = AbilityType.fullName(topOldName, fName) - // for all nested fields, arrows and abilities only left side must be renamed - val currentNewName = topNewName.map(AbilityType.fullName(_, fName)) - fValue match { - case nestedAbilityType @ AbilityType(_, _) => - getVarsAndArrowsFromAbilities( - currentOldName, - currentNewName, - nestedAbilityType, - exports, - arrows - ) - case ArrowType(_, _) => - Exports - .getLastValue(currentOldName, exports) - .flatMap { - case vm @ VarModel(name, _, _) => - arrows - .get(name) - .map(fa => - ( - Map(currentNewName.getOrElse(currentOldName) -> vm), - Map(name -> fa) - ) - ) - case lm @ LiteralModel(_, _) => - logger.error(s"Unexpected. Literal '$lm' cannot be an arrow") - None - } - .getOrElse((Map.empty, Map.empty)) + ): Map[String, FuncArrow] = { + val get = getAbilityFields( + abilityName, + abilityNewName, + abilityType, + exports + ) - case _ => - Exports - .getLastValue(currentOldName, exports) - .map { vm => - (Map(currentNewName.getOrElse(currentOldName) -> vm), Map.empty) - } - .getOrElse((Map.empty, Map.empty)) - } - }.foldMapA(_.bimap(_.toList, _.toList)).bimap(_.toMap, _.toMap) + get(_.arrows).flatMap { + case (_, VarModel(name, _, _)) => + arrows.get(name).map(name -> _) + case (_, m) => + logger.error(s"Unexpected: '$m' cannot be an arrow") + None + } } + private def getAbilityArrows[S: Arrows: Exports]( + abilityName: String, + abilityType: AbilityType + ): State[S, Map[String, FuncArrow]] = for { + exports <- Exports[S].exports + arrows <- Arrows[S].arrows + } yield getAbilityArrows(abilityName, None, abilityType, exports, arrows) + /** * Prepare the state context for this function call * @@ -435,25 +469,6 @@ object ArrowInliner extends Logging { // Result could be renamed; take care about that } yield (tree, fn.ret.map(_.renameVars(shouldRename))) - private def getAllArrowsFromAbility[S: Exports: Arrows: Mangler]( - name: String, - sc: AbilityType - ): State[S, Map[String, FuncArrow]] = { - for { - exports <- Exports[S].exports - arrows <- Arrows[S].arrows - } yield { - sc.fields.toSortedMap.toList.flatMap { - case (n, ArrowType(_, _)) => - exports.get(AbilityType.fullName(name, n)).flatMap { - case VarModel(n, _, _) => arrows.get(n).map(n -> _) - case _ => None - } - case _ => None - }.toMap - } - } - private[inline] def callArrowRet[S: Exports: Arrows: Mangler]( arrow: FuncArrow, call: CallModel @@ -461,8 +476,8 @@ object ArrowInliner extends Logging { for { passArrows <- Arrows[S].pickArrows(call.arrowArgNames) arrowsFromAbilities <- call.abilityArgs - .traverse(getAllArrowsFromAbility) - .map(_.fold(Map.empty)(_ ++ _)) + .traverse(getAbilityArrows.tupled) + .map(_.flatMap(_.toList).toMap) exports <- Exports[S].exports streams <- getOutsideStreamNames diff --git a/model/inline/src/test/scala/aqua/model/inline/ArrowInlinerSpec.scala b/model/inline/src/test/scala/aqua/model/inline/ArrowInlinerSpec.scala index ba2c39c9..f9003148 100644 --- a/model/inline/src/test/scala/aqua/model/inline/ArrowInlinerSpec.scala +++ b/model/inline/src/test/scala/aqua/model/inline/ArrowInlinerSpec.scala @@ -3,7 +3,7 @@ package aqua.model.inline import aqua.model.* import aqua.model.inline.state.InliningState import aqua.raw.ops.* -import aqua.raw.value.{ApplyPropertyRaw, FunctorRaw, IntoFieldRaw, IntoIndexRaw, LiteralRaw, VarRaw} +import aqua.raw.value.* import aqua.types.* import cats.syntax.show.* import cats.syntax.option.* @@ -49,7 +49,7 @@ class ArrowInlinerSpec extends AnyFlatSpec with Matchers { records: *string cb(records) */ - "arrow inliner" should "pass stream to callback properly" in { + it should "pass stream to callback properly" in { val streamType = StreamType(ScalarType.string) val streamVar = VarRaw("records", streamType) val streamModel = VarModel("records", StreamType(ScalarType.string)) @@ -141,7 +141,7 @@ class ArrowInlinerSpec extends AnyFlatSpec with Matchers { records: *string cb(records!) */ - ignore /*"arrow inliner"*/ should "pass stream with gate to callback properly" in { + ignore /*it*/ should "pass stream with gate to callback properly" in { val streamType = StreamType(ScalarType.string) val streamVar = VarRaw("records", streamType) val streamVarLambda = @@ -242,7 +242,7 @@ class ArrowInlinerSpec extends AnyFlatSpec with Matchers { append_records(records) <- records */ - "arrow inliner" should "work with streams as arguments" in { + it should "work with streams as arguments" in { val returnType = ArrayType(ArrayType(ScalarType.string)) val streamType = StreamType(ArrayType(ScalarType.string)) @@ -338,7 +338,7 @@ class ArrowInlinerSpec extends AnyFlatSpec with Matchers { * retval = res1 + res2 + res3 * <- retval */ - "arrow inliner" should "leave meta after function inlining" in { + it should "leave meta after function inlining" in { val innerName = "inner" val innerRes = VarRaw("res", ScalarType.u16) @@ -475,7 +475,7 @@ class ArrowInlinerSpec extends AnyFlatSpec with Matchers { * retval = inner() + inner() + 37 * <- retval */ - "arrow inliner" should "omit meta if arrow was completely erased" in { + it should "omit meta if arrow was completely erased" in { val innerName = "inner" val innerRes = VarRaw("res", ScalarType.u16) val innerRet = "42" @@ -753,7 +753,7 @@ class ArrowInlinerSpec extends AnyFlatSpec with Matchers { * retval = 37 + c(1) + c(2) * <- retval */ - "arrow inliner" should "leave meta after returned closure inlining" in { + it should "leave meta after returned closure inlining" in { val innerName = "inner" val closureName = "closure" val outterClosureName = "c" @@ -853,7 +853,7 @@ class ArrowInlinerSpec extends AnyFlatSpec with Matchers { * retval = 37 + c() + c() * <- retval */ - "arrow inliner" should "omit meta if returned closure was completely erased" in { + it should "omit meta if returned closure was completely erased" in { val innerName = "inner" val closureName = "closure" @@ -1023,7 +1023,7 @@ class ArrowInlinerSpec extends AnyFlatSpec with Matchers { * retval = 37 + a(1) + b(2) + c{3} * <- retval */ - "arrow inliner" should "correctly inline renamed closure [bug LNG-193]" in { + it should "correctly inline renamed closure [bug LNG-193]" in { val innerName = "inner" val closureName = "closure" val outterClosureName = "c" @@ -1151,7 +1151,7 @@ class ArrowInlinerSpec extends AnyFlatSpec with Matchers { * resT <- accept_closure(closure) * <- resT */ - "arrow inliner" should "correctly handle closure as argument [bug LNG-92]" in { + it should "correctly handle closure as argument [bug LNG-92]" in { val acceptName = "accept_closure" val closureName = "closure" val testName = "test" @@ -1274,7 +1274,7 @@ class ArrowInlinerSpec extends AnyFlatSpec with Matchers { * method = "method" * test(method) */ - "arrow inliner" should "not rename service call [bug LNG-199]" in { + it should "not rename service call [bug LNG-199]" in { val testName = "test" val argMethodName = "method" val serviceName = "Test" @@ -1381,7 +1381,7 @@ class ArrowInlinerSpec extends AnyFlatSpec with Matchers { v = arg.value OpHa.identity(v) */ - "arrow inliner" should "hold lambda" in { + it should "hold lambda" in { val innerName = "inner" // lambda that will be assigned to another variable @@ -1485,7 +1485,7 @@ class ArrowInlinerSpec extends AnyFlatSpec with Matchers { func joinIdxLocal(idx: i16, nodes: []string): join nodes[idx] */ - "arrow inliner" should "not rename value in index array lambda" in { + it should "not rename value in index array lambda" in { val innerName = "inner" // lambda that will be assigned to another variable @@ -1575,7 +1575,7 @@ class ArrowInlinerSpec extends AnyFlatSpec with Matchers { } - "arrow inliner" should "rename value in arrow with same name as in for" in { + it should "rename value in arrow with same name as in for" in { val argVar = VarRaw("arg", ScalarType.u32) val iVar = VarRaw("i", ScalarType.string) val iVar0 = VarRaw("i-0", ScalarType.string) @@ -1665,7 +1665,7 @@ class ArrowInlinerSpec extends AnyFlatSpec with Matchers { get() -> string func inner() -> string: - results <- DTGetter.get_dt() + results <- Get.get() <- results func outer() -> []string: @@ -1673,7 +1673,7 @@ class ArrowInlinerSpec extends AnyFlatSpec with Matchers { results <- use_name1() <- results */ - "arrow inliner" should "generate result in right order" in { + it should "generate result in right order" in { val innerName = "inner" val results = VarRaw("results", ScalarType.string) val resultsOut = VarRaw("results", StreamType(ScalarType.string)) @@ -1744,4 +1744,210 @@ class ArrowInlinerSpec extends AnyFlatSpec with Matchers { ) should be(true) } + /** + * ability Inner: + * arrow(x: i8) -> bool + * field: i8 + * + * ability Outer: + * inner: Inner + * + * func accept{Outer}() -> bool: + * res = Outer.inner.arrow(Outer.inner.field) + * <- res + * + * func main() -> bool: + * closure = (x: i8) -> bool: + * res = x > 0 + * <- res + * inner = Inner(arrow = closure, field = 42) + * outer = Outer(inner = inner) + * res <- accept{outer}() + * <- res + */ + it should "handle nested abilities" in { + val arrowType = ArrowType( + ProductType(ScalarType.i8 :: Nil), + ProductType(ScalarType.bool :: Nil) + ) + val closureType = arrowType.copy( + domain = ProductType.labelled( + "x" -> ScalarType.i8 :: Nil + ) + ) + val resVar = VarRaw("res", ScalarType.bool) + val innerType = AbilityType( + "Inner", + NonEmptyMap.of( + "arrow" -> arrowType, + "field" -> ScalarType.i8 + ) + ) + val outerType = AbilityType( + "Outer", + NonEmptyMap.of( + "inner" -> innerType + ) + ) + + val acceptBody = SeqTag.wrap( + AssignmentTag( + ApplyPropertyRaw.fromChain( + VarRaw("Outer", outerType), + Chain( + IntoFieldRaw("inner", innerType), + IntoArrowRaw( + "arrow", + arrowType, + List( + ApplyPropertyRaw.fromChain( + VarRaw("Outer", outerType), + Chain( + IntoFieldRaw("inner", innerType), + IntoFieldRaw("field", ScalarType.i8) + ) + ) + ) + ) + ) + ), + resVar.name + ).leaf, + ReturnTag(NonEmptyList.one(resVar)).leaf + ) + + val accept = FuncArrow( + "accept", + acceptBody, + ArrowType( + ProductType.labelled("Outer" -> outerType :: Nil), + ProductType(ScalarType.bool :: Nil) + ), + resVar :: Nil, + Map.empty, + Map.empty, + None + ) + + val closureBody = SeqTag.wrap( + AssignmentTag( + CallArrowRaw.service( + "cmp", + LiteralRaw.quote("cmp"), + "gt", + ArrowType( + ProductType(ScalarType.i64 :: ScalarType.i64 :: Nil), + ProductType(ScalarType.bool :: Nil) + ), + List( + VarRaw("x", ScalarType.i8), + LiteralRaw.number(0) + ) + ), + "res" + ).leaf, + ReturnTag(NonEmptyList.one(resVar)).leaf + ) + + val mainBody = SeqTag.wrap( + ClosureTag( + FuncRaw( + name = "closure", + arrow = ArrowRaw( + `type` = closureType, + ret = resVar :: Nil, + body = closureBody + ) + ), + detach = true + ).leaf, + AssignmentTag( + AbilityRaw( + NonEmptyMap.of( + "arrow" -> VarRaw("closure", closureType), + "field" -> LiteralRaw.number(42) + ), + innerType + ), + "inner" + ).leaf, + AssignmentTag( + AbilityRaw( + NonEmptyMap.of( + "inner" -> VarRaw("inner", innerType) + ), + outerType + ), + "outer" + ).leaf, + CallArrowRawTag + .func( + "accept", + Call( + VarRaw("outer", outerType) :: Nil, + Call.Export(resVar.name, resVar.`type`) :: Nil + ) + ) + .leaf, + ReturnTag(NonEmptyList.one(resVar)).leaf + ) + + val main = FuncArrow( + "main", + mainBody, + ArrowType( + ProductType(Nil), + ProductType(ScalarType.bool :: Nil) + ), + resVar :: Nil, + Map("accept" -> accept), + Map.empty, + None + ) + + val model = ArrowInliner + .callArrow[InliningState]( + FuncArrow( + "wrapper", + CallArrowRawTag + .func( + "main", + Call(Nil, Nil) + ) + .leaf, + ArrowType( + ProductType(Nil), + ProductType(Nil) + ), + Nil, + Map("main" -> main), + Map.empty, + None + ), + CallModel(Nil, Nil) + ) + .runA(InliningState()) + .value + + val body = SeqModel.wrap( + SeqModel.wrap( + FlattenModel(LiteralModel.number(42), "literal_ap").leaf, + FlattenModel(VarModel("literal_ap", LiteralType.unsigned), "literal_props").leaf + ), + CallServiceModel( + LiteralModel.quote("cmp"), + "gt", + CallModel( + VarModel("literal_props", LiteralType.unsigned) :: LiteralModel.number(0) :: Nil, + CallModel.Export("gt", ScalarType.bool) :: Nil + ) + ).leaf + ) + + val expected = List("main", "accept", "closure").foldRight(body) { case (name, body) => + MetaModel.CallArrowModel(name).wrap(body) + } + + model.equalsOrShowDiff(expected) shouldEqual true + } } diff --git a/types/src/main/scala/aqua/types/Type.scala b/types/src/main/scala/aqua/types/Type.scala index 9337c03c..6e795292 100644 --- a/types/src/main/scala/aqua/types/Type.scala +++ b/types/src/main/scala/aqua/types/Type.scala @@ -2,6 +2,10 @@ package aqua.types import cats.PartialOrder import cats.data.NonEmptyMap +import cats.Eval +import cats.syntax.traverse.* +import cats.syntax.applicative.* +import cats.syntax.option.* sealed trait Type { @@ -252,18 +256,69 @@ case class StructType(name: String, fields: NonEmptyMap[String, Type]) // Ability is an unordered collection of labelled types and arrows case class AbilityType(name: String, fields: NonEmptyMap[String, Type]) extends NamedType { - lazy val arrows: Map[String, ArrowType] = fields.toNel.collect { - case (name, at @ ArrowType(_, _)) => (name, at) - }.toMap + /** + * Get all arrows defined in this ability and its sub-abilities. + * Paths to arrows are returned **without** ability name + * to allow renaming on call site. + */ + lazy val arrows: Map[String, ArrowType] = { + def getArrowsEval(path: Option[String], ability: AbilityType): Eval[List[(String, ArrowType)]] = + ability.fields.toNel.toList.flatTraverse { + case (abName, abType: AbilityType) => + val newPath = path.fold(abName)(AbilityType.fullName(_, abName)) + getArrowsEval(newPath.some, abType) + case (aName, aType: ArrowType) => + val newPath = path.fold(aName)(AbilityType.fullName(_, aName)) + List(newPath -> aType).pure + case _ => Nil.pure + } - lazy val abilities: List[(String, AbilityType)] = fields.toNel.collect { - case (name, at @ AbilityType(_, _)) => (name, at) + getArrowsEval(None, this).value.toMap } - lazy val variables: List[(String, Type)] = fields.toNel.filter { - case (_, AbilityType(_, _)) => false - case (_, ArrowType(_, _)) => false - case (_, _) => true + /** + * Get all abilities defined in this ability and its sub-abilities. + * Paths to abilities are returned **without** ability name + * to allow renaming on call site. + */ + lazy val abilities: Map[String, AbilityType] = { + def getAbilitiesEval( + path: Option[String], + ability: AbilityType + ): Eval[List[(String, AbilityType)]] = + ability.fields.toNel.toList.flatTraverse { + case (abName, abType: AbilityType) => + val fullName = path.fold(abName)(AbilityType.fullName(_, abName)) + getAbilitiesEval(fullName.some, abType).map( + (fullName -> abType) :: _ + ) + case _ => Nil.pure + } + + getAbilitiesEval(None, this).value.toMap + } + + /** + * Get all variables defined in this ability and its sub-abilities. + * Paths to variables are returned **without** ability name + * to allow renaming on call site. + */ + lazy val variables: Map[String, DataType] = { + def getVariablesEval( + path: Option[String], + ability: AbilityType + ): Eval[List[(String, DataType)]] = + ability.fields.toNel.toList.flatTraverse { + case (abName, abType: AbilityType) => + val newPath = path.fold(abName)(AbilityType.fullName(_, abName)) + getVariablesEval(newPath.some, abType) + case (dName, dType: DataType) => + val newPath = path.fold(dName)(AbilityType.fullName(_, dName)) + List(newPath -> dType).pure + case _ => Nil.pure + } + + getVariablesEval(None, this).value.toMap } override def toString: String =