perf(compiler): Optimize Linker [LNG-321] (#1049)

* Refactor parser

* Refactor parser

* Refactor parser

* Refactor parser

* Savepoint

* Savepoint

* Refactor

* Fix unit test

* Remove file

* Filter exported modules

* Fix test

* Restore Test.scala

* Add comments
This commit is contained in:
InversionSpaces 2024-01-22 16:01:54 +01:00 committed by GitHub
parent abcb63db3b
commit 7b6c7245ec
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 234 additions and 255 deletions

View File

@ -1,6 +1,7 @@
package aqua.compiler package aqua.compiler
import aqua.compiler.AquaError.{ParserError as AquaParserError, *} import aqua.compiler.AquaError.{ParserError as AquaParserError, *}
import aqua.linker.Linker.link
import aqua.linker.{AquaModule, Linker, Modules} import aqua.linker.{AquaModule, Linker, Modules}
import aqua.parser.{Ast, ParserError} import aqua.parser.{Ast, ParserError}
import aqua.semantics.header.{HeaderHandler, Picker} import aqua.semantics.header.{HeaderHandler, Picker}
@ -22,43 +23,31 @@ class AquaCompiler[F[_]: Monad, E, I: Order, S[_]: Comonad, C: Monoid: Picker](
) extends Logging { ) extends Logging {
type Err = AquaError[I, E, S] type Err = AquaError[I, E, S]
type Ctx = NonEmptyMap[I, C]
type CompileWarns = [A] =>> CompileWarnings[S][A] type CompileWarns = [A] =>> CompileWarnings[S][A]
type CompileRes = [A] =>> CompileResult[I, E, S][A] type CompileRes = [A] =>> CompileResult[I, E, S][A]
type CompiledCtx = CompileRes[Ctx] // Transpilation function for module
type CompiledCtxT = CompiledCtx => CompiledCtx // (Imports contexts => Compilation result)
type TP = Map[String, C] => CompileRes[C]
private def linkModules( private def transpile(body: Ast[S]): TP =
modules: Modules[I, Err, CompiledCtxT], imports =>
cycleError: Linker.DepCycle[AquaModule[I, Err, CompiledCtxT]] => Err for {
): CompileRes[Map[I, C]] = { // Process header, get initial context
logger.trace("linking modules...") headerSem <- headerHandler
.sem(imports, body.head)
// By default, provide an empty context for this module's id .toCompileRes
val empty: I => CompiledCtx = i => NonEmptyMap.one(i, Monoid[C].empty).pure[CompileRes] // Analyze the body, with prepared initial context
_ = logger.trace("semantic processing...")
for { processed <- semantics
linked <- Linker .process(body, headerSem.initCtx)
.link(modules, cycleError, empty) .toCompileRes
.toEither // Handle exports, declares - finalize the resulting context
.toEitherT[CompileWarns] rc <- headerSem
res <- EitherT( .finCtx(processed)
linked.toList.traverse { case (id, ctx) => .toCompileRes
ctx } yield rc
.map(
/**
* NOTE: This should be safe
* as result for id should contain itself
*/
_.apply(id).map(id -> _).get
)
.toValidated
}.map(_.sequence.toEither)
)
} yield res.toMap
}
def compileRaw( def compileRaw(
sources: AquaSources[F, E, I], sources: AquaSources[F, E, I],
@ -66,48 +55,18 @@ class AquaCompiler[F[_]: Monad, E, I: Order, S[_]: Comonad, C: Monoid: Picker](
): F[CompileRes[Map[I, C]]] = { ): F[CompileRes[Map[I, C]]] = {
logger.trace("starting resolving sources...") logger.trace("starting resolving sources...")
new AquaParser[F, E, I, S](sources, parser) val parsing = new AquaParser(sources, parser)
.resolve[CompiledCtx](mod =>
context => parsing.resolve.value.map(resolution =>
for { for {
// Context with prepared imports // Lift resolution to CompileRes
ctx <- context modules <- resolution.toEitherT[CompileWarns]
imports = mod.imports.flatMap { case (fn, id) => // Generate transpilation functions for each module
ctx.apply(id).map(fn -> _) transpiled = modules.map(body => transpile(body))
} // Link modules
header = mod.body.head linked <- Linker.link(transpiled, CycleError.apply)
headerSem <- headerHandler } yield linked
.sem(imports, header) )
.toCompileRes
// Analyze the body, with prepared initial context
_ = logger.trace("semantic processing...")
processed <- semantics
.process(
mod.body,
headerSem.initCtx
)
.toCompileRes
// Handle exports, declares - finalize the resulting context
rc <- headerSem
.finCtx(processed)
.toCompileRes
/**
* Here we build a map of contexts while processing modules.
* Should not linker provide this info inside this process?
* Building this map complicates things a lot.
*/
} yield NonEmptyMap.one(mod.id, rc)
)
.value
.map(resolved =>
for {
modules <- resolved.toEitherT[CompileWarns]
linked <- linkModules(
modules,
cycle => CycleError(cycle.map(_.id))
)
} yield linked
)
} }
private val warningsK: semantics.Warnings ~> CompileWarns = private val warningsK: semantics.Warnings ~> CompileWarns =

View File

@ -5,6 +5,7 @@ import aqua.linker.{AquaModule, Modules}
import aqua.parser.head.{FilenameExpr, ImportExpr} import aqua.parser.head.{FilenameExpr, ImportExpr}
import aqua.parser.lift.{LiftParser, Span} import aqua.parser.lift.{LiftParser, Span}
import aqua.parser.{Ast, ParserError} import aqua.parser.{Ast, ParserError}
import aqua.syntax.eithert.fromValidatedF
import cats.data.Chain.* import cats.data.Chain.*
import cats.data.Validated.* import cats.data.Validated.*
@ -16,6 +17,7 @@ import cats.syntax.flatMap.*
import cats.syntax.foldable.* import cats.syntax.foldable.*
import cats.syntax.functor.* import cats.syntax.functor.*
import cats.syntax.monad.* import cats.syntax.monad.*
import cats.syntax.parallel.*
import cats.syntax.traverse.* import cats.syntax.traverse.*
import cats.syntax.validated.* import cats.syntax.validated.*
import cats.{Comonad, Monad, ~>} import cats.{Comonad, Monad, ~>}
@ -32,108 +34,82 @@ class AquaParser[F[_]: Monad, E, I, S[_]: Comonad](
private type FE[A] = EitherT[F, NonEmptyChain[Err], A] private type FE[A] = EitherT[F, NonEmptyChain[Err], A]
// Parse all the source files // Parse one source (text)
private def parseSources: F[ValidatedNec[Err, Chain[(I, Body)]]] = private def parse(id: I, src: String): EitherNec[Err, (I, Body)] =
sources.sources.map( parser(id)(src).toEither.bimap(
_.leftMap(_.map(SourcesError.apply)).andThen( _.map(AquaParserError.apply),
_.traverse { case (i, s) => ast => id -> ast
parser(i)(s).bimap(
_.map(AquaParserError.apply),
ast => i -> ast
)
}
)
) )
// Parse all the source files
private def parseSources: FE[Chain[(I, Body)]] =
for {
srcs <- EitherT
.fromValidatedF(sources.sources)
.leftMap(_.map(SourcesError.apply))
parsed <- srcs
.parTraverse(parse.tupled)
.toEitherT
} yield parsed
// Load one module (parse, resolve imports)
private def loadModule(id: I): FE[AquaModule[I, Err, Body]] =
for {
src <- EitherT
.fromValidatedF(sources.load(id))
.leftMap(_.map(SourcesError.apply))
parsed <- parse(id, src).toEitherT
(id, ast) = parsed
resolved <- resolveImports(id, ast)
} yield resolved
// Resolve imports (not parse, just resolve) of the given file // Resolve imports (not parse, just resolve) of the given file
private def resolveImports(id: I, ast: Body): F[ValidatedNec[Err, AquaModule[I, Err, Body]]] = private def resolveImports(id: I, ast: Body): FE[AquaModule[I, Err, Body]] =
ast.head.collect { case fe: FilenameExpr[S] => ast.head.collect { case fe: FilenameExpr[S] =>
fe.fileValue -> fe.token fe.fileValue -> fe.token
}.traverse { case (filename, token) => }.parTraverse { case (filename, token) =>
sources EitherT
.resolveImport(id, filename) .fromValidatedF(
.map( sources.resolveImport(id, filename)
_.bimap(
_.map(ResolveImportsError(id, token, _): Err),
importId => importId -> (filename, ImportError(token): Err)
)
) )
}.map(_.sequence.map { collected => .bimap(
AquaModule[I, Err, Body]( _.map(ResolveImportsError(id, token, _): Err),
id, importId => importId -> (filename, ImportError(token): Err)
)
}.map { collected =>
AquaModule(
id = id,
// How filenames correspond to the resolved IDs // How filenames correspond to the resolved IDs
collected.map { case (i, (fn, _)) => imports = collected.map { case (i, (fn, _)) =>
fn -> i fn -> i
}.toList.toMap[String, I], }.toList.toMap,
// Resolved IDs to errors that point to the import in source code // Resolved IDs to errors that point to the import in source code
collected.map { case (i, (_, err)) => dependsOn = collected.map { case (i, (_, err)) =>
i -> err i -> err
}.toList.toMap[I, Err], }.toList.toMap,
ast body = ast
) )
})
// Parse sources, convert to modules
private def sourceModules: F[ValidatedNec[Err, Modules[I, Err, Body]]] =
parseSources.flatMap {
case Validated.Valid(srcs) =>
srcs.traverse { case (id, ast) =>
resolveImports(id, ast)
}.map(_.sequence)
case Validated.Invalid(errs) =>
errs.invalid.pure[F]
}.map(
_.map(
_.foldLeft(Modules[I, Err, Body]())(
_.add(_, toExport = true)
)
)
)
private def loadModule(imp: I): F[ValidatedNec[Err, AquaModule[I, Err, Body]]] =
sources
.load(imp)
.map(_.leftMap(_.map(SourcesError.apply)).andThen { src =>
parser(imp)(src).leftMap(_.map(AquaParserError.apply))
})
.flatMap {
case Validated.Valid(ast) =>
resolveImports(imp, ast)
case Validated.Invalid(errs) =>
errs.invalid.pure[F]
}
private def resolveModules(
modules: Modules[I, Err, Body]
): F[ValidatedNec[Err, Modules[I, Err, Ast[S]]]] =
modules.dependsOn.toList.traverse { case (moduleId, unresolvedErrors) =>
loadModule(moduleId).map(_.leftMap(_ ++ unresolvedErrors))
}.map(
_.sequence.map(
_.foldLeft(modules)(_ add _)
)
).flatMap {
case Validated.Valid(ms) if ms.isResolved =>
ms.validNec.pure[F]
case Validated.Valid(ms) =>
resolveModules(ms)
case err =>
err.pure[F]
} }
private def resolveSources: FE[Modules[I, Err, Ast[S]]] = // Load modules (parse, resolve imports) of all the source files
private lazy val loadModules: FE[Modules[I, Err, Body]] =
for { for {
ms <- EitherT( srcs <- parseSources
sourceModules.map(_.toEither) modules <- srcs.parTraverse(resolveImports.tupled)
) } yield Modules.from(modules)
res <- EitherT(
resolveModules(ms).map(_.toEither)
)
} yield res
def resolve[T]( // Resolve modules (load all the dependencies)
transpile: AquaModule[I, Err, Body] => T => T private def resolveModules(
): FE[Modules[I, Err, T => T]] = modules: Modules[I, Err, Body]
resolveSources.map(_.mapModuleToBody(transpile)) ): FE[Modules[I, Err, Ast[S]]] =
modules.iterateUntilM(ms =>
// Load all modules that are dependencies of the current modules
ms.dependsOn.toList.parTraverse { case (moduleId, unresolvedErrors) =>
loadModule(moduleId).leftMap(_ ++ unresolvedErrors)
}.map(ms.addAll) // Add all loaded modules to the current modules
)(_.isResolved)
lazy val resolve: FE[Modules[I, Err, Body]] =
loadModules >>= resolveModules
} }

View File

@ -1,17 +1,26 @@
package aqua.linker package aqua.linker
import aqua.errors.Errors.internalError
import cats.MonadError
import cats.data.{NonEmptyChain, Validated, ValidatedNec} import cats.data.{NonEmptyChain, Validated, ValidatedNec}
import cats.kernel.{Monoid, Semigroup}
import cats.syntax.semigroup.*
import cats.syntax.validated.*
import cats.syntax.functor.*
import cats.instances.list.* import cats.instances.list.*
import cats.kernel.{Monoid, Semigroup}
import cats.syntax.applicative.*
import cats.syntax.flatMap.*
import cats.syntax.functor.*
import cats.syntax.semigroup.*
import cats.syntax.traverse.*
import cats.syntax.validated.*
import scala.annotation.tailrec
import scribe.Logging import scribe.Logging
import scala.annotation.tailrec
object Linker extends Logging { object Linker extends Logging {
// Transpilation function for module
// (Imports contexts => Compilation result)
type TP = [F[_], T] =>> Map[String, T] => F[T]
// Dependency Cycle, prev element import next // Dependency Cycle, prev element import next
// and last imports head // and last imports head
type DepCycle[I] = NonEmptyChain[I] type DepCycle[I] = NonEmptyChain[I]
@ -23,8 +32,8 @@ object Linker extends Logging {
* @return [[List]] of dependecy cycles found * @return [[List]] of dependecy cycles found
*/ */
private def findDepCycles[I, E, T]( private def findDepCycles[I, E, T](
mods: List[AquaModule[I, E, T => T]] mods: List[AquaModule[I, E, T]]
): List[DepCycle[AquaModule[I, E, T => T]]] = { ): List[DepCycle[I]] = {
val modsIds = mods.map(_.id).toSet val modsIds = mods.map(_.id).toSet
// Limit search to only passed modules (there maybe dependencies not from `mods`) // Limit search to only passed modules (there maybe dependencies not from `mods`)
val deps = mods.map(m => m.id -> m.dependsOn.keySet.intersect(modsIds)).toMap val deps = mods.map(m => m.id -> m.dependsOn.keySet.intersect(modsIds)).toMap
@ -56,7 +65,7 @@ object Linker extends Logging {
) )
} }
val cycles = mods mods
.flatMap(m => .flatMap(m =>
findCycles( findCycles(
paths = NonEmptyChain.one(m.id) :: Nil, paths = NonEmptyChain.one(m.id) :: Nil,
@ -69,73 +78,83 @@ object Linker extends Logging {
// should not be a lot of cycles // should not be a lot of cycles
_.toChain.toList.toSet _.toChain.toList.toSet
) )
val modsById = mods.fproductLeft(_.id).toMap
// This should be safe
cycles.map(_.map(modsById))
} }
@tailrec /**
def iter[I, E, T: Semigroup]( * Main iterative linking function
mods: List[AquaModule[I, E, T => T]], * @param mods Modules to link
proc: Map[I, T => T], * @param proc Already processed modules
cycleError: DepCycle[AquaModule[I, E, T => T]] => E * @param cycle Function to create error from dependency cycle
): ValidatedNec[E, Map[I, T => T]] = * @return Result for all modules
*/
def iter[I, E, F[_], T](
mods: List[AquaModule[I, E, TP[F, T]]],
proc: Map[I, T],
cycle: DepCycle[I] => E
)(using me: MonadError[F, NonEmptyChain[E]]): F[Map[I, T]] =
mods match { mods match {
case Nil => case Nil =>
proc.valid proc.pure
case _ => case _ =>
val (canHandle, postpone) = mods.partition(_.dependsOn.keySet.forall(proc.contains)) // Find modules that can be processed
val (canHandle, postpone) = mods.partition(
_.dependsOn.keySet.forall(proc.contains)
)
logger.debug("ITERATE, can handle: " + canHandle.map(_.id)) logger.debug("ITERATE, can handle: " + canHandle.map(_.id))
logger.debug(s"dependsOn = ${mods.map(_.dependsOn.keySet)}") logger.debug(s"dependsOn = ${mods.map(_.dependsOn.keySet)}")
logger.debug(s"postpone = ${postpone.map(_.id)}") logger.debug(s"postpone = ${postpone.map(_.id)}")
logger.debug(s"proc = ${proc.keySet}") logger.debug(s"proc = ${proc.keySet}")
// If there are no modules that can be processed
if (canHandle.isEmpty && postpone.nonEmpty) { if (canHandle.isEmpty && postpone.nonEmpty) {
findDepCycles(postpone) me.raiseError(
.map(cycleError) // This should be safe as cycles should exist at this moment
.invalid NonEmptyChain
.leftMap( .fromSeq(findDepCycles(postpone).map(cycle))
// This should be safe as cycles should exist at this moment .get
errs => NonEmptyChain.fromSeq(errs).get )
) } else
} else { canHandle.traverse { mod =>
val folded = canHandle.foldLeft(proc) { case (acc, m) => // Gather all imports for module
val importKeys = m.dependsOn.keySet val imports = mod.imports.mapValues { imp =>
logger.debug(s"${m.id} dependsOn $importKeys") proc
val deps: T => T = .get(imp)
importKeys.map(acc).foldLeft(identity[T]) { case (fAcc, f) => .getOrElse(
logger.debug("COMBINING ONE TIME ") // Should not happen as we check it above
t => { internalError(s"Module $imp not found in $proc")
logger.debug(s"call combine $t") )
fAcc(t) |+| f(t) }.toMap
}
} // Process (transpile) module
acc + (m.id -> m.body.compose(deps)) mod.body(imports).map(mod.id -> _)
} }.flatMap(processed =>
iter( // flatMap should be stack safe
postpone, iter(
// TODO can be done in parallel postpone,
folded, proc ++ processed,
cycleError cycle
)
) )
}
}
def link[I, E, T: Semigroup](
modules: Modules[I, E, T => T],
cycleError: DepCycle[AquaModule[I, E, T => T]] => E,
empty: I => T
): ValidatedNec[E, Map[I, T]] =
if (modules.dependsOn.nonEmpty) Validated.invalid(modules.dependsOn.values.reduce(_ ++ _))
else {
val result = iter(modules.loaded.values.toList, Map.empty, cycleError)
result.map(_.collect {
case (i, f) if modules.exports(i) =>
i -> f(empty(i))
})
} }
/**
* Link modules
*
* @param modules Modules to link (with transpilation functions as bodies)
* @param cycle Function to create error from dependency cycle
* @return Result for all **exported** modules
*/
def link[I, E, F[_], T](
modules: Modules[I, E, TP[F, T]],
cycle: DepCycle[I] => E
)(using me: MonadError[F, NonEmptyChain[E]]): F[Map[I, T]] =
if (modules.dependsOn.nonEmpty)
me.raiseError(
modules.dependsOn.values.reduce(_ ++ _)
)
else
iter(modules.loaded.values.toList, Map.empty, cycle).map(
// Remove all modules that are not exported from result
_.filterKeys(modules.exports.contains).toMap
)
} }

View File

@ -1,6 +1,8 @@
package aqua.linker package aqua.linker
import cats.data.NonEmptyChain import cats.Foldable
import cats.data.{Chain, NonEmptyChain}
import cats.syntax.foldable._
import cats.syntax.option._ import cats.syntax.option._
case class Modules[I, E, T]( case class Modules[I, E, T](
@ -23,17 +25,23 @@ case class Modules[I, E, T](
exports = if (toExport) exports + aquaModule.id else exports exports = if (toExport) exports + aquaModule.id else exports
) )
def addAll[F[_]: Foldable](modules: F[AquaModule[I, E, T]]): Modules[I, E, T] =
modules.foldLeft(this)(_ add _)
def isResolved: Boolean = dependsOn.isEmpty def isResolved: Boolean = dependsOn.isEmpty
def map[TT](f: T => TT): Modules[I, E, TT] = def map[TT](f: T => TT): Modules[I, E, TT] =
copy(loaded = loaded.view.mapValues(_.map(f)).toMap) copy(loaded = loaded.view.mapValues(_.map(f)).toMap)
def mapModuleToBody[TT](f: AquaModule[I, E, T] => TT): Modules[I, E, TT] =
copy(loaded = loaded.view.mapValues(v => v.map(_ => f(v))).toMap)
def mapErr[EE](f: E => EE): Modules[I, EE, T] = def mapErr[EE](f: E => EE): Modules[I, EE, T] =
copy( copy(
loaded = loaded.view.mapValues(_.mapErr(f)).toMap, loaded = loaded.view.mapValues(_.mapErr(f)).toMap,
dependsOn = dependsOn.view.mapValues(_.map(f)).toMap dependsOn = dependsOn.view.mapValues(_.map(f)).toMap
) )
} }
object Modules {
def from[I, E, T](modules: Chain[AquaModule[I, E, T]]): Modules[I, E, T] =
modules.foldLeft(Modules[I, E, T]())(_.add(_, toExport = true))
}

View File

@ -1,44 +1,61 @@
package aqua.linker package aqua.linker
import cats.data.Validated import cats.Id
import cats.data.{EitherNec, NonEmptyChain}
import cats.syntax.either.*
import org.scalatest.flatspec.AnyFlatSpec import org.scalatest.flatspec.AnyFlatSpec
import org.scalatest.matchers.should.Matchers import org.scalatest.matchers.should.Matchers
class LinkerSpec extends AnyFlatSpec with Matchers { class LinkerSpec extends AnyFlatSpec with Matchers {
type TP = Map[String, String] => EitherNec[String, String]
val cycle: NonEmptyChain[String] => String =
_.toChain.toList.mkString(" -> ")
"linker" should "resolve dependencies" in { "linker" should "resolve dependencies" in {
val empty = Modules[String, String, String => String]() val empty = Modules[String, String, TP]()
val withMod1 = val withMod1 = empty.add(
empty AquaModule(
.add( id = "mod1",
AquaModule[String, String, String => String]( imports = Map("mod2" -> "mod2"),
id = "mod1", dependsOn = Map("mod2" -> "unresolved mod2 in mod1"),
imports = Map.empty, body = imports => {
dependsOn = Map("mod2" -> "unresolved mod2 in mod1"), println(s"mod1: $imports")
body = _ ++ " | mod1"
), imports
toExport = true .get("mod2")
) .toRight("mod2 not found in mod1")
.toEitherNec
.map(_ ++ " | mod1")
}
),
toExport = true
)
withMod1.isResolved should be(false) withMod1.isResolved should be(false)
Linker.link[String, String, String]( Linker.link(withMod1, cycle) should be(
withMod1, Left("unresolved mod2 in mod1").toEitherNec
cycle => cycle.map(_.id).toChain.toList.mkString(" -> "), )
_ => ""
) should be(Validated.invalidNec("unresolved mod2 in mod1"))
val withMod2 = val withMod2 = withMod1.add(
withMod1.add(AquaModule("mod2", Map.empty, Map.empty, _ ++ " | mod2")) AquaModule(
id = "mod2",
imports = Map.empty,
dependsOn = Map.empty,
body = _ => "mod2".asRight.toEitherNec
)
)
withMod2.isResolved should be(true) withMod2.isResolved should be(true)
Linker.link[String, String, String]( Linker.link(withMod2, cycle) should be(
withMod2, Map(
cycle => cycle.map(_.id + "?").toChain.toList.mkString(" -> "), "mod1" -> "mod2 | mod1"
_ => "" ).asRight.toEitherNec
) should be(Validated.validNec(Map("mod1" -> " | mod2 | mod1"))) )
} }
} }