feat(compiler): Optimize math in compile time [LNG-245] (#922)

* Move service calls for math to inlining

* Fix: add predo

* Introduce CallServiceRaw

* Add comment

* Add optimization to inlining

* Add tests

* map -> mapValues

* Refactor type

* Add optimization

* Add optimization test

* Fix unit tests

* Fix PR comments

* Restrict optimization

* Add substraction to optimization

* Apply optimizations in gate

* Fix sign, move optimization to unfold

* Fix type and tests

* Fix unit tests

* Add unit test

* Fix after merge

* Add optimization, fix unit tests

* Fix comment
This commit is contained in:
InversionSpaces 2023-10-09 12:02:26 +02:00 committed by GitHub
parent b298eebf5e
commit 5f6c47ffea
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
26 changed files with 1439 additions and 699 deletions

View File

@ -15,16 +15,20 @@ import aqua.res.*
import aqua.res.ResBuilder import aqua.res.ResBuilder
import aqua.types.{ArrayType, CanonStreamType, LiteralType, ScalarType, StreamType, Type} import aqua.types.{ArrayType, CanonStreamType, LiteralType, ScalarType, StreamType, Type}
import org.scalatest.flatspec.AnyFlatSpec
import org.scalatest.matchers.should.Matchers
import cats.Id import cats.Id
import cats.data.{Chain, NonEmptyChain, NonEmptyMap, Validated, ValidatedNec} import cats.data.{Chain, NonEmptyChain, NonEmptyMap, Validated, ValidatedNec}
import cats.instances.string.* import cats.instances.string.*
import cats.syntax.show.* import cats.syntax.show.*
import cats.syntax.option.* import cats.syntax.option.*
import cats.syntax.either.* import cats.syntax.either.*
import org.scalatest.flatspec.AnyFlatSpec
import org.scalatest.matchers.should.Matchers
import org.scalatest.Inside
import aqua.model.AquaContext
import aqua.model.FlattenModel
import aqua.model.CallServiceModel
class AquaCompilerSpec extends AnyFlatSpec with Matchers { class AquaCompilerSpec extends AnyFlatSpec with Matchers with Inside {
import ModelBuilder.* import ModelBuilder.*
private def aquaSource(src: Map[String, String], imports: Map[String, String]) = { private def aquaSource(src: Map[String, String], imports: Map[String, String]) = {
@ -45,8 +49,13 @@ class AquaCompilerSpec extends AnyFlatSpec with Matchers {
} }
} }
private def compileToContext(src: Map[String, String], imports: Map[String, String]) = private def insideContext(
CompilerAPI src: Map[String, String],
imports: Map[String, String] = Map.empty
)(
test: AquaContext => Any
) = {
val compiled = CompilerAPI
.compileToContext[Id, String, String, Span.S]( .compileToContext[Id, String, String, Span.S](
aquaSource(src, imports), aquaSource(src, imports),
id => txt => Parser.parse(Parser.parserSchema)(txt), id => txt => Parser.parse(Parser.parserSchema)(txt),
@ -56,10 +65,29 @@ class AquaCompilerSpec extends AnyFlatSpec with Matchers {
.value .value
.toValidated .toValidated
inside(compiled) { case Validated.Valid(contexts) =>
inside(contexts.headOption) { case Some(ctx) =>
test(ctx)
}
}
}
private def insideRes(
src: Map[String, String],
imports: Map[String, String] = Map.empty,
transformCfg: TransformConfig = TransformConfig()
)(funcNames: String*)(
test: PartialFunction[List[FuncRes], Any]
) = insideContext(src, imports)(ctx =>
val aquaRes = Transform.contextRes(ctx, transformCfg)
// To preserve order as in funcNames do flatMap
val funcs = funcNames.flatMap(name => aquaRes.funcs.find(_.funcName == name)).toList
inside(funcs)(test)
)
"aqua compiler" should "compile a simple snippet to the right context" in { "aqua compiler" should "compile a simple snippet to the right context" in {
val res = compileToContext( val src = Map(
Map(
"index.aqua" -> "index.aqua" ->
"""module Foo declares X """module Foo declares X
| |
@ -73,16 +101,9 @@ class AquaCompilerSpec extends AnyFlatSpec with Matchers {
|func foo2() -> string: |func foo2() -> string:
| <- "hello2?" | <- "hello2?"
|""".stripMargin |""".stripMargin
),
Map.empty
) )
res.isValid should be(true) insideContext(src) { ctx =>
val Validated.Valid(ctxs) = res
ctxs.length should be(1)
val ctx = ctxs.headOption.get
ctx.allFuncs.contains("foo") should be(true) ctx.allFuncs.contains("foo") should be(true)
ctx.allFuncs.contains("foo_two") should be(true) ctx.allFuncs.contains("foo_two") should be(true)
@ -90,6 +111,7 @@ class AquaCompilerSpec extends AnyFlatSpec with Matchers {
const.nonEmpty should be(true) const.nonEmpty should be(true)
const.get should be(LiteralModel.number(5)) const.get should be(LiteralModel.number(5))
} }
}
def through(peer: ValueModel) = def through(peer: ValueModel) =
MakeRes.hop(peer) MakeRes.hop(peer)
@ -110,10 +132,8 @@ class AquaCompilerSpec extends AnyFlatSpec with Matchers {
private def join(vm: VarModel, size: ValueModel) = private def join(vm: VarModel, size: ValueModel) =
ResBuilder.join(vm, size, init) ResBuilder.join(vm, size, init)
"aqua compiler" should "create right topology" in { it should "create right topology" in {
val src = Map(
val res = compileToContext(
Map(
"index.aqua" -> "index.aqua" ->
"""service Op("op"): """service Op("op"):
| identity(s: string) -> string | identity(s: string) -> string
@ -126,29 +146,19 @@ class AquaCompilerSpec extends AnyFlatSpec with Matchers {
| |
| join results[2] | join results[2]
| <- results""".stripMargin | <- results""".stripMargin
),
Map.empty
) )
res.isValid should be(true)
val Validated.Valid(ctxs) = res
ctxs.length should be(1)
val ctx = ctxs.headOption.get
val transformCfg = TransformConfig() val transformCfg = TransformConfig()
val aquaRes = Transform.contextRes(ctx, transformCfg)
val Some(exec) = aquaRes.funcs.find(_.funcName == "exec")
insideRes(src, transformCfg = transformCfg)("exec") { case exec :: _ =>
val peers = VarModel("-peers-arg-", ArrayType(ScalarType.string)) val peers = VarModel("-peers-arg-", ArrayType(ScalarType.string))
val peer = VarModel("peer-0", ScalarType.string) val peer = VarModel("peer-0", ScalarType.string)
val resultsType = StreamType(ScalarType.string) val resultsType = StreamType(ScalarType.string)
val results = VarModel("results", resultsType) val results = VarModel("results", resultsType)
val canonResult = VarModel("-" + results.name + "-fix-0", CanonStreamType(resultsType.element)) val canonResult =
VarModel("-" + results.name + "-fix-0", CanonStreamType(resultsType.element))
val flatResult = VarModel("-results-flat-0", ArrayType(ScalarType.string)) val flatResult = VarModel("-results-flat-0", ArrayType(ScalarType.string))
val initPeer = LiteralModel.fromRaw(ValueRaw.InitPeerId) val initPeer = LiteralModel.fromRaw(ValueRaw.InitPeerId)
val sizeVar = VarModel("results_size", LiteralType.unsigned)
val retVar = VarModel("ret", ScalarType.string) val retVar = VarModel("ret", ScalarType.string)
val expected = val expected =
@ -188,14 +198,12 @@ class AquaCompilerSpec extends AnyFlatSpec with Matchers {
) )
) )
), ),
ResBuilder.add( join(results, LiteralModel.number(3)), // Compiler optimized addition
LiteralModel.number(2), CanonRes(
LiteralModel.number(1), results,
sizeVar, init,
initPeer CallModel.Export(canonResult.name, canonResult.`type`)
), ).leaf,
join(results, sizeVar),
CanonRes(results, init, CallModel.Export(canonResult.name, canonResult.`type`)).leaf,
ApRes( ApRes(
canonResult, canonResult,
CallModel.Export(flatResult.name, flatResult.`type`) CallModel.Export(flatResult.name, flatResult.`type`)
@ -209,11 +217,11 @@ class AquaCompilerSpec extends AnyFlatSpec with Matchers {
exec.body.equalsOrShowDiff(expected) shouldBe (true) exec.body.equalsOrShowDiff(expected) shouldBe (true)
} }
}
"aqua compiler" should "compile with imports" in { it should "compile with imports" in {
val res = compileToContext( val src = Map(
Map(
"index.aqua" -> "index.aqua" ->
"""module Import """module Import
|import foobar from "export2.aqua" |import foobar from "export2.aqua"
@ -231,8 +239,8 @@ class AquaCompilerSpec extends AnyFlatSpec with Matchers {
| -- Exp.f() returns literal, this func must return literal in AIR as well | -- Exp.f() returns literal, this func must return literal in AIR as well
| <- z | <- z
|""".stripMargin |""".stripMargin
), )
Map( val imports = Map(
"export2.aqua" -> "export2.aqua" ->
"""module Export declares foobar, foo """module Export declares foobar, foo
| |
@ -256,20 +264,13 @@ class AquaCompilerSpec extends AnyFlatSpec with Matchers {
| consume(s: string) | consume(s: string)
|""".stripMargin |""".stripMargin
) )
)
res.isValid should be(true)
val Validated.Valid(ctxs) = res
ctxs.length should be(1)
val ctx = ctxs.headOption.get
val transformCfg = TransformConfig(relayVarName = None) val transformCfg = TransformConfig(relayVarName = None)
val aquaRes = Transform.contextRes(ctx, transformCfg)
val Some(funcWrap) = aquaRes.funcs.find(_.funcName == "wrap")
val Some(barfoo) = aquaRes.funcs.find(_.funcName == "barfoo")
insideRes(src, imports, transformCfg)(
"wrap",
"barfoo"
) { case wrap :: barfoo :: _ =>
val resStreamType = StreamType(ScalarType.string) val resStreamType = StreamType(ScalarType.string)
val resVM = VarModel("res", resStreamType) val resVM = VarModel("res", resStreamType)
val resCanonVM = VarModel("-res-fix-0", CanonStreamType(ScalarType.string)) val resCanonVM = VarModel("-res-fix-0", CanonStreamType(ScalarType.string))
@ -308,6 +309,60 @@ class AquaCompilerSpec extends AnyFlatSpec with Matchers {
) )
barfoo.body.equalsOrShowDiff(expected) should be(true) barfoo.body.equalsOrShowDiff(expected) should be(true)
}
}
it should "optimize math inside stream join" in {
val src = Map(
"main.aqua" -> """
|func main(i: i32):
| stream: *string
| stream <<- "a"
| stream <<- "b"
| join stream[i - 1]
|""".stripMargin
)
val transformCfg = TransformConfig()
val streamName = "stream"
val streamType = StreamType(ScalarType.string)
val argName = "-i-arg-"
val argType = ScalarType.i32
val arg = VarModel(argName, argType)
/**
* NOTE: Compiler generates this unused decrement bc
* it doesn't know that we are inlining just join
* and do not need to access the element.
*/
val decrement = CallServiceRes(
LiteralModel.quote("math"),
"sub",
CallRes(
List(arg, LiteralModel.number(1)),
Some(CallModel.Export("stream_idx", argType))
),
LiteralModel.fromRaw(ValueRaw.InitPeerId)
).leaf
val expected = XorRes.wrap(
SeqRes.wrap(
getDataSrv("-relay-", "-relay-", ScalarType.string),
getDataSrv("i", argName, argType),
RestrictionRes(streamName, streamType).wrap(
SeqRes.wrap(
ApRes(LiteralModel.quote("a"), CallModel.Export(streamName, streamType)).leaf,
ApRes(LiteralModel.quote("b"), CallModel.Export(streamName, streamType)).leaf,
join(VarModel(streamName, streamType), arg),
decrement
)
)
),
errorCall(transformCfg, 0, initPeer)
)
insideRes(src, transformCfg = transformCfg)("main") { case main :: _ =>
main.body.equalsOrShowDiff(expected) should be(true)
}
} }
} }

View File

@ -3,20 +3,12 @@ package aqua.model.inline
import aqua.model.inline.state.{Arrows, Counter, Exports, Mangler} import aqua.model.inline.state.{Arrows, Counter, Exports, Mangler}
import aqua.model.inline.Inline.MergeMode.* import aqua.model.inline.Inline.MergeMode.*
import aqua.model.* import aqua.model.*
import aqua.model.inline.raw.{ import aqua.model.inline.raw.*
ApplyBinaryOpRawInliner,
ApplyFunctorRawInliner,
ApplyPropertiesRawInliner,
ApplyUnaryOpRawInliner,
CallArrowRawInliner,
CollectionRawInliner,
MakeAbilityRawInliner,
StreamGateInliner
}
import aqua.raw.ops.* import aqua.raw.ops.*
import aqua.raw.value.* import aqua.raw.value.*
import aqua.types.{ArrayType, LiteralType, OptionType, StreamType} import aqua.types.{ArrayType, LiteralType, OptionType, StreamType}
import cats.Eval
import cats.syntax.traverse.* import cats.syntax.traverse.*
import cats.syntax.monoid.* import cats.syntax.monoid.*
import cats.syntax.functor.* import cats.syntax.functor.*
@ -34,8 +26,10 @@ object RawValueInliner extends Logging {
private[inline] def unfold[S: Mangler: Exports: Arrows]( private[inline] def unfold[S: Mangler: Exports: Arrows](
raw: ValueRaw, raw: ValueRaw,
propertiesAllowed: Boolean = true propertiesAllowed: Boolean = true
): State[S, (ValueModel, Inline)] = ): State[S, (ValueModel, Inline)] = for {
raw match { optimized <- StateT.liftF(Optimization.optimize(raw))
_ <- StateT.liftF(Eval.later(logger.trace("OPTIMIZIED " + optimized)))
result <- optimized match {
case VarRaw(name, t) => case VarRaw(name, t) =>
for { for {
exports <- Exports[S].exports exports <- Exports[S].exports
@ -65,7 +59,12 @@ object RawValueInliner extends Logging {
case cr: CallArrowRaw => case cr: CallArrowRaw =>
CallArrowRawInliner(cr, propertiesAllowed) CallArrowRawInliner(cr, propertiesAllowed)
case cs: CallServiceRaw =>
CallServiceRawInliner(cs, propertiesAllowed)
} }
} yield result
private[inline] def inlineToTree[S: Mangler: Exports: Arrows]( private[inline] def inlineToTree[S: Mangler: Exports: Arrows](
inline: Inline inline: Inline
@ -101,10 +100,10 @@ object RawValueInliner extends Logging {
def valueToModel[S: Mangler: Exports: Arrows]( def valueToModel[S: Mangler: Exports: Arrows](
value: ValueRaw, value: ValueRaw,
propertiesAllowed: Boolean = true propertiesAllowed: Boolean = true
): State[S, (ValueModel, Option[OpModel.Tree])] = { ): State[S, (ValueModel, Option[OpModel.Tree])] = for {
logger.trace("RAW " + value) _ <- StateT.liftF(Eval.later(logger.trace("RAW " + value)))
toModel(unfold(value, propertiesAllowed)) model <- toModel(unfold(value, propertiesAllowed))
} } yield model
def valueListToModel[S: Mangler: Exports: Arrows]( def valueListToModel[S: Mangler: Exports: Arrows](
values: List[ValueRaw] values: List[ValueRaw]

View File

@ -4,7 +4,7 @@ import aqua.errors.Errors.internalError
import aqua.model.inline.state.{Arrows, Exports, Mangler} import aqua.model.inline.state.{Arrows, Exports, Mangler}
import aqua.model.* import aqua.model.*
import aqua.model.inline.RawValueInliner.collectionToModel import aqua.model.inline.RawValueInliner.collectionToModel
import aqua.model.inline.raw.CallArrowRawInliner import aqua.model.inline.raw.{CallArrowRawInliner, CallServiceRawInliner}
import aqua.raw.value.ApplyBinaryOpRaw.Op as BinOp import aqua.raw.value.ApplyBinaryOpRaw.Op as BinOp
import aqua.raw.ops.* import aqua.raw.ops.*
import aqua.raw.value.* import aqua.raw.value.*
@ -308,8 +308,13 @@ object TagInliner extends Logging {
TagInlined.Empty(prefix = parDesugarPrefix(nel.toList.flatMap(_._2))) TagInlined.Empty(prefix = parDesugarPrefix(nel.toList.flatMap(_._2)))
}) })
case CallArrowRawTag(exportTo, value: CallArrowRaw) => case CallArrowRawTag(exportTo, value: (CallArrowRaw | CallServiceRaw)) =>
CallArrowRawInliner.unfoldArrow(value, exportTo).flatMap { case (_, inline) => (value match {
case ca: CallArrowRaw =>
CallArrowRawInliner.unfold(ca, exportTo)
case cs: CallServiceRaw =>
CallServiceRawInliner.unfold(cs, exportTo)
}).flatMap { case (_, inline) =>
RawValueInliner RawValueInliner
.inlineToTree(inline) .inlineToTree(inline)
.map(tree => .map(tree =>

View File

@ -1,5 +1,6 @@
package aqua.model.inline.raw package aqua.model.inline.raw
import aqua.errors.Errors.internalError
import aqua.model.* import aqua.model.*
import aqua.model.inline.raw.RawInliner import aqua.model.inline.raw.RawInliner
import aqua.model.inline.TagInliner import aqua.model.inline.TagInliner
@ -8,8 +9,9 @@ import aqua.raw.value.{AbilityRaw, LiteralRaw, MakeStructRaw}
import cats.data.{NonEmptyList, NonEmptyMap, State} import cats.data.{NonEmptyList, NonEmptyMap, State}
import aqua.model.inline.Inline import aqua.model.inline.Inline
import aqua.model.inline.RawValueInliner.{unfold, valueToModel} import aqua.model.inline.RawValueInliner.{unfold, valueToModel}
import aqua.types.{ArrowType, ScalarType} import aqua.types.{ArrowType, ScalarType, Type}
import aqua.raw.value.ApplyBinaryOpRaw import aqua.raw.value.ApplyBinaryOpRaw
import aqua.raw.value.ApplyBinaryOpRaw.Op
import aqua.raw.value.ApplyBinaryOpRaw.Op.* import aqua.raw.value.ApplyBinaryOpRaw.Op.*
import aqua.model.inline.Inline.MergeMode import aqua.model.inline.Inline.MergeMode
@ -21,12 +23,10 @@ import cats.syntax.flatMap.*
import cats.syntax.apply.* import cats.syntax.apply.*
import cats.syntax.foldable.* import cats.syntax.foldable.*
import cats.syntax.applicative.* import cats.syntax.applicative.*
import aqua.types.LiteralType
object ApplyBinaryOpRawInliner extends RawInliner[ApplyBinaryOpRaw] { object ApplyBinaryOpRawInliner extends RawInliner[ApplyBinaryOpRaw] {
private type BoolOp = And.type | Or.type
private type EqOp = Eq.type | Neq.type
override def apply[S: Mangler: Exports: Arrows]( override def apply[S: Mangler: Exports: Arrows](
raw: ApplyBinaryOpRaw, raw: ApplyBinaryOpRaw,
propertiesAllowed: Boolean propertiesAllowed: Boolean
@ -37,16 +37,49 @@ object ApplyBinaryOpRawInliner extends RawInliner[ApplyBinaryOpRaw] {
(rmodel, rinline) = right (rmodel, rinline) = right
result <- raw.op match { result <- raw.op match {
case op @ (And | Or) => inlineBoolOp(lmodel, rmodel, linline, rinline, op) case op: Op.Bool =>
case op @ (Eq | Neq) => inlineBoolOp(
lmodel,
rmodel,
linline,
rinline,
op,
raw.baseType
)
case op: Op.Eq =>
for { for {
// Canonicalize stream operands before comparison // Canonicalize stream operands before comparison
leftStream <- TagInliner.canonicalizeIfStream(lmodel) leftStream <- TagInliner.canonicalizeIfStream(lmodel)
(lmodelStream, linlineStream) = leftStream.map(linline.append) (lmodelStream, linlineStream) = leftStream.map(linline.append)
rightStream <- TagInliner.canonicalizeIfStream(rmodel) rightStream <- TagInliner.canonicalizeIfStream(rmodel)
(rmodelStream, rinlineStream) = rightStream.map(rinline.append) (rmodelStream, rinlineStream) = rightStream.map(rinline.append)
result <- inlineEqOp(lmodelStream, rmodelStream, linlineStream, rinlineStream, op) result <- inlineEqOp(
lmodelStream,
rmodelStream,
linlineStream,
rinlineStream,
op,
raw.baseType
)
} yield result } yield result
case op: Op.Cmp =>
inlineCmpOp(
lmodel,
rmodel,
linline,
rinline,
op,
raw.baseType
)
case op: Op.Math =>
inlineMathOp(
lmodel,
rmodel,
linline,
rinline,
op,
raw.baseType
)
} }
} yield result } yield result
@ -55,7 +88,8 @@ object ApplyBinaryOpRawInliner extends RawInliner[ApplyBinaryOpRaw] {
rmodel: ValueModel, rmodel: ValueModel,
linline: Inline, linline: Inline,
rinline: Inline, rinline: Inline,
op: EqOp op: Op.Eq,
resType: Type
): State[S, (ValueModel, Inline)] = (lmodel, rmodel) match { ): State[S, (ValueModel, Inline)] = (lmodel, rmodel) match {
// Optimize in case compared values are literals // Optimize in case compared values are literals
// Semantics should check that types are comparable // Semantics should check that types are comparable
@ -69,7 +103,7 @@ object ApplyBinaryOpRawInliner extends RawInliner[ApplyBinaryOpRaw] {
}, },
linline.mergeWith(rinline, MergeMode.ParMode) linline.mergeWith(rinline, MergeMode.ParMode)
).pure[State[S, *]] ).pure[State[S, *]]
case _ => fullInlineEqOp(lmodel, rmodel, linline, rinline, op) case _ => fullInlineEqOp(lmodel, rmodel, linline, rinline, op, resType)
} }
private def fullInlineEqOp[S: Mangler: Exports: Arrows]( private def fullInlineEqOp[S: Mangler: Exports: Arrows](
@ -77,7 +111,8 @@ object ApplyBinaryOpRawInliner extends RawInliner[ApplyBinaryOpRaw] {
rmodel: ValueModel, rmodel: ValueModel,
linline: Inline, linline: Inline,
rinline: Inline, rinline: Inline,
op: EqOp op: Op.Eq,
resType: Type
): State[S, (ValueModel, Inline)] = { ): State[S, (ValueModel, Inline)] = {
val (name, shouldMatch) = op match { val (name, shouldMatch) = op match {
case Eq => ("eq", true) case Eq => ("eq", true)
@ -114,7 +149,7 @@ object ApplyBinaryOpRawInliner extends RawInliner[ApplyBinaryOpRaw] {
) )
) )
result(name, predo) result(name, resType, predo)
} }
private def inlineBoolOp[S: Mangler: Exports: Arrows]( private def inlineBoolOp[S: Mangler: Exports: Arrows](
@ -122,7 +157,8 @@ object ApplyBinaryOpRawInliner extends RawInliner[ApplyBinaryOpRaw] {
rmodel: ValueModel, rmodel: ValueModel,
linline: Inline, linline: Inline,
rinline: Inline, rinline: Inline,
op: BoolOp op: Op.Bool,
resType: Type
): State[S, (ValueModel, Inline)] = (lmodel, rmodel) match { ): State[S, (ValueModel, Inline)] = (lmodel, rmodel) match {
// Optimize in case of left value is known at compile time // Optimize in case of left value is known at compile time
case (LiteralModel.Bool(lvalue), _) => case (LiteralModel.Bool(lvalue), _) =>
@ -139,7 +175,7 @@ object ApplyBinaryOpRawInliner extends RawInliner[ApplyBinaryOpRaw] {
case _ => (lmodel, linline) case _ => (lmodel, linline)
}).pure[State[S, *]] }).pure[State[S, *]]
// Produce unoptimized inline // Produce unoptimized inline
case _ => fullInlineBoolOp(lmodel, rmodel, linline, rinline, op) case _ => fullInlineBoolOp(lmodel, rmodel, linline, rinline, op, resType)
} }
private def fullInlineBoolOp[S: Mangler: Exports: Arrows]( private def fullInlineBoolOp[S: Mangler: Exports: Arrows](
@ -147,7 +183,8 @@ object ApplyBinaryOpRawInliner extends RawInliner[ApplyBinaryOpRaw] {
rmodel: ValueModel, rmodel: ValueModel,
linline: Inline, linline: Inline,
rinline: Inline, rinline: Inline,
op: BoolOp op: Op.Bool,
resType: Type
): State[S, (ValueModel, Inline)] = { ): State[S, (ValueModel, Inline)] = {
val (name, compareWith) = op match { val (name, compareWith) = op match {
case And => ("and", false) case And => ("and", false)
@ -190,19 +227,162 @@ object ApplyBinaryOpRawInliner extends RawInliner[ApplyBinaryOpRaw] {
) )
) )
result(name, predo) result(name, resType, predo)
}
private def inlineCmpOp[S: Mangler: Exports: Arrows](
lmodel: ValueModel,
rmodel: ValueModel,
linline: Inline,
rinline: Inline,
op: Op.Cmp,
resType: Type
): State[S, (ValueModel, Inline)] = (lmodel, rmodel) match {
case (
LiteralModel.Integer(lv, _),
LiteralModel.Integer(rv, _)
) =>
val res = op match {
case Lt => lv < rv
case Lte => lv <= rv
case Gt => lv > rv
case Gte => lv >= rv
}
(
LiteralModel.bool(res),
Inline(linline.predo ++ rinline.predo)
).pure
case _ =>
val fn = op match {
case Lt => "lt"
case Lte => "lte"
case Gt => "gt"
case Gte => "gte"
}
val predo = (resName: String) =>
SeqModel.wrap(
linline.predo ++ rinline.predo :+ CallServiceModel(
serviceId = LiteralModel.quote("cmp"),
funcName = fn,
call = CallModel(
args = lmodel :: rmodel :: Nil,
exportTo = CallModel.Export(resName, resType) :: Nil
)
).leaf
)
result(fn, resType, predo)
}
private def inlineMathOp[S: Mangler: Exports: Arrows](
lmodel: ValueModel,
rmodel: ValueModel,
linline: Inline,
rinline: Inline,
op: Op.Math,
resType: Type
): State[S, (ValueModel, Inline)] = (lmodel, rmodel) match {
case (
LiteralModel.Integer(lv, lt),
LiteralModel.Integer(rv, rt)
) if canOptimizeMath(lv, lt, rv, rt, op) =>
val res = op match {
case Add => lv + rv
case Sub => lv - rv
case Mul => lv * rv
case Div => lv / rv
case Rem => lv % rv
case Pow => intPow(lv, rv)
case _ => internalError(s"Unsupported operation $op for $lv and $rv")
}
(
LiteralModel.number(res),
Inline(linline.predo ++ rinline.predo)
).pure
case _ =>
val fn = op match {
case Add => "add"
case Sub => "sub"
case Mul => "mul"
case FMul => "fmul"
case Div => "div"
case Rem => "rem"
case Pow => "pow"
}
val predo = (resName: String) =>
SeqModel.wrap(
linline.predo ++ rinline.predo :+ CallServiceModel(
serviceId = LiteralModel.quote("math"),
funcName = fn,
call = CallModel(
args = lmodel :: rmodel :: Nil,
exportTo = CallModel.Export(resName, resType) :: Nil
)
).leaf
)
result(fn, resType, predo)
} }
private def result[S: Mangler]( private def result[S: Mangler](
name: String, name: String,
resType: Type,
predo: String => OpModel.Tree predo: String => OpModel.Tree
): State[S, (ValueModel, Inline)] = ): State[S, (ValueModel, Inline)] =
Mangler[S] Mangler[S]
.findAndForbidName(name) .findAndForbidName(name)
.map(resName => .map(resName =>
( (
VarModel(resName, ScalarType.bool), VarModel(resName, resType),
Inline(Chain.one(predo(resName))) Inline(Chain.one(predo(resName)))
) )
) )
/**
* Check if we can optimize math operation
* in compile time.
*
* @param left left operand
* @param leftType type of left operand
* @param right right operand
* @param rightType type of right operand
* @param op operation
* @return true if we can optimize this operation
*/
private def canOptimizeMath(
left: Long,
leftType: ScalarType | LiteralType,
right: Long,
rightType: ScalarType | LiteralType,
op: Op.Math
): Boolean = op match {
// Leave division by zero for runtime
case Op.Div | Op.Rem => right != 0
// Leave negative power for runtime
case Op.Pow => right >= 0
case Op.Sub =>
// Leave subtraction overflow for runtime
ScalarType.isSignedInteger(leftType) ||
ScalarType.isSignedInteger(rightType)
case _ => true
}
/**
* Integer power (binary exponentiation)
*
* @param base
* @param exp >= 0
* @return base ** exp
*/
private def intPow(base: Long, exp: Long): Long = {
def intPowTailRec(base: Long, exp: Long, acc: Long): Long =
if (exp <= 0) acc
else intPowTailRec(base * base, exp / 2, if (exp % 2 == 0) acc else acc * base)
intPowTailRec(base, exp, 1)
}
} }

View File

@ -257,51 +257,59 @@ object ApplyPropertiesRawInliner extends RawInliner[ApplyPropertyRaw] with Loggi
idx: ValueRaw idx: ValueRaw
): State[S, (VarModel, Inline)] = for { ): State[S, (VarModel, Inline)] = for {
/** /**
* Inline idx * Inline size, which is `idx + 1`
* Increment on ValueRaw level to
* apply possible optimizations
*/ */
idxInlined <- unfold(idx) sizeInlined <- unfold(idx.increment)
(sizeVM, sizeInline) = sizeInlined
/**
* Inline idx which is `size - 1`
* TODO: Do not generate it if
* it is not needed, e.g. in `join`
*/
idxInlined <- sizeVM match {
/**
* Micro optimization: if idx is a literal
* do not generate inline.
*/
case LiteralModel.Integer(i, t) =>
(LiteralModel((i - 1).toString, t), Inline.empty).pure[State[S, *]]
case _ =>
Mangler[S].findAndForbidName(s"${streamName}_idx").map { idxName =>
val idxVar = VarModel(idxName, sizeVM.`type`)
val idxInline = Inline.tree(
CallServiceModel(
"math",
funcName = "sub",
args = List(sizeVM, LiteralModel.number(1)),
result = idxVar
).leaf
)
(idxVar, idxInline)
}
}
(idxVM, idxInline) = idxInlined (idxVM, idxInline) = idxInlined
/** /**
* Inline size which is `idx + 1` * Inline join of `size` elements of stream
* TODO: Refactor to apply optimizations
*/ */
sizeName <- Mangler[S].findAndForbidName(s"${streamName}_size") gateInlined <- StreamGateInliner(streamName, streamType, sizeVM)
sizeVar = VarModel(sizeName, idxVM.`type`)
sizeInline = CallServiceModel(
"math",
funcName = "add",
args = List(idxVM, LiteralModel.number(1)),
result = sizeVar
).leaf
gateInlined <- StreamGateInliner(streamName, streamType, sizeVar)
(gateVM, gateInline) = gateInlined (gateVM, gateInline) = gateInlined
/**
* Remove properties from idx
* as we need to use it in index
* TODO: Do not generate it
* if it is not needed,
* e.g. in `join`
*/
idxFlattened <- idxVM match {
case vr: VarModel => removeProperties(vr)
case _ => (idxVM, Inline.empty).pure[State[S, *]]
}
(idxFlat, idxFlatInline) = idxFlattened
/** /**
* Construct stream[idx] * Construct stream[idx]
*/ */
gate = gateVM.withProperty( gate = gateVM.withProperty(
IntoIndexModel IntoIndexModel
.fromValueModel(idxFlat, streamType.element) .fromValueModel(idxVM, streamType.element)
.getOrElse( .getOrElse(
internalError(s"Unexpected: could not convert ($idxFlat) to IntoIndexModel") internalError(s"Unexpected: could not convert ($idxVM) to IntoIndexModel")
) )
) )
} yield gate -> Inline( } yield gate -> Inline(
idxInline.predo sizeInline.predo ++
.append(sizeInline) ++
gateInline.predo ++ gateInline.predo ++
idxFlatInline.predo, idxInline.predo,
mergeMode = SeqMode mergeMode = SeqMode
) )

View File

@ -7,40 +7,21 @@ import aqua.model.inline.state.{Arrows, Exports, Mangler}
import aqua.model.inline.{ArrowInliner, Inline, TagInliner} import aqua.model.inline.{ArrowInliner, Inline, TagInliner}
import aqua.raw.ops.Call import aqua.raw.ops.Call
import aqua.raw.value.CallArrowRaw import aqua.raw.value.CallArrowRaw
import cats.data.{Chain, State} import cats.data.{Chain, State}
import cats.syntax.traverse.* import cats.syntax.traverse.*
import scribe.Logging import scribe.Logging
object CallArrowRawInliner extends RawInliner[CallArrowRaw] with Logging { object CallArrowRawInliner extends RawInliner[CallArrowRaw] with Logging {
private[inline] def unfoldArrow[S: Mangler: Exports: Arrows]( private[inline] def unfold[S: Mangler: Exports: Arrows](
value: CallArrowRaw, value: CallArrowRaw,
exportTo: List[Call.Export] exportTo: List[Call.Export]
): State[S, (List[ValueModel], Inline)] = Exports[S].exports.flatMap { exports => ): State[S, (List[ValueModel], Inline)] = Exports[S].exports.flatMap { exports =>
logger.trace(s"${exportTo.mkString(" ")} $value") logger.trace(s"${exportTo.mkString(" ")} $value")
val call = Call(value.arguments, exportTo) val call = Call(value.arguments, exportTo)
value.serviceId match {
case Some(serviceId) =>
logger.trace(Console.BLUE + s"call service id $serviceId" + Console.RESET)
for {
cd <- callToModel(call, true)
(callModel, callInline) = cd
sd <- valueToModel(serviceId)
(serviceIdValue, serviceIdInline) = sd
values = callModel.exportTo.map(e => e.name -> e.asVar.resolveWith(exports)).toMap
inline = Inline(
Chain(
SeqModel.wrap(
serviceIdInline.toList ++ callInline.toList :+
CallServiceModel(serviceIdValue, value.name, callModel).leaf
)
)
)
_ <- Exports[S].resolved(values)
_ <- Mangler[S].forbid(values.keySet)
} yield values.values.toList -> inline
case None =>
/** /**
* Here the back hop happens from [[TagInliner]] to [[ArrowInliner.callArrow]] * Here the back hop happens from [[TagInliner]] to [[ArrowInliner.callArrow]]
*/ */
@ -49,7 +30,6 @@ object CallArrowRawInliner extends RawInliner[CallArrowRaw] with Logging {
resolveArrow(funcName, call) resolveArrow(funcName, call)
} }
}
private def resolveFuncArrow[S: Mangler: Exports: Arrows]( private def resolveFuncArrow[S: Mangler: Exports: Arrows](
fn: FuncArrow, fn: FuncArrow,
@ -103,7 +83,7 @@ object CallArrowRawInliner extends RawInliner[CallArrowRaw] with Logging {
Mangler[S] Mangler[S]
.findAndForbidName(raw.name) .findAndForbidName(raw.name)
.flatMap(n => .flatMap(n =>
unfoldArrow(raw, Call.Export(n, raw.`type`) :: Nil).map { unfold(raw, Call.Export(n, raw.`type`) :: Nil).map {
case (Nil, inline) => (VarModel(n, raw.`type`), inline) case (Nil, inline) => (VarModel(n, raw.`type`), inline)
case (h :: _, inline) => (h, inline) case (h :: _, inline) => (h, inline)
} }

View File

@ -0,0 +1,57 @@
package aqua.model.inline.raw
import aqua.errors.Errors.internalError
import aqua.model.*
import aqua.model.inline.RawValueInliner.{callToModel, valueToModel}
import aqua.model.inline.state.{Arrows, Exports, Mangler}
import aqua.model.inline.{ArrowInliner, Inline, TagInliner}
import aqua.raw.ops.Call
import aqua.raw.value.CallServiceRaw
import cats.data.{Chain, State}
import cats.syntax.traverse.*
import scribe.Logging
object CallServiceRawInliner extends RawInliner[CallServiceRaw] with Logging {
private[inline] def unfold[S: Mangler: Exports: Arrows](
value: CallServiceRaw,
exportTo: List[Call.Export]
): State[S, (List[ValueModel], Inline)] = Exports[S].exports.flatMap { exports =>
logger.trace(s"${exportTo.mkString(" ")} $value")
logger.trace(Console.BLUE + s"call service id ${value.serviceId}" + Console.RESET)
val call = Call(value.arguments, exportTo)
for {
cd <- callToModel(call, true)
(callModel, callInline) = cd
sd <- valueToModel(value.serviceId)
(serviceIdValue, serviceIdInline) = sd
values = callModel.exportTo.map(e => e.name -> e.asVar.resolveWith(exports)).toMap
inline = Inline(
Chain(
SeqModel.wrap(
serviceIdInline.toList ++ callInline.toList :+
CallServiceModel(serviceIdValue, value.fnName, callModel).leaf
)
)
)
_ <- Exports[S].resolved(values)
_ <- Mangler[S].forbid(values.keySet)
} yield values.values.toList -> inline
}
override def apply[S: Mangler: Exports: Arrows](
raw: CallServiceRaw,
propertiesAllowed: Boolean
): State[S, (ValueModel, Inline)] =
Mangler[S]
.findAndForbidName(raw.fnName)
.flatMap(n =>
unfold(raw, Call.Export(n, raw.`type`) :: Nil).map {
case (Nil, inline) => (VarModel(n, raw.`type`), inline)
case (h :: _, inline) => (h, inline)
}
)
}

View File

@ -4,7 +4,6 @@ import aqua.errors.Errors.internalError
import aqua.model.* import aqua.model.*
import aqua.model.inline.Inline import aqua.model.inline.Inline
import aqua.model.inline.state.{Arrows, Exports, Mangler} import aqua.model.inline.state.{Arrows, Exports, Mangler}
import aqua.raw.value.{LiteralRaw, VarRaw}
import aqua.model.inline.RawValueInliner.unfold import aqua.model.inline.RawValueInliner.unfold
import aqua.types.{ArrayType, CanonStreamType, ScalarType, StreamType} import aqua.types.{ArrayType, CanonStreamType, ScalarType, StreamType}
@ -25,7 +24,6 @@ object StreamGateInliner extends Logging {
* (seq * (seq
* (fold $stream s * (fold $stream s
* (seq * (seq
* (seq
* (ap s $stream_test) * (ap s $stream_test)
* (canon <peer> $stream_test #stream_iter_canon) * (canon <peer> $stream_test #stream_iter_canon)
* ) * )
@ -35,7 +33,6 @@ object StreamGateInliner extends Logging {
* ) * )
* (next s) * (next s)
* ) * )
* )
* (never) * (never)
* ) * )
* (canon <peer> $stream_test #stream_result_canon) * (canon <peer> $stream_test #stream_result_canon)
@ -100,7 +97,6 @@ object StreamGateInliner extends Logging {
uniqueCanonName <- Mangler[S].findAndForbidName(streamName + "_result_canon") uniqueCanonName <- Mangler[S].findAndForbidName(streamName + "_result_canon")
uniqueResultName <- Mangler[S].findAndForbidName(streamName + "_gate") uniqueResultName <- Mangler[S].findAndForbidName(streamName + "_gate")
uniqueTestName <- Mangler[S].findAndForbidName(streamName + "_test") uniqueTestName <- Mangler[S].findAndForbidName(streamName + "_test")
uniqueIdxIncr <- Mangler[S].findAndForbidName(streamName + "_incr")
uniqueIterCanon <- Mangler[S].findAndForbidName(streamName + "_iter_canon") uniqueIterCanon <- Mangler[S].findAndForbidName(streamName + "_iter_canon")
uniqueIter <- Mangler[S].findAndForbidName(streamName + "_fold_var") uniqueIter <- Mangler[S].findAndForbidName(streamName + "_fold_var")
} yield { } yield {

View File

@ -21,7 +21,7 @@ final case class IfTagInliner(
def inlined[S: Mangler: Exports: Arrows] = def inlined[S: Mangler: Exports: Arrows] =
(valueRaw match { (valueRaw match {
// Optimize in case last operation is equality check // Optimize in case last operation is equality check
case ApplyBinaryOpRaw(op @ (BinOp.Eq | BinOp.Neq), left, right) => case ApplyBinaryOpRaw(op @ (BinOp.Eq | BinOp.Neq), left, right, _) =>
( (
valueToModel(left) >>= canonicalizeIfStream, valueToModel(left) >>= canonicalizeIfStream,
valueToModel(right) >>= canonicalizeIfStream valueToModel(right) >>= canonicalizeIfStream

View File

@ -827,7 +827,7 @@ class ArrowInlinerSpec extends AnyFlatSpec with Matchers with Inside {
val res1 = VarModel("res", ScalarType.u16) val res1 = VarModel("res", ScalarType.u16)
val res2 = VarModel("res2", ScalarType.u16) val res2 = VarModel("res2", ScalarType.u16)
val res3 = VarModel("res-0", ScalarType.u16) val res3 = VarModel("res-0", ScalarType.u16)
val tempAdd = VarModel("add-0", ScalarType.u16) val tempAdd = VarModel("add", ScalarType.u16)
val expected = SeqModel.wrap( val expected = SeqModel.wrap(
MetaModel MetaModel
@ -843,7 +843,7 @@ class ArrowInlinerSpec extends AnyFlatSpec with Matchers with Inside {
), ),
SeqModel.wrap( SeqModel.wrap(
ModelBuilder.add(res2, res3)(tempAdd).leaf, ModelBuilder.add(res2, res3)(tempAdd).leaf,
ModelBuilder.add(res1, tempAdd)(VarModel("add", ScalarType.u16)).leaf ModelBuilder.add(res1, tempAdd)(VarModel("add-0", ScalarType.u16)).leaf
) )
) )
@ -889,15 +889,12 @@ class ArrowInlinerSpec extends AnyFlatSpec with Matchers with Inside {
capturedTopology = None capturedTopology = None
) )
val innerCall = CallArrowRaw( val innerCall = CallArrowRaw.func(
ability = None, funcName = innerName,
name = innerName,
arguments = Nil,
baseType = ArrowType( baseType = ArrowType(
domain = NilType, domain = NilType,
codomain = ProductType(List(ScalarType.u16)) codomain = ProductType(List(ScalarType.u16))
), )
serviceId = None
) )
val outerAdd = "37" val outerAdd = "37"
@ -943,36 +940,18 @@ class ArrowInlinerSpec extends AnyFlatSpec with Matchers with Inside {
.runA(InliningState()) .runA(InliningState())
.value .value
/* WARNING: This naming is unstable */ // Addition is completely optimized out
val tempAdd0 = VarModel("add-0", ScalarType.u16) model.equalsOrShowDiff(EmptyModel.leaf) shouldEqual true
val tempAdd = VarModel("add", ScalarType.u16)
val expected = SeqModel.wrap(
ModelBuilder
.add(
LiteralModel(innerRet, ScalarType.u16),
LiteralModel(outerAdd, ScalarType.u16)
)(tempAdd0)
.leaf,
ModelBuilder
.add(
LiteralModel(innerRet, ScalarType.u16),
tempAdd0
)(tempAdd)
.leaf
)
model.equalsOrShowDiff(expected) shouldEqual true
} }
/** /**
* closureName = (x: u16) -> u16: * closureName = (x: u16) -> u16:
* retval = x + add * retval <- TestSrv.call(x, add)
* <- retval * <- retval
* *
* @return (closure func, closure type, closure type labelled) * @return (closure func, closure type, closure type labelled)
*/ */
def addClosure( def srvCallClosure(
closureName: String, closureName: String,
add: ValueRaw add: ValueRaw
): (FuncRaw, ArrowType, ArrowType) = { ): (FuncRaw, ArrowType, ArrowType) = {
@ -993,13 +972,16 @@ class ArrowInlinerSpec extends AnyFlatSpec with Matchers with Inside {
) )
val closureBody = SeqTag.wrap( val closureBody = SeqTag.wrap(
AssignmentTag( CallArrowRawTag
RawBuilder.add( .service(
closureArg, LiteralRaw.quote("test-srv"),
add funcName = "call",
), Call(
closureRes.name args = List(closureArg, add),
).leaf, exportTo = List(Call.Export(closureRes.name, closureRes.`type`))
)
)
.leaf,
ReturnTag( ReturnTag(
NonEmptyList.one(closureRes) NonEmptyList.one(closureRes)
).leaf ).leaf
@ -1017,10 +999,21 @@ class ArrowInlinerSpec extends AnyFlatSpec with Matchers with Inside {
(closureFunc, closureType, closureTypeLabelled) (closureFunc, closureType, closureTypeLabelled)
} }
def srvCallModel(
x: ValueModel,
add: ValueModel,
result: VarModel
): CallServiceModel = CallServiceModel(
serviceId = "test-srv",
funcName = "call",
args = List(x, add),
result = result
)
/** /**
* func innerName(arg: u16) -> u16 -> u16: * func innerName(arg: u16) -> u16 -> u16:
* closureName = (x: u16) -> u16: * closureName = (x: u16) -> u16:
* retval = x + arg * retval <- TestSrv.call(x, arg)
* <- retval * <- retval
* <- closureName * <- closureName
* *
@ -1042,7 +1035,7 @@ class ArrowInlinerSpec extends AnyFlatSpec with Matchers with Inside {
) )
val (closureFunc, closureType, closureTypeLabelled) = val (closureFunc, closureType, closureTypeLabelled) =
addClosure(closureName, innerArg) srvCallClosure(closureName, innerArg)
val innerRes = VarRaw( val innerRes = VarRaw(
closureName, closureName,
@ -1085,12 +1078,10 @@ class ArrowInlinerSpec extends AnyFlatSpec with Matchers with Inside {
val innerCall = val innerCall =
CallArrowRawTag( CallArrowRawTag(
List(Call.Export(outterClosure.name, outterClosure.`type`)), List(Call.Export(outterClosure.name, outterClosure.`type`)),
CallArrowRaw( CallArrowRaw.func(
ability = None, funcName = innerName,
name = innerName,
arguments = List(LiteralRaw("42", LiteralType.number)),
baseType = innerType, baseType = innerType,
serviceId = None arguments = List(LiteralRaw("42", LiteralType.number))
) )
).leaf ).leaf
@ -1128,7 +1119,7 @@ class ArrowInlinerSpec extends AnyFlatSpec with Matchers with Inside {
/** /**
* func inner(arg: u16) -> u16 -> u16: * func inner(arg: u16) -> u16 -> u16:
* closure = (x: u16) -> u16: * closure = (x: u16) -> u16:
* retval = x + arg * retval <- TestSrv.call(x, arg)
* <- retval * <- retval
* <- closure * <- closure
* *
@ -1144,12 +1135,10 @@ class ArrowInlinerSpec extends AnyFlatSpec with Matchers with Inside {
val outterResultName = "retval" val outterResultName = "retval"
val closureCall = (closureType: ArrowType, i: String) => val closureCall = (closureType: ArrowType, i: String) =>
CallArrowRaw( CallArrowRaw.func(
ability = None, funcName = outterClosureName,
name = outterClosureName,
arguments = List(LiteralRaw(i, LiteralType.number)),
baseType = closureType, baseType = closureType,
serviceId = None arguments = List(LiteralRaw(i, LiteralType.unsigned))
) )
val body = (closureType: ArrowType) => val body = (closureType: ArrowType) =>
@ -1157,7 +1146,7 @@ class ArrowInlinerSpec extends AnyFlatSpec with Matchers with Inside {
AssignmentTag( AssignmentTag(
RawBuilder.add( RawBuilder.add(
RawBuilder.add( RawBuilder.add(
LiteralRaw("37", LiteralType.number), LiteralRaw("37", LiteralType.unsigned),
closureCall(closureType, "1") closureCall(closureType, "1")
), ),
closureCall(closureType, "2") closureCall(closureType, "2")
@ -1180,20 +1169,19 @@ class ArrowInlinerSpec extends AnyFlatSpec with Matchers with Inside {
.wrap( .wrap(
ApplyTopologyModel(closureName) ApplyTopologyModel(closureName)
.wrap( .wrap(
ModelBuilder srvCallModel(
.add( LiteralModel(x, LiteralType.unsigned),
LiteralModel(x, LiteralType.number), LiteralModel("42", LiteralType.unsigned),
LiteralModel("42", LiteralType.number) result = o
)(o) ).leaf
.leaf
) )
) )
/* WARNING: This naming is unstable */ /* WARNING: This naming is unstable */
val tempAdd0 = VarModel("add-0", ScalarType.u16) val retval1 = VarModel("retval-0", ScalarType.u16)
val tempAdd1 = VarModel("add-1", ScalarType.u16) val retval2 = VarModel("retval-1", ScalarType.u16)
val tempAdd2 = VarModel("add-2", ScalarType.u16)
val tempAdd = VarModel("add", ScalarType.u16) val tempAdd = VarModel("add", ScalarType.u16)
val tempAdd0 = VarModel("add-0", ScalarType.u16)
val expected = SeqModel.wrap( val expected = SeqModel.wrap(
MetaModel MetaModel
@ -1202,23 +1190,21 @@ class ArrowInlinerSpec extends AnyFlatSpec with Matchers with Inside {
CaptureTopologyModel(closureName).leaf CaptureTopologyModel(closureName).leaf
), ),
SeqModel.wrap( SeqModel.wrap(
ParModel.wrap(
SeqModel.wrap( SeqModel.wrap(
closureCallModel("1", tempAdd1), closureCallModel("1", retval1),
closureCallModel("2", retval2),
ModelBuilder ModelBuilder
.add( .add(
LiteralModel("37", LiteralType.number), retval1,
tempAdd1 retval2
)(tempAdd0) )(tempAdd)
.leaf .leaf
), ),
closureCallModel("2", tempAdd2)
),
ModelBuilder ModelBuilder
.add( .add(
tempAdd0, tempAdd,
tempAdd2 LiteralModel.number(37)
)(tempAdd) )(tempAdd0)
.leaf .leaf
) )
) )
@ -1306,22 +1292,18 @@ class ArrowInlinerSpec extends AnyFlatSpec with Matchers with Inside {
val innerCall = val innerCall =
CallArrowRawTag( CallArrowRawTag(
List(Call.Export(outterClosure.name, outterClosure.`type`)), List(Call.Export(outterClosure.name, outterClosure.`type`)),
CallArrowRaw( CallArrowRaw.func(
ability = None, funcName = innerName,
name = innerName,
arguments = Nil,
baseType = innerType, baseType = innerType,
serviceId = None arguments = Nil
) )
).leaf ).leaf
val closureCall = val closureCall =
CallArrowRaw( CallArrowRaw.func(
ability = None, funcName = outterClosure.name,
name = outterClosure.name,
arguments = Nil,
baseType = closureType, baseType = closureType,
serviceId = None arguments = Nil
) )
val outerBody = SeqTag.wrap( val outerBody = SeqTag.wrap(
@ -1365,38 +1347,14 @@ class ArrowInlinerSpec extends AnyFlatSpec with Matchers with Inside {
.runA(InliningState()) .runA(InliningState())
.value .value
/* WARNING: This naming is unstable */ // Addition is completely optimized out
val tempAdd0 = VarModel("add-0", ScalarType.u16) model.equalsOrShowDiff(EmptyModel.leaf) shouldEqual true
val tempAdd = VarModel("add", ScalarType.u16)
val number = (v: String) =>
LiteralModel(
v,
LiteralType.number
)
val expected = SeqModel.wrap(
ModelBuilder
.add(
number("37"),
number("42")
)(tempAdd0)
.leaf,
ModelBuilder
.add(
tempAdd0,
number("42")
)(tempAdd)
.leaf
)
model.equalsOrShowDiff(expected) shouldEqual true
} }
/** /**
* func inner(arg: u16) -> u16 -> u16: * func inner(arg: u16) -> u16 -> u16:
* closure = (x: u16) -> u16: * closure = (x: u16) -> u16:
* retval = x + arg * retval = TestSrv.call(x, arg)
* <- retval * <- retval
* <- closure * <- closure
* *
@ -1404,7 +1362,7 @@ class ArrowInlinerSpec extends AnyFlatSpec with Matchers with Inside {
* c <- inner(42) * c <- inner(42)
* b = c * b = c
* a = b * a = b
* retval = 37 + a(1) + b(2) + c{3} * retval = 37 + a(1) + b(2) + c(3)
* <- retval * <- retval
*/ */
it should "correctly inline renamed closure [bug LNG-193]" in { it should "correctly inline renamed closure [bug LNG-193]" in {
@ -1416,12 +1374,10 @@ class ArrowInlinerSpec extends AnyFlatSpec with Matchers with Inside {
val secondRename = "a" val secondRename = "a"
val closureCall = (name: String, closureType: ArrowType, i: String) => val closureCall = (name: String, closureType: ArrowType, i: String) =>
CallArrowRaw( CallArrowRaw.func(
ability = None, funcName = name,
name = name,
arguments = List(LiteralRaw(i, LiteralType.number)), arguments = List(LiteralRaw(i, LiteralType.number)),
baseType = closureType, baseType = closureType
serviceId = None
) )
val body = (closureType: ArrowType) => val body = (closureType: ArrowType) =>
@ -1458,28 +1414,27 @@ class ArrowInlinerSpec extends AnyFlatSpec with Matchers with Inside {
body = body body = body
) )
val closureCallModel = (x: String, o: VarModel) => val closureCallModel = (x: Long, o: VarModel) =>
MetaModel MetaModel
.CallArrowModel(closureName) .CallArrowModel(closureName)
.wrap( .wrap(
ApplyTopologyModel(closureName) ApplyTopologyModel(closureName)
.wrap( .wrap(
ModelBuilder srvCallModel(
.add( LiteralModel.number(x),
LiteralModel(x, LiteralType.number), LiteralModel.number(42),
LiteralModel("42", LiteralType.number) result = o
)(o) ).leaf
.leaf
) )
) )
/* WARNING: This naming is unstable */ /* WARNING: This naming is unstable */
val tempAdd = VarModel("add", ScalarType.u16)
val tempAdd0 = VarModel("add-0", ScalarType.u16) val tempAdd0 = VarModel("add-0", ScalarType.u16)
val tempAdd1 = VarModel("add-1", ScalarType.u16) val tempAdd1 = VarModel("add-1", ScalarType.u16)
val tempAdd2 = VarModel("add-2", ScalarType.u16) val retval0 = VarModel("retval-0", ScalarType.u16)
val tempAdd3 = VarModel("add-3", ScalarType.u16) val retval1 = VarModel("retval-1", ScalarType.u16)
val tempAdd4 = VarModel("add-4", ScalarType.u16) val retval2 = VarModel("retval-2", ScalarType.u16)
val tempAdd = VarModel("add", ScalarType.u16)
val expected = SeqModel.wrap( val expected = SeqModel.wrap(
MetaModel MetaModel
@ -1488,35 +1443,16 @@ class ArrowInlinerSpec extends AnyFlatSpec with Matchers with Inside {
CaptureTopologyModel(closureName).leaf CaptureTopologyModel(closureName).leaf
), ),
SeqModel.wrap( SeqModel.wrap(
ParModel.wrap(
SeqModel.wrap( SeqModel.wrap(
ParModel.wrap(
SeqModel.wrap( SeqModel.wrap(
closureCallModel("1", tempAdd2), closureCallModel(1, retval0),
ModelBuilder closureCallModel(2, retval1),
.add( ModelBuilder.add(retval0, retval1)(tempAdd).leaf
LiteralModel("37", LiteralType.number),
tempAdd2
)(tempAdd1)
.leaf
), ),
closureCallModel("2", tempAdd3) closureCallModel(3, retval2),
ModelBuilder.add(tempAdd, retval2)(tempAdd0).leaf
), ),
ModelBuilder ModelBuilder.add(tempAdd0, LiteralModel.number(37))(tempAdd1).leaf
.add(
tempAdd1,
tempAdd3
)(tempAdd0)
.leaf
),
closureCallModel("3", tempAdd4)
),
ModelBuilder
.add(
tempAdd0,
tempAdd4
)(tempAdd)
.leaf
) )
) )
@ -1530,7 +1466,7 @@ class ArrowInlinerSpec extends AnyFlatSpec with Matchers with Inside {
* *
* func test() -> u16: * func test() -> u16:
* closure = (x: u16) -> u16: * closure = (x: u16) -> u16:
* resC = x + 37 * resC <- TestSrv.call(x, 37)
* <- resC * <- resC
* resT <- accept_closure(closure) * resT <- accept_closure(closure)
* <- resT * <- resT
@ -1543,7 +1479,7 @@ class ArrowInlinerSpec extends AnyFlatSpec with Matchers with Inside {
val testRes = VarRaw("resT", ScalarType.u16) val testRes = VarRaw("resT", ScalarType.u16)
val (closureFunc, closureType, closureTypeLabelled) = val (closureFunc, closureType, closureTypeLabelled) =
addClosure(closureName, LiteralRaw("37", LiteralType.number)) srvCallClosure(closureName, LiteralRaw.number(37))
val acceptType = ArrowType( val acceptType = ArrowType(
domain = ProductType.labelled(List(closureName -> closureType)), domain = ProductType.labelled(List(closureName -> closureType)),
@ -1553,12 +1489,10 @@ class ArrowInlinerSpec extends AnyFlatSpec with Matchers with Inside {
val acceptBody = SeqTag.wrap( val acceptBody = SeqTag.wrap(
CallArrowRawTag( CallArrowRawTag(
List(Call.Export(acceptRes.name, acceptRes.baseType)), List(Call.Export(acceptRes.name, acceptRes.baseType)),
CallArrowRaw( CallArrowRaw.func(
ability = None, funcName = closureName,
name = closureName,
arguments = List(LiteralRaw("42", LiteralType.number)),
baseType = closureType, baseType = closureType,
serviceId = None arguments = List(LiteralRaw.number(42))
) )
).leaf, ).leaf,
ReturnTag( ReturnTag(
@ -1586,12 +1520,10 @@ class ArrowInlinerSpec extends AnyFlatSpec with Matchers with Inside {
).leaf, ).leaf,
CallArrowRawTag( CallArrowRawTag(
List(Call.Export(testRes.name, testRes.baseType)), List(Call.Export(testRes.name, testRes.baseType)),
CallArrowRaw( CallArrowRaw.func(
ability = None, funcName = acceptName,
name = acceptName,
arguments = List(VarRaw(closureName, closureTypeLabelled)),
baseType = acceptFunc.arrowType, baseType = acceptFunc.arrowType,
serviceId = None arguments = List(VarRaw(closureName, closureTypeLabelled))
) )
).leaf, ).leaf,
ReturnTag( ReturnTag(
@ -1621,7 +1553,7 @@ class ArrowInlinerSpec extends AnyFlatSpec with Matchers with Inside {
.value .value
/* WARNING: This naming is unstable */ /* WARNING: This naming is unstable */
val tempAdd = VarModel("add", ScalarType.u16) val retval = VarModel("retval", ScalarType.u16)
val expected = SeqModel.wrap( val expected = SeqModel.wrap(
CaptureTopologyModel(closureName).leaf, CaptureTopologyModel(closureName).leaf,
@ -1632,12 +1564,11 @@ class ArrowInlinerSpec extends AnyFlatSpec with Matchers with Inside {
.CallArrowModel(closureName) .CallArrowModel(closureName)
.wrap( .wrap(
ApplyTopologyModel(closureName).wrap( ApplyTopologyModel(closureName).wrap(
ModelBuilder srvCallModel(
.add( LiteralModel.number(42),
LiteralModel("42", LiteralType.number), LiteralModel.number(37),
LiteralModel("37", LiteralType.number) retval
)(tempAdd) ).leaf
.leaf
) )
) )
) )
@ -1673,17 +1604,16 @@ class ArrowInlinerSpec extends AnyFlatSpec with Matchers with Inside {
val testBody = SeqTag.wrap( val testBody = SeqTag.wrap(
CallArrowRawTag CallArrowRawTag
.service( .service(
serviceId = serviceId, srvId = serviceId,
fnName = argMethodName, funcName = argMethodName,
call = Call( call = Call(
args = VarRaw(argMethodName, ScalarType.string) :: Nil, args = VarRaw(argMethodName, ScalarType.string) :: Nil,
exportTo = Call.Export(res.name, res.`type`) :: Nil exportTo = Call.Export(res.name, res.`type`) :: Nil
), ),
name = serviceName,
arrowType = ArrowType( arrowType = ArrowType(
domain = ProductType.labelled(List(argMethodName -> ScalarType.string)), domain = ProductType.labelled(List(argMethodName -> ScalarType.string)),
codomain = ProductType(ScalarType.string :: Nil) codomain = ProductType(ScalarType.string :: Nil)
) ).some
) )
.leaf, .leaf,
ReturnTag( ReturnTag(
@ -2014,7 +1944,7 @@ class ArrowInlinerSpec extends AnyFlatSpec with Matchers with Inside {
FuncArrow( FuncArrow(
"dumb_func", "dumb_func",
SeqTag.wrap( SeqTag.wrap(
AssignmentTag(LiteralRaw("1", LiteralType.number), argVar.name).leaf, AssignmentTag(LiteralRaw.number(1), argVar.name).leaf,
foldOp foldOp
), ),
ArrowType( ArrowType(
@ -2037,7 +1967,7 @@ class ArrowInlinerSpec extends AnyFlatSpec with Matchers with Inside {
CallServiceModel( CallServiceModel(
LiteralModel.fromRaw(serviceId), LiteralModel.fromRaw(serviceId),
fnName, fnName,
CallModel(LiteralModel("1", LiteralType.number) :: Nil, Nil) CallModel(LiteralModel.number(1) :: Nil, Nil)
).leaf, ).leaf,
NextModel(iVar0.name).leaf NextModel(iVar0.name).leaf
) )
@ -2215,8 +2145,7 @@ class ArrowInlinerSpec extends AnyFlatSpec with Matchers with Inside {
val closureBody = SeqTag.wrap( val closureBody = SeqTag.wrap(
AssignmentTag( AssignmentTag(
CallArrowRaw.service( CallServiceRaw(
"cmp",
LiteralRaw.quote("cmp"), LiteralRaw.quote("cmp"),
"gt", "gt",
ArrowType( ArrowType(

View File

@ -6,6 +6,7 @@ import aqua.model.inline.state.InliningState
import aqua.raw.ops.* import aqua.raw.ops.*
import aqua.raw.value.* import aqua.raw.value.*
import aqua.types.* import aqua.types.*
import cats.data.{Chain, NonEmptyList, NonEmptyMap} import cats.data.{Chain, NonEmptyList, NonEmptyMap}
import cats.syntax.show.* import cats.syntax.show.*
import org.scalatest.flatspec.AnyFlatSpec import org.scalatest.flatspec.AnyFlatSpec
@ -24,12 +25,11 @@ class CopyInlinerSpec extends AnyFlatSpec with Matchers {
val length = FunctorRaw("length", ScalarType.u32) val length = FunctorRaw("length", ScalarType.u32)
val lengthValue = VarRaw("l", arrType).withProperty(length) val lengthValue = VarRaw("l", arrType).withProperty(length)
val getField = CallArrowRaw( val getField = CallServiceRaw(
None, LiteralRaw.quote("serv"),
"get_field", "get_field",
Nil,
ArrowType(NilType, UnlabeledConsType(ScalarType.string, NilType)), ArrowType(NilType, UnlabeledConsType(ScalarType.string, NilType)),
Option(LiteralRaw.quote("serv")) Nil
) )
val copyRaw = val copyRaw =
@ -63,9 +63,22 @@ class CopyInlinerSpec extends AnyFlatSpec with Matchers {
).leaf ).leaf
), ),
RestrictionModel(streamMapName, streamMapType).wrap( RestrictionModel(streamMapName, streamMapType).wrap(
InsertKeyValueModel(LiteralModel.quote("field1"), VarModel("l_length", ScalarType.u32), streamMapName, streamMapType).leaf, InsertKeyValueModel(
InsertKeyValueModel(LiteralModel.quote("field2"), VarModel("get_field", ScalarType.string), streamMapName, streamMapType).leaf, LiteralModel.quote("field1"),
CanonicalizeModel(VarModel(streamMapName, streamMapType), CallModel.Export(result.name, result.`type`)).leaf VarModel("l_length", ScalarType.u32),
streamMapName,
streamMapType
).leaf,
InsertKeyValueModel(
LiteralModel.quote("field2"),
VarModel("get_field", ScalarType.string),
streamMapName,
streamMapType
).leaf,
CanonicalizeModel(
VarModel(streamMapName, streamMapType),
CallModel.Export(result.name, result.`type`)
).leaf
) )
) )
) shouldBe true ) shouldBe true

View File

@ -6,6 +6,7 @@ import aqua.model.inline.state.InliningState
import aqua.raw.ops.* import aqua.raw.ops.*
import aqua.raw.value.* import aqua.raw.value.*
import aqua.types.* import aqua.types.*
import cats.data.{Chain, NonEmptyList, NonEmptyMap} import cats.data.{Chain, NonEmptyList, NonEmptyMap}
import cats.syntax.show.* import cats.syntax.show.*
import org.scalatest.flatspec.AnyFlatSpec import org.scalatest.flatspec.AnyFlatSpec
@ -24,12 +25,11 @@ class MakeStructInlinerSpec extends AnyFlatSpec with Matchers {
val length = FunctorRaw("length", ScalarType.u32) val length = FunctorRaw("length", ScalarType.u32)
val lengthValue = VarRaw("l", arrType).withProperty(length) val lengthValue = VarRaw("l", arrType).withProperty(length)
val getField = CallArrowRaw( val getField = CallServiceRaw(
None, LiteralRaw.quote("serv"),
"get_field", "get_field",
Nil,
ArrowType(NilType, UnlabeledConsType(ScalarType.string, NilType)), ArrowType(NilType, UnlabeledConsType(ScalarType.string, NilType)),
Option(LiteralRaw.quote("serv")) Nil
) )
val makeStruct = val makeStruct =
@ -62,9 +62,22 @@ class MakeStructInlinerSpec extends AnyFlatSpec with Matchers {
).leaf ).leaf
), ),
RestrictionModel(streamMapName, streamMapType).wrap( RestrictionModel(streamMapName, streamMapType).wrap(
InsertKeyValueModel(LiteralModel.quote("field1"), VarModel("l_length", ScalarType.u32), streamMapName, streamMapType).leaf, InsertKeyValueModel(
InsertKeyValueModel(LiteralModel.quote("field2"), VarModel("get_field", ScalarType.string), streamMapName, streamMapType).leaf, LiteralModel.quote("field1"),
CanonicalizeModel(VarModel(streamMapName, streamMapType), CallModel.Export(result.name, result.`type`)).leaf VarModel("l_length", ScalarType.u32),
streamMapName,
streamMapType
).leaf,
InsertKeyValueModel(
LiteralModel.quote("field2"),
VarModel("get_field", ScalarType.string),
streamMapName,
streamMapType
).leaf,
CanonicalizeModel(
VarModel(streamMapName, streamMapType),
CallModel.Export(result.name, result.`type`)
).leaf
) )
) )
) shouldBe true ) shouldBe true

View File

@ -1,21 +1,10 @@
package aqua.model.inline package aqua.model.inline
import aqua.raw.value.{CallArrowRaw, LiteralRaw, ValueRaw} import aqua.raw.value.{ApplyBinaryOpRaw, ValueRaw}
import aqua.types.{ArrowType, ProductType, ScalarType} import aqua.types.{ArrowType, ProductType, ScalarType}
object RawBuilder { object RawBuilder {
def add(l: ValueRaw, r: ValueRaw): ValueRaw = def add(l: ValueRaw, r: ValueRaw): ValueRaw =
CallArrowRaw.service( ApplyBinaryOpRaw.Add(l, r)
abilityName = "math",
serviceId = LiteralRaw.quote("math"),
funcName = "add",
baseType = ArrowType(
ProductType(List(ScalarType.i64, ScalarType.i64)),
ProductType(
List(l.`type` `` r.`type`)
)
),
arguments = List(l, r)
)
} }

View File

@ -1,34 +1,48 @@
package aqua.model.inline package aqua.model.inline
import aqua.model.inline.raw.ApplyPropertiesRawInliner import aqua.model.inline.raw.{ApplyPropertiesRawInliner, StreamGateInliner}
import aqua.model.{ import aqua.model.*
EmptyModel,
FlattenModel,
FunctorModel,
IntoFieldModel,
IntoIndexModel,
ParModel,
SeqModel,
ValueModel,
VarModel
}
import aqua.model.inline.state.InliningState import aqua.model.inline.state.InliningState
import aqua.raw.value.{ApplyPropertyRaw, FunctorRaw, IntoIndexRaw, LiteralRaw, VarRaw} import aqua.raw.value.{ApplyPropertyRaw, FunctorRaw, IntoIndexRaw, LiteralRaw, VarRaw}
import aqua.types.* import aqua.types.*
import aqua.raw.value.*
import cats.Eval
import cats.data.NonEmptyMap import cats.data.NonEmptyMap
import cats.data.Chain import cats.data.Chain
import cats.syntax.show.* import cats.syntax.show.*
import cats.syntax.foldable.*
import cats.free.Cofree
import scala.collection.immutable.SortedMap
import scala.math
import org.scalatest.flatspec.AnyFlatSpec import org.scalatest.flatspec.AnyFlatSpec
import org.scalatest.matchers.should.Matchers import org.scalatest.matchers.should.Matchers
import org.scalatest.Inside
import scala.collection.immutable.SortedMap class RawValueInlinerSpec extends AnyFlatSpec with Matchers with Inside {
import aqua.raw.value.ApplyBinaryOpRaw
import aqua.raw.value.CallArrowRaw
class RawValueInlinerSpec extends AnyFlatSpec with Matchers {
import RawValueInliner.valueToModel import RawValueInliner.valueToModel
def join(stream: VarModel, size: ValueModel) =
stream match {
case VarModel(
streamName,
streamType: StreamType,
Chain.`nil`
) =>
StreamGateInliner.joinStreamOnIndexModel(
streamName = streamName,
streamType = streamType,
sizeModel = size,
testName = streamName + "_test",
iterName = streamName + "_fold_var",
canonName = streamName + "_result_canon",
iterCanonName = streamName + "_iter_canon",
resultName = streamName + "_gate"
)
case _ => ???
}
private def numVarWithLength(name: String) = private def numVarWithLength(name: String) =
VarRaw(name, ArrayType(ScalarType.u32)).withProperty( VarRaw(name, ArrayType(ScalarType.u32)).withProperty(
FunctorRaw("length", ScalarType.u32) FunctorRaw("length", ScalarType.u32)
@ -126,6 +140,51 @@ class RawValueInlinerSpec extends AnyFlatSpec with Matchers {
IntoIndexRaw(ysVarRaw(1), ScalarType.string) IntoIndexRaw(ysVarRaw(1), ScalarType.string)
) )
def int(i: Int): LiteralRaw = LiteralRaw.number(i)
extension (l: ValueRaw) {
def cmp(op: ApplyBinaryOpRaw.Op.Cmp)(r: ValueRaw): ApplyBinaryOpRaw =
ApplyBinaryOpRaw(op, l, r, ScalarType.bool)
def math(op: ApplyBinaryOpRaw.Op.Math)(r: ValueRaw): ApplyBinaryOpRaw =
ApplyBinaryOpRaw(op, l, r, ScalarType.i64) // result type is not important here
def `<`(r: ValueRaw): ApplyBinaryOpRaw =
cmp(ApplyBinaryOpRaw.Op.Lt)(r)
def `<=`(r: ValueRaw): ApplyBinaryOpRaw =
cmp(ApplyBinaryOpRaw.Op.Lte)(r)
def `>`(r: ValueRaw): ApplyBinaryOpRaw =
cmp(ApplyBinaryOpRaw.Op.Gt)(r)
def `>=`(r: ValueRaw): ApplyBinaryOpRaw =
cmp(ApplyBinaryOpRaw.Op.Gte)(r)
def `+`(r: ValueRaw): ApplyBinaryOpRaw =
math(ApplyBinaryOpRaw.Op.Add)(r)
def `-`(r: ValueRaw): ApplyBinaryOpRaw =
math(ApplyBinaryOpRaw.Op.Sub)(r)
def `*`(r: ValueRaw): ApplyBinaryOpRaw =
math(ApplyBinaryOpRaw.Op.Mul)(r)
def `/`(r: ValueRaw): ApplyBinaryOpRaw =
math(ApplyBinaryOpRaw.Op.Div)(r)
def `%`(r: ValueRaw): ApplyBinaryOpRaw =
math(ApplyBinaryOpRaw.Op.Rem)(r)
def `**`(r: ValueRaw): ApplyBinaryOpRaw =
math(ApplyBinaryOpRaw.Op.Pow)(r)
}
private def ivar(name: String, t: Option[Type] = None): VarRaw =
VarRaw(name, t.getOrElse(ScalarType.i32))
"raw value inliner" should "desugarize a single non-recursive raw value" in { "raw value inliner" should "desugarize a single non-recursive raw value" in {
// x[y] // x[y]
valueToModel[InliningState](`raw x[y]`) valueToModel[InliningState](`raw x[y]`)
@ -305,24 +364,53 @@ class RawValueInlinerSpec extends AnyFlatSpec with Matchers {
} }
it should "desugarize stream with gate" in { it should "desugarize stream with gate" in {
val streamWithProps = val stream = VarRaw("x", StreamType(ScalarType.string))
VarRaw("x", StreamType(ScalarType.string)).withProperty( val streamModel = VarModel.fromVarRaw(stream)
IntoIndexRaw(ysVarRaw(1), ScalarType.string) val idxRaw = ysVarRaw(1)
val streamWithProps = stream.withProperty(
IntoIndexRaw(idxRaw, ScalarType.string)
) )
val (resVal, resTree) = valueToModel[InliningState](streamWithProps) val initState = InliningState(noNames = Set("x", "ys"))
.runA(InliningState(noNames = Set("x", "ys")))
.value // Here retrieve how size is inlined
val (afterSizeState, (sizeModel, sizeTree)) =
valueToModel[InliningState](idxRaw.increment).run(initState).value
val (resVal, resTree) =
valueToModel[InliningState](streamWithProps).runA(initState).value
val idxModel = VarModel("x_idx", ScalarType.i8)
val decrement = CallServiceModel(
"math",
"sub",
List(
sizeModel,
LiteralModel.number(1)
),
idxModel
).leaf
val expected = SeqModel.wrap(
sizeTree.toList :+
join(streamModel, sizeModel) :+
decrement
)
resVal should be( resVal should be(
VarModel( VarModel(
"x_gate", "x_gate",
ArrayType(ScalarType.string), ArrayType(ScalarType.string),
Chain( Chain(
IntoIndexModel("ys_flat", ScalarType.string) IntoIndexModel(idxModel.name, ScalarType.string)
) )
) )
) )
inside(resTree) { case Some(tree) =>
tree.equalsOrShowDiff(expected) should be(true)
}
} }
it should "desugarize stream with length" in { it should "desugarize stream with length" in {
@ -388,4 +476,165 @@ class RawValueInlinerSpec extends AnyFlatSpec with Matchers {
) )
) should be(true) ) should be(true)
} }
it should "optimize constants comparison" in {
for {
l <- -100 to 100
r <- -100 to 100
} {
val lt = valueToModel[InliningState](
int(l) `<` int(r)
).runA(InliningState()).value
lt shouldBe (
LiteralModel.bool(l < r) -> None
)
val lte = valueToModel[InliningState](
int(l) `<=` int(r)
).runA(InliningState()).value
lte shouldBe (
LiteralModel.bool(l <= r) -> None
)
val gt = valueToModel[InliningState](
int(l) `>` int(r)
).runA(InliningState()).value
gt shouldBe (
LiteralModel.bool(l > r) -> None
)
val gte = valueToModel[InliningState](
int(l) `>=` int(r)
).runA(InliningState()).value
gte shouldBe (
LiteralModel.bool(l >= r) -> None
)
}
}
it should "optimize constants math" in {
for {
l <- -100 to 100
r <- -100 to 100
} {
val add = valueToModel[InliningState](
int(l) `+` int(r)
).runA(InliningState()).value
add shouldBe (
LiteralModel.number(l + r) -> None
)
val sub = valueToModel[InliningState](
int(l) `-` int(r)
).runA(InliningState()).value
sub shouldBe (
LiteralModel.number(l - r) -> None
)
val mul = valueToModel[InliningState](
int(l) `*` int(r)
).runA(InliningState()).value
mul shouldBe (
LiteralModel.number(l * r) -> None
)
val div = valueToModel[InliningState](
int(l) `/` int(r)
).runA(InliningState()).value
val rem = valueToModel[InliningState](
int(l) `%` int(r)
).runA(InliningState()).value
if (r != 0)
div shouldBe (
LiteralModel.number(l / r) -> None
)
rem shouldBe (
LiteralModel.number(l % r) -> None
)
else {
val (dmodel, dtree) = div
dmodel shouldBe a[VarModel]
dtree.nonEmpty shouldBe (true)
val (rmodel, rtree) = rem
rmodel shouldBe a[VarModel]
rtree.nonEmpty shouldBe (true)
}
if (r >= 0 && r <= 5) {
val pow = valueToModel[InliningState](
int(l) `**` int(r)
).runA(InliningState()).value
pow shouldBe (
LiteralModel.number(scala.math.pow(l, r).toLong) -> None
)
}
}
}
it should "optimize addition in expressions" in {
def test(numVars: Int, numLiterals: Int) = {
val vars = (1 to numVars).map(i => ivar(s"v$i")).toList
val literals = (1 to numLiterals).map(i => LiteralRaw.number(i)).toList
val values = vars ++ literals
/**
* Enumerate all possible binary trees of vals
*/
def genAllExprs(vals: List[ValueRaw]): List[ValueRaw] =
if (vals.length <= 1) vals
else
for {
split <- (1 until vals.length).toList
(left, right) = vals.splitAt(split)
l <- genAllExprs(left)
r <- genAllExprs(right)
} yield l `+` r
for {
perm <- values.permutations.toList
expr <- genAllExprs(perm)
} {
val state = InliningState(
resolvedExports = vars.map(v => v.name -> VarModel.fromVarRaw(v)).toMap
)
val (model, inline) = valueToModel[InliningState](expr).runA(state).value
model shouldBe a[VarModel]
inside(inline) { case Some(tree) =>
val numberOfAdditions = Cofree
.cata(tree) { (model, count: Chain[Int]) =>
Eval.later {
count.combineAll + (model match {
case CallServiceModel(_, "add", _) => 1
case _ => 0
})
}
}
.value
numberOfAdditions shouldEqual numVars
}
}
}
/**
* Number of expressions grows exponentially
* So we test only small cases
*/
test(2, 2)
test(3, 2)
test(2, 3)
}
} }

View File

@ -2,7 +2,7 @@ package aqua.raw.ops
import aqua.raw.arrow.FuncRaw import aqua.raw.arrow.FuncRaw
import aqua.raw.ops.RawTag.Tree import aqua.raw.ops.RawTag.Tree
import aqua.raw.value.{CallArrowRaw, ValueRaw} import aqua.raw.value.{CallArrowRaw, CallServiceRaw, ValueRaw}
import aqua.tree.{TreeNode, TreeNodeCompanion} import aqua.tree.{TreeNode, TreeNodeCompanion}
import aqua.types.{ArrowType, DataType, ServiceType} import aqua.types.{ArrowType, DataType, ServiceType}
@ -224,26 +224,6 @@ object CallArrowRawTag {
) )
) )
def service(
serviceId: ValueRaw,
fnName: String,
call: Call,
name: String = null,
arrowType: ArrowType = null
): CallArrowRawTag =
CallArrowRawTag(
call.exportTo,
CallArrowRaw(
Option(name),
fnName,
call.args,
Option(arrowType).getOrElse(
call.arrowType
),
Some(serviceId)
)
)
def func(fnName: String, call: Call): CallArrowRawTag = def func(fnName: String, call: Call): CallArrowRawTag =
CallArrowRawTag( CallArrowRawTag(
call.exportTo, call.exportTo,
@ -253,6 +233,22 @@ object CallArrowRawTag {
arguments = call.args arguments = call.args
) )
) )
def service(
srvId: ValueRaw,
funcName: String,
call: Call,
arrowType: Option[ArrowType] = None
): CallArrowRawTag =
CallArrowRawTag(
call.exportTo,
CallServiceRaw(
srvId,
funcName,
arrowType.getOrElse(call.arrowType),
call.args
)
)
} }
case class DeclareStreamTag( case class DeclareStreamTag(

View File

@ -0,0 +1,119 @@
package aqua.raw.value
import cats.Eval
import cats.data.Ior
import cats.Semigroup
import cats.syntax.applicative.*
import cats.syntax.apply.*
import cats.syntax.functor.*
import cats.syntax.semigroup.*
object Optimization {
/**
* Optimize raw value.
*
* A lot more optimizations could be done,
* it is here just as a proof of concept.
*/
def optimize(value: ValueRaw): Eval[ValueRaw] =
Addition.optimize(value)
object Addition {
def optimize(value: ValueRaw): Eval[ValueRaw] =
gatherLiteralsInAddition(value).map { res =>
res.fold(
// TODO: Type of literal is not preserved
LiteralRaw.number,
identity,
(literal, nonLiteral) =>
if (literal > 0) {
ApplyBinaryOpRaw.Add(nonLiteral, LiteralRaw.number(literal))
} else if (literal < 0) {
ApplyBinaryOpRaw.Sub(nonLiteral, LiteralRaw.number(-literal))
} else nonLiteral
)
}
private def gatherLiteralsInAddition(
value: ValueRaw
): Eval[Ior[Long, ValueRaw]] =
value match {
case ApplyBinaryOpRaw.Add(left, right) =>
(
gatherLiteralsInAddition(left),
gatherLiteralsInAddition(right)
).mapN(_ add _)
case ApplyBinaryOpRaw.Sub(
left,
/**
* We support subtraction only with literal at the right side
* because we don't have unary minus operator for values.
* (Or this algo should be much more complex to support it)
* But this case is pretty commonly generated by compiler
* in gates: `join stream[len - 1]` (see `StreamGateInliner`)
*/
LiteralRaw.Integer(i)
) =>
(
gatherLiteralsInAddition(left),
Eval.now(Ior.left(-i))
// NOTE: Use add as sign is stored inside Long
).mapN(_ add _)
case LiteralRaw.Integer(i) =>
Ior.left(i).pure
case _ =>
// Optimize expressions inside this value
Ior.right(value.mapValues(v => optimize(v).value)).pure
}
/**
* Rewritten `Ior.combine` method
* with custom combination functions.
*/
private def combineWith(
l: Ior[Long, ValueRaw],
r: Ior[Long, ValueRaw]
)(
lf: (Long, Long) => Long,
rf: (ValueRaw, ValueRaw) => ValueRaw
): Ior[Long, ValueRaw] =
l match {
case Ior.Left(lvl) =>
r match {
case Ior.Left(rvl) =>
Ior.left(lf(lvl, rvl))
case Ior.Right(rvr) =>
Ior.both(lvl, rvr)
case Ior.Both(rvl, rvr) =>
Ior.both(lf(lvl, rvl), rvr)
}
case Ior.Right(lvr) =>
r match {
case Ior.Left(rvl) =>
Ior.both(rvl, lvr)
case Ior.Right(rvr) =>
Ior.right(rf(lvr, rvr))
case Ior.Both(rvl, rvr) =>
Ior.both(rvl, rf(lvr, rvr))
}
case Ior.Both(lvl, lvr) =>
r match {
case Ior.Left(rvl) =>
Ior.both(lf(lvl, rvl), lvr)
case Ior.Right(rvr) =>
Ior.both(lvl, rf(lvr, rvr))
case Ior.Both(rvl, rvr) =>
Ior.both(lf(lvl, rvl), rf(lvr, rvr))
}
}
extension (l: Ior[Long, ValueRaw]) {
def add(r: Ior[Long, ValueRaw]): Ior[Long, ValueRaw] =
combineWith(l, r)(_ + _, ApplyBinaryOpRaw.Add(_, _))
}
}
}

View File

@ -6,6 +6,9 @@ import cats.data.NonEmptyMap
sealed trait PropertyRaw { sealed trait PropertyRaw {
def `type`: Type def `type`: Type
/**
* Apply function to values in this property
*/
def map(f: ValueRaw => ValueRaw): PropertyRaw def map(f: ValueRaw => ValueRaw): PropertyRaw
def renameVars(vals: Map[String, String]): PropertyRaw = this def renameVars(vals: Map[String, String]): PropertyRaw = this
@ -24,7 +27,8 @@ case class IntoArrowRaw(name: String, arrowType: Type, arguments: List[ValueRaw]
override def `type`: Type = arrowType override def `type`: Type = arrowType
override def map(f: ValueRaw => ValueRaw): PropertyRaw = this override def map(f: ValueRaw => ValueRaw): PropertyRaw =
copy(arguments = arguments.map(f))
override def varNames: Set[String] = arguments.flatMap(_.varNames).toSet override def varNames: Set[String] = arguments.flatMap(_.varNames).toSet

View File

@ -15,7 +15,16 @@ sealed trait ValueRaw {
def renameVars(map: Map[String, String]): ValueRaw def renameVars(map: Map[String, String]): ValueRaw
def map(f: ValueRaw => ValueRaw): ValueRaw /**
* Apply function to all values in the tree
*/
final def map(f: ValueRaw => ValueRaw): ValueRaw =
f(mapValues(_.map(f)))
/**
* Apply function to values in this value
*/
def mapValues(f: ValueRaw => ValueRaw): ValueRaw
def varNames: Set[String] def varNames: Set[String]
} }
@ -60,6 +69,12 @@ object ValueRaw {
type ApplyRaw = ApplyPropertyRaw | CallArrowRaw | CollectionRaw | ApplyBinaryOpRaw | type ApplyRaw = ApplyPropertyRaw | CallArrowRaw | CollectionRaw | ApplyBinaryOpRaw |
ApplyUnaryOpRaw ApplyUnaryOpRaw
extension (v: ValueRaw) {
def add(a: ValueRaw): ValueRaw = ApplyBinaryOpRaw.Add(v, a)
def increment: ValueRaw = ApplyBinaryOpRaw.Add(v, LiteralRaw.number(1))
}
} }
case class ApplyPropertyRaw(value: ValueRaw, property: PropertyRaw) extends ValueRaw { case class ApplyPropertyRaw(value: ValueRaw, property: PropertyRaw) extends ValueRaw {
@ -70,8 +85,8 @@ case class ApplyPropertyRaw(value: ValueRaw, property: PropertyRaw) extends Valu
override def renameVars(map: Map[String, String]): ValueRaw = override def renameVars(map: Map[String, String]): ValueRaw =
ApplyPropertyRaw(value.renameVars(map), property.renameVars(map)) ApplyPropertyRaw(value.renameVars(map), property.renameVars(map))
override def map(f: ValueRaw => ValueRaw): ValueRaw = override def mapValues(f: ValueRaw => ValueRaw): ValueRaw =
f(ApplyPropertyRaw(f(value), property.map(_.map(f)))) ApplyPropertyRaw(f(value), property.map(f))
override def toString: String = s"$value.$property" override def toString: String = s"$value.$property"
@ -96,7 +111,7 @@ object ApplyPropertyRaw {
case class VarRaw(name: String, baseType: Type) extends ValueRaw { case class VarRaw(name: String, baseType: Type) extends ValueRaw {
override def map(f: ValueRaw => ValueRaw): ValueRaw = f(this) override def mapValues(f: ValueRaw => ValueRaw): ValueRaw = this
override def renameVars(map: Map[String, String]): ValueRaw = override def renameVars(map: Map[String, String]): ValueRaw =
copy(name = map.getOrElse(name, name)) copy(name = map.getOrElse(name, name))
@ -110,7 +125,7 @@ case class VarRaw(name: String, baseType: Type) extends ValueRaw {
} }
case class LiteralRaw(value: String, baseType: Type) extends ValueRaw { case class LiteralRaw(value: String, baseType: Type) extends ValueRaw {
override def map(f: ValueRaw => ValueRaw): ValueRaw = f(this) override def mapValues(f: ValueRaw => ValueRaw): ValueRaw = this
override def toString: String = s"{$value: ${baseType}}" override def toString: String = s"{$value: ${baseType}}"
@ -122,12 +137,25 @@ case class LiteralRaw(value: String, baseType: Type) extends ValueRaw {
object LiteralRaw { object LiteralRaw {
def quote(value: String): LiteralRaw = LiteralRaw("\"" + value + "\"", LiteralType.string) def quote(value: String): LiteralRaw = LiteralRaw("\"" + value + "\"", LiteralType.string)
def number(value: Int): LiteralRaw = LiteralRaw(value.toString, LiteralType.forInt(value)) def number(value: Long): LiteralRaw = LiteralRaw(value.toString, LiteralType.forInt(value))
val Zero: LiteralRaw = number(0) val Zero: LiteralRaw = number(0)
val True: LiteralRaw = LiteralRaw("true", LiteralType.bool) val True: LiteralRaw = LiteralRaw("true", LiteralType.bool)
val False: LiteralRaw = LiteralRaw("false", LiteralType.bool) val False: LiteralRaw = LiteralRaw("false", LiteralType.bool)
object Integer {
/*
* Used to match integer literals in pattern matching
*/
def unapply(value: ValueRaw): Option[Long] =
value match {
case LiteralRaw(value, t) if ScalarType.integer.exists(_.acceptsValueOf(t)) =>
value.toLongOption
case _ => none
}
}
} }
case class CollectionRaw(values: NonEmptyList[ValueRaw], boxType: BoxType) extends ValueRaw { case class CollectionRaw(values: NonEmptyList[ValueRaw], boxType: BoxType) extends ValueRaw {
@ -136,10 +164,10 @@ case class CollectionRaw(values: NonEmptyList[ValueRaw], boxType: BoxType) exten
override lazy val baseType: Type = boxType override lazy val baseType: Type = boxType
override def map(f: ValueRaw => ValueRaw): ValueRaw = { override def mapValues(f: ValueRaw => ValueRaw): ValueRaw = {
val vals = values.map(f) val vals = values.map(f)
val el = vals.map(_.`type`).reduceLeft(_ `∩` _) val el = vals.map(_.`type`).reduceLeft(_ `∩` _)
f(copy(vals, boxType.withElement(el))) copy(vals, boxType.withElement(el))
} }
override def varNames: Set[String] = values.toList.flatMap(_.varNames).toSet override def varNames: Set[String] = values.toList.flatMap(_.varNames).toSet
@ -153,7 +181,8 @@ case class MakeStructRaw(fields: NonEmptyMap[String, ValueRaw], structType: Stru
override def baseType: Type = structType override def baseType: Type = structType
override def map(f: ValueRaw => ValueRaw): ValueRaw = f(copy(fields = fields.map(f))) override def mapValues(f: ValueRaw => ValueRaw): ValueRaw =
copy(fields = fields.map(f))
override def varNames: Set[String] = { override def varNames: Set[String] = {
fields.toSortedMap.values.flatMap(_.varNames).toSet fields.toSortedMap.values.flatMap(_.varNames).toSet
@ -168,8 +197,8 @@ case class AbilityRaw(fieldsAndArrows: NonEmptyMap[String, ValueRaw], abilityTyp
override def baseType: Type = abilityType override def baseType: Type = abilityType
override def map(f: ValueRaw => ValueRaw): ValueRaw = override def mapValues(f: ValueRaw => ValueRaw): ValueRaw =
f(copy(fieldsAndArrows = fieldsAndArrows.map(f))) copy(fieldsAndArrows = fieldsAndArrows.map(f))
override def varNames: Set[String] = { override def varNames: Set[String] = {
fieldsAndArrows.toSortedMap.values.flatMap(_.varNames).toSet fieldsAndArrows.toSortedMap.values.flatMap(_.varNames).toSet
@ -182,29 +211,79 @@ case class AbilityRaw(fieldsAndArrows: NonEmptyMap[String, ValueRaw], abilityTyp
case class ApplyBinaryOpRaw( case class ApplyBinaryOpRaw(
op: ApplyBinaryOpRaw.Op, op: ApplyBinaryOpRaw.Op,
left: ValueRaw, left: ValueRaw,
right: ValueRaw right: ValueRaw,
// TODO: Refactor type, get rid of `LiteralType`
resultType: ScalarType | LiteralType
) extends ValueRaw { ) extends ValueRaw {
// Only boolean operations are supported for now override val baseType: Type = resultType
override def baseType: Type = ScalarType.bool
override def map(f: ValueRaw => ValueRaw): ValueRaw = override def mapValues(f: ValueRaw => ValueRaw): ValueRaw =
f(copy(left = f(left), right = f(right))) copy(left = f(left), right = f(right))
override def varNames: Set[String] = left.varNames ++ right.varNames override def varNames: Set[String] = left.varNames ++ right.varNames
override def renameVars(map: Map[String, String]): ValueRaw = override def renameVars(map: Map[String, String]): ValueRaw =
copy(left = left.renameVars(map), right = right.renameVars(map)) copy(left = left.renameVars(map), right = right.renameVars(map))
override def toString(): String =
s"(${left} ${op} ${right}) :: ${resultType}"
} }
object ApplyBinaryOpRaw { object ApplyBinaryOpRaw {
enum Op { enum Op {
case And case And, Or
case Or case Eq, Neq
case Lt, Lte, Gt, Gte
case Add, Sub, Mul, FMul, Div, Pow, Rem
}
case Eq object Op {
case Neq
type Bool = And.type | Or.type
type Eq = Eq.type | Neq.type
type Cmp = Lt.type | Lte.type | Gt.type | Gte.type
type Math = Add.type | Sub.type | Mul.type | FMul.type | Div.type | Pow.type | Rem.type
}
object Add {
def apply(left: ValueRaw, right: ValueRaw): ValueRaw =
ApplyBinaryOpRaw(
Op.Add,
left,
right,
ScalarType.resolveMathOpType(left.`type`, right.`type`).`type`
)
def unapply(value: ValueRaw): Option[(ValueRaw, ValueRaw)] =
value match {
case ApplyBinaryOpRaw(Op.Add, left, right, _) =>
(left, right).some
case _ => none
}
}
object Sub {
def apply(left: ValueRaw, right: ValueRaw): ValueRaw =
ApplyBinaryOpRaw(
Op.Sub,
left,
right,
ScalarType.resolveMathOpType(left.`type`, right.`type`).`type`
)
def unapply(value: ValueRaw): Option[(ValueRaw, ValueRaw)] =
value match {
case ApplyBinaryOpRaw(Op.Sub, left, right, _) =>
(left, right).some
case _ => none
}
} }
} }
@ -216,8 +295,8 @@ case class ApplyUnaryOpRaw(
// Only boolean operations are supported for now // Only boolean operations are supported for now
override def baseType: Type = ScalarType.bool override def baseType: Type = ScalarType.bool
override def map(f: ValueRaw => ValueRaw): ValueRaw = override def mapValues(f: ValueRaw => ValueRaw): ValueRaw =
f(copy(value = f(value))) copy(value = f(value))
override def varNames: Set[String] = value.varNames override def varNames: Set[String] = value.varNames
@ -237,37 +316,28 @@ case class CallArrowRaw(
ability: Option[String], ability: Option[String],
name: String, name: String,
arguments: List[ValueRaw], arguments: List[ValueRaw],
baseType: ArrowType, baseType: ArrowType
// TODO: there should be no serviceId there
serviceId: Option[ValueRaw]
) extends ValueRaw { ) extends ValueRaw {
override def `type`: Type = baseType.codomain.uncons.map(_._1).getOrElse(baseType) override def `type`: Type = baseType.codomain.headOption.getOrElse(baseType)
override def map(f: ValueRaw => ValueRaw): ValueRaw = override def mapValues(f: ValueRaw => ValueRaw): ValueRaw =
f( copy(arguments = arguments.map(f))
copy(
arguments = arguments.map(_.map(f)),
serviceId = serviceId.map(_.map(f))
)
)
override def varNames: Set[String] = name.some override def varNames: Set[String] = name.some
.filterNot(_ => ability.isDefined || serviceId.isDefined) .filterNot(_ => ability.isDefined)
.toSet ++ arguments.flatMap(_.varNames).toSet .toSet ++ arguments.flatMap(_.varNames).toSet
override def renameVars(map: Map[String, String]): ValueRaw = override def renameVars(map: Map[String, String]): ValueRaw =
copy( copy(
name = map name = map
.get(name) .get(name)
// Rename only if it is **not** a service or ability call, see [bug LNG-199] // Rename only if it is **not** an ability call, see [bug LNG-199]
.filterNot(_ => ability.isDefined) .filterNot(_ => ability.isDefined)
.filterNot(_ => serviceId.isDefined)
.getOrElse(name) .getOrElse(name)
) )
override def toString: String = override def toString: String =
s"(call ${ability.fold("")(a => s"|$a| ")} (${serviceId.fold("")(_.toString + " ")}$name) [${arguments s"${ability.fold("")(a => s"$a.")}$name(${arguments.mkString(",")}) :: $baseType)"
.mkString(" ")}] :: $baseType)"
} }
object CallArrowRaw { object CallArrowRaw {
@ -280,8 +350,7 @@ object CallArrowRaw {
ability = None, ability = None,
name = funcName, name = funcName,
arguments = arguments, arguments = arguments,
baseType = baseType, baseType = baseType
serviceId = None
) )
def ability( def ability(
@ -293,22 +362,46 @@ object CallArrowRaw {
ability = None, ability = None,
name = AbilityType.fullName(abilityName, funcName), name = AbilityType.fullName(abilityName, funcName),
arguments = arguments, arguments = arguments,
baseType = baseType, baseType = baseType
serviceId = None
)
def service(
abilityName: String,
serviceId: ValueRaw,
funcName: String,
baseType: ArrowType,
arguments: List[ValueRaw] = Nil
): CallArrowRaw = CallArrowRaw(
ability = abilityName.some,
name = funcName,
arguments = arguments,
baseType = baseType,
serviceId = Some(serviceId)
) )
} }
/**
* WARNING: This class is internal and is used to generate code.
* Calls to services in aqua code are represented as [[CallArrowRaw]]
* and resolved through ability resolution.
*
* @param serviceId service id
* @param fnName service method name
* @param baseType type of the service method
* @param arguments call arguments
*/
case class CallServiceRaw(
serviceId: ValueRaw,
fnName: String,
baseType: ArrowType,
arguments: List[ValueRaw]
) extends ValueRaw {
override def `type`: Type = baseType.codomain.headOption.getOrElse(baseType)
override def mapValues(f: ValueRaw => ValueRaw): ValueRaw =
copy(
serviceId = f(serviceId),
arguments = arguments.map(f)
)
override def varNames: Set[String] =
arguments
.flatMap(_.varNames)
.toSet ++ serviceId.varNames
override def renameVars(map: Map[String, String]): ValueRaw =
copy(
serviceId = serviceId.renameVars(map),
arguments = arguments.map(_.renameVars(map))
)
override def toString: String =
s"call (${serviceId}) $fnName(${arguments.mkString(",")}) :: $baseType)"
}

View File

@ -144,17 +144,12 @@ object AquaContext extends Logging {
blank.copy( blank.copy(
module = Some(sm.name), module = Some(sm.name),
funcs = sm.`type`.arrows.map { case (fnName, arrowType) => funcs = sm.`type`.arrows.map { case (fnName, arrowType) =>
val (args, call, ret) = ArgsCall.arrowToArgsCallRet(arrowType) fnName -> FuncArrow.fromServiceMethod(
fnName -> fnName,
FuncArrow( sm.name,
fnName, fnName,
// TODO: capture ability resolution, get ID from the call context
CallArrowRawTag.service(serviceId, fnName, call, sm.name).leaf,
arrowType, arrowType,
ret.map(_.toRaw), serviceId
Map.empty,
Map.empty,
None
) )
} }
) )

View File

@ -2,10 +2,12 @@ package aqua.model
import aqua.raw.Raw import aqua.raw.Raw
import aqua.raw.arrow.FuncRaw import aqua.raw.arrow.FuncRaw
import aqua.raw.ops.{Call, CallArrowRawTag, RawTag} import aqua.raw.ops.{Call, CallArrowRawTag, EmptyTag, RawTag}
import aqua.raw.value.{ValueRaw, VarRaw} import aqua.raw.value.{ValueRaw, VarRaw}
import aqua.types.{ArrowType, ServiceType, Type} import aqua.types.{ArrowType, ServiceType, Type}
import cats.syntax.option.*
case class FuncArrow( case class FuncArrow(
funcName: String, funcName: String,
body: RawTag.Tree, body: RawTag.Tree,
@ -58,21 +60,28 @@ object FuncArrow {
serviceName: String, serviceName: String,
methodName: String, methodName: String,
methodType: ArrowType, methodType: ArrowType,
idValue: ValueModel idValue: ValueModel | ValueRaw
): FuncArrow = { ): FuncArrow = {
val id = VarRaw("id", idValue.`type`) val (id, capturedValues) = idValue match {
case i: ValueModel =>
(
VarRaw("id", i.`type`),
Map("id" -> i)
)
case i: ValueRaw => (i, Map.empty[String, ValueModel])
}
val retVar = methodType.res.map(t => VarRaw("ret", t)) val retVar = methodType.res.map(t => VarRaw("ret", t))
val call = Call( val call = Call(
methodType.domain.toLabelledList().map(VarRaw.apply), methodType.domain.toLabelledList().map(VarRaw.apply),
retVar.map(r => Call.Export(r.name, r.`type`)).toList retVar.map(r => Call.Export(r.name, r.`type`)).toList
) )
val body = CallArrowRawTag.service( val body = CallArrowRawTag.service(
serviceId = id, srvId = id,
fnName = methodName, funcName = methodName,
call = call, call = call,
name = serviceName, arrowType = methodType.some
arrowType = methodType
) )
FuncArrow( FuncArrow(
@ -81,9 +90,7 @@ object FuncArrow {
arrowType = methodType, arrowType = methodType,
ret = retVar.toList, ret = retVar.toList,
capturedArrows = Map.empty, capturedArrows = Map.empty,
capturedValues = Map( capturedValues = capturedValues,
id.name -> idValue
),
capturedTopology = None capturedTopology = None
) )
} }

View File

@ -7,6 +7,7 @@ import aqua.types.*
import cats.Eq import cats.Eq
import cats.data.{Chain, NonEmptyMap} import cats.data.{Chain, NonEmptyMap}
import cats.syntax.option.* import cats.syntax.option.*
import cats.syntax.apply.*
import scribe.Logging import scribe.Logging
sealed trait ValueModel { sealed trait ValueModel {
@ -91,6 +92,22 @@ object LiteralModel {
} }
} }
/*
* Used to match integer literals in pattern matching
*/
object Integer {
def unapply(lm: LiteralModel): Option[(Long, ScalarType | LiteralType)] =
lm match {
case LiteralModel(value, t) if ScalarType.integer.exists(_.acceptsValueOf(t)) =>
(
value.toLongOption,
t.some.collect { case t: (ScalarType | LiteralType) => t }
).tupled
case _ => none
}
}
// AquaVM will return 0 for // AquaVM will return 0 for
// :error:.$.error_code if there is no :error: // :error:.$.error_code if there is no :error:
val emptyErrorCode = number(0) val emptyErrorCode = number(0)
@ -102,7 +119,7 @@ object LiteralModel {
def quote(str: String): LiteralModel = LiteralModel(s"\"$str\"", LiteralType.string) def quote(str: String): LiteralModel = LiteralModel(s"\"$str\"", LiteralType.string)
def number(n: Int): LiteralModel = LiteralModel(n.toString, LiteralType.forInt(n)) def number(n: Long): LiteralModel = LiteralModel(n.toString, LiteralType.forInt(n))
def bool(b: Boolean): LiteralModel = LiteralModel(b.toString.toLowerCase, LiteralType.bool) def bool(b: Boolean): LiteralModel = LiteralModel(b.toString.toLowerCase, LiteralType.bool)
} }

View File

@ -50,7 +50,7 @@ class CallArrowSem[S[_]](val expr: CallArrowExpr[S]) extends AnyVal {
) )
} yield tag.map(_.funcOpLeaf) } yield tag.map(_.funcOpLeaf)
def program[Alg[_]: Monad](implicit def program[Alg[_]: Monad](using
N: NamesAlgebra[S, Alg], N: NamesAlgebra[S, Alg],
T: TypesAlgebra[S, Alg], T: TypesAlgebra[S, Alg],
V: ValuesAlgebra[S, Alg] V: ValuesAlgebra[S, Alg]

View File

@ -184,7 +184,6 @@ class ValuesAlgebra[S[_], Alg[_]: Monad](using
case (Some(leftRaw), Some(rightRaw)) => case (Some(leftRaw), Some(rightRaw)) =>
val lType = leftRaw.`type` val lType = leftRaw.`type`
val rType = rightRaw.`type` val rType = rightRaw.`type`
lazy val uType = lType `` rType
it.op match { it.op match {
case InfOp.Bool(bop) => case InfOp.Bool(bop) =>
@ -200,7 +199,8 @@ class ValuesAlgebra[S[_], Alg[_]: Monad](using
case BoolOp.Or => ApplyBinaryOpRaw.Op.Or case BoolOp.Or => ApplyBinaryOpRaw.Op.Or
}, },
left = leftRaw, left = leftRaw,
right = rightRaw right = rightRaw,
resultType = ScalarType.bool
) )
) )
case InfOp.Eq(eop) => case InfOp.Eq(eop) =>
@ -216,7 +216,8 @@ class ValuesAlgebra[S[_], Alg[_]: Monad](using
case EqOp.Neq => ApplyBinaryOpRaw.Op.Neq case EqOp.Neq => ApplyBinaryOpRaw.Op.Neq
}, },
left = leftRaw, left = leftRaw,
right = rightRaw right = rightRaw,
resultType = ScalarType.bool
) )
) )
) )
@ -228,70 +229,62 @@ class ValuesAlgebra[S[_], Alg[_]: Monad](using
case InfOp.Cmp(op) => op case InfOp.Cmp(op) => op
} }
val hasFloat = List(lType, rType).exists( lazy val hasFloat = List(lType, rType).exists(
_ acceptsValueOf LiteralType.float _ acceptsValueOf LiteralType.float
) )
// See https://github.com/fluencelabs/aqua-lib/blob/main/math.aqua val bop = iop match {
val (id, fn) = iop match { case MathOp.Add => ApplyBinaryOpRaw.Op.Add
case MathOp.Add => ("math", "add") case MathOp.Sub => ApplyBinaryOpRaw.Op.Sub
case MathOp.Sub => ("math", "sub") case MathOp.Mul if hasFloat => ApplyBinaryOpRaw.Op.FMul
case MathOp.Mul if hasFloat => ("math", "fmul") case MathOp.Mul => ApplyBinaryOpRaw.Op.Mul
case MathOp.Mul => ("math", "mul") case MathOp.Div => ApplyBinaryOpRaw.Op.Div
case MathOp.Div => ("math", "div") case MathOp.Rem => ApplyBinaryOpRaw.Op.Rem
case MathOp.Rem => ("math", "rem") case MathOp.Pow => ApplyBinaryOpRaw.Op.Pow
case MathOp.Pow => ("math", "pow") case CmpOp.Gt => ApplyBinaryOpRaw.Op.Gt
case CmpOp.Gt => ("cmp", "gt") case CmpOp.Gte => ApplyBinaryOpRaw.Op.Gte
case CmpOp.Gte => ("cmp", "gte") case CmpOp.Lt => ApplyBinaryOpRaw.Op.Lt
case CmpOp.Lt => ("cmp", "lt") case CmpOp.Lte => ApplyBinaryOpRaw.Op.Lte
case CmpOp.Lte => ("cmp", "lte")
} }
/* lazy val numbersTypeBounded: Alg[ScalarType | LiteralType] = {
* If `uType == TopType`, it means that we don't val resType = ScalarType.resolveMathOpType(lType, rType)
* have type big enough to hold the result of operation. report
* e.g. We will use `i64` for result of `i32 * u64` .warning(
* TODO: Handle this more gracefully it,
* (use warning system when it is implemented) s"Result type of ($lType ${it.op} $rType) is unknown, " +
*/ s"using ${resType.`type`} instead"
def uTypeBounded = if (uType == TopType) {
val bounded = ScalarType.i64
logger.warn(
s"Result type of ($lType ${it.op} $rType) is $TopType, " +
s"using $bounded instead"
) )
bounded .whenA(resType.overflow)
} else uType .as(resType.`type`)
}
// Expected type sets of left and right operands, result type // Expected type sets of left and right operands, result type
val (leftExp, rightExp, resType) = iop match { val (leftExp, rightExp, resTypeM) = iop match {
case MathOp.Add | MathOp.Sub | MathOp.Div | MathOp.Rem => case MathOp.Add | MathOp.Sub | MathOp.Div | MathOp.Rem =>
(ScalarType.integer, ScalarType.integer, uTypeBounded) (ScalarType.integer, ScalarType.integer, numbersTypeBounded)
case MathOp.Pow => case MathOp.Pow =>
(ScalarType.integer, ScalarType.unsigned, uTypeBounded) (ScalarType.integer, ScalarType.unsigned, numbersTypeBounded)
case MathOp.Mul if hasFloat => case MathOp.Mul if hasFloat =>
(ScalarType.float, ScalarType.float, ScalarType.i64) (ScalarType.float, ScalarType.float, ScalarType.i64.pure)
case MathOp.Mul => case MathOp.Mul =>
(ScalarType.integer, ScalarType.integer, uTypeBounded) (ScalarType.integer, ScalarType.integer, numbersTypeBounded)
case CmpOp.Gt | CmpOp.Lt | CmpOp.Gte | CmpOp.Lte => case CmpOp.Gt | CmpOp.Lt | CmpOp.Gte | CmpOp.Lte =>
(ScalarType.integer, ScalarType.integer, ScalarType.bool) (ScalarType.integer, ScalarType.integer, ScalarType.bool.pure)
} }
for { for {
leftChecked <- T.ensureTypeOneOf(l, leftExp, lType) leftChecked <- T.ensureTypeOneOf(l, leftExp, lType)
rightChecked <- T.ensureTypeOneOf(r, rightExp, rType) rightChecked <- T.ensureTypeOneOf(r, rightExp, rType)
resType <- resTypeM
} yield Option.when( } yield Option.when(
leftChecked.isDefined && rightChecked.isDefined leftChecked.isDefined && rightChecked.isDefined
)( )(
CallArrowRaw.service( ApplyBinaryOpRaw(
abilityName = id, op = bop,
serviceId = LiteralRaw.quote(id), left = leftRaw,
funcName = fn, right = rightRaw,
baseType = ArrowType( resultType = resType
ProductType(lType :: rType :: Nil),
ProductType(resType :: Nil)
),
arguments = leftRaw :: rightRaw :: Nil
) )
) )

View File

@ -108,10 +108,10 @@ class SemanticsSpec extends AnyFlatSpec with Matchers with Inside {
.leaf .leaf
def equ(left: ValueRaw, right: ValueRaw): ApplyBinaryOpRaw = def equ(left: ValueRaw, right: ValueRaw): ApplyBinaryOpRaw =
ApplyBinaryOpRaw(ApplyBinaryOpRaw.Op.Eq, left, right) ApplyBinaryOpRaw(ApplyBinaryOpRaw.Op.Eq, left, right, ScalarType.bool)
def neq(left: ValueRaw, right: ValueRaw): ApplyBinaryOpRaw = def neq(left: ValueRaw, right: ValueRaw): ApplyBinaryOpRaw =
ApplyBinaryOpRaw(ApplyBinaryOpRaw.Op.Neq, left, right) ApplyBinaryOpRaw(ApplyBinaryOpRaw.Op.Neq, left, right, ScalarType.bool)
def declareStreamPush( def declareStreamPush(
name: String, name: String,

View File

@ -267,7 +267,7 @@ class ValuesAlgebraSpec extends AnyFlatSpec with Matchers with Inside {
.run(state) .run(state)
.value .value
inside(res) { case Some(ApplyBinaryOpRaw(bop, _, _)) => inside(res) { case Some(ApplyBinaryOpRaw(bop, _, _, ScalarType.bool)) =>
bop shouldBe (op match { bop shouldBe (op match {
case InfixToken.BoolOp.And => ApplyBinaryOpRaw.Op.And case InfixToken.BoolOp.And => ApplyBinaryOpRaw.Op.And
case InfixToken.BoolOp.Or => ApplyBinaryOpRaw.Op.Or case InfixToken.BoolOp.Or => ApplyBinaryOpRaw.Op.Or
@ -319,7 +319,7 @@ class ValuesAlgebraSpec extends AnyFlatSpec with Matchers with Inside {
.run(state) .run(state)
.value .value
inside(res) { case Some(ApplyBinaryOpRaw(bop, _, _)) => inside(res) { case Some(ApplyBinaryOpRaw(bop, _, _, ScalarType.bool)) =>
bop shouldBe (op match { bop shouldBe (op match {
case InfixToken.EqOp.Eq => ApplyBinaryOpRaw.Op.Eq case InfixToken.EqOp.Eq => ApplyBinaryOpRaw.Op.Eq
case InfixToken.EqOp.Neq => ApplyBinaryOpRaw.Op.Neq case InfixToken.EqOp.Neq => ApplyBinaryOpRaw.Op.Neq

View File

@ -53,6 +53,11 @@ sealed trait ProductType extends Type {
case _ => None case _ => None
} }
def headOption: Option[Type] = this match {
case ConsType(t, _) => Some(t)
case _ => None
}
lazy val toList: List[Type] = this match { lazy val toList: List[Type] = this match {
case ConsType(t, pt) => t :: pt.toList case ConsType(t, pt) => t :: pt.toList
case _ => Nil case _ => Nil
@ -182,6 +187,45 @@ object ScalarType {
val integer = signed ++ unsigned val integer = signed ++ unsigned
val number = float ++ integer val number = float ++ integer
val all = number ++ Set(bool, string) val all = number ++ Set(bool, string)
final case class MathOpType(
`type`: ScalarType | LiteralType,
overflow: Boolean
)
/**
* Resolve type of math operation
* on two given types.
*
* WARNING: General `Type` is accepted
* but only integer `ScalarType` and `LiteralType`
* are actually expected.
*/
def resolveMathOpType(
lType: Type,
rType: Type
): MathOpType = {
val uType = lType `` rType
uType match {
case t: (ScalarType | LiteralType) => MathOpType(t, false)
case _ => MathOpType(ScalarType.i64, true)
}
}
/**
* Check if given type is signed.
*
* NOTE: Only integer types are expected.
* But it is impossible to enforce it.
*/
def isSignedInteger(t: ScalarType | LiteralType): Boolean =
t match {
case st: ScalarType => signed.contains(st)
/**
* WARNING: LiteralType.unsigned is signed integer!
*/
case lt: LiteralType => lt.oneOf.exists(signed.contains)
}
} }
case class LiteralType private (oneOf: Set[ScalarType], name: String) extends DataType { case class LiteralType private (oneOf: Set[ScalarType], name: String) extends DataType {
@ -200,7 +244,7 @@ object LiteralType {
val bool = LiteralType(Set(ScalarType.bool), "bool") val bool = LiteralType(Set(ScalarType.bool), "bool")
val string = LiteralType(Set(ScalarType.string), "string") val string = LiteralType(Set(ScalarType.string), "string")
def forInt(n: Int): LiteralType = if (n < 0) signed else unsigned def forInt(n: Long): LiteralType = if (n < 0) signed else unsigned
} }
sealed trait BoxType extends DataType { sealed trait BoxType extends DataType {
@ -323,8 +367,7 @@ case class StructType(name: String, fields: NonEmptyMap[String, Type])
s"$name{${fields.map(_.toString).toNel.toList.map(kv => kv._1 + ": " + kv._2).mkString(", ")}}" s"$name{${fields.map(_.toString).toNel.toList.map(kv => kv._1 + ": " + kv._2).mkString(", ")}}"
} }
case class StreamMapType(element: Type) case class StreamMapType(element: Type) extends DataType {
extends DataType {
override def toString: String = s"%$element" override def toString: String = s"%$element"
} }