fix(compiler): Fix nested abilities [fixes LNG-220] (#852)

* Fix fields gathering

* Remove println

* Add test

* Remove println

* Add comments

* Add comments
This commit is contained in:
InversionSpaces 2023-08-22 13:53:06 +04:00 committed by GitHub
parent 5db1282c1f
commit bf0b51fa5b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 396 additions and 120 deletions

View File

@ -5,11 +5,13 @@ import aqua.model.*
import aqua.model.inline.state.{Arrows, Exports, Mangler}
import aqua.raw.ops.RawTag
import aqua.raw.value.{ValueRaw, VarRaw}
import aqua.types.{AbilityType, ArrowType, BoxType, StreamType}
import aqua.types.{AbilityType, ArrowType, BoxType, StreamType, Type}
import cats.data.{Chain, IndexedStateT, State}
import cats.syntax.bifunctor.*
import cats.syntax.foldable.*
import cats.syntax.traverse.*
import cats.syntax.option.*
import cats.{Eval, Monoid}
import scribe.Logging
@ -117,24 +119,23 @@ object ArrowInliner extends Logging {
exports <- Exports[S].exports
arrows <- Arrows[S].arrows
// gather all arrows and variables from abilities
returnedFromAbilities = rets.collect { case VarModel(name, st @ AbilityType(_, _), _) =>
getVarsAndArrowsFromAbilities(name, None, st, exports, arrows)
}.foldMapA(_.bimap(_.toList, _.toList)).bimap(_.toMap, _.toMap)
returnedAbilities = rets.collect { case VarModel(name, at: AbilityType, _) => name -> at }
varsFromAbilities = returnedAbilities.flatMap { case (name, at) =>
getAbilityVars(name, None, at, exports)
}.toMap
arrowsFromAbilities = returnedAbilities.flatMap { case (name, at) =>
getAbilityArrows(name, None, at, exports, arrows)
}.toMap
// find and get resolved arrows if we return them from the function
returnedArrows = rets.collect { case VarModel(name, ArrowType(_, _), _) =>
name
}.toSet
returnedArrows = rets.collect { case VarModel(name, _: ArrowType, _) => name }.toSet
arrowsToSave <- Arrows[S].pickArrows(returnedArrows)
} yield {
val (valsFromAbilities, arrowsFromAbilities) = returnedFromAbilities
InlineResult(
SeqModel.wrap(ops.reverse: _*),
rets.reverse,
valsFromAbilities,
arrowsFromAbilities ++ arrowsToSave
)
}
} yield InlineResult(
SeqModel.wrap(ops.reverse),
rets.reverse,
varsFromAbilities,
arrowsFromAbilities ++ arrowsToSave
)
/**
* Get all arrows that is arguments from outer Arrows.
@ -271,75 +272,108 @@ object ArrowInliner extends Logging {
AbilityType.fullName(name, n) -> AbilityType.fullName(newName, n)
}
allNewNames = newFieldsName.add((name, newName)).toSortedMap
} yield {
val (allVars, allArrows) =
getVarsAndArrowsFromAbilities(vm.name, Option(newName), t, exports, arrows)
AbilityResolvingResult(allNewNames, allVars, allArrows)
allVars = getAbilityVars(vm.name, newName.some, t, exports)
allArrows = getAbilityArrows(vm.name, newName.some, t, exports, arrows)
} yield AbilityResolvingResult(allNewNames, allVars, allArrows)
}
/**
* Get ability fields (vars or arrows) from exports
*
* @param abilityName ability current name in state
* @param abilityNewName ability new name (for renaming)
* @param abilityType ability type
* @param exports exports state to resolve fields
* @param fields fields selector
* @return resolved ability fields (renamed if necessary)
*/
private def getAbilityFields[T <: Type](
abilityName: String,
abilityNewName: Option[String],
abilityType: AbilityType,
exports: Map[String, ValueModel]
)(fields: AbilityType => Map[String, T]): Map[String, ValueModel] =
fields(abilityType).flatMap { case (fName, _) =>
val fullName = AbilityType.fullName(abilityName, fName)
val newFullName = AbilityType.fullName(abilityNewName.getOrElse(abilityName), fName)
Exports
.getLastValue(fullName, exports)
.map(newFullName -> _)
}
/**
* Get ability vars and arrows as vars from exports
*
* @param abilityName ability current name in state
* @param abilityNewName ability new name (for renaming)
* @param abilityType ability type
* @param exports exports state to resolve fields
* @return resolved ability vars and arrows as vars (renamed if necessary)
*/
private def getAbilityVars(
abilityName: String,
abilityNewName: Option[String],
abilityType: AbilityType,
exports: Map[String, ValueModel]
): Map[String, ValueModel] = {
val get = getAbilityFields(
abilityName,
abilityNewName,
abilityType,
exports
)
get(_.variables) ++ get(_.arrows).flatMap {
case arrow @ (_, vm: VarModel) =>
arrow.some
case (_, m) =>
logger.error(s"Unexpected: '$m' cannot be an arrow")
None
}
}
/**
* Gather all arrows and variables from abilities recursively (because of possible nested abilities).
* Rename top names if needed in gathered fields and arrows.
* `top` name is a first name, i.e.: `topName.fieldName`.
* Only top name must be renamed to keep all field names unique.
* @param topOldName old name to find all fields in states
* @param topNewName new name to rename all fields in states
* @param abilityType type of current ability
* @param exports where to get values
* @param arrows where to get arrows
* @return
* Get ability arrows from arrows
*
* @param abilityName ability current name in state
* @param abilityNewName ability new name (for renaming)
* @param abilityType ability type
* @param exports exports state to resolve fields
* @param arrows arrows state to resolve arrows
* @return resolved ability arrows (renamed if necessary)
*/
private def getVarsAndArrowsFromAbilities(
topOldName: String,
topNewName: Option[String],
private def getAbilityArrows(
abilityName: String,
abilityNewName: Option[String],
abilityType: AbilityType,
exports: Map[String, ValueModel],
arrows: Map[String, FuncArrow]
): (Map[String, ValueModel], Map[String, FuncArrow]) = {
abilityType.fields.toSortedMap.toList.map { case (fName, fValue) =>
val currentOldName = AbilityType.fullName(topOldName, fName)
// for all nested fields, arrows and abilities only left side must be renamed
val currentNewName = topNewName.map(AbilityType.fullName(_, fName))
fValue match {
case nestedAbilityType @ AbilityType(_, _) =>
getVarsAndArrowsFromAbilities(
currentOldName,
currentNewName,
nestedAbilityType,
exports,
arrows
)
case ArrowType(_, _) =>
Exports
.getLastValue(currentOldName, exports)
.flatMap {
case vm @ VarModel(name, _, _) =>
arrows
.get(name)
.map(fa =>
(
Map(currentNewName.getOrElse(currentOldName) -> vm),
Map(name -> fa)
)
)
case lm @ LiteralModel(_, _) =>
logger.error(s"Unexpected. Literal '$lm' cannot be an arrow")
None
}
.getOrElse((Map.empty, Map.empty))
): Map[String, FuncArrow] = {
val get = getAbilityFields(
abilityName,
abilityNewName,
abilityType,
exports
)
case _ =>
Exports
.getLastValue(currentOldName, exports)
.map { vm =>
(Map(currentNewName.getOrElse(currentOldName) -> vm), Map.empty)
}
.getOrElse((Map.empty, Map.empty))
}
}.foldMapA(_.bimap(_.toList, _.toList)).bimap(_.toMap, _.toMap)
get(_.arrows).flatMap {
case (_, VarModel(name, _, _)) =>
arrows.get(name).map(name -> _)
case (_, m) =>
logger.error(s"Unexpected: '$m' cannot be an arrow")
None
}
}
private def getAbilityArrows[S: Arrows: Exports](
abilityName: String,
abilityType: AbilityType
): State[S, Map[String, FuncArrow]] = for {
exports <- Exports[S].exports
arrows <- Arrows[S].arrows
} yield getAbilityArrows(abilityName, None, abilityType, exports, arrows)
/**
* Prepare the state context for this function call
*
@ -435,25 +469,6 @@ object ArrowInliner extends Logging {
// Result could be renamed; take care about that
} yield (tree, fn.ret.map(_.renameVars(shouldRename)))
private def getAllArrowsFromAbility[S: Exports: Arrows: Mangler](
name: String,
sc: AbilityType
): State[S, Map[String, FuncArrow]] = {
for {
exports <- Exports[S].exports
arrows <- Arrows[S].arrows
} yield {
sc.fields.toSortedMap.toList.flatMap {
case (n, ArrowType(_, _)) =>
exports.get(AbilityType.fullName(name, n)).flatMap {
case VarModel(n, _, _) => arrows.get(n).map(n -> _)
case _ => None
}
case _ => None
}.toMap
}
}
private[inline] def callArrowRet[S: Exports: Arrows: Mangler](
arrow: FuncArrow,
call: CallModel
@ -461,8 +476,8 @@ object ArrowInliner extends Logging {
for {
passArrows <- Arrows[S].pickArrows(call.arrowArgNames)
arrowsFromAbilities <- call.abilityArgs
.traverse(getAllArrowsFromAbility)
.map(_.fold(Map.empty)(_ ++ _))
.traverse(getAbilityArrows.tupled)
.map(_.flatMap(_.toList).toMap)
exports <- Exports[S].exports
streams <- getOutsideStreamNames

View File

@ -3,7 +3,7 @@ package aqua.model.inline
import aqua.model.*
import aqua.model.inline.state.InliningState
import aqua.raw.ops.*
import aqua.raw.value.{ApplyPropertyRaw, FunctorRaw, IntoFieldRaw, IntoIndexRaw, LiteralRaw, VarRaw}
import aqua.raw.value.*
import aqua.types.*
import cats.syntax.show.*
import cats.syntax.option.*
@ -49,7 +49,7 @@ class ArrowInlinerSpec extends AnyFlatSpec with Matchers {
records: *string
cb(records)
*/
"arrow inliner" should "pass stream to callback properly" in {
it should "pass stream to callback properly" in {
val streamType = StreamType(ScalarType.string)
val streamVar = VarRaw("records", streamType)
val streamModel = VarModel("records", StreamType(ScalarType.string))
@ -141,7 +141,7 @@ class ArrowInlinerSpec extends AnyFlatSpec with Matchers {
records: *string
cb(records!)
*/
ignore /*"arrow inliner"*/ should "pass stream with gate to callback properly" in {
ignore /*it*/ should "pass stream with gate to callback properly" in {
val streamType = StreamType(ScalarType.string)
val streamVar = VarRaw("records", streamType)
val streamVarLambda =
@ -242,7 +242,7 @@ class ArrowInlinerSpec extends AnyFlatSpec with Matchers {
append_records(records)
<- records
*/
"arrow inliner" should "work with streams as arguments" in {
it should "work with streams as arguments" in {
val returnType = ArrayType(ArrayType(ScalarType.string))
val streamType = StreamType(ArrayType(ScalarType.string))
@ -338,7 +338,7 @@ class ArrowInlinerSpec extends AnyFlatSpec with Matchers {
* retval = res1 + res2 + res3
* <- retval
*/
"arrow inliner" should "leave meta after function inlining" in {
it should "leave meta after function inlining" in {
val innerName = "inner"
val innerRes = VarRaw("res", ScalarType.u16)
@ -475,7 +475,7 @@ class ArrowInlinerSpec extends AnyFlatSpec with Matchers {
* retval = inner() + inner() + 37
* <- retval
*/
"arrow inliner" should "omit meta if arrow was completely erased" in {
it should "omit meta if arrow was completely erased" in {
val innerName = "inner"
val innerRes = VarRaw("res", ScalarType.u16)
val innerRet = "42"
@ -753,7 +753,7 @@ class ArrowInlinerSpec extends AnyFlatSpec with Matchers {
* retval = 37 + c(1) + c(2)
* <- retval
*/
"arrow inliner" should "leave meta after returned closure inlining" in {
it should "leave meta after returned closure inlining" in {
val innerName = "inner"
val closureName = "closure"
val outterClosureName = "c"
@ -853,7 +853,7 @@ class ArrowInlinerSpec extends AnyFlatSpec with Matchers {
* retval = 37 + c() + c()
* <- retval
*/
"arrow inliner" should "omit meta if returned closure was completely erased" in {
it should "omit meta if returned closure was completely erased" in {
val innerName = "inner"
val closureName = "closure"
@ -1023,7 +1023,7 @@ class ArrowInlinerSpec extends AnyFlatSpec with Matchers {
* retval = 37 + a(1) + b(2) + c{3}
* <- retval
*/
"arrow inliner" should "correctly inline renamed closure [bug LNG-193]" in {
it should "correctly inline renamed closure [bug LNG-193]" in {
val innerName = "inner"
val closureName = "closure"
val outterClosureName = "c"
@ -1151,7 +1151,7 @@ class ArrowInlinerSpec extends AnyFlatSpec with Matchers {
* resT <- accept_closure(closure)
* <- resT
*/
"arrow inliner" should "correctly handle closure as argument [bug LNG-92]" in {
it should "correctly handle closure as argument [bug LNG-92]" in {
val acceptName = "accept_closure"
val closureName = "closure"
val testName = "test"
@ -1274,7 +1274,7 @@ class ArrowInlinerSpec extends AnyFlatSpec with Matchers {
* method = "method"
* test(method)
*/
"arrow inliner" should "not rename service call [bug LNG-199]" in {
it should "not rename service call [bug LNG-199]" in {
val testName = "test"
val argMethodName = "method"
val serviceName = "Test"
@ -1381,7 +1381,7 @@ class ArrowInlinerSpec extends AnyFlatSpec with Matchers {
v = arg.value
OpHa.identity(v)
*/
"arrow inliner" should "hold lambda" in {
it should "hold lambda" in {
val innerName = "inner"
// lambda that will be assigned to another variable
@ -1485,7 +1485,7 @@ class ArrowInlinerSpec extends AnyFlatSpec with Matchers {
func joinIdxLocal(idx: i16, nodes: []string):
join nodes[idx]
*/
"arrow inliner" should "not rename value in index array lambda" in {
it should "not rename value in index array lambda" in {
val innerName = "inner"
// lambda that will be assigned to another variable
@ -1575,7 +1575,7 @@ class ArrowInlinerSpec extends AnyFlatSpec with Matchers {
}
"arrow inliner" should "rename value in arrow with same name as in for" in {
it should "rename value in arrow with same name as in for" in {
val argVar = VarRaw("arg", ScalarType.u32)
val iVar = VarRaw("i", ScalarType.string)
val iVar0 = VarRaw("i-0", ScalarType.string)
@ -1665,7 +1665,7 @@ class ArrowInlinerSpec extends AnyFlatSpec with Matchers {
get() -> string
func inner() -> string:
results <- DTGetter.get_dt()
results <- Get.get()
<- results
func outer() -> []string:
@ -1673,7 +1673,7 @@ class ArrowInlinerSpec extends AnyFlatSpec with Matchers {
results <- use_name1()
<- results
*/
"arrow inliner" should "generate result in right order" in {
it should "generate result in right order" in {
val innerName = "inner"
val results = VarRaw("results", ScalarType.string)
val resultsOut = VarRaw("results", StreamType(ScalarType.string))
@ -1744,4 +1744,210 @@ class ArrowInlinerSpec extends AnyFlatSpec with Matchers {
) should be(true)
}
/**
* ability Inner:
* arrow(x: i8) -> bool
* field: i8
*
* ability Outer:
* inner: Inner
*
* func accept{Outer}() -> bool:
* res = Outer.inner.arrow(Outer.inner.field)
* <- res
*
* func main() -> bool:
* closure = (x: i8) -> bool:
* res = x > 0
* <- res
* inner = Inner(arrow = closure, field = 42)
* outer = Outer(inner = inner)
* res <- accept{outer}()
* <- res
*/
it should "handle nested abilities" in {
val arrowType = ArrowType(
ProductType(ScalarType.i8 :: Nil),
ProductType(ScalarType.bool :: Nil)
)
val closureType = arrowType.copy(
domain = ProductType.labelled(
"x" -> ScalarType.i8 :: Nil
)
)
val resVar = VarRaw("res", ScalarType.bool)
val innerType = AbilityType(
"Inner",
NonEmptyMap.of(
"arrow" -> arrowType,
"field" -> ScalarType.i8
)
)
val outerType = AbilityType(
"Outer",
NonEmptyMap.of(
"inner" -> innerType
)
)
val acceptBody = SeqTag.wrap(
AssignmentTag(
ApplyPropertyRaw.fromChain(
VarRaw("Outer", outerType),
Chain(
IntoFieldRaw("inner", innerType),
IntoArrowRaw(
"arrow",
arrowType,
List(
ApplyPropertyRaw.fromChain(
VarRaw("Outer", outerType),
Chain(
IntoFieldRaw("inner", innerType),
IntoFieldRaw("field", ScalarType.i8)
)
)
)
)
)
),
resVar.name
).leaf,
ReturnTag(NonEmptyList.one(resVar)).leaf
)
val accept = FuncArrow(
"accept",
acceptBody,
ArrowType(
ProductType.labelled("Outer" -> outerType :: Nil),
ProductType(ScalarType.bool :: Nil)
),
resVar :: Nil,
Map.empty,
Map.empty,
None
)
val closureBody = SeqTag.wrap(
AssignmentTag(
CallArrowRaw.service(
"cmp",
LiteralRaw.quote("cmp"),
"gt",
ArrowType(
ProductType(ScalarType.i64 :: ScalarType.i64 :: Nil),
ProductType(ScalarType.bool :: Nil)
),
List(
VarRaw("x", ScalarType.i8),
LiteralRaw.number(0)
)
),
"res"
).leaf,
ReturnTag(NonEmptyList.one(resVar)).leaf
)
val mainBody = SeqTag.wrap(
ClosureTag(
FuncRaw(
name = "closure",
arrow = ArrowRaw(
`type` = closureType,
ret = resVar :: Nil,
body = closureBody
)
),
detach = true
).leaf,
AssignmentTag(
AbilityRaw(
NonEmptyMap.of(
"arrow" -> VarRaw("closure", closureType),
"field" -> LiteralRaw.number(42)
),
innerType
),
"inner"
).leaf,
AssignmentTag(
AbilityRaw(
NonEmptyMap.of(
"inner" -> VarRaw("inner", innerType)
),
outerType
),
"outer"
).leaf,
CallArrowRawTag
.func(
"accept",
Call(
VarRaw("outer", outerType) :: Nil,
Call.Export(resVar.name, resVar.`type`) :: Nil
)
)
.leaf,
ReturnTag(NonEmptyList.one(resVar)).leaf
)
val main = FuncArrow(
"main",
mainBody,
ArrowType(
ProductType(Nil),
ProductType(ScalarType.bool :: Nil)
),
resVar :: Nil,
Map("accept" -> accept),
Map.empty,
None
)
val model = ArrowInliner
.callArrow[InliningState](
FuncArrow(
"wrapper",
CallArrowRawTag
.func(
"main",
Call(Nil, Nil)
)
.leaf,
ArrowType(
ProductType(Nil),
ProductType(Nil)
),
Nil,
Map("main" -> main),
Map.empty,
None
),
CallModel(Nil, Nil)
)
.runA(InliningState())
.value
val body = SeqModel.wrap(
SeqModel.wrap(
FlattenModel(LiteralModel.number(42), "literal_ap").leaf,
FlattenModel(VarModel("literal_ap", LiteralType.unsigned), "literal_props").leaf
),
CallServiceModel(
LiteralModel.quote("cmp"),
"gt",
CallModel(
VarModel("literal_props", LiteralType.unsigned) :: LiteralModel.number(0) :: Nil,
CallModel.Export("gt", ScalarType.bool) :: Nil
)
).leaf
)
val expected = List("main", "accept", "closure").foldRight(body) { case (name, body) =>
MetaModel.CallArrowModel(name).wrap(body)
}
model.equalsOrShowDiff(expected) shouldEqual true
}
}

View File

@ -2,6 +2,10 @@ package aqua.types
import cats.PartialOrder
import cats.data.NonEmptyMap
import cats.Eval
import cats.syntax.traverse.*
import cats.syntax.applicative.*
import cats.syntax.option.*
sealed trait Type {
@ -252,18 +256,69 @@ case class StructType(name: String, fields: NonEmptyMap[String, Type])
// Ability is an unordered collection of labelled types and arrows
case class AbilityType(name: String, fields: NonEmptyMap[String, Type]) extends NamedType {
lazy val arrows: Map[String, ArrowType] = fields.toNel.collect {
case (name, at @ ArrowType(_, _)) => (name, at)
}.toMap
/**
* Get all arrows defined in this ability and its sub-abilities.
* Paths to arrows are returned **without** ability name
* to allow renaming on call site.
*/
lazy val arrows: Map[String, ArrowType] = {
def getArrowsEval(path: Option[String], ability: AbilityType): Eval[List[(String, ArrowType)]] =
ability.fields.toNel.toList.flatTraverse {
case (abName, abType: AbilityType) =>
val newPath = path.fold(abName)(AbilityType.fullName(_, abName))
getArrowsEval(newPath.some, abType)
case (aName, aType: ArrowType) =>
val newPath = path.fold(aName)(AbilityType.fullName(_, aName))
List(newPath -> aType).pure
case _ => Nil.pure
}
lazy val abilities: List[(String, AbilityType)] = fields.toNel.collect {
case (name, at @ AbilityType(_, _)) => (name, at)
getArrowsEval(None, this).value.toMap
}
lazy val variables: List[(String, Type)] = fields.toNel.filter {
case (_, AbilityType(_, _)) => false
case (_, ArrowType(_, _)) => false
case (_, _) => true
/**
* Get all abilities defined in this ability and its sub-abilities.
* Paths to abilities are returned **without** ability name
* to allow renaming on call site.
*/
lazy val abilities: Map[String, AbilityType] = {
def getAbilitiesEval(
path: Option[String],
ability: AbilityType
): Eval[List[(String, AbilityType)]] =
ability.fields.toNel.toList.flatTraverse {
case (abName, abType: AbilityType) =>
val fullName = path.fold(abName)(AbilityType.fullName(_, abName))
getAbilitiesEval(fullName.some, abType).map(
(fullName -> abType) :: _
)
case _ => Nil.pure
}
getAbilitiesEval(None, this).value.toMap
}
/**
* Get all variables defined in this ability and its sub-abilities.
* Paths to variables are returned **without** ability name
* to allow renaming on call site.
*/
lazy val variables: Map[String, DataType] = {
def getVariablesEval(
path: Option[String],
ability: AbilityType
): Eval[List[(String, DataType)]] =
ability.fields.toNel.toList.flatTraverse {
case (abName, abType: AbilityType) =>
val newPath = path.fold(abName)(AbilityType.fullName(_, abName))
getVariablesEval(newPath.some, abType)
case (dName, dType: DataType) =>
val newPath = path.fold(dName)(AbilityType.fullName(_, dName))
List(newPath -> dType).pure
case _ => Nil.pure
}
getVariablesEval(None, this).value.toMap
}
override def toString: String =