Skip to content

Commit

Permalink
Add support for separate compilation
Browse files Browse the repository at this point in the history
  • Loading branch information
mbovel committed Dec 6, 2024
1 parent fdf67d4 commit 4c908a8
Show file tree
Hide file tree
Showing 5 changed files with 56 additions and 13 deletions.
1 change: 1 addition & 0 deletions core/src/main/scala/stainless/MainHelpers.scala
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,7 @@ trait MainHelpers extends inox.MainHelpers { self =>
frontend.optKeep -> Description(General, "Keep library objects marked by @keepFor(g) for some g in g1,g2,... (implies --batched)"),
frontend.optExtraDeps -> Description(General, "Fetch the specified extra source dependencies and add their source files to the session"),
frontend.optExtraResolvers -> Description(General, "Extra resolvers to use to fetch extra source dependencies"),
frontend.optClasspath -> Description(General, "Add the specified directory to the classpath"),
utils.Caches.optCacheDir -> Description(General, "Specify the directory in which cache files should be stored"),
utils.Caches.optBinaryCache -> Description(General, "Set Binary mode for the cache instead of Hash mode, i.e., the cache will contain the entire VC and program in serialized format. This is less space efficient."),
testgen.optOutputFile -> Description(TestsGeneration, "Specify the output file"),
Expand Down
7 changes: 7 additions & 0 deletions core/src/main/scala/stainless/frontend/package.scala
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,13 @@ package object frontend {
*/
object optBatchedProgram extends inox.FlagOptionDef("batched", false)

object optClasspath extends inox.OptionDef[Option[String]] {
val name = "classpath"
val default = None
val parser = input => Some(Some(input))
val usageRhs = "DIR"
}

/**
* Given a context (with its reporter) and a frontend factory, proceed to compile,
* extract, transform and verify the input programs based on the active components
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,10 @@ class DottyCompiler(ctx: inox.Context, callback: CallBack) extends Compiler {

override def runOn(units: List[CompilationUnit])(using dottyCtx: DottyContext): List[CompilationUnit] = {
exportedSymsMapping = exportedSymbolsMapping(ctx, this.start, units)
super.runOn(units)
val res = super.runOn(units)
extraction.extractClasspathUnits(exportedSymsMapping).foreach(extracted =>
callback(extracted.file, extracted.unit, extracted.classes, extracted.functions, extracted.typeDefs))
res
}
}

Expand Down Expand Up @@ -153,8 +156,9 @@ object DottyCompiler {
x => new File(x.getLocation.toURI).getAbsolutePath
} getOrElse { ctx.reporter.fatalError("No Scala 3 library found.") }

val cps = Seq(scala213Lib, scala3Lib).distinct.mkString(java.io.File.pathSeparator)
val flags = Seq("-color:never", "-language:implicitConversions", "-Wsafe-init", s"-cp:$cps") // -Ysafe-init is deprecated (SAM 21.08.2024)
val extraCps = ctx.options.findOptionOrDefault(frontend.optClasspath).toSeq
val cps = (extraCps ++ Seq(scala213Lib, scala3Lib)).distinct.mkString(java.io.File.pathSeparator)
val flags = Seq("-Yretain-trees", "-color:never", "-language:implicitConversions", "-Wsafe-init", s"-cp:$cps") // -Ysafe-init is deprecated (SAM 21.08.2024)
allCompilerArguments(ctx, compilerArgs) ++ flags
}
val compiler: DottyCompiler = new DottyCompiler(ctx, this.callback)
Expand All @@ -180,4 +184,4 @@ object DottyCompiler {
}
}
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ import transform._
import ast.tpd
import ast.Trees._
import typer._
import util.SourceFile

import extraction.xlang.{trees => xt}
import frontend.{CallBack, Frontend, FrontendFactory, ThreadedFrontend, UnsupportedCodeException}
Expand All @@ -22,35 +23,43 @@ class StainlessExtraction(val inoxCtx: inox.Context) {
private val symbolMapping = new SymbolMapping

def extractUnit(exportedSymsMapping: ExportedSymbolsMapping)(using ctx: DottyContext): Option[ExtractedUnit] = {
val unit = ctx.compilationUnit
val tree = unit.tpdTree
extractUnit(tree, unit.source, exportedSymsMapping)
}

def extractUnit(
tree: tpd.Tree,
source: SourceFile,
exportedSymsMapping: ExportedSymbolsMapping
)(using ctx: DottyContext): Option[ExtractedUnit] = {
// Remark: the method `extractUnit` is called for each compilation unit (which corresponds more or less to a Scala file)
// Therefore, the symbolMapping instances needs to be shared accross compilation unit.
// Since `extractUnit` is called within the same thread, we do not need to synchronize accesses to symbolMapping.
val extraction = new CodeExtraction(inoxCtx, symbolMapping, exportedSymsMapping)
import extraction._

val unit = ctx.compilationUnit
val tree = unit.tpdTree
val (id, stats) = tree match {
case pd@PackageDef(_, lst) =>
val id = lst.collectFirst { case PackageDef(ref, _) => ref } match {
case Some(ref) => extractRef(ref)
case None => FreshIdentifier(unit.source.file.name.replaceFirst("[.][^.]+$", ""))
case None => FreshIdentifier(source.file.name.replaceFirst("[.][^.]+$", ""))
}
(id, pd.stats)
case _ =>
(FreshIdentifier(unit.source.file.name.replaceFirst("[.][^.]+$", "")), List.empty)
(FreshIdentifier(source.file.name.replaceFirst("[.][^.]+$", "")), List.empty)
}

val fragmentChecker = new FragmentChecker(inoxCtx)
fragmentChecker.ghostChecker(tree)
fragmentChecker.checker(tree)

if (!fragmentChecker.hasErrors()) tryExtractUnit(extraction, unit, id, stats)
if (!fragmentChecker.hasErrors()) tryExtractUnit(extraction, source, id, stats)
else None
}

private def tryExtractUnit(extraction: CodeExtraction,
unit: CompilationUnit,
source: SourceFile,
id: Identifier,
stats: List[tpd.Tree])(using DottyContext): Option[ExtractedUnit] = {
// If the user annotates a function with @main, the compiler will generate a top-level class
Expand All @@ -67,7 +76,7 @@ class StainlessExtraction(val inoxCtx: inox.Context) {
try {
val (imports, unitClasses, unitFunctions, _, subs, classes, functions, typeDefs) = extraction.extractStatic(filteredStats)
assert(unitFunctions.isEmpty, "Packages shouldn't contain functions")
val file = unit.source.file.absolute.path
val file = source.file.absolute.path
val isLibrary = stainless.Main.libraryFiles contains file
val xtUnit = xt.UnitDef(id, imports, unitClasses, subs, !isLibrary)
Some(ExtractedUnit(file, xtUnit, classes, functions, typeDefs))
Expand All @@ -92,4 +101,19 @@ class StainlessExtraction(val inoxCtx: inox.Context) {
}
trAcc(None, stats)
}

def extractClasspathUnits(exportedSymsMapping: ExportedSymbolsMapping)(using DottyContext): Seq[ExtractedUnit] = {
@scala.annotation.tailrec
def loop(units: Map[ClassSymbol, ExtractedUnit]): Seq[ExtractedUnit] =
val newUnits =
symbolMapping
.getUsedTastyClasses()
.filterNot(units.contains)
.map(sym => sym -> extractUnit(sym.rootTree, sym.sourceOfClass, exportedSymsMapping).get)
.toMap
if (newUnits.isEmpty) units.values.toSeq
else loop(units ++ newUnits)

loop(Map.empty)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ import dotty.tools.dotc.core.Symbols._
import dotty.tools.dotc.core.Contexts._
import stainless.ast.SymbolIdentifier

import scala.collection.mutable.{ Map => MutableMap }
import scala.collection.mutable.{ Map => MutableMap, Set => MutableSet }

class SymbolMapping {
import SymbolMapping._
Expand All @@ -25,10 +25,16 @@ class SymbolMapping {
private val s2sAccessor = MutableMap[Symbol, SymbolIdentifier]()
private val s2sEnumType = MutableMap[Symbol, SymbolIdentifier]()

private val usedTastyClasses = MutableSet[ClassSymbol]()
def getUsedTastyClasses(): Set[ClassSymbol] = usedTastyClasses.toSet

/** Get the identifier associated with the given [[sym]], creating a new one if needed. */
def fetch(sym: Symbol, mode: FetchingMode)(using Context): SymbolIdentifier = mode match {
case Plain =>
s2s.getOrElseUpdate(sym, {
if (sym.tastyInfo.isDefined) {
usedTastyClasses += sym.topLevelClass.asClass
}
val overrides = sym.allOverriddenSymbols.toSeq
val top = overrides.lastOption.getOrElse(sym)
if (top eq sym) {
Expand Down Expand Up @@ -66,6 +72,7 @@ class SymbolMapping {
res
})
}

}

object SymbolMapping {
Expand Down Expand Up @@ -93,4 +100,4 @@ object SymbolMapping {
}
.mkString(".")
}
}
}

0 comments on commit 4c908a8

Please sign in to comment.