From 46e3dcdc391be5a2f11443a2fc18e70c42d29ffc Mon Sep 17 00:00:00 2001 From: Matt Bovel Date: Fri, 6 Dec 2024 11:50:26 +0000 Subject: [PATCH] [skip ci] WIP: Load Stainless standard library from the classpath --- .../verification/VerificationConditions.scala | 2 ++ .../src/test/scala/stainless/InputUtils.scala | 2 +- .../frontends/dotc/DottyCompiler.scala | 25 ++++++++++++++++--- .../frontends/dotc/StainlessExtraction.scala | 21 ++++++++++------ .../frontends/dotc/SymbolMapping.scala | 12 ++++++--- 5 files changed, 46 insertions(+), 16 deletions(-) diff --git a/core/src/main/scala/stainless/verification/VerificationConditions.scala b/core/src/main/scala/stainless/verification/VerificationConditions.scala index 220bb36ee..33adf70ae 100644 --- a/core/src/main/scala/stainless/verification/VerificationConditions.scala +++ b/core/src/main/scala/stainless/verification/VerificationConditions.scala @@ -27,6 +27,8 @@ case class VC[T <: ast.Trees](val trees: T)(val condition: trees.Expr, val fid: val state = Seq(trees, condition, fid, kind, satisfiability) state.map(_.hashCode()).foldLeft(0)((a, b) => 31 * a + b) } + + override def toString(): String = s"VC($condition)" } sealed abstract class VCKind(val name: String, val abbrv: String) { diff --git a/frontends/common/src/test/scala/stainless/InputUtils.scala b/frontends/common/src/test/scala/stainless/InputUtils.scala index 7d8a296af..41b5bdffa 100644 --- a/frontends/common/src/test/scala/stainless/InputUtils.scala +++ b/frontends/common/src/test/scala/stainless/InputUtils.scala @@ -35,7 +35,7 @@ trait InputUtils { loadFiles(files, filterOpt, sanitize) } - /** Compile and extract the given files (& the library). */ + /** Compile and extract the given files. */ def loadFiles(files: Seq[String], filterOpt: Option[Filter] = None, sanitize: Boolean = true) (using ctx: inox.Context): (Seq[xt.UnitDef], Program { val trees: xt.type }) = { diff --git a/frontends/dotty/src/main/scala/stainless/frontends/dotc/DottyCompiler.scala b/frontends/dotty/src/main/scala/stainless/frontends/dotc/DottyCompiler.scala index 14efc27d2..deba809d6 100644 --- a/frontends/dotty/src/main/scala/stainless/frontends/dotc/DottyCompiler.scala +++ b/frontends/dotty/src/main/scala/stainless/frontends/dotc/DottyCompiler.scala @@ -8,6 +8,7 @@ import plugins._ import dotty.tools.dotc.reporting.{Diagnostic, Reporter => DottyReporter} import dotty.tools.dotc.interfaces.Diagnostic.{ERROR, WARNING, INFO} import dotty.tools.dotc.util.SourcePosition +import dotty.tools.dotc.core.Symbols.{ClassSymbol => DottyClasSymbol} import dotty.tools.io.AbstractFile import core.Contexts.{Context => DottyContext, _} import core.Phases._ @@ -42,16 +43,19 @@ class DottyCompiler(ctx: inox.Context, callback: CallBack) extends Compiler { // to be shared across multiple compilation unit. private val extraction = new StainlessExtraction(ctx) private var exportedSymsMapping: ExportedSymbolsMapping = ExportedSymbolsMapping.empty + // This method id called for every compilation unit, and in the same thread. override def run(using dottyCtx: DottyContext): Unit = extraction.extractUnit(exportedSymsMapping).foreach(extracted => callback(extracted.file, extracted.unit, extracted.classes, extracted.functions, extracted.typeDefs)) - + override def runOn(units: List[CompilationUnit])(using dottyCtx: DottyContext): List[CompilationUnit] = { exportedSymsMapping = exportedSymbolsMapping(ctx, this.start, units) val res = super.runOn(units) - extraction.extractClasspathUnits(exportedSymsMapping).foreach(extracted => + + extraction.extractClasspathUnits(exportedSymsMapping, ctx).foreach(extracted => + ctx.reporter.info(s"Extracted classpath unit: ${extracted.unit.id}") callback(extracted.file, extracted.unit, extracted.classes, extracted.functions, extracted.typeDefs)) res } @@ -142,10 +146,16 @@ object DottyCompiler { override val libraryPaths: Seq[String] ) extends FrontendFactory { + /** Overriden to not include library sources. */ + final override protected def allCompilerArguments(ctx: inox.Context, compilerArgs: Seq[String]): Seq[String] = { + val extraSources = extraSourceFiles(ctx) + extraCompilerArguments ++ extraSources ++ compilerArgs + } + override def apply(ctx: inox.Context, compilerArgs: Seq[String], callback: CallBack): Frontend = new ThreadedFrontend(callback, ctx) { val args = { - // Attempt to find where the Scala 2.13 and 3.0 libs are. + // Attempt to find where the Scala 2.13 and 3.0 libs, and the Stainless lib are. // The 3.0 library depends on the 2.13, so we need to fetch the later as well. val scala213Lib: String = Option(scala.Predef.getClass.getProtectionDomain.getCodeSource) map { x => new File(x.getLocation.toURI).getAbsolutePath @@ -155,12 +165,19 @@ object DottyCompiler { val scala3Lib: String = Option(scala.util.NotGiven.getClass.getProtectionDomain.getCodeSource) map { x => new File(x.getLocation.toURI).getAbsolutePath } getOrElse { ctx.reporter.fatalError("No Scala 3 library found.") } + // Find the Stainless library by looking at the location of the `stainless.collection.List`. + val stainlessLib: String = Option(stainless.collection.List.getClass.getProtectionDomain.getCodeSource) map { + x => new File(x.getLocation.toURI).getAbsolutePath + } getOrElse { ctx.reporter.fatalError("No Stainless Library found.") } + + ctx.reporter.info(s"Stainless library found at: $stainlessLib") val extraCps = ctx.options.findOptionOrDefault(frontend.optClasspath).toSeq - val cps = (extraCps ++ Seq(scala213Lib, scala3Lib)).distinct.mkString(java.io.File.pathSeparator) + val cps = (extraCps ++ Seq(stainlessLib, scala213Lib, scala3Lib)).distinct.mkString(java.io.File.pathSeparator) val flags = Seq("-Yretain-trees", "-color:never", "-language:implicitConversions", "-Wsafe-init", s"-cp:$cps") // -Ysafe-init is deprecated (SAM 21.08.2024) allCompilerArguments(ctx, compilerArgs) ++ flags } + val compiler: DottyCompiler = new DottyCompiler(ctx, this.callback) val driver = new DottyDriver(args, compiler, new SimpleReporter(ctx.reporter)) diff --git a/frontends/dotty/src/main/scala/stainless/frontends/dotc/StainlessExtraction.scala b/frontends/dotty/src/main/scala/stainless/frontends/dotc/StainlessExtraction.scala index 912958f7c..a62b8f92b 100644 --- a/frontends/dotty/src/main/scala/stainless/frontends/dotc/StainlessExtraction.scala +++ b/frontends/dotty/src/main/scala/stainless/frontends/dotc/StainlessExtraction.scala @@ -25,13 +25,14 @@ class StainlessExtraction(val inoxCtx: inox.Context) { def extractUnit(exportedSymsMapping: ExportedSymbolsMapping)(using ctx: DottyContext): Option[ExtractedUnit] = { val unit = ctx.compilationUnit val tree = unit.tpdTree - extractUnit(tree, unit.source, exportedSymsMapping) + extractUnit(tree, unit.source, exportedSymsMapping, isFromSource = true) } def extractUnit( tree: tpd.Tree, source: SourceFile, - exportedSymsMapping: ExportedSymbolsMapping + exportedSymsMapping: ExportedSymbolsMapping, + isFromSource: Boolean )(using ctx: DottyContext): Option[ExtractedUnit] = { // Remark: the method `extractUnit` is called for each compilation unit (which corresponds more or less to a Scala file) // Therefore, the symbolMapping instances needs to be shared accross compilation unit. @@ -40,13 +41,14 @@ class StainlessExtraction(val inoxCtx: inox.Context) { import extraction._ val (id, stats) = tree match { - case pd@PackageDef(_, lst) => + case pd@PackageDef(pid, lst) => val id = lst.collectFirst { case PackageDef(ref, _) => ref } match { case Some(ref) => extractRef(ref) case None => FreshIdentifier(source.file.name.replaceFirst("[.][^.]+$", "")) } (id, pd.stats) case _ => + inoxCtx.reporter.info("Empty package definition: " + source.file.name) (FreshIdentifier(source.file.name.replaceFirst("[.][^.]+$", "")), List.empty) } @@ -54,14 +56,16 @@ class StainlessExtraction(val inoxCtx: inox.Context) { fragmentChecker.ghostChecker(tree) fragmentChecker.checker(tree) - if (!fragmentChecker.hasErrors()) tryExtractUnit(extraction, source, id, stats) + if (!fragmentChecker.hasErrors()) tryExtractUnit(extraction, source, id, stats, isFromSource) else None } private def tryExtractUnit(extraction: CodeExtraction, source: SourceFile, id: Identifier, - stats: List[tpd.Tree])(using DottyContext): Option[ExtractedUnit] = { + stats: List[tpd.Tree], + isFromSource: Boolean + )(using DottyContext): Option[ExtractedUnit] = { // If the user annotates a function with @main, the compiler will generate a top-level class // with the same name as the function. // This generated class defines def main(args: Array[String]): Unit @@ -78,7 +82,8 @@ class StainlessExtraction(val inoxCtx: inox.Context) { assert(unitFunctions.isEmpty, "Packages shouldn't contain functions") val file = source.file.absolute.path val isLibrary = stainless.Main.libraryFiles contains file - val xtUnit = xt.UnitDef(id, imports, unitClasses, subs, !isLibrary) + val isMain = isFromSource && !isLibrary + val xtUnit = xt.UnitDef(id, imports, unitClasses, subs, isMain) Some(ExtractedUnit(file, xtUnit, classes, functions, typeDefs)) } catch { case UnsupportedCodeException(pos, msg) => @@ -102,14 +107,14 @@ class StainlessExtraction(val inoxCtx: inox.Context) { trAcc(None, stats) } - def extractClasspathUnits(exportedSymsMapping: ExportedSymbolsMapping)(using DottyContext): Seq[ExtractedUnit] = { + def extractClasspathUnits(exportedSymsMapping: ExportedSymbolsMapping, inoxCtx: inox.Context)(using DottyContext): Seq[ExtractedUnit] = { @scala.annotation.tailrec def loop(units: Map[ClassSymbol, ExtractedUnit]): Seq[ExtractedUnit] = val newUnits = symbolMapping .getUsedTastyClasses() .filterNot(units.contains) - .map(sym => sym -> extractUnit(sym.rootTree, sym.sourceOfClass, exportedSymsMapping).get) + .map(sym => sym -> extractUnit(sym.rootTree, sym.sourceOfClass, exportedSymsMapping, isFromSource = false).get) .toMap if (newUnits.isEmpty) units.values.toSeq else loop(units ++ newUnits) diff --git a/frontends/dotty/src/main/scala/stainless/frontends/dotc/SymbolMapping.scala b/frontends/dotty/src/main/scala/stainless/frontends/dotc/SymbolMapping.scala index 67a22fbc6..3fdd6dad0 100644 --- a/frontends/dotty/src/main/scala/stainless/frontends/dotc/SymbolMapping.scala +++ b/frontends/dotty/src/main/scala/stainless/frontends/dotc/SymbolMapping.scala @@ -28,13 +28,17 @@ class SymbolMapping { private val usedTastyClasses = MutableSet[ClassSymbol]() def getUsedTastyClasses(): Set[ClassSymbol] = usedTastyClasses.toSet + private def maybeRegisterTastyClass(sym: Symbol)(using Context): Unit = { + if (sym.tastyInfo.isDefined) { + usedTastyClasses += sym.topLevelClass.asClass + } + } + /** Get the identifier associated with the given [[sym]], creating a new one if needed. */ def fetch(sym: Symbol, mode: FetchingMode)(using Context): SymbolIdentifier = mode match { case Plain => s2s.getOrElseUpdate(sym, { - if (sym.tastyInfo.isDefined) { - usedTastyClasses += sym.topLevelClass.asClass - } + maybeRegisterTastyClass(sym) val overrides = sym.allOverriddenSymbols.toSeq val top = overrides.lastOption.getOrElse(sym) if (top eq sym) { @@ -45,6 +49,7 @@ class SymbolMapping { }) case FieldAccessor => s2sAccessor.getOrElseUpdate(sym, { + maybeRegisterTastyClass(sym) val overrides = sym.allOverriddenSymbols.toSeq val top = overrides.lastOption.getOrElse(sym) if (top eq sym) { @@ -58,6 +63,7 @@ class SymbolMapping { }) case EnumType => s2sEnumType.getOrElseUpdate(sym, { + maybeRegisterTastyClass(sym) assert(sym.allOverriddenSymbols.isEmpty) SymbolIdentifier(ast.Symbol(symFullName(sym))) })