Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

NODE-2529 Nested pattern restrictions #3888

Merged
merged 11 commits into from
Nov 27, 2023
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ object StateSyntheticBenchmark {

val textScript = "sigVerify(tx.bodyBytes,tx.proofs[0],tx.senderPublicKey)"
val untypedScript = Parser.parseExpr(textScript).get.value
val typedScript = ExpressionCompiler(compilerContext(V1, Expression, isAssetScript = false), untypedScript).explicitGet()._1
val typedScript = ExpressionCompiler(compilerContext(V1, Expression, isAssetScript = false), V1, untypedScript).explicitGet()._1

val setScriptBlock = nextBlock(
Seq(
Expand Down
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
package com.wavesplatform.lang.v1.compiler

import java.nio.charset.StandardCharsets

import cats.Show
import com.wavesplatform.lang.v1.ContractLimits
import com.wavesplatform.lang.v1.compiler.Types.*
import com.wavesplatform.lang.v1.evaluator.ctx.FunctionTypeSignature
import com.wavesplatform.lang.v1.parser.Expressions
import com.wavesplatform.lang.v1.parser.Expressions.{Declaration, PART}

import java.nio.charset.StandardCharsets

sealed trait CompilationError {
def start: Int
def end: Int
Expand Down Expand Up @@ -52,18 +52,11 @@ object CompilationError {
s"but ${names.map(n => s"`$n`").mkString(", ")} found"
}

final case class UnusedCaseVariables(start: Int, end: Int, names: List[String]) extends CompilationError {
val message = s"Unused case variable(s) ${names.map(n => s"`$n`").mkString(", ")}"
}

final case class AlreadyDefined(start: Int, end: Int, name: String, isFunction: Boolean) extends CompilationError {
val message =
if (isFunction) s"Value '$name' can't be defined because function with this name is already defined"
else s"Value '$name' already defined in the scope"
}
final case class NonExistingType(start: Int, end: Int, name: String, existing: List[String]) extends CompilationError {
val message = s"Value '$name' declared as non-existing type, while all possible types are $existing"
}

final case class BadFunctionSignatureSameArgNames(start: Int, end: Int, name: String) extends CompilationError {
val message = s"Function '$name' declared with duplicating argument names"
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
package com.wavesplatform.lang.v1.compiler

import com.wavesplatform.lang.v1.compiler.Types.FINAL
import com.wavesplatform.lang.v1.parser.Expressions

case class CompilationStepResultDec(
ctx: CompilerContext,
dec: Terms.DECLARATION,
t: FINAL,
parseNodeExpr: Expressions.Declaration,
errors: Iterable[CompilationError] = Iterable.empty
)
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
package com.wavesplatform.lang.v1.compiler

import com.wavesplatform.lang.v1.compiler.Types.FINAL
import com.wavesplatform.lang.v1.parser.Expressions

case class CompilationStepResultExpr(
ctx: CompilerContext,
expr: Terms.EXPR,
t: FINAL,
parseNodeExpr: Expressions.EXPR,
errors: Iterable[CompilationError] = Iterable.empty
)
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ import com.wavesplatform.lang.contract.meta.{MetaMapper, V1 as MetaV1, V2 as Met
import com.wavesplatform.lang.directives.values.{StdLibVersion, V3, V6}
import com.wavesplatform.lang.v1.compiler.CompilationError.{AlreadyDefined, Generic, UnionNotAllowedForCallableArgs, WrongArgumentType}
import com.wavesplatform.lang.v1.compiler.CompilerContext.{VariableInfo, vars}
import com.wavesplatform.lang.v1.compiler.ExpressionCompiler.*
import com.wavesplatform.lang.v1.compiler.ContractCompiler.*
import com.wavesplatform.lang.v1.compiler.ScriptResultSource.FreeCall
import com.wavesplatform.lang.v1.compiler.Terms.EXPR
import com.wavesplatform.lang.v1.compiler.Types.{BOOLEAN, BYTESTR, LONG, STRING}
Expand All @@ -25,16 +25,13 @@ import com.wavesplatform.lang.v1.parser.Expressions.{FUNC, PART, Type}
import com.wavesplatform.lang.v1.parser.Parser.LibrariesOffset
import com.wavesplatform.lang.v1.parser.{Expressions, Parser}
import com.wavesplatform.lang.v1.task.imports.*
import com.wavesplatform.lang.v1.{ContractLimits, FunctionHeader, compiler}
import com.wavesplatform.lang.v1.{ContractLimits, FunctionHeader}

import scala.annotation.tailrec

object ContractCompiler {
val FreeCallInvocationArg = "i"

class ContractCompiler(version: StdLibVersion) extends ExpressionCompiler(version) {
private def compileAnnotatedFunc(
af: Expressions.ANNOTATEDFUNC,
version: StdLibVersion,
saveExprContext: Boolean,
allowIllFormedStrings: Boolean,
source: ScriptResultSource
Expand Down Expand Up @@ -88,10 +85,10 @@ object ContractCompiler {
_.flatMap(_.dic(version).toList).map(nameAndType => (nameAndType._1, VariableInfo(AnyPos, nameAndType._2)))
)
.getOrElse(List.empty)
unionInCallableErrs <- checkCallableUnions(af, annotationsWithErr._1.toList.flatten, version)
unionInCallableErrs <- checkCallableUnions(af, annotationsWithErr._1.toList.flatten)
compiledBody <- local {
modify[Id, CompilerContext, CompilationError](vars.modify(_)(_ ++ annotationBindings)).flatMap(_ =>
compiler.ExpressionCompiler.compileFunc(af.f.position, af.f, saveExprContext, annotationBindings.map(_._1), allowIllFormedStrings)
compileFunc(af.f.position, af.f, saveExprContext, annotationBindings.map(_._1), allowIllFormedStrings)
)
}
annotatedFuncWithErr <- getCompiledAnnotatedFunc(annotationsWithErr, compiledBody._1).handleError()
Expand Down Expand Up @@ -132,7 +129,6 @@ object ContractCompiler {

private def compileContract(
parsedDapp: Expressions.DAPP,
version: StdLibVersion,
needCompaction: Boolean,
removeUnusedCode: Boolean,
source: ScriptResultSource,
Expand All @@ -149,7 +145,7 @@ object ContractCompiler {
annFuncArgTypesErr <- validateAnnotatedFuncsArgTypes(parsedDapp).handleError()
compiledAnnFuncsWithErr <- parsedDapp.fs
.traverse[CompileM, (Option[AnnotatedFunction], List[(String, Types.FINAL)], Expressions.ANNOTATEDFUNC, Iterable[CompilationError])](af =>
local(compileAnnotatedFunc(af, version, saveExprContext, allowIllFormedStrings, source))
local(compileAnnotatedFunc(af, saveExprContext, allowIllFormedStrings, source))
)
annotatedFuncs = compiledAnnFuncsWithErr.filter(_._1.nonEmpty).map(_._1.get)
parsedNodeAFuncs = compiledAnnFuncsWithErr.map(_._3)
Expand Down Expand Up @@ -236,7 +232,7 @@ object ContractCompiler {
} yield result
}

def handleValid[T](part: PART[T]): CompileM[PART.VALID[T]] = part match {
private def handleValid[T](part: PART[T]): CompileM[PART.VALID[T]] = part match {
case x: PART.VALID[T] => x.pure[CompileM]
case PART.INVALID(p, message) => raiseError(Generic(p.start, p.end, message))
}
Expand Down Expand Up @@ -305,13 +301,7 @@ object ContractCompiler {
}
}

val primitiveCallableTypes: Set[String] =
Set(LONG, BYTESTR, BOOLEAN, STRING).map(_.name)

val allowedCallableTypesV4: Set[String] =
primitiveCallableTypes + "List[]"

private def validateDuplicateVarsInContract(contract: Expressions.DAPP): CompileM[Any] = {
private def validateDuplicateVarsInContract(contract: Expressions.DAPP): CompileM[Any] =
for {
ctx <- get[Id, CompilerContext, CompilationError]
annotationVars = contract.fs.flatMap(_.anns.flatMap(_.args)).traverse[CompileM, PART.VALID[String]](handleValid)
Expand Down Expand Up @@ -339,7 +329,52 @@ object ContractCompiler {
}
}
} yield ()

private def checkCallableUnions(
func: Expressions.ANNOTATEDFUNC,
annotations: List[Annotation],
): CompileM[Seq[UnionNotAllowedForCallableArgs]] = {
@tailrec
def containsUnion(tpe: Type): Boolean =
tpe match {
case Expressions.Union(types) if types.size > 1 => true
case Expressions.Single(PART.VALID(_, Type.ListTypeName), Some(PART.VALID(_, Expressions.Union(types)))) if types.size > 1 => true
case Expressions.Single(
PART.VALID(_, Type.ListTypeName),
Some(PART.VALID(_, inner @ Expressions.Single(PART.VALID(_, Type.ListTypeName), _)))
) =>
containsUnion(inner)
case _ => false
}

val isCallable = annotations.exists {
case CallableAnnotation(_) => true
case _ => false
}

if (version < V6 || !isCallable) {
Seq.empty[UnionNotAllowedForCallableArgs].pure[CompileM]
} else {
func.f.args
.filter { case (_, tpe) =>
containsUnion(tpe)
}
.map { case (argName, _) =>
UnionNotAllowedForCallableArgs(argName.position.start, argName.position.end)
}
.pure[CompileM]
}
}
}

object ContractCompiler {
val FreeCallInvocationArg = "i"

val primitiveCallableTypes: Set[String] =
Set(LONG, BYTESTR, BOOLEAN, STRING).map(_.name)

val allowedCallableTypesV4: Set[String] =
primitiveCallableTypes + "List[]"

def apply(
c: CompilerContext,
Expand All @@ -350,7 +385,8 @@ object ContractCompiler {
removeUnusedCode: Boolean = false,
allowIllFormedStrings: Boolean = false
): Either[String, DApp] = {
compileContract(contract, version, needCompaction, removeUnusedCode, source, allowIllFormedStrings = allowIllFormedStrings)
new ContractCompiler(version)
.compileContract(contract, needCompaction, removeUnusedCode, source, allowIllFormedStrings = allowIllFormedStrings)
.run(c)
.map(
_._2
Expand All @@ -375,7 +411,7 @@ object ContractCompiler {
val parser = new Parser(version)(offset)
parser.parseContract(input) match {
case fastparse.Parsed.Success(xs, _) =>
ContractCompiler(ctx, xs, version, source, needCompaction, removeUnusedCode, allowIllFormedStrings) match {
apply(ctx, xs, version, source, needCompaction, removeUnusedCode, allowIllFormedStrings) match {
case Left(err) => Left(err)
case Right(c) => Right(c)
}
Expand All @@ -396,7 +432,8 @@ object ContractCompiler {
new Parser(version)(offset)
.parseDAPPWithErrorRecovery(input)
.flatMap { case (parseResult, removedCharPosOpt) =>
compileContract(parseResult, version, needCompaction, removeUnusedCode, ScriptResultSource.CallableFunction, saveExprContext)
new ContractCompiler(version)
.compileContract(parseResult, needCompaction, removeUnusedCode, ScriptResultSource.CallableFunction, saveExprContext)
.run(ctx)
.map(
_._2
Expand Down Expand Up @@ -437,41 +474,4 @@ object ContractCompiler {
Left(parser.toString(input, f))
}
}

private def checkCallableUnions(
func: Expressions.ANNOTATEDFUNC,
annotations: List[Annotation],
version: StdLibVersion
): CompileM[Seq[UnionNotAllowedForCallableArgs]] = {
@tailrec
def containsUnion(tpe: Type): Boolean =
tpe match {
case Expressions.Union(types) if types.size > 1 => true
case Expressions.Single(PART.VALID(_, Type.ListTypeName), Some(PART.VALID(_, Expressions.Union(types)))) if types.size > 1 => true
case Expressions.Single(
PART.VALID(_, Type.ListTypeName),
Some(PART.VALID(_, inner @ Expressions.Single(PART.VALID(_, Type.ListTypeName), _)))
) =>
containsUnion(inner)
case _ => false
}

val isCallable = annotations.exists {
case CallableAnnotation(_) => true
case _ => false
}

if (version < V6 || !isCallable) {
Seq.empty[UnionNotAllowedForCallableArgs].pure[CompileM]
} else {
func.f.args
.filter { case (_, tpe) =>
containsUnion(tpe)
}
.map { case (argName, _) =>
UnionNotAllowedForCallableArgs(argName.position.start, argName.position.end)
}
.pure[CompileM]
}
}
}
Loading