From 5f6c47ffea5ab6e32df918a33414482129b00fd7 Mon Sep 17 00:00:00 2001 From: InversionSpaces Date: Mon, 9 Oct 2023 12:02:26 +0200 Subject: [PATCH] feat(compiler): Optimize math in compile time [LNG-245] (#922) * Move service calls for math to inlining * Fix: add predo * Introduce CallServiceRaw * Add comment * Add optimization to inlining * Add tests * map -> mapValues * Refactor type * Add optimization * Add optimization test * Fix unit tests * Fix PR comments * Restrict optimization * Add substraction to optimization * Apply optimizations in gate * Fix sign, move optimization to unfold * Fix type and tests * Fix unit tests * Add unit test * Fix after merge * Add optimization, fix unit tests * Fix comment --- .../aqua/compiler/AquaCompilerSpec.scala | 461 ++++++++++-------- .../aqua/model/inline/RawValueInliner.scala | 31 +- .../scala/aqua/model/inline/TagInliner.scala | 11 +- .../inline/raw/ApplyBinaryOpRawInliner.scala | 212 +++++++- .../raw/ApplyPropertiesRawInliner.scala | 68 +-- .../inline/raw/CallArrowRawInliner.scala | 40 +- .../inline/raw/CallServiceRawInliner.scala | 57 +++ .../model/inline/raw/StreamGateInliner.scala | 18 +- .../aqua/model/inline/tag/IfTagInliner.scala | 2 +- .../aqua/model/inline/ArrowInlinerSpec.scala | 291 +++++------ .../aqua/model/inline/CopyInlinerSpec.scala | 27 +- .../model/inline/MakeStructInlinerSpec.scala | 27 +- .../scala/aqua/model/inline/RawBuilder.scala | 15 +- .../model/inline/RawValueInlinerSpec.scala | 299 +++++++++++- .../src/main/scala/aqua/raw/ops/RawTag.scala | 38 +- .../scala/aqua/raw/value/Optimization.scala | 119 +++++ .../scala/aqua/raw/value/PropertyRaw.scala | 6 +- .../main/scala/aqua/raw/value/ValueRaw.scala | 205 +++++--- .../main/scala/aqua/model/AquaContext.scala | 19 +- .../src/main/scala/aqua/model/FuncArrow.scala | 27 +- .../main/scala/aqua/model/ValueModel.scala | 19 +- .../semantics/expr/func/CallArrowSem.scala | 2 +- .../aqua/semantics/rules/ValuesAlgebra.scala | 87 ++-- .../scala/aqua/semantics/SemanticsSpec.scala | 4 +- .../aqua/semantics/ValuesAlgebraSpec.scala | 4 +- types/src/main/scala/aqua/types/Type.scala | 49 +- 26 files changed, 1439 insertions(+), 699 deletions(-) create mode 100644 model/inline/src/main/scala/aqua/model/inline/raw/CallServiceRawInliner.scala create mode 100644 model/raw/src/main/scala/aqua/raw/value/Optimization.scala diff --git a/compiler/src/test/scala/aqua/compiler/AquaCompilerSpec.scala b/compiler/src/test/scala/aqua/compiler/AquaCompilerSpec.scala index 3b8dd0a7..fa753df1 100644 --- a/compiler/src/test/scala/aqua/compiler/AquaCompilerSpec.scala +++ b/compiler/src/test/scala/aqua/compiler/AquaCompilerSpec.scala @@ -15,16 +15,20 @@ import aqua.res.* import aqua.res.ResBuilder import aqua.types.{ArrayType, CanonStreamType, LiteralType, ScalarType, StreamType, Type} -import org.scalatest.flatspec.AnyFlatSpec -import org.scalatest.matchers.should.Matchers import cats.Id import cats.data.{Chain, NonEmptyChain, NonEmptyMap, Validated, ValidatedNec} import cats.instances.string.* import cats.syntax.show.* import cats.syntax.option.* import cats.syntax.either.* +import org.scalatest.flatspec.AnyFlatSpec +import org.scalatest.matchers.should.Matchers +import org.scalatest.Inside +import aqua.model.AquaContext +import aqua.model.FlattenModel +import aqua.model.CallServiceModel -class AquaCompilerSpec extends AnyFlatSpec with Matchers { +class AquaCompilerSpec extends AnyFlatSpec with Matchers with Inside { import ModelBuilder.* private def aquaSource(src: Map[String, String], imports: Map[String, String]) = { @@ -45,8 +49,13 @@ class AquaCompilerSpec extends AnyFlatSpec with Matchers { } } - private def compileToContext(src: Map[String, String], imports: Map[String, String]) = - CompilerAPI + private def insideContext( + src: Map[String, String], + imports: Map[String, String] = Map.empty + )( + test: AquaContext => Any + ) = { + val compiled = CompilerAPI .compileToContext[Id, String, String, Span.S]( aquaSource(src, imports), id => txt => Parser.parse(Parser.parserSchema)(txt), @@ -56,39 +65,52 @@ class AquaCompilerSpec extends AnyFlatSpec with Matchers { .value .toValidated + inside(compiled) { case Validated.Valid(contexts) => + inside(contexts.headOption) { case Some(ctx) => + test(ctx) + } + } + } + + private def insideRes( + src: Map[String, String], + imports: Map[String, String] = Map.empty, + transformCfg: TransformConfig = TransformConfig() + )(funcNames: String*)( + test: PartialFunction[List[FuncRes], Any] + ) = insideContext(src, imports)(ctx => + val aquaRes = Transform.contextRes(ctx, transformCfg) + // To preserve order as in funcNames do flatMap + val funcs = funcNames.flatMap(name => aquaRes.funcs.find(_.funcName == name)).toList + inside(funcs)(test) + ) + "aqua compiler" should "compile a simple snippet to the right context" in { - val res = compileToContext( - Map( - "index.aqua" -> - """module Foo declares X - | - |export foo, foo2 as foo_two, X - | - |const X = 5 - | - |func foo() -> string: - | <- "hello?" - | - |func foo2() -> string: - | <- "hello2?" - |""".stripMargin - ), - Map.empty + val src = Map( + "index.aqua" -> + """module Foo declares X + | + |export foo, foo2 as foo_two, X + | + |const X = 5 + | + |func foo() -> string: + | <- "hello?" + | + |func foo2() -> string: + | <- "hello2?" + |""".stripMargin ) - res.isValid should be(true) - val Validated.Valid(ctxs) = res + insideContext(src) { ctx => + ctx.allFuncs.contains("foo") should be(true) + ctx.allFuncs.contains("foo_two") should be(true) - ctxs.length should be(1) - val ctx = ctxs.headOption.get - - ctx.allFuncs.contains("foo") should be(true) - ctx.allFuncs.contains("foo_two") should be(true) - - val const = ctx.allValues.get("X") - const.nonEmpty should be(true) - const.get should be(LiteralModel.number(5)) + val const = ctx.allValues.get("X") + const.nonEmpty should be(true) + const.get should be(LiteralModel.number(5)) + } } def through(peer: ValueModel) = @@ -110,204 +132,237 @@ class AquaCompilerSpec extends AnyFlatSpec with Matchers { private def join(vm: VarModel, size: ValueModel) = ResBuilder.join(vm, size, init) - "aqua compiler" should "create right topology" in { - - val res = compileToContext( - Map( - "index.aqua" -> - """service Op("op"): - | identity(s: string) -> string - | - |func exec(peers: []string) -> []string: - | results: *string - | for peer <- peers par: - | on peer: - | results <- Op.identity("hahahahah") - | - | join results[2] - | <- results""".stripMargin - ), - Map.empty + it should "create right topology" in { + val src = Map( + "index.aqua" -> + """service Op("op"): + | identity(s: string) -> string + | + |func exec(peers: []string) -> []string: + | results: *string + | for peer <- peers par: + | on peer: + | results <- Op.identity("hahahahah") + | + | join results[2] + | <- results""".stripMargin ) - res.isValid should be(true) - val Validated.Valid(ctxs) = res - - ctxs.length should be(1) - val ctx = ctxs.headOption.get - val transformCfg = TransformConfig() - val aquaRes = Transform.contextRes(ctx, transformCfg) - val Some(exec) = aquaRes.funcs.find(_.funcName == "exec") + insideRes(src, transformCfg = transformCfg)("exec") { case exec :: _ => + val peers = VarModel("-peers-arg-", ArrayType(ScalarType.string)) + val peer = VarModel("peer-0", ScalarType.string) + val resultsType = StreamType(ScalarType.string) + val results = VarModel("results", resultsType) + val canonResult = + VarModel("-" + results.name + "-fix-0", CanonStreamType(resultsType.element)) + val flatResult = VarModel("-results-flat-0", ArrayType(ScalarType.string)) + val initPeer = LiteralModel.fromRaw(ValueRaw.InitPeerId) + val retVar = VarModel("ret", ScalarType.string) - val peers = VarModel("-peers-arg-", ArrayType(ScalarType.string)) - val peer = VarModel("peer-0", ScalarType.string) - val resultsType = StreamType(ScalarType.string) - val results = VarModel("results", resultsType) - val canonResult = VarModel("-" + results.name + "-fix-0", CanonStreamType(resultsType.element)) - val flatResult = VarModel("-results-flat-0", ArrayType(ScalarType.string)) - val initPeer = LiteralModel.fromRaw(ValueRaw.InitPeerId) - val sizeVar = VarModel("results_size", LiteralType.unsigned) - val retVar = VarModel("ret", ScalarType.string) - - val expected = - XorRes.wrap( - SeqRes.wrap( - getDataSrv("-relay-", "-relay-", ScalarType.string), - getDataSrv("peers", peers.name, peers.`type`), - RestrictionRes(results.name, resultsType).wrap( - SeqRes.wrap( - ParRes.wrap( - FoldRes(peer.name, peers, ForModel.Mode.Never.some).wrap( - ParRes.wrap( - XorRes.wrap( - // better if first relay will be outside `for` - SeqRes.wrap( - through(ValueModel.fromRaw(relay)), - CallServiceRes( - LiteralModel.fromRaw(LiteralRaw.quote("op")), - "identity", - CallRes( - LiteralModel.fromRaw(LiteralRaw.quote("hahahahah")) :: Nil, - Some(CallModel.Export(retVar.name, retVar.`type`)) - ), - peer - ).leaf, - ApRes(retVar, CallModel.Export(results.name, results.`type`)).leaf, - through(ValueModel.fromRaw(relay)), - through(initPeer) + val expected = + XorRes.wrap( + SeqRes.wrap( + getDataSrv("-relay-", "-relay-", ScalarType.string), + getDataSrv("peers", peers.name, peers.`type`), + RestrictionRes(results.name, resultsType).wrap( + SeqRes.wrap( + ParRes.wrap( + FoldRes(peer.name, peers, ForModel.Mode.Never.some).wrap( + ParRes.wrap( + XorRes.wrap( + // better if first relay will be outside `for` + SeqRes.wrap( + through(ValueModel.fromRaw(relay)), + CallServiceRes( + LiteralModel.fromRaw(LiteralRaw.quote("op")), + "identity", + CallRes( + LiteralModel.fromRaw(LiteralRaw.quote("hahahahah")) :: Nil, + Some(CallModel.Export(retVar.name, retVar.`type`)) + ), + peer + ).leaf, + ApRes(retVar, CallModel.Export(results.name, results.`type`)).leaf, + through(ValueModel.fromRaw(relay)), + through(initPeer) + ), + SeqRes.wrap( + through(ValueModel.fromRaw(relay)), + through(initPeer), + failErrorRes + ) ), - SeqRes.wrap( - through(ValueModel.fromRaw(relay)), - through(initPeer), - failErrorRes - ) - ), - NextRes(peer.name).leaf + NextRes(peer.name).leaf + ) ) - ) - ), - ResBuilder.add( - LiteralModel.number(2), - LiteralModel.number(1), - sizeVar, - initPeer - ), - join(results, sizeVar), - CanonRes(results, init, CallModel.Export(canonResult.name, canonResult.`type`)).leaf, + ), + join(results, LiteralModel.number(3)), // Compiler optimized addition + CanonRes( + results, + init, + CallModel.Export(canonResult.name, canonResult.`type`) + ).leaf, + ApRes( + canonResult, + CallModel.Export(flatResult.name, flatResult.`type`) + ).leaf + ) + ), + respCall(transformCfg, flatResult, initPeer) + ), + errorCall(transformCfg, 0, initPeer) + ) + + exec.body.equalsOrShowDiff(expected) shouldBe (true) + } + } + + it should "compile with imports" in { + + val src = Map( + "index.aqua" -> + """module Import + |import foobar from "export2.aqua" + | + |use foo as f from "export2.aqua" as Exp + | + |import "../gen/OneMore.aqua" + | + |export foo_wrapper as wrap, foobar as barfoo + | + |func foo_wrapper() -> string: + | z <- Exp.f() + | OneMore "hello" + | OneMore.more_call() + | -- Exp.f() returns literal, this func must return literal in AIR as well + | <- z + |""".stripMargin + ) + val imports = Map( + "export2.aqua" -> + """module Export declares foobar, foo + | + |func bar() -> string: + | <- " I am MyFooBar bar" + | + |func foo() -> string: + | <- "I am MyFooBar foo" + | + |func foobar() -> []string: + | res: *string + | res <- foo() + | res <- bar() + | <- res + | + |""".stripMargin, + "../gen/OneMore.aqua" -> + """ + |service OneMore: + | more_call() + | consume(s: string) + |""".stripMargin + ) + + val transformCfg = TransformConfig(relayVarName = None) + + insideRes(src, imports, transformCfg)( + "wrap", + "barfoo" + ) { case wrap :: barfoo :: _ => + val resStreamType = StreamType(ScalarType.string) + val resVM = VarModel("res", resStreamType) + val resCanonVM = VarModel("-res-fix-0", CanonStreamType(ScalarType.string)) + val resFlatVM = VarModel("-res-flat-0", ArrayType(ScalarType.string)) + + val expected = XorRes.wrap( + SeqRes.wrap( + RestrictionRes(resVM.name, resStreamType).wrap( + SeqRes.wrap( + // res <- foo() ApRes( - canonResult, - CallModel.Export(flatResult.name, flatResult.`type`) + LiteralModel.fromRaw(LiteralRaw.quote("I am MyFooBar foo")), + CallModel.Export(resVM.name, resVM.`type`) + ).leaf, + // res <- bar() + ApRes( + LiteralModel.fromRaw(LiteralRaw.quote(" I am MyFooBar bar")), + CallModel.Export(resVM.name, resVM.`type`) + ).leaf, + // canonicalization + CanonRes( + resVM, + LiteralModel.fromRaw(ValueRaw.InitPeerId), + CallModel.Export(resCanonVM.name, resCanonVM.`type`) + ).leaf, + // flattening + ApRes( + VarModel(resCanonVM.name, resCanonVM.`type`), + CallModel.Export(resFlatVM.name, resFlatVM.`type`) ).leaf ) ), - respCall(transformCfg, flatResult, initPeer) + respCall(transformCfg, resFlatVM, initPeer) ), errorCall(transformCfg, 0, initPeer) ) - exec.body.equalsOrShowDiff(expected) shouldBe (true) + barfoo.body.equalsOrShowDiff(expected) should be(true) + } } - "aqua compiler" should "compile with imports" in { - - val res = compileToContext( - Map( - "index.aqua" -> - """module Import - |import foobar from "export2.aqua" - | - |use foo as f from "export2.aqua" as Exp - | - |import "../gen/OneMore.aqua" - | - |export foo_wrapper as wrap, foobar as barfoo - | - |func foo_wrapper() -> string: - | z <- Exp.f() - | OneMore "hello" - | OneMore.more_call() - | -- Exp.f() returns literal, this func must return literal in AIR as well - | <- z - |""".stripMargin - ), - Map( - "export2.aqua" -> - """module Export declares foobar, foo - | - |func bar() -> string: - | <- " I am MyFooBar bar" - | - |func foo() -> string: - | <- "I am MyFooBar foo" - | - |func foobar() -> []string: - | res: *string - | res <- foo() - | res <- bar() - | <- res - | - |""".stripMargin, - "../gen/OneMore.aqua" -> - """ - |service OneMore: - | more_call() - | consume(s: string) - |""".stripMargin - ) + it should "optimize math inside stream join" in { + val src = Map( + "main.aqua" -> """ + |func main(i: i32): + | stream: *string + | stream <<- "a" + | stream <<- "b" + | join stream[i - 1] + |""".stripMargin ) - res.isValid should be(true) - val Validated.Valid(ctxs) = res + val transformCfg = TransformConfig() + val streamName = "stream" + val streamType = StreamType(ScalarType.string) + val argName = "-i-arg-" + val argType = ScalarType.i32 + val arg = VarModel(argName, argType) - ctxs.length should be(1) - val ctx = ctxs.headOption.get - - val transformCfg = TransformConfig(relayVarName = None) - val aquaRes = Transform.contextRes(ctx, transformCfg) - - val Some(funcWrap) = aquaRes.funcs.find(_.funcName == "wrap") - val Some(barfoo) = aquaRes.funcs.find(_.funcName == "barfoo") - - val resStreamType = StreamType(ScalarType.string) - val resVM = VarModel("res", resStreamType) - val resCanonVM = VarModel("-res-fix-0", CanonStreamType(ScalarType.string)) - val resFlatVM = VarModel("-res-flat-0", ArrayType(ScalarType.string)) + /** + * NOTE: Compiler generates this unused decrement bc + * it doesn't know that we are inlining just join + * and do not need to access the element. + */ + val decrement = CallServiceRes( + LiteralModel.quote("math"), + "sub", + CallRes( + List(arg, LiteralModel.number(1)), + Some(CallModel.Export("stream_idx", argType)) + ), + LiteralModel.fromRaw(ValueRaw.InitPeerId) + ).leaf val expected = XorRes.wrap( SeqRes.wrap( - RestrictionRes(resVM.name, resStreamType).wrap( + getDataSrv("-relay-", "-relay-", ScalarType.string), + getDataSrv("i", argName, argType), + RestrictionRes(streamName, streamType).wrap( SeqRes.wrap( - // res <- foo() - ApRes( - LiteralModel.fromRaw(LiteralRaw.quote("I am MyFooBar foo")), - CallModel.Export(resVM.name, resVM.`type`) - ).leaf, - // res <- bar() - ApRes( - LiteralModel.fromRaw(LiteralRaw.quote(" I am MyFooBar bar")), - CallModel.Export(resVM.name, resVM.`type`) - ).leaf, - // canonicalization - CanonRes( - resVM, - LiteralModel.fromRaw(ValueRaw.InitPeerId), - CallModel.Export(resCanonVM.name, resCanonVM.`type`) - ).leaf, - // flattening - ApRes( - VarModel(resCanonVM.name, resCanonVM.`type`), - CallModel.Export(resFlatVM.name, resFlatVM.`type`) - ).leaf + ApRes(LiteralModel.quote("a"), CallModel.Export(streamName, streamType)).leaf, + ApRes(LiteralModel.quote("b"), CallModel.Export(streamName, streamType)).leaf, + join(VarModel(streamName, streamType), arg), + decrement ) - ), - respCall(transformCfg, resFlatVM, initPeer) + ) ), errorCall(transformCfg, 0, initPeer) ) - barfoo.body.equalsOrShowDiff(expected) should be(true) - + insideRes(src, transformCfg = transformCfg)("main") { case main :: _ => + main.body.equalsOrShowDiff(expected) should be(true) + } } } diff --git a/model/inline/src/main/scala/aqua/model/inline/RawValueInliner.scala b/model/inline/src/main/scala/aqua/model/inline/RawValueInliner.scala index b94e510f..3caba68f 100644 --- a/model/inline/src/main/scala/aqua/model/inline/RawValueInliner.scala +++ b/model/inline/src/main/scala/aqua/model/inline/RawValueInliner.scala @@ -3,20 +3,12 @@ package aqua.model.inline import aqua.model.inline.state.{Arrows, Counter, Exports, Mangler} import aqua.model.inline.Inline.MergeMode.* import aqua.model.* -import aqua.model.inline.raw.{ - ApplyBinaryOpRawInliner, - ApplyFunctorRawInliner, - ApplyPropertiesRawInliner, - ApplyUnaryOpRawInliner, - CallArrowRawInliner, - CollectionRawInliner, - MakeAbilityRawInliner, - StreamGateInliner -} +import aqua.model.inline.raw.* import aqua.raw.ops.* import aqua.raw.value.* import aqua.types.{ArrayType, LiteralType, OptionType, StreamType} +import cats.Eval import cats.syntax.traverse.* import cats.syntax.monoid.* import cats.syntax.functor.* @@ -34,8 +26,10 @@ object RawValueInliner extends Logging { private[inline] def unfold[S: Mangler: Exports: Arrows]( raw: ValueRaw, propertiesAllowed: Boolean = true - ): State[S, (ValueModel, Inline)] = - raw match { + ): State[S, (ValueModel, Inline)] = for { + optimized <- StateT.liftF(Optimization.optimize(raw)) + _ <- StateT.liftF(Eval.later(logger.trace("OPTIMIZIED " + optimized))) + result <- optimized match { case VarRaw(name, t) => for { exports <- Exports[S].exports @@ -65,7 +59,12 @@ object RawValueInliner extends Logging { case cr: CallArrowRaw => CallArrowRawInliner(cr, propertiesAllowed) + + case cs: CallServiceRaw => + CallServiceRawInliner(cs, propertiesAllowed) + } + } yield result private[inline] def inlineToTree[S: Mangler: Exports: Arrows]( inline: Inline @@ -101,10 +100,10 @@ object RawValueInliner extends Logging { def valueToModel[S: Mangler: Exports: Arrows]( value: ValueRaw, propertiesAllowed: Boolean = true - ): State[S, (ValueModel, Option[OpModel.Tree])] = { - logger.trace("RAW " + value) - toModel(unfold(value, propertiesAllowed)) - } + ): State[S, (ValueModel, Option[OpModel.Tree])] = for { + _ <- StateT.liftF(Eval.later(logger.trace("RAW " + value))) + model <- toModel(unfold(value, propertiesAllowed)) + } yield model def valueListToModel[S: Mangler: Exports: Arrows]( values: List[ValueRaw] diff --git a/model/inline/src/main/scala/aqua/model/inline/TagInliner.scala b/model/inline/src/main/scala/aqua/model/inline/TagInliner.scala index 27d1ac4f..fcd996df 100644 --- a/model/inline/src/main/scala/aqua/model/inline/TagInliner.scala +++ b/model/inline/src/main/scala/aqua/model/inline/TagInliner.scala @@ -4,7 +4,7 @@ import aqua.errors.Errors.internalError import aqua.model.inline.state.{Arrows, Exports, Mangler} import aqua.model.* import aqua.model.inline.RawValueInliner.collectionToModel -import aqua.model.inline.raw.CallArrowRawInliner +import aqua.model.inline.raw.{CallArrowRawInliner, CallServiceRawInliner} import aqua.raw.value.ApplyBinaryOpRaw.Op as BinOp import aqua.raw.ops.* import aqua.raw.value.* @@ -308,8 +308,13 @@ object TagInliner extends Logging { TagInlined.Empty(prefix = parDesugarPrefix(nel.toList.flatMap(_._2))) }) - case CallArrowRawTag(exportTo, value: CallArrowRaw) => - CallArrowRawInliner.unfoldArrow(value, exportTo).flatMap { case (_, inline) => + case CallArrowRawTag(exportTo, value: (CallArrowRaw | CallServiceRaw)) => + (value match { + case ca: CallArrowRaw => + CallArrowRawInliner.unfold(ca, exportTo) + case cs: CallServiceRaw => + CallServiceRawInliner.unfold(cs, exportTo) + }).flatMap { case (_, inline) => RawValueInliner .inlineToTree(inline) .map(tree => diff --git a/model/inline/src/main/scala/aqua/model/inline/raw/ApplyBinaryOpRawInliner.scala b/model/inline/src/main/scala/aqua/model/inline/raw/ApplyBinaryOpRawInliner.scala index c41cd28f..48c06cad 100644 --- a/model/inline/src/main/scala/aqua/model/inline/raw/ApplyBinaryOpRawInliner.scala +++ b/model/inline/src/main/scala/aqua/model/inline/raw/ApplyBinaryOpRawInliner.scala @@ -1,5 +1,6 @@ package aqua.model.inline.raw +import aqua.errors.Errors.internalError import aqua.model.* import aqua.model.inline.raw.RawInliner import aqua.model.inline.TagInliner @@ -8,8 +9,9 @@ import aqua.raw.value.{AbilityRaw, LiteralRaw, MakeStructRaw} import cats.data.{NonEmptyList, NonEmptyMap, State} import aqua.model.inline.Inline import aqua.model.inline.RawValueInliner.{unfold, valueToModel} -import aqua.types.{ArrowType, ScalarType} +import aqua.types.{ArrowType, ScalarType, Type} import aqua.raw.value.ApplyBinaryOpRaw +import aqua.raw.value.ApplyBinaryOpRaw.Op import aqua.raw.value.ApplyBinaryOpRaw.Op.* import aqua.model.inline.Inline.MergeMode @@ -21,12 +23,10 @@ import cats.syntax.flatMap.* import cats.syntax.apply.* import cats.syntax.foldable.* import cats.syntax.applicative.* +import aqua.types.LiteralType object ApplyBinaryOpRawInliner extends RawInliner[ApplyBinaryOpRaw] { - private type BoolOp = And.type | Or.type - private type EqOp = Eq.type | Neq.type - override def apply[S: Mangler: Exports: Arrows]( raw: ApplyBinaryOpRaw, propertiesAllowed: Boolean @@ -37,16 +37,49 @@ object ApplyBinaryOpRawInliner extends RawInliner[ApplyBinaryOpRaw] { (rmodel, rinline) = right result <- raw.op match { - case op @ (And | Or) => inlineBoolOp(lmodel, rmodel, linline, rinline, op) - case op @ (Eq | Neq) => + case op: Op.Bool => + inlineBoolOp( + lmodel, + rmodel, + linline, + rinline, + op, + raw.baseType + ) + case op: Op.Eq => for { // Canonicalize stream operands before comparison leftStream <- TagInliner.canonicalizeIfStream(lmodel) (lmodelStream, linlineStream) = leftStream.map(linline.append) rightStream <- TagInliner.canonicalizeIfStream(rmodel) (rmodelStream, rinlineStream) = rightStream.map(rinline.append) - result <- inlineEqOp(lmodelStream, rmodelStream, linlineStream, rinlineStream, op) + result <- inlineEqOp( + lmodelStream, + rmodelStream, + linlineStream, + rinlineStream, + op, + raw.baseType + ) } yield result + case op: Op.Cmp => + inlineCmpOp( + lmodel, + rmodel, + linline, + rinline, + op, + raw.baseType + ) + case op: Op.Math => + inlineMathOp( + lmodel, + rmodel, + linline, + rinline, + op, + raw.baseType + ) } } yield result @@ -55,7 +88,8 @@ object ApplyBinaryOpRawInliner extends RawInliner[ApplyBinaryOpRaw] { rmodel: ValueModel, linline: Inline, rinline: Inline, - op: EqOp + op: Op.Eq, + resType: Type ): State[S, (ValueModel, Inline)] = (lmodel, rmodel) match { // Optimize in case compared values are literals // Semantics should check that types are comparable @@ -69,7 +103,7 @@ object ApplyBinaryOpRawInliner extends RawInliner[ApplyBinaryOpRaw] { }, linline.mergeWith(rinline, MergeMode.ParMode) ).pure[State[S, *]] - case _ => fullInlineEqOp(lmodel, rmodel, linline, rinline, op) + case _ => fullInlineEqOp(lmodel, rmodel, linline, rinline, op, resType) } private def fullInlineEqOp[S: Mangler: Exports: Arrows]( @@ -77,7 +111,8 @@ object ApplyBinaryOpRawInliner extends RawInliner[ApplyBinaryOpRaw] { rmodel: ValueModel, linline: Inline, rinline: Inline, - op: EqOp + op: Op.Eq, + resType: Type ): State[S, (ValueModel, Inline)] = { val (name, shouldMatch) = op match { case Eq => ("eq", true) @@ -114,7 +149,7 @@ object ApplyBinaryOpRawInliner extends RawInliner[ApplyBinaryOpRaw] { ) ) - result(name, predo) + result(name, resType, predo) } private def inlineBoolOp[S: Mangler: Exports: Arrows]( @@ -122,7 +157,8 @@ object ApplyBinaryOpRawInliner extends RawInliner[ApplyBinaryOpRaw] { rmodel: ValueModel, linline: Inline, rinline: Inline, - op: BoolOp + op: Op.Bool, + resType: Type ): State[S, (ValueModel, Inline)] = (lmodel, rmodel) match { // Optimize in case of left value is known at compile time case (LiteralModel.Bool(lvalue), _) => @@ -139,7 +175,7 @@ object ApplyBinaryOpRawInliner extends RawInliner[ApplyBinaryOpRaw] { case _ => (lmodel, linline) }).pure[State[S, *]] // Produce unoptimized inline - case _ => fullInlineBoolOp(lmodel, rmodel, linline, rinline, op) + case _ => fullInlineBoolOp(lmodel, rmodel, linline, rinline, op, resType) } private def fullInlineBoolOp[S: Mangler: Exports: Arrows]( @@ -147,7 +183,8 @@ object ApplyBinaryOpRawInliner extends RawInliner[ApplyBinaryOpRaw] { rmodel: ValueModel, linline: Inline, rinline: Inline, - op: BoolOp + op: Op.Bool, + resType: Type ): State[S, (ValueModel, Inline)] = { val (name, compareWith) = op match { case And => ("and", false) @@ -190,19 +227,162 @@ object ApplyBinaryOpRawInliner extends RawInliner[ApplyBinaryOpRaw] { ) ) - result(name, predo) + result(name, resType, predo) + } + + private def inlineCmpOp[S: Mangler: Exports: Arrows]( + lmodel: ValueModel, + rmodel: ValueModel, + linline: Inline, + rinline: Inline, + op: Op.Cmp, + resType: Type + ): State[S, (ValueModel, Inline)] = (lmodel, rmodel) match { + case ( + LiteralModel.Integer(lv, _), + LiteralModel.Integer(rv, _) + ) => + val res = op match { + case Lt => lv < rv + case Lte => lv <= rv + case Gt => lv > rv + case Gte => lv >= rv + } + + ( + LiteralModel.bool(res), + Inline(linline.predo ++ rinline.predo) + ).pure + case _ => + val fn = op match { + case Lt => "lt" + case Lte => "lte" + case Gt => "gt" + case Gte => "gte" + } + + val predo = (resName: String) => + SeqModel.wrap( + linline.predo ++ rinline.predo :+ CallServiceModel( + serviceId = LiteralModel.quote("cmp"), + funcName = fn, + call = CallModel( + args = lmodel :: rmodel :: Nil, + exportTo = CallModel.Export(resName, resType) :: Nil + ) + ).leaf + ) + + result(fn, resType, predo) + } + + private def inlineMathOp[S: Mangler: Exports: Arrows]( + lmodel: ValueModel, + rmodel: ValueModel, + linline: Inline, + rinline: Inline, + op: Op.Math, + resType: Type + ): State[S, (ValueModel, Inline)] = (lmodel, rmodel) match { + case ( + LiteralModel.Integer(lv, lt), + LiteralModel.Integer(rv, rt) + ) if canOptimizeMath(lv, lt, rv, rt, op) => + val res = op match { + case Add => lv + rv + case Sub => lv - rv + case Mul => lv * rv + case Div => lv / rv + case Rem => lv % rv + case Pow => intPow(lv, rv) + case _ => internalError(s"Unsupported operation $op for $lv and $rv") + } + + ( + LiteralModel.number(res), + Inline(linline.predo ++ rinline.predo) + ).pure + case _ => + val fn = op match { + case Add => "add" + case Sub => "sub" + case Mul => "mul" + case FMul => "fmul" + case Div => "div" + case Rem => "rem" + case Pow => "pow" + } + + val predo = (resName: String) => + SeqModel.wrap( + linline.predo ++ rinline.predo :+ CallServiceModel( + serviceId = LiteralModel.quote("math"), + funcName = fn, + call = CallModel( + args = lmodel :: rmodel :: Nil, + exportTo = CallModel.Export(resName, resType) :: Nil + ) + ).leaf + ) + + result(fn, resType, predo) } private def result[S: Mangler]( name: String, + resType: Type, predo: String => OpModel.Tree ): State[S, (ValueModel, Inline)] = Mangler[S] .findAndForbidName(name) .map(resName => ( - VarModel(resName, ScalarType.bool), + VarModel(resName, resType), Inline(Chain.one(predo(resName))) ) ) + + /** + * Check if we can optimize math operation + * in compile time. + * + * @param left left operand + * @param leftType type of left operand + * @param right right operand + * @param rightType type of right operand + * @param op operation + * @return true if we can optimize this operation + */ + private def canOptimizeMath( + left: Long, + leftType: ScalarType | LiteralType, + right: Long, + rightType: ScalarType | LiteralType, + op: Op.Math + ): Boolean = op match { + // Leave division by zero for runtime + case Op.Div | Op.Rem => right != 0 + // Leave negative power for runtime + case Op.Pow => right >= 0 + case Op.Sub => + // Leave subtraction overflow for runtime + ScalarType.isSignedInteger(leftType) || + ScalarType.isSignedInteger(rightType) + case _ => true + } + + /** + * Integer power (binary exponentiation) + * + * @param base + * @param exp >= 0 + * @return base ** exp + */ + private def intPow(base: Long, exp: Long): Long = { + def intPowTailRec(base: Long, exp: Long, acc: Long): Long = + if (exp <= 0) acc + else intPowTailRec(base * base, exp / 2, if (exp % 2 == 0) acc else acc * base) + + intPowTailRec(base, exp, 1) + } } diff --git a/model/inline/src/main/scala/aqua/model/inline/raw/ApplyPropertiesRawInliner.scala b/model/inline/src/main/scala/aqua/model/inline/raw/ApplyPropertiesRawInliner.scala index 1f0318a1..df9a34da 100644 --- a/model/inline/src/main/scala/aqua/model/inline/raw/ApplyPropertiesRawInliner.scala +++ b/model/inline/src/main/scala/aqua/model/inline/raw/ApplyPropertiesRawInliner.scala @@ -257,51 +257,59 @@ object ApplyPropertiesRawInliner extends RawInliner[ApplyPropertyRaw] with Loggi idx: ValueRaw ): State[S, (VarModel, Inline)] = for { /** - * Inline idx + * Inline size, which is `idx + 1` + * Increment on ValueRaw level to + * apply possible optimizations */ - idxInlined <- unfold(idx) + sizeInlined <- unfold(idx.increment) + (sizeVM, sizeInline) = sizeInlined + /** + * Inline idx which is `size - 1` + * TODO: Do not generate it if + * it is not needed, e.g. in `join` + */ + idxInlined <- sizeVM match { + /** + * Micro optimization: if idx is a literal + * do not generate inline. + */ + case LiteralModel.Integer(i, t) => + (LiteralModel((i - 1).toString, t), Inline.empty).pure[State[S, *]] + case _ => + Mangler[S].findAndForbidName(s"${streamName}_idx").map { idxName => + val idxVar = VarModel(idxName, sizeVM.`type`) + val idxInline = Inline.tree( + CallServiceModel( + "math", + funcName = "sub", + args = List(sizeVM, LiteralModel.number(1)), + result = idxVar + ).leaf + ) + + (idxVar, idxInline) + } + } (idxVM, idxInline) = idxInlined /** - * Inline size which is `idx + 1` - * TODO: Refactor to apply optimizations + * Inline join of `size` elements of stream */ - sizeName <- Mangler[S].findAndForbidName(s"${streamName}_size") - sizeVar = VarModel(sizeName, idxVM.`type`) - sizeInline = CallServiceModel( - "math", - funcName = "add", - args = List(idxVM, LiteralModel.number(1)), - result = sizeVar - ).leaf - gateInlined <- StreamGateInliner(streamName, streamType, sizeVar) + gateInlined <- StreamGateInliner(streamName, streamType, sizeVM) (gateVM, gateInline) = gateInlined - /** - * Remove properties from idx - * as we need to use it in index - * TODO: Do not generate it - * if it is not needed, - * e.g. in `join` - */ - idxFlattened <- idxVM match { - case vr: VarModel => removeProperties(vr) - case _ => (idxVM, Inline.empty).pure[State[S, *]] - } - (idxFlat, idxFlatInline) = idxFlattened /** * Construct stream[idx] */ gate = gateVM.withProperty( IntoIndexModel - .fromValueModel(idxFlat, streamType.element) + .fromValueModel(idxVM, streamType.element) .getOrElse( - internalError(s"Unexpected: could not convert ($idxFlat) to IntoIndexModel") + internalError(s"Unexpected: could not convert ($idxVM) to IntoIndexModel") ) ) } yield gate -> Inline( - idxInline.predo - .append(sizeInline) ++ + sizeInline.predo ++ gateInline.predo ++ - idxFlatInline.predo, + idxInline.predo, mergeMode = SeqMode ) diff --git a/model/inline/src/main/scala/aqua/model/inline/raw/CallArrowRawInliner.scala b/model/inline/src/main/scala/aqua/model/inline/raw/CallArrowRawInliner.scala index 1074b8e2..700deaca 100644 --- a/model/inline/src/main/scala/aqua/model/inline/raw/CallArrowRawInliner.scala +++ b/model/inline/src/main/scala/aqua/model/inline/raw/CallArrowRawInliner.scala @@ -7,48 +7,28 @@ import aqua.model.inline.state.{Arrows, Exports, Mangler} import aqua.model.inline.{ArrowInliner, Inline, TagInliner} import aqua.raw.ops.Call import aqua.raw.value.CallArrowRaw + import cats.data.{Chain, State} import cats.syntax.traverse.* import scribe.Logging object CallArrowRawInliner extends RawInliner[CallArrowRaw] with Logging { - private[inline] def unfoldArrow[S: Mangler: Exports: Arrows]( + private[inline] def unfold[S: Mangler: Exports: Arrows]( value: CallArrowRaw, exportTo: List[Call.Export] ): State[S, (List[ValueModel], Inline)] = Exports[S].exports.flatMap { exports => logger.trace(s"${exportTo.mkString(" ")} $value") val call = Call(value.arguments, exportTo) - value.serviceId match { - case Some(serviceId) => - logger.trace(Console.BLUE + s"call service id $serviceId" + Console.RESET) - for { - cd <- callToModel(call, true) - (callModel, callInline) = cd - sd <- valueToModel(serviceId) - (serviceIdValue, serviceIdInline) = sd - values = callModel.exportTo.map(e => e.name -> e.asVar.resolveWith(exports)).toMap - inline = Inline( - Chain( - SeqModel.wrap( - serviceIdInline.toList ++ callInline.toList :+ - CallServiceModel(serviceIdValue, value.name, callModel).leaf - ) - ) - ) - _ <- Exports[S].resolved(values) - _ <- Mangler[S].forbid(values.keySet) - } yield values.values.toList -> inline - case None => - /** - * Here the back hop happens from [[TagInliner]] to [[ArrowInliner.callArrow]] - */ - val funcName = value.ability.fold(value.name)(_ + "." + value.name) - logger.trace(s" $funcName") - resolveArrow(funcName, call) - } + /** + * Here the back hop happens from [[TagInliner]] to [[ArrowInliner.callArrow]] + */ + val funcName = value.ability.fold(value.name)(_ + "." + value.name) + logger.trace(s" $funcName") + + resolveArrow(funcName, call) } private def resolveFuncArrow[S: Mangler: Exports: Arrows]( @@ -103,7 +83,7 @@ object CallArrowRawInliner extends RawInliner[CallArrowRaw] with Logging { Mangler[S] .findAndForbidName(raw.name) .flatMap(n => - unfoldArrow(raw, Call.Export(n, raw.`type`) :: Nil).map { + unfold(raw, Call.Export(n, raw.`type`) :: Nil).map { case (Nil, inline) => (VarModel(n, raw.`type`), inline) case (h :: _, inline) => (h, inline) } diff --git a/model/inline/src/main/scala/aqua/model/inline/raw/CallServiceRawInliner.scala b/model/inline/src/main/scala/aqua/model/inline/raw/CallServiceRawInliner.scala new file mode 100644 index 00000000..eca1a633 --- /dev/null +++ b/model/inline/src/main/scala/aqua/model/inline/raw/CallServiceRawInliner.scala @@ -0,0 +1,57 @@ +package aqua.model.inline.raw + +import aqua.errors.Errors.internalError +import aqua.model.* +import aqua.model.inline.RawValueInliner.{callToModel, valueToModel} +import aqua.model.inline.state.{Arrows, Exports, Mangler} +import aqua.model.inline.{ArrowInliner, Inline, TagInliner} +import aqua.raw.ops.Call +import aqua.raw.value.CallServiceRaw + +import cats.data.{Chain, State} +import cats.syntax.traverse.* +import scribe.Logging + +object CallServiceRawInliner extends RawInliner[CallServiceRaw] with Logging { + + private[inline] def unfold[S: Mangler: Exports: Arrows]( + value: CallServiceRaw, + exportTo: List[Call.Export] + ): State[S, (List[ValueModel], Inline)] = Exports[S].exports.flatMap { exports => + logger.trace(s"${exportTo.mkString(" ")} $value") + logger.trace(Console.BLUE + s"call service id ${value.serviceId}" + Console.RESET) + + val call = Call(value.arguments, exportTo) + + for { + cd <- callToModel(call, true) + (callModel, callInline) = cd + sd <- valueToModel(value.serviceId) + (serviceIdValue, serviceIdInline) = sd + values = callModel.exportTo.map(e => e.name -> e.asVar.resolveWith(exports)).toMap + inline = Inline( + Chain( + SeqModel.wrap( + serviceIdInline.toList ++ callInline.toList :+ + CallServiceModel(serviceIdValue, value.fnName, callModel).leaf + ) + ) + ) + _ <- Exports[S].resolved(values) + _ <- Mangler[S].forbid(values.keySet) + } yield values.values.toList -> inline + } + + override def apply[S: Mangler: Exports: Arrows]( + raw: CallServiceRaw, + propertiesAllowed: Boolean + ): State[S, (ValueModel, Inline)] = + Mangler[S] + .findAndForbidName(raw.fnName) + .flatMap(n => + unfold(raw, Call.Export(n, raw.`type`) :: Nil).map { + case (Nil, inline) => (VarModel(n, raw.`type`), inline) + case (h :: _, inline) => (h, inline) + } + ) +} diff --git a/model/inline/src/main/scala/aqua/model/inline/raw/StreamGateInliner.scala b/model/inline/src/main/scala/aqua/model/inline/raw/StreamGateInliner.scala index 0c52bee1..3c55e15a 100644 --- a/model/inline/src/main/scala/aqua/model/inline/raw/StreamGateInliner.scala +++ b/model/inline/src/main/scala/aqua/model/inline/raw/StreamGateInliner.scala @@ -4,7 +4,6 @@ import aqua.errors.Errors.internalError import aqua.model.* import aqua.model.inline.Inline import aqua.model.inline.state.{Arrows, Exports, Mangler} -import aqua.raw.value.{LiteralRaw, VarRaw} import aqua.model.inline.RawValueInliner.unfold import aqua.types.{ArrayType, CanonStreamType, ScalarType, StreamType} @@ -25,16 +24,14 @@ object StreamGateInliner extends Logging { * (seq * (fold $stream s * (seq - * (seq - * (ap s $stream_test) - * (canon $stream_test #stream_iter_canon) - * ) - * (xor - * (match #stream_iter_canon.length size - * (null) - * ) - * (next s) + * (ap s $stream_test) + * (canon $stream_test #stream_iter_canon) + * ) + * (xor + * (match #stream_iter_canon.length size + * (null) * ) + * (next s) * ) * (never) * ) @@ -100,7 +97,6 @@ object StreamGateInliner extends Logging { uniqueCanonName <- Mangler[S].findAndForbidName(streamName + "_result_canon") uniqueResultName <- Mangler[S].findAndForbidName(streamName + "_gate") uniqueTestName <- Mangler[S].findAndForbidName(streamName + "_test") - uniqueIdxIncr <- Mangler[S].findAndForbidName(streamName + "_incr") uniqueIterCanon <- Mangler[S].findAndForbidName(streamName + "_iter_canon") uniqueIter <- Mangler[S].findAndForbidName(streamName + "_fold_var") } yield { diff --git a/model/inline/src/main/scala/aqua/model/inline/tag/IfTagInliner.scala b/model/inline/src/main/scala/aqua/model/inline/tag/IfTagInliner.scala index 900a8283..e563c609 100644 --- a/model/inline/src/main/scala/aqua/model/inline/tag/IfTagInliner.scala +++ b/model/inline/src/main/scala/aqua/model/inline/tag/IfTagInliner.scala @@ -21,7 +21,7 @@ final case class IfTagInliner( def inlined[S: Mangler: Exports: Arrows] = (valueRaw match { // Optimize in case last operation is equality check - case ApplyBinaryOpRaw(op @ (BinOp.Eq | BinOp.Neq), left, right) => + case ApplyBinaryOpRaw(op @ (BinOp.Eq | BinOp.Neq), left, right, _) => ( valueToModel(left) >>= canonicalizeIfStream, valueToModel(right) >>= canonicalizeIfStream diff --git a/model/inline/src/test/scala/aqua/model/inline/ArrowInlinerSpec.scala b/model/inline/src/test/scala/aqua/model/inline/ArrowInlinerSpec.scala index 8d483b2f..b44d5b32 100644 --- a/model/inline/src/test/scala/aqua/model/inline/ArrowInlinerSpec.scala +++ b/model/inline/src/test/scala/aqua/model/inline/ArrowInlinerSpec.scala @@ -827,7 +827,7 @@ class ArrowInlinerSpec extends AnyFlatSpec with Matchers with Inside { val res1 = VarModel("res", ScalarType.u16) val res2 = VarModel("res2", ScalarType.u16) val res3 = VarModel("res-0", ScalarType.u16) - val tempAdd = VarModel("add-0", ScalarType.u16) + val tempAdd = VarModel("add", ScalarType.u16) val expected = SeqModel.wrap( MetaModel @@ -843,7 +843,7 @@ class ArrowInlinerSpec extends AnyFlatSpec with Matchers with Inside { ), SeqModel.wrap( ModelBuilder.add(res2, res3)(tempAdd).leaf, - ModelBuilder.add(res1, tempAdd)(VarModel("add", ScalarType.u16)).leaf + ModelBuilder.add(res1, tempAdd)(VarModel("add-0", ScalarType.u16)).leaf ) ) @@ -889,15 +889,12 @@ class ArrowInlinerSpec extends AnyFlatSpec with Matchers with Inside { capturedTopology = None ) - val innerCall = CallArrowRaw( - ability = None, - name = innerName, - arguments = Nil, + val innerCall = CallArrowRaw.func( + funcName = innerName, baseType = ArrowType( domain = NilType, codomain = ProductType(List(ScalarType.u16)) - ), - serviceId = None + ) ) val outerAdd = "37" @@ -943,36 +940,18 @@ class ArrowInlinerSpec extends AnyFlatSpec with Matchers with Inside { .runA(InliningState()) .value - /* WARNING: This naming is unstable */ - val tempAdd0 = VarModel("add-0", ScalarType.u16) - val tempAdd = VarModel("add", ScalarType.u16) - - val expected = SeqModel.wrap( - ModelBuilder - .add( - LiteralModel(innerRet, ScalarType.u16), - LiteralModel(outerAdd, ScalarType.u16) - )(tempAdd0) - .leaf, - ModelBuilder - .add( - LiteralModel(innerRet, ScalarType.u16), - tempAdd0 - )(tempAdd) - .leaf - ) - - model.equalsOrShowDiff(expected) shouldEqual true + // Addition is completely optimized out + model.equalsOrShowDiff(EmptyModel.leaf) shouldEqual true } /** * closureName = (x: u16) -> u16: - * retval = x + add + * retval <- TestSrv.call(x, add) * <- retval * * @return (closure func, closure type, closure type labelled) */ - def addClosure( + def srvCallClosure( closureName: String, add: ValueRaw ): (FuncRaw, ArrowType, ArrowType) = { @@ -993,13 +972,16 @@ class ArrowInlinerSpec extends AnyFlatSpec with Matchers with Inside { ) val closureBody = SeqTag.wrap( - AssignmentTag( - RawBuilder.add( - closureArg, - add - ), - closureRes.name - ).leaf, + CallArrowRawTag + .service( + LiteralRaw.quote("test-srv"), + funcName = "call", + Call( + args = List(closureArg, add), + exportTo = List(Call.Export(closureRes.name, closureRes.`type`)) + ) + ) + .leaf, ReturnTag( NonEmptyList.one(closureRes) ).leaf @@ -1017,10 +999,21 @@ class ArrowInlinerSpec extends AnyFlatSpec with Matchers with Inside { (closureFunc, closureType, closureTypeLabelled) } + def srvCallModel( + x: ValueModel, + add: ValueModel, + result: VarModel + ): CallServiceModel = CallServiceModel( + serviceId = "test-srv", + funcName = "call", + args = List(x, add), + result = result + ) + /** * func innerName(arg: u16) -> u16 -> u16: * closureName = (x: u16) -> u16: - * retval = x + arg + * retval <- TestSrv.call(x, arg) * <- retval * <- closureName * @@ -1042,7 +1035,7 @@ class ArrowInlinerSpec extends AnyFlatSpec with Matchers with Inside { ) val (closureFunc, closureType, closureTypeLabelled) = - addClosure(closureName, innerArg) + srvCallClosure(closureName, innerArg) val innerRes = VarRaw( closureName, @@ -1085,12 +1078,10 @@ class ArrowInlinerSpec extends AnyFlatSpec with Matchers with Inside { val innerCall = CallArrowRawTag( List(Call.Export(outterClosure.name, outterClosure.`type`)), - CallArrowRaw( - ability = None, - name = innerName, - arguments = List(LiteralRaw("42", LiteralType.number)), + CallArrowRaw.func( + funcName = innerName, baseType = innerType, - serviceId = None + arguments = List(LiteralRaw("42", LiteralType.number)) ) ).leaf @@ -1128,7 +1119,7 @@ class ArrowInlinerSpec extends AnyFlatSpec with Matchers with Inside { /** * func inner(arg: u16) -> u16 -> u16: * closure = (x: u16) -> u16: - * retval = x + arg + * retval <- TestSrv.call(x, arg) * <- retval * <- closure * @@ -1144,12 +1135,10 @@ class ArrowInlinerSpec extends AnyFlatSpec with Matchers with Inside { val outterResultName = "retval" val closureCall = (closureType: ArrowType, i: String) => - CallArrowRaw( - ability = None, - name = outterClosureName, - arguments = List(LiteralRaw(i, LiteralType.number)), + CallArrowRaw.func( + funcName = outterClosureName, baseType = closureType, - serviceId = None + arguments = List(LiteralRaw(i, LiteralType.unsigned)) ) val body = (closureType: ArrowType) => @@ -1157,7 +1146,7 @@ class ArrowInlinerSpec extends AnyFlatSpec with Matchers with Inside { AssignmentTag( RawBuilder.add( RawBuilder.add( - LiteralRaw("37", LiteralType.number), + LiteralRaw("37", LiteralType.unsigned), closureCall(closureType, "1") ), closureCall(closureType, "2") @@ -1180,20 +1169,19 @@ class ArrowInlinerSpec extends AnyFlatSpec with Matchers with Inside { .wrap( ApplyTopologyModel(closureName) .wrap( - ModelBuilder - .add( - LiteralModel(x, LiteralType.number), - LiteralModel("42", LiteralType.number) - )(o) - .leaf + srvCallModel( + LiteralModel(x, LiteralType.unsigned), + LiteralModel("42", LiteralType.unsigned), + result = o + ).leaf ) ) /* WARNING: This naming is unstable */ - val tempAdd0 = VarModel("add-0", ScalarType.u16) - val tempAdd1 = VarModel("add-1", ScalarType.u16) - val tempAdd2 = VarModel("add-2", ScalarType.u16) + val retval1 = VarModel("retval-0", ScalarType.u16) + val retval2 = VarModel("retval-1", ScalarType.u16) val tempAdd = VarModel("add", ScalarType.u16) + val tempAdd0 = VarModel("add-0", ScalarType.u16) val expected = SeqModel.wrap( MetaModel @@ -1202,23 +1190,21 @@ class ArrowInlinerSpec extends AnyFlatSpec with Matchers with Inside { CaptureTopologyModel(closureName).leaf ), SeqModel.wrap( - ParModel.wrap( - SeqModel.wrap( - closureCallModel("1", tempAdd1), - ModelBuilder - .add( - LiteralModel("37", LiteralType.number), - tempAdd1 - )(tempAdd0) - .leaf - ), - closureCallModel("2", tempAdd2) + SeqModel.wrap( + closureCallModel("1", retval1), + closureCallModel("2", retval2), + ModelBuilder + .add( + retval1, + retval2 + )(tempAdd) + .leaf ), ModelBuilder .add( - tempAdd0, - tempAdd2 - )(tempAdd) + tempAdd, + LiteralModel.number(37) + )(tempAdd0) .leaf ) ) @@ -1306,22 +1292,18 @@ class ArrowInlinerSpec extends AnyFlatSpec with Matchers with Inside { val innerCall = CallArrowRawTag( List(Call.Export(outterClosure.name, outterClosure.`type`)), - CallArrowRaw( - ability = None, - name = innerName, - arguments = Nil, + CallArrowRaw.func( + funcName = innerName, baseType = innerType, - serviceId = None + arguments = Nil ) ).leaf val closureCall = - CallArrowRaw( - ability = None, - name = outterClosure.name, - arguments = Nil, + CallArrowRaw.func( + funcName = outterClosure.name, baseType = closureType, - serviceId = None + arguments = Nil ) val outerBody = SeqTag.wrap( @@ -1365,38 +1347,14 @@ class ArrowInlinerSpec extends AnyFlatSpec with Matchers with Inside { .runA(InliningState()) .value - /* WARNING: This naming is unstable */ - val tempAdd0 = VarModel("add-0", ScalarType.u16) - val tempAdd = VarModel("add", ScalarType.u16) - - val number = (v: String) => - LiteralModel( - v, - LiteralType.number - ) - - val expected = SeqModel.wrap( - ModelBuilder - .add( - number("37"), - number("42") - )(tempAdd0) - .leaf, - ModelBuilder - .add( - tempAdd0, - number("42") - )(tempAdd) - .leaf - ) - - model.equalsOrShowDiff(expected) shouldEqual true + // Addition is completely optimized out + model.equalsOrShowDiff(EmptyModel.leaf) shouldEqual true } /** * func inner(arg: u16) -> u16 -> u16: * closure = (x: u16) -> u16: - * retval = x + arg + * retval = TestSrv.call(x, arg) * <- retval * <- closure * @@ -1404,7 +1362,7 @@ class ArrowInlinerSpec extends AnyFlatSpec with Matchers with Inside { * c <- inner(42) * b = c * a = b - * retval = 37 + a(1) + b(2) + c{3} + * retval = 37 + a(1) + b(2) + c(3) * <- retval */ it should "correctly inline renamed closure [bug LNG-193]" in { @@ -1416,12 +1374,10 @@ class ArrowInlinerSpec extends AnyFlatSpec with Matchers with Inside { val secondRename = "a" val closureCall = (name: String, closureType: ArrowType, i: String) => - CallArrowRaw( - ability = None, - name = name, + CallArrowRaw.func( + funcName = name, arguments = List(LiteralRaw(i, LiteralType.number)), - baseType = closureType, - serviceId = None + baseType = closureType ) val body = (closureType: ArrowType) => @@ -1458,28 +1414,27 @@ class ArrowInlinerSpec extends AnyFlatSpec with Matchers with Inside { body = body ) - val closureCallModel = (x: String, o: VarModel) => + val closureCallModel = (x: Long, o: VarModel) => MetaModel .CallArrowModel(closureName) .wrap( ApplyTopologyModel(closureName) .wrap( - ModelBuilder - .add( - LiteralModel(x, LiteralType.number), - LiteralModel("42", LiteralType.number) - )(o) - .leaf + srvCallModel( + LiteralModel.number(x), + LiteralModel.number(42), + result = o + ).leaf ) ) /* WARNING: This naming is unstable */ + val tempAdd = VarModel("add", ScalarType.u16) val tempAdd0 = VarModel("add-0", ScalarType.u16) val tempAdd1 = VarModel("add-1", ScalarType.u16) - val tempAdd2 = VarModel("add-2", ScalarType.u16) - val tempAdd3 = VarModel("add-3", ScalarType.u16) - val tempAdd4 = VarModel("add-4", ScalarType.u16) - val tempAdd = VarModel("add", ScalarType.u16) + val retval0 = VarModel("retval-0", ScalarType.u16) + val retval1 = VarModel("retval-1", ScalarType.u16) + val retval2 = VarModel("retval-2", ScalarType.u16) val expected = SeqModel.wrap( MetaModel @@ -1488,35 +1443,16 @@ class ArrowInlinerSpec extends AnyFlatSpec with Matchers with Inside { CaptureTopologyModel(closureName).leaf ), SeqModel.wrap( - ParModel.wrap( + SeqModel.wrap( SeqModel.wrap( - ParModel.wrap( - SeqModel.wrap( - closureCallModel("1", tempAdd2), - ModelBuilder - .add( - LiteralModel("37", LiteralType.number), - tempAdd2 - )(tempAdd1) - .leaf - ), - closureCallModel("2", tempAdd3) - ), - ModelBuilder - .add( - tempAdd1, - tempAdd3 - )(tempAdd0) - .leaf + closureCallModel(1, retval0), + closureCallModel(2, retval1), + ModelBuilder.add(retval0, retval1)(tempAdd).leaf ), - closureCallModel("3", tempAdd4) + closureCallModel(3, retval2), + ModelBuilder.add(tempAdd, retval2)(tempAdd0).leaf ), - ModelBuilder - .add( - tempAdd0, - tempAdd4 - )(tempAdd) - .leaf + ModelBuilder.add(tempAdd0, LiteralModel.number(37))(tempAdd1).leaf ) ) @@ -1530,7 +1466,7 @@ class ArrowInlinerSpec extends AnyFlatSpec with Matchers with Inside { * * func test() -> u16: * closure = (x: u16) -> u16: - * resC = x + 37 + * resC <- TestSrv.call(x, 37) * <- resC * resT <- accept_closure(closure) * <- resT @@ -1543,7 +1479,7 @@ class ArrowInlinerSpec extends AnyFlatSpec with Matchers with Inside { val testRes = VarRaw("resT", ScalarType.u16) val (closureFunc, closureType, closureTypeLabelled) = - addClosure(closureName, LiteralRaw("37", LiteralType.number)) + srvCallClosure(closureName, LiteralRaw.number(37)) val acceptType = ArrowType( domain = ProductType.labelled(List(closureName -> closureType)), @@ -1553,12 +1489,10 @@ class ArrowInlinerSpec extends AnyFlatSpec with Matchers with Inside { val acceptBody = SeqTag.wrap( CallArrowRawTag( List(Call.Export(acceptRes.name, acceptRes.baseType)), - CallArrowRaw( - ability = None, - name = closureName, - arguments = List(LiteralRaw("42", LiteralType.number)), + CallArrowRaw.func( + funcName = closureName, baseType = closureType, - serviceId = None + arguments = List(LiteralRaw.number(42)) ) ).leaf, ReturnTag( @@ -1586,12 +1520,10 @@ class ArrowInlinerSpec extends AnyFlatSpec with Matchers with Inside { ).leaf, CallArrowRawTag( List(Call.Export(testRes.name, testRes.baseType)), - CallArrowRaw( - ability = None, - name = acceptName, - arguments = List(VarRaw(closureName, closureTypeLabelled)), + CallArrowRaw.func( + funcName = acceptName, baseType = acceptFunc.arrowType, - serviceId = None + arguments = List(VarRaw(closureName, closureTypeLabelled)) ) ).leaf, ReturnTag( @@ -1621,7 +1553,7 @@ class ArrowInlinerSpec extends AnyFlatSpec with Matchers with Inside { .value /* WARNING: This naming is unstable */ - val tempAdd = VarModel("add", ScalarType.u16) + val retval = VarModel("retval", ScalarType.u16) val expected = SeqModel.wrap( CaptureTopologyModel(closureName).leaf, @@ -1632,12 +1564,11 @@ class ArrowInlinerSpec extends AnyFlatSpec with Matchers with Inside { .CallArrowModel(closureName) .wrap( ApplyTopologyModel(closureName).wrap( - ModelBuilder - .add( - LiteralModel("42", LiteralType.number), - LiteralModel("37", LiteralType.number) - )(tempAdd) - .leaf + srvCallModel( + LiteralModel.number(42), + LiteralModel.number(37), + retval + ).leaf ) ) ) @@ -1673,17 +1604,16 @@ class ArrowInlinerSpec extends AnyFlatSpec with Matchers with Inside { val testBody = SeqTag.wrap( CallArrowRawTag .service( - serviceId = serviceId, - fnName = argMethodName, + srvId = serviceId, + funcName = argMethodName, call = Call( args = VarRaw(argMethodName, ScalarType.string) :: Nil, exportTo = Call.Export(res.name, res.`type`) :: Nil ), - name = serviceName, arrowType = ArrowType( domain = ProductType.labelled(List(argMethodName -> ScalarType.string)), codomain = ProductType(ScalarType.string :: Nil) - ) + ).some ) .leaf, ReturnTag( @@ -2014,7 +1944,7 @@ class ArrowInlinerSpec extends AnyFlatSpec with Matchers with Inside { FuncArrow( "dumb_func", SeqTag.wrap( - AssignmentTag(LiteralRaw("1", LiteralType.number), argVar.name).leaf, + AssignmentTag(LiteralRaw.number(1), argVar.name).leaf, foldOp ), ArrowType( @@ -2037,7 +1967,7 @@ class ArrowInlinerSpec extends AnyFlatSpec with Matchers with Inside { CallServiceModel( LiteralModel.fromRaw(serviceId), fnName, - CallModel(LiteralModel("1", LiteralType.number) :: Nil, Nil) + CallModel(LiteralModel.number(1) :: Nil, Nil) ).leaf, NextModel(iVar0.name).leaf ) @@ -2215,8 +2145,7 @@ class ArrowInlinerSpec extends AnyFlatSpec with Matchers with Inside { val closureBody = SeqTag.wrap( AssignmentTag( - CallArrowRaw.service( - "cmp", + CallServiceRaw( LiteralRaw.quote("cmp"), "gt", ArrowType( diff --git a/model/inline/src/test/scala/aqua/model/inline/CopyInlinerSpec.scala b/model/inline/src/test/scala/aqua/model/inline/CopyInlinerSpec.scala index ab0e7436..844f144d 100644 --- a/model/inline/src/test/scala/aqua/model/inline/CopyInlinerSpec.scala +++ b/model/inline/src/test/scala/aqua/model/inline/CopyInlinerSpec.scala @@ -6,6 +6,7 @@ import aqua.model.inline.state.InliningState import aqua.raw.ops.* import aqua.raw.value.* import aqua.types.* + import cats.data.{Chain, NonEmptyList, NonEmptyMap} import cats.syntax.show.* import org.scalatest.flatspec.AnyFlatSpec @@ -24,12 +25,11 @@ class CopyInlinerSpec extends AnyFlatSpec with Matchers { val length = FunctorRaw("length", ScalarType.u32) val lengthValue = VarRaw("l", arrType).withProperty(length) - val getField = CallArrowRaw( - None, + val getField = CallServiceRaw( + LiteralRaw.quote("serv"), "get_field", - Nil, ArrowType(NilType, UnlabeledConsType(ScalarType.string, NilType)), - Option(LiteralRaw.quote("serv")) + Nil ) val copyRaw = @@ -63,9 +63,22 @@ class CopyInlinerSpec extends AnyFlatSpec with Matchers { ).leaf ), RestrictionModel(streamMapName, streamMapType).wrap( - InsertKeyValueModel(LiteralModel.quote("field1"), VarModel("l_length", ScalarType.u32), streamMapName, streamMapType).leaf, - InsertKeyValueModel(LiteralModel.quote("field2"), VarModel("get_field", ScalarType.string), streamMapName, streamMapType).leaf, - CanonicalizeModel(VarModel(streamMapName, streamMapType), CallModel.Export(result.name, result.`type`)).leaf + InsertKeyValueModel( + LiteralModel.quote("field1"), + VarModel("l_length", ScalarType.u32), + streamMapName, + streamMapType + ).leaf, + InsertKeyValueModel( + LiteralModel.quote("field2"), + VarModel("get_field", ScalarType.string), + streamMapName, + streamMapType + ).leaf, + CanonicalizeModel( + VarModel(streamMapName, streamMapType), + CallModel.Export(result.name, result.`type`) + ).leaf ) ) ) shouldBe true diff --git a/model/inline/src/test/scala/aqua/model/inline/MakeStructInlinerSpec.scala b/model/inline/src/test/scala/aqua/model/inline/MakeStructInlinerSpec.scala index 9d147370..6ab85395 100644 --- a/model/inline/src/test/scala/aqua/model/inline/MakeStructInlinerSpec.scala +++ b/model/inline/src/test/scala/aqua/model/inline/MakeStructInlinerSpec.scala @@ -6,6 +6,7 @@ import aqua.model.inline.state.InliningState import aqua.raw.ops.* import aqua.raw.value.* import aqua.types.* + import cats.data.{Chain, NonEmptyList, NonEmptyMap} import cats.syntax.show.* import org.scalatest.flatspec.AnyFlatSpec @@ -24,12 +25,11 @@ class MakeStructInlinerSpec extends AnyFlatSpec with Matchers { val length = FunctorRaw("length", ScalarType.u32) val lengthValue = VarRaw("l", arrType).withProperty(length) - val getField = CallArrowRaw( - None, + val getField = CallServiceRaw( + LiteralRaw.quote("serv"), "get_field", - Nil, ArrowType(NilType, UnlabeledConsType(ScalarType.string, NilType)), - Option(LiteralRaw.quote("serv")) + Nil ) val makeStruct = @@ -62,9 +62,22 @@ class MakeStructInlinerSpec extends AnyFlatSpec with Matchers { ).leaf ), RestrictionModel(streamMapName, streamMapType).wrap( - InsertKeyValueModel(LiteralModel.quote("field1"), VarModel("l_length", ScalarType.u32), streamMapName, streamMapType).leaf, - InsertKeyValueModel(LiteralModel.quote("field2"), VarModel("get_field", ScalarType.string), streamMapName, streamMapType).leaf, - CanonicalizeModel(VarModel(streamMapName, streamMapType), CallModel.Export(result.name, result.`type`)).leaf + InsertKeyValueModel( + LiteralModel.quote("field1"), + VarModel("l_length", ScalarType.u32), + streamMapName, + streamMapType + ).leaf, + InsertKeyValueModel( + LiteralModel.quote("field2"), + VarModel("get_field", ScalarType.string), + streamMapName, + streamMapType + ).leaf, + CanonicalizeModel( + VarModel(streamMapName, streamMapType), + CallModel.Export(result.name, result.`type`) + ).leaf ) ) ) shouldBe true diff --git a/model/inline/src/test/scala/aqua/model/inline/RawBuilder.scala b/model/inline/src/test/scala/aqua/model/inline/RawBuilder.scala index 329576f2..67dae910 100644 --- a/model/inline/src/test/scala/aqua/model/inline/RawBuilder.scala +++ b/model/inline/src/test/scala/aqua/model/inline/RawBuilder.scala @@ -1,21 +1,10 @@ package aqua.model.inline -import aqua.raw.value.{CallArrowRaw, LiteralRaw, ValueRaw} +import aqua.raw.value.{ApplyBinaryOpRaw, ValueRaw} import aqua.types.{ArrowType, ProductType, ScalarType} object RawBuilder { def add(l: ValueRaw, r: ValueRaw): ValueRaw = - CallArrowRaw.service( - abilityName = "math", - serviceId = LiteralRaw.quote("math"), - funcName = "add", - baseType = ArrowType( - ProductType(List(ScalarType.i64, ScalarType.i64)), - ProductType( - List(l.`type` `∪` r.`type`) - ) - ), - arguments = List(l, r) - ) + ApplyBinaryOpRaw.Add(l, r) } diff --git a/model/inline/src/test/scala/aqua/model/inline/RawValueInlinerSpec.scala b/model/inline/src/test/scala/aqua/model/inline/RawValueInlinerSpec.scala index 021cdadd..8248986e 100644 --- a/model/inline/src/test/scala/aqua/model/inline/RawValueInlinerSpec.scala +++ b/model/inline/src/test/scala/aqua/model/inline/RawValueInlinerSpec.scala @@ -1,34 +1,48 @@ package aqua.model.inline -import aqua.model.inline.raw.ApplyPropertiesRawInliner -import aqua.model.{ - EmptyModel, - FlattenModel, - FunctorModel, - IntoFieldModel, - IntoIndexModel, - ParModel, - SeqModel, - ValueModel, - VarModel -} +import aqua.model.inline.raw.{ApplyPropertiesRawInliner, StreamGateInliner} +import aqua.model.* import aqua.model.inline.state.InliningState import aqua.raw.value.{ApplyPropertyRaw, FunctorRaw, IntoIndexRaw, LiteralRaw, VarRaw} import aqua.types.* +import aqua.raw.value.* + +import cats.Eval import cats.data.NonEmptyMap import cats.data.Chain import cats.syntax.show.* +import cats.syntax.foldable.* +import cats.free.Cofree +import scala.collection.immutable.SortedMap +import scala.math import org.scalatest.flatspec.AnyFlatSpec import org.scalatest.matchers.should.Matchers +import org.scalatest.Inside -import scala.collection.immutable.SortedMap -import aqua.raw.value.ApplyBinaryOpRaw -import aqua.raw.value.CallArrowRaw - -class RawValueInlinerSpec extends AnyFlatSpec with Matchers { +class RawValueInlinerSpec extends AnyFlatSpec with Matchers with Inside { import RawValueInliner.valueToModel + def join(stream: VarModel, size: ValueModel) = + stream match { + case VarModel( + streamName, + streamType: StreamType, + Chain.`nil` + ) => + StreamGateInliner.joinStreamOnIndexModel( + streamName = streamName, + streamType = streamType, + sizeModel = size, + testName = streamName + "_test", + iterName = streamName + "_fold_var", + canonName = streamName + "_result_canon", + iterCanonName = streamName + "_iter_canon", + resultName = streamName + "_gate" + ) + case _ => ??? + } + private def numVarWithLength(name: String) = VarRaw(name, ArrayType(ScalarType.u32)).withProperty( FunctorRaw("length", ScalarType.u32) @@ -126,6 +140,51 @@ class RawValueInlinerSpec extends AnyFlatSpec with Matchers { IntoIndexRaw(ysVarRaw(1), ScalarType.string) ) + def int(i: Int): LiteralRaw = LiteralRaw.number(i) + + extension (l: ValueRaw) { + + def cmp(op: ApplyBinaryOpRaw.Op.Cmp)(r: ValueRaw): ApplyBinaryOpRaw = + ApplyBinaryOpRaw(op, l, r, ScalarType.bool) + + def math(op: ApplyBinaryOpRaw.Op.Math)(r: ValueRaw): ApplyBinaryOpRaw = + ApplyBinaryOpRaw(op, l, r, ScalarType.i64) // result type is not important here + + def `<`(r: ValueRaw): ApplyBinaryOpRaw = + cmp(ApplyBinaryOpRaw.Op.Lt)(r) + + def `<=`(r: ValueRaw): ApplyBinaryOpRaw = + cmp(ApplyBinaryOpRaw.Op.Lte)(r) + + def `>`(r: ValueRaw): ApplyBinaryOpRaw = + cmp(ApplyBinaryOpRaw.Op.Gt)(r) + + def `>=`(r: ValueRaw): ApplyBinaryOpRaw = + cmp(ApplyBinaryOpRaw.Op.Gte)(r) + + def `+`(r: ValueRaw): ApplyBinaryOpRaw = + math(ApplyBinaryOpRaw.Op.Add)(r) + + def `-`(r: ValueRaw): ApplyBinaryOpRaw = + math(ApplyBinaryOpRaw.Op.Sub)(r) + + def `*`(r: ValueRaw): ApplyBinaryOpRaw = + math(ApplyBinaryOpRaw.Op.Mul)(r) + + def `/`(r: ValueRaw): ApplyBinaryOpRaw = + math(ApplyBinaryOpRaw.Op.Div)(r) + + def `%`(r: ValueRaw): ApplyBinaryOpRaw = + math(ApplyBinaryOpRaw.Op.Rem)(r) + + def `**`(r: ValueRaw): ApplyBinaryOpRaw = + math(ApplyBinaryOpRaw.Op.Pow)(r) + + } + + private def ivar(name: String, t: Option[Type] = None): VarRaw = + VarRaw(name, t.getOrElse(ScalarType.i32)) + "raw value inliner" should "desugarize a single non-recursive raw value" in { // x[y] valueToModel[InliningState](`raw x[y]`) @@ -305,24 +364,53 @@ class RawValueInlinerSpec extends AnyFlatSpec with Matchers { } it should "desugarize stream with gate" in { - val streamWithProps = - VarRaw("x", StreamType(ScalarType.string)).withProperty( - IntoIndexRaw(ysVarRaw(1), ScalarType.string) - ) + val stream = VarRaw("x", StreamType(ScalarType.string)) + val streamModel = VarModel.fromVarRaw(stream) + val idxRaw = ysVarRaw(1) + val streamWithProps = stream.withProperty( + IntoIndexRaw(idxRaw, ScalarType.string) + ) - val (resVal, resTree) = valueToModel[InliningState](streamWithProps) - .runA(InliningState(noNames = Set("x", "ys"))) - .value + val initState = InliningState(noNames = Set("x", "ys")) + + // Here retrieve how size is inlined + val (afterSizeState, (sizeModel, sizeTree)) = + valueToModel[InliningState](idxRaw.increment).run(initState).value + + val (resVal, resTree) = + valueToModel[InliningState](streamWithProps).runA(initState).value + + val idxModel = VarModel("x_idx", ScalarType.i8) + + val decrement = CallServiceModel( + "math", + "sub", + List( + sizeModel, + LiteralModel.number(1) + ), + idxModel + ).leaf + + val expected = SeqModel.wrap( + sizeTree.toList :+ + join(streamModel, sizeModel) :+ + decrement + ) resVal should be( VarModel( "x_gate", ArrayType(ScalarType.string), Chain( - IntoIndexModel("ys_flat", ScalarType.string) + IntoIndexModel(idxModel.name, ScalarType.string) ) ) ) + + inside(resTree) { case Some(tree) => + tree.equalsOrShowDiff(expected) should be(true) + } } it should "desugarize stream with length" in { @@ -388,4 +476,165 @@ class RawValueInlinerSpec extends AnyFlatSpec with Matchers { ) ) should be(true) } + + it should "optimize constants comparison" in { + + for { + l <- -100 to 100 + r <- -100 to 100 + } { + val lt = valueToModel[InliningState]( + int(l) `<` int(r) + ).runA(InliningState()).value + + lt shouldBe ( + LiteralModel.bool(l < r) -> None + ) + + val lte = valueToModel[InliningState]( + int(l) `<=` int(r) + ).runA(InliningState()).value + + lte shouldBe ( + LiteralModel.bool(l <= r) -> None + ) + + val gt = valueToModel[InliningState]( + int(l) `>` int(r) + ).runA(InliningState()).value + + gt shouldBe ( + LiteralModel.bool(l > r) -> None + ) + + val gte = valueToModel[InliningState]( + int(l) `>=` int(r) + ).runA(InliningState()).value + + gte shouldBe ( + LiteralModel.bool(l >= r) -> None + ) + } + } + + it should "optimize constants math" in { + for { + l <- -100 to 100 + r <- -100 to 100 + } { + val add = valueToModel[InliningState]( + int(l) `+` int(r) + ).runA(InliningState()).value + + add shouldBe ( + LiteralModel.number(l + r) -> None + ) + + val sub = valueToModel[InliningState]( + int(l) `-` int(r) + ).runA(InliningState()).value + + sub shouldBe ( + LiteralModel.number(l - r) -> None + ) + + val mul = valueToModel[InliningState]( + int(l) `*` int(r) + ).runA(InliningState()).value + + mul shouldBe ( + LiteralModel.number(l * r) -> None + ) + + val div = valueToModel[InliningState]( + int(l) `/` int(r) + ).runA(InliningState()).value + + val rem = valueToModel[InliningState]( + int(l) `%` int(r) + ).runA(InliningState()).value + + if (r != 0) + div shouldBe ( + LiteralModel.number(l / r) -> None + ) + rem shouldBe ( + LiteralModel.number(l % r) -> None + ) + else { + val (dmodel, dtree) = div + dmodel shouldBe a[VarModel] + dtree.nonEmpty shouldBe (true) + + val (rmodel, rtree) = rem + rmodel shouldBe a[VarModel] + rtree.nonEmpty shouldBe (true) + } + + if (r >= 0 && r <= 5) { + val pow = valueToModel[InliningState]( + int(l) `**` int(r) + ).runA(InliningState()).value + + pow shouldBe ( + LiteralModel.number(scala.math.pow(l, r).toLong) -> None + ) + } + } + } + + it should "optimize addition in expressions" in { + def test(numVars: Int, numLiterals: Int) = { + val vars = (1 to numVars).map(i => ivar(s"v$i")).toList + val literals = (1 to numLiterals).map(i => LiteralRaw.number(i)).toList + val values = vars ++ literals + + /** + * Enumerate all possible binary trees of vals + */ + def genAllExprs(vals: List[ValueRaw]): List[ValueRaw] = + if (vals.length <= 1) vals + else + for { + split <- (1 until vals.length).toList + (left, right) = vals.splitAt(split) + l <- genAllExprs(left) + r <- genAllExprs(right) + } yield l `+` r + + for { + perm <- values.permutations.toList + expr <- genAllExprs(perm) + } { + val state = InliningState( + resolvedExports = vars.map(v => v.name -> VarModel.fromVarRaw(v)).toMap + ) + val (model, inline) = valueToModel[InliningState](expr).runA(state).value + + model shouldBe a[VarModel] + inside(inline) { case Some(tree) => + val numberOfAdditions = Cofree + .cata(tree) { (model, count: Chain[Int]) => + Eval.later { + count.combineAll + (model match { + case CallServiceModel(_, "add", _) => 1 + case _ => 0 + }) + } + } + .value + + numberOfAdditions shouldEqual numVars + } + } + } + + /** + * Number of expressions grows exponentially + * So we test only small cases + */ + test(2, 2) + test(3, 2) + test(2, 3) + } } diff --git a/model/raw/src/main/scala/aqua/raw/ops/RawTag.scala b/model/raw/src/main/scala/aqua/raw/ops/RawTag.scala index 43d49278..34d9389d 100644 --- a/model/raw/src/main/scala/aqua/raw/ops/RawTag.scala +++ b/model/raw/src/main/scala/aqua/raw/ops/RawTag.scala @@ -2,7 +2,7 @@ package aqua.raw.ops import aqua.raw.arrow.FuncRaw import aqua.raw.ops.RawTag.Tree -import aqua.raw.value.{CallArrowRaw, ValueRaw} +import aqua.raw.value.{CallArrowRaw, CallServiceRaw, ValueRaw} import aqua.tree.{TreeNode, TreeNodeCompanion} import aqua.types.{ArrowType, DataType, ServiceType} @@ -224,26 +224,6 @@ object CallArrowRawTag { ) ) - def service( - serviceId: ValueRaw, - fnName: String, - call: Call, - name: String = null, - arrowType: ArrowType = null - ): CallArrowRawTag = - CallArrowRawTag( - call.exportTo, - CallArrowRaw( - Option(name), - fnName, - call.args, - Option(arrowType).getOrElse( - call.arrowType - ), - Some(serviceId) - ) - ) - def func(fnName: String, call: Call): CallArrowRawTag = CallArrowRawTag( call.exportTo, @@ -253,6 +233,22 @@ object CallArrowRawTag { arguments = call.args ) ) + + def service( + srvId: ValueRaw, + funcName: String, + call: Call, + arrowType: Option[ArrowType] = None + ): CallArrowRawTag = + CallArrowRawTag( + call.exportTo, + CallServiceRaw( + srvId, + funcName, + arrowType.getOrElse(call.arrowType), + call.args + ) + ) } case class DeclareStreamTag( diff --git a/model/raw/src/main/scala/aqua/raw/value/Optimization.scala b/model/raw/src/main/scala/aqua/raw/value/Optimization.scala new file mode 100644 index 00000000..412de4cf --- /dev/null +++ b/model/raw/src/main/scala/aqua/raw/value/Optimization.scala @@ -0,0 +1,119 @@ +package aqua.raw.value + +import cats.Eval +import cats.data.Ior +import cats.Semigroup +import cats.syntax.applicative.* +import cats.syntax.apply.* +import cats.syntax.functor.* +import cats.syntax.semigroup.* + +object Optimization { + + /** + * Optimize raw value. + * + * A lot more optimizations could be done, + * it is here just as a proof of concept. + */ + def optimize(value: ValueRaw): Eval[ValueRaw] = + Addition.optimize(value) + + object Addition { + + def optimize(value: ValueRaw): Eval[ValueRaw] = + gatherLiteralsInAddition(value).map { res => + res.fold( + // TODO: Type of literal is not preserved + LiteralRaw.number, + identity, + (literal, nonLiteral) => + if (literal > 0) { + ApplyBinaryOpRaw.Add(nonLiteral, LiteralRaw.number(literal)) + } else if (literal < 0) { + ApplyBinaryOpRaw.Sub(nonLiteral, LiteralRaw.number(-literal)) + } else nonLiteral + ) + } + + private def gatherLiteralsInAddition( + value: ValueRaw + ): Eval[Ior[Long, ValueRaw]] = + value match { + case ApplyBinaryOpRaw.Add(left, right) => + ( + gatherLiteralsInAddition(left), + gatherLiteralsInAddition(right) + ).mapN(_ add _) + case ApplyBinaryOpRaw.Sub( + left, + /** + * We support subtraction only with literal at the right side + * because we don't have unary minus operator for values. + * (Or this algo should be much more complex to support it) + * But this case is pretty commonly generated by compiler + * in gates: `join stream[len - 1]` (see `StreamGateInliner`) + */ + LiteralRaw.Integer(i) + ) => + ( + gatherLiteralsInAddition(left), + Eval.now(Ior.left(-i)) + // NOTE: Use add as sign is stored inside Long + ).mapN(_ add _) + case LiteralRaw.Integer(i) => + Ior.left(i).pure + case _ => + // Optimize expressions inside this value + Ior.right(value.mapValues(v => optimize(v).value)).pure + } + + /** + * Rewritten `Ior.combine` method + * with custom combination functions. + */ + private def combineWith( + l: Ior[Long, ValueRaw], + r: Ior[Long, ValueRaw] + )( + lf: (Long, Long) => Long, + rf: (ValueRaw, ValueRaw) => ValueRaw + ): Ior[Long, ValueRaw] = + l match { + case Ior.Left(lvl) => + r match { + case Ior.Left(rvl) => + Ior.left(lf(lvl, rvl)) + case Ior.Right(rvr) => + Ior.both(lvl, rvr) + case Ior.Both(rvl, rvr) => + Ior.both(lf(lvl, rvl), rvr) + } + case Ior.Right(lvr) => + r match { + case Ior.Left(rvl) => + Ior.both(rvl, lvr) + case Ior.Right(rvr) => + Ior.right(rf(lvr, rvr)) + case Ior.Both(rvl, rvr) => + Ior.both(rvl, rf(lvr, rvr)) + } + case Ior.Both(lvl, lvr) => + r match { + case Ior.Left(rvl) => + Ior.both(lf(lvl, rvl), lvr) + case Ior.Right(rvr) => + Ior.both(lvl, rf(lvr, rvr)) + case Ior.Both(rvl, rvr) => + Ior.both(lf(lvl, rvl), rf(lvr, rvr)) + } + } + + extension (l: Ior[Long, ValueRaw]) { + + def add(r: Ior[Long, ValueRaw]): Ior[Long, ValueRaw] = + combineWith(l, r)(_ + _, ApplyBinaryOpRaw.Add(_, _)) + } + } + +} diff --git a/model/raw/src/main/scala/aqua/raw/value/PropertyRaw.scala b/model/raw/src/main/scala/aqua/raw/value/PropertyRaw.scala index e5357ca5..8e4a60c3 100644 --- a/model/raw/src/main/scala/aqua/raw/value/PropertyRaw.scala +++ b/model/raw/src/main/scala/aqua/raw/value/PropertyRaw.scala @@ -6,6 +6,9 @@ import cats.data.NonEmptyMap sealed trait PropertyRaw { def `type`: Type + /** + * Apply function to values in this property + */ def map(f: ValueRaw => ValueRaw): PropertyRaw def renameVars(vals: Map[String, String]): PropertyRaw = this @@ -24,7 +27,8 @@ case class IntoArrowRaw(name: String, arrowType: Type, arguments: List[ValueRaw] override def `type`: Type = arrowType - override def map(f: ValueRaw => ValueRaw): PropertyRaw = this + override def map(f: ValueRaw => ValueRaw): PropertyRaw = + copy(arguments = arguments.map(f)) override def varNames: Set[String] = arguments.flatMap(_.varNames).toSet diff --git a/model/raw/src/main/scala/aqua/raw/value/ValueRaw.scala b/model/raw/src/main/scala/aqua/raw/value/ValueRaw.scala index 87fefdce..f5424521 100644 --- a/model/raw/src/main/scala/aqua/raw/value/ValueRaw.scala +++ b/model/raw/src/main/scala/aqua/raw/value/ValueRaw.scala @@ -15,7 +15,16 @@ sealed trait ValueRaw { def renameVars(map: Map[String, String]): ValueRaw - def map(f: ValueRaw => ValueRaw): ValueRaw + /** + * Apply function to all values in the tree + */ + final def map(f: ValueRaw => ValueRaw): ValueRaw = + f(mapValues(_.map(f))) + + /** + * Apply function to values in this value + */ + def mapValues(f: ValueRaw => ValueRaw): ValueRaw def varNames: Set[String] } @@ -60,6 +69,12 @@ object ValueRaw { type ApplyRaw = ApplyPropertyRaw | CallArrowRaw | CollectionRaw | ApplyBinaryOpRaw | ApplyUnaryOpRaw + + extension (v: ValueRaw) { + def add(a: ValueRaw): ValueRaw = ApplyBinaryOpRaw.Add(v, a) + + def increment: ValueRaw = ApplyBinaryOpRaw.Add(v, LiteralRaw.number(1)) + } } case class ApplyPropertyRaw(value: ValueRaw, property: PropertyRaw) extends ValueRaw { @@ -70,8 +85,8 @@ case class ApplyPropertyRaw(value: ValueRaw, property: PropertyRaw) extends Valu override def renameVars(map: Map[String, String]): ValueRaw = ApplyPropertyRaw(value.renameVars(map), property.renameVars(map)) - override def map(f: ValueRaw => ValueRaw): ValueRaw = - f(ApplyPropertyRaw(f(value), property.map(_.map(f)))) + override def mapValues(f: ValueRaw => ValueRaw): ValueRaw = + ApplyPropertyRaw(f(value), property.map(f)) override def toString: String = s"$value.$property" @@ -96,7 +111,7 @@ object ApplyPropertyRaw { case class VarRaw(name: String, baseType: Type) extends ValueRaw { - override def map(f: ValueRaw => ValueRaw): ValueRaw = f(this) + override def mapValues(f: ValueRaw => ValueRaw): ValueRaw = this override def renameVars(map: Map[String, String]): ValueRaw = copy(name = map.getOrElse(name, name)) @@ -110,7 +125,7 @@ case class VarRaw(name: String, baseType: Type) extends ValueRaw { } case class LiteralRaw(value: String, baseType: Type) extends ValueRaw { - override def map(f: ValueRaw => ValueRaw): ValueRaw = f(this) + override def mapValues(f: ValueRaw => ValueRaw): ValueRaw = this override def toString: String = s"{$value: ${baseType}}" @@ -122,12 +137,25 @@ case class LiteralRaw(value: String, baseType: Type) extends ValueRaw { object LiteralRaw { def quote(value: String): LiteralRaw = LiteralRaw("\"" + value + "\"", LiteralType.string) - def number(value: Int): LiteralRaw = LiteralRaw(value.toString, LiteralType.forInt(value)) + def number(value: Long): LiteralRaw = LiteralRaw(value.toString, LiteralType.forInt(value)) val Zero: LiteralRaw = number(0) val True: LiteralRaw = LiteralRaw("true", LiteralType.bool) val False: LiteralRaw = LiteralRaw("false", LiteralType.bool) + + object Integer { + + /* + * Used to match integer literals in pattern matching + */ + def unapply(value: ValueRaw): Option[Long] = + value match { + case LiteralRaw(value, t) if ScalarType.integer.exists(_.acceptsValueOf(t)) => + value.toLongOption + case _ => none + } + } } case class CollectionRaw(values: NonEmptyList[ValueRaw], boxType: BoxType) extends ValueRaw { @@ -136,10 +164,10 @@ case class CollectionRaw(values: NonEmptyList[ValueRaw], boxType: BoxType) exten override lazy val baseType: Type = boxType - override def map(f: ValueRaw => ValueRaw): ValueRaw = { + override def mapValues(f: ValueRaw => ValueRaw): ValueRaw = { val vals = values.map(f) val el = vals.map(_.`type`).reduceLeft(_ `∩` _) - f(copy(vals, boxType.withElement(el))) + copy(vals, boxType.withElement(el)) } override def varNames: Set[String] = values.toList.flatMap(_.varNames).toSet @@ -153,7 +181,8 @@ case class MakeStructRaw(fields: NonEmptyMap[String, ValueRaw], structType: Stru override def baseType: Type = structType - override def map(f: ValueRaw => ValueRaw): ValueRaw = f(copy(fields = fields.map(f))) + override def mapValues(f: ValueRaw => ValueRaw): ValueRaw = + copy(fields = fields.map(f)) override def varNames: Set[String] = { fields.toSortedMap.values.flatMap(_.varNames).toSet @@ -168,8 +197,8 @@ case class AbilityRaw(fieldsAndArrows: NonEmptyMap[String, ValueRaw], abilityTyp override def baseType: Type = abilityType - override def map(f: ValueRaw => ValueRaw): ValueRaw = - f(copy(fieldsAndArrows = fieldsAndArrows.map(f))) + override def mapValues(f: ValueRaw => ValueRaw): ValueRaw = + copy(fieldsAndArrows = fieldsAndArrows.map(f)) override def varNames: Set[String] = { fieldsAndArrows.toSortedMap.values.flatMap(_.varNames).toSet @@ -182,29 +211,79 @@ case class AbilityRaw(fieldsAndArrows: NonEmptyMap[String, ValueRaw], abilityTyp case class ApplyBinaryOpRaw( op: ApplyBinaryOpRaw.Op, left: ValueRaw, - right: ValueRaw + right: ValueRaw, + // TODO: Refactor type, get rid of `LiteralType` + resultType: ScalarType | LiteralType ) extends ValueRaw { - // Only boolean operations are supported for now - override def baseType: Type = ScalarType.bool + override val baseType: Type = resultType - override def map(f: ValueRaw => ValueRaw): ValueRaw = - f(copy(left = f(left), right = f(right))) + override def mapValues(f: ValueRaw => ValueRaw): ValueRaw = + copy(left = f(left), right = f(right)) override def varNames: Set[String] = left.varNames ++ right.varNames override def renameVars(map: Map[String, String]): ValueRaw = copy(left = left.renameVars(map), right = right.renameVars(map)) + + override def toString(): String = + s"(${left} ${op} ${right}) :: ${resultType}" } object ApplyBinaryOpRaw { enum Op { - case And - case Or + case And, Or + case Eq, Neq + case Lt, Lte, Gt, Gte + case Add, Sub, Mul, FMul, Div, Pow, Rem + } - case Eq - case Neq + object Op { + + type Bool = And.type | Or.type + + type Eq = Eq.type | Neq.type + + type Cmp = Lt.type | Lte.type | Gt.type | Gte.type + + type Math = Add.type | Sub.type | Mul.type | FMul.type | Div.type | Pow.type | Rem.type + } + + object Add { + + def apply(left: ValueRaw, right: ValueRaw): ValueRaw = + ApplyBinaryOpRaw( + Op.Add, + left, + right, + ScalarType.resolveMathOpType(left.`type`, right.`type`).`type` + ) + + def unapply(value: ValueRaw): Option[(ValueRaw, ValueRaw)] = + value match { + case ApplyBinaryOpRaw(Op.Add, left, right, _) => + (left, right).some + case _ => none + } + } + + object Sub { + + def apply(left: ValueRaw, right: ValueRaw): ValueRaw = + ApplyBinaryOpRaw( + Op.Sub, + left, + right, + ScalarType.resolveMathOpType(left.`type`, right.`type`).`type` + ) + + def unapply(value: ValueRaw): Option[(ValueRaw, ValueRaw)] = + value match { + case ApplyBinaryOpRaw(Op.Sub, left, right, _) => + (left, right).some + case _ => none + } } } @@ -216,8 +295,8 @@ case class ApplyUnaryOpRaw( // Only boolean operations are supported for now override def baseType: Type = ScalarType.bool - override def map(f: ValueRaw => ValueRaw): ValueRaw = - f(copy(value = f(value))) + override def mapValues(f: ValueRaw => ValueRaw): ValueRaw = + copy(value = f(value)) override def varNames: Set[String] = value.varNames @@ -237,37 +316,28 @@ case class CallArrowRaw( ability: Option[String], name: String, arguments: List[ValueRaw], - baseType: ArrowType, - // TODO: there should be no serviceId there - serviceId: Option[ValueRaw] + baseType: ArrowType ) extends ValueRaw { - override def `type`: Type = baseType.codomain.uncons.map(_._1).getOrElse(baseType) + override def `type`: Type = baseType.codomain.headOption.getOrElse(baseType) - override def map(f: ValueRaw => ValueRaw): ValueRaw = - f( - copy( - arguments = arguments.map(_.map(f)), - serviceId = serviceId.map(_.map(f)) - ) - ) + override def mapValues(f: ValueRaw => ValueRaw): ValueRaw = + copy(arguments = arguments.map(f)) override def varNames: Set[String] = name.some - .filterNot(_ => ability.isDefined || serviceId.isDefined) + .filterNot(_ => ability.isDefined) .toSet ++ arguments.flatMap(_.varNames).toSet override def renameVars(map: Map[String, String]): ValueRaw = copy( name = map .get(name) - // Rename only if it is **not** a service or ability call, see [bug LNG-199] + // Rename only if it is **not** an ability call, see [bug LNG-199] .filterNot(_ => ability.isDefined) - .filterNot(_ => serviceId.isDefined) .getOrElse(name) ) override def toString: String = - s"(call ${ability.fold("")(a => s"|$a| ")} (${serviceId.fold("")(_.toString + " ")}$name) [${arguments - .mkString(" ")}] :: $baseType)" + s"${ability.fold("")(a => s"$a.")}$name(${arguments.mkString(",")}) :: $baseType)" } object CallArrowRaw { @@ -280,8 +350,7 @@ object CallArrowRaw { ability = None, name = funcName, arguments = arguments, - baseType = baseType, - serviceId = None + baseType = baseType ) def ability( @@ -293,22 +362,46 @@ object CallArrowRaw { ability = None, name = AbilityType.fullName(abilityName, funcName), arguments = arguments, - baseType = baseType, - serviceId = None - ) - - def service( - abilityName: String, - serviceId: ValueRaw, - funcName: String, - baseType: ArrowType, - arguments: List[ValueRaw] = Nil - ): CallArrowRaw = CallArrowRaw( - ability = abilityName.some, - name = funcName, - arguments = arguments, - baseType = baseType, - serviceId = Some(serviceId) + baseType = baseType ) } + +/** + * WARNING: This class is internal and is used to generate code. + * Calls to services in aqua code are represented as [[CallArrowRaw]] + * and resolved through ability resolution. + * + * @param serviceId service id + * @param fnName service method name + * @param baseType type of the service method + * @param arguments call arguments + */ +case class CallServiceRaw( + serviceId: ValueRaw, + fnName: String, + baseType: ArrowType, + arguments: List[ValueRaw] +) extends ValueRaw { + override def `type`: Type = baseType.codomain.headOption.getOrElse(baseType) + + override def mapValues(f: ValueRaw => ValueRaw): ValueRaw = + copy( + serviceId = f(serviceId), + arguments = arguments.map(f) + ) + + override def varNames: Set[String] = + arguments + .flatMap(_.varNames) + .toSet ++ serviceId.varNames + + override def renameVars(map: Map[String, String]): ValueRaw = + copy( + serviceId = serviceId.renameVars(map), + arguments = arguments.map(_.renameVars(map)) + ) + + override def toString: String = + s"call (${serviceId}) $fnName(${arguments.mkString(",")}) :: $baseType)" +} diff --git a/model/src/main/scala/aqua/model/AquaContext.scala b/model/src/main/scala/aqua/model/AquaContext.scala index 739a41e1..e057c36f 100644 --- a/model/src/main/scala/aqua/model/AquaContext.scala +++ b/model/src/main/scala/aqua/model/AquaContext.scala @@ -144,18 +144,13 @@ object AquaContext extends Logging { blank.copy( module = Some(sm.name), funcs = sm.`type`.arrows.map { case (fnName, arrowType) => - val (args, call, ret) = ArgsCall.arrowToArgsCallRet(arrowType) - fnName -> - FuncArrow( - fnName, - // TODO: capture ability resolution, get ID from the call context - CallArrowRawTag.service(serviceId, fnName, call, sm.name).leaf, - arrowType, - ret.map(_.toRaw), - Map.empty, - Map.empty, - None - ) + fnName -> FuncArrow.fromServiceMethod( + fnName, + sm.name, + fnName, + arrowType, + serviceId + ) } ) diff --git a/model/src/main/scala/aqua/model/FuncArrow.scala b/model/src/main/scala/aqua/model/FuncArrow.scala index c9d46822..34894159 100644 --- a/model/src/main/scala/aqua/model/FuncArrow.scala +++ b/model/src/main/scala/aqua/model/FuncArrow.scala @@ -2,10 +2,12 @@ package aqua.model import aqua.raw.Raw import aqua.raw.arrow.FuncRaw -import aqua.raw.ops.{Call, CallArrowRawTag, RawTag} +import aqua.raw.ops.{Call, CallArrowRawTag, EmptyTag, RawTag} import aqua.raw.value.{ValueRaw, VarRaw} import aqua.types.{ArrowType, ServiceType, Type} +import cats.syntax.option.* + case class FuncArrow( funcName: String, body: RawTag.Tree, @@ -58,21 +60,28 @@ object FuncArrow { serviceName: String, methodName: String, methodType: ArrowType, - idValue: ValueModel + idValue: ValueModel | ValueRaw ): FuncArrow = { - val id = VarRaw("id", idValue.`type`) + val (id, capturedValues) = idValue match { + case i: ValueModel => + ( + VarRaw("id", i.`type`), + Map("id" -> i) + ) + case i: ValueRaw => (i, Map.empty[String, ValueModel]) + } val retVar = methodType.res.map(t => VarRaw("ret", t)) val call = Call( methodType.domain.toLabelledList().map(VarRaw.apply), retVar.map(r => Call.Export(r.name, r.`type`)).toList ) + val body = CallArrowRawTag.service( - serviceId = id, - fnName = methodName, + srvId = id, + funcName = methodName, call = call, - name = serviceName, - arrowType = methodType + arrowType = methodType.some ) FuncArrow( @@ -81,9 +90,7 @@ object FuncArrow { arrowType = methodType, ret = retVar.toList, capturedArrows = Map.empty, - capturedValues = Map( - id.name -> idValue - ), + capturedValues = capturedValues, capturedTopology = None ) } diff --git a/model/src/main/scala/aqua/model/ValueModel.scala b/model/src/main/scala/aqua/model/ValueModel.scala index 4a243127..71c578d4 100644 --- a/model/src/main/scala/aqua/model/ValueModel.scala +++ b/model/src/main/scala/aqua/model/ValueModel.scala @@ -7,6 +7,7 @@ import aqua.types.* import cats.Eq import cats.data.{Chain, NonEmptyMap} import cats.syntax.option.* +import cats.syntax.apply.* import scribe.Logging sealed trait ValueModel { @@ -91,6 +92,22 @@ object LiteralModel { } } + /* + * Used to match integer literals in pattern matching + */ + object Integer { + + def unapply(lm: LiteralModel): Option[(Long, ScalarType | LiteralType)] = + lm match { + case LiteralModel(value, t) if ScalarType.integer.exists(_.acceptsValueOf(t)) => + ( + value.toLongOption, + t.some.collect { case t: (ScalarType | LiteralType) => t } + ).tupled + case _ => none + } + } + // AquaVM will return 0 for // :error:.$.error_code if there is no :error: val emptyErrorCode = number(0) @@ -102,7 +119,7 @@ object LiteralModel { def quote(str: String): LiteralModel = LiteralModel(s"\"$str\"", LiteralType.string) - def number(n: Int): LiteralModel = LiteralModel(n.toString, LiteralType.forInt(n)) + def number(n: Long): LiteralModel = LiteralModel(n.toString, LiteralType.forInt(n)) def bool(b: Boolean): LiteralModel = LiteralModel(b.toString.toLowerCase, LiteralType.bool) } diff --git a/semantics/src/main/scala/aqua/semantics/expr/func/CallArrowSem.scala b/semantics/src/main/scala/aqua/semantics/expr/func/CallArrowSem.scala index 970fb5f8..d3c33edd 100644 --- a/semantics/src/main/scala/aqua/semantics/expr/func/CallArrowSem.scala +++ b/semantics/src/main/scala/aqua/semantics/expr/func/CallArrowSem.scala @@ -50,7 +50,7 @@ class CallArrowSem[S[_]](val expr: CallArrowExpr[S]) extends AnyVal { ) } yield tag.map(_.funcOpLeaf) - def program[Alg[_]: Monad](implicit + def program[Alg[_]: Monad](using N: NamesAlgebra[S, Alg], T: TypesAlgebra[S, Alg], V: ValuesAlgebra[S, Alg] diff --git a/semantics/src/main/scala/aqua/semantics/rules/ValuesAlgebra.scala b/semantics/src/main/scala/aqua/semantics/rules/ValuesAlgebra.scala index b3d0ffa4..2e37f648 100644 --- a/semantics/src/main/scala/aqua/semantics/rules/ValuesAlgebra.scala +++ b/semantics/src/main/scala/aqua/semantics/rules/ValuesAlgebra.scala @@ -184,7 +184,6 @@ class ValuesAlgebra[S[_], Alg[_]: Monad](using case (Some(leftRaw), Some(rightRaw)) => val lType = leftRaw.`type` val rType = rightRaw.`type` - lazy val uType = lType `∪` rType it.op match { case InfOp.Bool(bop) => @@ -200,7 +199,8 @@ class ValuesAlgebra[S[_], Alg[_]: Monad](using case BoolOp.Or => ApplyBinaryOpRaw.Op.Or }, left = leftRaw, - right = rightRaw + right = rightRaw, + resultType = ScalarType.bool ) ) case InfOp.Eq(eop) => @@ -216,7 +216,8 @@ class ValuesAlgebra[S[_], Alg[_]: Monad](using case EqOp.Neq => ApplyBinaryOpRaw.Op.Neq }, left = leftRaw, - right = rightRaw + right = rightRaw, + resultType = ScalarType.bool ) ) ) @@ -228,70 +229,62 @@ class ValuesAlgebra[S[_], Alg[_]: Monad](using case InfOp.Cmp(op) => op } - val hasFloat = List(lType, rType).exists( + lazy val hasFloat = List(lType, rType).exists( _ acceptsValueOf LiteralType.float ) - // See https://github.com/fluencelabs/aqua-lib/blob/main/math.aqua - val (id, fn) = iop match { - case MathOp.Add => ("math", "add") - case MathOp.Sub => ("math", "sub") - case MathOp.Mul if hasFloat => ("math", "fmul") - case MathOp.Mul => ("math", "mul") - case MathOp.Div => ("math", "div") - case MathOp.Rem => ("math", "rem") - case MathOp.Pow => ("math", "pow") - case CmpOp.Gt => ("cmp", "gt") - case CmpOp.Gte => ("cmp", "gte") - case CmpOp.Lt => ("cmp", "lt") - case CmpOp.Lte => ("cmp", "lte") + val bop = iop match { + case MathOp.Add => ApplyBinaryOpRaw.Op.Add + case MathOp.Sub => ApplyBinaryOpRaw.Op.Sub + case MathOp.Mul if hasFloat => ApplyBinaryOpRaw.Op.FMul + case MathOp.Mul => ApplyBinaryOpRaw.Op.Mul + case MathOp.Div => ApplyBinaryOpRaw.Op.Div + case MathOp.Rem => ApplyBinaryOpRaw.Op.Rem + case MathOp.Pow => ApplyBinaryOpRaw.Op.Pow + case CmpOp.Gt => ApplyBinaryOpRaw.Op.Gt + case CmpOp.Gte => ApplyBinaryOpRaw.Op.Gte + case CmpOp.Lt => ApplyBinaryOpRaw.Op.Lt + case CmpOp.Lte => ApplyBinaryOpRaw.Op.Lte } - /* - * If `uType == TopType`, it means that we don't - * have type big enough to hold the result of operation. - * e.g. We will use `i64` for result of `i32 * u64` - * TODO: Handle this more gracefully - * (use warning system when it is implemented) - */ - def uTypeBounded = if (uType == TopType) { - val bounded = ScalarType.i64 - logger.warn( - s"Result type of ($lType ${it.op} $rType) is $TopType, " + - s"using $bounded instead" - ) - bounded - } else uType + lazy val numbersTypeBounded: Alg[ScalarType | LiteralType] = { + val resType = ScalarType.resolveMathOpType(lType, rType) + report + .warning( + it, + s"Result type of ($lType ${it.op} $rType) is unknown, " + + s"using ${resType.`type`} instead" + ) + .whenA(resType.overflow) + .as(resType.`type`) + } // Expected type sets of left and right operands, result type - val (leftExp, rightExp, resType) = iop match { + val (leftExp, rightExp, resTypeM) = iop match { case MathOp.Add | MathOp.Sub | MathOp.Div | MathOp.Rem => - (ScalarType.integer, ScalarType.integer, uTypeBounded) + (ScalarType.integer, ScalarType.integer, numbersTypeBounded) case MathOp.Pow => - (ScalarType.integer, ScalarType.unsigned, uTypeBounded) + (ScalarType.integer, ScalarType.unsigned, numbersTypeBounded) case MathOp.Mul if hasFloat => - (ScalarType.float, ScalarType.float, ScalarType.i64) + (ScalarType.float, ScalarType.float, ScalarType.i64.pure) case MathOp.Mul => - (ScalarType.integer, ScalarType.integer, uTypeBounded) + (ScalarType.integer, ScalarType.integer, numbersTypeBounded) case CmpOp.Gt | CmpOp.Lt | CmpOp.Gte | CmpOp.Lte => - (ScalarType.integer, ScalarType.integer, ScalarType.bool) + (ScalarType.integer, ScalarType.integer, ScalarType.bool.pure) } for { leftChecked <- T.ensureTypeOneOf(l, leftExp, lType) rightChecked <- T.ensureTypeOneOf(r, rightExp, rType) + resType <- resTypeM } yield Option.when( leftChecked.isDefined && rightChecked.isDefined )( - CallArrowRaw.service( - abilityName = id, - serviceId = LiteralRaw.quote(id), - funcName = fn, - baseType = ArrowType( - ProductType(lType :: rType :: Nil), - ProductType(resType :: Nil) - ), - arguments = leftRaw :: rightRaw :: Nil + ApplyBinaryOpRaw( + op = bop, + left = leftRaw, + right = rightRaw, + resultType = resType ) ) diff --git a/semantics/src/test/scala/aqua/semantics/SemanticsSpec.scala b/semantics/src/test/scala/aqua/semantics/SemanticsSpec.scala index 0331a2c6..ff327b1c 100644 --- a/semantics/src/test/scala/aqua/semantics/SemanticsSpec.scala +++ b/semantics/src/test/scala/aqua/semantics/SemanticsSpec.scala @@ -108,10 +108,10 @@ class SemanticsSpec extends AnyFlatSpec with Matchers with Inside { .leaf def equ(left: ValueRaw, right: ValueRaw): ApplyBinaryOpRaw = - ApplyBinaryOpRaw(ApplyBinaryOpRaw.Op.Eq, left, right) + ApplyBinaryOpRaw(ApplyBinaryOpRaw.Op.Eq, left, right, ScalarType.bool) def neq(left: ValueRaw, right: ValueRaw): ApplyBinaryOpRaw = - ApplyBinaryOpRaw(ApplyBinaryOpRaw.Op.Neq, left, right) + ApplyBinaryOpRaw(ApplyBinaryOpRaw.Op.Neq, left, right, ScalarType.bool) def declareStreamPush( name: String, diff --git a/semantics/src/test/scala/aqua/semantics/ValuesAlgebraSpec.scala b/semantics/src/test/scala/aqua/semantics/ValuesAlgebraSpec.scala index a8a4bd9f..8e86e160 100644 --- a/semantics/src/test/scala/aqua/semantics/ValuesAlgebraSpec.scala +++ b/semantics/src/test/scala/aqua/semantics/ValuesAlgebraSpec.scala @@ -267,7 +267,7 @@ class ValuesAlgebraSpec extends AnyFlatSpec with Matchers with Inside { .run(state) .value - inside(res) { case Some(ApplyBinaryOpRaw(bop, _, _)) => + inside(res) { case Some(ApplyBinaryOpRaw(bop, _, _, ScalarType.bool)) => bop shouldBe (op match { case InfixToken.BoolOp.And => ApplyBinaryOpRaw.Op.And case InfixToken.BoolOp.Or => ApplyBinaryOpRaw.Op.Or @@ -319,7 +319,7 @@ class ValuesAlgebraSpec extends AnyFlatSpec with Matchers with Inside { .run(state) .value - inside(res) { case Some(ApplyBinaryOpRaw(bop, _, _)) => + inside(res) { case Some(ApplyBinaryOpRaw(bop, _, _, ScalarType.bool)) => bop shouldBe (op match { case InfixToken.EqOp.Eq => ApplyBinaryOpRaw.Op.Eq case InfixToken.EqOp.Neq => ApplyBinaryOpRaw.Op.Neq diff --git a/types/src/main/scala/aqua/types/Type.scala b/types/src/main/scala/aqua/types/Type.scala index 218a4e3b..c0492e63 100644 --- a/types/src/main/scala/aqua/types/Type.scala +++ b/types/src/main/scala/aqua/types/Type.scala @@ -53,6 +53,11 @@ sealed trait ProductType extends Type { case _ => None } + def headOption: Option[Type] = this match { + case ConsType(t, _) => Some(t) + case _ => None + } + lazy val toList: List[Type] = this match { case ConsType(t, pt) => t :: pt.toList case _ => Nil @@ -182,6 +187,45 @@ object ScalarType { val integer = signed ++ unsigned val number = float ++ integer val all = number ++ Set(bool, string) + + final case class MathOpType( + `type`: ScalarType | LiteralType, + overflow: Boolean + ) + + /** + * Resolve type of math operation + * on two given types. + * + * WARNING: General `Type` is accepted + * but only integer `ScalarType` and `LiteralType` + * are actually expected. + */ + def resolveMathOpType( + lType: Type, + rType: Type + ): MathOpType = { + val uType = lType `∪` rType + uType match { + case t: (ScalarType | LiteralType) => MathOpType(t, false) + case _ => MathOpType(ScalarType.i64, true) + } + } + + /** + * Check if given type is signed. + * + * NOTE: Only integer types are expected. + * But it is impossible to enforce it. + */ + def isSignedInteger(t: ScalarType | LiteralType): Boolean = + t match { + case st: ScalarType => signed.contains(st) + /** + * WARNING: LiteralType.unsigned is signed integer! + */ + case lt: LiteralType => lt.oneOf.exists(signed.contains) + } } case class LiteralType private (oneOf: Set[ScalarType], name: String) extends DataType { @@ -200,7 +244,7 @@ object LiteralType { val bool = LiteralType(Set(ScalarType.bool), "bool") val string = LiteralType(Set(ScalarType.string), "string") - def forInt(n: Int): LiteralType = if (n < 0) signed else unsigned + def forInt(n: Long): LiteralType = if (n < 0) signed else unsigned } sealed trait BoxType extends DataType { @@ -323,8 +367,7 @@ case class StructType(name: String, fields: NonEmptyMap[String, Type]) s"$name{${fields.map(_.toString).toNel.toList.map(kv => kv._1 + ": " + kv._2).mkString(", ")}}" } -case class StreamMapType(element: Type) - extends DataType { +case class StreamMapType(element: Type) extends DataType { override def toString: String = s"%$element" }