From b22762ca6d6ed443a930780c928c7b72ae6a50f6 Mon Sep 17 00:00:00 2001 From: Dima Date: Wed, 21 Feb 2024 14:11:21 +0300 Subject: [PATCH] fix(compiler): Recursively find abilities [LNG-338] (#1086) --- aqua-src/antithesis.aqua | 44 +++++------ .../aqua/examples/abilitiesClosure.aqua | 30 ++++++- .../src/__test__/examples.spec.ts | 7 +- .../src/examples/abilityClosureCall.ts | 6 +- .../aqua/model/inline/ArrowInliner.scala | 79 ++++++++++++++++--- 5 files changed, 129 insertions(+), 37 deletions(-) diff --git a/aqua-src/antithesis.aqua b/aqua-src/antithesis.aqua index 923c46ba..6ba99770 100644 --- a/aqua-src/antithesis.aqua +++ b/aqua-src/antithesis.aqua @@ -1,31 +1,29 @@ aqua A -import "aqua-src/gen/OneMore.aqua" +export haveFun -export main +ability Compute: + job() -> string -alias SomeAlias: string +func lift() -> Compute: + job = () -> string: + <- "job done" + <- Compute(job) -data NestedStruct: - a: SomeAlias +ability Function: + run() -> string -data SomeStruct: - al: SomeAlias - nested: NestedStruct +func roundtrip{Function}() -> string: + res <- Function.run() + <- res -ability SomeAbility: - someStr: SomeStruct - nested: NestedStruct - al: SomeAlias - someFunc(ss: SomeStruct, nest: NestedStruct, al: SomeAlias) -> NestedStruct, SomeStruct, SomeAlias +func disjoint_run{Compute}() -> Function: + run = func () -> string: + <- Compute.job() + <- Function(run = run) -service Srv("a"): - check(ss: SomeStruct, nest: NestedStruct, al: SomeAlias) -> NestedStruct - check2() -> SomeStruct - check3() -> SomeAlias - -func withAb{SomeAbility}() -> SomeStruct: - <- SomeAbility.someStr - -func main(ss: SomeStruct, nest: NestedStruct, al: SomeAlias) -> string: - <- "" \ No newline at end of file +func haveFun() -> string: + comp = lift() + fn = disjoint_run{comp}() + res <- roundtrip{fn}() + <- res \ No newline at end of file diff --git a/integration-tests/aqua/examples/abilitiesClosure.aqua b/integration-tests/aqua/examples/abilitiesClosure.aqua index 7b16c9b5..1fd903b8 100644 --- a/integration-tests/aqua/examples/abilitiesClosure.aqua +++ b/integration-tests/aqua/examples/abilitiesClosure.aqua @@ -1,6 +1,6 @@ aqua M -export bugLNG314 +export bugLNG314, bugLNG338 ability WorkerJob: runOnSingleWorker(w: string) -> string @@ -20,4 +20,30 @@ func bugLNG314() -> string: worker_job = WorkerJob(runOnSingleWorker = job2) subnet_job <- disjoint_run{worker_job}() res <- runJob(subnet_job) - <- res \ No newline at end of file + <- res + +ability Compute: + job() -> string + +func lift() -> Compute: + job = () -> string: + <- "job done" + <- Compute(job) + +ability Function: + run() -> string + +func roundtrip{Function}() -> string: + res <- Function.run() + <- res + +func disj{Compute}() -> Function: + run = func () -> string: + <- Compute.job() + <- Function(run = run) + +func bugLNG338() -> string: + comp = lift() + fn = disj{comp}() + res <- roundtrip{fn}() + <- res \ No newline at end of file diff --git a/integration-tests/src/__test__/examples.spec.ts b/integration-tests/src/__test__/examples.spec.ts index 76e7a286..4fa629d0 100644 --- a/integration-tests/src/__test__/examples.spec.ts +++ b/integration-tests/src/__test__/examples.spec.ts @@ -40,7 +40,7 @@ import { multipleAbilityWithClosureCall, returnSrvAsAbilityCall, } from "../examples/abilityCall.js"; -import { bugLNG314Call } from "../examples/abilityClosureCall.js"; +import { bugLNG314Call, bugLNG338Call } from "../examples/abilityClosureCall.js"; import { nilLengthCall, nilLiteralCall, @@ -665,6 +665,11 @@ describe("Testing examples", () => { expect(result).toEqual("strstrstr"); }); + it("abilitiesClosure.aqua bug LNG-338", async () => { + let result = await bugLNG338Call(); + expect(result).toEqual("job done"); + }); + it("functors.aqua LNG-119 bug", async () => { let result = await bugLng119Call(); expect(result).toEqual([1]); diff --git a/integration-tests/src/examples/abilityClosureCall.ts b/integration-tests/src/examples/abilityClosureCall.ts index 5cbb5c5e..cc06b999 100644 --- a/integration-tests/src/examples/abilityClosureCall.ts +++ b/integration-tests/src/examples/abilityClosureCall.ts @@ -1,7 +1,11 @@ import { - bugLNG314 + bugLNG314, bugLNG338 } from "../compiled/examples/abilitiesClosure.js"; export async function bugLNG314Call(): Promise { return await bugLNG314(); } + +export async function bugLNG338Call(): Promise { + return await bugLNG338(); +} diff --git a/model/inline/src/main/scala/aqua/model/inline/ArrowInliner.scala b/model/inline/src/main/scala/aqua/model/inline/ArrowInliner.scala index f5fd6327..0287394d 100644 --- a/model/inline/src/main/scala/aqua/model/inline/ArrowInliner.scala +++ b/model/inline/src/main/scala/aqua/model/inline/ArrowInliner.scala @@ -17,6 +17,7 @@ import cats.syntax.option.* import cats.syntax.semigroup.* import cats.syntax.traverse.* import cats.{Eval, Monoid} +import scala.annotation.tailrec import scribe.Logging /** @@ -82,6 +83,62 @@ object ArrowInliner extends Logging { arrowsToSave: Map[String, FuncArrow] ) + /** + * Find abilities recursively, because ability can hold arrow with another ability in it. + * @param abilitiesToGather gather all fields for these abilities + * @param varsFromAbs already gathered variables + * @param arrowsFromAbs already gathered arrows + * @param processedAbs already processed abilities + * @return all needed variables and arrows + */ + @tailrec + private def arrowsAndVarsFromAbilities( + abilitiesToGather: Map[String, GeneralAbilityType], + exports: Map[String, ValueModel], + arrows: Map[String, FuncArrow], + varsFromAbs: Map[String, ValueModel] = Map.empty, + arrowsFromAbs: Map[String, FuncArrow] = Map.empty, + processedAbs: Set[String] = Set.empty + ): (Map[String, ValueModel], Map[String, FuncArrow]) = { + val varsFromAbilities = abilitiesToGather.flatMap { case (name, at) => + getAbilityVars(name, None, at, exports) + } + val arrowsFromAbilities = abilitiesToGather.flatMap { case (name, at) => + getAbilityArrows(name, None, at, exports, arrows) + } + + val allProcessed = abilitiesToGather.keySet ++ processedAbs + + // find all names that is used in arrows + val namesUsage = arrowsFromAbilities.values.flatMap(_.body.usesVarNames.value).toSet + + // check if there is abilities that we didn't gather + val abilitiesUsage = namesUsage.toList + .flatMap(exports.get) + .collect { + case ValueModel.Ability(vm, at) if !allProcessed.contains(vm.name) => + vm.name -> at + } + .toMap + + val allVars = varsFromAbilities ++ varsFromAbs + val allArrows = arrowsFromAbilities ++ arrowsFromAbs + + if (abilitiesUsage.isEmpty) { + (allVars, allArrows) + } else { + arrowsAndVarsFromAbilities( + abilitiesUsage, + exports, + arrows, + allVars, + allArrows, + allProcessed + ) + } + + } + // Apply a callable function, get its fully resolved body & optional value, if any private def inline[S: Mangler: Arrows: Exports]( fn: FuncArrow, @@ -104,15 +161,15 @@ object ArrowInliner extends Logging { exports <- Exports[S].exports arrows <- Arrows[S].arrows // gather all arrows and variables from abilities - returnedAbilities = rets.collect { case ValueModel.Ability(vm, at) => + abilitiesToGather = rets.collect { case ValueModel.Ability(vm, at) => vm.name -> at } - varsFromAbilities = returnedAbilities.flatMap { case (name, at) => - getAbilityVars(name, None, at, exports) - }.toMap - arrowsFromAbilities = returnedAbilities.flatMap { case (name, at) => - getAbilityArrows(name, None, at, exports, arrows) - }.toMap + arrsVars = arrowsAndVarsFromAbilities( + abilitiesToGather.toMap, + exports, + arrows + ) + (varsFromAbilities, arrowsFromAbilities) = arrsVars // find and get resolved arrows if we return them from the function returnedArrows = rets.collect { case VarModel(name, _: ArrowType, _) => name }.toSet @@ -172,9 +229,11 @@ object ArrowInliner extends Logging { abilityType, exports ) + val abilityExport = + exports.get(abilityName).map(vm => abilityNewName.getOrElse(abilityName) -> vm).toMap - get(_.variables) ++ get(_.arrows).flatMap { - case arrow @ (_, vm @ ValueModel.Arrow(_, _)) => + abilityExport ++ get(_.variables) ++ get(_.arrows).flatMap { + case arrow @ (_, ValueModel.Arrow(_, _)) => arrow.some case (_, m) => internalError(s"($m) cannot be an arrow") @@ -497,7 +556,7 @@ object ArrowInliner extends Logging { exports <- Exports[S].exports streams <- getOutsideStreamNames arrows = passArrows ++ arrowsFromAbilities - + inlineResult <- Exports[S].scope( Arrows[S].scope( for {