From a1576efad9d6a3d5fb20c01182b0d0a948c50ec4 Mon Sep 17 00:00:00 2001 From: Dima Date: Thu, 2 Nov 2023 20:13:25 +0700 Subject: [PATCH] fix(compiler): Code generate wrong stream name in AIR [LNG-276] (#958) --- aqua-src/antithesis.aqua | 2 +- .../aqua/model/inline/ArrowInliner.scala | 3 +- .../scala/aqua/model/inline/TagInliner.scala | 75 +++++++++++----- .../aqua/model/inline/ArrowInlinerSpec.scala | 90 +++++++++++++++++++ .../transform/pre/FuncPreTransformer.scala | 10 +-- .../rules/names/NamesInterpreter.scala | 10 +-- 6 files changed, 155 insertions(+), 35 deletions(-) diff --git a/aqua-src/antithesis.aqua b/aqua-src/antithesis.aqua index bb9364ec..17c7db1d 100644 --- a/aqua-src/antithesis.aqua +++ b/aqua-src/antithesis.aqua @@ -1,2 +1,2 @@ func arr(strs: []string) -> []string - <- strs \ No newline at end of file + <- strs diff --git a/model/inline/src/main/scala/aqua/model/inline/ArrowInliner.scala b/model/inline/src/main/scala/aqua/model/inline/ArrowInliner.scala index 6ffae791..701146a2 100644 --- a/model/inline/src/main/scala/aqua/model/inline/ArrowInliner.scala +++ b/model/inline/src/main/scala/aqua/model/inline/ArrowInliner.scala @@ -317,7 +317,7 @@ object ArrowInliner extends Logging { ) defineRenames <- Mangler[S].findAndForbidNames(defineNames) - renaming = ( + renaming = data.renames ++ streamRenames ++ arrowRenames ++ @@ -325,7 +325,6 @@ object ArrowInliner extends Logging { capturedValues.renames ++ capturedArrows.renames ++ defineRenames - ) /** * TODO: Optimize resolve. diff --git a/model/inline/src/main/scala/aqua/model/inline/TagInliner.scala b/model/inline/src/main/scala/aqua/model/inline/TagInliner.scala index daa7a0ff..81bb1386 100644 --- a/model/inline/src/main/scala/aqua/model/inline/TagInliner.scala +++ b/model/inline/src/main/scala/aqua/model/inline/TagInliner.scala @@ -39,34 +39,45 @@ object TagInliner extends Logging { * * @param prefix Previous instructions */ - enum TagInlined(prefix: Option[OpModel.Tree]) { + enum TagInlined[T](prefix: Option[OpModel.Tree]) { /** * Tag inlining emitted nothing */ - case Empty( + case Empty[S]( prefix: Option[OpModel.Tree] = None - ) extends TagInlined(prefix) + ) extends TagInlined[S](prefix) /** * Tag inlining emitted one parent model * * @param model Model which will wrap children */ - case Single( + case Single[S]( model: OpModel, prefix: Option[OpModel.Tree] = None - ) extends TagInlined(prefix) + ) extends TagInlined[S](prefix) /** * Tag inling emitted complex transformation * * @param toModel Function from children results to result of this tag */ - case Mapping( + case Mapping[S]( toModel: Chain[OpModel.Tree] => OpModel.Tree, prefix: Option[OpModel.Tree] = None - ) extends TagInlined(prefix) + ) extends TagInlined[S](prefix) + + /** + * Tag inlining emitted computation + * that should be executed after children + * + * @param model computation producing model + */ + case After[S]( + model: State[S, OpModel], + prefix: Option[OpModel.Tree] = None + ) extends TagInlined[S](prefix) /** * Finalize inlining, construct a tree @@ -74,23 +85,34 @@ object TagInliner extends Logging { * @param children Children results * @return Result of inlining */ - def build(children: Chain[OpModel.Tree]): OpModel.Tree = { - val inlined = this match { - case Empty(_) => children - case Single(model, _) => - Chain.one(model.wrap(children)) - case Mapping(toModel, _) => - Chain.one(toModel(children)) + def build(children: Chain[OpModel.Tree]): State[T, OpModel.Tree] = { + def toSeqModel(tree: OpModel.Tree | Chain[OpModel.Tree]): State[T, OpModel.Tree] = { + val treeChain = tree match { + case c: Chain[OpModel.Tree] => c + case t: OpModel.Tree => Chain.one(t) + } + + State.pure(SeqModel.wrap(Chain.fromOption(prefix) ++ treeChain)) + } + + this match { + case Empty(_) => + toSeqModel(children) + case Single(model, _) => + toSeqModel(model.wrap(children)) + case Mapping(toModel, _) => + toSeqModel(toModel(children)) + case After(model, _) => + model.flatMap(m => toSeqModel(m.wrap(children))) } - SeqModel.wrap(Chain.fromOption(prefix) ++ inlined) } } - private def pure[S](op: OpModel): State[S, TagInlined] = + private def pure[S](op: OpModel): State[S, TagInlined[S]] = TagInlined.Single(model = op).pure - private def none[S]: State[S, TagInlined] = + private def none[S]: State[S, TagInlined[S]] = TagInlined.Empty().pure private def combineOpsWithSeq(l: Option[OpModel.Tree], r: Option[OpModel.Tree]) = @@ -174,7 +196,7 @@ object TagInliner extends Logging { */ def tagToModel[S: Mangler: Arrows: Exports]( tag: RawTag - ): State[S, TagInlined] = + ): State[S, TagInlined[S]] = tag match { case OnTag(peerId, via, strategy) => for { @@ -371,7 +393,17 @@ object TagInliner extends Logging { } yield model.fold(TagInlined.Empty())(m => TagInlined.Single(model = m)) case RestrictionTag(name, typ) => - pure(RestrictionModel(name, typ)) + // Rename restriction after children are inlined with new exports + TagInlined + .After( + for { + exps <- Exports[S].exports + model = exps.get(name).collect { case VarModel(n, _, _) => + RestrictionModel(n, typ) + } + } yield model.getOrElse(RestrictionModel(name, typ)) + ) + .pure case DeclareStreamTag(value) => value match @@ -438,13 +470,14 @@ object TagInliner extends Logging { private def traverseS[S]( cf: RawTag.Tree, - f: RawTag => State[S, TagInlined] + f: RawTag => State[S, TagInlined[S]] ): State[S, OpModel.Tree] = for { headInlined <- f(cf.head) tail <- StateT.liftF(cf.tail) children <- tail.traverse(traverseS[S](_, f)) - } yield headInlined.build(children) + inlined <- headInlined.build(children) + } yield inlined def handleTree[S: Exports: Mangler: Arrows]( tree: RawTag.Tree diff --git a/model/inline/src/test/scala/aqua/model/inline/ArrowInlinerSpec.scala b/model/inline/src/test/scala/aqua/model/inline/ArrowInlinerSpec.scala index 5fae85b3..124c07c8 100644 --- a/model/inline/src/test/scala/aqua/model/inline/ArrowInlinerSpec.scala +++ b/model/inline/src/test/scala/aqua/model/inline/ArrowInlinerSpec.scala @@ -183,6 +183,96 @@ class ArrowInlinerSpec extends AnyFlatSpec with Matchers with Inside { } + /** + * func returnNil() -> *string: + * someStr: *string + * <- someStr + * + * func newFunc() -> []string: + * stream <- returnNil() + * stream <<- "asd" + * <- stream + */ + it should "rename restricted stream correctly" in { + val streamType = StreamType(ScalarType.string) + val someStr = VarRaw("someStr", streamType) + + val returnStreamArrowType = ArrowType( + ProductType(Nil), + ProductType(streamType :: Nil) + ) + + val returnNil = FuncArrow( + "returnNil", + SeqTag.wrap( + DeclareStreamTag(someStr).leaf, + ReturnTag( + NonEmptyList.one(someStr) + ).leaf + ), + returnStreamArrowType, + List(someStr), + Map.empty, + Map.empty, + None + ) + + val streamVar = VarRaw("stream", streamType) + val canonStreamVar = VarRaw( + s"-${streamVar.name}-canon-0", + CanonStreamType(ScalarType.string) + ) + val flatStreamVar = VarRaw( + s"-${streamVar.name}-flat-0", + ArrayType(ScalarType.string) + ) + + val newFunc = FuncArrow( + "newFunc", + RestrictionTag(streamVar.name, streamType).wrap( + SeqTag.wrap( + CallArrowRawTag + .func( + returnNil.funcName, + Call(Nil, Call.Export(streamVar.name, streamType) :: Nil) + ) + .leaf, + PushToStreamTag( + LiteralRaw.quote("asd"), + Call.Export(streamVar.name, streamVar.`type`) + ).leaf, + CanonicalizeTag( + streamVar, + Call.Export(canonStreamVar.name, canonStreamVar.`type`) + ).leaf, + FlattenTag( + canonStreamVar, + flatStreamVar.name + ).leaf, + ReturnTag( + NonEmptyList.one(flatStreamVar) + ).leaf + ) + ), + ArrowType( + ProductType(Nil), + ProductType(ArrayType(ScalarType.string) :: Nil) + ), + List(flatStreamVar), + Map(returnNil.funcName -> returnNil), + Map.empty, + None + ) + + val model = callFuncModel(newFunc) + + val restrictionName = model.collect { + case RestrictionModel(name, _) => name + }.headOption + + restrictionName shouldBe Some(someStr.name) + } + /** * func returnStream() -> *string: * stream: *string diff --git a/model/transform/src/main/scala/aqua/model/transform/pre/FuncPreTransformer.scala b/model/transform/src/main/scala/aqua/model/transform/pre/FuncPreTransformer.scala index 145a1c57..b3eea8ff 100644 --- a/model/transform/src/main/scala/aqua/model/transform/pre/FuncPreTransformer.scala +++ b/model/transform/src/main/scala/aqua/model/transform/pre/FuncPreTransformer.scala @@ -1,12 +1,10 @@ package aqua.model.transform.pre -import aqua.model.FuncArrow -import aqua.model.ArgsCall -import aqua.raw.ops.{Call, CallArrowRawTag, RawTag, SeqTag, TryTag} -import aqua.raw.value.{ValueRaw, VarRaw} +import aqua.model.{ArgsCall, FuncArrow} +import aqua.raw.ops.* +import aqua.raw.value.VarRaw import aqua.types.* -import cats.syntax.show.* import cats.syntax.option.* /** @@ -47,7 +45,7 @@ case class FuncPreTransformer( * @return FuncArrow that can be called and delegates the call to a client-registered callback */ private def arrowToCallback(name: String, arrowType: ArrowType): FuncArrow = { - val (args, call, ret) = ArgsCall.arrowToArgsCallRet(arrowType) + val (_, call, ret) = ArgsCall.arrowToArgsCallRet(arrowType) FuncArrow( arrowCallbackPrefix + name, callback(name, call), diff --git a/semantics/src/main/scala/aqua/semantics/rules/names/NamesInterpreter.scala b/semantics/src/main/scala/aqua/semantics/rules/names/NamesInterpreter.scala index bd718aea..b55b53a7 100644 --- a/semantics/src/main/scala/aqua/semantics/rules/names/NamesInterpreter.scala +++ b/semantics/src/main/scala/aqua/semantics/rules/names/NamesInterpreter.scala @@ -3,15 +3,15 @@ package aqua.semantics.rules.names import aqua.parser.lexer.{Name, Token} import aqua.semantics.Levenshtein import aqua.semantics.rules.StackInterpreter -import aqua.semantics.rules.report.ReportAlgebra import aqua.semantics.rules.locations.LocationsAlgebra -import aqua.types.{AbilityType, ArrowType, StreamType, Type} +import aqua.semantics.rules.report.ReportAlgebra +import aqua.types.{ArrowType, StreamType, Type} import cats.data.{OptionT, State} +import cats.syntax.all.* +import cats.syntax.applicative.* import cats.syntax.flatMap.* import cats.syntax.functor.* -import cats.syntax.applicative.* -import cats.syntax.all.* import monocle.Lens import monocle.macros.GenLens @@ -160,7 +160,7 @@ class NamesInterpreter[S[_], X](using mapStackHead(Map.empty) { frame => frame -> frame.names.collect { case (n, st @ StreamType(_)) => n -> st - }.toMap + } } override def beginScope(token: Token[S]): SX[Unit] =