From 4c908a86ea1d63772fbee8cce98c5ff417dd5c7f Mon Sep 17 00:00:00 2001 From: Matt Bovel Date: Thu, 5 Dec 2024 00:11:35 +0000 Subject: [PATCH] Add support for separate compilation --- .../main/scala/stainless/MainHelpers.scala | 1 + .../scala/stainless/frontend/package.scala | 7 ++++ .../frontends/dotc/DottyCompiler.scala | 12 ++++-- .../frontends/dotc/StainlessExtraction.scala | 38 +++++++++++++++---- .../frontends/dotc/SymbolMapping.scala | 11 +++++- 5 files changed, 56 insertions(+), 13 deletions(-) diff --git a/core/src/main/scala/stainless/MainHelpers.scala b/core/src/main/scala/stainless/MainHelpers.scala index 85135bba4e..f4bc6a9b7d 100644 --- a/core/src/main/scala/stainless/MainHelpers.scala +++ b/core/src/main/scala/stainless/MainHelpers.scala @@ -74,6 +74,7 @@ trait MainHelpers extends inox.MainHelpers { self => frontend.optKeep -> Description(General, "Keep library objects marked by @keepFor(g) for some g in g1,g2,... (implies --batched)"), frontend.optExtraDeps -> Description(General, "Fetch the specified extra source dependencies and add their source files to the session"), frontend.optExtraResolvers -> Description(General, "Extra resolvers to use to fetch extra source dependencies"), + frontend.optClasspath -> Description(General, "Add the specified directory to the classpath"), utils.Caches.optCacheDir -> Description(General, "Specify the directory in which cache files should be stored"), utils.Caches.optBinaryCache -> Description(General, "Set Binary mode for the cache instead of Hash mode, i.e., the cache will contain the entire VC and program in serialized format. This is less space efficient."), testgen.optOutputFile -> Description(TestsGeneration, "Specify the output file"), diff --git a/core/src/main/scala/stainless/frontend/package.scala b/core/src/main/scala/stainless/frontend/package.scala index c1c4e87aa4..fe79b94b0f 100644 --- a/core/src/main/scala/stainless/frontend/package.scala +++ b/core/src/main/scala/stainless/frontend/package.scala @@ -23,6 +23,13 @@ package object frontend { */ object optBatchedProgram extends inox.FlagOptionDef("batched", false) + object optClasspath extends inox.OptionDef[Option[String]] { + val name = "classpath" + val default = None + val parser = input => Some(Some(input)) + val usageRhs = "DIR" + } + /** * Given a context (with its reporter) and a frontend factory, proceed to compile, * extract, transform and verify the input programs based on the active components diff --git a/frontends/dotty/src/main/scala/stainless/frontends/dotc/DottyCompiler.scala b/frontends/dotty/src/main/scala/stainless/frontends/dotc/DottyCompiler.scala index d0d1d48067..14efc27d22 100644 --- a/frontends/dotty/src/main/scala/stainless/frontends/dotc/DottyCompiler.scala +++ b/frontends/dotty/src/main/scala/stainless/frontends/dotc/DottyCompiler.scala @@ -50,7 +50,10 @@ class DottyCompiler(ctx: inox.Context, callback: CallBack) extends Compiler { override def runOn(units: List[CompilationUnit])(using dottyCtx: DottyContext): List[CompilationUnit] = { exportedSymsMapping = exportedSymbolsMapping(ctx, this.start, units) - super.runOn(units) + val res = super.runOn(units) + extraction.extractClasspathUnits(exportedSymsMapping).foreach(extracted => + callback(extracted.file, extracted.unit, extracted.classes, extracted.functions, extracted.typeDefs)) + res } } @@ -153,8 +156,9 @@ object DottyCompiler { x => new File(x.getLocation.toURI).getAbsolutePath } getOrElse { ctx.reporter.fatalError("No Scala 3 library found.") } - val cps = Seq(scala213Lib, scala3Lib).distinct.mkString(java.io.File.pathSeparator) - val flags = Seq("-color:never", "-language:implicitConversions", "-Wsafe-init", s"-cp:$cps") // -Ysafe-init is deprecated (SAM 21.08.2024) + val extraCps = ctx.options.findOptionOrDefault(frontend.optClasspath).toSeq + val cps = (extraCps ++ Seq(scala213Lib, scala3Lib)).distinct.mkString(java.io.File.pathSeparator) + val flags = Seq("-Yretain-trees", "-color:never", "-language:implicitConversions", "-Wsafe-init", s"-cp:$cps") // -Ysafe-init is deprecated (SAM 21.08.2024) allCompilerArguments(ctx, compilerArgs) ++ flags } val compiler: DottyCompiler = new DottyCompiler(ctx, this.callback) @@ -180,4 +184,4 @@ object DottyCompiler { } } } -} \ No newline at end of file +} diff --git a/frontends/dotty/src/main/scala/stainless/frontends/dotc/StainlessExtraction.scala b/frontends/dotty/src/main/scala/stainless/frontends/dotc/StainlessExtraction.scala index 9c1984384d..912958f7c9 100644 --- a/frontends/dotty/src/main/scala/stainless/frontends/dotc/StainlessExtraction.scala +++ b/frontends/dotty/src/main/scala/stainless/frontends/dotc/StainlessExtraction.scala @@ -11,6 +11,7 @@ import transform._ import ast.tpd import ast.Trees._ import typer._ +import util.SourceFile import extraction.xlang.{trees => xt} import frontend.{CallBack, Frontend, FrontendFactory, ThreadedFrontend, UnsupportedCodeException} @@ -22,35 +23,43 @@ class StainlessExtraction(val inoxCtx: inox.Context) { private val symbolMapping = new SymbolMapping def extractUnit(exportedSymsMapping: ExportedSymbolsMapping)(using ctx: DottyContext): Option[ExtractedUnit] = { + val unit = ctx.compilationUnit + val tree = unit.tpdTree + extractUnit(tree, unit.source, exportedSymsMapping) + } + + def extractUnit( + tree: tpd.Tree, + source: SourceFile, + exportedSymsMapping: ExportedSymbolsMapping + )(using ctx: DottyContext): Option[ExtractedUnit] = { // Remark: the method `extractUnit` is called for each compilation unit (which corresponds more or less to a Scala file) // Therefore, the symbolMapping instances needs to be shared accross compilation unit. // Since `extractUnit` is called within the same thread, we do not need to synchronize accesses to symbolMapping. val extraction = new CodeExtraction(inoxCtx, symbolMapping, exportedSymsMapping) import extraction._ - val unit = ctx.compilationUnit - val tree = unit.tpdTree val (id, stats) = tree match { case pd@PackageDef(_, lst) => val id = lst.collectFirst { case PackageDef(ref, _) => ref } match { case Some(ref) => extractRef(ref) - case None => FreshIdentifier(unit.source.file.name.replaceFirst("[.][^.]+$", "")) + case None => FreshIdentifier(source.file.name.replaceFirst("[.][^.]+$", "")) } (id, pd.stats) case _ => - (FreshIdentifier(unit.source.file.name.replaceFirst("[.][^.]+$", "")), List.empty) + (FreshIdentifier(source.file.name.replaceFirst("[.][^.]+$", "")), List.empty) } val fragmentChecker = new FragmentChecker(inoxCtx) fragmentChecker.ghostChecker(tree) fragmentChecker.checker(tree) - if (!fragmentChecker.hasErrors()) tryExtractUnit(extraction, unit, id, stats) + if (!fragmentChecker.hasErrors()) tryExtractUnit(extraction, source, id, stats) else None } private def tryExtractUnit(extraction: CodeExtraction, - unit: CompilationUnit, + source: SourceFile, id: Identifier, stats: List[tpd.Tree])(using DottyContext): Option[ExtractedUnit] = { // If the user annotates a function with @main, the compiler will generate a top-level class @@ -67,7 +76,7 @@ class StainlessExtraction(val inoxCtx: inox.Context) { try { val (imports, unitClasses, unitFunctions, _, subs, classes, functions, typeDefs) = extraction.extractStatic(filteredStats) assert(unitFunctions.isEmpty, "Packages shouldn't contain functions") - val file = unit.source.file.absolute.path + val file = source.file.absolute.path val isLibrary = stainless.Main.libraryFiles contains file val xtUnit = xt.UnitDef(id, imports, unitClasses, subs, !isLibrary) Some(ExtractedUnit(file, xtUnit, classes, functions, typeDefs)) @@ -92,4 +101,19 @@ class StainlessExtraction(val inoxCtx: inox.Context) { } trAcc(None, stats) } + + def extractClasspathUnits(exportedSymsMapping: ExportedSymbolsMapping)(using DottyContext): Seq[ExtractedUnit] = { + @scala.annotation.tailrec + def loop(units: Map[ClassSymbol, ExtractedUnit]): Seq[ExtractedUnit] = + val newUnits = + symbolMapping + .getUsedTastyClasses() + .filterNot(units.contains) + .map(sym => sym -> extractUnit(sym.rootTree, sym.sourceOfClass, exportedSymsMapping).get) + .toMap + if (newUnits.isEmpty) units.values.toSeq + else loop(units ++ newUnits) + + loop(Map.empty) + } } diff --git a/frontends/dotty/src/main/scala/stainless/frontends/dotc/SymbolMapping.scala b/frontends/dotty/src/main/scala/stainless/frontends/dotc/SymbolMapping.scala index 64d2aee84f..67a22fbc6b 100644 --- a/frontends/dotty/src/main/scala/stainless/frontends/dotc/SymbolMapping.scala +++ b/frontends/dotty/src/main/scala/stainless/frontends/dotc/SymbolMapping.scala @@ -10,7 +10,7 @@ import dotty.tools.dotc.core.Symbols._ import dotty.tools.dotc.core.Contexts._ import stainless.ast.SymbolIdentifier -import scala.collection.mutable.{ Map => MutableMap } +import scala.collection.mutable.{ Map => MutableMap, Set => MutableSet } class SymbolMapping { import SymbolMapping._ @@ -25,10 +25,16 @@ class SymbolMapping { private val s2sAccessor = MutableMap[Symbol, SymbolIdentifier]() private val s2sEnumType = MutableMap[Symbol, SymbolIdentifier]() + private val usedTastyClasses = MutableSet[ClassSymbol]() + def getUsedTastyClasses(): Set[ClassSymbol] = usedTastyClasses.toSet + /** Get the identifier associated with the given [[sym]], creating a new one if needed. */ def fetch(sym: Symbol, mode: FetchingMode)(using Context): SymbolIdentifier = mode match { case Plain => s2s.getOrElseUpdate(sym, { + if (sym.tastyInfo.isDefined) { + usedTastyClasses += sym.topLevelClass.asClass + } val overrides = sym.allOverriddenSymbols.toSeq val top = overrides.lastOption.getOrElse(sym) if (top eq sym) { @@ -66,6 +72,7 @@ class SymbolMapping { res }) } + } object SymbolMapping { @@ -93,4 +100,4 @@ object SymbolMapping { } .mkString(".") } -} \ No newline at end of file +}