Skip to content

Commit

Permalink
[skip ci] WIP: Load Stainless standard library from the classpath
Browse files Browse the repository at this point in the history
  • Loading branch information
mbovel committed Dec 6, 2024
1 parent 4c908a8 commit 46e3dcd
Show file tree
Hide file tree
Showing 5 changed files with 46 additions and 16 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@ case class VC[T <: ast.Trees](val trees: T)(val condition: trees.Expr, val fid:
val state = Seq(trees, condition, fid, kind, satisfiability)
state.map(_.hashCode()).foldLeft(0)((a, b) => 31 * a + b)
}

override def toString(): String = s"VC($condition)"
}

sealed abstract class VCKind(val name: String, val abbrv: String) {
Expand Down
2 changes: 1 addition & 1 deletion frontends/common/src/test/scala/stainless/InputUtils.scala
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ trait InputUtils {
loadFiles(files, filterOpt, sanitize)
}

/** Compile and extract the given files (& the library). */
/** Compile and extract the given files. */
def loadFiles(files: Seq[String], filterOpt: Option[Filter] = None, sanitize: Boolean = true)
(using ctx: inox.Context): (Seq[xt.UnitDef], Program { val trees: xt.type }) = {

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import plugins._
import dotty.tools.dotc.reporting.{Diagnostic, Reporter => DottyReporter}
import dotty.tools.dotc.interfaces.Diagnostic.{ERROR, WARNING, INFO}
import dotty.tools.dotc.util.SourcePosition
import dotty.tools.dotc.core.Symbols.{ClassSymbol => DottyClasSymbol}
import dotty.tools.io.AbstractFile
import core.Contexts.{Context => DottyContext, _}
import core.Phases._
Expand Down Expand Up @@ -42,16 +43,19 @@ class DottyCompiler(ctx: inox.Context, callback: CallBack) extends Compiler {
// to be shared across multiple compilation unit.
private val extraction = new StainlessExtraction(ctx)
private var exportedSymsMapping: ExportedSymbolsMapping = ExportedSymbolsMapping.empty


// This method id called for every compilation unit, and in the same thread.
override def run(using dottyCtx: DottyContext): Unit =
extraction.extractUnit(exportedSymsMapping).foreach(extracted =>
callback(extracted.file, extracted.unit, extracted.classes, extracted.functions, extracted.typeDefs))

override def runOn(units: List[CompilationUnit])(using dottyCtx: DottyContext): List[CompilationUnit] = {
exportedSymsMapping = exportedSymbolsMapping(ctx, this.start, units)
val res = super.runOn(units)
extraction.extractClasspathUnits(exportedSymsMapping).foreach(extracted =>

extraction.extractClasspathUnits(exportedSymsMapping, ctx).foreach(extracted =>
ctx.reporter.info(s"Extracted classpath unit: ${extracted.unit.id}")
callback(extracted.file, extracted.unit, extracted.classes, extracted.functions, extracted.typeDefs))
res
}
Expand Down Expand Up @@ -142,10 +146,16 @@ object DottyCompiler {
override val libraryPaths: Seq[String]
) extends FrontendFactory {

/** Overriden to not include library sources. */
final override protected def allCompilerArguments(ctx: inox.Context, compilerArgs: Seq[String]): Seq[String] = {
val extraSources = extraSourceFiles(ctx)
extraCompilerArguments ++ extraSources ++ compilerArgs
}

override def apply(ctx: inox.Context, compilerArgs: Seq[String], callback: CallBack): Frontend =
new ThreadedFrontend(callback, ctx) {
val args = {
// Attempt to find where the Scala 2.13 and 3.0 libs are.
// Attempt to find where the Scala 2.13 and 3.0 libs, and the Stainless lib are.
// The 3.0 library depends on the 2.13, so we need to fetch the later as well.
val scala213Lib: String = Option(scala.Predef.getClass.getProtectionDomain.getCodeSource) map {
x => new File(x.getLocation.toURI).getAbsolutePath
Expand All @@ -155,12 +165,19 @@ object DottyCompiler {
val scala3Lib: String = Option(scala.util.NotGiven.getClass.getProtectionDomain.getCodeSource) map {
x => new File(x.getLocation.toURI).getAbsolutePath
} getOrElse { ctx.reporter.fatalError("No Scala 3 library found.") }
// Find the Stainless library by looking at the location of the `stainless.collection.List`.
val stainlessLib: String = Option(stainless.collection.List.getClass.getProtectionDomain.getCodeSource) map {
x => new File(x.getLocation.toURI).getAbsolutePath
} getOrElse { ctx.reporter.fatalError("No Stainless Library found.") }

ctx.reporter.info(s"Stainless library found at: $stainlessLib")

val extraCps = ctx.options.findOptionOrDefault(frontend.optClasspath).toSeq
val cps = (extraCps ++ Seq(scala213Lib, scala3Lib)).distinct.mkString(java.io.File.pathSeparator)
val cps = (extraCps ++ Seq(stainlessLib, scala213Lib, scala3Lib)).distinct.mkString(java.io.File.pathSeparator)
val flags = Seq("-Yretain-trees", "-color:never", "-language:implicitConversions", "-Wsafe-init", s"-cp:$cps") // -Ysafe-init is deprecated (SAM 21.08.2024)
allCompilerArguments(ctx, compilerArgs) ++ flags
}

val compiler: DottyCompiler = new DottyCompiler(ctx, this.callback)

val driver = new DottyDriver(args, compiler, new SimpleReporter(ctx.reporter))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,13 +25,14 @@ class StainlessExtraction(val inoxCtx: inox.Context) {
def extractUnit(exportedSymsMapping: ExportedSymbolsMapping)(using ctx: DottyContext): Option[ExtractedUnit] = {
val unit = ctx.compilationUnit
val tree = unit.tpdTree
extractUnit(tree, unit.source, exportedSymsMapping)
extractUnit(tree, unit.source, exportedSymsMapping, isFromSource = true)
}

def extractUnit(
tree: tpd.Tree,
source: SourceFile,
exportedSymsMapping: ExportedSymbolsMapping
exportedSymsMapping: ExportedSymbolsMapping,
isFromSource: Boolean
)(using ctx: DottyContext): Option[ExtractedUnit] = {
// Remark: the method `extractUnit` is called for each compilation unit (which corresponds more or less to a Scala file)
// Therefore, the symbolMapping instances needs to be shared accross compilation unit.
Expand All @@ -40,28 +41,31 @@ class StainlessExtraction(val inoxCtx: inox.Context) {
import extraction._

val (id, stats) = tree match {
case pd@PackageDef(_, lst) =>
case pd@PackageDef(pid, lst) =>
val id = lst.collectFirst { case PackageDef(ref, _) => ref } match {
case Some(ref) => extractRef(ref)
case None => FreshIdentifier(source.file.name.replaceFirst("[.][^.]+$", ""))
}
(id, pd.stats)
case _ =>
inoxCtx.reporter.info("Empty package definition: " + source.file.name)
(FreshIdentifier(source.file.name.replaceFirst("[.][^.]+$", "")), List.empty)
}

val fragmentChecker = new FragmentChecker(inoxCtx)
fragmentChecker.ghostChecker(tree)
fragmentChecker.checker(tree)

if (!fragmentChecker.hasErrors()) tryExtractUnit(extraction, source, id, stats)
if (!fragmentChecker.hasErrors()) tryExtractUnit(extraction, source, id, stats, isFromSource)
else None
}

private def tryExtractUnit(extraction: CodeExtraction,
source: SourceFile,
id: Identifier,
stats: List[tpd.Tree])(using DottyContext): Option[ExtractedUnit] = {
stats: List[tpd.Tree],
isFromSource: Boolean
)(using DottyContext): Option[ExtractedUnit] = {
// If the user annotates a function with @main, the compiler will generate a top-level class
// with the same name as the function.
// This generated class defines def main(args: Array[String]): Unit
Expand All @@ -78,7 +82,8 @@ class StainlessExtraction(val inoxCtx: inox.Context) {
assert(unitFunctions.isEmpty, "Packages shouldn't contain functions")
val file = source.file.absolute.path
val isLibrary = stainless.Main.libraryFiles contains file
val xtUnit = xt.UnitDef(id, imports, unitClasses, subs, !isLibrary)
val isMain = isFromSource && !isLibrary
val xtUnit = xt.UnitDef(id, imports, unitClasses, subs, isMain)
Some(ExtractedUnit(file, xtUnit, classes, functions, typeDefs))
} catch {
case UnsupportedCodeException(pos, msg) =>
Expand All @@ -102,14 +107,14 @@ class StainlessExtraction(val inoxCtx: inox.Context) {
trAcc(None, stats)
}

def extractClasspathUnits(exportedSymsMapping: ExportedSymbolsMapping)(using DottyContext): Seq[ExtractedUnit] = {
def extractClasspathUnits(exportedSymsMapping: ExportedSymbolsMapping, inoxCtx: inox.Context)(using DottyContext): Seq[ExtractedUnit] = {
@scala.annotation.tailrec
def loop(units: Map[ClassSymbol, ExtractedUnit]): Seq[ExtractedUnit] =
val newUnits =
symbolMapping
.getUsedTastyClasses()
.filterNot(units.contains)
.map(sym => sym -> extractUnit(sym.rootTree, sym.sourceOfClass, exportedSymsMapping).get)
.map(sym => sym -> extractUnit(sym.rootTree, sym.sourceOfClass, exportedSymsMapping, isFromSource = false).get)
.toMap
if (newUnits.isEmpty) units.values.toSeq
else loop(units ++ newUnits)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,13 +28,17 @@ class SymbolMapping {
private val usedTastyClasses = MutableSet[ClassSymbol]()
def getUsedTastyClasses(): Set[ClassSymbol] = usedTastyClasses.toSet

private def maybeRegisterTastyClass(sym: Symbol)(using Context): Unit = {
if (sym.tastyInfo.isDefined) {
usedTastyClasses += sym.topLevelClass.asClass
}
}

/** Get the identifier associated with the given [[sym]], creating a new one if needed. */
def fetch(sym: Symbol, mode: FetchingMode)(using Context): SymbolIdentifier = mode match {
case Plain =>
s2s.getOrElseUpdate(sym, {
if (sym.tastyInfo.isDefined) {
usedTastyClasses += sym.topLevelClass.asClass
}
maybeRegisterTastyClass(sym)
val overrides = sym.allOverriddenSymbols.toSeq
val top = overrides.lastOption.getOrElse(sym)
if (top eq sym) {
Expand All @@ -45,6 +49,7 @@ class SymbolMapping {
})
case FieldAccessor =>
s2sAccessor.getOrElseUpdate(sym, {
maybeRegisterTastyClass(sym)
val overrides = sym.allOverriddenSymbols.toSeq
val top = overrides.lastOption.getOrElse(sym)
if (top eq sym) {
Expand All @@ -58,6 +63,7 @@ class SymbolMapping {
})
case EnumType =>
s2sEnumType.getOrElseUpdate(sym, {
maybeRegisterTastyClass(sym)
assert(sym.allOverriddenSymbols.isEmpty)
SymbolIdentifier(ast.Symbol(symFullName(sym)))
})
Expand Down

0 comments on commit 46e3dcd

Please sign in to comment.