From 443e65e3d8bca4774f5bdb6db5e526c5f2201c89 Mon Sep 17 00:00:00 2001 From: InversionSpaces Date: Fri, 1 Sep 2023 15:05:32 +0200 Subject: [PATCH] fix(compiler): Fix closure stream capture [fixes LNG-58] (#857) --- .../aqua/compiler/AquaCompilerSpec.scala | 10 +- .../aqua/examples/streamCapture.aqua | 50 ++ .../aqua/examples/streamReturn.aqua | 24 + .../src/__test__/examples.spec.ts | 25 +- .../src/examples/streamCapture.ts | 9 + .../src/examples/streamReturn.ts | 5 + .../aqua/model/inline/ArrowInliner.scala | 471 ++++++------------ .../aqua/model/inline/RawValueInliner.scala | 5 +- .../scala/aqua/model/inline/TagInliner.scala | 60 ++- .../inline/raw/CallArrowRawInliner.scala | 23 +- .../aqua/model/inline/state/Arrows.scala | 25 +- .../aqua/model/inline/state/Exports.scala | 19 +- .../aqua/model/inline/state/Mangler.scala | 6 + .../aqua/model/inline/ArrowInlinerSpec.scala | 402 ++++++++++++++- .../main/scala/aqua/raw/arrow/FuncRaw.scala | 10 + .../src/main/scala/aqua/raw/ops/RawTag.scala | 48 +- .../scala/aqua/raw/ops/RawTagGivens.scala | 29 +- .../main/scala/aqua/raw/value/ValueRaw.scala | 30 +- .../src/main/scala/aqua/model/ArgsCall.scala | 89 +++- .../src/main/scala/aqua/model/FuncArrow.scala | 8 +- .../aqua/model/transform/Transform.scala | 11 +- .../model/transform/pre/ArgsProvider.scala | 33 +- .../transform/pre/FuncPreTransformer.scala | 30 +- .../aqua/semantics/expr/func/ArrowSem.scala | 206 ++++---- .../semantics/expr/func/AssignmentSem.scala | 2 +- .../semantics/rules/StackInterpreter.scala | 41 +- .../abilities/AbilitiesInterpreter.scala | 50 +- .../rules/names/NamesInterpreter.scala | 34 +- .../rules/types/TypesInterpreter.scala | 101 ++-- .../scala/aqua/semantics/ArrowSemSpec.scala | 2 +- 30 files changed, 1177 insertions(+), 681 deletions(-) create mode 100644 integration-tests/aqua/examples/streamCapture.aqua create mode 100644 integration-tests/aqua/examples/streamReturn.aqua create mode 100644 integration-tests/src/examples/streamCapture.ts create mode 100644 integration-tests/src/examples/streamReturn.ts diff --git a/compiler/src/test/scala/aqua/compiler/AquaCompilerSpec.scala b/compiler/src/test/scala/aqua/compiler/AquaCompilerSpec.scala index 2fdd2c24..e74da781 100644 --- a/compiler/src/test/scala/aqua/compiler/AquaCompilerSpec.scala +++ b/compiler/src/test/scala/aqua/compiler/AquaCompilerSpec.scala @@ -101,11 +101,11 @@ class AquaCompilerSpec extends AnyFlatSpec with Matchers { val relay = VarRaw("-relay-", ScalarType.string) - def getDataSrv(name: String, t: Type) = { + def getDataSrv(name: String, varName: String, t: Type) = { CallServiceRes( LiteralModel.fromRaw(LiteralRaw.quote("getDataSrv")), name, - CallRes(Nil, Some(CallModel.Export(name, t))), + CallRes(Nil, Some(CallModel.Export(varName, t))), LiteralModel.fromRaw(ValueRaw.InitPeerId) ).leaf } @@ -146,7 +146,7 @@ class AquaCompilerSpec extends AnyFlatSpec with Matchers { val Some(exec) = aquaRes.funcs.find(_.funcName == "exec") - val peers = VarModel("peers", ArrayType(ScalarType.string)) + val peers = VarModel("-peers-arg-", ArrayType(ScalarType.string)) val peer = VarModel("peer-0", ScalarType.string) val resultsType = StreamType(ScalarType.string) val results = VarModel("results", resultsType) @@ -156,8 +156,8 @@ class AquaCompilerSpec extends AnyFlatSpec with Matchers { val expected = SeqRes.wrap( - getDataSrv("-relay-", ScalarType.string), - getDataSrv(peers.name, peers.`type`), + getDataSrv("-relay-", "-relay-", ScalarType.string), + getDataSrv("peers", peers.name, peers.`type`), XorRes.wrap( RestrictionRes(results.name, resultsType).wrap( SeqRes.wrap( diff --git a/integration-tests/aqua/examples/streamCapture.aqua b/integration-tests/aqua/examples/streamCapture.aqua new file mode 100644 index 00000000..9f448d27 --- /dev/null +++ b/integration-tests/aqua/examples/streamCapture.aqua @@ -0,0 +1,50 @@ +aqua StreamCapture + +export testStreamCaptureSimple, testStreamCaptureReturn + +-- SIMPLE + +func useCaptureSimple(push: string -> ()): + push("two") + +func testStreamCaptureSimple() -> []string: + stream: *string + + stream <<- "one" + + push = (s: string): + stream <<- s + + useCaptureSimple(push) + push("three") + + <- stream + +-- RETURN + +func captureStream() -> (string -> []string): + stream: *string + + stream <<- "one" + + capture = (s: string) -> []string: + stream <<- s + <- stream + + capture("two") + + <- capture + +func useCaptureReturn(capture: string -> []string): + capture("three") + +func rereturnCapture() -> (string -> []string): + capture <- captureStream() + useCaptureReturn(capture) + capture("four") + <- capture + +func testStreamCaptureReturn() -> []string: + on HOST_PEER_ID: + capture <- rereturnCapture() + <- capture("five") diff --git a/integration-tests/aqua/examples/streamReturn.aqua b/integration-tests/aqua/examples/streamReturn.aqua new file mode 100644 index 00000000..754e93c5 --- /dev/null +++ b/integration-tests/aqua/examples/streamReturn.aqua @@ -0,0 +1,24 @@ +aqua StreamReturn + +export testReturnStream + +func returnStream() -> *string: + stream: *string + stream <<- "one" + <- stream + +func useStream(stream: *string) -> *string: + stream <<- "two" + <- stream + +func rereturnStream() -> *string: + stream <- returnStream() + useStream(stream) + stream <<- "three" + <- stream + +func testReturnStream() -> []string: + on HOST_PEER_ID: + stream <- rereturnStream() + stream <<- "four" + <- stream \ No newline at end of file diff --git a/integration-tests/src/__test__/examples.spec.ts b/integration-tests/src/__test__/examples.spec.ts index caf58ec9..c5119a33 100644 --- a/integration-tests/src/__test__/examples.spec.ts +++ b/integration-tests/src/__test__/examples.spec.ts @@ -38,6 +38,9 @@ import { coCall } from '../examples/coCall.js'; import { bugLNG60Call, passArgsCall } from '../examples/passArgsCall.js'; import { streamArgsCall } from '../examples/streamArgsCall.js'; import { streamResultsCall } from '../examples/streamResultsCall.js'; +import { structuralTypingCall } from '../examples/structuralTypingCall'; +import { streamReturnCall } from '../examples/streamReturn.js'; +import { streamCaptureSimpleCall, streamCaptureReturnCall } from '../examples/streamCapture.js'; import { streamIfCall, streamForCall, streamTryCall, streamComplexCall } from '../examples/streamScopes.js'; import { pushToStreamCall } from '../examples/pushToStreamCall.js'; import { literalCall } from '../examples/returnLiteralCall.js'; @@ -79,7 +82,7 @@ export const relay2 = config.relays[1]; const relayPeerId2 = relay2.peerId; import log from 'loglevel'; -import {structuralTypingCall} from "../examples/structuralTypingCall"; + // log.setDefaultLevel("debug") async function start() { @@ -245,7 +248,7 @@ describe('Testing examples', () => { it('structuraltyping.aqua', async () => { let result = await structuralTypingCall(); - expect(result).toEqual("some_stringsome_stringsome_stringab_string"); + expect(result).toEqual('some_stringsome_stringsome_stringab_string'); }); it('collectionSugar array', async () => { @@ -389,7 +392,7 @@ describe('Testing examples', () => { expect(result).toStrictEqual([false, true]); }); - it('ability.aqua complex', async () => { + it('ability.aqua ability calls', async () => { let result = await checkAbCallsCall(); expect(result).toStrictEqual([true, false]); }); @@ -419,6 +422,22 @@ describe('Testing examples', () => { expect(streamResultsResult).toEqual(['new_name', 'new_name', 'new_name']); }); + it('streamReturn.aqua', async () => { + let streamReturnResult = await streamReturnCall(); + expect(streamReturnResult).toEqual(['one', 'two', 'three', 'four']); + }); + + it('streamCapture.aqua simple', async () => { + let streamCaptureResult = await streamCaptureSimpleCall(); + expect(streamCaptureResult).toEqual(['one', 'two', 'three']); + }); + + // TODO: Unskip this after LNG-226 is fixed + it.skip('streamCapture.aqua return', async () => { + let streamCaptureResult = await streamCaptureReturnCall(); + expect(streamCaptureResult).toEqual(['one', 'two', 'three', 'four', 'five']); + }); + it('assignment.aqua', async () => { let assignmentResult = await assignmentCall(); expect(assignmentResult).toEqual(['abc', 'hello']); diff --git a/integration-tests/src/examples/streamCapture.ts b/integration-tests/src/examples/streamCapture.ts new file mode 100644 index 00000000..2326f257 --- /dev/null +++ b/integration-tests/src/examples/streamCapture.ts @@ -0,0 +1,9 @@ +import { testStreamCaptureSimple, testStreamCaptureReturn } from '../compiled/examples/streamCapture.js'; + +export async function streamCaptureSimpleCall() { + return await testStreamCaptureSimple(); +} + +export async function streamCaptureReturnCall() { + return await testStreamCaptureReturn(); +} diff --git a/integration-tests/src/examples/streamReturn.ts b/integration-tests/src/examples/streamReturn.ts new file mode 100644 index 00000000..dd7b5ed5 --- /dev/null +++ b/integration-tests/src/examples/streamReturn.ts @@ -0,0 +1,5 @@ +import { testReturnStream } from '../compiled/examples/streamReturn.js'; + +export async function streamReturnCall() { + return await testReturnStream(); +} diff --git a/model/inline/src/main/scala/aqua/model/inline/ArrowInliner.scala b/model/inline/src/main/scala/aqua/model/inline/ArrowInliner.scala index 5e86f52c..9f2efa29 100644 --- a/model/inline/src/main/scala/aqua/model/inline/ArrowInliner.scala +++ b/model/inline/src/main/scala/aqua/model/inline/ArrowInliner.scala @@ -7,11 +7,14 @@ import aqua.raw.ops.RawTag import aqua.raw.value.{ValueRaw, VarRaw} import aqua.types.{AbilityType, ArrowType, BoxType, StreamType, Type} +import cats.data.StateT import cats.data.{Chain, IndexedStateT, State} +import cats.syntax.applicative.* import cats.syntax.bifunctor.* import cats.syntax.foldable.* import cats.syntax.traverse.* import cats.syntax.option.* +import cats.syntax.show.* import cats.{Eval, Monoid} import scribe.Logging @@ -26,7 +29,7 @@ object ArrowInliner extends Logging { arrow: FuncArrow, call: CallModel ): State[S, OpModel.Tree] = - callArrowRet(arrow, call).map(_._1) + callArrowRet(arrow, call).map { case (tree, _) => tree } // Get streams that was declared outside of a function private def getOutsideStreamNames[S: Exports]: State[S, Set[String]] = @@ -41,40 +44,26 @@ object ArrowInliner extends Logging { private def pushStreamResults[S: Mangler: Exports: Arrows]( outsideStreamNames: Set[String], exportTo: List[CallModel.Export], - results: List[ValueRaw], - body: OpModel.Tree + results: List[ValueRaw] ): State[S, (List[OpModel.Tree], List[ValueModel])] = - for { - // Fix return values with exports collected in the body - resolvedResult <- RawValueInliner.valueListToModel(results) - } yield { - // Fix the return values - val (ops, rets) = (exportTo zip resolvedResult).map { - case ( - CallModel.Export(n, StreamType(_)), - (res @ VarModel(_, StreamType(_), _), resDesugar) - ) if !outsideStreamNames.contains(n) => - resDesugar.toList -> res - case (CallModel.Export(exp, st @ StreamType(_)), (res, resDesugar)) => - // pass nested function results to a stream - (resDesugar.toList :+ PushToStreamModel( - res, - CallModel.Export(exp, st) - ).leaf) -> VarModel( - exp, - st, - Chain.empty - ) - case (_, (res, resDesugar)) => - resDesugar.toList -> res - }.foldLeft[(List[OpModel.Tree], List[ValueModel])]( - (body :: Nil, Nil) - ) { case ((ops, rets), (fo, r)) => - (fo ::: ops, r :: rets) - } - - (ops, rets) - } + // Fix return values with exports collected in the body + RawValueInliner + .valueListToModel(results) + .map(resolvedResults => + // Fix the return values + (exportTo zip resolvedResults).map { + case ( + CallModel.Export(n, StreamType(_)), + (res @ VarModel(_, StreamType(_), _), resDesugar) + ) if !outsideStreamNames.contains(n) => + resDesugar.toList -> res + case (cexp @ CallModel.Export(exp, st @ StreamType(_)), (res, resDesugar)) => + // pass nested function results to a stream + (resDesugar.toList :+ PushToStreamModel(res, cexp).leaf) -> cexp.asVar + case (_, (res, resDesugar)) => + resDesugar.toList -> res + }.unzip.leftMap(_.flatten) + ) /** * @param tree generated tree after inlining a function @@ -94,189 +83,43 @@ object ArrowInliner extends Logging { fn: FuncArrow, call: CallModel, outsideDeclaredStreams: Set[String] - ): State[S, InlineResult] = - for { - // Register captured values as available exports - _ <- Exports[S].resolved(fn.capturedValues) - _ <- Mangler[S].forbid(fn.capturedValues.keySet) + ): State[S, InlineResult] = for { + callableFuncBodyNoTopology <- TagInliner.handleTree(fn.body, fn.funcName) + callableFuncBody = + fn.capturedTopology + .fold(SeqModel)(ApplyTopologyModel.apply) + .wrap(callableFuncBodyNoTopology) - // Now, substitute the arrows that were received as function arguments - // Use the new op tree (args are replaced with values, names are unique & safe) - callableFuncBodyNoTopology <- TagInliner.handleTree(fn.body, fn.funcName) - callableFuncBody = - fn.capturedTopology - .fold[OpModel](SeqModel)(ApplyTopologyModel.apply) - .wrap(callableFuncBodyNoTopology) - - opsAndRets <- pushStreamResults( - outsideDeclaredStreams, - call.exportTo, - fn.ret, - callableFuncBody - ) - (ops, rets) = opsAndRets - - exports <- Exports[S].exports - arrows <- Arrows[S].arrows - // gather all arrows and variables from abilities - returnedAbilities = rets.collect { case VarModel(name, at: AbilityType, _) => name -> at } - varsFromAbilities = returnedAbilities.flatMap { case (name, at) => - getAbilityVars(name, None, at, exports) - }.toMap - arrowsFromAbilities = returnedAbilities.flatMap { case (name, at) => - getAbilityArrows(name, None, at, exports, arrows) - }.toMap - - // find and get resolved arrows if we return them from the function - returnedArrows = rets.collect { case VarModel(name, _: ArrowType, _) => name }.toSet - arrowsToSave <- Arrows[S].pickArrows(returnedArrows) - } yield InlineResult( - SeqModel.wrap(ops.reverse), - rets.reverse, - varsFromAbilities, - arrowsFromAbilities ++ arrowsToSave + opsAndRets <- pushStreamResults( + outsideStreamNames = outsideDeclaredStreams, + exportTo = call.exportTo, + results = fn.ret ) + (ops, rets) = opsAndRets - /** - * Get all arrows that is arguments from outer Arrows. - * Purge and push captured arrows and arrows as arguments into state. - * Grab all arrows that must be renamed. - * - * @param argsToArrowsRaw arguments with ArrowType - * @param func function where captured and returned may exist - * @param abilityArrows arrows from abilities that should be renamed - * @return all arrows that must be renamed in function body - */ - private def updateArrowsAndRenameArrowArgs[S: Mangler: Arrows: Exports]( - argsToArrowsRaw: Map[String, FuncArrow], - func: FuncArrow, - abilityArrows: Map[String, String] - ): State[S, Map[String, String]] = { - for { - argsToArrowsShouldRename <- Mangler[S] - .findNewNames( - argsToArrowsRaw.keySet - ) - .map(_ ++ abilityArrows) - argsToArrows = argsToArrowsRaw.map { case (k, v) => - argsToArrowsShouldRename.getOrElse(k, k) -> v - } - returnedArrows = func.ret.collect { case VarRaw(name, ArrowType(_, _)) => - name - }.toSet + exports <- Exports[S].exports + arrows <- Arrows[S].arrows + // gather all arrows and variables from abilities + returnedAbilities = rets.collect { case VarModel(name, at: AbilityType, _) => name -> at } + varsFromAbilities = returnedAbilities.flatMap { case (name, at) => + getAbilityVars(name, None, at, exports) + }.toMap + arrowsFromAbilities = returnedAbilities.flatMap { case (name, at) => + getAbilityArrows(name, None, at, exports, arrows) + }.toMap - returnedArrowsShouldRename <- Mangler[S].findNewNames(returnedArrows) - renamedCapturedArrows = func.capturedArrows.map { case (k, v) => - returnedArrowsShouldRename.getOrElse(k, k) -> v - } + // find and get resolved arrows if we return them from the function + returnedArrows = rets.collect { case VarModel(name, _: ArrowType, _) => name }.toSet + arrowsToSave <- Arrows[S].pickArrows(returnedArrows) - _ <- Arrows[S].resolved(renamedCapturedArrows ++ argsToArrows) - } yield { - argsToArrowsShouldRename ++ returnedArrowsShouldRename - } - } - - /** - * @param argsToDataRaw data arguments to rename - * @param abilityValues values from abilities to rename - * @return all values that must be renamed in function body - */ - private def updateExportsAndRenameDataArgs[S: Mangler: Arrows: Exports]( - argsToDataRaw: Map[String, ValueModel], - abilityValues: Map[String, String] - ): State[S, Map[String, String]] = { - for { - // Find all duplicates in arguments - // we should not find new names for 'abilityValues' arguments that will be renamed by 'streamToRename' - argsToDataShouldRename <- Mangler[S] - .findNewNames( - argsToDataRaw.keySet - ) - .map(_ ++ abilityValues) - - // Do not rename arguments if they just match external names - argsToData = argsToDataRaw.map { case (k, v) => - argsToDataShouldRename.getOrElse(k, k) -> v - } - - _ <- Exports[S].resolved(argsToData) - } yield argsToDataShouldRename - } - - // Rename all exports-to-stream for streams that passed as arguments - private def renameStreams( - tree: RawTag.Tree, - streamArgs: Map[String, VarModel] - ): RawTag.Tree = { - // collect arguments with stream type - // to exclude it from resolving and rename it with a higher-level stream that passed by argument - val streamsToRename = streamArgs.view.mapValues(_.name).toMap - - if (streamsToRename.isEmpty) tree - else - tree - .map(_.mapValues(_.map { - // if an argument is a BoxType (Array or Option), but we pass a stream, - // change a type as stream to not miss `$` sign in air - // @see ArrowInlinerSpec `pass stream to callback properly` test - case v @ VarRaw(name, baseType: BoxType) if streamsToRename.contains(name) => - v.copy(baseType = StreamType(baseType.element)) - case v: VarRaw if streamsToRename.contains(v.name) => - v.copy(baseType = StreamType(v.baseType)) - case v => v - })) - .renameExports(streamsToRename) - } - - case class AbilityResolvingResult( - namesToRename: Map[String, String], - renamedExports: Map[String, ValueModel], - renamedArrows: Map[String, FuncArrow] + body = SeqModel.wrap(callableFuncBody :: ops) + } yield InlineResult( + body, + rets, + varsFromAbilities, + arrowsFromAbilities ++ arrowsToSave ) - given Monoid[AbilityResolvingResult] with - - override val empty: AbilityResolvingResult = - AbilityResolvingResult(Map.empty, Map.empty, Map.empty) - - override def combine( - a: AbilityResolvingResult, - b: AbilityResolvingResult - ): AbilityResolvingResult = - AbilityResolvingResult( - a.namesToRename ++ b.namesToRename, - a.renamedExports ++ b.renamedExports, - a.renamedArrows ++ b.renamedArrows - ) - - /** - * Generate new names for all ability fields and arrows if necessary. - * Gather all fields and arrows from Arrows and Exports states - * @param name ability name in state - * @param vm ability variable - * @param t ability type - * @param exports previous Exports - * @param arrows previous Arrows - * @return names to rename, Exports and Arrows with all ability fields and arrows - */ - private def renameAndResolveAbilities[S: Mangler: Arrows: Exports]( - name: String, - vm: VarModel, - t: AbilityType, - exports: Map[String, ValueModel], - arrows: Map[String, FuncArrow] - ): State[S, AbilityResolvingResult] = { - for { - newName <- Mangler[S].findNewName(name) - newFieldsName = t.fields.mapBoth { case (n, _) => - AbilityType.fullName(name, n) -> AbilityType.fullName(newName, n) - } - allNewNames = newFieldsName.add((name, newName)).toSortedMap - allVars = getAbilityVars(vm.name, newName.some, t, exports) - allArrows = getAbilityArrows(vm.name, newName.some, t, exports, arrows) - } yield AbilityResolvingResult(allNewNames, allVars, allArrows) - } - /** * Get ability fields (vars or arrows) from exports * @@ -374,135 +217,129 @@ object ArrowInliner extends Logging { arrows <- Arrows[S].arrows } yield getAbilityArrows(abilityName, None, abilityType, exports, arrows) + final case class Renamed[T]( + renames: Map[String, String], + renamed: Map[String, T] + ) + /** - * Prepare the state context for this function call + * Rename values and forbid new names * - * @param fn - * Function that will be called - * @param call - * Call object - * @tparam S - * State - * @return - * Tree with substituted values, list of return values prior to function calling/inlining + * @param values Mapping name -> value + * @return Renamed values and renames + */ + private def findNewNames[S: Mangler, T](values: Map[String, T]): State[S, Renamed[T]] = + Mangler[S].findAndForbidNames(values.keySet).map { renames => + Renamed( + renames, + values.map { case (name, value) => + renames.getOrElse(name, name) -> value + } + ) + } + + /** + * Prepare the function and the context for inlining + * + * @param fn Function that will be called + * @param call Call object + * @param exports Exports state before calling/inlining + * @param arrows Arrows that are available for callee + * @return Prepared function */ private def prelude[S: Mangler: Arrows: Exports]( fn: FuncArrow, call: CallModel, - oldExports: Map[String, ValueModel], + exports: Map[String, ValueModel], arrows: Map[String, FuncArrow] - ): State[S, (RawTag.Tree, List[ValueRaw])] = - for { - // Collect all arguments: what names are used inside the function, what values are received - args <- State.pure(ArgsCall(fn.arrowType.domain, call.args)) + ): State[S, FuncArrow] = for { + args <- ArgsCall(fn.arrowType.domain, call.args).pure[State[S, *]] - abArgs = args.abilityArgs + argNames = args.argNames + capturedNames = fn.capturedValues.keySet ++ fn.capturedArrows.keySet - abilityResolvingResult <- abArgs.toList.traverse { case (str, (vm, sct)) => - renameAndResolveAbilities(str, vm, sct, oldExports, arrows) - }.map(_.combineAll) + /** + * Substitute all arguments inside function body. + * Data arguments could be passed as variables or values (expressions), + * so we need to resolve them in `Exports`. + * Streams, arrows, abilities are passed as variables only, + * so we just rename them in the function body to match + * the names in the current context. + */ + data <- findNewNames(args.dataArgs) + streamRenames = args.streamArgsRenames + arrowRenames = args.arrowArgsRenames + abRenames = args.abilityArgsRenames - absRenames = abilityResolvingResult.namesToRename - absVars = abilityResolvingResult.renamedExports - absArrows = abilityResolvingResult.renamedArrows + /** + * Find new names for captured values and arrows + * to avoid collisions, then resolve them in context. + */ + capturedValues <- findNewNames(fn.capturedValues) + capturedArrows <- findNewNames(fn.capturedArrows) - arrowArgs = args.arrowArgs(arrows) - // Update states and rename tags - renamedArrows <- updateArrowsAndRenameArrowArgs(arrowArgs ++ absArrows, fn, absRenames) - - argsToDataShouldRename <- updateExportsAndRenameDataArgs(args.dataArgs ++ absVars, absRenames) - - // rename variables that store arrows - _ <- Exports[S].renameVariables(renamedArrows) - - /* - * Don't rename arrows from abilities in a body, because we link arrows by VarModel - * and they won't be called directly. - * They could intersect with arrows defined inside the body - * - * ability Simple: - * arrow() -> bool - * - * func foo{Simple}() -> bool, bool: - * closure = () -> bool: - * <- true - * <- closure(), Simple.arrow() - * - * func main() -> bool, bool: - * closure = () -> bool: - * <- false - * MySimple = Simple(arrow = closure) - * -- here we will rename arrow in Arrows[S] to 'closure-0' - * -- and link to arrow as 'Simple.arrow' -> VarModel('closure-0') - * -- and it will work well with closure with the same name 'closure' inside 'foo' - * foo{MySimple}() - */ - allShouldRename = argsToDataShouldRename ++ (renamedArrows -- absArrows.keySet) ++ absRenames - - // Rename all renamed arguments in the body - treeRenamed = fn.body.rename(allShouldRename) - treeStreamsRenamed = renameStreams( - treeRenamed, - args.streamArgs.map { case (k, v) => argsToDataShouldRename.getOrElse(k, k) -> v } + /** + * Function defines variables inside its body. + * We rename and forbid all those names so that when we inline + * **another function inside this one** we would know what names + * are prohibited because they are used inside **this function**. + */ + defineNames <- StateT.liftF( + fn.body.definesVarNames.map( + _ -- argNames -- capturedNames ) + ) + defineRenames <- Mangler[S].findAndForbidNames(defineNames) - // Function body on its own defines some values; collect their names - // except stream arguments. They should be already renamed - treeDefines = - treeStreamsRenamed.definesVarNames.value -- - args.streamArgs.keySet -- - args.streamArgs.values.map(_.name) -- - call.exportTo.filter { exp => - exp.`type` match { - case StreamType(_) => false - case _ => true - } - }.map(_.name) + renaming = ( + data.renames ++ + streamRenames ++ + arrowRenames ++ + abRenames ++ + capturedValues.renames ++ + capturedArrows.renames ++ + defineRenames + ) - // We have some names in scope (forbiddenNames), can't introduce them again; so find new names - shouldRename <- Mangler[S].findNewNames(treeDefines).map(_ ++ allShouldRename) - _ <- Mangler[S].forbid(treeDefines ++ shouldRename.values.toSet) + arrowsResolved = arrows ++ capturedArrows.renamed + exportsResolved = exports ++ data.renamed ++ capturedValues.renamed - // If there was a collision, rename exports and usages with new names - tree = treeStreamsRenamed.rename(shouldRename) + tree = fn.body.rename(renaming) + ret = fn.ret.map(_.renameVars(renaming)) - // Result could be renamed; take care about that - } yield (tree, fn.ret.map(_.renameVars(shouldRename))) + _ <- Arrows[S].resolved(arrowsResolved) + _ <- Exports[S].resolved(exportsResolved) + } yield fn.copy(body = tree, ret = ret) private[inline] def callArrowRet[S: Exports: Arrows: Mangler]( arrow: FuncArrow, call: CallModel - ): State[S, (OpModel.Tree, List[ValueModel])] = - for { - passArrows <- Arrows[S].pickArrows(call.arrowArgNames) - arrowsFromAbilities <- call.abilityArgs - .traverse(getAbilityArrows.tupled) - .map(_.flatMap(_.toList).toMap) + ): State[S, (OpModel.Tree, List[ValueModel])] = for { + passArrows <- Arrows[S].pickArrows(call.arrowArgNames) + arrowsFromAbilities <- call.abilityArgs + .traverse(getAbilityArrows.tupled) + .map(_.flatMap(_.toList).toMap) - exports <- Exports[S].exports - streams <- getOutsideStreamNames + exports <- Exports[S].exports + streams <- getOutsideStreamNames - inlineResult <- Exports[S].scope( - Arrows[S].scope( - for { - // Process renamings, prepare environment - tr <- prelude[S](arrow, call, exports, passArrows ++ arrowsFromAbilities) - (tree, results) = tr - inlineResult <- ArrowInliner.inline( - arrow.copy(body = tree, ret = results), - call, - streams - ) - } yield inlineResult - ) + inlineResult <- Exports[S].scope( + Arrows[S].scope( + for { + // Process renamings, prepare environment + fn <- ArrowInliner.prelude(arrow, call, exports, passArrows ++ arrowsFromAbilities) + inlineResult <- ArrowInliner.inline(fn, call, streams) + } yield inlineResult ) + ) - _ <- Arrows[S].resolved(inlineResult.arrowsToSave) - _ <- Exports[S].resolved( - call.exportTo - .map(_.name) - .zip(inlineResult.returnedValues) - .toMap ++ inlineResult.exportsToSave - ) - } yield inlineResult.tree -> inlineResult.returnedValues + exportTo = call.exportTo.map(_.name) + _ <- Arrows[S].resolved(inlineResult.arrowsToSave) + _ <- Exports[S].resolved( + exportTo + .zip(inlineResult.returnedValues) + .toMap ++ inlineResult.exportsToSave + ) + _ <- Mangler[S].forbid(exportTo.toSet) + } yield inlineResult.tree -> inlineResult.returnedValues } diff --git a/model/inline/src/main/scala/aqua/model/inline/RawValueInliner.scala b/model/inline/src/main/scala/aqua/model/inline/RawValueInliner.scala index 8195209f..46ecc26a 100644 --- a/model/inline/src/main/scala/aqua/model/inline/RawValueInliner.scala +++ b/model/inline/src/main/scala/aqua/model/inline/RawValueInliner.scala @@ -37,7 +37,10 @@ object RawValueInliner extends Logging { ): State[S, (ValueModel, Inline)] = raw match { case VarRaw(name, t) => - Exports[S].exports.map(VarModel(name, t, Chain.empty).resolveWith).map(_ -> Inline.empty) + for { + exports <- Exports[S].exports + model = VarModel(name, t, Chain.empty).resolveWith(exports) + } yield model -> Inline.empty case LiteralRaw(value, t) => State.pure(LiteralModel(value, t) -> Inline.empty) diff --git a/model/inline/src/main/scala/aqua/model/inline/TagInliner.scala b/model/inline/src/main/scala/aqua/model/inline/TagInliner.scala index 03293187..38fc8f74 100644 --- a/model/inline/src/main/scala/aqua/model/inline/TagInliner.scala +++ b/model/inline/src/main/scala/aqua/model/inline/TagInliner.scala @@ -7,7 +7,7 @@ import aqua.model.inline.raw.CallArrowRawInliner import aqua.raw.value.ApplyBinaryOpRaw.Op as BinOp import aqua.raw.ops.* import aqua.raw.value.* -import aqua.types.{BoxType, CanonStreamType, StreamType} +import aqua.types.{BoxType, CanonStreamType, DataType, StreamType} import aqua.model.inline.Inline.parDesugarPrefixOpt import cats.syntax.traverse.* @@ -295,27 +295,45 @@ object TagInliner extends Logging { ) case PushToStreamTag(operand, exportTo) => - valueToModel(operand).map { case (v, p) => - TagInlined.Single( - model = PushToStreamModel(v, CallModel.callExport(exportTo)), - prefix = p - ) + ( + valueToModel(operand), + // We need to resolve stream because it could + // be actually pointing to another var. + // TODO: Looks like a hack, refator resolving + valueToModel(exportTo.toRaw) + ).mapN { + case ((v, p), (VarModel(name, st, Chain.nil), None)) => + TagInlined.Single( + model = PushToStreamModel(v, CallModel.Export(name, st)), + prefix = p + ) + case (_, (vm, prefix)) => + logger.error( + s"Unexpected: stream (${exportTo}) resolved " + + s"to ($vm) with prefix ($prefix)" + ) + TagInlined.Empty() } case CanonicalizeTag(operand, exportTo) => valueToModel(operand).flatMap { // pass literals as is case (l @ LiteralModel(_, _), p) => - for { - _ <- Exports[S].resolved(exportTo.name, l) - } yield TagInlined.Empty(prefix = p) + Exports[S] + .resolved(exportTo.name, l) + .as(TagInlined.Empty(prefix = p)) case (v, p) => - TagInlined - .Single( - model = CanonicalizeModel(v, CallModel.callExport(exportTo)), - prefix = p + Exports[S] + .resolved( + exportTo.name, + VarModel(exportTo.name, exportTo.`type`) + ) + .as( + TagInlined.Single( + model = CanonicalizeModel(v, CallModel.callExport(exportTo)), + prefix = p + ) ) - .pure } case FlattenTag(operand, assignTo) => @@ -356,8 +374,6 @@ object TagInliner extends Logging { case AssignmentTag(value, assignTo) => for { - // NOTE: Name should not exist yet - _ <- Mangler[S].forbidName(assignTo) modelAndPrefix <- value match { // if we assign collection to a stream, we must use it's name, because it is already created with 'new' case c @ CollectionRaw(_, _: StreamType) => @@ -372,10 +388,9 @@ object TagInliner extends Logging { case ClosureTag(arrow, detach) => if (detach) Arrows[S].resolved(arrow, None).as(TagInlined.Empty()) else - for { - t <- Mangler[S].findAndForbidName(arrow.name) - _ <- Arrows[S].resolved(arrow, Some(t)) - } yield TagInlined.Single(model = CaptureTopologyModel(t)) + Arrows[S] + .resolved(arrow, arrow.name.some) + .as(TagInlined.Single(model = CaptureTopologyModel(arrow.name))) case NextTag(item) => for { @@ -393,8 +408,9 @@ object TagInliner extends Logging { case VarRaw(name, _) => for { cd <- valueToModel(value) - _ <- Exports[S].resolved(name, cd._1) - } yield TagInlined.Empty(prefix = cd._2) + (vm, prefix) = cd + _ <- Exports[S].resolved(name, vm) + } yield TagInlined.Empty(prefix = prefix) case _ => none case _: SeqGroupTag => pure(SeqModel) diff --git a/model/inline/src/main/scala/aqua/model/inline/raw/CallArrowRawInliner.scala b/model/inline/src/main/scala/aqua/model/inline/raw/CallArrowRawInliner.scala index 4abe6679..80a6b6d5 100644 --- a/model/inline/src/main/scala/aqua/model/inline/raw/CallArrowRawInliner.scala +++ b/model/inline/src/main/scala/aqua/model/inline/raw/CallArrowRawInliner.scala @@ -8,6 +8,7 @@ import aqua.model.inline.state.{Arrows, Exports, Mangler} import aqua.raw.ops.Call import aqua.types.ArrowType import aqua.raw.value.CallArrowRaw + import cats.data.{Chain, State} import scribe.Logging @@ -27,15 +28,21 @@ object CallArrowRawInliner extends RawInliner[CallArrowRaw] with Logging { logger.trace(Console.BLUE + s"call service id $serviceId" + Console.RESET) for { cd <- callToModel(call, true) + (callModel, callInline) = cd sd <- valueToModel(serviceId) - } yield cd._1.exportTo.map(_.asVar.resolveWith(exports)) -> Inline( - Chain( - SeqModel.wrap( - sd._2.toList ++ - cd._2.toList :+ CallServiceModel(sd._1, value.name, cd._1).leaf + (serviceIdValue, serviceIdInline) = sd + values = callModel.exportTo.map(e => e.name -> e.asVar.resolveWith(exports)).toMap + inline = Inline( + Chain( + SeqModel.wrap( + serviceIdInline.toList ++ callInline.toList :+ + CallServiceModel(serviceIdValue, value.name, callModel).leaf + ) ) ) - ) + _ <- Exports[S].resolved(values) + _ <- Mangler[S].forbid(values.keySet) + } yield values.values.toList -> inline case None => /** * Here the back hop happens from [[TagInliner]] to [[ArrowInliner.callArrow]] @@ -61,9 +68,7 @@ object CallArrowRawInliner extends RawInliner[CallArrowRaw] with Logging { // Leave meta information in tree after inlining MetaModel .CallArrowModel(fn.funcName) - .wrap( - SeqModel.wrap(p.toList :+ body: _*) - ) + .wrap(SeqModel.wrap(p.toList :+ body)) ) ) } diff --git a/model/inline/src/main/scala/aqua/model/inline/state/Arrows.scala b/model/inline/src/main/scala/aqua/model/inline/state/Arrows.scala index e6640343..d92c65b8 100644 --- a/model/inline/src/main/scala/aqua/model/inline/state/Arrows.scala +++ b/model/inline/src/main/scala/aqua/model/inline/state/Arrows.scala @@ -2,10 +2,12 @@ package aqua.model.inline.state import aqua.model.{ArgsCall, FuncArrow} import aqua.raw.arrow.FuncRaw + import cats.data.State import cats.instances.list.* import cats.syntax.functor.* import cats.syntax.traverse.* +import cats.syntax.show.* /** * State algebra for resolved arrows @@ -20,16 +22,23 @@ trait Arrows[S] extends Scoped[S] { /** * Arrow is resolved – save it to the state [[S]] * - * @param arrow - * resolved arrow - * @param e - * contextual Exports that an arrow captures + * @param arrow resolved arrow + * @param topology captured topology */ - final def resolved(arrow: FuncRaw, topology: Option[String])(implicit e: Exports[S]): State[S, Unit] = + final def resolved( + arrow: FuncRaw, + topology: Option[String] + )(using Exports[S]): State[S, Unit] = for { - exps <- e.exports + exps <- Exports[S].exports arrs <- arrows - funcArrow = FuncArrow.fromRaw(arrow, arrs, exps, topology) + // _ = println(s"Resolved arrow: ${arrow.name}") + // _ = println(s"Captured var names: ${arrow.capturedVars}") + captuedVars = exps.filterKeys(arrow.capturedVars).toMap + capturedArrows = arrs.filterKeys(arrow.capturedVars).toMap + // _ = println(s"Captured vars: ${captuedVars}") + // _ = println(s"Captured arrows: ${capturedArrows}") + funcArrow = FuncArrow.fromRaw(arrow, capturedArrows, captuedVars, topology) _ <- save(arrow.name, funcArrow) } yield () @@ -63,7 +72,7 @@ trait Arrows[S] extends Scoped[S] { * @return */ def argsArrows(args: ArgsCall): State[S, Map[String, FuncArrow]] = - arrows.map(args.arrowArgs) + arrows.map(args.arrowArgsMap) /** * Changes the [[S]] type to [[R]] diff --git a/model/inline/src/main/scala/aqua/model/inline/state/Exports.scala b/model/inline/src/main/scala/aqua/model/inline/state/Exports.scala index f524dc06..efd173d7 100644 --- a/model/inline/src/main/scala/aqua/model/inline/state/Exports.scala +++ b/model/inline/src/main/scala/aqua/model/inline/state/Exports.scala @@ -127,10 +127,10 @@ object Exports { // Get last linked VarModel def getLastValue(name: String, state: Map[String, ValueModel]): Option[ValueModel] = { state.get(name) match { - case Some(vm@VarModel(n, _, _)) => + case Some(vm @ VarModel(n, _, _)) => if (name == n) Option(vm) else getLastValue(n, state).orElse(Option(vm)) - case lm@Some(LiteralModel(_, _)) => + case lm @ Some(LiteralModel(_, _)) => lm case _ => None @@ -140,9 +140,14 @@ object Exports { object Simple extends Exports[Map[String, ValueModel]] { // Make links from one set of abilities to another (for ability assignment) - private def getAbilityPairs(oldName: String, newName: String, at: AbilityType, state: Map[String, ValueModel]): NonEmptyList[(String, ValueModel)] = { + private def getAbilityPairs( + oldName: String, + newName: String, + at: AbilityType, + state: Map[String, ValueModel] + ): NonEmptyList[(String, ValueModel)] = { at.fields.toNel.flatMap { - case (n, at@AbilityType(_, _)) => + case (n, at @ AbilityType(_, _)) => val newFullName = AbilityType.fullName(newName, n) val oldFullName = AbilityType.fullName(oldName, n) getAbilityPairs(oldFullName, newFullName, at, state) @@ -168,11 +173,7 @@ object Exports { } override def getLastVarName(name: String): State[Map[String, ValueModel], Option[String]] = - State.get.map(st => getLastValue(name, st).flatMap { - case VarModel(name, _, _) => Option(name) - case LiteralModel(_, _) => - None - }) + State.get.map(st => getLastValue(name, st).collect { case VarModel(name, _, _) => name }) override def resolved(exports: Map[String, ValueModel]): State[Map[String, ValueModel], Unit] = State.modify(_ ++ exports) diff --git a/model/inline/src/main/scala/aqua/model/inline/state/Mangler.scala b/model/inline/src/main/scala/aqua/model/inline/state/Mangler.scala index bcc54ef3..15a73edb 100644 --- a/model/inline/src/main/scala/aqua/model/inline/state/Mangler.scala +++ b/model/inline/src/main/scala/aqua/model/inline/state/Mangler.scala @@ -17,6 +17,12 @@ trait Mangler[S] { _ <- forbid(Set(n)) } yield n + def findAndForbidNames(introduce: Set[String]): State[S, Map[String, String]] = + for { + n <- findNewNames(introduce) + _ <- forbid(introduce ++ n.values.toSet) + } yield n + def forbid(names: Set[String]): State[S, Unit] def forbidName(name: String): State[S, Unit] = diff --git a/model/inline/src/test/scala/aqua/model/inline/ArrowInlinerSpec.scala b/model/inline/src/test/scala/aqua/model/inline/ArrowInlinerSpec.scala index f9003148..f79fa353 100644 --- a/model/inline/src/test/scala/aqua/model/inline/ArrowInlinerSpec.scala +++ b/model/inline/src/test/scala/aqua/model/inline/ArrowInlinerSpec.scala @@ -5,15 +5,57 @@ import aqua.model.inline.state.InliningState import aqua.raw.ops.* import aqua.raw.value.* import aqua.types.* -import cats.syntax.show.* -import cats.syntax.option.* -import cats.data.{Chain, NonEmptyList, NonEmptyMap} -import org.scalatest.flatspec.AnyFlatSpec -import org.scalatest.matchers.should.Matchers import aqua.raw.value.{CallArrowRaw, ValueRaw} import aqua.raw.arrow.{ArrowRaw, FuncRaw} -class ArrowInlinerSpec extends AnyFlatSpec with Matchers { +import cats.Eval +import cats.syntax.show.* +import cats.syntax.option.* +import cats.syntax.flatMap.* +import cats.free.Cofree +import cats.data.{Chain, NonEmptyList, NonEmptyMap} +import org.scalatest.flatspec.AnyFlatSpec +import org.scalatest.matchers.should.Matchers +import org.scalatest.Inside + +class ArrowInlinerSpec extends AnyFlatSpec with Matchers with Inside { + + extension (tree: OpModel.Tree) { + + def collect[A](pf: PartialFunction[OpModel, A]): Chain[A] = + Cofree + .cata(tree)((op, children: Chain[Chain[A]]) => + Eval.later( + Chain.fromOption(pf.lift(op)) ++ children.flatten + ) + ) + .value + } + + def callFuncModel(func: FuncArrow): OpModel.Tree = + ArrowInliner + .callArrow[InliningState]( + FuncArrow( + "wrapper", + CallArrowRawTag + .func( + func.funcName, + Call(Nil, Nil) + ) + .leaf, + ArrowType( + ProductType(Nil), + ProductType(Nil) + ), + Nil, + Map(func.funcName -> func), + Map.empty, + None + ), + CallModel(Nil, Nil) + ) + .runA(InliningState()) + .value "arrow inliner" should "convert simple arrow" in { @@ -104,15 +146,20 @@ class ArrowInlinerSpec extends AnyFlatSpec with Matchers { ProductType(Nil) ), Nil, - Map("cb" -> cbArrow), + Map.empty, Map.empty, None ), CallModel(cbVal :: Nil, Nil) ) - .run(InliningState()) + .runA( + InliningState( + resolvedArrows = Map( + cbVal.name -> cbArrow + ) + ) + ) .value - ._2 model.equalsOrShowDiff( RestrictionModel(streamVar.name, streamType).wrap( @@ -136,6 +183,343 @@ class ArrowInlinerSpec extends AnyFlatSpec with Matchers { } + /** + * func returnStream() -> *string: + * stream: *string + * stream <<- "one" + * <- stream + * + * func rereturnStream() -> *string: + * stream <- returnStream() + * stream <<- "two" + * <- stream + * + * func testReturnStream() -> []string: + * stream <- rereturnStream() + * stream <<- "three" + * <- stream + */ + it should "handle returned stream" in { + val streamType = StreamType(ScalarType.string) + val streamVar = VarRaw("stream", streamType) + val canonStreamVar = VarRaw( + s"-${streamVar.name}-canon-0", + CanonStreamType(ScalarType.string) + ) + val flatStreamVar = VarRaw( + s"-${streamVar.name}-flat-0", + ArrayType(ScalarType.string) + ) + val returnStreamArrowType = ArrowType( + ProductType(Nil), + ProductType(streamType :: Nil) + ) + + val returnStream = FuncArrow( + "returnStream", + SeqTag.wrap( + DeclareStreamTag(streamVar).leaf, + PushToStreamTag( + LiteralRaw.quote("one"), + Call.Export(streamVar.name, streamVar.`type`) + ).leaf, + ReturnTag( + NonEmptyList.one(streamVar) + ).leaf + ), + returnStreamArrowType, + List(streamVar), + Map.empty, + Map.empty, + None + ) + + val rereturnStream = FuncArrow( + "rereturnStream", + SeqTag.wrap( + CallArrowRawTag + .func( + returnStream.funcName, + Call(Nil, Call.Export(streamVar.name, streamType) :: Nil) + ) + .leaf, + PushToStreamTag( + LiteralRaw.quote("two"), + Call.Export(streamVar.name, streamVar.`type`) + ).leaf, + ReturnTag( + NonEmptyList.one(streamVar) + ).leaf + ), + returnStreamArrowType, + List(streamVar), + Map(returnStream.funcName -> returnStream), + Map.empty, + None + ) + + val testReturnStream = FuncArrow( + "testReturnStream", + RestrictionTag(streamVar.name, streamType).wrap( + SeqTag.wrap( + CallArrowRawTag + .func( + rereturnStream.funcName, + Call(Nil, Call.Export(streamVar.name, streamType) :: Nil) + ) + .leaf, + PushToStreamTag( + LiteralRaw.quote("three"), + Call.Export(streamVar.name, streamVar.`type`) + ).leaf, + CanonicalizeTag( + streamVar, + Call.Export(canonStreamVar.name, canonStreamVar.`type`) + ).leaf, + FlattenTag( + canonStreamVar, + flatStreamVar.name + ).leaf, + ReturnTag( + NonEmptyList.one(flatStreamVar) + ).leaf + ) + ), + ArrowType( + ProductType(Nil), + ProductType(ArrayType(ScalarType.string) :: Nil) + ), + List(flatStreamVar), + Map(rereturnStream.funcName -> rereturnStream), + Map.empty, + None + ) + + val model = callFuncModel(testReturnStream) + + val result = model.collect { + case p: PushToStreamModel => p + case c: CanonicalizeModel => c + case f: FlattenModel => f + } + + val streamName = result.collectFirst { case PushToStreamModel(value, exportTo) => + exportTo.name + } + + val canonFlatNames = result.collect { + case FlattenModel(VarModel(name, _, Chain.nil), assingTo) => + (name, assingTo) + }.headOption + + inside(streamName) { case Some(streamName) => + inside(canonFlatNames) { case Some((canonName, flatName)) => + val canonExport = CallModel.Export( + canonName, + CanonStreamType(ScalarType.string) + ) + + val expected = Chain("one", "two", "three").map(s => + PushToStreamModel( + LiteralModel.quote(s), + CallModel.Export(streamName, streamType) + ) + ) ++ Chain( + CanonicalizeModel( + VarModel(streamName, streamType), + canonExport + ), + FlattenModel( + canonExport.asVar, + flatName + ) + ) + + result shouldEqual expected + } + } + } + + /** + * func return() -> (-> []string): + * result: *string + * + * result <<- "one" + * + * closure = () -> []string: + * result <<- "two-three" + * <- result + * + * closure() + * + * <- closure + * + * func testReturn() -> []string: + * closure <- return() + * res <- closure() + * <- res + */ + it should "handle stream captured in closure" in { + val streamType = StreamType(ScalarType.string) + val streamVar = VarRaw("status", streamType) + val resType = ArrayType(ScalarType.string) + val resVar = VarRaw("res", resType) + val canonStreamVar = VarRaw( + s"-${streamVar.name}-canon-0", + CanonStreamType(ScalarType.string) + ) + val flatStreamVar = VarRaw( + s"-${streamVar.name}-flat-0", + ArrayType(ScalarType.string) + ) + val closureType = ArrowType( + ProductType(Nil), + ProductType(resType :: Nil) + ) + val closureVar = VarRaw("closure", closureType) + + val closureFunc = FuncRaw( + "closure", + ArrowRaw( + ArrowType( + ProductType(Nil), + ProductType(ArrayType(ScalarType.string) :: Nil) + ), + List(flatStreamVar), + SeqTag.wrap( + PushToStreamTag( + LiteralRaw.quote("two-three"), + Call.Export(streamVar.name, streamVar.`type`) + ).leaf, + CanonicalizeTag( + streamVar, + Call.Export(canonStreamVar.name, canonStreamVar.`type`) + ).leaf, + FlattenTag( + canonStreamVar, + flatStreamVar.name + ).leaf + ) + ) + ) + + val returnFunc = FuncArrow( + "return", + SeqTag.wrap( + DeclareStreamTag(streamVar).leaf, + PushToStreamTag( + LiteralRaw.quote("one"), + Call.Export(streamVar.name, streamVar.`type`) + ).leaf, + ClosureTag( + closureFunc, + detach = false + ).leaf, + CallArrowRawTag + .func( + closureVar.name, + Call(Nil, Nil) + ) + .leaf, + ReturnTag( + NonEmptyList.one(closureVar) + ).leaf + ), + ArrowType( + ProductType(Nil), + ProductType(closureType :: Nil) + ), + List(closureVar), + Map.empty, + Map.empty, + None + ) + + val testFunc = FuncArrow( + "test", + SeqTag.wrap( + CallArrowRawTag + .func( + returnFunc.funcName, + Call(Nil, Call.Export(closureVar.name, closureType) :: Nil) + ) + .leaf, + CallArrowRawTag + .func( + closureVar.name, + Call(Nil, Call.Export(resVar.name, ArrayType(ScalarType.string)) :: Nil) + ) + .leaf, + ReturnTag( + NonEmptyList.one(resVar) + ).leaf + ), + ArrowType( + ProductType(Nil), + ProductType(ArrayType(ScalarType.string) :: Nil) + ), + List(resVar), + Map(returnFunc.funcName -> returnFunc), + Map.empty, + None + ) + + val model = callFuncModel(testFunc) + + val result = model.collect { + case p: PushToStreamModel => p + case c: CanonicalizeModel => c + case f: FlattenModel => f + } + + val streamName = model.collect { case PushToStreamModel(value, CallModel.Export(name, _)) => + name + }.headOption + + val canonFlatNames = model.collect { + case FlattenModel(VarModel(name, _, Chain.nil), assingTo) => + (name, assingTo) + }.toList + + // WARNING: This test does not take + // stream restriction into account + inside(streamName) { case Some(streamName) => + inside(canonFlatNames) { case (canonName1, flatName1) :: (canonName2, flatName2) :: Nil => + def canon(canonName: String, flatName: String) = { + val canonExport = CallModel.Export( + canonName, + CanonStreamType(ScalarType.string) + ) + + Chain( + CanonicalizeModel( + VarModel(streamName, streamType), + canonExport + ), + FlattenModel( + canonExport.asVar, + flatName + ) + ) + } + + val expected = Chain("one", "two-three").map(s => + PushToStreamModel( + LiteralModel.quote(s), + CallModel.Export(streamName, streamType) + ) + ) ++ canon(canonName1, flatName1) ++ Chain.one( + PushToStreamModel( + LiteralModel.quote("two-three"), + CallModel.Export(streamName, streamType) + ) + ) ++ canon(canonName2, flatName2) + + result shouldEqual expected + } + } + } + /* func stream-callback(cb: string -> ()): records: *string diff --git a/model/raw/src/main/scala/aqua/raw/arrow/FuncRaw.scala b/model/raw/src/main/scala/aqua/raw/arrow/FuncRaw.scala index b7813cde..36e0afce 100644 --- a/model/raw/src/main/scala/aqua/raw/arrow/FuncRaw.scala +++ b/model/raw/src/main/scala/aqua/raw/arrow/FuncRaw.scala @@ -11,4 +11,14 @@ case class FuncRaw( override def rename(s: String): RawPart = copy(name = s) override def rawPartType: Type = arrow.`type` + + def capturedVars: Set[String] = { + val freeBodyVars = arrow.body.usesVarNames.value + val argsNames = arrow.`type`.domain + .toLabelledList() + .map { case (name, _) => name } + .toSet + + freeBodyVars -- argsNames + } } diff --git a/model/raw/src/main/scala/aqua/raw/ops/RawTag.scala b/model/raw/src/main/scala/aqua/raw/ops/RawTag.scala index dfa5ee7b..e7162ac5 100644 --- a/model/raw/src/main/scala/aqua/raw/ops/RawTag.scala +++ b/model/raw/src/main/scala/aqua/raw/ops/RawTag.scala @@ -5,8 +5,10 @@ import aqua.raw.ops.RawTag.Tree import aqua.raw.value.{CallArrowRaw, ValueRaw} import aqua.tree.{TreeNode, TreeNodeCompanion} import aqua.types.{ArrowType, DataType} + import cats.Show import cats.data.{Chain, NonEmptyList} +import cats.syntax.foldable.* import cats.free.Cofree sealed trait RawTag extends TreeNode[RawTag] { @@ -20,6 +22,9 @@ sealed trait RawTag extends TreeNode[RawTag] { // All variable names introduced by this tag def definesVarNames: Set[String] = exportsVarNames ++ restrictsVarNames + // Variable names used by this tag (not introduced by it) + def usesVarNames: Set[String] = Set.empty + def mapValues(f: ValueRaw => ValueRaw): RawTag def renameExports(map: Map[String, String]): RawTag = this @@ -83,6 +88,8 @@ case object ParTag extends ParGroupTag { case class IfTag(value: ValueRaw) extends GroupTag { + override def usesVarNames: Set[String] = value.varNames + override def mapValues(f: ValueRaw => ValueRaw): RawTag = IfTag(value.map(f)) } @@ -119,6 +126,8 @@ case class OnTag( strategy: Option[OnTag.ReturnStrategy] = None ) extends SeqGroupTag { + override def usesVarNames: Set[String] = peerId.varNames ++ via.foldMap(_.varNames) + override def mapValues(f: ValueRaw => ValueRaw): RawTag = OnTag(peerId.map(f), via.map(_.map(f)), strategy) @@ -146,6 +155,8 @@ case class NextTag(item: String) extends RawTag { override def renameExports(map: Map[String, String]): RawTag = copy(item = map.getOrElse(item, item)) + override def usesVarNames: Set[String] = Set(item) + override def mapValues(f: ValueRaw => ValueRaw): RawTag = this } @@ -162,6 +173,8 @@ case class ForTag(item: String, iterable: ValueRaw, mode: Option[ForTag.Mode] = override def restrictsVarNames: Set[String] = Set(item) + override def usesVarNames: Set[String] = iterable.varNames + override def mapValues(f: ValueRaw => ValueRaw): RawTag = ForTag(item, iterable.map(f), mode) @@ -184,6 +197,8 @@ case class CallArrowRawTag( override def exportsVarNames: Set[String] = exportTo.map(_.name).toSet + override def usesVarNames: Set[String] = value.varNames + override def mapValues(f: ValueRaw => ValueRaw): RawTag = CallArrowRawTag(exportTo, value.map(f)) @@ -227,9 +242,13 @@ object CallArrowRawTag { } case class DeclareStreamTag( + // TODO: Why is it ValueRaw and + // not just (stream name, stream type)? value: ValueRaw ) extends RawTag { + override def exportsVarNames: Set[String] = value.varNames + override def mapValues(f: ValueRaw => ValueRaw): RawTag = DeclareStreamTag(value.map(f)) } @@ -239,6 +258,10 @@ case class AssignmentTag( assignTo: String ) extends NoExecTag { + override def exportsVarNames: Set[String] = Set(assignTo) + + override def usesVarNames: Set[String] = value.varNames + override def renameExports(map: Map[String, String]): RawTag = copy(assignTo = map.getOrElse(assignTo, assignTo)) @@ -251,6 +274,11 @@ case class ClosureTag( detach: Boolean ) extends NoExecTag { + override def exportsVarNames: Set[String] = Set(func.name) + + // FIXME: Is it correct? + override def usesVarNames: Set[String] = Set.empty + override def renameExports(map: Map[String, String]): RawTag = copy(func = func.copy(name = map.getOrElse(func.name, func.name))) @@ -269,6 +297,8 @@ case class ReturnTag( values: NonEmptyList[ValueRaw] ) extends NoExecTag { + override def usesVarNames: Set[String] = values.foldMap(_.varNames) + override def mapValues(f: ValueRaw => ValueRaw): RawTag = ReturnTag(values.map(_.map(f))) } @@ -282,13 +312,23 @@ case class AbilityIdTag( service: String ) extends NoExecTag { + override def usesVarNames: Set[String] = value.varNames + override def mapValues(f: ValueRaw => ValueRaw): RawTag = AbilityIdTag(value.map(f), service) } case class PushToStreamTag(operand: ValueRaw, exportTo: Call.Export) extends RawTag { - override def exportsVarNames: Set[String] = Set(exportTo.name) + /** + * NOTE: Pushing to a stream will create it, but we suppose + * that `DeclareStreamTag` exports stream and this tag does not + * to distinguish cases when stream is captured from outside. + * This is why `exportTo` is not in `exportsVarNames`. + */ + override def exportsVarNames: Set[String] = Set.empty + + override def usesVarNames: Set[String] = operand.varNames + exportTo.name override def mapValues(f: ValueRaw => ValueRaw): RawTag = PushToStreamTag(operand.map(f), exportTo) @@ -303,6 +343,8 @@ case class FlattenTag(operand: ValueRaw, assignTo: String) extends RawTag { override def exportsVarNames: Set[String] = Set(assignTo) + override def usesVarNames: Set[String] = operand.varNames + override def mapValues(f: ValueRaw => ValueRaw): RawTag = FlattenTag(operand.map(f), assignTo) @@ -316,6 +358,8 @@ case class CanonicalizeTag(operand: ValueRaw, exportTo: Call.Export) extends Raw override def exportsVarNames: Set[String] = Set(exportTo.name) + override def usesVarNames: Set[String] = operand.varNames + override def mapValues(f: ValueRaw => ValueRaw): RawTag = CanonicalizeTag(operand.map(f), exportTo) @@ -327,6 +371,8 @@ case class CanonicalizeTag(operand: ValueRaw, exportTo: Call.Export) extends Raw case class JoinTag(operands: NonEmptyList[ValueRaw]) extends RawTag { + override def usesVarNames: Set[String] = operands.foldMap(_.varNames) + override def mapValues(f: ValueRaw => ValueRaw): RawTag = JoinTag(operands.map(_.map(f))) diff --git a/model/raw/src/main/scala/aqua/raw/ops/RawTagGivens.scala b/model/raw/src/main/scala/aqua/raw/ops/RawTagGivens.scala index d8b6cdb4..2738adb5 100644 --- a/model/raw/src/main/scala/aqua/raw/ops/RawTagGivens.scala +++ b/model/raw/src/main/scala/aqua/raw/ops/RawTagGivens.scala @@ -1,11 +1,14 @@ package aqua.raw.ops -import aqua.raw.value.LiteralRaw +import aqua.raw.value.{LiteralRaw, ValueRaw} + import cats.free.Cofree import cats.data.Chain import cats.{Eval, Semigroup} import cats.syntax.apply.* import cats.syntax.semigroup.* +import cats.syntax.foldable.* +import cats.syntax.all.* trait RawTagGivens { @@ -31,6 +34,9 @@ trait RawTagGivens { if (vals.isEmpty) tree else tree.map(_.mapValues(_.renameVars(vals)).renameExports(vals)) + def mapValues(f: ValueRaw => ValueRaw): RawTag.Tree = + tree.map(_.mapValues(f)) + def renameExports(vals: Map[String, String]): RawTag.Tree = if (vals.isEmpty) tree else tree.map(_.renameExports(vals)) @@ -39,4 +45,25 @@ trait RawTagGivens { Cofree.cata(tree) { case (tag, acc) => Eval.later(acc.foldLeft(tag.definesVarNames)(_ ++ _)) } + + /** + * Get all variable names used by this tree + * but not exported in it (free variables). + */ + def usesVarNames: Eval[Set[String]] = + Cofree + .cata(tree)((tag, childs: Chain[(Set[String], Set[String])]) => + Eval.later { + val (childExports, childUses) = childs.combineAll + val exports = tag.exportsVarNames ++ childExports -- tag.restrictsVarNames + val uses = tag.usesVarNames ++ childUses -- exports + (exports, uses) + } + ) + .map { case (_, uses) => uses } + + private def collect[A](pf: PartialFunction[RawTag, A]): Eval[Chain[A]] = + Cofree.cata(tree)((tag, acc: Chain[Chain[A]]) => + Eval.later(Chain.fromOption(pf.lift(tag)) ++ acc.flatten) + ) } diff --git a/model/raw/src/main/scala/aqua/raw/value/ValueRaw.scala b/model/raw/src/main/scala/aqua/raw/value/ValueRaw.scala index 6e634c1c..f5f46702 100644 --- a/model/raw/src/main/scala/aqua/raw/value/ValueRaw.scala +++ b/model/raw/src/main/scala/aqua/raw/value/ValueRaw.scala @@ -45,6 +45,9 @@ object ValueRaw { "%last_error%", lastErrorType ) + + type ApplyRaw = ApplyGateRaw | ApplyPropertyRaw | CallArrowRaw | CollectionRaw | + ApplyBinaryOpRaw | ApplyUnaryOpRaw } case class ApplyPropertyRaw(value: ValueRaw, property: PropertyRaw) extends ValueRaw { @@ -55,9 +58,8 @@ case class ApplyPropertyRaw(value: ValueRaw, property: PropertyRaw) extends Valu override def renameVars(map: Map[String, String]): ValueRaw = ApplyPropertyRaw(value.renameVars(map), property.renameVars(map)) - override def map(f: ValueRaw => ValueRaw): ValueRaw = f( - ApplyPropertyRaw(f(value), property.map(f)) - ) + override def map(f: ValueRaw => ValueRaw): ValueRaw = + f(ApplyPropertyRaw(f(value), property.map(_.map(f)))) override def toString: String = s"$value.$property" @@ -88,7 +90,8 @@ case class ApplyGateRaw(name: String, streamType: StreamType, idx: ValueRaw) ext override def renameVars(map: Map[String, String]): ValueRaw = copy(name = map.getOrElse(name, name), idx = idx.renameVars(map)) - override def map(f: ValueRaw => ValueRaw): ValueRaw = this + override def map(f: ValueRaw => ValueRaw): ValueRaw = + f(copy(idx = f(idx))) override def toString: String = s"gate $name.$idx" @@ -100,7 +103,7 @@ case class VarRaw(name: String, baseType: Type) extends ValueRaw { override def map(f: ValueRaw => ValueRaw): ValueRaw = f(this) override def renameVars(map: Map[String, String]): ValueRaw = - copy(map.getOrElse(name, name)) + copy(name = map.getOrElse(name, name)) override def toString: String = s"var{$name: " + baseType + s"}" @@ -169,9 +172,8 @@ case class AbilityRaw(fieldsAndArrows: NonEmptyMap[String, ValueRaw], abilityTyp override def baseType: Type = abilityType - override def map(f: ValueRaw => ValueRaw): ValueRaw = f( - copy(fieldsAndArrows = fieldsAndArrows.map(f)) - ) + override def map(f: ValueRaw => ValueRaw): ValueRaw = + f(copy(fieldsAndArrows = fieldsAndArrows.map(f))) override def varNames: Set[String] = { fieldsAndArrows.toSortedMap.values.flatMap(_.varNames).toSet @@ -246,7 +248,12 @@ case class CallArrowRaw( override def `type`: Type = baseType.codomain.uncons.map(_._1).getOrElse(baseType) override def map(f: ValueRaw => ValueRaw): ValueRaw = - f(copy(arguments = arguments.map(f))) + f( + copy( + arguments = arguments.map(f), + serviceId = serviceId.map(f) + ) + ) override def varNames: Set[String] = arguments.flatMap(_.varNames).toSet @@ -256,9 +263,8 @@ case class CallArrowRaw( .get(name) // Rename only if it is **not** a service or ability call, see [bug LNG-199] .filterNot(_ => ability.isDefined) - .getOrElse(name), - arguments = arguments.map(_.renameVars(map)), - serviceId = serviceId.map(_.renameVars(map)) + .filterNot(_ => serviceId.isDefined) + .getOrElse(name) ) override def toString: String = diff --git a/model/src/main/scala/aqua/model/ArgsCall.scala b/model/src/main/scala/aqua/model/ArgsCall.scala index 7befb3c2..7244f7eb 100644 --- a/model/src/main/scala/aqua/model/ArgsCall.scala +++ b/model/src/main/scala/aqua/model/ArgsCall.scala @@ -5,6 +5,8 @@ import aqua.raw.ops.Call import aqua.raw.value.{ValueRaw, VarRaw} import aqua.types.* +import cats.syntax.foldable.* + /** * Wraps argument definitions of a function, along with values provided when this function is called * @@ -18,26 +20,91 @@ case class ArgsCall(args: ProductType, callWith: List[ValueModel]) { // and values (value models and types how they seen on the call site) private lazy val zipped: List[((String, Type), ValueModel)] = args.toLabelledList() zip callWith + /** + * Names of arguments as they defined in the function definition + */ + lazy val argNames: Set[String] = args + .toLabelledList() + .map { case (name, _) => name } + .toSet + + /** + * Data arguments (except streams) as mapping + * Name of argument -> value passed in the call + */ lazy val dataArgs: Map[String, ValueModel] = - zipped.collect { case ((name, _: DataType), value) => - name -> value + zipped.collect { + case ((name, _: DataType), value) if !streamArgs.contains(name) => + name -> value }.toMap + /** + * Ability arguments as mapping + * Name of argument -> (variable passed in the call, ability type) + */ lazy val abilityArgs: Map[String, (VarModel, AbilityType)] = - zipped.collect { case (k, vr@VarModel(_, t@AbilityType(_, _), _)) => - k._1 -> (vr, t) + zipped.collect { case ((name, _), vr @ VarModel(_, t @ AbilityType(_, _), _)) => + name -> (vr, t) }.toMap - lazy val streamArgs: Map[String, VarModel] = - dataArgs.collect { case (k, vr @ VarModel(n, StreamType(_), _)) => - (k, vr) + /** + * All renamings from ability arguments as mapping + * Name inside function body -> name in the call context + */ + lazy val abilityArgsRenames: Map[String, String] = + abilityArgs.toList.foldMap { case (name, (vm, at)) => + at.arrows.keys + .map(arrowPath => + val fullName = AbilityType.fullName(name, arrowPath) + val newFullName = AbilityType.fullName(vm.name, arrowPath) + fullName -> newFullName + ) + .toMap + .updated(name, vm.name) } - def arrowArgs[T](arrowsInScope: Map[String, T]): Map[String, T] = - zipped.collect { - case ((name, _: ArrowType), VarModel(value, _, _)) if arrowsInScope.contains(value) => - name -> arrowsInScope(value) + /** + * Stream arguments as mapping + * Name of argument -> variable passed in the call + * NOTE: Argument is stream if it is passed as stream + * on the call site. Type of argument in the function + * definition does not matter. + */ + lazy val streamArgs: Map[String, VarModel] = + zipped.collect { case ((name, _), vr @ VarModel(_, StreamType(_), _)) => + name -> vr }.toMap + + /** + * All renamings from stream arguments as mapping + * Name inside function body -> name in the call context + */ + lazy val streamArgsRenames: Map[String, String] = + streamArgs.view.mapValues(_.name).toMap + + /** + * Arrow arguments as mapping + * Name of argument -> variable passed in the call + */ + lazy val arrowArgs: Map[String, VarModel] = + zipped.collect { case ((name, _: ArrowType), vm: VarModel) => + name -> vm + }.toMap + + /** + * All renamings from arrow arguments as mapping + * Name inside function body -> name in the call context + */ + lazy val arrowArgsRenames: Map[String, String] = + arrowArgs.view.mapValues(_.name).toMap + + def arrowArgsMap[T](arrows: Map[String, T]): Map[String, T] = + arrowArgs.view + .mapValues(_.name) + .flatMap { case (name, argName) => + arrows.get(argName).map(name -> _) + } + .toMap } object ArgsCall { diff --git a/model/src/main/scala/aqua/model/FuncArrow.scala b/model/src/main/scala/aqua/model/FuncArrow.scala index aa61b3bd..16651f0b 100644 --- a/model/src/main/scala/aqua/model/FuncArrow.scala +++ b/model/src/main/scala/aqua/model/FuncArrow.scala @@ -3,7 +3,7 @@ package aqua.model import aqua.raw.Raw import aqua.raw.arrow.FuncRaw import aqua.raw.ops.RawTag -import aqua.raw.value.ValueRaw +import aqua.raw.value.{ValueRaw, VarRaw} import aqua.types.{ArrowType, Type} case class FuncArrow( @@ -17,7 +17,11 @@ case class FuncArrow( ) { lazy val args: List[(String, Type)] = arrowType.domain.toLabelledList() - lazy val argNames: List[String] = args.map(_._1) + + lazy val argNames: List[String] = args.map { case (name, _) => name } + + lazy val returnedArrows: Set[String] = + ret.collect { case VarRaw(name, _: ArrowType) => name }.toSet } diff --git a/model/transform/src/main/scala/aqua/model/transform/Transform.scala b/model/transform/src/main/scala/aqua/model/transform/Transform.scala index 7917d916..1d57e32f 100644 --- a/model/transform/src/main/scala/aqua/model/transform/Transform.scala +++ b/model/transform/src/main/scala/aqua/model/transform/Transform.scala @@ -1,8 +1,5 @@ package aqua.model.transform -import cats.syntax.show.* -import cats.syntax.traverse.* -import cats.instances.list.* import aqua.model.inline.ArrowInliner import aqua.model.inline.state.InliningState import aqua.model.transform.funcop.* @@ -13,13 +10,17 @@ import aqua.raw.ops.RawTag import aqua.raw.value.VarRaw import aqua.res.* import aqua.types.ScalarType +import aqua.model.transform.TransformConfig.TracingConfig +import aqua.model.transform.pre.{CallbackErrorHandler, ErrorHandler} + import cats.Eval import cats.data.Chain import cats.free.Cofree import cats.syntax.option.* +import cats.syntax.show.* +import cats.syntax.traverse.* +import cats.instances.list.* import scribe.Logging -import aqua.model.transform.TransformConfig.TracingConfig -import aqua.model.transform.pre.{CallbackErrorHandler, ErrorHandler} // API for transforming RawTag to Res object Transform extends Logging { diff --git a/model/transform/src/main/scala/aqua/model/transform/pre/ArgsProvider.scala b/model/transform/src/main/scala/aqua/model/transform/pre/ArgsProvider.scala index 316db843..6d378e66 100644 --- a/model/transform/src/main/scala/aqua/model/transform/pre/ArgsProvider.scala +++ b/model/transform/src/main/scala/aqua/model/transform/pre/ArgsProvider.scala @@ -3,15 +3,28 @@ package aqua.model.transform.pre import aqua.raw.ops.* import aqua.raw.value.{ValueRaw, VarRaw} import aqua.types.{ArrayType, DataType, StreamType} + import cats.data.Chain trait ArgsProvider { - def provideArgs(args: List[(String, DataType)]): List[RawTag.Tree] + def provideArgs(args: List[ArgsProvider.Arg]): List[RawTag.Tree] +} + +object ArgsProvider { + + final case class Arg( + // Actual name of the argument + name: String, + // Variable name to store the value of the argument + varName: String, + // Type of the argument + t: DataType + ) } case class ArgsFromService(dataServiceId: ValueRaw) extends ArgsProvider { - private def getStreamDataOp(name: String, t: StreamType): RawTag.Tree = { + private def getStreamDataOp(name: String, varName: String, t: StreamType): RawTag.Tree = { val iter = s"$name-iter" val item = s"$name-item" SeqTag.wrap( @@ -24,28 +37,28 @@ case class ArgsFromService(dataServiceId: ValueRaw) extends ArgsProvider { .leaf, ForTag(item, VarRaw(iter, ArrayType(t.element))).wrap( SeqTag.wrap( - PushToStreamTag(VarRaw(item, t.element), Call.Export(name, t)).leaf, + PushToStreamTag(VarRaw(item, t.element), Call.Export(varName, t)).leaf, NextTag(item).leaf ) ) ) } - def getDataOp(name: String, t: DataType): RawTag.Tree = - t match { + def getDataOp(arg: ArgsProvider.Arg): RawTag.Tree = + arg.t match { case st: StreamType => - getStreamDataOp(name, st) + getStreamDataOp(arg.name, arg.varName, st) case _ => CallArrowRawTag .service( dataServiceId, - name, - Call(Nil, Call.Export(name, t) :: Nil) + arg.name, + Call(Nil, Call.Export(arg.varName, arg.t) :: Nil) ) .leaf } - override def provideArgs(args: List[(String, DataType)]): List[RawTag.Tree] = - args.map(getDataOp.tupled) + override def provideArgs(args: List[ArgsProvider.Arg]): List[RawTag.Tree] = + args.map(getDataOp) } diff --git a/model/transform/src/main/scala/aqua/model/transform/pre/FuncPreTransformer.scala b/model/transform/src/main/scala/aqua/model/transform/pre/FuncPreTransformer.scala index 13a64cf2..b72a8cfd 100644 --- a/model/transform/src/main/scala/aqua/model/transform/pre/FuncPreTransformer.scala +++ b/model/transform/src/main/scala/aqua/model/transform/pre/FuncPreTransformer.scala @@ -22,7 +22,7 @@ case class FuncPreTransformer( private val returnVar: String = "-return-" - private val relayVar = relayVarName.map(_ -> ScalarType.string) + private val relayArg = relayVarName.map(name => ArgsProvider.Arg(name, name, ScalarType.string)) /** * Convert an arrow-type argument to init user's callback @@ -59,13 +59,30 @@ case class FuncPreTransformer( case t => t }).toLabelledList(returnVar) + /** + * Arguments list (argument name, variable name, argument type). + * We need to give other names to arguments because they can + * collide with the name of the function itself. + */ + val args = func.arrowType.domain.toLabelledList().map { case (name, typ) => + (name, s"-$name-arg-", typ) + } + + val dataArgs = args.collect { case (name, varName, t: DataType) => + ArgsProvider.Arg(name, varName, t) + } + + val arrowArgs = args.collect { case (name, argName, arrowType: ArrowType) => + argName -> arrowToCallback(name, arrowType) + }.toMap + val funcCall = Call( - func.arrowType.domain.toLabelledList().map(ad => VarRaw(ad._1, ad._2)), + args.map { case (_, varName, t) => VarRaw(varName, t) }, returnType.map { case (l, t) => Call.Export(l, t) } ) val provideArgs = argsProvider.provideArgs( - relayVar.toList ::: func.arrowType.domain.labelledData + relayArg.toList ::: dataArgs ) val handleResults = resultsHandler.handleResults( @@ -90,12 +107,7 @@ case class FuncPreTransformer( body, ArrowType(ConsType.cons(func.funcName, func.arrowType, NilType), NilType), Nil, - func.arrowType.domain - .toLabelledList() - .collect { case (argName, arrowType: ArrowType) => - argName -> arrowToCallback(argName, arrowType) - } - .toMap, + arrowArgs, Map.empty, None ) diff --git a/semantics/src/main/scala/aqua/semantics/expr/func/ArrowSem.scala b/semantics/src/main/scala/aqua/semantics/expr/func/ArrowSem.scala index 009e5a5c..a5739e7f 100644 --- a/semantics/src/main/scala/aqua/semantics/expr/func/ArrowSem.scala +++ b/semantics/src/main/scala/aqua/semantics/expr/func/ArrowSem.scala @@ -14,10 +14,14 @@ import aqua.semantics.rules.locations.LocationsAlgebra import aqua.semantics.rules.names.NamesAlgebra import aqua.semantics.rules.types.TypesAlgebra import aqua.types.{ArrayType, ArrowType, CanonStreamType, ProductType, StreamType, Type} + +import cats.Eval import cats.data.{Chain, NonEmptyList} import cats.free.{Cofree, Free} import cats.syntax.applicative.* import cats.syntax.apply.* +import cats.syntax.foldable.* +import cats.syntax.bifunctor.* import cats.syntax.flatMap.* import cats.syntax.functor.* import cats.syntax.traverse.* @@ -32,137 +36,93 @@ class ArrowSem[S[_]](val expr: ArrowExpr[S]) extends AnyVal { N: NamesAlgebra[S, Alg], A: AbilitiesAlgebra[S, Alg], L: LocationsAlgebra[S, Alg] - ): Alg[ArrowType] = - // Begin scope -- for mangling - A.beginScope(arrowTypeExpr) *> L.beginScope() *> N.beginScope(arrowTypeExpr) *> T - .beginArrowScope( - arrowTypeExpr - ) - .flatMap((arrowType: ArrowType) => - // Create local variables - expr.arrowTypeExpr.args - .flatMap(_._1) - .zip( - arrowType.domain.toList - ) - .traverse { - case (argName, t: ArrowType) => - N.defineArrow(argName, t, isRoot = false) - case (argName, t) => - N.define(argName, t) - } - .as(arrowType) - ) + ): Alg[ArrowType] = for { + arrowType <- T.beginArrowScope(arrowTypeExpr) + // Create local variables + _ <- expr.arrowTypeExpr.args.flatMap { case (name, _) => name } + .zip(arrowType.domain.toList) + .traverse { + case (argName, t: ArrowType) => + N.defineArrow(argName, t, isRoot = false) + case (argName, t) => + N.define(argName, t) + } + } yield arrowType - private def assignRaw( - v: ValueRaw, - idx: Int, - body: RawTag.Tree, - returnAcc: Chain[ValueRaw] - ): (SeqTag.Tree, Chain[ValueRaw], Int) = { - val assignedReturnVar = VarRaw(s"-return-fix-$idx", v.`type`) - ( - SeqTag.wrap( - body :: AssignmentTag( - v, - assignedReturnVar.name - ).leaf :: Nil: _* - ), - returnAcc :+ assignedReturnVar, - idx + 1 - ) - } - - def after[Alg[_]: Monad](funcArrow: ArrowType, bodyGen: Raw)(implicit + def after[Alg[_]: Monad]( + funcArrow: ArrowType, + bodyGen: Raw + )(using T: TypesAlgebra[S, Alg], N: NamesAlgebra[S, Alg], A: AbilitiesAlgebra[S, Alg], L: LocationsAlgebra[S, Alg] - ): Alg[Raw] = - A.endScope() *> ( - N.streamsDefinedWithinScope(), - T.endArrowScope(expr.arrowTypeExpr) - .flatMap(retValues => N.getDerivedFrom(retValues.map(_.varNames)).map(retValues -> _)) - ).mapN { - case ( - streamsInScope: Map[String, StreamType], - (retValues: List[ValueRaw], retValuesDerivedFrom: List[Set[String]]) - ) => - bodyGen match { - case FuncOp(bodyModel) => - // TODO: wrap with local on...via... + ): Alg[Raw] = for { + streamsInScope <- N.streamsDefinedWithinScope() + retValues <- T.endArrowScope(expr.arrowTypeExpr) + retValuesDerivedFrom <- N.getDerivedFrom(retValues.map(_.varNames)) + res = bodyGen match { + case FuncOp(bodyModel) => + // TODO: wrap with local on...via... - // These streams are returned as streams - val retStreams: Map[String, Option[Type]] = - (retValues zip funcArrow.codomain.toList).collect { - case (VarRaw(n, StreamType(_)), StreamType(_)) => n -> None - case (VarRaw(n, StreamType(_)), t) => n -> Some(t) - }.toMap + val retsAndArgs = retValues zip funcArrow.codomain.toList - val streamsThatReturnAsStreams = retStreams.collect { case (n, None) => - n - }.toSet + val argNames = funcArrow.domain.labelledData.map { case (name, _) => name } + val streamsThatReturnAsStreams = retsAndArgs.collect { + case (VarRaw(n, StreamType(_)), StreamType(_)) => n + }.toSet - val streamArguments = funcArrow.domain.labelledData.map(_._1) + // Remove arguments, and values returned as streams + val localStreams = streamsInScope -- argNames -- streamsThatReturnAsStreams - // Remove stream arguments, and values returned as streams - val localStreams = streamsInScope -- streamArguments -- streamsThatReturnAsStreams + // process stream that returns as not streams and all Apply*Raw + val (bodyRets, retVals) = retsAndArgs.mapWithIndex { + case ((v @ VarRaw(_, StreamType(_)), StreamType(_)), _) => + (Chain.empty, v) + // canonicalize and change return value + case ((VarRaw(streamName, streamType @ StreamType(streamElement)), _), idx) => + val canonReturnVar = VarRaw(s"-$streamName-fix-$idx", CanonStreamType(streamElement)) + val returnVar = VarRaw(s"-$streamName-flat-$idx", ArrayType(streamElement)) + val body = Chain( + CanonicalizeTag( + VarRaw(streamName, streamType), + Call.Export(canonReturnVar.name, canonReturnVar.`type`) + ).leaf, + FlattenTag( + canonReturnVar, + returnVar.name + ).leaf + ) - // process stream that returns as not streams and all Apply*Raw - val (bodyModified, returnValuesModified, _) = (retValues zip funcArrow.codomain.toList) - .foldLeft[(RawTag.Tree, Chain[ValueRaw], Int)]((bodyModel, Chain.empty, 0)) { - case ((bodyAcc, returnAcc, idx), rets) => - rets match { - // do nothing - case (v @ VarRaw(_, StreamType(_)), StreamType(_)) => - (bodyAcc, returnAcc :+ v, idx) - // canonicalize and change return value - case (VarRaw(streamName, streamType @ StreamType(streamElement)), _) => - val canonReturnVar = - VarRaw(s"-$streamName-fix-$idx", CanonStreamType(streamElement)) + (body, returnVar) + // assign and change return value for all `Apply*Raw` + case ((v: ValueRaw.ApplyRaw, _), idx) => + val assignedReturnVar = VarRaw(s"-return-fix-$idx", v.`type`) + val body = Chain.one( + AssignmentTag( + v, + assignedReturnVar.name + ).leaf + ) - val returnVar = - VarRaw(s"-$streamName-flat-$idx", ArrayType(streamElement)) + (body, assignedReturnVar) + case ((v, _), _) => (Chain.empty, v) + }.unzip.leftMap(_.combineAll) - ( - SeqTag.wrap( - bodyAcc, - CanonicalizeTag( - VarRaw(streamName, streamType), - Call.Export(canonReturnVar.name, canonReturnVar.`type`) - ).leaf, - FlattenTag( - canonReturnVar, - returnVar.name - ).leaf - ), - returnAcc :+ returnVar, - idx + 1 - ) - // assign and change return value for all `Apply*Raw` - case ( - v: (ApplyGateRaw | ApplyPropertyRaw | CallArrowRaw | CollectionRaw | - ApplyBinaryOpRaw | ApplyUnaryOpRaw), - _ - ) => - assignRaw(v, idx, bodyAcc, returnAcc) + val bodyModified = SeqTag.wrap( + bodyModel +: bodyRets + ) - case (v, _) => (bodyAcc, returnAcc :+ v, idx) - } - - } - - // wrap streams with restrictions - val bodyWithRestrictions = localStreams.foldLeft(bodyModified) { - case (bm, (streamName, streamType)) => - RestrictionTag(streamName, streamType).wrap(bm) - } - - ArrowRaw(funcArrow, returnValuesModified.toList, bodyWithRestrictions) - case bodyModel => - bodyModel + // wrap streams with restrictions + val bodyWithRestrictions = localStreams.foldLeft(bodyModified) { + case (bm, (streamName, streamType)) => + RestrictionTag(streamName, streamType).wrap(bm) } - } <* N.endScope() <* L.endScope() + + ArrowRaw(funcArrow, retVals, bodyWithRestrictions) + case _ => Raw.error("Invalid arrow body") + } + } yield res def program[Alg[_]: Monad](implicit T: TypesAlgebra[S, Alg], @@ -170,9 +130,13 @@ class ArrowSem[S[_]](val expr: ArrowExpr[S]) extends AnyVal { A: AbilitiesAlgebra[S, Alg], L: LocationsAlgebra[S, Alg] ): Prog[Alg, Raw] = - Prog.around( - before[Alg], - after[Alg] - ) + Prog + .around( + before[Alg], + after[Alg] + ) + .abilitiesScope(expr.arrowTypeExpr) + .namesScope(expr.arrowTypeExpr) + .locationsScope() } diff --git a/semantics/src/main/scala/aqua/semantics/expr/func/AssignmentSem.scala b/semantics/src/main/scala/aqua/semantics/expr/func/AssignmentSem.scala index 2186926b..6f401e61 100644 --- a/semantics/src/main/scala/aqua/semantics/expr/func/AssignmentSem.scala +++ b/semantics/src/main/scala/aqua/semantics/expr/func/AssignmentSem.scala @@ -3,7 +3,7 @@ package aqua.semantics.expr.func import aqua.raw.Raw import aqua.types.ArrowType import aqua.raw.value.CallArrowRaw -import aqua.raw.ops.{AssignmentTag, ClosureTag} +import aqua.raw.ops.AssignmentTag import aqua.parser.expr.func.AssignmentExpr import aqua.raw.arrow.FuncRaw import aqua.semantics.Prog diff --git a/semantics/src/main/scala/aqua/semantics/rules/StackInterpreter.scala b/semantics/src/main/scala/aqua/semantics/rules/StackInterpreter.scala index 084b1b41..dd7e79f7 100644 --- a/semantics/src/main/scala/aqua/semantics/rules/StackInterpreter.scala +++ b/semantics/src/main/scala/aqua/semantics/rules/StackInterpreter.scala @@ -1,12 +1,16 @@ package aqua.semantics.rules import aqua.parser.lexer.Token -import cats.data.State -import monocle.Lens -import cats.syntax.functor.* import aqua.semantics.rules.errors.ReportErrors -case class StackInterpreter[S[_], X, St, Fr](stackLens: Lens[St, List[Fr]])(implicit +import cats.data.State +import cats.syntax.functor.* +import cats.syntax.applicative.* +import monocle.Lens + +case class StackInterpreter[S[_], X, St, Fr]( + stackLens: Lens[St, List[Fr]] +)(using lens: Lens[X, St], error: ReportErrors[S, X] ) { @@ -24,28 +28,19 @@ case class StackInterpreter[S[_], X, St, Fr](stackLens: Lens[St, List[Fr]])(impl def modify(f: St => St): SX[Unit] = State.modify(lens.modify(f)) - def mapStackHead[A](ifStackEmpty: SX[A])(f: Fr => (Fr, A)): SX[A] = - getState.map(stackLens.get).flatMap { - case h :: tail => - val (updated, result) = f(h) - modify(stackLens.replace(updated :: tail)).as(result) - case Nil => - ifStackEmpty - } + def mapStackHead[A](ifStackEmpty: A)(f: Fr => (Fr, A)): SX[A] = + mapStackHeadM(ifStackEmpty.pure)(f.andThen(_.pure)) - def mapStackHeadE[A]( - ifStackEmpty: SX[A] - )(f: Fr => Either[(Token[S], String, A), (Fr, A)]): SX[A] = + def mapStackHead_(f: Fr => Fr): SX[Unit] = + mapStackHead(())(f.andThen(_ -> ())) + + def mapStackHeadM[A](ifStackEmpty: SX[A])(f: Fr => SX[(Fr, A)]): SX[A] = getState.map(stackLens.get).flatMap { - case h :: tail => - f(h) match { - case Right((updated, result)) => - modify(stackLens.replace(updated :: tail)).as(result) - case Left((tkn, hint, result)) => - report(tkn, hint).as(result) + case head :: tail => + f(head).flatMap { case (updated, result) => + modify(stackLens.replace(updated :: tail)).as(result) } - case Nil => - ifStackEmpty + case Nil => ifStackEmpty } def endScope: SX[Unit] = diff --git a/semantics/src/main/scala/aqua/semantics/rules/abilities/AbilitiesInterpreter.scala b/semantics/src/main/scala/aqua/semantics/rules/abilities/AbilitiesInterpreter.scala index 6ff60c63..23a8b5e0 100644 --- a/semantics/src/main/scala/aqua/semantics/rules/abilities/AbilitiesInterpreter.scala +++ b/semantics/src/main/scala/aqua/semantics/rules/abilities/AbilitiesInterpreter.scala @@ -6,11 +6,14 @@ import aqua.raw.{RawContext, ServiceRaw} import aqua.semantics.Levenshtein import aqua.semantics.rules.errors.ReportErrors import aqua.semantics.rules.locations.LocationsAlgebra -import aqua.semantics.rules.{StackInterpreter, abilities} +import aqua.semantics.rules.{abilities, StackInterpreter} import aqua.types.ArrowType + import cats.data.{NonEmptyMap, State} import cats.syntax.functor.* +import cats.syntax.foldable.* import cats.syntax.traverse.* +import cats.syntax.applicative.* import monocle.Lens import monocle.macros.GenLens @@ -26,7 +29,7 @@ class AbilitiesInterpreter[S[_], X](implicit GenLens[AbilitiesState[S]](_.stack) ) - import stackInt.{getState, mapStackHead, modify, report} + import stackInt.{getState, mapStackHead, mapStackHeadM, modify, report} override def defineService( name: NamedTypeToken[S], @@ -35,30 +38,34 @@ class AbilitiesInterpreter[S[_], X](implicit ): SX[Boolean] = getService(name.value).flatMap { case Some(_) => - getState.map(_.definitions.get(name.value).exists(_ == name)).flatMap { - case true => State.pure(false) - case false => report(name, "Service with this name was already defined").as(false) - - } + getState + .map(_.definitions.get(name.value).exists(_ == name)) + .flatMap(exists => + report( + name, + "Service with this name was already defined" + ).whenA(!exists) + ) + .as(false) case None => - arrows.toNel.map(_._2).collect { - case (n, arr) if arr.codomain.length > 1 => + for { + _ <- arrows.toNel.traverse_ { case (_, (n, arr)) => report(n, "Service functions cannot have multiple results") - }.sequence.flatMap{ _ => - modify(s => + .whenA(arr.codomain.length > 1) + } + _ <- modify(s => s.copy( services = s.services .updated(name.value, ServiceRaw(name.value, arrows.map(_._2), defaultId)), definitions = s.definitions.updated(name.value, name) ) - ).flatMap { _ => - locations.addTokenWithFields( - name.value, - name, - arrows.toNel.toList.map(t => t._1 -> t._2._1) - ) - }.as(true) - } + ) + _ <- locations.addTokenWithFields( + name.value, + name, + arrows.toNel.toList.map(t => t._1 -> t._2._1) + ) + } yield true } // adds location from token to its definition @@ -107,11 +114,10 @@ class AbilitiesInterpreter[S[_], X](implicit override def setServiceId(name: NamedTypeToken[S], id: ValueToken[S], vm: ValueRaw): SX[Boolean] = getService(name.value).flatMap { case Some(_) => - mapStackHead( + mapStackHeadM( modify(st => st.copy(rootServiceIds = st.rootServiceIds.updated(name.value, id -> vm))) .as(true) - )(h => h.copy(serviceIds = h.serviceIds.updated(name.value, id -> vm)) -> true) - + )(h => (h.copy(serviceIds = h.serviceIds.updated(name.value, id -> vm)) -> true).pure) case None => report(name, "Service with this name is not registered, can't set its ID").as(false) } diff --git a/semantics/src/main/scala/aqua/semantics/rules/names/NamesInterpreter.scala b/semantics/src/main/scala/aqua/semantics/rules/names/NamesInterpreter.scala index b8976c24..7a0c9121 100644 --- a/semantics/src/main/scala/aqua/semantics/rules/names/NamesInterpreter.scala +++ b/semantics/src/main/scala/aqua/semantics/rules/names/NamesInterpreter.scala @@ -6,9 +6,12 @@ import aqua.semantics.rules.StackInterpreter import aqua.semantics.rules.errors.ReportErrors import aqua.semantics.rules.locations.LocationsAlgebra import aqua.types.{AbilityType, ArrowType, StreamType, Type} + import cats.data.{OptionT, State} import cats.syntax.flatMap.* import cats.syntax.functor.* +import cats.syntax.applicative.* +import cats.syntax.all.* import monocle.Lens import monocle.macros.GenLens @@ -22,7 +25,7 @@ class NamesInterpreter[S[_], X](implicit GenLens[NamesState[S]](_.stack) ) - import stackInt.{getState, mapStackHead, modify, report} + import stackInt.{getState, mapStackHead, mapStackHeadM, mapStackHead_, modify, report} type SX[A] = State[X, A] @@ -98,22 +101,21 @@ class NamesInterpreter[S[_], X](implicit case false => report(name, "This name was already defined in the scope").as(false) } case None => - mapStackHead( - report(name, "Cannot define a variable in the root scope") - .as(false) - )(fr => fr.addName(name, `type`) -> true).flatTap(_ => locations.addToken(name.value, name)) + mapStackHeadM(report(name, "Cannot define a variable in the root scope").as(false))(fr => + (fr.addName(name, `type`) -> true).pure + ) <* locations.addToken(name.value, name) } override def derive(name: Name[S], `type`: Type, derivedFrom: Set[String]): State[X, Boolean] = - define(name, `type`).flatMap { - case true => - mapStackHead(State.pure(true))(_.derived(name, derivedFrom) -> true) - case false => State.pure(false) - }.flatTap(_ => locations.addToken(name.value, name)) + define(name, `type`).flatTap(defined => + mapStackHead_(_.derived(name, derivedFrom)).whenA(defined) + ) <* locations.addToken(name.value, name) override def getDerivedFrom(fromNames: List[Set[String]]): State[X, List[Set[String]]] = - mapStackHead(State.pure(Nil))(fr => - fr -> fromNames.map(ns => fr.derivedFrom.view.filterKeys(ns).values.foldLeft(ns)(_ ++ _)) + mapStackHead(Nil)(frame => + frame -> fromNames.map(ns => + frame.derivedFrom.view.filterKeys(ns).values.toList.combineAll ++ ns + ) ) override def defineConstant(name: Name[S], `type`: Type): SX[Boolean] = @@ -137,7 +139,7 @@ class NamesInterpreter[S[_], X](implicit } case None => - mapStackHead( + mapStackHeadM( if (isRoot) modify(st => st.copy( @@ -149,14 +151,14 @@ class NamesInterpreter[S[_], X](implicit else report(name, "Cannot define a variable in the root scope") .as(false) - )(fr => fr.addArrow(name, arrowType) -> true) + )(fr => (fr.addArrow(name, arrowType) -> true).pure) }.flatTap(_ => locations.addToken(name.value, name)) override def streamsDefinedWithinScope(): SX[Map[String, StreamType]] = - stackInt.mapStackHead(State.pure(Map.empty[String, StreamType])) { frame => + mapStackHead(Map.empty) { frame => frame -> frame.names.collect { case (n, st @ StreamType(_)) => n -> st - } + }.toMap } override def beginScope(token: Token[S]): SX[Unit] = diff --git a/semantics/src/main/scala/aqua/semantics/rules/types/TypesInterpreter.scala b/semantics/src/main/scala/aqua/semantics/rules/types/TypesInterpreter.scala index 3d66ddd3..d84c4934 100644 --- a/semantics/src/main/scala/aqua/semantics/rules/types/TypesInterpreter.scala +++ b/semantics/src/main/scala/aqua/semantics/rules/types/TypesInterpreter.scala @@ -14,6 +14,7 @@ import aqua.semantics.rules.locations.LocationsAlgebra import aqua.semantics.rules.StackInterpreter import aqua.semantics.rules.errors.ReportErrors import aqua.types.* + import cats.data.Validated.{Invalid, Valid} import cats.data.{Chain, NonEmptyList, NonEmptyMap, State} import cats.instances.list.* @@ -414,80 +415,54 @@ class TypesInterpreter[S[_], X](implicit override def checkArrowReturn( values: NonEmptyList[(ValueToken[S], ValueRaw)] ): State[X, Boolean] = - mapStackHeadE[Boolean]( + mapStackHeadM[Boolean]( report(values.head._1, "Fatal: checkArrowReturn has no matching beginArrowScope").as(false) - )((frame: TypesState.Frame[S]) => + )(frame => if (frame.retVals.nonEmpty) - Left( - ( - values.head._1, - "Return expression was already used in scope; you can use only one Return in an arrow declaration, use conditional return pattern if you need to return based on condition", - false - ) - ) + report( + values.head._1, + "Return expression was already used in scope; you can use only one Return in an arrow declaration, use conditional return pattern if you need to return based on condition" + ).as(frame -> false) else if (frame.token.res.isEmpty) - Left( - ( - values.head._1, - "No return type declared for this arrow, please remove `<- ...` expression or add `-> ...` return type(s) declaration to the arrow", - false - ) - ) + report( + values.head._1, + "No return type declared for this arrow, please remove `<- ...` expression or add `-> ...` return type(s) declaration to the arrow" + ).as(frame -> false) else if (frame.token.res.length > values.length) - Left( - ( - values.last._1, - s"Expected ${frame.token.res.length - values.length} more values to be returned, see return type declaration", - false - ) - ) + report( + values.last._1, + s"Expected ${frame.token.res.length - values.length} more values to be returned, see return type declaration" + ).as(frame -> false) else if (frame.token.res.length < values.length) - Left( - ( - values.toList.drop(frame.token.res.length).headOption.getOrElse(values.last)._1, - s"Too many values are returned from this arrow, this one is unexpected. Defined return type: ${frame.arrowType.codomain}", - false - ) - ) - else { + report( + values.toList.drop(frame.token.res.length).headOption.getOrElse(values.last)._1, + s"Too many values are returned from this arrow, this one is unexpected. Defined return type: ${frame.arrowType.codomain}" + ).as(frame -> false) + else frame.arrowType.codomain.toList - .lazyZip(values.toList) - .foldLeft[Either[(Token[S], String, Boolean), List[ValueRaw]]](Right(Nil)) { - case (acc, (returnType, (_, returnValue))) => - acc.flatMap { a => - if (!returnType.acceptsValueOf(returnValue.`type`)) - Left( - ( - values.toList - .drop(frame.token.res.length) - .headOption - .getOrElse(values.last) - ._1, - s"Wrong value type, expected: $returnType, given: ${returnValue.`type`}", - false - ) - ) - else Right(a :+ returnValue) - } + .zip(values.toList) + .traverse { case (returnType, (token, returnValue)) => + if (!returnType.acceptsValueOf(returnValue.`type`)) + report( + token, + s"Wrong value type, expected: $returnType, given: ${returnValue.`type`}" + ).as(none) + else returnValue.some.pure[SX] } - .map(res => frame.copy(retVals = Some(res)) -> true) - } + .map(_.sequence) + .map(res => frame.copy(retVals = res) -> res.isDefined) ) override def endArrowScope(token: Token[S]): State[X, List[ValueRaw]] = - mapStackHeadE[List[ValueRaw]]( + mapStackHeadM( report(token, "Fatal: endArrowScope has no matching beginArrowScope").as(Nil) )(frame => - if (frame.token.res.isEmpty) { - Right(frame -> Nil) - } else if (frame.retVals.isEmpty) { - Left( - ( - frame.token.res.headOption.getOrElse(frame.token), - "Return type is defined for the arrow, but nothing returned. Use `<- value, ...` as the last expression inside function body.", - Nil - ) - ) - } else Right(frame -> frame.retVals.getOrElse(Nil)) + if (frame.token.res.isEmpty) (frame -> Nil).pure + else if (frame.retVals.isEmpty) + report( + frame.token.res.headOption.getOrElse(frame.token), + "Return type is defined for the arrow, but nothing returned. Use `<- value, ...` as the last expression inside function body." + ).as(frame -> Nil) + else (frame -> frame.retVals.getOrElse(Nil)).pure ) <* stack.endScope } diff --git a/semantics/src/test/scala/aqua/semantics/ArrowSemSpec.scala b/semantics/src/test/scala/aqua/semantics/ArrowSemSpec.scala index b8847baf..35d203cc 100644 --- a/semantics/src/test/scala/aqua/semantics/ArrowSemSpec.scala +++ b/semantics/src/test/scala/aqua/semantics/ArrowSemSpec.scala @@ -28,7 +28,7 @@ class ArrowSemSpec extends AnyFlatSpec with Matchers with EitherValues { "sem" should "create empty model" in { val model = getModel(program("(a: string, b: u32) -> u8")) - model shouldBe (Raw.Empty("empty")) + model shouldBe (Raw.Empty("Invalid arrow body")) } "sem" should "create error model" ignore {