Skip to content

Commit

Permalink
Fix spurious duplicate loads of Tasty units
Browse files Browse the repository at this point in the history
  • Loading branch information
mbovel committed Dec 11, 2024
1 parent 04fedc0 commit 8d9d614
Show file tree
Hide file tree
Showing 3 changed files with 106 additions and 37 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@ import typer._
import frontend.{CallBack, Frontend, FrontendFactory, ThreadedFrontend}
import Utils._

import inox.DebugSection

import java.io.File
import java.net.URL

Expand Down Expand Up @@ -52,7 +54,7 @@ class DottyCompiler(ctx: inox.Context, callback: CallBack) extends Compiler {
override def runOn(units: List[CompilationUnit])(using dottyCtx: DottyContext): List[CompilationUnit] = {
exportedSymsMapping = exportedSymbolsMapping(ctx, this.start, units)
val res = super.runOn(units)
extraction.extractClasspathUnits(exportedSymsMapping, ctx).foreach(extracted =>
extraction.extractTastyUnits(exportedSymsMapping, ctx).foreach(extracted =>
callback(extracted.file, extracted.unit, extracted.classes, extracted.functions, extracted.typeDefs))
res
}
Expand Down Expand Up @@ -167,7 +169,10 @@ object DottyCompiler {
x => new File(x.getLocation.toURI).getAbsolutePath
} getOrElse { ctx.reporter.fatalError("No Stainless Library found.") }

ctx.reporter.info(s"Stainless library found at: $stainlessLib")
given DebugSection = frontend.DebugSectionFrontend
ctx.reporter.debug(s"Scala library 2.13 found at: $scala213Lib")
ctx.reporter.debug(s"Scala library 3 found at: $scala3Lib")
ctx.reporter.debug(s"Stainless library found at: $stainlessLib")

val extraCps = ctx.options.findOptionOrDefault(frontend.optClasspath).toSeq
val cps = (extraCps ++ Seq(stainlessLib, scala213Lib, scala3Lib)).distinct.mkString(java.io.File.pathSeparator)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,19 +3,25 @@
package stainless
package frontends.dotc

import scala.collection.mutable.{LinkedHashMap, ArrayBuffer}

import dotty.tools.io.AbstractFile
import dotty.tools.dotc._
import core.Names._
import core.Symbols._
import core.CompilationUnitInfo
import core.Contexts.{Context => DottyContext}
import transform._
import ast.tpd
import ast.Trees._
import typer._
import util.SourceFile

import inox.DebugSection

import extraction.xlang.{trees => xt}
import frontend.{CallBack, Frontend, FrontendFactory, ThreadedFrontend, UnsupportedCodeException, DebugSectionFrontend}
import Utils._
import stainless.verification.CoqEncoder.m

case class ExtractedUnit(file: String, unit: xt.UnitDef, classes: Seq[xt.ClassDef], functions: Seq[xt.FunDef], typeDefs: Seq[xt.TypeDef])

Expand All @@ -25,14 +31,14 @@ class StainlessExtraction(val inoxCtx: inox.Context) {
def extractUnit(exportedSymsMapping: ExportedSymbolsMapping)(using ctx: DottyContext): Option[ExtractedUnit] = {
val unit = ctx.compilationUnit
val tree = unit.tpdTree
extractUnit(tree, unit.source, exportedSymsMapping, isFromSource = true)
extractUnit(tree, unit.source.file, exportedSymsMapping, isFromSource = true)
}

def extractUnit(
tree: tpd.Tree,
source: SourceFile,
file: AbstractFile,
exportedSymsMapping: ExportedSymbolsMapping,
isFromSource: Boolean
isFromSource: Boolean,
)(using ctx: DottyContext): Option[ExtractedUnit] = {
// Remark: the method `extractUnit` is called for each compilation unit (which corresponds more or less to a Scala file)
// Therefore, the symbolMapping instances needs to be shared accross compilation unit.
Expand All @@ -44,24 +50,24 @@ class StainlessExtraction(val inoxCtx: inox.Context) {
case pd@PackageDef(pid, lst) =>
val id = lst.collectFirst { case PackageDef(ref, _) => ref } match {
case Some(ref) => extractRef(ref)
case None => FreshIdentifier(source.file.name.replaceFirst("[.][^.]+$", ""))
case None => FreshIdentifier(file.name.replaceFirst("[.][^.]+$", ""))
}
(id, pd.stats)
case _ =>
inoxCtx.reporter.info("Empty package definition: " + source.file.name)
(FreshIdentifier(source.file.name.replaceFirst("[.][^.]+$", "")), List.empty)
inoxCtx.reporter.info("Empty package definition: " + file.name)
(FreshIdentifier(file.name.replaceFirst("[.][^.]+$", "")), List.empty)
}

val fragmentChecker = new FragmentChecker(inoxCtx)
fragmentChecker.ghostChecker(tree)
fragmentChecker.checker(tree)

if (!fragmentChecker.hasErrors()) tryExtractUnit(extraction, source, id, stats, isFromSource)
if (!fragmentChecker.hasErrors()) tryExtractUnit(extraction, file, id, stats, isFromSource)
else None
}

private def tryExtractUnit(extraction: CodeExtraction,
source: SourceFile,
file: AbstractFile,
id: Identifier,
stats: List[tpd.Tree],
isFromSource: Boolean
Expand All @@ -80,11 +86,8 @@ class StainlessExtraction(val inoxCtx: inox.Context) {
try {
val (imports, unitClasses, unitFunctions, _, subs, classes, functions, typeDefs) = extraction.extractStatic(filteredStats)
assert(unitFunctions.isEmpty, "Packages shouldn't contain functions")
val file = source.file.absolute.path
val isLibrary = stainless.Main.libraryFiles contains file
val isMain = isFromSource && !isLibrary
val xtUnit = xt.UnitDef(id, imports, unitClasses, subs, isMain)
Some(ExtractedUnit(file, xtUnit, classes, functions, typeDefs))
val xtUnit = xt.UnitDef(id, imports, unitClasses, subs, isFromSource)
Some(ExtractedUnit(file.absolute.path, xtUnit, classes, functions, typeDefs))
} catch {
case UnsupportedCodeException(pos, msg) =>
inoxCtx.reporter.error(pos, msg)
Expand All @@ -107,19 +110,50 @@ class StainlessExtraction(val inoxCtx: inox.Context) {
trAcc(None, stats)
}

def extractClasspathUnits(exportedSymsMapping: ExportedSymbolsMapping, inoxCtx: inox.Context)(using DottyContext): Seq[ExtractedUnit] = {
def loop(units: Map[ClassSymbol, ExtractedUnit], depth: Int): Seq[ExtractedUnit] =
val newSymbols = symbolMapping.getUsedTastyClasses().filterNot(units.contains)
inoxCtx.reporter.debug(f"Symbols to extract from classpath at depth $depth: [${newSymbols.map(_.fullName).mkString(", ")}]")(using DebugSectionFrontend)
/** Extract units defined in Tasty files.
*
* This will only extract units that have not been extracted yet.
*
* See [[SymbolMapping.popUsedTastyUnits]] for more information about how
* these units are collected.
*
* Side-effect: calls [[SymbolMapping.popUsedTastyUnits]].
*/
def extractTastyUnits(exportedSymsMapping: ExportedSymbolsMapping, inoxCtx: inox.Context)(using DottyContext): Seq[ExtractedUnit] = {
given DebugSection = DebugSectionFrontend

val unextractedPackages: Set[Symbol] = Set(defn.ScalaPackageClass, defn.JavaPackageClass)

def extractTastyUnit(tree: tpd.Tree, info: CompilationUnitInfo): Option[ExtractedUnit] = {
val res = extractUnit(tree, info.associatedFile, exportedSymsMapping, isFromSource = false)
res match {
case Some(extracted) => inoxCtx.reporter.debug(s"- Extracted ${extracted.unit.id}.")
case None => inoxCtx.reporter.debug(s"- Failed to extract Tasty unit from ${info.associatedFile.path}.")
}
res
}

var depth = 0
// Potential performance improvement: share the Map of extracted Tasty units
// accross runs, so that we don't extract the same units multiple times in
// watch mode.
val extractedTastyUnits = LinkedHashMap[tpd.Tree, Option[ExtractedUnit]]()
while depth < 100 do
inoxCtx.reporter.debug(f"Extracting Tasty units at depth $depth:")
val newUnits =
newSymbols.map(sym => {
val extracted = extractUnit(sym.rootTree, sym.sourceOfClass, exportedSymsMapping, isFromSource = false).get
inoxCtx.reporter.debug(s"Extracted class ${sym.fullName} from classpath as unit ${extracted.unit.id}.")(using DebugSectionFrontend)
(sym -> extracted)
})
.toMap
if (newUnits.isEmpty) units.values.toSeq
else loop(units ++ newUnits, depth + 1)
loop(Map.empty, 0)
symbolMapping
.popUsedTastyUnits()
.filterNot((tree, _) => extractedTastyUnits.contains(tree))
.filterNot((tree, _) => tree.symbol.ownersIterator.exists(unextractedPackages))
.map((tree, info) => tree -> extractTastyUnit(tree, info))
if newUnits.isEmpty then
inoxCtx.reporter.debug(f"- No more units to extract.")
return extractedTastyUnits.values.flatten.toSeq
extractedTastyUnits ++= newUnits
depth += 1

// This should not be reached.
inoxCtx.reporter.error("Reached maximum depth while extracting Tasty units. This is likely a bug in Stainless.")
Nil
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,14 @@ package frontends.dotc

import scala.language.implicitConversions

import dotty.tools.dotc.ast.tpd
import dotty.tools.dotc.core.CompilationUnitInfo
import dotty.tools.dotc.core.Flags._
import dotty.tools.dotc.core.Symbols._
import dotty.tools.dotc.core.Contexts._
import stainless.ast.SymbolIdentifier

import scala.collection.mutable.{ Map => MutableMap, Set => MutableSet }
import scala.collection.mutable.{ Map => MutableMap, LinkedHashMap }

class SymbolMapping {
import SymbolMapping._
Expand All @@ -25,20 +27,48 @@ class SymbolMapping {
private val s2sAccessor = MutableMap[Symbol, SymbolIdentifier]()
private val s2sEnumType = MutableMap[Symbol, SymbolIdentifier]()

private val usedTastyClasses = MutableSet[ClassSymbol]()
def getUsedTastyClasses(): Set[ClassSymbol] = usedTastyClasses.toSet
/** Returns a mapping from Dotty unit trees loaded from Tasty files to their
* [[CompilationUnitInfo]].
*
* Everytime a [[Symbol]] is touched through [[fetch]], we check if it has an
* associated Tasty file. If it does, we register the symbol's root tree
* (which is the tree of the compilation unit containing the symbol's
* definition) and its associated compilation unit info (which contains the
* path to the Tasty file among other things).
*
* We keep this information to later extract compilation units defined in
* Tasty files, in [[StainlessExtraction.extractTastyUnits]].
*
* This a [[LinkedHashMap]] to keep elements in insertion order (there is no
* efficient immutable equivalent in the Scala standard library).
*
* Side-effect: calling this method clears the internal list of used Tasty
* units.
*
* Potential performance improvement: `Tree.equals` and `Tree.hashCode` might
* be suboptimal. Comparing by reference sould be sufficient.
*/
def popUsedTastyUnits(): LinkedHashMap[tpd.Tree, CompilationUnitInfo] =
val res = usedTastyUnits
usedTastyUnits = LinkedHashMap[tpd.Tree, CompilationUnitInfo]()
res

private var usedTastyUnits = LinkedHashMap[tpd.Tree, CompilationUnitInfo]()

private def maybeRegisterTastyClass(sym: Symbol)(using Context): Unit = {
private def maybeRegisterTastyUnit(sym: Symbol)(using Context): Unit = {
if (sym.tastyInfo.isDefined) {
usedTastyClasses += sym.topLevelClass.asClass
val classSym = sym.topLevelClass.asClass
// Below, `classSym.rootTree` returns the tree read from the Tasty file.
// It works because we passed the `-Yretain-trees` option to the compiler.
usedTastyUnits += (classSym.rootTree -> classSym.compilationUnitInfo)
}
}

/** Get the identifier associated with the given [[sym]], creating a new one if needed. */
def fetch(sym: Symbol, mode: FetchingMode)(using Context): SymbolIdentifier = mode match {
case Plain =>
s2s.getOrElseUpdate(sym, {
maybeRegisterTastyClass(sym)
maybeRegisterTastyUnit(sym)
val overrides = sym.allOverriddenSymbols.toSeq
val top = overrides.lastOption.getOrElse(sym)
if (top eq sym) {
Expand All @@ -49,7 +79,7 @@ class SymbolMapping {
})
case FieldAccessor =>
s2sAccessor.getOrElseUpdate(sym, {
maybeRegisterTastyClass(sym)
maybeRegisterTastyUnit(sym)
val overrides = sym.allOverriddenSymbols.toSeq
val top = overrides.lastOption.getOrElse(sym)
if (top eq sym) {
Expand All @@ -63,7 +93,7 @@ class SymbolMapping {
})
case EnumType =>
s2sEnumType.getOrElseUpdate(sym, {
maybeRegisterTastyClass(sym)
maybeRegisterTastyUnit(sym)
assert(sym.allOverriddenSymbols.isEmpty)
SymbolIdentifier(ast.Symbol(symFullName(sym)))
})
Expand Down

0 comments on commit 8d9d614

Please sign in to comment.