feat: Return arrows from functions (#693)

This commit is contained in:
Dima 2023-04-14 16:28:17 +03:00 committed by GitHub
parent a3c1b0ed31
commit 8fa979cd33
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
12 changed files with 249 additions and 82 deletions

View File

@ -1,15 +1,10 @@
service Console("run-console"):
print(any: )
get() -> string
zzz() -> string
data Azazaz:
s: string
func exec(peers: []string) -> []string:
on "":
closure = (s: Azazaz) -> Azazaz:
Console.get()
func returnCall() -> string -> string:
closure = (s: string) -> string:
<- s
Console.zzz()
<- peers
closure("123asdf")
<- closure
func test() -> string:
a = returnCall()
b = a("arg")
<- b

View File

@ -3,6 +3,7 @@ package aqua.model.inline
import aqua.model.inline.state.{Arrows, Counter, Exports, Mangler}
import aqua.model.*
import aqua.raw.ops.RawTag
import aqua.types.ArrowType
import aqua.raw.value.{ValueRaw, VarRaw}
import aqua.types.{BoxType, StreamType}
import cats.data.{Chain, State, StateT}
@ -131,14 +132,24 @@ object ArrowInliner extends Logging {
argsToArrows = argsToArrowsRaw.map { case (k, v) => argsShouldRename.getOrElse(k, k) -> v }
returnedArrows = fn.ret.collect { case VarRaw(name, ArrowType(_, _)) =>
name
}.toSet
returnedArrowsShouldRename <- Mangler[S].findNewNames(returnedArrows)
renamedCapturedArrows = fn.capturedArrows.map { case (k, v) =>
returnedArrowsShouldRename.getOrElse(k, k) -> v
}
// Going to resolve arrows: collect them all. Names should never collide: it's semantically checked
_ <- Arrows[S].purge
_ <- Arrows[S].resolved(fn.capturedArrows ++ argsToArrows)
_ <- Arrows[S].resolved(renamedCapturedArrows ++ argsToArrows)
// Rename all renamed arguments in the body
treeRenamed =
fn.body
.rename(argsShouldRename)
.rename(returnedArrowsShouldRename)
.map(_.mapValues(_.map {
// if an argument is a BoxType (Array or Option), but we pass a stream,
// change a type as stream to not miss `$` sign in air
@ -172,7 +183,7 @@ object ArrowInliner extends Logging {
tree = treeRenamed.rename(shouldRename)
// Result could be renamed; take care about that
} yield (tree, fn.ret.map(_.renameVars(shouldRename)))
} yield (tree, fn.ret.map(_.renameVars(shouldRename ++ returnedArrowsShouldRename)))
private[inline] def callArrowRet[S: Exports: Arrows: Mangler](
arrow: FuncArrow,
@ -185,12 +196,15 @@ object ArrowInliner extends Logging {
for {
_ <- Arrows[S].resolved(passArrows)
av <- ArrowInliner.inline(arrow, call)
} yield av
)
(appliedOp, value) = av
_ <- Exports[S].resolved(call.exportTo.map(_.name).zip(value).toMap)
} yield appliedOp -> value
// find and get resolved arrows if we return them from the function
returnedArrows = av._2.collect { case VarModel(name, ArrowType(_, _), _) =>
name
}
arrowsToSave <- Arrows[S].pickArrows(returnedArrows.toSet)
} yield av -> arrowsToSave
)
((appliedOp, values), arrowsToSave) = av
_ <- Arrows[S].resolved(arrowsToSave)
_ <- Exports[S].resolved(call.exportTo.map(_.name).zip(values).toMap)
} yield appliedOp -> values
}

View File

@ -4,9 +4,10 @@ import aqua.model.inline.state.{Arrows, Counter, Exports, Mangler}
import aqua.model.*
import aqua.model.inline.RawValueInliner.collectionToModel
import aqua.model.inline.raw.{CallArrowRawInliner, CollectionRawInliner}
import aqua.raw.arrow.FuncRaw
import aqua.raw.ops.*
import aqua.raw.value.*
import aqua.types.{ArrayType, BoxType, CanonStreamType, StreamType}
import aqua.types.{ArrayType, ArrowType, BoxType, CanonStreamType, StreamType}
import cats.syntax.traverse.*
import cats.syntax.applicative.*
import cats.instances.list.*

View File

@ -1,11 +1,12 @@
package aqua.model.inline.raw
import aqua.model.inline.Inline.parDesugarPrefixOpt
import aqua.model.{CallServiceModel, SeqModel, ValueModel, VarModel}
import aqua.model.{CallServiceModel, FuncArrow, SeqModel, ValueModel, VarModel}
import aqua.model.inline.{ArrowInliner, Inline, TagInliner}
import aqua.model.inline.RawValueInliner.{callToModel, valueToModel}
import aqua.model.inline.state.{Arrows, Exports, Mangler}
import aqua.raw.ops.Call
import aqua.types.ArrowType
import aqua.raw.value.CallArrowRaw
import cats.data.{Chain, State}
import scribe.Logging
@ -42,10 +43,13 @@ object CallArrowRawInliner extends RawInliner[CallArrowRaw] with Logging {
*/
val funcName = value.ability.fold(value.name)(_ + "." + value.name)
logger.trace(s" $funcName")
Arrows[S].arrows.flatMap(arrows =>
arrows.get(funcName) match {
case Some(fn) =>
logger.trace(Console.YELLOW + s"Call arrow $funcName" + Console.RESET)
resolveArrow(funcName, call)
}
}
private def resolveFuncArrow[S: Mangler: Exports: Arrows](fn: FuncArrow, call: Call) = {
logger.trace(Console.YELLOW + s"Call arrow ${fn.funcName}" + Console.RESET)
callToModel(call, false).flatMap { case (cm, p) =>
ArrowInliner
.callArrowRet(fn, cm)
@ -56,16 +60,40 @@ object CallArrowRawInliner extends RawInliner[CallArrowRaw] with Logging {
)
}
}
}
private def resolveArrow[S: Mangler: Exports: Arrows](funcName: String, call: Call) =
Arrows[S].arrows.flatMap(arrows =>
arrows.get(funcName) match {
case Some(fn) =>
resolveFuncArrow(fn, call)
case None =>
Exports[S].exports.flatMap { exps =>
// if there is no arrow, check if it is stored in Exports as variable and try to resolve it
exps.get(funcName) match {
case Some(VarModel(name, ArrowType(_, _), _)) =>
Arrows[S].arrows.flatMap(arrows =>
arrows.get(name) match {
case Some(fn) =>
resolveFuncArrow(fn, call)
case _ =>
logger.error(
s"Inlining, cannot find arrow ${funcName}, available: ${arrows.keys
s"Inlining, cannot find arrow $funcName, available: ${arrows.keys
.mkString(", ")}"
)
State.pure(Nil -> Inline.empty)
}
)
case _ =>
logger.error(
s"Inlining, cannot find arrow $funcName, available: ${arrows.keys
.mkString(", ")}"
)
State.pure(Nil -> Inline.empty)
}
}
}
)
override def apply[S: Mangler: Exports: Arrows](
raw: CallArrowRaw,

View File

@ -3,7 +3,7 @@ package aqua.raw.ops
import aqua.raw.Raw
import aqua.raw.arrow.FuncRaw
import aqua.raw.ops.RawTag.Tree
import aqua.raw.value.{CallArrowRaw, ValueRaw}
import aqua.raw.value.{CallArrowRaw, ValueRaw, VarRaw}
import aqua.tree.{TreeNode, TreeNodeCompanion}
import aqua.types.{ArrowType, ProductType}
import cats.{Eval, Show}
@ -104,7 +104,8 @@ case class MatchMismatchTag(left: ValueRaw, right: ValueRaw, shouldMatch: Boolea
MatchMismatchTag(left.map(f), right.map(f), shouldMatch)
}
case class ForTag(item: String, iterable: ValueRaw, mode: Option[ForTag.Mode] = None) extends SeqGroupTag {
case class ForTag(item: String, iterable: ValueRaw, mode: Option[ForTag.Mode] = None)
extends SeqGroupTag {
override def restrictsVarNames: Set[String] = Set(item)
@ -195,6 +196,9 @@ case class ClosureTag(
detach: Boolean
) extends NoExecTag {
override def renameExports(map: Map[String, String]): RawTag =
copy(func = func.copy(name = map.getOrElse(func.name, func.name)))
override def mapValues(f: ValueRaw => ValueRaw): RawTag =
copy(
func.copy(arrow =

View File

@ -3,14 +3,14 @@ package aqua.parser.expr.func
import aqua.parser.Expr
import aqua.parser.expr.func.DeclareStreamExpr
import aqua.parser.lexer.Token.*
import aqua.parser.lexer.{Name, Token, TypeToken}
import aqua.parser.lexer.{DataTypeToken, Name, Token, TypeToken}
import aqua.parser.lift.LiftParser
import cats.parse.Parser
import cats.{~>, Comonad}
import cats.parse.Parser as P
import cats.{Comonad, ~>}
import aqua.parser.lift.Span
import aqua.parser.lift.Span.{P0ToSpan, PToSpan}
case class DeclareStreamExpr[F[_]](name: Name[F], `type`: TypeToken[F])
case class DeclareStreamExpr[F[_]](name: Name[F], `type`: DataTypeToken[F])
extends Expr[F](DeclareStreamExpr, name) {
override def mapK[K[_]: Comonad](fk: F ~> K): DeclareStreamExpr[K] =
@ -19,8 +19,8 @@ case class DeclareStreamExpr[F[_]](name: Name[F], `type`: TypeToken[F])
object DeclareStreamExpr extends Expr.Leaf {
override val p: Parser[DeclareStreamExpr[Span.S]] =
((Name.p <* ` : `) ~ TypeToken.`typedef`).map { case (name, t) =>
override val p: P[DeclareStreamExpr[Span.S]] =
((Name.p <* ` : `) ~ DataTypeToken.`datatypedef`).map { case (name, t) =>
DeclareStreamExpr(name, t)
}

View File

@ -20,6 +20,7 @@ case class IfExpr[F[_]](left: ValueToken[F], eqOp: EqOp[F], right: ValueToken[F]
object IfExpr extends Expr.AndIndented {
// list of expressions that can be used inside this block
override def validChildren: List[Expr.Lexem] = ForExpr.validChildren
override val p: P[IfExpr[Span.S]] =

View File

@ -10,7 +10,7 @@ import cats.syntax.comonad.*
import cats.syntax.functor.*
import cats.~>
import aqua.parser.lift.Span
import aqua.parser.lift.Span.{P0ToSpan, PToSpan}
import aqua.parser.lift.Span.{P0ToSpan, PToSpan, S}
sealed trait TypeToken[S[_]] extends Token[S] {
def mapK[K[_]: Comonad](fk: S ~> K): TypeToken[K]
@ -102,7 +102,7 @@ object BasicTypeToken {
case class ArrowTypeToken[S[_]: Comonad](
override val unit: S[Unit],
args: List[(Option[Name[S]], TypeToken[S])],
res: List[DataTypeToken[S]]
res: List[TypeToken[S]]
) extends TypeToken[S] {
override def as[T](v: T): S[T] = unit.as(v)
@ -117,9 +117,15 @@ case class ArrowTypeToken[S[_]: Comonad](
object ArrowTypeToken {
def typeDef(): P[TypeToken[S]] = P.defer(TypeToken.`typedef`.between(`(`, `)`).backtrack | TypeToken.`typedef`)
def returnDef(): P[List[TypeToken[S]]] = comma(
typeDef().backtrack
).map(_.toList)
def `arrowdef`(argTypeP: P[TypeToken[Span.S]]): P[ArrowTypeToken[Span.S]] =
(comma0(argTypeP).with1 ~ ` -> `.lift ~
(comma(DataTypeToken.`datatypedef`).map(_.toList)
(returnDef().backtrack
| `()`.as(Nil))).map { case ((args, point), res)
ArrowTypeToken(point, args.map(Option.empty[Name[Span.S]] -> _), res)
}
@ -129,7 +135,7 @@ object ArrowTypeToken {
(Name.p.map(Option(_)) ~ (` : ` *> (argTypeP | argTypeP.between(`(`, `)`))))
.surroundedBy(`/s*`)
) <* (`/s*` *> `)` <* ` `.?)) ~
(` -> ` *> comma(DataTypeToken.`datatypedef`)).?).map { case ((point, args), res) =>
(` -> ` *> returnDef()).?).map { case ((point, args), res) =>
ArrowTypeToken(point, args, res.toList.flatMap(_.toList))
}
}

View File

@ -4,6 +4,7 @@ import aqua.parser.lift.LiftParser.Implicits.idLiftParser
import aqua.types.ScalarType
import aqua.types.ScalarType.u32
import cats.Id
import cats.parse.Parser
import org.scalatest.EitherValues
import org.scalatest.flatspec.AnyFlatSpec
import org.scalatest.matchers.should.Matchers
@ -21,9 +22,61 @@ class TypeTokenSpec extends AnyFlatSpec with Matchers with EitherValues {
BasicTypeToken.`basictypedef`.parseAll("()").isLeft should be(true)
}
"Return type" should "parse" in {
def typedef(str: String) =
ArrowTypeToken.typeDef().parseAll(str).value.mapK(spanToId)
def returndef(str: String) =
ArrowTypeToken.returnDef().parseAll(str).value.map(_.mapK(spanToId))
typedef("(A -> ())") should be(
ArrowTypeToken[Id]((), List((None, CustomTypeToken[Id]("A"))), Nil)
)
typedef("(A -> B)") should be(
ArrowTypeToken[Id]((), List((None, CustomTypeToken[Id]("A"))), List(CustomTypeToken[Id]("B")))
)
returndef("(A -> B), (C -> D)") should be(
List(
ArrowTypeToken[Id](
(),
(None, CustomTypeToken[Id]("A")) :: Nil,
List(CustomTypeToken[Id]("B"))
),
ArrowTypeToken[Id](
(),
(None, CustomTypeToken[Id]("C")) :: Nil,
List(CustomTypeToken[Id]("D"))
)
)
)
returndef("A, (B, C -> D, E), F -> G, H") should be(
List(
CustomTypeToken[Id]("A"),
ArrowTypeToken[Id](
(),
(None, CustomTypeToken[Id]("B")) :: (None, CustomTypeToken[Id]("C")) :: Nil,
List(CustomTypeToken[Id]("D"), CustomTypeToken[Id]("E"))
),
ArrowTypeToken[Id](
(),
(None, CustomTypeToken[Id]("F")) :: Nil,
List(CustomTypeToken[Id]("G"), CustomTypeToken[Id]("H"))
)
)
)
}
"Arrow type" should "parse" in {
def arrowdef(str: String) = ArrowTypeToken.`arrowdef`(DataTypeToken.`datatypedef`).parseAll(str).value.mapK(spanToId)
def arrowWithNames(str: String) = ArrowTypeToken.`arrowWithNames`(DataTypeToken.`datatypedef`).parseAll(str).value.mapK(spanToId)
def arrowdef(str: String) =
ArrowTypeToken.`arrowdef`(DataTypeToken.`datatypedef`).parseAll(str).value.mapK(spanToId)
def arrowWithNames(str: String) = ArrowTypeToken
.`arrowWithNames`(DataTypeToken.`datatypedef`)
.parseAll(str)
.value
.mapK(spanToId)
arrowdef("-> B") should be(
ArrowTypeToken[Id]((), Nil, List(CustomTypeToken[Id]("B")))
@ -36,6 +89,53 @@ class TypeTokenSpec extends AnyFlatSpec with Matchers with EitherValues {
)
)
arrowdef("A -> B -> C") should be(
ArrowTypeToken[Id](
(),
(None -> CustomTypeToken[Id]("A")) :: Nil,
List(
ArrowTypeToken[Id](
(),
(None -> CustomTypeToken[Id]("B")) :: Nil,
List(CustomTypeToken[Id]("C"))
)
)
)
)
arrowdef("A -> B, C -> D") should be(
ArrowTypeToken[Id](
(),
(None -> CustomTypeToken[Id]("A")) :: Nil,
List(
ArrowTypeToken[Id](
(),
(None -> CustomTypeToken[Id]("B")) :: (None -> CustomTypeToken[Id]("C")) :: Nil,
List(CustomTypeToken[Id]("D"))
)
)
)
)
arrowdef("A -> (B -> F), (C -> D, E)") should be(
ArrowTypeToken[Id](
(),
(None -> CustomTypeToken[Id]("A")) :: Nil,
List(
ArrowTypeToken[Id](
(),
(None -> CustomTypeToken[Id]("B")) :: Nil,
CustomTypeToken[Id]("F") :: Nil
),
ArrowTypeToken[Id](
(),
(None -> CustomTypeToken[Id]("C")) :: Nil,
CustomTypeToken[Id]("D") :: CustomTypeToken[Id]("E") :: Nil
)
)
)
)
arrowWithNames("(a: A) -> B") should be(
ArrowTypeToken[Id](
(),

View File

@ -1,8 +1,11 @@
package aqua.semantics.expr.func
import aqua.raw.Raw
import aqua.raw.ops.{AssignmentTag, FuncOp}
import aqua.types.ArrowType
import aqua.raw.value.CallArrowRaw
import aqua.raw.ops.{AssignmentTag, ClosureTag, FuncOp}
import aqua.parser.expr.func.AssignmentExpr
import aqua.raw.arrow.FuncRaw
import aqua.semantics.Prog
import aqua.semantics.rules.ValuesAlgebra
import aqua.semantics.rules.names.NamesAlgebra
@ -19,10 +22,19 @@ class AssignmentSem[S[_]](val expr: AssignmentExpr[S]) extends AnyVal {
): Prog[Alg, Raw] =
V.valueToRaw(expr.value).flatMap {
case Some(vm) =>
vm.`type` match {
case at @ ArrowType(_, _) =>
N.defineArrow(expr.variable, at, false) as (AssignmentTag(
vm,
expr.variable.value
).funcOpLeaf: Raw)
case _ =>
N.derive(expr.variable, vm.`type`, vm.varNames) as (AssignmentTag(
vm,
expr.variable.value
).funcOpLeaf: Raw)
}
case _ => Raw.error("Cannot resolve assignment type").pure[Alg]
}

View File

@ -65,6 +65,11 @@ class NamesInterpreter[S[_], X](implicit lens: Lens[X, NamesState[S]], error: Re
case Some(g) =>
modify(st => st.copy(locations = st.locations :+ (name, g))).map(_ => Option(g.tokenType))
case None =>
// check if we have arrow in variable
readName(name.value).flatMap {
case Some(tt@TokenTypeInfo(_, at@ArrowType(_, _))) =>
modify(st => st.copy(locations = st.locations :+ (name, tt))).map(_ => Option(at))
case _ =>
getState.flatMap(st =>
report(
name,
@ -77,6 +82,7 @@ class NamesInterpreter[S[_], X](implicit lens: Lens[X, NamesState[S]], error: Re
.as(Option.empty[ArrowType])
)
}
}
def readArrowHelper(name: String): SX[Option[TokenArrowInfo[S]]] =
getState.map { st =>

View File

@ -172,7 +172,7 @@ class TypesInterpreter[S[_], X](implicit lens: Lens[X, TypesState[S]], error: Re
ensureTypeMatches(op.fields.lookup(fieldName).getOrElse(op), t, value.`type`)
case None => report(op, s"No field with name '$fieldName' in $rootT").as(false)
}
}.map(res => if (res.toList.fold(true)(_ && _)) Some(IntoCopyRaw(st, fields)) else None)
}.map(res => if (res.forall(identity)) Some(IntoCopyRaw(st, fields)) else None)
case _ =>
report(op, s"Expected $rootT to be a data type").as(None)
@ -230,7 +230,7 @@ class TypesInterpreter[S[_], X](implicit lens: Lens[X, TypesState[S]], error: Re
if (expected.acceptsValueOf(givenType)) State.pure(true)
else {
(expected, givenType) match {
case (StructType(n, valueFields), StructType(typeName, typeFields)) =>
case (StructType(n, valueFields), StructType(_, typeFields)) =>
// value can have more fields
if (valueFields.length < typeFields.length) {
report(
@ -253,7 +253,7 @@ class TypesInterpreter[S[_], X](implicit lens: Lens[X, TypesState[S]], error: Re
s"Wrong value type, expected: $expected, given: $givenType"
).as(false)
}
}.map(_.toList.fold(true)(_ && _))
}.map(_.forall(identity))
}
case _ =>
val notes =
@ -287,7 +287,7 @@ class TypesInterpreter[S[_], X](implicit lens: Lens[X, TypesState[S]], error: Re
else
report(
token,
s"Number of arguments doesn't match the function type, expected: ${expected}, given: ${givenNum}"
s"Number of arguments doesn't match the function type, expected: ${expected}, given: $givenNum"
).as(false)
override def beginArrowScope(token: ArrowTypeToken[S]): State[X, ArrowType] =
@ -362,7 +362,7 @@ class TypesInterpreter[S[_], X](implicit lens: Lens[X, TypesState[S]], error: Re
frame.arrowType.codomain.toList
.lazyZip(values.toList)
.foldLeft[Either[(Token[S], String, Boolean), List[ValueRaw]]](Right(Nil)) {
case (acc, (returnType, (token, returnValue))) =>
case (acc, (returnType, (_, returnValue))) =>
acc.flatMap { a =>
if (!returnType.acceptsValueOf(returnValue.`type`))
Left(
@ -372,7 +372,7 @@ class TypesInterpreter[S[_], X](implicit lens: Lens[X, TypesState[S]], error: Re
.headOption
.getOrElse(values.last)
._1,
s"Wrong value type, expected: ${returnType}, given: ${returnValue.`type`}",
s"Wrong value type, expected: $returnType, given: ${returnValue.`type`}",
false
)
)