feat(compiler): Optimize math in compile time [LNG-245] (#922)

* Move service calls for math to inlining

* Fix: add predo

* Introduce CallServiceRaw

* Add comment

* Add optimization to inlining

* Add tests

* map -> mapValues

* Refactor type

* Add optimization

* Add optimization test

* Fix unit tests

* Fix PR comments

* Restrict optimization

* Add substraction to optimization

* Apply optimizations in gate

* Fix sign, move optimization to unfold

* Fix type and tests

* Fix unit tests

* Add unit test

* Fix after merge

* Add optimization, fix unit tests

* Fix comment
This commit is contained in:
InversionSpaces 2023-10-09 12:02:26 +02:00 committed by GitHub
parent b298eebf5e
commit 5f6c47ffea
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
26 changed files with 1439 additions and 699 deletions

View File

@ -15,16 +15,20 @@ import aqua.res.*
import aqua.res.ResBuilder
import aqua.types.{ArrayType, CanonStreamType, LiteralType, ScalarType, StreamType, Type}
import org.scalatest.flatspec.AnyFlatSpec
import org.scalatest.matchers.should.Matchers
import cats.Id
import cats.data.{Chain, NonEmptyChain, NonEmptyMap, Validated, ValidatedNec}
import cats.instances.string.*
import cats.syntax.show.*
import cats.syntax.option.*
import cats.syntax.either.*
import org.scalatest.flatspec.AnyFlatSpec
import org.scalatest.matchers.should.Matchers
import org.scalatest.Inside
import aqua.model.AquaContext
import aqua.model.FlattenModel
import aqua.model.CallServiceModel
class AquaCompilerSpec extends AnyFlatSpec with Matchers {
class AquaCompilerSpec extends AnyFlatSpec with Matchers with Inside {
import ModelBuilder.*
private def aquaSource(src: Map[String, String], imports: Map[String, String]) = {
@ -45,8 +49,13 @@ class AquaCompilerSpec extends AnyFlatSpec with Matchers {
}
}
private def compileToContext(src: Map[String, String], imports: Map[String, String]) =
CompilerAPI
private def insideContext(
src: Map[String, String],
imports: Map[String, String] = Map.empty
)(
test: AquaContext => Any
) = {
val compiled = CompilerAPI
.compileToContext[Id, String, String, Span.S](
aquaSource(src, imports),
id => txt => Parser.parse(Parser.parserSchema)(txt),
@ -56,39 +65,52 @@ class AquaCompilerSpec extends AnyFlatSpec with Matchers {
.value
.toValidated
inside(compiled) { case Validated.Valid(contexts) =>
inside(contexts.headOption) { case Some(ctx) =>
test(ctx)
}
}
}
private def insideRes(
src: Map[String, String],
imports: Map[String, String] = Map.empty,
transformCfg: TransformConfig = TransformConfig()
)(funcNames: String*)(
test: PartialFunction[List[FuncRes], Any]
) = insideContext(src, imports)(ctx =>
val aquaRes = Transform.contextRes(ctx, transformCfg)
// To preserve order as in funcNames do flatMap
val funcs = funcNames.flatMap(name => aquaRes.funcs.find(_.funcName == name)).toList
inside(funcs)(test)
)
"aqua compiler" should "compile a simple snippet to the right context" in {
val res = compileToContext(
Map(
"index.aqua" ->
"""module Foo declares X
|
|export foo, foo2 as foo_two, X
|
|const X = 5
|
|func foo() -> string:
| <- "hello?"
|
|func foo2() -> string:
| <- "hello2?"
|""".stripMargin
),
Map.empty
val src = Map(
"index.aqua" ->
"""module Foo declares X
|
|export foo, foo2 as foo_two, X
|
|const X = 5
|
|func foo() -> string:
| <- "hello?"
|
|func foo2() -> string:
| <- "hello2?"
|""".stripMargin
)
res.isValid should be(true)
val Validated.Valid(ctxs) = res
insideContext(src) { ctx =>
ctx.allFuncs.contains("foo") should be(true)
ctx.allFuncs.contains("foo_two") should be(true)
ctxs.length should be(1)
val ctx = ctxs.headOption.get
ctx.allFuncs.contains("foo") should be(true)
ctx.allFuncs.contains("foo_two") should be(true)
val const = ctx.allValues.get("X")
const.nonEmpty should be(true)
const.get should be(LiteralModel.number(5))
val const = ctx.allValues.get("X")
const.nonEmpty should be(true)
const.get should be(LiteralModel.number(5))
}
}
def through(peer: ValueModel) =
@ -110,204 +132,237 @@ class AquaCompilerSpec extends AnyFlatSpec with Matchers {
private def join(vm: VarModel, size: ValueModel) =
ResBuilder.join(vm, size, init)
"aqua compiler" should "create right topology" in {
val res = compileToContext(
Map(
"index.aqua" ->
"""service Op("op"):
| identity(s: string) -> string
|
|func exec(peers: []string) -> []string:
| results: *string
| for peer <- peers par:
| on peer:
| results <- Op.identity("hahahahah")
|
| join results[2]
| <- results""".stripMargin
),
Map.empty
it should "create right topology" in {
val src = Map(
"index.aqua" ->
"""service Op("op"):
| identity(s: string) -> string
|
|func exec(peers: []string) -> []string:
| results: *string
| for peer <- peers par:
| on peer:
| results <- Op.identity("hahahahah")
|
| join results[2]
| <- results""".stripMargin
)
res.isValid should be(true)
val Validated.Valid(ctxs) = res
ctxs.length should be(1)
val ctx = ctxs.headOption.get
val transformCfg = TransformConfig()
val aquaRes = Transform.contextRes(ctx, transformCfg)
val Some(exec) = aquaRes.funcs.find(_.funcName == "exec")
insideRes(src, transformCfg = transformCfg)("exec") { case exec :: _ =>
val peers = VarModel("-peers-arg-", ArrayType(ScalarType.string))
val peer = VarModel("peer-0", ScalarType.string)
val resultsType = StreamType(ScalarType.string)
val results = VarModel("results", resultsType)
val canonResult =
VarModel("-" + results.name + "-fix-0", CanonStreamType(resultsType.element))
val flatResult = VarModel("-results-flat-0", ArrayType(ScalarType.string))
val initPeer = LiteralModel.fromRaw(ValueRaw.InitPeerId)
val retVar = VarModel("ret", ScalarType.string)
val peers = VarModel("-peers-arg-", ArrayType(ScalarType.string))
val peer = VarModel("peer-0", ScalarType.string)
val resultsType = StreamType(ScalarType.string)
val results = VarModel("results", resultsType)
val canonResult = VarModel("-" + results.name + "-fix-0", CanonStreamType(resultsType.element))
val flatResult = VarModel("-results-flat-0", ArrayType(ScalarType.string))
val initPeer = LiteralModel.fromRaw(ValueRaw.InitPeerId)
val sizeVar = VarModel("results_size", LiteralType.unsigned)
val retVar = VarModel("ret", ScalarType.string)
val expected =
XorRes.wrap(
SeqRes.wrap(
getDataSrv("-relay-", "-relay-", ScalarType.string),
getDataSrv("peers", peers.name, peers.`type`),
RestrictionRes(results.name, resultsType).wrap(
SeqRes.wrap(
ParRes.wrap(
FoldRes(peer.name, peers, ForModel.Mode.Never.some).wrap(
ParRes.wrap(
XorRes.wrap(
// better if first relay will be outside `for`
SeqRes.wrap(
through(ValueModel.fromRaw(relay)),
CallServiceRes(
LiteralModel.fromRaw(LiteralRaw.quote("op")),
"identity",
CallRes(
LiteralModel.fromRaw(LiteralRaw.quote("hahahahah")) :: Nil,
Some(CallModel.Export(retVar.name, retVar.`type`))
),
peer
).leaf,
ApRes(retVar, CallModel.Export(results.name, results.`type`)).leaf,
through(ValueModel.fromRaw(relay)),
through(initPeer)
val expected =
XorRes.wrap(
SeqRes.wrap(
getDataSrv("-relay-", "-relay-", ScalarType.string),
getDataSrv("peers", peers.name, peers.`type`),
RestrictionRes(results.name, resultsType).wrap(
SeqRes.wrap(
ParRes.wrap(
FoldRes(peer.name, peers, ForModel.Mode.Never.some).wrap(
ParRes.wrap(
XorRes.wrap(
// better if first relay will be outside `for`
SeqRes.wrap(
through(ValueModel.fromRaw(relay)),
CallServiceRes(
LiteralModel.fromRaw(LiteralRaw.quote("op")),
"identity",
CallRes(
LiteralModel.fromRaw(LiteralRaw.quote("hahahahah")) :: Nil,
Some(CallModel.Export(retVar.name, retVar.`type`))
),
peer
).leaf,
ApRes(retVar, CallModel.Export(results.name, results.`type`)).leaf,
through(ValueModel.fromRaw(relay)),
through(initPeer)
),
SeqRes.wrap(
through(ValueModel.fromRaw(relay)),
through(initPeer),
failErrorRes
)
),
SeqRes.wrap(
through(ValueModel.fromRaw(relay)),
through(initPeer),
failErrorRes
)
),
NextRes(peer.name).leaf
NextRes(peer.name).leaf
)
)
)
),
ResBuilder.add(
LiteralModel.number(2),
LiteralModel.number(1),
sizeVar,
initPeer
),
join(results, sizeVar),
CanonRes(results, init, CallModel.Export(canonResult.name, canonResult.`type`)).leaf,
),
join(results, LiteralModel.number(3)), // Compiler optimized addition
CanonRes(
results,
init,
CallModel.Export(canonResult.name, canonResult.`type`)
).leaf,
ApRes(
canonResult,
CallModel.Export(flatResult.name, flatResult.`type`)
).leaf
)
),
respCall(transformCfg, flatResult, initPeer)
),
errorCall(transformCfg, 0, initPeer)
)
exec.body.equalsOrShowDiff(expected) shouldBe (true)
}
}
it should "compile with imports" in {
val src = Map(
"index.aqua" ->
"""module Import
|import foobar from "export2.aqua"
|
|use foo as f from "export2.aqua" as Exp
|
|import "../gen/OneMore.aqua"
|
|export foo_wrapper as wrap, foobar as barfoo
|
|func foo_wrapper() -> string:
| z <- Exp.f()
| OneMore "hello"
| OneMore.more_call()
| -- Exp.f() returns literal, this func must return literal in AIR as well
| <- z
|""".stripMargin
)
val imports = Map(
"export2.aqua" ->
"""module Export declares foobar, foo
|
|func bar() -> string:
| <- " I am MyFooBar bar"
|
|func foo() -> string:
| <- "I am MyFooBar foo"
|
|func foobar() -> []string:
| res: *string
| res <- foo()
| res <- bar()
| <- res
|
|""".stripMargin,
"../gen/OneMore.aqua" ->
"""
|service OneMore:
| more_call()
| consume(s: string)
|""".stripMargin
)
val transformCfg = TransformConfig(relayVarName = None)
insideRes(src, imports, transformCfg)(
"wrap",
"barfoo"
) { case wrap :: barfoo :: _ =>
val resStreamType = StreamType(ScalarType.string)
val resVM = VarModel("res", resStreamType)
val resCanonVM = VarModel("-res-fix-0", CanonStreamType(ScalarType.string))
val resFlatVM = VarModel("-res-flat-0", ArrayType(ScalarType.string))
val expected = XorRes.wrap(
SeqRes.wrap(
RestrictionRes(resVM.name, resStreamType).wrap(
SeqRes.wrap(
// res <- foo()
ApRes(
canonResult,
CallModel.Export(flatResult.name, flatResult.`type`)
LiteralModel.fromRaw(LiteralRaw.quote("I am MyFooBar foo")),
CallModel.Export(resVM.name, resVM.`type`)
).leaf,
// res <- bar()
ApRes(
LiteralModel.fromRaw(LiteralRaw.quote(" I am MyFooBar bar")),
CallModel.Export(resVM.name, resVM.`type`)
).leaf,
// canonicalization
CanonRes(
resVM,
LiteralModel.fromRaw(ValueRaw.InitPeerId),
CallModel.Export(resCanonVM.name, resCanonVM.`type`)
).leaf,
// flattening
ApRes(
VarModel(resCanonVM.name, resCanonVM.`type`),
CallModel.Export(resFlatVM.name, resFlatVM.`type`)
).leaf
)
),
respCall(transformCfg, flatResult, initPeer)
respCall(transformCfg, resFlatVM, initPeer)
),
errorCall(transformCfg, 0, initPeer)
)
exec.body.equalsOrShowDiff(expected) shouldBe (true)
barfoo.body.equalsOrShowDiff(expected) should be(true)
}
}
"aqua compiler" should "compile with imports" in {
val res = compileToContext(
Map(
"index.aqua" ->
"""module Import
|import foobar from "export2.aqua"
|
|use foo as f from "export2.aqua" as Exp
|
|import "../gen/OneMore.aqua"
|
|export foo_wrapper as wrap, foobar as barfoo
|
|func foo_wrapper() -> string:
| z <- Exp.f()
| OneMore "hello"
| OneMore.more_call()
| -- Exp.f() returns literal, this func must return literal in AIR as well
| <- z
|""".stripMargin
),
Map(
"export2.aqua" ->
"""module Export declares foobar, foo
|
|func bar() -> string:
| <- " I am MyFooBar bar"
|
|func foo() -> string:
| <- "I am MyFooBar foo"
|
|func foobar() -> []string:
| res: *string
| res <- foo()
| res <- bar()
| <- res
|
|""".stripMargin,
"../gen/OneMore.aqua" ->
"""
|service OneMore:
| more_call()
| consume(s: string)
|""".stripMargin
)
it should "optimize math inside stream join" in {
val src = Map(
"main.aqua" -> """
|func main(i: i32):
| stream: *string
| stream <<- "a"
| stream <<- "b"
| join stream[i - 1]
|""".stripMargin
)
res.isValid should be(true)
val Validated.Valid(ctxs) = res
val transformCfg = TransformConfig()
val streamName = "stream"
val streamType = StreamType(ScalarType.string)
val argName = "-i-arg-"
val argType = ScalarType.i32
val arg = VarModel(argName, argType)
ctxs.length should be(1)
val ctx = ctxs.headOption.get
val transformCfg = TransformConfig(relayVarName = None)
val aquaRes = Transform.contextRes(ctx, transformCfg)
val Some(funcWrap) = aquaRes.funcs.find(_.funcName == "wrap")
val Some(barfoo) = aquaRes.funcs.find(_.funcName == "barfoo")
val resStreamType = StreamType(ScalarType.string)
val resVM = VarModel("res", resStreamType)
val resCanonVM = VarModel("-res-fix-0", CanonStreamType(ScalarType.string))
val resFlatVM = VarModel("-res-flat-0", ArrayType(ScalarType.string))
/**
* NOTE: Compiler generates this unused decrement bc
* it doesn't know that we are inlining just join
* and do not need to access the element.
*/
val decrement = CallServiceRes(
LiteralModel.quote("math"),
"sub",
CallRes(
List(arg, LiteralModel.number(1)),
Some(CallModel.Export("stream_idx", argType))
),
LiteralModel.fromRaw(ValueRaw.InitPeerId)
).leaf
val expected = XorRes.wrap(
SeqRes.wrap(
RestrictionRes(resVM.name, resStreamType).wrap(
getDataSrv("-relay-", "-relay-", ScalarType.string),
getDataSrv("i", argName, argType),
RestrictionRes(streamName, streamType).wrap(
SeqRes.wrap(
// res <- foo()
ApRes(
LiteralModel.fromRaw(LiteralRaw.quote("I am MyFooBar foo")),
CallModel.Export(resVM.name, resVM.`type`)
).leaf,
// res <- bar()
ApRes(
LiteralModel.fromRaw(LiteralRaw.quote(" I am MyFooBar bar")),
CallModel.Export(resVM.name, resVM.`type`)
).leaf,
// canonicalization
CanonRes(
resVM,
LiteralModel.fromRaw(ValueRaw.InitPeerId),
CallModel.Export(resCanonVM.name, resCanonVM.`type`)
).leaf,
// flattening
ApRes(
VarModel(resCanonVM.name, resCanonVM.`type`),
CallModel.Export(resFlatVM.name, resFlatVM.`type`)
).leaf
ApRes(LiteralModel.quote("a"), CallModel.Export(streamName, streamType)).leaf,
ApRes(LiteralModel.quote("b"), CallModel.Export(streamName, streamType)).leaf,
join(VarModel(streamName, streamType), arg),
decrement
)
),
respCall(transformCfg, resFlatVM, initPeer)
)
),
errorCall(transformCfg, 0, initPeer)
)
barfoo.body.equalsOrShowDiff(expected) should be(true)
insideRes(src, transformCfg = transformCfg)("main") { case main :: _ =>
main.body.equalsOrShowDiff(expected) should be(true)
}
}
}

View File

@ -3,20 +3,12 @@ package aqua.model.inline
import aqua.model.inline.state.{Arrows, Counter, Exports, Mangler}
import aqua.model.inline.Inline.MergeMode.*
import aqua.model.*
import aqua.model.inline.raw.{
ApplyBinaryOpRawInliner,
ApplyFunctorRawInliner,
ApplyPropertiesRawInliner,
ApplyUnaryOpRawInliner,
CallArrowRawInliner,
CollectionRawInliner,
MakeAbilityRawInliner,
StreamGateInliner
}
import aqua.model.inline.raw.*
import aqua.raw.ops.*
import aqua.raw.value.*
import aqua.types.{ArrayType, LiteralType, OptionType, StreamType}
import cats.Eval
import cats.syntax.traverse.*
import cats.syntax.monoid.*
import cats.syntax.functor.*
@ -34,8 +26,10 @@ object RawValueInliner extends Logging {
private[inline] def unfold[S: Mangler: Exports: Arrows](
raw: ValueRaw,
propertiesAllowed: Boolean = true
): State[S, (ValueModel, Inline)] =
raw match {
): State[S, (ValueModel, Inline)] = for {
optimized <- StateT.liftF(Optimization.optimize(raw))
_ <- StateT.liftF(Eval.later(logger.trace("OPTIMIZIED " + optimized)))
result <- optimized match {
case VarRaw(name, t) =>
for {
exports <- Exports[S].exports
@ -65,7 +59,12 @@ object RawValueInliner extends Logging {
case cr: CallArrowRaw =>
CallArrowRawInliner(cr, propertiesAllowed)
case cs: CallServiceRaw =>
CallServiceRawInliner(cs, propertiesAllowed)
}
} yield result
private[inline] def inlineToTree[S: Mangler: Exports: Arrows](
inline: Inline
@ -101,10 +100,10 @@ object RawValueInliner extends Logging {
def valueToModel[S: Mangler: Exports: Arrows](
value: ValueRaw,
propertiesAllowed: Boolean = true
): State[S, (ValueModel, Option[OpModel.Tree])] = {
logger.trace("RAW " + value)
toModel(unfold(value, propertiesAllowed))
}
): State[S, (ValueModel, Option[OpModel.Tree])] = for {
_ <- StateT.liftF(Eval.later(logger.trace("RAW " + value)))
model <- toModel(unfold(value, propertiesAllowed))
} yield model
def valueListToModel[S: Mangler: Exports: Arrows](
values: List[ValueRaw]

View File

@ -4,7 +4,7 @@ import aqua.errors.Errors.internalError
import aqua.model.inline.state.{Arrows, Exports, Mangler}
import aqua.model.*
import aqua.model.inline.RawValueInliner.collectionToModel
import aqua.model.inline.raw.CallArrowRawInliner
import aqua.model.inline.raw.{CallArrowRawInliner, CallServiceRawInliner}
import aqua.raw.value.ApplyBinaryOpRaw.Op as BinOp
import aqua.raw.ops.*
import aqua.raw.value.*
@ -308,8 +308,13 @@ object TagInliner extends Logging {
TagInlined.Empty(prefix = parDesugarPrefix(nel.toList.flatMap(_._2)))
})
case CallArrowRawTag(exportTo, value: CallArrowRaw) =>
CallArrowRawInliner.unfoldArrow(value, exportTo).flatMap { case (_, inline) =>
case CallArrowRawTag(exportTo, value: (CallArrowRaw | CallServiceRaw)) =>
(value match {
case ca: CallArrowRaw =>
CallArrowRawInliner.unfold(ca, exportTo)
case cs: CallServiceRaw =>
CallServiceRawInliner.unfold(cs, exportTo)
}).flatMap { case (_, inline) =>
RawValueInliner
.inlineToTree(inline)
.map(tree =>

View File

@ -1,5 +1,6 @@
package aqua.model.inline.raw
import aqua.errors.Errors.internalError
import aqua.model.*
import aqua.model.inline.raw.RawInliner
import aqua.model.inline.TagInliner
@ -8,8 +9,9 @@ import aqua.raw.value.{AbilityRaw, LiteralRaw, MakeStructRaw}
import cats.data.{NonEmptyList, NonEmptyMap, State}
import aqua.model.inline.Inline
import aqua.model.inline.RawValueInliner.{unfold, valueToModel}
import aqua.types.{ArrowType, ScalarType}
import aqua.types.{ArrowType, ScalarType, Type}
import aqua.raw.value.ApplyBinaryOpRaw
import aqua.raw.value.ApplyBinaryOpRaw.Op
import aqua.raw.value.ApplyBinaryOpRaw.Op.*
import aqua.model.inline.Inline.MergeMode
@ -21,12 +23,10 @@ import cats.syntax.flatMap.*
import cats.syntax.apply.*
import cats.syntax.foldable.*
import cats.syntax.applicative.*
import aqua.types.LiteralType
object ApplyBinaryOpRawInliner extends RawInliner[ApplyBinaryOpRaw] {
private type BoolOp = And.type | Or.type
private type EqOp = Eq.type | Neq.type
override def apply[S: Mangler: Exports: Arrows](
raw: ApplyBinaryOpRaw,
propertiesAllowed: Boolean
@ -37,16 +37,49 @@ object ApplyBinaryOpRawInliner extends RawInliner[ApplyBinaryOpRaw] {
(rmodel, rinline) = right
result <- raw.op match {
case op @ (And | Or) => inlineBoolOp(lmodel, rmodel, linline, rinline, op)
case op @ (Eq | Neq) =>
case op: Op.Bool =>
inlineBoolOp(
lmodel,
rmodel,
linline,
rinline,
op,
raw.baseType
)
case op: Op.Eq =>
for {
// Canonicalize stream operands before comparison
leftStream <- TagInliner.canonicalizeIfStream(lmodel)
(lmodelStream, linlineStream) = leftStream.map(linline.append)
rightStream <- TagInliner.canonicalizeIfStream(rmodel)
(rmodelStream, rinlineStream) = rightStream.map(rinline.append)
result <- inlineEqOp(lmodelStream, rmodelStream, linlineStream, rinlineStream, op)
result <- inlineEqOp(
lmodelStream,
rmodelStream,
linlineStream,
rinlineStream,
op,
raw.baseType
)
} yield result
case op: Op.Cmp =>
inlineCmpOp(
lmodel,
rmodel,
linline,
rinline,
op,
raw.baseType
)
case op: Op.Math =>
inlineMathOp(
lmodel,
rmodel,
linline,
rinline,
op,
raw.baseType
)
}
} yield result
@ -55,7 +88,8 @@ object ApplyBinaryOpRawInliner extends RawInliner[ApplyBinaryOpRaw] {
rmodel: ValueModel,
linline: Inline,
rinline: Inline,
op: EqOp
op: Op.Eq,
resType: Type
): State[S, (ValueModel, Inline)] = (lmodel, rmodel) match {
// Optimize in case compared values are literals
// Semantics should check that types are comparable
@ -69,7 +103,7 @@ object ApplyBinaryOpRawInliner extends RawInliner[ApplyBinaryOpRaw] {
},
linline.mergeWith(rinline, MergeMode.ParMode)
).pure[State[S, *]]
case _ => fullInlineEqOp(lmodel, rmodel, linline, rinline, op)
case _ => fullInlineEqOp(lmodel, rmodel, linline, rinline, op, resType)
}
private def fullInlineEqOp[S: Mangler: Exports: Arrows](
@ -77,7 +111,8 @@ object ApplyBinaryOpRawInliner extends RawInliner[ApplyBinaryOpRaw] {
rmodel: ValueModel,
linline: Inline,
rinline: Inline,
op: EqOp
op: Op.Eq,
resType: Type
): State[S, (ValueModel, Inline)] = {
val (name, shouldMatch) = op match {
case Eq => ("eq", true)
@ -114,7 +149,7 @@ object ApplyBinaryOpRawInliner extends RawInliner[ApplyBinaryOpRaw] {
)
)
result(name, predo)
result(name, resType, predo)
}
private def inlineBoolOp[S: Mangler: Exports: Arrows](
@ -122,7 +157,8 @@ object ApplyBinaryOpRawInliner extends RawInliner[ApplyBinaryOpRaw] {
rmodel: ValueModel,
linline: Inline,
rinline: Inline,
op: BoolOp
op: Op.Bool,
resType: Type
): State[S, (ValueModel, Inline)] = (lmodel, rmodel) match {
// Optimize in case of left value is known at compile time
case (LiteralModel.Bool(lvalue), _) =>
@ -139,7 +175,7 @@ object ApplyBinaryOpRawInliner extends RawInliner[ApplyBinaryOpRaw] {
case _ => (lmodel, linline)
}).pure[State[S, *]]
// Produce unoptimized inline
case _ => fullInlineBoolOp(lmodel, rmodel, linline, rinline, op)
case _ => fullInlineBoolOp(lmodel, rmodel, linline, rinline, op, resType)
}
private def fullInlineBoolOp[S: Mangler: Exports: Arrows](
@ -147,7 +183,8 @@ object ApplyBinaryOpRawInliner extends RawInliner[ApplyBinaryOpRaw] {
rmodel: ValueModel,
linline: Inline,
rinline: Inline,
op: BoolOp
op: Op.Bool,
resType: Type
): State[S, (ValueModel, Inline)] = {
val (name, compareWith) = op match {
case And => ("and", false)
@ -190,19 +227,162 @@ object ApplyBinaryOpRawInliner extends RawInliner[ApplyBinaryOpRaw] {
)
)
result(name, predo)
result(name, resType, predo)
}
private def inlineCmpOp[S: Mangler: Exports: Arrows](
lmodel: ValueModel,
rmodel: ValueModel,
linline: Inline,
rinline: Inline,
op: Op.Cmp,
resType: Type
): State[S, (ValueModel, Inline)] = (lmodel, rmodel) match {
case (
LiteralModel.Integer(lv, _),
LiteralModel.Integer(rv, _)
) =>
val res = op match {
case Lt => lv < rv
case Lte => lv <= rv
case Gt => lv > rv
case Gte => lv >= rv
}
(
LiteralModel.bool(res),
Inline(linline.predo ++ rinline.predo)
).pure
case _ =>
val fn = op match {
case Lt => "lt"
case Lte => "lte"
case Gt => "gt"
case Gte => "gte"
}
val predo = (resName: String) =>
SeqModel.wrap(
linline.predo ++ rinline.predo :+ CallServiceModel(
serviceId = LiteralModel.quote("cmp"),
funcName = fn,
call = CallModel(
args = lmodel :: rmodel :: Nil,
exportTo = CallModel.Export(resName, resType) :: Nil
)
).leaf
)
result(fn, resType, predo)
}
private def inlineMathOp[S: Mangler: Exports: Arrows](
lmodel: ValueModel,
rmodel: ValueModel,
linline: Inline,
rinline: Inline,
op: Op.Math,
resType: Type
): State[S, (ValueModel, Inline)] = (lmodel, rmodel) match {
case (
LiteralModel.Integer(lv, lt),
LiteralModel.Integer(rv, rt)
) if canOptimizeMath(lv, lt, rv, rt, op) =>
val res = op match {
case Add => lv + rv
case Sub => lv - rv
case Mul => lv * rv
case Div => lv / rv
case Rem => lv % rv
case Pow => intPow(lv, rv)
case _ => internalError(s"Unsupported operation $op for $lv and $rv")
}
(
LiteralModel.number(res),
Inline(linline.predo ++ rinline.predo)
).pure
case _ =>
val fn = op match {
case Add => "add"
case Sub => "sub"
case Mul => "mul"
case FMul => "fmul"
case Div => "div"
case Rem => "rem"
case Pow => "pow"
}
val predo = (resName: String) =>
SeqModel.wrap(
linline.predo ++ rinline.predo :+ CallServiceModel(
serviceId = LiteralModel.quote("math"),
funcName = fn,
call = CallModel(
args = lmodel :: rmodel :: Nil,
exportTo = CallModel.Export(resName, resType) :: Nil
)
).leaf
)
result(fn, resType, predo)
}
private def result[S: Mangler](
name: String,
resType: Type,
predo: String => OpModel.Tree
): State[S, (ValueModel, Inline)] =
Mangler[S]
.findAndForbidName(name)
.map(resName =>
(
VarModel(resName, ScalarType.bool),
VarModel(resName, resType),
Inline(Chain.one(predo(resName)))
)
)
/**
* Check if we can optimize math operation
* in compile time.
*
* @param left left operand
* @param leftType type of left operand
* @param right right operand
* @param rightType type of right operand
* @param op operation
* @return true if we can optimize this operation
*/
private def canOptimizeMath(
left: Long,
leftType: ScalarType | LiteralType,
right: Long,
rightType: ScalarType | LiteralType,
op: Op.Math
): Boolean = op match {
// Leave division by zero for runtime
case Op.Div | Op.Rem => right != 0
// Leave negative power for runtime
case Op.Pow => right >= 0
case Op.Sub =>
// Leave subtraction overflow for runtime
ScalarType.isSignedInteger(leftType) ||
ScalarType.isSignedInteger(rightType)
case _ => true
}
/**
* Integer power (binary exponentiation)
*
* @param base
* @param exp >= 0
* @return base ** exp
*/
private def intPow(base: Long, exp: Long): Long = {
def intPowTailRec(base: Long, exp: Long, acc: Long): Long =
if (exp <= 0) acc
else intPowTailRec(base * base, exp / 2, if (exp % 2 == 0) acc else acc * base)
intPowTailRec(base, exp, 1)
}
}

View File

@ -257,51 +257,59 @@ object ApplyPropertiesRawInliner extends RawInliner[ApplyPropertyRaw] with Loggi
idx: ValueRaw
): State[S, (VarModel, Inline)] = for {
/**
* Inline idx
* Inline size, which is `idx + 1`
* Increment on ValueRaw level to
* apply possible optimizations
*/
idxInlined <- unfold(idx)
sizeInlined <- unfold(idx.increment)
(sizeVM, sizeInline) = sizeInlined
/**
* Inline idx which is `size - 1`
* TODO: Do not generate it if
* it is not needed, e.g. in `join`
*/
idxInlined <- sizeVM match {
/**
* Micro optimization: if idx is a literal
* do not generate inline.
*/
case LiteralModel.Integer(i, t) =>
(LiteralModel((i - 1).toString, t), Inline.empty).pure[State[S, *]]
case _ =>
Mangler[S].findAndForbidName(s"${streamName}_idx").map { idxName =>
val idxVar = VarModel(idxName, sizeVM.`type`)
val idxInline = Inline.tree(
CallServiceModel(
"math",
funcName = "sub",
args = List(sizeVM, LiteralModel.number(1)),
result = idxVar
).leaf
)
(idxVar, idxInline)
}
}
(idxVM, idxInline) = idxInlined
/**
* Inline size which is `idx + 1`
* TODO: Refactor to apply optimizations
* Inline join of `size` elements of stream
*/
sizeName <- Mangler[S].findAndForbidName(s"${streamName}_size")
sizeVar = VarModel(sizeName, idxVM.`type`)
sizeInline = CallServiceModel(
"math",
funcName = "add",
args = List(idxVM, LiteralModel.number(1)),
result = sizeVar
).leaf
gateInlined <- StreamGateInliner(streamName, streamType, sizeVar)
gateInlined <- StreamGateInliner(streamName, streamType, sizeVM)
(gateVM, gateInline) = gateInlined
/**
* Remove properties from idx
* as we need to use it in index
* TODO: Do not generate it
* if it is not needed,
* e.g. in `join`
*/
idxFlattened <- idxVM match {
case vr: VarModel => removeProperties(vr)
case _ => (idxVM, Inline.empty).pure[State[S, *]]
}
(idxFlat, idxFlatInline) = idxFlattened
/**
* Construct stream[idx]
*/
gate = gateVM.withProperty(
IntoIndexModel
.fromValueModel(idxFlat, streamType.element)
.fromValueModel(idxVM, streamType.element)
.getOrElse(
internalError(s"Unexpected: could not convert ($idxFlat) to IntoIndexModel")
internalError(s"Unexpected: could not convert ($idxVM) to IntoIndexModel")
)
)
} yield gate -> Inline(
idxInline.predo
.append(sizeInline) ++
sizeInline.predo ++
gateInline.predo ++
idxFlatInline.predo,
idxInline.predo,
mergeMode = SeqMode
)

View File

@ -7,48 +7,28 @@ import aqua.model.inline.state.{Arrows, Exports, Mangler}
import aqua.model.inline.{ArrowInliner, Inline, TagInliner}
import aqua.raw.ops.Call
import aqua.raw.value.CallArrowRaw
import cats.data.{Chain, State}
import cats.syntax.traverse.*
import scribe.Logging
object CallArrowRawInliner extends RawInliner[CallArrowRaw] with Logging {
private[inline] def unfoldArrow[S: Mangler: Exports: Arrows](
private[inline] def unfold[S: Mangler: Exports: Arrows](
value: CallArrowRaw,
exportTo: List[Call.Export]
): State[S, (List[ValueModel], Inline)] = Exports[S].exports.flatMap { exports =>
logger.trace(s"${exportTo.mkString(" ")} $value")
val call = Call(value.arguments, exportTo)
value.serviceId match {
case Some(serviceId) =>
logger.trace(Console.BLUE + s"call service id $serviceId" + Console.RESET)
for {
cd <- callToModel(call, true)
(callModel, callInline) = cd
sd <- valueToModel(serviceId)
(serviceIdValue, serviceIdInline) = sd
values = callModel.exportTo.map(e => e.name -> e.asVar.resolveWith(exports)).toMap
inline = Inline(
Chain(
SeqModel.wrap(
serviceIdInline.toList ++ callInline.toList :+
CallServiceModel(serviceIdValue, value.name, callModel).leaf
)
)
)
_ <- Exports[S].resolved(values)
_ <- Mangler[S].forbid(values.keySet)
} yield values.values.toList -> inline
case None =>
/**
* Here the back hop happens from [[TagInliner]] to [[ArrowInliner.callArrow]]
*/
val funcName = value.ability.fold(value.name)(_ + "." + value.name)
logger.trace(s" $funcName")
resolveArrow(funcName, call)
}
/**
* Here the back hop happens from [[TagInliner]] to [[ArrowInliner.callArrow]]
*/
val funcName = value.ability.fold(value.name)(_ + "." + value.name)
logger.trace(s" $funcName")
resolveArrow(funcName, call)
}
private def resolveFuncArrow[S: Mangler: Exports: Arrows](
@ -103,7 +83,7 @@ object CallArrowRawInliner extends RawInliner[CallArrowRaw] with Logging {
Mangler[S]
.findAndForbidName(raw.name)
.flatMap(n =>
unfoldArrow(raw, Call.Export(n, raw.`type`) :: Nil).map {
unfold(raw, Call.Export(n, raw.`type`) :: Nil).map {
case (Nil, inline) => (VarModel(n, raw.`type`), inline)
case (h :: _, inline) => (h, inline)
}

View File

@ -0,0 +1,57 @@
package aqua.model.inline.raw
import aqua.errors.Errors.internalError
import aqua.model.*
import aqua.model.inline.RawValueInliner.{callToModel, valueToModel}
import aqua.model.inline.state.{Arrows, Exports, Mangler}
import aqua.model.inline.{ArrowInliner, Inline, TagInliner}
import aqua.raw.ops.Call
import aqua.raw.value.CallServiceRaw
import cats.data.{Chain, State}
import cats.syntax.traverse.*
import scribe.Logging
object CallServiceRawInliner extends RawInliner[CallServiceRaw] with Logging {
private[inline] def unfold[S: Mangler: Exports: Arrows](
value: CallServiceRaw,
exportTo: List[Call.Export]
): State[S, (List[ValueModel], Inline)] = Exports[S].exports.flatMap { exports =>
logger.trace(s"${exportTo.mkString(" ")} $value")
logger.trace(Console.BLUE + s"call service id ${value.serviceId}" + Console.RESET)
val call = Call(value.arguments, exportTo)
for {
cd <- callToModel(call, true)
(callModel, callInline) = cd
sd <- valueToModel(value.serviceId)
(serviceIdValue, serviceIdInline) = sd
values = callModel.exportTo.map(e => e.name -> e.asVar.resolveWith(exports)).toMap
inline = Inline(
Chain(
SeqModel.wrap(
serviceIdInline.toList ++ callInline.toList :+
CallServiceModel(serviceIdValue, value.fnName, callModel).leaf
)
)
)
_ <- Exports[S].resolved(values)
_ <- Mangler[S].forbid(values.keySet)
} yield values.values.toList -> inline
}
override def apply[S: Mangler: Exports: Arrows](
raw: CallServiceRaw,
propertiesAllowed: Boolean
): State[S, (ValueModel, Inline)] =
Mangler[S]
.findAndForbidName(raw.fnName)
.flatMap(n =>
unfold(raw, Call.Export(n, raw.`type`) :: Nil).map {
case (Nil, inline) => (VarModel(n, raw.`type`), inline)
case (h :: _, inline) => (h, inline)
}
)
}

View File

@ -4,7 +4,6 @@ import aqua.errors.Errors.internalError
import aqua.model.*
import aqua.model.inline.Inline
import aqua.model.inline.state.{Arrows, Exports, Mangler}
import aqua.raw.value.{LiteralRaw, VarRaw}
import aqua.model.inline.RawValueInliner.unfold
import aqua.types.{ArrayType, CanonStreamType, ScalarType, StreamType}
@ -25,16 +24,14 @@ object StreamGateInliner extends Logging {
* (seq
* (fold $stream s
* (seq
* (seq
* (ap s $stream_test)
* (canon <peer> $stream_test #stream_iter_canon)
* )
* (xor
* (match #stream_iter_canon.length size
* (null)
* )
* (next s)
* (ap s $stream_test)
* (canon <peer> $stream_test #stream_iter_canon)
* )
* (xor
* (match #stream_iter_canon.length size
* (null)
* )
* (next s)
* )
* (never)
* )
@ -100,7 +97,6 @@ object StreamGateInliner extends Logging {
uniqueCanonName <- Mangler[S].findAndForbidName(streamName + "_result_canon")
uniqueResultName <- Mangler[S].findAndForbidName(streamName + "_gate")
uniqueTestName <- Mangler[S].findAndForbidName(streamName + "_test")
uniqueIdxIncr <- Mangler[S].findAndForbidName(streamName + "_incr")
uniqueIterCanon <- Mangler[S].findAndForbidName(streamName + "_iter_canon")
uniqueIter <- Mangler[S].findAndForbidName(streamName + "_fold_var")
} yield {

View File

@ -21,7 +21,7 @@ final case class IfTagInliner(
def inlined[S: Mangler: Exports: Arrows] =
(valueRaw match {
// Optimize in case last operation is equality check
case ApplyBinaryOpRaw(op @ (BinOp.Eq | BinOp.Neq), left, right) =>
case ApplyBinaryOpRaw(op @ (BinOp.Eq | BinOp.Neq), left, right, _) =>
(
valueToModel(left) >>= canonicalizeIfStream,
valueToModel(right) >>= canonicalizeIfStream

View File

@ -827,7 +827,7 @@ class ArrowInlinerSpec extends AnyFlatSpec with Matchers with Inside {
val res1 = VarModel("res", ScalarType.u16)
val res2 = VarModel("res2", ScalarType.u16)
val res3 = VarModel("res-0", ScalarType.u16)
val tempAdd = VarModel("add-0", ScalarType.u16)
val tempAdd = VarModel("add", ScalarType.u16)
val expected = SeqModel.wrap(
MetaModel
@ -843,7 +843,7 @@ class ArrowInlinerSpec extends AnyFlatSpec with Matchers with Inside {
),
SeqModel.wrap(
ModelBuilder.add(res2, res3)(tempAdd).leaf,
ModelBuilder.add(res1, tempAdd)(VarModel("add", ScalarType.u16)).leaf
ModelBuilder.add(res1, tempAdd)(VarModel("add-0", ScalarType.u16)).leaf
)
)
@ -889,15 +889,12 @@ class ArrowInlinerSpec extends AnyFlatSpec with Matchers with Inside {
capturedTopology = None
)
val innerCall = CallArrowRaw(
ability = None,
name = innerName,
arguments = Nil,
val innerCall = CallArrowRaw.func(
funcName = innerName,
baseType = ArrowType(
domain = NilType,
codomain = ProductType(List(ScalarType.u16))
),
serviceId = None
)
)
val outerAdd = "37"
@ -943,36 +940,18 @@ class ArrowInlinerSpec extends AnyFlatSpec with Matchers with Inside {
.runA(InliningState())
.value
/* WARNING: This naming is unstable */
val tempAdd0 = VarModel("add-0", ScalarType.u16)
val tempAdd = VarModel("add", ScalarType.u16)
val expected = SeqModel.wrap(
ModelBuilder
.add(
LiteralModel(innerRet, ScalarType.u16),
LiteralModel(outerAdd, ScalarType.u16)
)(tempAdd0)
.leaf,
ModelBuilder
.add(
LiteralModel(innerRet, ScalarType.u16),
tempAdd0
)(tempAdd)
.leaf
)
model.equalsOrShowDiff(expected) shouldEqual true
// Addition is completely optimized out
model.equalsOrShowDiff(EmptyModel.leaf) shouldEqual true
}
/**
* closureName = (x: u16) -> u16:
* retval = x + add
* retval <- TestSrv.call(x, add)
* <- retval
*
* @return (closure func, closure type, closure type labelled)
*/
def addClosure(
def srvCallClosure(
closureName: String,
add: ValueRaw
): (FuncRaw, ArrowType, ArrowType) = {
@ -993,13 +972,16 @@ class ArrowInlinerSpec extends AnyFlatSpec with Matchers with Inside {
)
val closureBody = SeqTag.wrap(
AssignmentTag(
RawBuilder.add(
closureArg,
add
),
closureRes.name
).leaf,
CallArrowRawTag
.service(
LiteralRaw.quote("test-srv"),
funcName = "call",
Call(
args = List(closureArg, add),
exportTo = List(Call.Export(closureRes.name, closureRes.`type`))
)
)
.leaf,
ReturnTag(
NonEmptyList.one(closureRes)
).leaf
@ -1017,10 +999,21 @@ class ArrowInlinerSpec extends AnyFlatSpec with Matchers with Inside {
(closureFunc, closureType, closureTypeLabelled)
}
def srvCallModel(
x: ValueModel,
add: ValueModel,
result: VarModel
): CallServiceModel = CallServiceModel(
serviceId = "test-srv",
funcName = "call",
args = List(x, add),
result = result
)
/**
* func innerName(arg: u16) -> u16 -> u16:
* closureName = (x: u16) -> u16:
* retval = x + arg
* retval <- TestSrv.call(x, arg)
* <- retval
* <- closureName
*
@ -1042,7 +1035,7 @@ class ArrowInlinerSpec extends AnyFlatSpec with Matchers with Inside {
)
val (closureFunc, closureType, closureTypeLabelled) =
addClosure(closureName, innerArg)
srvCallClosure(closureName, innerArg)
val innerRes = VarRaw(
closureName,
@ -1085,12 +1078,10 @@ class ArrowInlinerSpec extends AnyFlatSpec with Matchers with Inside {
val innerCall =
CallArrowRawTag(
List(Call.Export(outterClosure.name, outterClosure.`type`)),
CallArrowRaw(
ability = None,
name = innerName,
arguments = List(LiteralRaw("42", LiteralType.number)),
CallArrowRaw.func(
funcName = innerName,
baseType = innerType,
serviceId = None
arguments = List(LiteralRaw("42", LiteralType.number))
)
).leaf
@ -1128,7 +1119,7 @@ class ArrowInlinerSpec extends AnyFlatSpec with Matchers with Inside {
/**
* func inner(arg: u16) -> u16 -> u16:
* closure = (x: u16) -> u16:
* retval = x + arg
* retval <- TestSrv.call(x, arg)
* <- retval
* <- closure
*
@ -1144,12 +1135,10 @@ class ArrowInlinerSpec extends AnyFlatSpec with Matchers with Inside {
val outterResultName = "retval"
val closureCall = (closureType: ArrowType, i: String) =>
CallArrowRaw(
ability = None,
name = outterClosureName,
arguments = List(LiteralRaw(i, LiteralType.number)),
CallArrowRaw.func(
funcName = outterClosureName,
baseType = closureType,
serviceId = None
arguments = List(LiteralRaw(i, LiteralType.unsigned))
)
val body = (closureType: ArrowType) =>
@ -1157,7 +1146,7 @@ class ArrowInlinerSpec extends AnyFlatSpec with Matchers with Inside {
AssignmentTag(
RawBuilder.add(
RawBuilder.add(
LiteralRaw("37", LiteralType.number),
LiteralRaw("37", LiteralType.unsigned),
closureCall(closureType, "1")
),
closureCall(closureType, "2")
@ -1180,20 +1169,19 @@ class ArrowInlinerSpec extends AnyFlatSpec with Matchers with Inside {
.wrap(
ApplyTopologyModel(closureName)
.wrap(
ModelBuilder
.add(
LiteralModel(x, LiteralType.number),
LiteralModel("42", LiteralType.number)
)(o)
.leaf
srvCallModel(
LiteralModel(x, LiteralType.unsigned),
LiteralModel("42", LiteralType.unsigned),
result = o
).leaf
)
)
/* WARNING: This naming is unstable */
val tempAdd0 = VarModel("add-0", ScalarType.u16)
val tempAdd1 = VarModel("add-1", ScalarType.u16)
val tempAdd2 = VarModel("add-2", ScalarType.u16)
val retval1 = VarModel("retval-0", ScalarType.u16)
val retval2 = VarModel("retval-1", ScalarType.u16)
val tempAdd = VarModel("add", ScalarType.u16)
val tempAdd0 = VarModel("add-0", ScalarType.u16)
val expected = SeqModel.wrap(
MetaModel
@ -1202,23 +1190,21 @@ class ArrowInlinerSpec extends AnyFlatSpec with Matchers with Inside {
CaptureTopologyModel(closureName).leaf
),
SeqModel.wrap(
ParModel.wrap(
SeqModel.wrap(
closureCallModel("1", tempAdd1),
ModelBuilder
.add(
LiteralModel("37", LiteralType.number),
tempAdd1
)(tempAdd0)
.leaf
),
closureCallModel("2", tempAdd2)
SeqModel.wrap(
closureCallModel("1", retval1),
closureCallModel("2", retval2),
ModelBuilder
.add(
retval1,
retval2
)(tempAdd)
.leaf
),
ModelBuilder
.add(
tempAdd0,
tempAdd2
)(tempAdd)
tempAdd,
LiteralModel.number(37)
)(tempAdd0)
.leaf
)
)
@ -1306,22 +1292,18 @@ class ArrowInlinerSpec extends AnyFlatSpec with Matchers with Inside {
val innerCall =
CallArrowRawTag(
List(Call.Export(outterClosure.name, outterClosure.`type`)),
CallArrowRaw(
ability = None,
name = innerName,
arguments = Nil,
CallArrowRaw.func(
funcName = innerName,
baseType = innerType,
serviceId = None
arguments = Nil
)
).leaf
val closureCall =
CallArrowRaw(
ability = None,
name = outterClosure.name,
arguments = Nil,
CallArrowRaw.func(
funcName = outterClosure.name,
baseType = closureType,
serviceId = None
arguments = Nil
)
val outerBody = SeqTag.wrap(
@ -1365,38 +1347,14 @@ class ArrowInlinerSpec extends AnyFlatSpec with Matchers with Inside {
.runA(InliningState())
.value
/* WARNING: This naming is unstable */
val tempAdd0 = VarModel("add-0", ScalarType.u16)
val tempAdd = VarModel("add", ScalarType.u16)
val number = (v: String) =>
LiteralModel(
v,
LiteralType.number
)
val expected = SeqModel.wrap(
ModelBuilder
.add(
number("37"),
number("42")
)(tempAdd0)
.leaf,
ModelBuilder
.add(
tempAdd0,
number("42")
)(tempAdd)
.leaf
)
model.equalsOrShowDiff(expected) shouldEqual true
// Addition is completely optimized out
model.equalsOrShowDiff(EmptyModel.leaf) shouldEqual true
}
/**
* func inner(arg: u16) -> u16 -> u16:
* closure = (x: u16) -> u16:
* retval = x + arg
* retval = TestSrv.call(x, arg)
* <- retval
* <- closure
*
@ -1404,7 +1362,7 @@ class ArrowInlinerSpec extends AnyFlatSpec with Matchers with Inside {
* c <- inner(42)
* b = c
* a = b
* retval = 37 + a(1) + b(2) + c{3}
* retval = 37 + a(1) + b(2) + c(3)
* <- retval
*/
it should "correctly inline renamed closure [bug LNG-193]" in {
@ -1416,12 +1374,10 @@ class ArrowInlinerSpec extends AnyFlatSpec with Matchers with Inside {
val secondRename = "a"
val closureCall = (name: String, closureType: ArrowType, i: String) =>
CallArrowRaw(
ability = None,
name = name,
CallArrowRaw.func(
funcName = name,
arguments = List(LiteralRaw(i, LiteralType.number)),
baseType = closureType,
serviceId = None
baseType = closureType
)
val body = (closureType: ArrowType) =>
@ -1458,28 +1414,27 @@ class ArrowInlinerSpec extends AnyFlatSpec with Matchers with Inside {
body = body
)
val closureCallModel = (x: String, o: VarModel) =>
val closureCallModel = (x: Long, o: VarModel) =>
MetaModel
.CallArrowModel(closureName)
.wrap(
ApplyTopologyModel(closureName)
.wrap(
ModelBuilder
.add(
LiteralModel(x, LiteralType.number),
LiteralModel("42", LiteralType.number)
)(o)
.leaf
srvCallModel(
LiteralModel.number(x),
LiteralModel.number(42),
result = o
).leaf
)
)
/* WARNING: This naming is unstable */
val tempAdd = VarModel("add", ScalarType.u16)
val tempAdd0 = VarModel("add-0", ScalarType.u16)
val tempAdd1 = VarModel("add-1", ScalarType.u16)
val tempAdd2 = VarModel("add-2", ScalarType.u16)
val tempAdd3 = VarModel("add-3", ScalarType.u16)
val tempAdd4 = VarModel("add-4", ScalarType.u16)
val tempAdd = VarModel("add", ScalarType.u16)
val retval0 = VarModel("retval-0", ScalarType.u16)
val retval1 = VarModel("retval-1", ScalarType.u16)
val retval2 = VarModel("retval-2", ScalarType.u16)
val expected = SeqModel.wrap(
MetaModel
@ -1488,35 +1443,16 @@ class ArrowInlinerSpec extends AnyFlatSpec with Matchers with Inside {
CaptureTopologyModel(closureName).leaf
),
SeqModel.wrap(
ParModel.wrap(
SeqModel.wrap(
SeqModel.wrap(
ParModel.wrap(
SeqModel.wrap(
closureCallModel("1", tempAdd2),
ModelBuilder
.add(
LiteralModel("37", LiteralType.number),
tempAdd2
)(tempAdd1)
.leaf
),
closureCallModel("2", tempAdd3)
),
ModelBuilder
.add(
tempAdd1,
tempAdd3
)(tempAdd0)
.leaf
closureCallModel(1, retval0),
closureCallModel(2, retval1),
ModelBuilder.add(retval0, retval1)(tempAdd).leaf
),
closureCallModel("3", tempAdd4)
closureCallModel(3, retval2),
ModelBuilder.add(tempAdd, retval2)(tempAdd0).leaf
),
ModelBuilder
.add(
tempAdd0,
tempAdd4
)(tempAdd)
.leaf
ModelBuilder.add(tempAdd0, LiteralModel.number(37))(tempAdd1).leaf
)
)
@ -1530,7 +1466,7 @@ class ArrowInlinerSpec extends AnyFlatSpec with Matchers with Inside {
*
* func test() -> u16:
* closure = (x: u16) -> u16:
* resC = x + 37
* resC <- TestSrv.call(x, 37)
* <- resC
* resT <- accept_closure(closure)
* <- resT
@ -1543,7 +1479,7 @@ class ArrowInlinerSpec extends AnyFlatSpec with Matchers with Inside {
val testRes = VarRaw("resT", ScalarType.u16)
val (closureFunc, closureType, closureTypeLabelled) =
addClosure(closureName, LiteralRaw("37", LiteralType.number))
srvCallClosure(closureName, LiteralRaw.number(37))
val acceptType = ArrowType(
domain = ProductType.labelled(List(closureName -> closureType)),
@ -1553,12 +1489,10 @@ class ArrowInlinerSpec extends AnyFlatSpec with Matchers with Inside {
val acceptBody = SeqTag.wrap(
CallArrowRawTag(
List(Call.Export(acceptRes.name, acceptRes.baseType)),
CallArrowRaw(
ability = None,
name = closureName,
arguments = List(LiteralRaw("42", LiteralType.number)),
CallArrowRaw.func(
funcName = closureName,
baseType = closureType,
serviceId = None
arguments = List(LiteralRaw.number(42))
)
).leaf,
ReturnTag(
@ -1586,12 +1520,10 @@ class ArrowInlinerSpec extends AnyFlatSpec with Matchers with Inside {
).leaf,
CallArrowRawTag(
List(Call.Export(testRes.name, testRes.baseType)),
CallArrowRaw(
ability = None,
name = acceptName,
arguments = List(VarRaw(closureName, closureTypeLabelled)),
CallArrowRaw.func(
funcName = acceptName,
baseType = acceptFunc.arrowType,
serviceId = None
arguments = List(VarRaw(closureName, closureTypeLabelled))
)
).leaf,
ReturnTag(
@ -1621,7 +1553,7 @@ class ArrowInlinerSpec extends AnyFlatSpec with Matchers with Inside {
.value
/* WARNING: This naming is unstable */
val tempAdd = VarModel("add", ScalarType.u16)
val retval = VarModel("retval", ScalarType.u16)
val expected = SeqModel.wrap(
CaptureTopologyModel(closureName).leaf,
@ -1632,12 +1564,11 @@ class ArrowInlinerSpec extends AnyFlatSpec with Matchers with Inside {
.CallArrowModel(closureName)
.wrap(
ApplyTopologyModel(closureName).wrap(
ModelBuilder
.add(
LiteralModel("42", LiteralType.number),
LiteralModel("37", LiteralType.number)
)(tempAdd)
.leaf
srvCallModel(
LiteralModel.number(42),
LiteralModel.number(37),
retval
).leaf
)
)
)
@ -1673,17 +1604,16 @@ class ArrowInlinerSpec extends AnyFlatSpec with Matchers with Inside {
val testBody = SeqTag.wrap(
CallArrowRawTag
.service(
serviceId = serviceId,
fnName = argMethodName,
srvId = serviceId,
funcName = argMethodName,
call = Call(
args = VarRaw(argMethodName, ScalarType.string) :: Nil,
exportTo = Call.Export(res.name, res.`type`) :: Nil
),
name = serviceName,
arrowType = ArrowType(
domain = ProductType.labelled(List(argMethodName -> ScalarType.string)),
codomain = ProductType(ScalarType.string :: Nil)
)
).some
)
.leaf,
ReturnTag(
@ -2014,7 +1944,7 @@ class ArrowInlinerSpec extends AnyFlatSpec with Matchers with Inside {
FuncArrow(
"dumb_func",
SeqTag.wrap(
AssignmentTag(LiteralRaw("1", LiteralType.number), argVar.name).leaf,
AssignmentTag(LiteralRaw.number(1), argVar.name).leaf,
foldOp
),
ArrowType(
@ -2037,7 +1967,7 @@ class ArrowInlinerSpec extends AnyFlatSpec with Matchers with Inside {
CallServiceModel(
LiteralModel.fromRaw(serviceId),
fnName,
CallModel(LiteralModel("1", LiteralType.number) :: Nil, Nil)
CallModel(LiteralModel.number(1) :: Nil, Nil)
).leaf,
NextModel(iVar0.name).leaf
)
@ -2215,8 +2145,7 @@ class ArrowInlinerSpec extends AnyFlatSpec with Matchers with Inside {
val closureBody = SeqTag.wrap(
AssignmentTag(
CallArrowRaw.service(
"cmp",
CallServiceRaw(
LiteralRaw.quote("cmp"),
"gt",
ArrowType(

View File

@ -6,6 +6,7 @@ import aqua.model.inline.state.InliningState
import aqua.raw.ops.*
import aqua.raw.value.*
import aqua.types.*
import cats.data.{Chain, NonEmptyList, NonEmptyMap}
import cats.syntax.show.*
import org.scalatest.flatspec.AnyFlatSpec
@ -24,12 +25,11 @@ class CopyInlinerSpec extends AnyFlatSpec with Matchers {
val length = FunctorRaw("length", ScalarType.u32)
val lengthValue = VarRaw("l", arrType).withProperty(length)
val getField = CallArrowRaw(
None,
val getField = CallServiceRaw(
LiteralRaw.quote("serv"),
"get_field",
Nil,
ArrowType(NilType, UnlabeledConsType(ScalarType.string, NilType)),
Option(LiteralRaw.quote("serv"))
Nil
)
val copyRaw =
@ -63,9 +63,22 @@ class CopyInlinerSpec extends AnyFlatSpec with Matchers {
).leaf
),
RestrictionModel(streamMapName, streamMapType).wrap(
InsertKeyValueModel(LiteralModel.quote("field1"), VarModel("l_length", ScalarType.u32), streamMapName, streamMapType).leaf,
InsertKeyValueModel(LiteralModel.quote("field2"), VarModel("get_field", ScalarType.string), streamMapName, streamMapType).leaf,
CanonicalizeModel(VarModel(streamMapName, streamMapType), CallModel.Export(result.name, result.`type`)).leaf
InsertKeyValueModel(
LiteralModel.quote("field1"),
VarModel("l_length", ScalarType.u32),
streamMapName,
streamMapType
).leaf,
InsertKeyValueModel(
LiteralModel.quote("field2"),
VarModel("get_field", ScalarType.string),
streamMapName,
streamMapType
).leaf,
CanonicalizeModel(
VarModel(streamMapName, streamMapType),
CallModel.Export(result.name, result.`type`)
).leaf
)
)
) shouldBe true

View File

@ -6,6 +6,7 @@ import aqua.model.inline.state.InliningState
import aqua.raw.ops.*
import aqua.raw.value.*
import aqua.types.*
import cats.data.{Chain, NonEmptyList, NonEmptyMap}
import cats.syntax.show.*
import org.scalatest.flatspec.AnyFlatSpec
@ -24,12 +25,11 @@ class MakeStructInlinerSpec extends AnyFlatSpec with Matchers {
val length = FunctorRaw("length", ScalarType.u32)
val lengthValue = VarRaw("l", arrType).withProperty(length)
val getField = CallArrowRaw(
None,
val getField = CallServiceRaw(
LiteralRaw.quote("serv"),
"get_field",
Nil,
ArrowType(NilType, UnlabeledConsType(ScalarType.string, NilType)),
Option(LiteralRaw.quote("serv"))
Nil
)
val makeStruct =
@ -62,9 +62,22 @@ class MakeStructInlinerSpec extends AnyFlatSpec with Matchers {
).leaf
),
RestrictionModel(streamMapName, streamMapType).wrap(
InsertKeyValueModel(LiteralModel.quote("field1"), VarModel("l_length", ScalarType.u32), streamMapName, streamMapType).leaf,
InsertKeyValueModel(LiteralModel.quote("field2"), VarModel("get_field", ScalarType.string), streamMapName, streamMapType).leaf,
CanonicalizeModel(VarModel(streamMapName, streamMapType), CallModel.Export(result.name, result.`type`)).leaf
InsertKeyValueModel(
LiteralModel.quote("field1"),
VarModel("l_length", ScalarType.u32),
streamMapName,
streamMapType
).leaf,
InsertKeyValueModel(
LiteralModel.quote("field2"),
VarModel("get_field", ScalarType.string),
streamMapName,
streamMapType
).leaf,
CanonicalizeModel(
VarModel(streamMapName, streamMapType),
CallModel.Export(result.name, result.`type`)
).leaf
)
)
) shouldBe true

View File

@ -1,21 +1,10 @@
package aqua.model.inline
import aqua.raw.value.{CallArrowRaw, LiteralRaw, ValueRaw}
import aqua.raw.value.{ApplyBinaryOpRaw, ValueRaw}
import aqua.types.{ArrowType, ProductType, ScalarType}
object RawBuilder {
def add(l: ValueRaw, r: ValueRaw): ValueRaw =
CallArrowRaw.service(
abilityName = "math",
serviceId = LiteralRaw.quote("math"),
funcName = "add",
baseType = ArrowType(
ProductType(List(ScalarType.i64, ScalarType.i64)),
ProductType(
List(l.`type` `` r.`type`)
)
),
arguments = List(l, r)
)
ApplyBinaryOpRaw.Add(l, r)
}

View File

@ -1,34 +1,48 @@
package aqua.model.inline
import aqua.model.inline.raw.ApplyPropertiesRawInliner
import aqua.model.{
EmptyModel,
FlattenModel,
FunctorModel,
IntoFieldModel,
IntoIndexModel,
ParModel,
SeqModel,
ValueModel,
VarModel
}
import aqua.model.inline.raw.{ApplyPropertiesRawInliner, StreamGateInliner}
import aqua.model.*
import aqua.model.inline.state.InliningState
import aqua.raw.value.{ApplyPropertyRaw, FunctorRaw, IntoIndexRaw, LiteralRaw, VarRaw}
import aqua.types.*
import aqua.raw.value.*
import cats.Eval
import cats.data.NonEmptyMap
import cats.data.Chain
import cats.syntax.show.*
import cats.syntax.foldable.*
import cats.free.Cofree
import scala.collection.immutable.SortedMap
import scala.math
import org.scalatest.flatspec.AnyFlatSpec
import org.scalatest.matchers.should.Matchers
import org.scalatest.Inside
import scala.collection.immutable.SortedMap
import aqua.raw.value.ApplyBinaryOpRaw
import aqua.raw.value.CallArrowRaw
class RawValueInlinerSpec extends AnyFlatSpec with Matchers {
class RawValueInlinerSpec extends AnyFlatSpec with Matchers with Inside {
import RawValueInliner.valueToModel
def join(stream: VarModel, size: ValueModel) =
stream match {
case VarModel(
streamName,
streamType: StreamType,
Chain.`nil`
) =>
StreamGateInliner.joinStreamOnIndexModel(
streamName = streamName,
streamType = streamType,
sizeModel = size,
testName = streamName + "_test",
iterName = streamName + "_fold_var",
canonName = streamName + "_result_canon",
iterCanonName = streamName + "_iter_canon",
resultName = streamName + "_gate"
)
case _ => ???
}
private def numVarWithLength(name: String) =
VarRaw(name, ArrayType(ScalarType.u32)).withProperty(
FunctorRaw("length", ScalarType.u32)
@ -126,6 +140,51 @@ class RawValueInlinerSpec extends AnyFlatSpec with Matchers {
IntoIndexRaw(ysVarRaw(1), ScalarType.string)
)
def int(i: Int): LiteralRaw = LiteralRaw.number(i)
extension (l: ValueRaw) {
def cmp(op: ApplyBinaryOpRaw.Op.Cmp)(r: ValueRaw): ApplyBinaryOpRaw =
ApplyBinaryOpRaw(op, l, r, ScalarType.bool)
def math(op: ApplyBinaryOpRaw.Op.Math)(r: ValueRaw): ApplyBinaryOpRaw =
ApplyBinaryOpRaw(op, l, r, ScalarType.i64) // result type is not important here
def `<`(r: ValueRaw): ApplyBinaryOpRaw =
cmp(ApplyBinaryOpRaw.Op.Lt)(r)
def `<=`(r: ValueRaw): ApplyBinaryOpRaw =
cmp(ApplyBinaryOpRaw.Op.Lte)(r)
def `>`(r: ValueRaw): ApplyBinaryOpRaw =
cmp(ApplyBinaryOpRaw.Op.Gt)(r)
def `>=`(r: ValueRaw): ApplyBinaryOpRaw =
cmp(ApplyBinaryOpRaw.Op.Gte)(r)
def `+`(r: ValueRaw): ApplyBinaryOpRaw =
math(ApplyBinaryOpRaw.Op.Add)(r)
def `-`(r: ValueRaw): ApplyBinaryOpRaw =
math(ApplyBinaryOpRaw.Op.Sub)(r)
def `*`(r: ValueRaw): ApplyBinaryOpRaw =
math(ApplyBinaryOpRaw.Op.Mul)(r)
def `/`(r: ValueRaw): ApplyBinaryOpRaw =
math(ApplyBinaryOpRaw.Op.Div)(r)
def `%`(r: ValueRaw): ApplyBinaryOpRaw =
math(ApplyBinaryOpRaw.Op.Rem)(r)
def `**`(r: ValueRaw): ApplyBinaryOpRaw =
math(ApplyBinaryOpRaw.Op.Pow)(r)
}
private def ivar(name: String, t: Option[Type] = None): VarRaw =
VarRaw(name, t.getOrElse(ScalarType.i32))
"raw value inliner" should "desugarize a single non-recursive raw value" in {
// x[y]
valueToModel[InliningState](`raw x[y]`)
@ -305,24 +364,53 @@ class RawValueInlinerSpec extends AnyFlatSpec with Matchers {
}
it should "desugarize stream with gate" in {
val streamWithProps =
VarRaw("x", StreamType(ScalarType.string)).withProperty(
IntoIndexRaw(ysVarRaw(1), ScalarType.string)
)
val stream = VarRaw("x", StreamType(ScalarType.string))
val streamModel = VarModel.fromVarRaw(stream)
val idxRaw = ysVarRaw(1)
val streamWithProps = stream.withProperty(
IntoIndexRaw(idxRaw, ScalarType.string)
)
val (resVal, resTree) = valueToModel[InliningState](streamWithProps)
.runA(InliningState(noNames = Set("x", "ys")))
.value
val initState = InliningState(noNames = Set("x", "ys"))
// Here retrieve how size is inlined
val (afterSizeState, (sizeModel, sizeTree)) =
valueToModel[InliningState](idxRaw.increment).run(initState).value
val (resVal, resTree) =
valueToModel[InliningState](streamWithProps).runA(initState).value
val idxModel = VarModel("x_idx", ScalarType.i8)
val decrement = CallServiceModel(
"math",
"sub",
List(
sizeModel,
LiteralModel.number(1)
),
idxModel
).leaf
val expected = SeqModel.wrap(
sizeTree.toList :+
join(streamModel, sizeModel) :+
decrement
)
resVal should be(
VarModel(
"x_gate",
ArrayType(ScalarType.string),
Chain(
IntoIndexModel("ys_flat", ScalarType.string)
IntoIndexModel(idxModel.name, ScalarType.string)
)
)
)
inside(resTree) { case Some(tree) =>
tree.equalsOrShowDiff(expected) should be(true)
}
}
it should "desugarize stream with length" in {
@ -388,4 +476,165 @@ class RawValueInlinerSpec extends AnyFlatSpec with Matchers {
)
) should be(true)
}
it should "optimize constants comparison" in {
for {
l <- -100 to 100
r <- -100 to 100
} {
val lt = valueToModel[InliningState](
int(l) `<` int(r)
).runA(InliningState()).value
lt shouldBe (
LiteralModel.bool(l < r) -> None
)
val lte = valueToModel[InliningState](
int(l) `<=` int(r)
).runA(InliningState()).value
lte shouldBe (
LiteralModel.bool(l <= r) -> None
)
val gt = valueToModel[InliningState](
int(l) `>` int(r)
).runA(InliningState()).value
gt shouldBe (
LiteralModel.bool(l > r) -> None
)
val gte = valueToModel[InliningState](
int(l) `>=` int(r)
).runA(InliningState()).value
gte shouldBe (
LiteralModel.bool(l >= r) -> None
)
}
}
it should "optimize constants math" in {
for {
l <- -100 to 100
r <- -100 to 100
} {
val add = valueToModel[InliningState](
int(l) `+` int(r)
).runA(InliningState()).value
add shouldBe (
LiteralModel.number(l + r) -> None
)
val sub = valueToModel[InliningState](
int(l) `-` int(r)
).runA(InliningState()).value
sub shouldBe (
LiteralModel.number(l - r) -> None
)
val mul = valueToModel[InliningState](
int(l) `*` int(r)
).runA(InliningState()).value
mul shouldBe (
LiteralModel.number(l * r) -> None
)
val div = valueToModel[InliningState](
int(l) `/` int(r)
).runA(InliningState()).value
val rem = valueToModel[InliningState](
int(l) `%` int(r)
).runA(InliningState()).value
if (r != 0)
div shouldBe (
LiteralModel.number(l / r) -> None
)
rem shouldBe (
LiteralModel.number(l % r) -> None
)
else {
val (dmodel, dtree) = div
dmodel shouldBe a[VarModel]
dtree.nonEmpty shouldBe (true)
val (rmodel, rtree) = rem
rmodel shouldBe a[VarModel]
rtree.nonEmpty shouldBe (true)
}
if (r >= 0 && r <= 5) {
val pow = valueToModel[InliningState](
int(l) `**` int(r)
).runA(InliningState()).value
pow shouldBe (
LiteralModel.number(scala.math.pow(l, r).toLong) -> None
)
}
}
}
it should "optimize addition in expressions" in {
def test(numVars: Int, numLiterals: Int) = {
val vars = (1 to numVars).map(i => ivar(s"v$i")).toList
val literals = (1 to numLiterals).map(i => LiteralRaw.number(i)).toList
val values = vars ++ literals
/**
* Enumerate all possible binary trees of vals
*/
def genAllExprs(vals: List[ValueRaw]): List[ValueRaw] =
if (vals.length <= 1) vals
else
for {
split <- (1 until vals.length).toList
(left, right) = vals.splitAt(split)
l <- genAllExprs(left)
r <- genAllExprs(right)
} yield l `+` r
for {
perm <- values.permutations.toList
expr <- genAllExprs(perm)
} {
val state = InliningState(
resolvedExports = vars.map(v => v.name -> VarModel.fromVarRaw(v)).toMap
)
val (model, inline) = valueToModel[InliningState](expr).runA(state).value
model shouldBe a[VarModel]
inside(inline) { case Some(tree) =>
val numberOfAdditions = Cofree
.cata(tree) { (model, count: Chain[Int]) =>
Eval.later {
count.combineAll + (model match {
case CallServiceModel(_, "add", _) => 1
case _ => 0
})
}
}
.value
numberOfAdditions shouldEqual numVars
}
}
}
/**
* Number of expressions grows exponentially
* So we test only small cases
*/
test(2, 2)
test(3, 2)
test(2, 3)
}
}

View File

@ -2,7 +2,7 @@ package aqua.raw.ops
import aqua.raw.arrow.FuncRaw
import aqua.raw.ops.RawTag.Tree
import aqua.raw.value.{CallArrowRaw, ValueRaw}
import aqua.raw.value.{CallArrowRaw, CallServiceRaw, ValueRaw}
import aqua.tree.{TreeNode, TreeNodeCompanion}
import aqua.types.{ArrowType, DataType, ServiceType}
@ -224,26 +224,6 @@ object CallArrowRawTag {
)
)
def service(
serviceId: ValueRaw,
fnName: String,
call: Call,
name: String = null,
arrowType: ArrowType = null
): CallArrowRawTag =
CallArrowRawTag(
call.exportTo,
CallArrowRaw(
Option(name),
fnName,
call.args,
Option(arrowType).getOrElse(
call.arrowType
),
Some(serviceId)
)
)
def func(fnName: String, call: Call): CallArrowRawTag =
CallArrowRawTag(
call.exportTo,
@ -253,6 +233,22 @@ object CallArrowRawTag {
arguments = call.args
)
)
def service(
srvId: ValueRaw,
funcName: String,
call: Call,
arrowType: Option[ArrowType] = None
): CallArrowRawTag =
CallArrowRawTag(
call.exportTo,
CallServiceRaw(
srvId,
funcName,
arrowType.getOrElse(call.arrowType),
call.args
)
)
}
case class DeclareStreamTag(

View File

@ -0,0 +1,119 @@
package aqua.raw.value
import cats.Eval
import cats.data.Ior
import cats.Semigroup
import cats.syntax.applicative.*
import cats.syntax.apply.*
import cats.syntax.functor.*
import cats.syntax.semigroup.*
object Optimization {
/**
* Optimize raw value.
*
* A lot more optimizations could be done,
* it is here just as a proof of concept.
*/
def optimize(value: ValueRaw): Eval[ValueRaw] =
Addition.optimize(value)
object Addition {
def optimize(value: ValueRaw): Eval[ValueRaw] =
gatherLiteralsInAddition(value).map { res =>
res.fold(
// TODO: Type of literal is not preserved
LiteralRaw.number,
identity,
(literal, nonLiteral) =>
if (literal > 0) {
ApplyBinaryOpRaw.Add(nonLiteral, LiteralRaw.number(literal))
} else if (literal < 0) {
ApplyBinaryOpRaw.Sub(nonLiteral, LiteralRaw.number(-literal))
} else nonLiteral
)
}
private def gatherLiteralsInAddition(
value: ValueRaw
): Eval[Ior[Long, ValueRaw]] =
value match {
case ApplyBinaryOpRaw.Add(left, right) =>
(
gatherLiteralsInAddition(left),
gatherLiteralsInAddition(right)
).mapN(_ add _)
case ApplyBinaryOpRaw.Sub(
left,
/**
* We support subtraction only with literal at the right side
* because we don't have unary minus operator for values.
* (Or this algo should be much more complex to support it)
* But this case is pretty commonly generated by compiler
* in gates: `join stream[len - 1]` (see `StreamGateInliner`)
*/
LiteralRaw.Integer(i)
) =>
(
gatherLiteralsInAddition(left),
Eval.now(Ior.left(-i))
// NOTE: Use add as sign is stored inside Long
).mapN(_ add _)
case LiteralRaw.Integer(i) =>
Ior.left(i).pure
case _ =>
// Optimize expressions inside this value
Ior.right(value.mapValues(v => optimize(v).value)).pure
}
/**
* Rewritten `Ior.combine` method
* with custom combination functions.
*/
private def combineWith(
l: Ior[Long, ValueRaw],
r: Ior[Long, ValueRaw]
)(
lf: (Long, Long) => Long,
rf: (ValueRaw, ValueRaw) => ValueRaw
): Ior[Long, ValueRaw] =
l match {
case Ior.Left(lvl) =>
r match {
case Ior.Left(rvl) =>
Ior.left(lf(lvl, rvl))
case Ior.Right(rvr) =>
Ior.both(lvl, rvr)
case Ior.Both(rvl, rvr) =>
Ior.both(lf(lvl, rvl), rvr)
}
case Ior.Right(lvr) =>
r match {
case Ior.Left(rvl) =>
Ior.both(rvl, lvr)
case Ior.Right(rvr) =>
Ior.right(rf(lvr, rvr))
case Ior.Both(rvl, rvr) =>
Ior.both(rvl, rf(lvr, rvr))
}
case Ior.Both(lvl, lvr) =>
r match {
case Ior.Left(rvl) =>
Ior.both(lf(lvl, rvl), lvr)
case Ior.Right(rvr) =>
Ior.both(lvl, rf(lvr, rvr))
case Ior.Both(rvl, rvr) =>
Ior.both(lf(lvl, rvl), rf(lvr, rvr))
}
}
extension (l: Ior[Long, ValueRaw]) {
def add(r: Ior[Long, ValueRaw]): Ior[Long, ValueRaw] =
combineWith(l, r)(_ + _, ApplyBinaryOpRaw.Add(_, _))
}
}
}

View File

@ -6,6 +6,9 @@ import cats.data.NonEmptyMap
sealed trait PropertyRaw {
def `type`: Type
/**
* Apply function to values in this property
*/
def map(f: ValueRaw => ValueRaw): PropertyRaw
def renameVars(vals: Map[String, String]): PropertyRaw = this
@ -24,7 +27,8 @@ case class IntoArrowRaw(name: String, arrowType: Type, arguments: List[ValueRaw]
override def `type`: Type = arrowType
override def map(f: ValueRaw => ValueRaw): PropertyRaw = this
override def map(f: ValueRaw => ValueRaw): PropertyRaw =
copy(arguments = arguments.map(f))
override def varNames: Set[String] = arguments.flatMap(_.varNames).toSet

View File

@ -15,7 +15,16 @@ sealed trait ValueRaw {
def renameVars(map: Map[String, String]): ValueRaw
def map(f: ValueRaw => ValueRaw): ValueRaw
/**
* Apply function to all values in the tree
*/
final def map(f: ValueRaw => ValueRaw): ValueRaw =
f(mapValues(_.map(f)))
/**
* Apply function to values in this value
*/
def mapValues(f: ValueRaw => ValueRaw): ValueRaw
def varNames: Set[String]
}
@ -60,6 +69,12 @@ object ValueRaw {
type ApplyRaw = ApplyPropertyRaw | CallArrowRaw | CollectionRaw | ApplyBinaryOpRaw |
ApplyUnaryOpRaw
extension (v: ValueRaw) {
def add(a: ValueRaw): ValueRaw = ApplyBinaryOpRaw.Add(v, a)
def increment: ValueRaw = ApplyBinaryOpRaw.Add(v, LiteralRaw.number(1))
}
}
case class ApplyPropertyRaw(value: ValueRaw, property: PropertyRaw) extends ValueRaw {
@ -70,8 +85,8 @@ case class ApplyPropertyRaw(value: ValueRaw, property: PropertyRaw) extends Valu
override def renameVars(map: Map[String, String]): ValueRaw =
ApplyPropertyRaw(value.renameVars(map), property.renameVars(map))
override def map(f: ValueRaw => ValueRaw): ValueRaw =
f(ApplyPropertyRaw(f(value), property.map(_.map(f))))
override def mapValues(f: ValueRaw => ValueRaw): ValueRaw =
ApplyPropertyRaw(f(value), property.map(f))
override def toString: String = s"$value.$property"
@ -96,7 +111,7 @@ object ApplyPropertyRaw {
case class VarRaw(name: String, baseType: Type) extends ValueRaw {
override def map(f: ValueRaw => ValueRaw): ValueRaw = f(this)
override def mapValues(f: ValueRaw => ValueRaw): ValueRaw = this
override def renameVars(map: Map[String, String]): ValueRaw =
copy(name = map.getOrElse(name, name))
@ -110,7 +125,7 @@ case class VarRaw(name: String, baseType: Type) extends ValueRaw {
}
case class LiteralRaw(value: String, baseType: Type) extends ValueRaw {
override def map(f: ValueRaw => ValueRaw): ValueRaw = f(this)
override def mapValues(f: ValueRaw => ValueRaw): ValueRaw = this
override def toString: String = s"{$value: ${baseType}}"
@ -122,12 +137,25 @@ case class LiteralRaw(value: String, baseType: Type) extends ValueRaw {
object LiteralRaw {
def quote(value: String): LiteralRaw = LiteralRaw("\"" + value + "\"", LiteralType.string)
def number(value: Int): LiteralRaw = LiteralRaw(value.toString, LiteralType.forInt(value))
def number(value: Long): LiteralRaw = LiteralRaw(value.toString, LiteralType.forInt(value))
val Zero: LiteralRaw = number(0)
val True: LiteralRaw = LiteralRaw("true", LiteralType.bool)
val False: LiteralRaw = LiteralRaw("false", LiteralType.bool)
object Integer {
/*
* Used to match integer literals in pattern matching
*/
def unapply(value: ValueRaw): Option[Long] =
value match {
case LiteralRaw(value, t) if ScalarType.integer.exists(_.acceptsValueOf(t)) =>
value.toLongOption
case _ => none
}
}
}
case class CollectionRaw(values: NonEmptyList[ValueRaw], boxType: BoxType) extends ValueRaw {
@ -136,10 +164,10 @@ case class CollectionRaw(values: NonEmptyList[ValueRaw], boxType: BoxType) exten
override lazy val baseType: Type = boxType
override def map(f: ValueRaw => ValueRaw): ValueRaw = {
override def mapValues(f: ValueRaw => ValueRaw): ValueRaw = {
val vals = values.map(f)
val el = vals.map(_.`type`).reduceLeft(_ `∩` _)
f(copy(vals, boxType.withElement(el)))
copy(vals, boxType.withElement(el))
}
override def varNames: Set[String] = values.toList.flatMap(_.varNames).toSet
@ -153,7 +181,8 @@ case class MakeStructRaw(fields: NonEmptyMap[String, ValueRaw], structType: Stru
override def baseType: Type = structType
override def map(f: ValueRaw => ValueRaw): ValueRaw = f(copy(fields = fields.map(f)))
override def mapValues(f: ValueRaw => ValueRaw): ValueRaw =
copy(fields = fields.map(f))
override def varNames: Set[String] = {
fields.toSortedMap.values.flatMap(_.varNames).toSet
@ -168,8 +197,8 @@ case class AbilityRaw(fieldsAndArrows: NonEmptyMap[String, ValueRaw], abilityTyp
override def baseType: Type = abilityType
override def map(f: ValueRaw => ValueRaw): ValueRaw =
f(copy(fieldsAndArrows = fieldsAndArrows.map(f)))
override def mapValues(f: ValueRaw => ValueRaw): ValueRaw =
copy(fieldsAndArrows = fieldsAndArrows.map(f))
override def varNames: Set[String] = {
fieldsAndArrows.toSortedMap.values.flatMap(_.varNames).toSet
@ -182,29 +211,79 @@ case class AbilityRaw(fieldsAndArrows: NonEmptyMap[String, ValueRaw], abilityTyp
case class ApplyBinaryOpRaw(
op: ApplyBinaryOpRaw.Op,
left: ValueRaw,
right: ValueRaw
right: ValueRaw,
// TODO: Refactor type, get rid of `LiteralType`
resultType: ScalarType | LiteralType
) extends ValueRaw {
// Only boolean operations are supported for now
override def baseType: Type = ScalarType.bool
override val baseType: Type = resultType
override def map(f: ValueRaw => ValueRaw): ValueRaw =
f(copy(left = f(left), right = f(right)))
override def mapValues(f: ValueRaw => ValueRaw): ValueRaw =
copy(left = f(left), right = f(right))
override def varNames: Set[String] = left.varNames ++ right.varNames
override def renameVars(map: Map[String, String]): ValueRaw =
copy(left = left.renameVars(map), right = right.renameVars(map))
override def toString(): String =
s"(${left} ${op} ${right}) :: ${resultType}"
}
object ApplyBinaryOpRaw {
enum Op {
case And
case Or
case And, Or
case Eq, Neq
case Lt, Lte, Gt, Gte
case Add, Sub, Mul, FMul, Div, Pow, Rem
}
case Eq
case Neq
object Op {
type Bool = And.type | Or.type
type Eq = Eq.type | Neq.type
type Cmp = Lt.type | Lte.type | Gt.type | Gte.type
type Math = Add.type | Sub.type | Mul.type | FMul.type | Div.type | Pow.type | Rem.type
}
object Add {
def apply(left: ValueRaw, right: ValueRaw): ValueRaw =
ApplyBinaryOpRaw(
Op.Add,
left,
right,
ScalarType.resolveMathOpType(left.`type`, right.`type`).`type`
)
def unapply(value: ValueRaw): Option[(ValueRaw, ValueRaw)] =
value match {
case ApplyBinaryOpRaw(Op.Add, left, right, _) =>
(left, right).some
case _ => none
}
}
object Sub {
def apply(left: ValueRaw, right: ValueRaw): ValueRaw =
ApplyBinaryOpRaw(
Op.Sub,
left,
right,
ScalarType.resolveMathOpType(left.`type`, right.`type`).`type`
)
def unapply(value: ValueRaw): Option[(ValueRaw, ValueRaw)] =
value match {
case ApplyBinaryOpRaw(Op.Sub, left, right, _) =>
(left, right).some
case _ => none
}
}
}
@ -216,8 +295,8 @@ case class ApplyUnaryOpRaw(
// Only boolean operations are supported for now
override def baseType: Type = ScalarType.bool
override def map(f: ValueRaw => ValueRaw): ValueRaw =
f(copy(value = f(value)))
override def mapValues(f: ValueRaw => ValueRaw): ValueRaw =
copy(value = f(value))
override def varNames: Set[String] = value.varNames
@ -237,37 +316,28 @@ case class CallArrowRaw(
ability: Option[String],
name: String,
arguments: List[ValueRaw],
baseType: ArrowType,
// TODO: there should be no serviceId there
serviceId: Option[ValueRaw]
baseType: ArrowType
) extends ValueRaw {
override def `type`: Type = baseType.codomain.uncons.map(_._1).getOrElse(baseType)
override def `type`: Type = baseType.codomain.headOption.getOrElse(baseType)
override def map(f: ValueRaw => ValueRaw): ValueRaw =
f(
copy(
arguments = arguments.map(_.map(f)),
serviceId = serviceId.map(_.map(f))
)
)
override def mapValues(f: ValueRaw => ValueRaw): ValueRaw =
copy(arguments = arguments.map(f))
override def varNames: Set[String] = name.some
.filterNot(_ => ability.isDefined || serviceId.isDefined)
.filterNot(_ => ability.isDefined)
.toSet ++ arguments.flatMap(_.varNames).toSet
override def renameVars(map: Map[String, String]): ValueRaw =
copy(
name = map
.get(name)
// Rename only if it is **not** a service or ability call, see [bug LNG-199]
// Rename only if it is **not** an ability call, see [bug LNG-199]
.filterNot(_ => ability.isDefined)
.filterNot(_ => serviceId.isDefined)
.getOrElse(name)
)
override def toString: String =
s"(call ${ability.fold("")(a => s"|$a| ")} (${serviceId.fold("")(_.toString + " ")}$name) [${arguments
.mkString(" ")}] :: $baseType)"
s"${ability.fold("")(a => s"$a.")}$name(${arguments.mkString(",")}) :: $baseType)"
}
object CallArrowRaw {
@ -280,8 +350,7 @@ object CallArrowRaw {
ability = None,
name = funcName,
arguments = arguments,
baseType = baseType,
serviceId = None
baseType = baseType
)
def ability(
@ -293,22 +362,46 @@ object CallArrowRaw {
ability = None,
name = AbilityType.fullName(abilityName, funcName),
arguments = arguments,
baseType = baseType,
serviceId = None
)
def service(
abilityName: String,
serviceId: ValueRaw,
funcName: String,
baseType: ArrowType,
arguments: List[ValueRaw] = Nil
): CallArrowRaw = CallArrowRaw(
ability = abilityName.some,
name = funcName,
arguments = arguments,
baseType = baseType,
serviceId = Some(serviceId)
baseType = baseType
)
}
/**
* WARNING: This class is internal and is used to generate code.
* Calls to services in aqua code are represented as [[CallArrowRaw]]
* and resolved through ability resolution.
*
* @param serviceId service id
* @param fnName service method name
* @param baseType type of the service method
* @param arguments call arguments
*/
case class CallServiceRaw(
serviceId: ValueRaw,
fnName: String,
baseType: ArrowType,
arguments: List[ValueRaw]
) extends ValueRaw {
override def `type`: Type = baseType.codomain.headOption.getOrElse(baseType)
override def mapValues(f: ValueRaw => ValueRaw): ValueRaw =
copy(
serviceId = f(serviceId),
arguments = arguments.map(f)
)
override def varNames: Set[String] =
arguments
.flatMap(_.varNames)
.toSet ++ serviceId.varNames
override def renameVars(map: Map[String, String]): ValueRaw =
copy(
serviceId = serviceId.renameVars(map),
arguments = arguments.map(_.renameVars(map))
)
override def toString: String =
s"call (${serviceId}) $fnName(${arguments.mkString(",")}) :: $baseType)"
}

View File

@ -144,18 +144,13 @@ object AquaContext extends Logging {
blank.copy(
module = Some(sm.name),
funcs = sm.`type`.arrows.map { case (fnName, arrowType) =>
val (args, call, ret) = ArgsCall.arrowToArgsCallRet(arrowType)
fnName ->
FuncArrow(
fnName,
// TODO: capture ability resolution, get ID from the call context
CallArrowRawTag.service(serviceId, fnName, call, sm.name).leaf,
arrowType,
ret.map(_.toRaw),
Map.empty,
Map.empty,
None
)
fnName -> FuncArrow.fromServiceMethod(
fnName,
sm.name,
fnName,
arrowType,
serviceId
)
}
)

View File

@ -2,10 +2,12 @@ package aqua.model
import aqua.raw.Raw
import aqua.raw.arrow.FuncRaw
import aqua.raw.ops.{Call, CallArrowRawTag, RawTag}
import aqua.raw.ops.{Call, CallArrowRawTag, EmptyTag, RawTag}
import aqua.raw.value.{ValueRaw, VarRaw}
import aqua.types.{ArrowType, ServiceType, Type}
import cats.syntax.option.*
case class FuncArrow(
funcName: String,
body: RawTag.Tree,
@ -58,21 +60,28 @@ object FuncArrow {
serviceName: String,
methodName: String,
methodType: ArrowType,
idValue: ValueModel
idValue: ValueModel | ValueRaw
): FuncArrow = {
val id = VarRaw("id", idValue.`type`)
val (id, capturedValues) = idValue match {
case i: ValueModel =>
(
VarRaw("id", i.`type`),
Map("id" -> i)
)
case i: ValueRaw => (i, Map.empty[String, ValueModel])
}
val retVar = methodType.res.map(t => VarRaw("ret", t))
val call = Call(
methodType.domain.toLabelledList().map(VarRaw.apply),
retVar.map(r => Call.Export(r.name, r.`type`)).toList
)
val body = CallArrowRawTag.service(
serviceId = id,
fnName = methodName,
srvId = id,
funcName = methodName,
call = call,
name = serviceName,
arrowType = methodType
arrowType = methodType.some
)
FuncArrow(
@ -81,9 +90,7 @@ object FuncArrow {
arrowType = methodType,
ret = retVar.toList,
capturedArrows = Map.empty,
capturedValues = Map(
id.name -> idValue
),
capturedValues = capturedValues,
capturedTopology = None
)
}

View File

@ -7,6 +7,7 @@ import aqua.types.*
import cats.Eq
import cats.data.{Chain, NonEmptyMap}
import cats.syntax.option.*
import cats.syntax.apply.*
import scribe.Logging
sealed trait ValueModel {
@ -91,6 +92,22 @@ object LiteralModel {
}
}
/*
* Used to match integer literals in pattern matching
*/
object Integer {
def unapply(lm: LiteralModel): Option[(Long, ScalarType | LiteralType)] =
lm match {
case LiteralModel(value, t) if ScalarType.integer.exists(_.acceptsValueOf(t)) =>
(
value.toLongOption,
t.some.collect { case t: (ScalarType | LiteralType) => t }
).tupled
case _ => none
}
}
// AquaVM will return 0 for
// :error:.$.error_code if there is no :error:
val emptyErrorCode = number(0)
@ -102,7 +119,7 @@ object LiteralModel {
def quote(str: String): LiteralModel = LiteralModel(s"\"$str\"", LiteralType.string)
def number(n: Int): LiteralModel = LiteralModel(n.toString, LiteralType.forInt(n))
def number(n: Long): LiteralModel = LiteralModel(n.toString, LiteralType.forInt(n))
def bool(b: Boolean): LiteralModel = LiteralModel(b.toString.toLowerCase, LiteralType.bool)
}

View File

@ -50,7 +50,7 @@ class CallArrowSem[S[_]](val expr: CallArrowExpr[S]) extends AnyVal {
)
} yield tag.map(_.funcOpLeaf)
def program[Alg[_]: Monad](implicit
def program[Alg[_]: Monad](using
N: NamesAlgebra[S, Alg],
T: TypesAlgebra[S, Alg],
V: ValuesAlgebra[S, Alg]

View File

@ -184,7 +184,6 @@ class ValuesAlgebra[S[_], Alg[_]: Monad](using
case (Some(leftRaw), Some(rightRaw)) =>
val lType = leftRaw.`type`
val rType = rightRaw.`type`
lazy val uType = lType `` rType
it.op match {
case InfOp.Bool(bop) =>
@ -200,7 +199,8 @@ class ValuesAlgebra[S[_], Alg[_]: Monad](using
case BoolOp.Or => ApplyBinaryOpRaw.Op.Or
},
left = leftRaw,
right = rightRaw
right = rightRaw,
resultType = ScalarType.bool
)
)
case InfOp.Eq(eop) =>
@ -216,7 +216,8 @@ class ValuesAlgebra[S[_], Alg[_]: Monad](using
case EqOp.Neq => ApplyBinaryOpRaw.Op.Neq
},
left = leftRaw,
right = rightRaw
right = rightRaw,
resultType = ScalarType.bool
)
)
)
@ -228,70 +229,62 @@ class ValuesAlgebra[S[_], Alg[_]: Monad](using
case InfOp.Cmp(op) => op
}
val hasFloat = List(lType, rType).exists(
lazy val hasFloat = List(lType, rType).exists(
_ acceptsValueOf LiteralType.float
)
// See https://github.com/fluencelabs/aqua-lib/blob/main/math.aqua
val (id, fn) = iop match {
case MathOp.Add => ("math", "add")
case MathOp.Sub => ("math", "sub")
case MathOp.Mul if hasFloat => ("math", "fmul")
case MathOp.Mul => ("math", "mul")
case MathOp.Div => ("math", "div")
case MathOp.Rem => ("math", "rem")
case MathOp.Pow => ("math", "pow")
case CmpOp.Gt => ("cmp", "gt")
case CmpOp.Gte => ("cmp", "gte")
case CmpOp.Lt => ("cmp", "lt")
case CmpOp.Lte => ("cmp", "lte")
val bop = iop match {
case MathOp.Add => ApplyBinaryOpRaw.Op.Add
case MathOp.Sub => ApplyBinaryOpRaw.Op.Sub
case MathOp.Mul if hasFloat => ApplyBinaryOpRaw.Op.FMul
case MathOp.Mul => ApplyBinaryOpRaw.Op.Mul
case MathOp.Div => ApplyBinaryOpRaw.Op.Div
case MathOp.Rem => ApplyBinaryOpRaw.Op.Rem
case MathOp.Pow => ApplyBinaryOpRaw.Op.Pow
case CmpOp.Gt => ApplyBinaryOpRaw.Op.Gt
case CmpOp.Gte => ApplyBinaryOpRaw.Op.Gte
case CmpOp.Lt => ApplyBinaryOpRaw.Op.Lt
case CmpOp.Lte => ApplyBinaryOpRaw.Op.Lte
}
/*
* If `uType == TopType`, it means that we don't
* have type big enough to hold the result of operation.
* e.g. We will use `i64` for result of `i32 * u64`
* TODO: Handle this more gracefully
* (use warning system when it is implemented)
*/
def uTypeBounded = if (uType == TopType) {
val bounded = ScalarType.i64
logger.warn(
s"Result type of ($lType ${it.op} $rType) is $TopType, " +
s"using $bounded instead"
)
bounded
} else uType
lazy val numbersTypeBounded: Alg[ScalarType | LiteralType] = {
val resType = ScalarType.resolveMathOpType(lType, rType)
report
.warning(
it,
s"Result type of ($lType ${it.op} $rType) is unknown, " +
s"using ${resType.`type`} instead"
)
.whenA(resType.overflow)
.as(resType.`type`)
}
// Expected type sets of left and right operands, result type
val (leftExp, rightExp, resType) = iop match {
val (leftExp, rightExp, resTypeM) = iop match {
case MathOp.Add | MathOp.Sub | MathOp.Div | MathOp.Rem =>
(ScalarType.integer, ScalarType.integer, uTypeBounded)
(ScalarType.integer, ScalarType.integer, numbersTypeBounded)
case MathOp.Pow =>
(ScalarType.integer, ScalarType.unsigned, uTypeBounded)
(ScalarType.integer, ScalarType.unsigned, numbersTypeBounded)
case MathOp.Mul if hasFloat =>
(ScalarType.float, ScalarType.float, ScalarType.i64)
(ScalarType.float, ScalarType.float, ScalarType.i64.pure)
case MathOp.Mul =>
(ScalarType.integer, ScalarType.integer, uTypeBounded)
(ScalarType.integer, ScalarType.integer, numbersTypeBounded)
case CmpOp.Gt | CmpOp.Lt | CmpOp.Gte | CmpOp.Lte =>
(ScalarType.integer, ScalarType.integer, ScalarType.bool)
(ScalarType.integer, ScalarType.integer, ScalarType.bool.pure)
}
for {
leftChecked <- T.ensureTypeOneOf(l, leftExp, lType)
rightChecked <- T.ensureTypeOneOf(r, rightExp, rType)
resType <- resTypeM
} yield Option.when(
leftChecked.isDefined && rightChecked.isDefined
)(
CallArrowRaw.service(
abilityName = id,
serviceId = LiteralRaw.quote(id),
funcName = fn,
baseType = ArrowType(
ProductType(lType :: rType :: Nil),
ProductType(resType :: Nil)
),
arguments = leftRaw :: rightRaw :: Nil
ApplyBinaryOpRaw(
op = bop,
left = leftRaw,
right = rightRaw,
resultType = resType
)
)

View File

@ -108,10 +108,10 @@ class SemanticsSpec extends AnyFlatSpec with Matchers with Inside {
.leaf
def equ(left: ValueRaw, right: ValueRaw): ApplyBinaryOpRaw =
ApplyBinaryOpRaw(ApplyBinaryOpRaw.Op.Eq, left, right)
ApplyBinaryOpRaw(ApplyBinaryOpRaw.Op.Eq, left, right, ScalarType.bool)
def neq(left: ValueRaw, right: ValueRaw): ApplyBinaryOpRaw =
ApplyBinaryOpRaw(ApplyBinaryOpRaw.Op.Neq, left, right)
ApplyBinaryOpRaw(ApplyBinaryOpRaw.Op.Neq, left, right, ScalarType.bool)
def declareStreamPush(
name: String,

View File

@ -267,7 +267,7 @@ class ValuesAlgebraSpec extends AnyFlatSpec with Matchers with Inside {
.run(state)
.value
inside(res) { case Some(ApplyBinaryOpRaw(bop, _, _)) =>
inside(res) { case Some(ApplyBinaryOpRaw(bop, _, _, ScalarType.bool)) =>
bop shouldBe (op match {
case InfixToken.BoolOp.And => ApplyBinaryOpRaw.Op.And
case InfixToken.BoolOp.Or => ApplyBinaryOpRaw.Op.Or
@ -319,7 +319,7 @@ class ValuesAlgebraSpec extends AnyFlatSpec with Matchers with Inside {
.run(state)
.value
inside(res) { case Some(ApplyBinaryOpRaw(bop, _, _)) =>
inside(res) { case Some(ApplyBinaryOpRaw(bop, _, _, ScalarType.bool)) =>
bop shouldBe (op match {
case InfixToken.EqOp.Eq => ApplyBinaryOpRaw.Op.Eq
case InfixToken.EqOp.Neq => ApplyBinaryOpRaw.Op.Neq

View File

@ -53,6 +53,11 @@ sealed trait ProductType extends Type {
case _ => None
}
def headOption: Option[Type] = this match {
case ConsType(t, _) => Some(t)
case _ => None
}
lazy val toList: List[Type] = this match {
case ConsType(t, pt) => t :: pt.toList
case _ => Nil
@ -182,6 +187,45 @@ object ScalarType {
val integer = signed ++ unsigned
val number = float ++ integer
val all = number ++ Set(bool, string)
final case class MathOpType(
`type`: ScalarType | LiteralType,
overflow: Boolean
)
/**
* Resolve type of math operation
* on two given types.
*
* WARNING: General `Type` is accepted
* but only integer `ScalarType` and `LiteralType`
* are actually expected.
*/
def resolveMathOpType(
lType: Type,
rType: Type
): MathOpType = {
val uType = lType `` rType
uType match {
case t: (ScalarType | LiteralType) => MathOpType(t, false)
case _ => MathOpType(ScalarType.i64, true)
}
}
/**
* Check if given type is signed.
*
* NOTE: Only integer types are expected.
* But it is impossible to enforce it.
*/
def isSignedInteger(t: ScalarType | LiteralType): Boolean =
t match {
case st: ScalarType => signed.contains(st)
/**
* WARNING: LiteralType.unsigned is signed integer!
*/
case lt: LiteralType => lt.oneOf.exists(signed.contains)
}
}
case class LiteralType private (oneOf: Set[ScalarType], name: String) extends DataType {
@ -200,7 +244,7 @@ object LiteralType {
val bool = LiteralType(Set(ScalarType.bool), "bool")
val string = LiteralType(Set(ScalarType.string), "string")
def forInt(n: Int): LiteralType = if (n < 0) signed else unsigned
def forInt(n: Long): LiteralType = if (n < 0) signed else unsigned
}
sealed trait BoxType extends DataType {
@ -323,8 +367,7 @@ case class StructType(name: String, fields: NonEmptyMap[String, Type])
s"$name{${fields.map(_.toString).toNel.toList.map(kv => kv._1 + ": " + kv._2).mkString(", ")}}"
}
case class StreamMapType(element: Type)
extends DataType {
case class StreamMapType(element: Type) extends DataType {
override def toString: String = s"%$element"
}