diff --git a/compiler/src/main/scala/aqua/compiler/AquaCompiler.scala b/compiler/src/main/scala/aqua/compiler/AquaCompiler.scala index 94643ff8..bb67fac2 100644 --- a/compiler/src/main/scala/aqua/compiler/AquaCompiler.scala +++ b/compiler/src/main/scala/aqua/compiler/AquaCompiler.scala @@ -1,6 +1,7 @@ package aqua.compiler import aqua.compiler.AquaError.{ParserError as AquaParserError, *} +import aqua.linker.Linker.link import aqua.linker.{AquaModule, Linker, Modules} import aqua.parser.{Ast, ParserError} import aqua.semantics.header.{HeaderHandler, Picker} @@ -22,43 +23,31 @@ class AquaCompiler[F[_]: Monad, E, I: Order, S[_]: Comonad, C: Monoid: Picker]( ) extends Logging { type Err = AquaError[I, E, S] - type Ctx = NonEmptyMap[I, C] type CompileWarns = [A] =>> CompileWarnings[S][A] type CompileRes = [A] =>> CompileResult[I, E, S][A] - type CompiledCtx = CompileRes[Ctx] - type CompiledCtxT = CompiledCtx => CompiledCtx + // Transpilation function for module + // (Imports contexts => Compilation result) + type TP = Map[String, C] => CompileRes[C] - private def linkModules( - modules: Modules[I, Err, CompiledCtxT], - cycleError: Linker.DepCycle[AquaModule[I, Err, CompiledCtxT]] => Err - ): CompileRes[Map[I, C]] = { - logger.trace("linking modules...") - - // By default, provide an empty context for this module's id - val empty: I => CompiledCtx = i => NonEmptyMap.one(i, Monoid[C].empty).pure[CompileRes] - - for { - linked <- Linker - .link(modules, cycleError, empty) - .toEither - .toEitherT[CompileWarns] - res <- EitherT( - linked.toList.traverse { case (id, ctx) => - ctx - .map( - /** - * NOTE: This should be safe - * as result for id should contain itself - */ - _.apply(id).map(id -> _).get - ) - .toValidated - }.map(_.sequence.toEither) - ) - } yield res.toMap - } + private def transpile(body: Ast[S]): TP = + imports => + for { + // Process header, get initial context + headerSem <- headerHandler + .sem(imports, body.head) + .toCompileRes + // Analyze the body, with prepared initial context + _ = logger.trace("semantic processing...") + processed <- semantics + .process(body, headerSem.initCtx) + .toCompileRes + // Handle exports, declares - finalize the resulting context + rc <- headerSem + .finCtx(processed) + .toCompileRes + } yield rc def compileRaw( sources: AquaSources[F, E, I], @@ -66,48 +55,18 @@ class AquaCompiler[F[_]: Monad, E, I: Order, S[_]: Comonad, C: Monoid: Picker]( ): F[CompileRes[Map[I, C]]] = { logger.trace("starting resolving sources...") - new AquaParser[F, E, I, S](sources, parser) - .resolve[CompiledCtx](mod => - context => - for { - // Context with prepared imports - ctx <- context - imports = mod.imports.flatMap { case (fn, id) => - ctx.apply(id).map(fn -> _) - } - header = mod.body.head - headerSem <- headerHandler - .sem(imports, header) - .toCompileRes - // Analyze the body, with prepared initial context - _ = logger.trace("semantic processing...") - processed <- semantics - .process( - mod.body, - headerSem.initCtx - ) - .toCompileRes - // Handle exports, declares - finalize the resulting context - rc <- headerSem - .finCtx(processed) - .toCompileRes - /** - * Here we build a map of contexts while processing modules. - * Should not linker provide this info inside this process? - * Building this map complicates things a lot. - */ - } yield NonEmptyMap.one(mod.id, rc) - ) - .value - .map(resolved => - for { - modules <- resolved.toEitherT[CompileWarns] - linked <- linkModules( - modules, - cycle => CycleError(cycle.map(_.id)) - ) - } yield linked - ) + val parsing = new AquaParser(sources, parser) + + parsing.resolve.value.map(resolution => + for { + // Lift resolution to CompileRes + modules <- resolution.toEitherT[CompileWarns] + // Generate transpilation functions for each module + transpiled = modules.map(body => transpile(body)) + // Link modules + linked <- Linker.link(transpiled, CycleError.apply) + } yield linked + ) } private val warningsK: semantics.Warnings ~> CompileWarns = diff --git a/compiler/src/main/scala/aqua/compiler/AquaParser.scala b/compiler/src/main/scala/aqua/compiler/AquaParser.scala index 70262764..952fec14 100644 --- a/compiler/src/main/scala/aqua/compiler/AquaParser.scala +++ b/compiler/src/main/scala/aqua/compiler/AquaParser.scala @@ -5,6 +5,7 @@ import aqua.linker.{AquaModule, Modules} import aqua.parser.head.{FilenameExpr, ImportExpr} import aqua.parser.lift.{LiftParser, Span} import aqua.parser.{Ast, ParserError} +import aqua.syntax.eithert.fromValidatedF import cats.data.Chain.* import cats.data.Validated.* @@ -16,6 +17,7 @@ import cats.syntax.flatMap.* import cats.syntax.foldable.* import cats.syntax.functor.* import cats.syntax.monad.* +import cats.syntax.parallel.* import cats.syntax.traverse.* import cats.syntax.validated.* import cats.{Comonad, Monad, ~>} @@ -32,108 +34,82 @@ class AquaParser[F[_]: Monad, E, I, S[_]: Comonad]( private type FE[A] = EitherT[F, NonEmptyChain[Err], A] - // Parse all the source files - private def parseSources: F[ValidatedNec[Err, Chain[(I, Body)]]] = - sources.sources.map( - _.leftMap(_.map(SourcesError.apply)).andThen( - _.traverse { case (i, s) => - parser(i)(s).bimap( - _.map(AquaParserError.apply), - ast => i -> ast - ) - } - ) + // Parse one source (text) + private def parse(id: I, src: String): EitherNec[Err, (I, Body)] = + parser(id)(src).toEither.bimap( + _.map(AquaParserError.apply), + ast => id -> ast ) + // Parse all the source files + private def parseSources: FE[Chain[(I, Body)]] = + for { + srcs <- EitherT + .fromValidatedF(sources.sources) + .leftMap(_.map(SourcesError.apply)) + parsed <- srcs + .parTraverse(parse.tupled) + .toEitherT + } yield parsed + + // Load one module (parse, resolve imports) + private def loadModule(id: I): FE[AquaModule[I, Err, Body]] = + for { + src <- EitherT + .fromValidatedF(sources.load(id)) + .leftMap(_.map(SourcesError.apply)) + parsed <- parse(id, src).toEitherT + (id, ast) = parsed + resolved <- resolveImports(id, ast) + } yield resolved + // Resolve imports (not parse, just resolve) of the given file - private def resolveImports(id: I, ast: Body): F[ValidatedNec[Err, AquaModule[I, Err, Body]]] = + private def resolveImports(id: I, ast: Body): FE[AquaModule[I, Err, Body]] = ast.head.collect { case fe: FilenameExpr[S] => fe.fileValue -> fe.token - }.traverse { case (filename, token) => - sources - .resolveImport(id, filename) - .map( - _.bimap( - _.map(ResolveImportsError(id, token, _): Err), - importId => importId -> (filename, ImportError(token): Err) - ) + }.parTraverse { case (filename, token) => + EitherT + .fromValidatedF( + sources.resolveImport(id, filename) ) - }.map(_.sequence.map { collected => - AquaModule[I, Err, Body]( - id, + .bimap( + _.map(ResolveImportsError(id, token, _): Err), + importId => importId -> (filename, ImportError(token): Err) + ) + }.map { collected => + AquaModule( + id = id, // How filenames correspond to the resolved IDs - collected.map { case (i, (fn, _)) => + imports = collected.map { case (i, (fn, _)) => fn -> i - }.toList.toMap[String, I], + }.toList.toMap, // Resolved IDs to errors that point to the import in source code - collected.map { case (i, (_, err)) => + dependsOn = collected.map { case (i, (_, err)) => i -> err - }.toList.toMap[I, Err], - ast + }.toList.toMap, + body = ast ) - }) - - // Parse sources, convert to modules - private def sourceModules: F[ValidatedNec[Err, Modules[I, Err, Body]]] = - parseSources.flatMap { - case Validated.Valid(srcs) => - srcs.traverse { case (id, ast) => - resolveImports(id, ast) - }.map(_.sequence) - case Validated.Invalid(errs) => - errs.invalid.pure[F] - }.map( - _.map( - _.foldLeft(Modules[I, Err, Body]())( - _.add(_, toExport = true) - ) - ) - ) - - private def loadModule(imp: I): F[ValidatedNec[Err, AquaModule[I, Err, Body]]] = - sources - .load(imp) - .map(_.leftMap(_.map(SourcesError.apply)).andThen { src => - parser(imp)(src).leftMap(_.map(AquaParserError.apply)) - }) - .flatMap { - case Validated.Valid(ast) => - resolveImports(imp, ast) - case Validated.Invalid(errs) => - errs.invalid.pure[F] - } - - private def resolveModules( - modules: Modules[I, Err, Body] - ): F[ValidatedNec[Err, Modules[I, Err, Ast[S]]]] = - modules.dependsOn.toList.traverse { case (moduleId, unresolvedErrors) => - loadModule(moduleId).map(_.leftMap(_ ++ unresolvedErrors)) - }.map( - _.sequence.map( - _.foldLeft(modules)(_ add _) - ) - ).flatMap { - case Validated.Valid(ms) if ms.isResolved => - ms.validNec.pure[F] - case Validated.Valid(ms) => - resolveModules(ms) - case err => - err.pure[F] } - private def resolveSources: FE[Modules[I, Err, Ast[S]]] = + // Load modules (parse, resolve imports) of all the source files + private lazy val loadModules: FE[Modules[I, Err, Body]] = for { - ms <- EitherT( - sourceModules.map(_.toEither) - ) - res <- EitherT( - resolveModules(ms).map(_.toEither) - ) - } yield res + srcs <- parseSources + modules <- srcs.parTraverse(resolveImports.tupled) + } yield Modules.from(modules) - def resolve[T]( - transpile: AquaModule[I, Err, Body] => T => T - ): FE[Modules[I, Err, T => T]] = - resolveSources.map(_.mapModuleToBody(transpile)) + // Resolve modules (load all the dependencies) + private def resolveModules( + modules: Modules[I, Err, Body] + ): FE[Modules[I, Err, Ast[S]]] = + modules.iterateUntilM(ms => + // Load all modules that are dependencies of the current modules + ms.dependsOn.toList.parTraverse { case (moduleId, unresolvedErrors) => + loadModule(moduleId).leftMap(_ ++ unresolvedErrors) + }.map(ms.addAll) // Add all loaded modules to the current modules + )(_.isResolved) + + lazy val resolve: FE[Modules[I, Err, Body]] = + loadModules >>= resolveModules } diff --git a/linker/src/main/scala/aqua/linker/Linker.scala b/linker/src/main/scala/aqua/linker/Linker.scala index 8669e201..dab46f45 100644 --- a/linker/src/main/scala/aqua/linker/Linker.scala +++ b/linker/src/main/scala/aqua/linker/Linker.scala @@ -1,17 +1,26 @@ package aqua.linker +import aqua.errors.Errors.internalError + +import cats.MonadError import cats.data.{NonEmptyChain, Validated, ValidatedNec} -import cats.kernel.{Monoid, Semigroup} -import cats.syntax.semigroup.* -import cats.syntax.validated.* -import cats.syntax.functor.* import cats.instances.list.* +import cats.kernel.{Monoid, Semigroup} +import cats.syntax.applicative.* +import cats.syntax.flatMap.* +import cats.syntax.functor.* +import cats.syntax.semigroup.* +import cats.syntax.traverse.* +import cats.syntax.validated.* +import scala.annotation.tailrec import scribe.Logging -import scala.annotation.tailrec - object Linker extends Logging { + // Transpilation function for module + // (Imports contexts => Compilation result) + type TP = [F[_], T] =>> Map[String, T] => F[T] + // Dependency Cycle, prev element import next // and last imports head type DepCycle[I] = NonEmptyChain[I] @@ -23,8 +32,8 @@ object Linker extends Logging { * @return [[List]] of dependecy cycles found */ private def findDepCycles[I, E, T]( - mods: List[AquaModule[I, E, T => T]] - ): List[DepCycle[AquaModule[I, E, T => T]]] = { + mods: List[AquaModule[I, E, T]] + ): List[DepCycle[I]] = { val modsIds = mods.map(_.id).toSet // Limit search to only passed modules (there maybe dependencies not from `mods`) val deps = mods.map(m => m.id -> m.dependsOn.keySet.intersect(modsIds)).toMap @@ -56,7 +65,7 @@ object Linker extends Logging { ) } - val cycles = mods + mods .flatMap(m => findCycles( paths = NonEmptyChain.one(m.id) :: Nil, @@ -69,73 +78,83 @@ object Linker extends Logging { // should not be a lot of cycles _.toChain.toList.toSet ) - - val modsById = mods.fproductLeft(_.id).toMap - - // This should be safe - cycles.map(_.map(modsById)) } - @tailrec - def iter[I, E, T: Semigroup]( - mods: List[AquaModule[I, E, T => T]], - proc: Map[I, T => T], - cycleError: DepCycle[AquaModule[I, E, T => T]] => E - ): ValidatedNec[E, Map[I, T => T]] = + /** + * Main iterative linking function + * @param mods Modules to link + * @param proc Already processed modules + * @param cycle Function to create error from dependency cycle + * @return Result for all modules + */ + def iter[I, E, F[_], T]( + mods: List[AquaModule[I, E, TP[F, T]]], + proc: Map[I, T], + cycle: DepCycle[I] => E + )(using me: MonadError[F, NonEmptyChain[E]]): F[Map[I, T]] = mods match { case Nil => - proc.valid + proc.pure case _ => - val (canHandle, postpone) = mods.partition(_.dependsOn.keySet.forall(proc.contains)) + // Find modules that can be processed + val (canHandle, postpone) = mods.partition( + _.dependsOn.keySet.forall(proc.contains) + ) logger.debug("ITERATE, can handle: " + canHandle.map(_.id)) logger.debug(s"dependsOn = ${mods.map(_.dependsOn.keySet)}") logger.debug(s"postpone = ${postpone.map(_.id)}") logger.debug(s"proc = ${proc.keySet}") + // If there are no modules that can be processed if (canHandle.isEmpty && postpone.nonEmpty) { - findDepCycles(postpone) - .map(cycleError) - .invalid - .leftMap( - // This should be safe as cycles should exist at this moment - errs => NonEmptyChain.fromSeq(errs).get - ) - } else { - val folded = canHandle.foldLeft(proc) { case (acc, m) => - val importKeys = m.dependsOn.keySet - logger.debug(s"${m.id} dependsOn $importKeys") - val deps: T => T = - importKeys.map(acc).foldLeft(identity[T]) { case (fAcc, f) => - logger.debug("COMBINING ONE TIME ") - t => { - logger.debug(s"call combine $t") - fAcc(t) |+| f(t) - } - } - acc + (m.id -> m.body.compose(deps)) - } - iter( - postpone, - // TODO can be done in parallel - folded, - cycleError + me.raiseError( + // This should be safe as cycles should exist at this moment + NonEmptyChain + .fromSeq(findDepCycles(postpone).map(cycle)) + .get + ) + } else + canHandle.traverse { mod => + // Gather all imports for module + val imports = mod.imports.mapValues { imp => + proc + .get(imp) + .getOrElse( + // Should not happen as we check it above + internalError(s"Module $imp not found in $proc") + ) + }.toMap + + // Process (transpile) module + mod.body(imports).map(mod.id -> _) + }.flatMap(processed => + // flatMap should be stack safe + iter( + postpone, + proc ++ processed, + cycle + ) ) - } - } - - def link[I, E, T: Semigroup]( - modules: Modules[I, E, T => T], - cycleError: DepCycle[AquaModule[I, E, T => T]] => E, - empty: I => T - ): ValidatedNec[E, Map[I, T]] = - if (modules.dependsOn.nonEmpty) Validated.invalid(modules.dependsOn.values.reduce(_ ++ _)) - else { - val result = iter(modules.loaded.values.toList, Map.empty, cycleError) - - result.map(_.collect { - case (i, f) if modules.exports(i) => - i -> f(empty(i)) - }) } + /** + * Link modules + * + * @param modules Modules to link (with transpilation functions as bodies) + * @param cycle Function to create error from dependency cycle + * @return Result for all **exported** modules + */ + def link[I, E, F[_], T]( + modules: Modules[I, E, TP[F, T]], + cycle: DepCycle[I] => E + )(using me: MonadError[F, NonEmptyChain[E]]): F[Map[I, T]] = + if (modules.dependsOn.nonEmpty) + me.raiseError( + modules.dependsOn.values.reduce(_ ++ _) + ) + else + iter(modules.loaded.values.toList, Map.empty, cycle).map( + // Remove all modules that are not exported from result + _.filterKeys(modules.exports.contains).toMap + ) } diff --git a/linker/src/main/scala/aqua/linker/Modules.scala b/linker/src/main/scala/aqua/linker/Modules.scala index 5371b58d..e3270209 100644 --- a/linker/src/main/scala/aqua/linker/Modules.scala +++ b/linker/src/main/scala/aqua/linker/Modules.scala @@ -1,6 +1,8 @@ package aqua.linker -import cats.data.NonEmptyChain +import cats.Foldable +import cats.data.{Chain, NonEmptyChain} +import cats.syntax.foldable._ import cats.syntax.option._ case class Modules[I, E, T]( @@ -23,17 +25,23 @@ case class Modules[I, E, T]( exports = if (toExport) exports + aquaModule.id else exports ) + def addAll[F[_]: Foldable](modules: F[AquaModule[I, E, T]]): Modules[I, E, T] = + modules.foldLeft(this)(_ add _) + def isResolved: Boolean = dependsOn.isEmpty def map[TT](f: T => TT): Modules[I, E, TT] = copy(loaded = loaded.view.mapValues(_.map(f)).toMap) - def mapModuleToBody[TT](f: AquaModule[I, E, T] => TT): Modules[I, E, TT] = - copy(loaded = loaded.view.mapValues(v => v.map(_ => f(v))).toMap) - def mapErr[EE](f: E => EE): Modules[I, EE, T] = copy( loaded = loaded.view.mapValues(_.mapErr(f)).toMap, dependsOn = dependsOn.view.mapValues(_.map(f)).toMap ) } + +object Modules { + + def from[I, E, T](modules: Chain[AquaModule[I, E, T]]): Modules[I, E, T] = + modules.foldLeft(Modules[I, E, T]())(_.add(_, toExport = true)) +} diff --git a/linker/src/test/scala/aqua/linker/LinkerSpec.scala b/linker/src/test/scala/aqua/linker/LinkerSpec.scala index 065a5ee5..9c7b3c6f 100644 --- a/linker/src/test/scala/aqua/linker/LinkerSpec.scala +++ b/linker/src/test/scala/aqua/linker/LinkerSpec.scala @@ -1,44 +1,61 @@ package aqua.linker -import cats.data.Validated +import cats.Id +import cats.data.{EitherNec, NonEmptyChain} +import cats.syntax.either.* import org.scalatest.flatspec.AnyFlatSpec import org.scalatest.matchers.should.Matchers class LinkerSpec extends AnyFlatSpec with Matchers { + type TP = Map[String, String] => EitherNec[String, String] + + val cycle: NonEmptyChain[String] => String = + _.toChain.toList.mkString(" -> ") + "linker" should "resolve dependencies" in { - val empty = Modules[String, String, String => String]() + val empty = Modules[String, String, TP]() - val withMod1 = - empty - .add( - AquaModule[String, String, String => String]( - id = "mod1", - imports = Map.empty, - dependsOn = Map("mod2" -> "unresolved mod2 in mod1"), - body = _ ++ " | mod1" - ), - toExport = true - ) + val withMod1 = empty.add( + AquaModule( + id = "mod1", + imports = Map("mod2" -> "mod2"), + dependsOn = Map("mod2" -> "unresolved mod2 in mod1"), + body = imports => { + println(s"mod1: $imports") + + imports + .get("mod2") + .toRight("mod2 not found in mod1") + .toEitherNec + .map(_ ++ " | mod1") + } + ), + toExport = true + ) withMod1.isResolved should be(false) - Linker.link[String, String, String]( - withMod1, - cycle => cycle.map(_.id).toChain.toList.mkString(" -> "), - _ => "" - ) should be(Validated.invalidNec("unresolved mod2 in mod1")) + Linker.link(withMod1, cycle) should be( + Left("unresolved mod2 in mod1").toEitherNec + ) - val withMod2 = - withMod1.add(AquaModule("mod2", Map.empty, Map.empty, _ ++ " | mod2")) + val withMod2 = withMod1.add( + AquaModule( + id = "mod2", + imports = Map.empty, + dependsOn = Map.empty, + body = _ => "mod2".asRight.toEitherNec + ) + ) withMod2.isResolved should be(true) - Linker.link[String, String, String]( - withMod2, - cycle => cycle.map(_.id + "?").toChain.toList.mkString(" -> "), - _ => "" - ) should be(Validated.validNec(Map("mod1" -> " | mod2 | mod1"))) + Linker.link(withMod2, cycle) should be( + Map( + "mod1" -> "mod2 | mod1" + ).asRight.toEitherNec + ) } }