diff --git a/frontends/dotty/src/main/scala/stainless/frontends/dotc/DottyCompiler.scala b/frontends/dotty/src/main/scala/stainless/frontends/dotc/DottyCompiler.scala index 53e04ebf1..d2792451e 100644 --- a/frontends/dotty/src/main/scala/stainless/frontends/dotc/DottyCompiler.scala +++ b/frontends/dotty/src/main/scala/stainless/frontends/dotc/DottyCompiler.scala @@ -17,6 +17,8 @@ import typer._ import frontend.{CallBack, Frontend, FrontendFactory, ThreadedFrontend} import Utils._ +import inox.DebugSection + import java.io.File import java.net.URL @@ -52,7 +54,7 @@ class DottyCompiler(ctx: inox.Context, callback: CallBack) extends Compiler { override def runOn(units: List[CompilationUnit])(using dottyCtx: DottyContext): List[CompilationUnit] = { exportedSymsMapping = exportedSymbolsMapping(ctx, this.start, units) val res = super.runOn(units) - extraction.extractClasspathUnits(exportedSymsMapping, ctx).foreach(extracted => + extraction.extractTastyUnits(exportedSymsMapping, ctx).foreach(extracted => callback(extracted.file, extracted.unit, extracted.classes, extracted.functions, extracted.typeDefs)) res } @@ -167,7 +169,10 @@ object DottyCompiler { x => new File(x.getLocation.toURI).getAbsolutePath } getOrElse { ctx.reporter.fatalError("No Stainless Library found.") } - ctx.reporter.info(s"Stainless library found at: $stainlessLib") + given DebugSection = frontend.DebugSectionFrontend + ctx.reporter.debug(s"Scala library 2.13 found at: $scala213Lib") + ctx.reporter.debug(s"Scala library 3 found at: $scala3Lib") + ctx.reporter.debug(s"Stainless library found at: $stainlessLib") val extraCps = ctx.options.findOptionOrDefault(frontend.optClasspath).toSeq val cps = (extraCps ++ Seq(stainlessLib, scala213Lib, scala3Lib)).distinct.mkString(java.io.File.pathSeparator) diff --git a/frontends/dotty/src/main/scala/stainless/frontends/dotc/StainlessExtraction.scala b/frontends/dotty/src/main/scala/stainless/frontends/dotc/StainlessExtraction.scala index 52f7953a8..5df49b97e 100644 --- a/frontends/dotty/src/main/scala/stainless/frontends/dotc/StainlessExtraction.scala +++ b/frontends/dotty/src/main/scala/stainless/frontends/dotc/StainlessExtraction.scala @@ -3,19 +3,25 @@ package stainless package frontends.dotc +import scala.collection.mutable.{LinkedHashMap, ArrayBuffer} + +import dotty.tools.io.AbstractFile import dotty.tools.dotc._ import core.Names._ import core.Symbols._ +import core.CompilationUnitInfo import core.Contexts.{Context => DottyContext} import transform._ import ast.tpd import ast.Trees._ import typer._ -import util.SourceFile + +import inox.DebugSection import extraction.xlang.{trees => xt} import frontend.{CallBack, Frontend, FrontendFactory, ThreadedFrontend, UnsupportedCodeException, DebugSectionFrontend} import Utils._ +import stainless.verification.CoqEncoder.m case class ExtractedUnit(file: String, unit: xt.UnitDef, classes: Seq[xt.ClassDef], functions: Seq[xt.FunDef], typeDefs: Seq[xt.TypeDef]) @@ -25,14 +31,14 @@ class StainlessExtraction(val inoxCtx: inox.Context) { def extractUnit(exportedSymsMapping: ExportedSymbolsMapping)(using ctx: DottyContext): Option[ExtractedUnit] = { val unit = ctx.compilationUnit val tree = unit.tpdTree - extractUnit(tree, unit.source, exportedSymsMapping, isFromSource = true) + extractUnit(tree, unit.source.file, exportedSymsMapping, isFromSource = true) } def extractUnit( tree: tpd.Tree, - source: SourceFile, + file: AbstractFile, exportedSymsMapping: ExportedSymbolsMapping, - isFromSource: Boolean + isFromSource: Boolean, )(using ctx: DottyContext): Option[ExtractedUnit] = { // Remark: the method `extractUnit` is called for each compilation unit (which corresponds more or less to a Scala file) // Therefore, the symbolMapping instances needs to be shared accross compilation unit. @@ -44,24 +50,24 @@ class StainlessExtraction(val inoxCtx: inox.Context) { case pd@PackageDef(pid, lst) => val id = lst.collectFirst { case PackageDef(ref, _) => ref } match { case Some(ref) => extractRef(ref) - case None => FreshIdentifier(source.file.name.replaceFirst("[.][^.]+$", "")) + case None => FreshIdentifier(file.name.replaceFirst("[.][^.]+$", "")) } (id, pd.stats) case _ => - inoxCtx.reporter.info("Empty package definition: " + source.file.name) - (FreshIdentifier(source.file.name.replaceFirst("[.][^.]+$", "")), List.empty) + inoxCtx.reporter.info("Empty package definition: " + file.name) + (FreshIdentifier(file.name.replaceFirst("[.][^.]+$", "")), List.empty) } val fragmentChecker = new FragmentChecker(inoxCtx) fragmentChecker.ghostChecker(tree) fragmentChecker.checker(tree) - if (!fragmentChecker.hasErrors()) tryExtractUnit(extraction, source, id, stats, isFromSource) + if (!fragmentChecker.hasErrors()) tryExtractUnit(extraction, file, id, stats, isFromSource) else None } private def tryExtractUnit(extraction: CodeExtraction, - source: SourceFile, + file: AbstractFile, id: Identifier, stats: List[tpd.Tree], isFromSource: Boolean @@ -80,11 +86,8 @@ class StainlessExtraction(val inoxCtx: inox.Context) { try { val (imports, unitClasses, unitFunctions, _, subs, classes, functions, typeDefs) = extraction.extractStatic(filteredStats) assert(unitFunctions.isEmpty, "Packages shouldn't contain functions") - val file = source.file.absolute.path - val isLibrary = stainless.Main.libraryFiles contains file - val isMain = isFromSource && !isLibrary - val xtUnit = xt.UnitDef(id, imports, unitClasses, subs, isMain) - Some(ExtractedUnit(file, xtUnit, classes, functions, typeDefs)) + val xtUnit = xt.UnitDef(id, imports, unitClasses, subs, isFromSource) + Some(ExtractedUnit(file.absolute.path, xtUnit, classes, functions, typeDefs)) } catch { case UnsupportedCodeException(pos, msg) => inoxCtx.reporter.error(pos, msg) @@ -107,19 +110,50 @@ class StainlessExtraction(val inoxCtx: inox.Context) { trAcc(None, stats) } - def extractClasspathUnits(exportedSymsMapping: ExportedSymbolsMapping, inoxCtx: inox.Context)(using DottyContext): Seq[ExtractedUnit] = { - def loop(units: Map[ClassSymbol, ExtractedUnit], depth: Int): Seq[ExtractedUnit] = - val newSymbols = symbolMapping.getUsedTastyClasses().filterNot(units.contains) - inoxCtx.reporter.debug(f"Symbols to extract from classpath at depth $depth: [${newSymbols.map(_.fullName).mkString(", ")}]")(using DebugSectionFrontend) + /** Extract units defined in Tasty files. + * + * This will only extract units that have not been extracted yet. + * + * See [[SymbolMapping.popUsedTastyUnits]] for more information about how + * these units are collected. + * + * Side-effect: calls [[SymbolMapping.popUsedTastyUnits]]. + */ + def extractTastyUnits(exportedSymsMapping: ExportedSymbolsMapping, inoxCtx: inox.Context)(using DottyContext): Seq[ExtractedUnit] = { + given DebugSection = DebugSectionFrontend + + val unextractedPackages: Set[Symbol] = Set(defn.ScalaPackageClass, defn.JavaPackageClass) + + def extractTastyUnit(tree: tpd.Tree, info: CompilationUnitInfo): Option[ExtractedUnit] = { + val res = extractUnit(tree, info.associatedFile, exportedSymsMapping, isFromSource = false) + res match { + case Some(extracted) => inoxCtx.reporter.debug(s"- Extracted ${extracted.unit.id}.") + case None => inoxCtx.reporter.debug(s"- Failed to extract Tasty unit from ${info.associatedFile.path}.") + } + res + } + + var depth = 0 + // Potential performance improvement: share the Map of extracted Tasty units + // accross runs, so that we don't extract the same units multiple times in + // watch mode. + val extractedTastyUnits = LinkedHashMap[tpd.Tree, Option[ExtractedUnit]]() + while depth < 100 do + inoxCtx.reporter.debug(f"Extracting Tasty units at depth $depth:") val newUnits = - newSymbols.map(sym => { - val extracted = extractUnit(sym.rootTree, sym.sourceOfClass, exportedSymsMapping, isFromSource = false).get - inoxCtx.reporter.debug(s"Extracted class ${sym.fullName} from classpath as unit ${extracted.unit.id}.")(using DebugSectionFrontend) - (sym -> extracted) - }) - .toMap - if (newUnits.isEmpty) units.values.toSeq - else loop(units ++ newUnits, depth + 1) - loop(Map.empty, 0) + symbolMapping + .popUsedTastyUnits() + .filterNot((tree, _) => extractedTastyUnits.contains(tree)) + .filterNot((tree, _) => tree.symbol.ownersIterator.exists(unextractedPackages)) + .map((tree, info) => tree -> extractTastyUnit(tree, info)) + if newUnits.isEmpty then + inoxCtx.reporter.debug(f"- No more units to extract.") + return extractedTastyUnits.values.flatten.toSeq + extractedTastyUnits ++= newUnits + depth += 1 + + // This should not be reached. + inoxCtx.reporter.error("Reached maximum depth while extracting Tasty units. This is likely a bug in Stainless.") + Nil } } diff --git a/frontends/dotty/src/main/scala/stainless/frontends/dotc/SymbolMapping.scala b/frontends/dotty/src/main/scala/stainless/frontends/dotc/SymbolMapping.scala index 3fdd6dad0..261a3595e 100644 --- a/frontends/dotty/src/main/scala/stainless/frontends/dotc/SymbolMapping.scala +++ b/frontends/dotty/src/main/scala/stainless/frontends/dotc/SymbolMapping.scala @@ -5,12 +5,14 @@ package frontends.dotc import scala.language.implicitConversions +import dotty.tools.dotc.ast.tpd +import dotty.tools.dotc.core.CompilationUnitInfo import dotty.tools.dotc.core.Flags._ import dotty.tools.dotc.core.Symbols._ import dotty.tools.dotc.core.Contexts._ import stainless.ast.SymbolIdentifier -import scala.collection.mutable.{ Map => MutableMap, Set => MutableSet } +import scala.collection.mutable.{ Map => MutableMap, LinkedHashMap } class SymbolMapping { import SymbolMapping._ @@ -25,12 +27,40 @@ class SymbolMapping { private val s2sAccessor = MutableMap[Symbol, SymbolIdentifier]() private val s2sEnumType = MutableMap[Symbol, SymbolIdentifier]() - private val usedTastyClasses = MutableSet[ClassSymbol]() - def getUsedTastyClasses(): Set[ClassSymbol] = usedTastyClasses.toSet + /** Returns a mapping from Dotty unit trees loaded from Tasty files to their + * [[CompilationUnitInfo]]. + * + * Everytime a [[Symbol]] is touched through [[fetch]], we check if it has an + * associated Tasty file. If it does, we register the symbol's root tree + * (which is the tree of the compilation unit containing the symbol's + * definition) and its associated compilation unit info (which contains the + * path to the Tasty file among other things). + * + * We keep this information to later extract compilation units defined in + * Tasty files, in [[StainlessExtraction.extractTastyUnits]]. + * + * This a [[LinkedHashMap]] to keep elements in insertion order (there is no + * efficient immutable equivalent in the Scala standard library). + * + * Side-effect: calling this method clears the internal list of used Tasty + * units. + * + * Potential performance improvement: `Tree.equals` and `Tree.hashCode` might + * be suboptimal. Comparing by reference sould be sufficient. + */ + def popUsedTastyUnits(): LinkedHashMap[tpd.Tree, CompilationUnitInfo] = + val res = usedTastyUnits + usedTastyUnits = LinkedHashMap[tpd.Tree, CompilationUnitInfo]() + res + + private var usedTastyUnits = LinkedHashMap[tpd.Tree, CompilationUnitInfo]() - private def maybeRegisterTastyClass(sym: Symbol)(using Context): Unit = { + private def maybeRegisterTastyUnit(sym: Symbol)(using Context): Unit = { if (sym.tastyInfo.isDefined) { - usedTastyClasses += sym.topLevelClass.asClass + val classSym = sym.topLevelClass.asClass + // Below, `classSym.rootTree` returns the tree read from the Tasty file. + // It works because we passed the `-Yretain-trees` option to the compiler. + usedTastyUnits += (classSym.rootTree -> classSym.compilationUnitInfo) } } @@ -38,7 +68,7 @@ class SymbolMapping { def fetch(sym: Symbol, mode: FetchingMode)(using Context): SymbolIdentifier = mode match { case Plain => s2s.getOrElseUpdate(sym, { - maybeRegisterTastyClass(sym) + maybeRegisterTastyUnit(sym) val overrides = sym.allOverriddenSymbols.toSeq val top = overrides.lastOption.getOrElse(sym) if (top eq sym) { @@ -49,7 +79,7 @@ class SymbolMapping { }) case FieldAccessor => s2sAccessor.getOrElseUpdate(sym, { - maybeRegisterTastyClass(sym) + maybeRegisterTastyUnit(sym) val overrides = sym.allOverriddenSymbols.toSeq val top = overrides.lastOption.getOrElse(sym) if (top eq sym) { @@ -63,7 +93,7 @@ class SymbolMapping { }) case EnumType => s2sEnumType.getOrElseUpdate(sym, { - maybeRegisterTastyClass(sym) + maybeRegisterTastyUnit(sym) assert(sym.allOverriddenSymbols.isEmpty) SymbolIdentifier(ast.Symbol(symFullName(sym))) })