feat(compiler): Always generate last argument of fold [LNG-265] (#947)

* Always generate last in fold

* Fix unit tests

* Add methods
This commit is contained in:
InversionSpaces 2023-10-30 10:58:51 +01:00 committed by GitHub
parent 634b1c17b6
commit 78ee753c7b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
19 changed files with 254 additions and 177 deletions

View File

@ -93,7 +93,7 @@ object Air {
iterable: DataView,
label: String,
instruction: Air,
lastNextInstruction: Option[Air]
lastNextInstruction: Air
) extends Air(Keyword.Fold)
case class Match(left: DataView, right: DataView, instruction: Air) extends Air(Keyword.Match)
@ -137,7 +137,7 @@ object Air {
case Air.Next(label) s" $label"
case Air.New(item, inst) s" ${item.show}\n${showNext(inst)}$space"
case Air.Fold(iter, label, inst, lastInst)
val l = lastInst.map(a => show(depth + 1, a)).getOrElse("")
val l = show(depth + 1, lastInst)
s" ${iter.show} $label\n${showNext(inst)}$l$space"
case Air.Match(left, right, inst)
s" ${left.show} ${right.show}\n${showNext(inst)}$space"

View File

@ -94,9 +94,9 @@ object AirGen extends Logging {
)
case FoldRes(item, iterable, mode) =>
val m = mode.map {
case ForModel.Mode.Null => NullGen
case ForModel.Mode.Never => NeverGen
val m = mode match {
case FoldRes.Mode.Null => NullGen
case FoldRes.Mode.Never => NeverGen
}
Eval later ForGen(valueToData(iterable), item, opsToSingle(ops), m)
case RestrictionRes(item, itemType) =>
@ -202,9 +202,8 @@ case class MatchMismatchGen(
else Air.Mismatch(left, right, body.generate)
}
case class ForGen(iterable: DataView, item: String, body: AirGen, mode: Option[AirGen])
extends AirGen {
override def generate: Air = Air.Fold(iterable, item, body.generate, mode.map(_.generate))
case class ForGen(iterable: DataView, item: String, body: AirGen, mode: AirGen) extends AirGen {
override def generate: Air = Air.Fold(iterable, item, body.generate, mode.generate)
}
case class NewGen(name: String, body: AirGen) extends AirGen {

View File

@ -169,7 +169,9 @@ class AquaCompilerSpec extends AnyFlatSpec with Matchers with Inside {
RestrictionRes(results.name, resultsType).wrap(
SeqRes.wrap(
ParRes.wrap(
FoldRes(peer.name, peers, ForModel.Mode.Never.some).wrap(
FoldRes
.lastNever(peer.name, peers)
.wrap(
ParRes.wrap(
XorRes.wrap(
// better if first relay will be outside `for`

View File

@ -227,12 +227,12 @@ object TagInliner extends Logging {
)
}
_ <- Exports[S].resolved(item, VarModel(n, elementType))
m = mode.map {
case ForTag.Mode.Wait => ForModel.Mode.Never
case ForTag.Mode.Pass => ForModel.Mode.Null
modeModel = mode match {
case ForTag.Mode.Blocking => ForModel.Mode.Never
case ForTag.Mode.NonBlocking => ForModel.Mode.Null
}
} yield TagInlined.Single(
model = ForModel(n, v, m),
model = ForModel(n, v, modeModel),
prefix = p
)

View File

@ -56,7 +56,7 @@ object StreamGateInliner extends Logging {
val resultCanon = VarModel(canonName, CanonStreamType(streamType.element))
RestrictionModel(varSTest.name, streamType).wrap(
ForModel(iter.name, VarModel(streamName, streamType), ForModel.Mode.Never.some).wrap(
ForModel(iter.name, VarModel(streamName, streamType), ForModel.Mode.Never).wrap(
PushToStreamModel(
iter,
CallModel.Export(varSTest.name, varSTest.`type`)

View File

@ -2064,8 +2064,12 @@ class ArrowInlinerSpec extends AnyFlatSpec with Matchers with Inside {
.leaf
)
val foldOp =
ForTag(iVar.name, array, ForTag.Mode.Wait.some).wrap(inFold, NextTag(iVar.name).leaf)
val foldOp = ForTag
.blocking(iVar.name, array)
.wrap(
inFold,
NextTag(iVar.name).leaf
)
val model: OpModel.Tree = ArrowInliner
.callArrow[InliningState](
@ -2091,7 +2095,9 @@ class ArrowInlinerSpec extends AnyFlatSpec with Matchers with Inside {
._2
model.equalsOrShowDiff(
ForModel(iVar0.name, ValueModel.fromRaw(array), ForModel.Mode.Never.some).wrap(
ForModel
.neverMode(iVar0.name, ValueModel.fromRaw(array))
.wrap(
CallServiceModel(
LiteralModel.fromRaw(serviceId),
fnName,

View File

@ -168,8 +168,7 @@ case class RestrictionTag(name: String, `type`: DataType) extends SeqGroupTag {
copy(name = map.getOrElse(name, name))
}
case class ForTag(item: String, iterable: ValueRaw, mode: Option[ForTag.Mode] = None)
extends SeqGroupTag {
case class ForTag(item: String, iterable: ValueRaw, mode: ForTag.Mode) extends SeqGroupTag {
override def restrictsVarNames: Set[String] = Set(item)
@ -185,9 +184,15 @@ case class ForTag(item: String, iterable: ValueRaw, mode: Option[ForTag.Mode] =
object ForTag {
enum Mode {
case Wait
case Pass
case Blocking
case NonBlocking
}
def blocking(item: String, iterable: ValueRaw): ForTag =
ForTag(item, iterable, Mode.Blocking)
def nonBlocking(item: String, iterable: ValueRaw): ForTag =
ForTag(item, iterable, Mode.NonBlocking)
}
case class CallArrowRawTag(

View File

@ -46,7 +46,13 @@ object MakeRes {
case SeqModel | _: OnModel | _: ApplyTopologyModel => SeqRes.leaf
case MatchMismatchModel(a, b, s) =>
MatchMismatchRes(a, b, s).leaf
case ForModel(item, iter, mode) if !isNillLiteral(iter) => FoldRes(item, iter, mode).leaf
case ForModel(item, iter, mode) if !isNillLiteral(iter) =>
val modeRes = mode match {
case ForModel.Mode.Null => FoldRes.Mode.Null
case ForModel.Mode.Never => FoldRes.Mode.Never
}
FoldRes(item, iter, modeRes).leaf
case RestrictionModel(item, itemType) => RestrictionRes(item, itemType).leaf
case DetachModel => ParRes.leaf
case ParModel => ParRes.leaf

View File

@ -32,9 +32,18 @@ case class MatchMismatchRes(left: ValueModel, right: ValueModel, shouldMatch: Bo
override def toString: String = s"(${if (shouldMatch) "match" else "mismatch"} $left $right)"
}
case class FoldRes(item: String, iterable: ValueModel, mode: Option[ForModel.Mode] = None)
extends ResolvedOp {
override def toString: String = s"(fold $iterable $item ${mode.map(_.toString).getOrElse("")}"
case class FoldRes(item: String, iterable: ValueModel, mode: FoldRes.Mode) extends ResolvedOp {
override def toString: String = s"(fold $iterable $item ${mode.toString.toLowerCase()}"
}
object FoldRes {
enum Mode { case Null, Never }
def lastNull(item: String, iterable: ValueModel): FoldRes =
FoldRes(item, iterable, Mode.Null)
def lastNever(item: String, iterable: ValueModel): FoldRes =
FoldRes(item, iterable, Mode.Never)
}
case class RestrictionRes(item: String, `type`: DataType) extends ResolvedOp {
@ -50,7 +59,8 @@ case class CallServiceRes(
override def toString: String = s"(call $peerId ($serviceId $funcName) $call)"
}
case class ApStreamMapRes(key: ValueModel, value: ValueModel, exportTo: CallModel.Export) extends ResolvedOp {
case class ApStreamMapRes(key: ValueModel, value: ValueModel, exportTo: CallModel.Export)
extends ResolvedOp {
override def toString: String = s"(ap ($key $value) $exportTo)"
}

View File

@ -18,7 +18,7 @@ object ResBuilder {
val arrayRes = VarModel(stream.name + "_gate", ArrayType(ScalarType.string))
RestrictionRes(testVM.name, testStreamType).wrap(
FoldRes(iter.name, stream, ForModel.Mode.Never.some).wrap(
FoldRes(iter.name, stream, FoldRes.Mode.Never).wrap(
ApRes(iter, CallModel.Export(testVM.name, testVM.`type`)).leaf,
CanonRes(testVM, peer, CallModel.Export(canon.name, canon.`type`)).leaf,
XorRes.wrap(

View File

@ -147,11 +147,11 @@ case class MatchMismatchModel(left: ValueModel, right: ValueModel, shouldMatch:
case class ForModel(
item: String,
iterable: ValueModel,
mode: Option[ForModel.Mode] = Some(ForModel.Mode.Null)
mode: ForModel.Mode = ForModel.Mode.Null
) extends SeqGroupModel {
override def toString: String =
s"for $item <- $iterable${mode.map(m => " " + m.toString).getOrElse("")}"
s"for $item <- $iterable${mode.toString}"
override def restrictsVarNames: Set[String] = Set(item)
@ -165,6 +165,12 @@ object ForModel {
case Null
case Never
}
def neverMode(item: String, iterable: ValueModel): ForModel =
ForModel(item, iterable, Mode.Never)
def nullMode(item: String, iterable: ValueModel): ForModel =
ForModel(item, iterable, Mode.Null)
}
// TODO how is it used? remove, if it's not
@ -175,7 +181,12 @@ case class DeclareStreamModel(value: ValueModel) extends NoExecModel {
}
// key must be only string or number
case class InsertKeyValueModel(key: ValueModel, value: ValueModel, assignTo: String, assignToType: StreamMapType) extends OpModel {
case class InsertKeyValueModel(
key: ValueModel,
value: ValueModel,
assignTo: String,
assignToType: StreamMapType
) extends OpModel {
override def usesVarNames: Set[String] = value.usesVarNames
override def exportsVarNames: Set[String] = Set(assignTo)

View File

@ -35,7 +35,9 @@ case class ArgsFromService(dataServiceId: ValueRaw) extends ArgsProvider {
Call(Nil, Call.Export(iter, ArrayType(t.element)) :: Nil)
)
.leaf,
ForTag(item, VarRaw(iter, ArrayType(t.element))).wrap(
ForTag
.nonBlocking(item, VarRaw(iter, ArrayType(t.element)))
.wrap(
SeqTag.wrap(
PushToStreamTag(VarRaw(item, t.element), Call.Export(varName, t)).leaf,
NextTag(item).leaf

View File

@ -377,7 +377,11 @@ object Topology extends Logging {
NextRes(itemName).leaf
)
FoldRes(itemName, v).wrap(if (reversed) steps.reverse else steps)
FoldRes
.lastNull(itemName, v)
.wrap(
if (reversed) steps.reverse else steps
)
case _ =>
MakeRes.hop(v)
}

View File

@ -124,7 +124,7 @@ object ModelBuilder {
failErrorModel
)
def fold(item: String, iter: ValueRaw, mode: Option[ForModel.Mode], body: OpModel.Tree*) = {
def fold(item: String, iter: ValueRaw, mode: ForModel.Mode, body: OpModel.Tree*) = {
val ops = SeqModel.wrap(body: _*)
ForModel(item, ValueModel.fromRaw(iter), mode).wrap(ops, NextModel(item).leaf)
}
@ -132,7 +132,8 @@ object ModelBuilder {
def foldPar(item: String, iter: ValueRaw, body: OpModel.Tree*) = {
val ops = SeqModel.wrap(body: _*)
DetachModel.wrap(
ForModel(item, ValueModel.fromRaw(iter), ForModel.Mode.Never.some)
ForModel
.neverMode(item, ValueModel.fromRaw(iter))
.wrap(ParModel.wrap(ops, NextModel(item).leaf))
)
}

View File

@ -1,7 +1,7 @@
package aqua.model.transform.topology
import aqua.model.transform.ModelBuilder
import aqua.model.{CallModel, OnModel, SeqModel}
import aqua.model.{CallModel, ForModel, OnModel, SeqModel}
import aqua.model.transform.cursor.ChainZipper
import aqua.raw.value.{LiteralRaw, ValueRaw, VarRaw}
import aqua.raw.ops.{Call, FuncOp, OnTag}
@ -137,7 +137,7 @@ class OpModelTreeCursorSpec extends AnyFlatSpec with Matchers {
fold(
"item",
VarRaw("iterable", ArrayType(ScalarType.string)),
None,
ForModel.Mode.Null,
OnModel(
VarRaw("-in-fold-", ScalarType.string),
Chain.one(VarRaw("-fold-relay-", ScalarType.string))

View File

@ -463,7 +463,8 @@ class TopologySpec extends AnyFlatSpec with Matchers {
through(relay),
callRes(0, otherPeer),
ParRes.wrap(
FoldRes("i", valueArray, ForModel.Mode.Never.some)
FoldRes
.lastNever("i", valueArray)
.wrap(ParRes.wrap(callRes(2, otherPeer2), NextRes("i").leaf))
),
through(relay),
@ -509,7 +510,9 @@ class TopologySpec extends AnyFlatSpec with Matchers {
val proc = Topology.resolve(init).value
val foldRes = ParRes.wrap(
FoldRes("i", valueArray, ForModel.Mode.Never.some).wrap(
FoldRes
.lastNever("i", valueArray)
.wrap(
ParRes.wrap(
// better if first relay will be outside `for`
SeqRes.wrap(
@ -579,7 +582,9 @@ class TopologySpec extends AnyFlatSpec with Matchers {
val proc = Topology.resolve(init).value
val fold = ParRes.wrap(
FoldRes("i", valueArray, ForModel.Mode.Never.some).wrap(
FoldRes
.lastNever("i", valueArray)
.wrap(
ParRes.wrap(
// better if first relay will be outside `for`
SeqRes.wrap(
@ -626,7 +631,7 @@ class TopologySpec extends AnyFlatSpec with Matchers {
fold(
"i",
valueArray,
None,
ForModel.Mode.Null,
OnModel(otherPeer2, Chain.one(otherRelay2)).wrap(
callModel(2)
)
@ -643,7 +648,9 @@ class TopologySpec extends AnyFlatSpec with Matchers {
through(relay),
callRes(1, otherPeer),
through(otherRelay2),
FoldRes("i", valueArray).wrap(
FoldRes
.lastNull("i", valueArray)
.wrap(
callRes(2, otherPeer2),
NextRes("i").leaf
),
@ -662,7 +669,7 @@ class TopologySpec extends AnyFlatSpec with Matchers {
fold(
"i",
valueArray,
None,
ForModel.Mode.Null,
OnModel(i, Chain.one(otherRelay)).wrap(
callModel(1)
)
@ -674,7 +681,9 @@ class TopologySpec extends AnyFlatSpec with Matchers {
val expected =
SeqRes.wrap(
through(relay),
FoldRes("i", valueArray).wrap(
FoldRes
.lastNull("i", valueArray)
.wrap(
SeqRes.wrap(
through(otherRelay),
callRes(1, i)
@ -766,7 +775,9 @@ class TopologySpec extends AnyFlatSpec with Matchers {
val expected = SeqRes.wrap(
callRes(1, otherPeer),
ParRes.wrap(
FoldRes("i", valueArray, ForModel.Mode.Never.some).wrap(
FoldRes
.lastNever("i", valueArray)
.wrap(
ParRes.wrap(
SeqRes.wrap(
// TODO: should be outside of fold
@ -843,7 +854,9 @@ class TopologySpec extends AnyFlatSpec with Matchers {
val expected = SeqRes.wrap(
ParRes.wrap(
FoldRes("i", ValueModel.fromRaw(valueArray), ForModel.Mode.Never.some).wrap(
FoldRes
.lastNever("i", ValueModel.fromRaw(valueArray))
.wrap(
ParRes.wrap(
SeqRes.wrap(
through(relay),
@ -892,7 +905,9 @@ class TopologySpec extends AnyFlatSpec with Matchers {
val proc = Topology.resolve(init).value
val foldRes = ParRes.wrap(
FoldRes("i", ValueModel.fromRaw(valueArray), ForModel.Mode.Never.some).wrap(
FoldRes
.lastNever("i", ValueModel.fromRaw(valueArray))
.wrap(
ParRes.wrap(
SeqRes.wrap(
through(relay),
@ -1036,7 +1051,9 @@ class TopologySpec extends AnyFlatSpec with Matchers {
CallModel.Export(array.name, array.`type`)
).leaf
),
FoldRes(iterName, array, ForModel.Mode.Null.some).wrap(
FoldRes
.lastNull(iterName, array)
.wrap(
NextRes(iterName).leaf
)
)

View File

@ -21,6 +21,7 @@ import cats.syntax.apply.*
import cats.syntax.flatMap.*
import cats.syntax.functor.*
import cats.syntax.option.*
import aqua.parser.expr.func.ForExpr.Mode
class ForSem[S[_]](val expr: ForExpr[S]) extends AnyVal {
@ -44,7 +45,14 @@ class ForSem[S[_]](val expr: ForExpr[S]) extends AnyVal {
case ForExpr.Mode.TryMode => TryTag
}
val mode = expr.mode.collect { case ForExpr.Mode.ParMode => ForTag.Mode.Wait }
/**
* `for ... par` => blocking (`never` as `last` in `fold`)
* `for` and `for ... try` => non blocking (`null` as `last` in `fold`)
*/
val mode = expr.mode.fold(ForTag.Mode.NonBlocking) {
case ForExpr.Mode.ParMode => ForTag.Mode.Blocking
case Mode.TryMode => ForTag.Mode.NonBlocking
}
val forTag = ForTag(expr.item.value, vm, mode).wrap(
innerTag.wrap(

View File

@ -22,7 +22,7 @@ import cats.syntax.functor.*
class ParSeqSem[S[_]](val expr: ParSeqExpr[S]) extends AnyVal {
def program[F[_]: Monad](implicit
def program[F[_]: Monad](using
V: ValuesAlgebra[S, F],
N: NamesAlgebra[S, F],
T: TypesAlgebra[S, F],
@ -63,7 +63,13 @@ class ParSeqSem[S[_]](val expr: ParSeqExpr[S]) extends AnyVal {
via = Chain.fromSeq(viaVM),
strategy = OnTag.ReturnStrategy.Relay.some
)
tag = ForTag(expr.item.value, vm).wrap(
/**
* `parseq` => blocking (`never` as `last` in `fold`)
* So that peer initiating `parseq` would not continue execution past it
*/
tag = ForTag
.blocking(expr.item.value, vm)
.wrap(
ParTag.wrap(
onTag.wrap(restricted),
NextTag(expr.item.value).leaf

View File

@ -581,7 +581,7 @@ class SemanticsSpec extends AnyFlatSpec with Matchers with Inside {
|""".stripMargin
insideBody(script) { body =>
matchSubtree(body) { case (ForTag("p", _, None), forTag) =>
matchSubtree(body) { case (ForTag("p", _, ForTag.Mode.Blocking), forTag) =>
matchChildren(forTag) { case (ParTag, parTag) =>
matchChildren(parTag)(
{ case (OnTag(_, _, strat), _) =>