diff --git a/compiler/src/dotty/tools/backend/jvm/BCodeBodyBuilder.scala b/compiler/src/dotty/tools/backend/jvm/BCodeBodyBuilder.scala index bf10e37943a8..1d559c9950f1 100644 --- a/compiler/src/dotty/tools/backend/jvm/BCodeBodyBuilder.scala +++ b/compiler/src/dotty/tools/backend/jvm/BCodeBodyBuilder.scala @@ -4,7 +4,7 @@ package jvm import scala.language.unsafeNulls -import scala.annotation.switch +import scala.annotation.{switch, tailrec} import scala.collection.mutable.SortedMap import scala.tools.asm @@ -79,9 +79,14 @@ trait BCodeBodyBuilder extends BCodeSkelBuilder { tree match { case Assign(lhs @ DesugaredSelect(qual, _), rhs) => + val savedStackHeight = stackHeight val isStatic = lhs.symbol.isStaticMember - if (!isStatic) { genLoadQualifier(lhs) } + if (!isStatic) { + genLoadQualifier(lhs) + stackHeight += 1 + } genLoad(rhs, symInfoTK(lhs.symbol)) + stackHeight = savedStackHeight lineNumber(tree) // receiverClass is used in the bytecode to access the field. using sym.owner may lead to IllegalAccessError val receiverClass = qual.tpe.typeSymbol @@ -145,7 +150,9 @@ trait BCodeBodyBuilder extends BCodeSkelBuilder { } genLoad(larg, resKind) + stackHeight += resKind.size genLoad(rarg, if (isShift) INT else resKind) + stackHeight -= resKind.size (code: @switch) match { case ADD => bc add resKind @@ -182,14 +189,19 @@ trait BCodeBodyBuilder extends BCodeSkelBuilder { if (isArrayGet(code)) { // load argument on stack assert(args.length == 1, s"Too many arguments for array get operation: $tree"); + stackHeight += 1 genLoad(args.head, INT) + stackHeight -= 1 generatedType = k.asArrayBType.componentType bc.aload(elementType) } else if (isArraySet(code)) { val List(a1, a2) = args + stackHeight += 1 genLoad(a1, INT) + stackHeight += 1 genLoad(a2) + stackHeight -= 2 generatedType = UNIT bc.astore(elementType) } else { @@ -223,7 +235,7 @@ trait BCodeBodyBuilder extends BCodeSkelBuilder { val resKind = if (hasUnitBranch) UNIT else tpeTK(tree) val postIf = new asm.Label - genLoadTo(thenp, resKind, LoadDestination.Jump(postIf)) + genLoadTo(thenp, resKind, LoadDestination.Jump(postIf, stackHeight)) markProgramPoint(failure) genLoadTo(elsep, resKind, LoadDestination.FallThrough) markProgramPoint(postIf) @@ -482,7 +494,17 @@ trait BCodeBodyBuilder extends BCodeSkelBuilder { dest match case LoadDestination.FallThrough => () - case LoadDestination.Jump(label) => + case LoadDestination.Jump(label, targetStackHeight) => + if targetStackHeight < stackHeight then + val stackDiff = stackHeight - targetStackHeight + if expectedType == UNIT then + bc dropMany stackDiff + else + val loc = locals.makeTempLocal(expectedType) + bc.store(loc.idx, expectedType) + bc dropMany stackDiff + bc.load(loc.idx, expectedType) + end if bc goTo label case LoadDestination.Return => bc emitRETURN returnType @@ -577,7 +599,7 @@ trait BCodeBodyBuilder extends BCodeSkelBuilder { if dest == LoadDestination.FallThrough then val resKind = tpeTK(tree) val jumpTarget = new asm.Label - registerJumpDest(labelSym, resKind, LoadDestination.Jump(jumpTarget)) + registerJumpDest(labelSym, resKind, LoadDestination.Jump(jumpTarget, stackHeight)) genLoad(expr, resKind) markProgramPoint(jumpTarget) resKind @@ -635,7 +657,7 @@ trait BCodeBodyBuilder extends BCodeSkelBuilder { markProgramPoint(loop) if isInfinite then - val dest = LoadDestination.Jump(loop) + val dest = LoadDestination.Jump(loop, stackHeight) genLoadTo(body, UNIT, dest) dest else @@ -650,7 +672,7 @@ trait BCodeBodyBuilder extends BCodeSkelBuilder { val failure = new asm.Label genCond(cond, success, failure, targetIfNoJump = success) markProgramPoint(success) - genLoadTo(body, UNIT, LoadDestination.Jump(loop)) + genLoadTo(body, UNIT, LoadDestination.Jump(loop, stackHeight)) markProgramPoint(failure) end match LoadDestination.FallThrough @@ -744,7 +766,9 @@ trait BCodeBodyBuilder extends BCodeSkelBuilder { // scala/bug#10290: qual can be `this.$outer()` (not just `this`), so we call genLoad (not just ALOAD_0) genLoad(superQual) + stackHeight += 1 genLoadArguments(args, paramTKs(app)) + stackHeight -= 1 generatedType = genCallMethod(fun.symbol, InvokeStyle.Super, app.span) // 'new' constructor call: Note: since constructors are @@ -766,7 +790,9 @@ trait BCodeBodyBuilder extends BCodeSkelBuilder { assert(classBTypeFromSymbol(ctor.owner) == rt, s"Symbol ${ctor.owner.showFullName} is different from $rt") mnode.visitTypeInsn(asm.Opcodes.NEW, rt.internalName) bc dup generatedType + stackHeight += 2 genLoadArguments(args, paramTKs(app)) + stackHeight -= 2 genCallMethod(ctor, InvokeStyle.Special, app.span) case _ => @@ -799,8 +825,12 @@ trait BCodeBodyBuilder extends BCodeSkelBuilder { else if (app.hasAttachment(BCodeHelpers.UseInvokeSpecial)) InvokeStyle.Special else InvokeStyle.Virtual - if (invokeStyle.hasInstance) genLoadQualifier(fun) + val savedStackHeight = stackHeight + if invokeStyle.hasInstance then + genLoadQualifier(fun) + stackHeight += 1 genLoadArguments(args, paramTKs(app)) + stackHeight = savedStackHeight val DesugaredSelect(qual, name) = fun: @unchecked // fun is a Select, also checked in genLoadQualifier val isArrayClone = name == nme.clone_ && qual.tpe.widen.isInstanceOf[JavaArrayType] @@ -858,6 +888,8 @@ trait BCodeBodyBuilder extends BCodeSkelBuilder { bc iconst elems.length bc newarray elmKind + stackHeight += 3 // during the genLoad below, there is the result, its dup, and the index + var i = 0 var rest = elems while (!rest.isEmpty) { @@ -869,6 +901,8 @@ trait BCodeBodyBuilder extends BCodeSkelBuilder { i = i + 1 } + stackHeight -= 3 + generatedType } @@ -883,7 +917,7 @@ trait BCodeBodyBuilder extends BCodeSkelBuilder { val (generatedType, postMatch, postMatchDest) = if dest == LoadDestination.FallThrough then val postMatch = new asm.Label - (tpeTK(tree), postMatch, LoadDestination.Jump(postMatch)) + (tpeTK(tree), postMatch, LoadDestination.Jump(postMatch, stackHeight)) else (expectedType, null, dest) @@ -1160,14 +1194,21 @@ trait BCodeBodyBuilder extends BCodeSkelBuilder { } def genLoadArguments(args: List[Tree], btpes: List[BType]): Unit = - args match - case arg :: args1 => - btpes match - case btpe :: btpes1 => - genLoad(arg, btpe) - genLoadArguments(args1, btpes1) - case _ => - case _ => + @tailrec def loop(args: List[Tree], btpes: List[BType]): Unit = + args match + case arg :: args1 => + btpes match + case btpe :: btpes1 => + genLoad(arg, btpe) + stackHeight += btpe.size + loop(args1, btpes1) + case _ => + case _ => + + val savedStackHeight = stackHeight + loop(args, btpes) + stackHeight = savedStackHeight + end genLoadArguments def genLoadModule(tree: Tree): BType = { val module = ( @@ -1266,11 +1307,14 @@ trait BCodeBodyBuilder extends BCodeSkelBuilder { }.sum bc.genNewStringBuilder(approxBuilderSize) + stackHeight += 1 // during the genLoad below, there is a reference to the StringBuilder on the stack for (elem <- concatArguments) { val elemType = tpeTK(elem) genLoad(elem, elemType) bc.genStringBuilderAppend(elemType) } + stackHeight -= 1 + bc.genStringBuilderEnd } else { @@ -1287,12 +1331,15 @@ trait BCodeBodyBuilder extends BCodeSkelBuilder { var totalArgSlots = 0 var countConcats = 1 // ie. 1 + how many times we spilled + val savedStackHeight = stackHeight + for (elem <- concatArguments) { val tpe = tpeTK(elem) val elemSlots = tpe.size // Unlikely spill case if (totalArgSlots + elemSlots >= MaxIndySlots) { + stackHeight = savedStackHeight + countConcats bc.genIndyStringConcat(recipe.toString, argTypes.result(), constVals.result()) countConcats += 1 totalArgSlots = 0 @@ -1317,8 +1364,10 @@ trait BCodeBodyBuilder extends BCodeSkelBuilder { val tpe = tpeTK(elem) argTypes += tpe.toASMType genLoad(elem, tpe) + stackHeight += 1 } } + stackHeight = savedStackHeight bc.genIndyStringConcat(recipe.toString, argTypes.result(), constVals.result()) // If we spilled, generate one final concat @@ -1513,7 +1562,9 @@ trait BCodeBodyBuilder extends BCodeSkelBuilder { } else { val tk = tpeTK(l).maxType(tpeTK(r)) genLoad(l, tk) + stackHeight += tk.size genLoad(r, tk) + stackHeight -= tk.size genCJUMP(success, failure, op, tk, targetIfNoJump) } } @@ -1628,7 +1679,9 @@ trait BCodeBodyBuilder extends BCodeSkelBuilder { } genLoad(l, ObjectRef) + stackHeight += 1 genLoad(r, ObjectRef) + stackHeight -= 1 genCallMethod(equalsMethod, InvokeStyle.Static) genCZJUMP(success, failure, Primitives.NE, BOOL, targetIfNoJump) } @@ -1644,7 +1697,9 @@ trait BCodeBodyBuilder extends BCodeSkelBuilder { } else if (isNonNullExpr(l)) { // SI-7852 Avoid null check if L is statically non-null. genLoad(l, ObjectRef) + stackHeight += 1 genLoad(r, ObjectRef) + stackHeight -= 1 genCallMethod(defn.Any_equals, InvokeStyle.Virtual) genCZJUMP(success, failure, Primitives.NE, BOOL, targetIfNoJump) } else { @@ -1654,7 +1709,9 @@ trait BCodeBodyBuilder extends BCodeSkelBuilder { val lNonNull = new asm.Label genLoad(l, ObjectRef) + stackHeight += 1 genLoad(r, ObjectRef) + stackHeight -= 1 locals.store(eqEqTempLocal) bc dup ObjectRef genCZJUMP(lNull, lNonNull, Primitives.EQ, ObjectRef, targetIfNoJump = lNull) diff --git a/compiler/src/dotty/tools/backend/jvm/BCodeIdiomatic.scala b/compiler/src/dotty/tools/backend/jvm/BCodeIdiomatic.scala index 2d4b22a10527..b86efb7cacb1 100644 --- a/compiler/src/dotty/tools/backend/jvm/BCodeIdiomatic.scala +++ b/compiler/src/dotty/tools/backend/jvm/BCodeIdiomatic.scala @@ -620,6 +620,16 @@ trait BCodeIdiomatic { // can-multi-thread final def drop(tk: BType): Unit = { emit(if (tk.isWideType) Opcodes.POP2 else Opcodes.POP) } + // can-multi-thread + final def dropMany(size: Int): Unit = { + var s = size + while s >= 2 do + emit(Opcodes.POP2) + s -= 2 + if s > 0 then + emit(Opcodes.POP) + } + // can-multi-thread final def dup(tk: BType): Unit = { emit(if (tk.isWideType) Opcodes.DUP2 else Opcodes.DUP) } diff --git a/compiler/src/dotty/tools/backend/jvm/BCodeSkelBuilder.scala b/compiler/src/dotty/tools/backend/jvm/BCodeSkelBuilder.scala index 1d8a9c579cb9..1885210a6687 100644 --- a/compiler/src/dotty/tools/backend/jvm/BCodeSkelBuilder.scala +++ b/compiler/src/dotty/tools/backend/jvm/BCodeSkelBuilder.scala @@ -45,7 +45,7 @@ trait BCodeSkelBuilder extends BCodeHelpers { /** The value is put on the stack, and control flows through to the next opcode. */ case FallThrough /** The value is put on the stack, and control flow is transferred to the given `label`. */ - case Jump(label: asm.Label) + case Jump(label: asm.Label, targetStackHeight: Int) /** The value is RETURN'ed from the enclosing method. */ case Return /** The value is ATHROW'n. */ @@ -368,6 +368,8 @@ trait BCodeSkelBuilder extends BCodeHelpers { // used by genLoadTry() and genSynchronized() var earlyReturnVar: Symbol = null var shouldEmitCleanup = false + // stack tracking + var stackHeight = 0 // line numbers var lastEmittedLineNr = -1 @@ -504,6 +506,13 @@ trait BCodeSkelBuilder extends BCodeHelpers { loc } + def makeTempLocal(tk: BType): Local = + assert(nxtIdx != -1, "not a valid start index") + assert(tk.size > 0, "makeLocal called for a symbol whose type is Unit.") + val loc = Local(tk, "temp", nxtIdx, isSynth = true) + nxtIdx += tk.size + loc + // not to be confused with `fieldStore` and `fieldLoad` which also take a symbol but a field-symbol. def store(locSym: Symbol): Unit = { val Local(tk, _, idx, _) = slots(locSym) @@ -574,6 +583,8 @@ trait BCodeSkelBuilder extends BCodeHelpers { earlyReturnVar = null shouldEmitCleanup = false + stackHeight = 0 + lastEmittedLineNr = -1 } diff --git a/compiler/src/dotty/tools/dotc/Compiler.scala b/compiler/src/dotty/tools/dotc/Compiler.scala index 375cdaaa2e94..b03953afb37c 100644 --- a/compiler/src/dotty/tools/dotc/Compiler.scala +++ b/compiler/src/dotty/tools/dotc/Compiler.scala @@ -88,7 +88,8 @@ class Compiler { new sjs.ExplicitJSClasses, // Make all JS classes explicit (Scala.js only) new ExplicitOuter, // Add accessors to outer classes from nested ones. new ExplicitSelf, // Make references to non-trivial self types explicit as casts - new StringInterpolatorOpt) :: // Optimizes raw and s and f string interpolators by rewriting them to string concatenations or formats + new StringInterpolatorOpt, // Optimizes raw and s and f string interpolators by rewriting them to string concatenations or formats + new DropBreaks) :: // Optimize local Break throws by rewriting them List(new PruneErasedDefs, // Drop erased definitions from scopes and simplify erased expressions new UninitializedDefs, // Replaces `compiletime.uninitialized` by `_` new InlinePatterns, // Remove placeholders of inlined patterns diff --git a/compiler/src/dotty/tools/dotc/ast/TreeInfo.scala b/compiler/src/dotty/tools/dotc/ast/TreeInfo.scala index ccc6e99737d4..eec32598de1a 100644 --- a/compiler/src/dotty/tools/dotc/ast/TreeInfo.scala +++ b/compiler/src/dotty/tools/dotc/ast/TreeInfo.scala @@ -102,6 +102,12 @@ trait TreeInfo[T <: Untyped] { self: Trees.Instance[T] => case _ => tree } + def stripTyped(tree: Tree): Tree = unsplice(tree) match + case Typed(expr, _) => + stripTyped(expr) + case _ => + tree + /** The number of arguments in an application */ def numArgs(tree: Tree): Int = unsplice(tree) match { case Apply(fn, args) => numArgs(fn) + args.length diff --git a/compiler/src/dotty/tools/dotc/core/Definitions.scala b/compiler/src/dotty/tools/dotc/core/Definitions.scala index 81a9d4db2e40..13a2e1559b06 100644 --- a/compiler/src/dotty/tools/dotc/core/Definitions.scala +++ b/compiler/src/dotty/tools/dotc/core/Definitions.scala @@ -968,6 +968,10 @@ class Definitions { def TupledFunctionClass(using Context): ClassSymbol = TupledFunctionTypeRef.symbol.asClass def RuntimeTupleFunctionsModule(using Context): Symbol = requiredModule("scala.runtime.TupledFunctions") + @tu lazy val boundaryModule: Symbol = requiredModule("scala.util.boundary") + @tu lazy val LabelClass: Symbol = requiredClass("scala.util.boundary.Label") + @tu lazy val BreakClass: Symbol = requiredClass("scala.util.boundary.Break") + @tu lazy val CapsModule: Symbol = requiredModule("scala.caps") @tu lazy val captureRoot: TermSymbol = CapsModule.requiredValue("*") @tu lazy val CapsUnsafeModule: Symbol = requiredModule("scala.caps.unsafe") diff --git a/compiler/src/dotty/tools/dotc/core/NameKinds.scala b/compiler/src/dotty/tools/dotc/core/NameKinds.scala index d121288a9cd8..2c968ab9446c 100644 --- a/compiler/src/dotty/tools/dotc/core/NameKinds.scala +++ b/compiler/src/dotty/tools/dotc/core/NameKinds.scala @@ -325,6 +325,8 @@ object NameKinds { val LocalOptInlineLocalObj: UniqueNameKind = new UniqueNameKind("ilo") + val BoundaryName: UniqueNameKind = new UniqueNameKind("boundary") + /** The kind of names of default argument getters */ val DefaultGetterName: NumberedNameKind = new NumberedNameKind(DEFAULTGETTER, "DefaultGetter") { def mkString(underlying: TermName, info: ThisInfo) = { diff --git a/compiler/src/dotty/tools/dotc/core/StdNames.scala b/compiler/src/dotty/tools/dotc/core/StdNames.scala index dff423fd0bb4..7755b7877ca8 100644 --- a/compiler/src/dotty/tools/dotc/core/StdNames.scala +++ b/compiler/src/dotty/tools/dotc/core/StdNames.scala @@ -420,6 +420,7 @@ object StdNames { val assert_ : N = "assert" val assume_ : N = "assume" val box: N = "box" + val break: N = "break" val build : N = "build" val bundle: N = "bundle" val bytes: N = "bytes" @@ -511,10 +512,12 @@ object StdNames { val isInstanceOfPM: N = "$isInstanceOf$" val java: N = "java" val key: N = "key" + val label: N = "label" val lang: N = "lang" val language: N = "language" val length: N = "length" val lengthCompare: N = "lengthCompare" + val local: N = "local" val longHash: N = "longHash" val macroThis : N = "_this" val macroContext : N = "c" diff --git a/compiler/src/dotty/tools/dotc/transform/DropBreaks.scala b/compiler/src/dotty/tools/dotc/transform/DropBreaks.scala new file mode 100644 index 000000000000..3081bd5c2b20 --- /dev/null +++ b/compiler/src/dotty/tools/dotc/transform/DropBreaks.scala @@ -0,0 +1,251 @@ +package dotty.tools +package dotc +package transform + +import ast.{Trees, tpd} +import core.* +import Decorators.* +import NameKinds.BoundaryName +import MegaPhase._ +import Types._, Contexts._, Flags._, DenotTransformers._ +import Symbols._, StdNames._, Trees._ +import util.Property +import Constants.Constant +import Flags.MethodOrLazy + +object DropBreaks: + val name: String = "dropBreaks" + val description: String = "replace local Break throws by labeled returns" + + /** Usage data and other info associated with a Label symbol. + * @param goto the return-label to use for a labeled return. + * @param enclMeth the enclosing method + */ + class LabelUsage(val goto: TermSymbol, val enclMeth: Symbol): + /** The number of references to associated label that come from labeled returns */ + var returnRefs: Int = 0 + /** The number of other references to associated label */ + var otherRefs: Int = 0 + + private val LabelUsages = new Property.Key[Map[Symbol, LabelUsage]] + private val ShadowedLabels = new Property.Key[Set[Symbol]] + +/** Rewrites local Break throws to labeled returns. + * Drops `try` statements on breaks if no other uses of its label remain. + * A Break throw with a `Label` created by some enclosing boundary is replaced + * with a labeled return if + * + * - the throw and the boundary are in the same method, and + * - there is no try expression inside the boundary that encloses the throw. + */ +class DropBreaks extends MiniPhase: + import DropBreaks.* + + import tpd._ + + override def phaseName: String = DropBreaks.name + + override def description: String = DropBreaks.description + + override def runsAfterGroupsOf: Set[String] = Set(ElimByName.name) + // we want by-name parameters to be converted to closures + + /** The number of boundary nodes enclosing the currently analized tree. */ + private var enclosingBoundaries: Int = 0 + + private object LabelTry: + + object GuardedThrow: + + /** `(ex, local)` provided `expr` matches + * + * if ex.label.eq(local) then ex.value else throw ex + */ + def unapply(expr: Tree)(using Context): Option[(Symbol, Symbol)] = stripTyped(expr) match + case If( + Apply(Select(Select(ex: Ident, label), eq), (lbl @ Ident(local)) :: Nil), + Select(ex2: Ident, value), + Apply(throww, (ex3: Ident) :: Nil)) + if label == nme.label && eq == nme.eq && local == nme.local && value == nme.value + && throww.symbol == defn.throwMethod + && ex.symbol == ex2.symbol && ex.symbol == ex3.symbol => + Some((ex.symbol, lbl.symbol)) + case _ => + None + end GuardedThrow + + /** `(local, body)` provided `tree` matches + * + * try body + * catch case ex: Break => + * if ex.label.eq(local) then ex.value else throw ex + */ + def unapply(tree: Tree)(using Context): Option[(Symbol, Tree)] = stripTyped(tree) match + case Try(body, CaseDef(pat @ Bind(_, Typed(_, tpt)), EmptyTree, GuardedThrow(exc, local)) :: Nil, EmptyTree) + if tpt.tpe.isRef(defn.BreakClass) && exc == pat.symbol => + Some((local, body)) + case _ => + None + end LabelTry + + private object BreakBoundary: + + /** `(local, body)` provided `tree` matches + * + * { val local: Label[...] = ...; } + */ + def unapply(tree: Tree)(using Context): Option[(Symbol, Tree)] = stripTyped(tree) match + case Block((vd @ ValDef(nme.local, _, _)) :: Nil, LabelTry(caughtAndRhs)) + if vd.symbol.info.isRef(defn.LabelClass) && vd.symbol == caughtAndRhs._1 => + Some(caughtAndRhs) + case _ => + None + end BreakBoundary + + private object Break: + + private def isBreak(sym: Symbol)(using Context): Boolean = + sym.name == nme.break && sym.owner == defn.boundaryModule.moduleClass + + /** `(local, arg)` provided `tree` matches + * + * break[...](arg)(local) + * + * or `(local, ())` provided `tree` matches + * + * break()(local) + */ + def unapply(tree: Tree)(using Context): Option[(Symbol, Tree)] = tree match + case Apply(Apply(fn, args), id :: Nil) + if isBreak(fn.symbol) => + stripInlined(id) match + case id: Ident => + val arg = (args: @unchecked) match + case arg :: Nil => arg + case Nil => Literal(Constant(())).withSpan(tree.span) + Some((id.symbol, arg)) + case _ => None + case _ => None + end Break + + /** The LabelUsage data associated with `lbl` in the current context */ + private def labelUsage(lbl: Symbol)(using Context): Option[LabelUsage] = + for + usesMap <- ctx.property(LabelUsages) + uses <- usesMap.get(lbl) + yield + uses + + /** If `tree` is a BreakBoundary, associate a fresh `LabelUsage` with its label. */ + override def prepareForBlock(tree: Block)(using Context): Context = tree match + case BreakBoundary(label, _) => + enclosingBoundaries += 1 + val mapSoFar = ctx.property(LabelUsages).getOrElse(Map.empty) + val goto = newSymbol(ctx.owner, BoundaryName.fresh(), Synthetic | Label, tree.tpe) + ctx.fresh.setProperty(LabelUsages, + mapSoFar.updated(label, LabelUsage(goto, ctx.owner.enclosingMethod))) + case _ => + ctx + + /** Include all enclosing labels in the `ShadowedLabels` context property. + * This means that breaks to these labels will not be translated to labeled + * returns while this context is valid. + */ + private def shadowLabels(using Context): Context = + ctx.property(LabelUsages) match + case Some(usesMap) => + val setSoFar = ctx.property(ShadowedLabels).getOrElse(Set.empty) + ctx.fresh.setProperty(ShadowedLabels, setSoFar ++ usesMap.keysIterator) + case _ => ctx + + /** Need to suppress labeled returns if there is an intervening try + */ + override def prepareForTry(tree: Try)(using Context): Context = + if enclosingBoundaries == 0 then ctx + else tree match + case LabelTry(_, _) => ctx + case _ => shadowLabels + + override def prepareForValDef(tree: ValDef)(using Context): Context = + if enclosingBoundaries != 0 + && tree.symbol.is(Lazy) + && tree.symbol.owner == ctx.owner.enclosingMethod + then shadowLabels // RHS be converted to a lambda + else ctx + + /** If `tree` is a BreakBoundary, transform it as follows: + * - Wrap it in a labeled block if its label has local uses + * - Drop the try/catch if its label has no other uses + */ + override def transformBlock(tree: Block)(using Context): Tree = tree match + case BreakBoundary(label, expr) => + enclosingBoundaries -= 1 + val uses = ctx.property(LabelUsages).get(label) + val tree1 = + if uses.otherRefs > 1 then + // one non-local ref is always in the catch clause; this one does not count + tree + else + expr + report.log(i"trans boundary block $label // ${uses.returnRefs}, ${uses.otherRefs}") + if uses.returnRefs > 0 then Labeled(uses.goto, tree1) else tree1 + case _ => + tree + + private def isBreak(sym: Symbol)(using Context): Boolean = + sym.name == nme.break && sym.owner == defn.boundaryModule.moduleClass + + private def transformBreak(tree: Tree, arg: Tree, lbl: Symbol)(using Context): Tree = + report.log(i"transform break $tree/$arg/$lbl") + labelUsage(lbl) match + case Some(uses: LabelUsage) + if uses.enclMeth == ctx.owner.enclosingMethod + && !ctx.property(ShadowedLabels).getOrElse(Set.empty).contains(lbl) + => + uses.otherRefs -= 1 + uses.returnRefs += 1 + Return(arg, ref(uses.goto)).withSpan(arg.span) + case _ => + tree + + + /** Rewrite a break call + * + * break.apply[...](value)(using lbl) + * + * where `lbl` is a label defined in the current method and is not included in + * ShadowedLabels to + * + * return[target] arg + * + * where `target` is the `goto` return label associated with `lbl`. + * Adjust associated ref counts accordingly. The local refcount is increased + * and the non-local refcount is decreased, since the `lbl` implicit argument + * to `break` is dropped. + */ + override def transformApply(tree: Apply)(using Context): Tree = + if enclosingBoundaries == 0 then tree + else tree match + case Break(lbl, arg) => + labelUsage(lbl) match + case Some(uses: LabelUsage) + if uses.enclMeth == ctx.owner.enclosingMethod + && !ctx.property(ShadowedLabels).getOrElse(Set.empty).contains(lbl) + => + uses.otherRefs -= 1 + uses.returnRefs += 1 + Return(arg, ref(uses.goto)).withSpan(arg.span) + case _ => tree + case _ => tree + + /** If `tree` refers to an enclosing label, increase its non local recount. + * This increase is corrected in `transformInlined` if the reference turns + * out to be part of a BreakThrow to a local, non-shadowed label. + */ + override def transformIdent(tree: Ident)(using Context): Tree = + if enclosingBoundaries != 0 then + for uses <- labelUsage(tree.symbol) do + uses.otherRefs += 1 + tree + +end DropBreaks diff --git a/compiler/src/dotty/tools/dotc/transform/NonLocalReturns.scala b/compiler/src/dotty/tools/dotc/transform/NonLocalReturns.scala index ddf858994220..a75d6da9dd6a 100644 --- a/compiler/src/dotty/tools/dotc/transform/NonLocalReturns.scala +++ b/compiler/src/dotty/tools/dotc/transform/NonLocalReturns.scala @@ -97,7 +97,7 @@ class NonLocalReturns extends MiniPhase { override def transformReturn(tree: Return)(using Context): Tree = if isNonLocalReturn(tree) then report.gradualErrorOrMigrationWarning( - em"Non local returns are no longer supported; use scala.util.control.NonLocalReturns instead", + em"Non local returns are no longer supported; use `boundary` and `boundary.break` in `scala.util` instead", tree.srcPos, warnFrom = `3.2`, errorFrom = future) diff --git a/compiler/test/dotty/tools/backend/jvm/LabelBytecodeTests.scala b/compiler/test/dotty/tools/backend/jvm/LabelBytecodeTests.scala new file mode 100644 index 000000000000..aea567b87f91 --- /dev/null +++ b/compiler/test/dotty/tools/backend/jvm/LabelBytecodeTests.scala @@ -0,0 +1,166 @@ +package dotty.tools.backend.jvm + +import scala.language.unsafeNulls + +import org.junit.Assert._ +import org.junit.Test + +import scala.tools.asm +import asm._ +import asm.tree._ + +import scala.tools.asm.Opcodes +import scala.jdk.CollectionConverters._ +import Opcodes._ + +class LabelBytecodeTests extends DottyBytecodeTest { + import ASMConverters._ + + @Test def localLabelBreak = { + testLabelBytecodeEquals( + """val local = boundary.Label[Long]() + |try break(5L)(using local) + |catch case ex: boundary.Break[Long] @unchecked => + | if ex.label eq local then ex.value + | else throw ex + """.stripMargin, + "Long", + Ldc(LDC, 5), + Op(LRETURN) + ) + } + + @Test def simpleBoundaryBreak = { + testLabelBytecodeEquals( + """boundary: l ?=> + | break(2)(using l) + """.stripMargin, + "Int", + Op(ICONST_2), + Op(IRETURN) + ) + + testLabelBytecodeEquals( + """boundary: + | break(3) + """.stripMargin, + "Int", + Op(ICONST_3), + Op(IRETURN) + ) + + testLabelBytecodeEquals( + """boundary: + | break() + """.stripMargin, + "Unit", + Op(RETURN) + ) + } + + @Test def labelExtraction = { + // Test extra Inlined around the label + testLabelBytecodeEquals( + """boundary: + | break(2)(using summon[boundary.Label[Int]]) + """.stripMargin, + "Int", + Op(ICONST_2), + Op(IRETURN) + ) + + // Test extra Block around the label + testLabelBytecodeEquals( + """boundary: l ?=> + | break(2)(using { l }) + """.stripMargin, + "Int", + Op(ICONST_2), + Op(IRETURN) + ) + } + + @Test def boundaryLocalBreak = { + testLabelBytecodeExpect( + """val x: Boolean = true + |boundary[Unit]: + | var i = 0 + | while true do + | i += 1 + | if i > 10 then break() + """.stripMargin, + "Unit", + !throws(_) + ) + } + + @Test def boundaryNonLocalBreak = { + testLabelBytecodeExpect( + """boundary[Unit]: + | nonLocalBreak() + """.stripMargin, + "Unit", + throws + ) + + testLabelBytecodeExpect( + """boundary[Unit]: + | def f() = break() + | f() + """.stripMargin, + "Unit", + throws + ) + } + + @Test def boundaryLocalAndNonLocalBreak = { + testLabelBytecodeExpect( + """boundary[Unit]: l ?=> + | break() + | nonLocalBreak() + """.stripMargin, + "Unit", + throws + ) + } + + private def throws(instructions: List[Instruction]): Boolean = + instructions.exists { + case Op(ATHROW) => true + case _ => false + } + + private def testLabelBytecodeEquals(code: String, tpe: String, expected: Instruction*): Unit = + checkLabelBytecodeInstructions(code, tpe) { instructions => + val expectedList = expected.toList + assert(instructions == expectedList, + "`test` was not properly generated\n" + diffInstructions(instructions, expectedList)) + } + + private def testLabelBytecodeExpect(code: String, tpe: String, expected: List[Instruction] => Boolean): Unit = + checkLabelBytecodeInstructions(code, tpe) { instructions => + assert(expected(instructions), + "`test` was not properly generated\n" + instructions) + } + + private def checkLabelBytecodeInstructions(code: String, tpe: String)(checkOutput: List[Instruction] => Unit): Unit = { + val source = + s"""import scala.util.boundary, boundary.break + |class Test: + | def test: $tpe = { + | ${code.linesIterator.toList.mkString("", "\n ", "")} + | } + | def nonLocalBreak[T](value: T)(using boundary.Label[T]): Nothing = break(value) + | def nonLocalBreak()(using boundary.Label[Unit]): Nothing = break(()) + """.stripMargin + + checkBCode(source) { dir => + val clsIn = dir.lookupName("Test.class", directory = false).input + val clsNode = loadClassNode(clsIn) + val method = getMethod(clsNode, "test") + + checkOutput(instructionsFromMethod(method)) + } + } + +} diff --git a/library/src/scala/util/boundary.scala b/library/src/scala/util/boundary.scala new file mode 100644 index 000000000000..3c6c6982c7ee --- /dev/null +++ b/library/src/scala/util/boundary.scala @@ -0,0 +1,62 @@ +package scala.util + +/** A boundary that can be exited by `break` calls. + * `boundary` and `break` represent a unified and superior alternative for the + * `scala.util.control.NonLocalReturns` and `scala.util.control.Breaks` APIs. + * The main differences are: + * + * - Unified names: `boundary` to establish a scope, `break` to leave it. + * `break` can optionally return a value. + * - Integration with exceptions. `break`s are logically non-fatal exceptions. + * The `Break` exception class extends `RuntimeException` and is optimized so + * that stack trace generation is suppressed. + * - Better performance: breaks to enclosing scopes in the same method can + * be rewritten to jumps. + * + * Example usage: + * + * import scala.util.boundary, boundary.break + * + * def firstIndex[T](xs: List[T], elem: T): Int = + * boundary: + * for (x, i) <- xs.zipWithIndex do + * if x == elem then break(i) + * -1 + */ +object boundary: + + /** User code should call `break.apply` instead of throwing this exception + * directly. + */ + final class Break[T] private[boundary](val label: Label[T], val value: T) + extends RuntimeException( + /*message*/ null, /*cause*/ null, /*enableSuppression=*/ false, /*writableStackTrace*/ false) + + /** Labels are targets indicating which boundary will be exited by a `break`. + */ + final class Label[-T] + + /** Abort current computation and instead return `value` as the value of + * the enclosing `boundary` call that created `label`. + */ + def break[T](value: T)(using label: Label[T]): Nothing = + throw Break(label, value) + + /** Abort current computation and instead continue after the `boundary` call that + * created `label`. + */ + def break()(using label: Label[Unit]): Nothing = + throw Break(label, ()) + + /** Run `body` with freshly generated label as implicit argument. Catch any + * breaks associated with that label and return their results instead of + * `body`'s result. + */ + inline def apply[T](inline body: Label[T] ?=> T): T = + val local = Label[T]() + try body(using local) + catch case ex: Break[T] @unchecked => + if ex.label eq local then ex.value + else throw ex + +end boundary diff --git a/library/src/scala/util/control/NonLocalReturns.scala b/library/src/scala/util/control/NonLocalReturns.scala index c32e0ff16457..ad4dc05f36ac 100644 --- a/library/src/scala/util/control/NonLocalReturns.scala +++ b/library/src/scala/util/control/NonLocalReturns.scala @@ -7,8 +7,19 @@ package scala.util.control * import scala.util.control.NonLocalReturns.* * * returning { ... throwReturn(x) ... } + * + * This API has been deprecated. Its functionality is better served by + * + * - `scala.util.boundary` in place of `returning` + * - `scala.util.break` in place of `throwReturn` + * + * The new abstractions work with plain `RuntimeExceptions` and are more + * performant, since returns within the scope of the same method can be + * rewritten by the compiler to jumps. */ +@deprecated("Use scala.util.boundary instead", "3.3") object NonLocalReturns { + @deprecated("Use scala.util.boundary.Break instead", "3.3") class ReturnThrowable[T] extends ControlThrowable { private var myResult: T = _ def throwReturn(result: T): Nothing = { @@ -19,10 +30,12 @@ object NonLocalReturns { } /** Performs a nonlocal return by throwing an exception. */ + @deprecated("Use scala.util.boundary.break instead", "3.3") def throwReturn[T](result: T)(using returner: ReturnThrowable[? >: T]): Nothing = returner.throwReturn(result) /** Enable nonlocal returns in `op`. */ + @deprecated("Use scala.util.boundary instead", "3.3") def returning[T](op: ReturnThrowable[T] ?=> T): T = { val returner = new ReturnThrowable[T] try op(using returner) diff --git a/tests/pos/simple-boundary.scala b/tests/pos/simple-boundary.scala new file mode 100644 index 000000000000..b7adaf4e11cf --- /dev/null +++ b/tests/pos/simple-boundary.scala @@ -0,0 +1,4 @@ +import scala.util.boundary, boundary.break +def test: Unit = + boundary: label ?=> + while true do break() diff --git a/tests/run/break-opt.check b/tests/run/break-opt.check new file mode 100644 index 000000000000..ac1be39bbe0a --- /dev/null +++ b/tests/run/break-opt.check @@ -0,0 +1,2 @@ +done +done diff --git a/tests/run/break-opt.scala b/tests/run/break-opt.scala new file mode 100644 index 000000000000..ec979ff0e8ad --- /dev/null +++ b/tests/run/break-opt.scala @@ -0,0 +1,101 @@ +import scala.util.boundary, boundary.break + +object breakOpt: + + var zero = 0 + + def test1(x: Int): Int = + boundary: + if x < 0 then break(zero) + x + + def test2(xx: Int, xs: List[Int]): Int = + boundary: + if xx < 0 then break(zero) + xs.map: y => + if y < 0 then break(zero) + y + xx + xs.sum + + def test3(xx: Int, xs: List[Int]): Int = + def cond[T](p: Boolean, x: => T, y: => T): T = + if p then x else y + boundary: + cond(true, { if xx < 0 then break(zero); xx }, xx) + + def test3a(xx: Int, xs: List[Int]): Int = + inline def cond[T](p: Boolean, inline x: T, y: => T): T = + if p then x else y + boundary: + cond(true, { if xx < 0 then break(zero); xx }, xx) + + def test4(x: Int): Int = + boundary: + try + if x < 0 then break(zero) + boundary: + if x == 0 then break(-1) + x + finally + println("done") + + def test5(x: Int): Int = + boundary: lab1 ?=> + if x < 0 then break(zero) + boundary: + if x == 0 then break(-1) + if x > 0 then break(+1)(using lab1) + x + + def test6(x0: Int): Int = + var x = x0 + var y = x + boundary: + while true do + y = y * x + x -= 1 + if x == 0 then break() + y + + def test7(x0: Int): Option[Int] = + val result = + boundary: + Some( + 1 + ( + if x0 < 0 then break(None) // no jump possible, since stacksize changes and no direct RETURN + else x0 + ) + ) + result.map(_ + 10) + + def test8(x0: Int): Option[Int] = + boundary: + lazy val x = + if x0 < 0 then break(None) // no jump possible, since ultimately in a different method + else x0 + 1 + Some(x) + + def test9(x0: Int): Option[Int] = + boundary: + def x = + if x0 < 0 then break(None) // no jump possible, since in a different method + else x0 + 1 + Some(x) + +@main def Test = + import breakOpt.* + assert(test1(0) == 0) + assert(test1(-1) == 0) + assert(test2(1, List(1, 2, 3)) == 7) + assert(test2(-1, List(1, 2, 3)) == 0) + assert(test2(1, List(1, -2, 3)) == 0) + test4(1) + test4(-1) + assert(test5(2) == 1) + assert(test6(3) == 18) + assert(test7(3) == Some(14)) + assert(test7(-3) == None) + assert(test8(3) == Some(4)) + assert(test8(-3) == None) + assert(test9(3) == Some(4)) + assert(test9(-3) == None) diff --git a/tests/run/breaks.scala b/tests/run/breaks.scala new file mode 100644 index 000000000000..3036bc6c9124 --- /dev/null +++ b/tests/run/breaks.scala @@ -0,0 +1,38 @@ +import scala.util.boundary, boundary.break +import collection.mutable.ListBuffer + +object Test { + def has(xs: List[Int], elem: Int) = + boundary: + for x <- xs do + if x == elem then break(true) + false + + def takeUntil(xs: List[Int], elem: Int) = + boundary: + var buf = new ListBuffer[Int] + for x <- xs yield + if x == elem then break(buf.toList) + buf += x + x + + trait Animal + object Dog extends Animal + object Cat extends Animal + + def animal(arg: Int): Animal = + boundary: + if arg < 0 then break(Dog) + Cat + + def main(arg: Array[String]): Unit = { + assert(has(1 :: 2 :: Nil, 1)) + assert(has(1 :: 2 :: Nil, 2)) + assert(!has(1 :: 2 :: Nil, 3)) + assert(animal(1) == Cat) + assert(animal(-1) == Dog) + + assert(has(List(1, 2, 3), 2)) + assert(takeUntil(List(1, 2, 3), 3) == List(1, 2)) + } +} \ No newline at end of file diff --git a/tests/run/errorhandling.check b/tests/run/errorhandling.check new file mode 100644 index 000000000000..882ca57f5022 --- /dev/null +++ b/tests/run/errorhandling.check @@ -0,0 +1,4 @@ +breakTest +optTest +resultTest +Person(Kostas,5) diff --git a/tests/run/errorhandling/Result.scala b/tests/run/errorhandling/Result.scala new file mode 100644 index 000000000000..027c07c86769 --- /dev/null +++ b/tests/run/errorhandling/Result.scala @@ -0,0 +1,67 @@ +package scala.util +import boundary.{Label, break} + +abstract class Result[+T, +E] +case class Ok[+T](value: T) extends Result[T, Nothing] +case class Err[+E](value: E) extends Result[Nothing, E] + +object Result: + extension [T, E](r: Result[T, E]) + + /** `_.?` propagates Err to current Label */ + transparent inline def ? (using Label[Err[E]]): T = r match + case r: Ok[_] => r.value + case err => break(err.asInstanceOf[Err[E]]) + + /** If this is an `Err`, map its value */ + def mapErr[E1](f: E => E1): Result[T, E1] = r match + case err: Err[_] => Err(f(err.value)) + case ok: Ok[_] => ok + + /** Map Ok values, propagate Errs */ + def map[U](f: T => U): Result[U, E] = r match + case Ok(x) => Ok(f(x)) + case err: Err[_] => err + + /** Flatmap Ok values, propagate Errs */ + def flatMap[U](f: T => Result[U, E]): Result[U, E] = r match + case Ok(x) => f(x) + case err: Err[_] => err + + /** Validate both `r` and `other`; return a pair of successes or a list of failures. */ + def * [U](other: Result[U, E]): Result[(T, U), List[E]] = (r, other) match + case (Ok(x), Ok(y)) => Ok((x, y)) + case (Ok(_), Err(e)) => Err(e :: Nil) + case (Err(e), Ok(_)) => Err(e :: Nil) + case (Err(e1), Err(e2)) => Err(e1 :: e2 :: Nil) + + /** Validate both `r` and `other`; return a tuple of successes or a list of failures. + * Unlike with `*`, the right hand side `other` must be a `Result` returning a `Tuple`, + * and the left hand side is added to it. See `Result.empty` for a convenient + * right unit of chains of `*:`s. + */ + def *: [U <: Tuple](other: Result[U, List[E]]): Result[T *: U, List[E]] = (r, other) match + case (Ok(x), Ok(ys)) => Ok(x *: ys) + case (Ok(_), es: Err[?]) => es + case (Err(e), Ok(_)) => Err(e :: Nil) + case (Err(e), Err(es)) => Err(e :: es) + end extension + + /** Similar to `Try`: Convert exceptions raised by `body` to `Err`s. + */ + def apply[T](body: => T): Result[T, Exception] = + try Ok(body) + catch case ex: Exception => Err(ex) + + /** Right unit for chains of `*:`s. Returns an `Ok` with an `EmotyTuple` value. */ + def empty: Result[EmptyTuple, Nothing] = Ok(EmptyTuple) +end Result + +/** A prompt for `_.?`. It establishes a boundary to which `_.?` returns */ +object respond: + inline def apply[T, E](inline body: Label[Err[E]] ?=> T): Result[T, E] = + boundary: + val result = body + Ok(result) + + diff --git a/tests/run/errorhandling/Test.scala b/tests/run/errorhandling/Test.scala new file mode 100644 index 000000000000..4aa1cd28c5aa --- /dev/null +++ b/tests/run/errorhandling/Test.scala @@ -0,0 +1,78 @@ +import scala.util.*, boundary.break + +/** boundary/break as a replacement for non-local returns */ +def indexOf[T](xs: List[T], elem: T): Int = + boundary: + for (x, i) <- xs.zipWithIndex do + if x == elem then break(i) + -1 + +def breakTest() = + println("breakTest") + assert(indexOf(List(1, 2, 3), 2) == 1) + assert(indexOf(List(1, 2, 3), 0) == -1) + +/** traverse becomes trivial to write */ +def traverse[T](xs: List[Option[T]]): Option[List[T]] = + optional(xs.map(_.?)) + +def optTest() = + println("optTest") + assert(traverse(List(Some(1), Some(2), Some(3))) == Some(List(1, 2, 3))) + assert(traverse(List(Some(1), None, Some(3))) == None) + +/** A check function returning a Result[Unit, _] */ +inline def check[E](p: Boolean, err: E): Result[Unit, E] = + if p then Ok(()) else Err(err) + +/** Another variant of a check function that returns directly to the given + * label in case of error. + */ +inline def check_![E](p: Boolean, err: E)(using l: boundary.Label[Err[E]]): Unit = + if p then () else break(Err(err)) + +/** Use `Result` to convert exceptions to `Err` values */ +def parseDouble(s: String): Result[Double, Exception] = + Result(s.toDouble) + +def parseDoubles(ss: List[String]): Result[List[Double], Exception] = + respond: + ss.map(parseDouble(_).?) + +/** Demonstrate combination of `check` and `.?`. */ +def trySqrt(x: Double) = // inferred: Result[Double, String] + respond: + check(x >= 0, s"cannot take sqrt of negative $x").? // direct jump + math.sqrt(x) + +/** Instead of `check(...).?` one can also use `check_!(...)`. + * Note use of `mapErr` to convert Exception errors to String errors. + */ +def sumRoots(xs: List[String]) = // inferred: Result[Double, String] + respond: + check_!(xs.nonEmpty, "list is empty") // direct jump + val ys = parseDoubles(xs).mapErr(_.toString).? // direct jump + ys.reduce((x, y) => x + trySqrt(y).?) // need exception to propagate `Err` + +def resultTest() = + println("resultTest") + def assertFail(value: Any, s: String) = value match + case Err(msg: String) => assert(msg.contains(s)) + assert(sumRoots(List("1", "4", "9")) == Ok(6)) + assertFail(sumRoots(List("1", "-2", "4")), "cannot take sqrt of negative") + assertFail(sumRoots(List()), "list is empty") + assertFail(sumRoots(List("1", "3ab")), "NumberFormatException") + val xs = sumRoots(List("1", "-2", "4")) *: sumRoots(List()) *: sumRoots(List("1", "3ab")) *: Result.empty + xs match + case Err(msgs) => assert(msgs.length == 3) + case _ => assert(false) + val ys = sumRoots(List("1", "2", "4")) *: sumRoots(List("1")) *: sumRoots(List("2")) *: Result.empty + ys match + case Ok((a, b, c)) => // ok + case _ => assert(false) + +@main def Test = + breakTest() + optTest() + resultTest() + parseCsvIgnoreErrors() \ No newline at end of file diff --git a/tests/run/errorhandling/kostas.scala b/tests/run/errorhandling/kostas.scala new file mode 100644 index 000000000000..085275d5cd82 --- /dev/null +++ b/tests/run/errorhandling/kostas.scala @@ -0,0 +1,35 @@ +package optionMockup: + import scala.util.boundary, boundary.break + object optional: + transparent inline def apply[T](inline body: boundary.Label[None.type] ?=> T): Option[T] = + boundary(Some(body)) + + extension [T](r: Option[T]) + transparent inline def ? (using label: boundary.Label[None.type]): T = r match + case Some(x) => x + case None => break(None) + +import optionMockup.* + +case class Person(name: String, age: Int) + +object PersonCsvParserIgnoreErrors: + def parse(csv: Seq[String]): Seq[Person] = + for + line <- csv + columns = line.split(",") + parsed <- parseColumns(columns) + yield + parsed + + private def parseColumns(columns: Seq[String]): Option[Person] = + columns match + case Seq(name, age) => parsePerson(name, age) + case _ => None + + private def parsePerson(name: String, age: String): Option[Person] = + optional: + Person(name, age.toIntOption.?) + +def parseCsvIgnoreErrors() = + println(PersonCsvParserIgnoreErrors.parse(Seq("Kostas,5", "George,invalid", "too,many,columns")).mkString("\n")) \ No newline at end of file diff --git a/tests/run/errorhandling/optional.scala b/tests/run/errorhandling/optional.scala new file mode 100644 index 000000000000..e041834b825c --- /dev/null +++ b/tests/run/errorhandling/optional.scala @@ -0,0 +1,20 @@ +package scala.util +import boundary.{break, Label} + +/** A mockup of scala.Option */ +abstract class Option[+T] +case class Some[+T](x: T) extends Option[T] +case object None extends Option[Nothing] + +object Option: + /** This extension should be added to the companion object of scala.Option */ + extension [T](r: Option[T]) + inline def ? (using label: Label[None.type]): T = r match + case Some(x) => x + case None => break(None) + +/** A prompt for `Option`, which establishes a boundary which `_.?` on `Option` can return */ +object optional: + inline def apply[T](inline body: Label[None.type] ?=> T): Option[T] = + boundary(Some(body)) + diff --git a/tests/run/loops-alt.scala b/tests/run/loops-alt.scala new file mode 100644 index 000000000000..19c3312bff50 --- /dev/null +++ b/tests/run/loops-alt.scala @@ -0,0 +1,46 @@ +import scala.util.boundary +import boundary.{break, Label} +import java.util.concurrent.TimeUnit + +object loop: + opaque type ExitLabel = Label[Unit] + opaque type ContinueLabel = Label[Unit] + + inline def apply(inline op: (ExitLabel, ContinueLabel) ?=> Unit): Unit = + boundary { exitLabel ?=> + while true do + boundary { continueLabel ?=> + op(using exitLabel, continueLabel) + } + } + end apply + + inline def exit()(using ExitLabel): Unit = + break() + + inline def continue()(using ContinueLabel): Unit = + break() +end loop + +def testLoop(xs: List[Int]) = + var current = xs + var sum = 0 + loop: + // This should be convertible to labeled returns but isn't, since + // the following code is still passed as a closure to `boundary`. + // That's probably due to the additional facade operations necessary + // for opaque types. + if current.isEmpty then loop.exit() + val hd = current.head + current = current.tail + if hd == 0 then loop.exit() + if hd < 0 then loop.continue() + sum += hd + sum + +@main def Test = + assert(testLoop(List(1, 2, 3, -2, 4, -3, 0, 1, 2, 3)) == 10) + assert(testLoop(List()) == 0) + assert(testLoop(List(-2, -3, 0, 1)) == 0) + + diff --git a/tests/run/loops.scala b/tests/run/loops.scala new file mode 100644 index 000000000000..e6a9d3c98dbe --- /dev/null +++ b/tests/run/loops.scala @@ -0,0 +1,37 @@ +import scala.util.boundary, boundary.break + +object loop: + + // We could use a boolean instead, but with `Ctrl` label types in using clauses are + // more specific. + enum Ctrl: + case Exit, Continue + + inline def apply(inline op: boundary.Label[Ctrl] ?=> Unit) = + while boundary { op; Ctrl.Continue } == Ctrl.Continue do () + + inline def exit()(using boundary.Label[Ctrl]): Unit = + break(Ctrl.Exit) + + inline def continue()(using boundary.Label[Ctrl]): Unit = + break(Ctrl.Continue) +end loop + +def testLoop(xs: List[Int]) = + var current = xs + var sum = 0 + loop: + if current.isEmpty then loop.exit() + val hd = current.head + current = current.tail + if hd == 0 then loop.exit() + if hd < 0 then loop.continue() + sum += hd + sum + +@main def Test = + assert(testLoop(List(1, 2, 3, -2, 4, -3, 0, 1, 2, 3)) == 10) + assert(testLoop(List()) == 0) + assert(testLoop(List(-2, -3, 0, 1)) == 0) + + diff --git a/tests/run/rescue-boundary.scala b/tests/run/rescue-boundary.scala new file mode 100644 index 000000000000..3e7fe8508686 --- /dev/null +++ b/tests/run/rescue-boundary.scala @@ -0,0 +1,76 @@ +import scala.util.control.NonFatal +import scala.util.boundary, boundary.break + +object lib: + extension [T](op: => T) inline def rescue (fallback: => T) = + try op + catch + case ex: boundary.Break[_] => throw ex + case NonFatal(_) => fallback + + extension [T, E <: Throwable](op: => T) inline def rescue (fallback: PartialFunction[E, T]) = + try op + catch + case ex: E => + // user should never match `ReturnThrowable`, which breaks semantics of non-local return + if fallback.isDefinedAt(ex) && !ex.isInstanceOf[boundary.Break[_]] then fallback(ex) else throw ex +end lib + +import lib.* + +@main def Test = { + assert((9 / 1 rescue 1) == 9) + assert((9 / 0 rescue 1) == 1) + assert(((9 / 0 rescue { case ex: NullPointerException => 5 }) rescue 10) == 10) + assert(((9 / 0 rescue { case ex: ArithmeticException => 5 }) rescue 10) == 5) + + assert( + { + 9 / 0 rescue { + case ex: NullPointerException => 4 + case ex: ArithmeticException => 3 + } + } == 3 + ) + + (9 / 0) rescue { case ex: ArithmeticException => 4 } + + assert( + { + { + val a = 9 / 0 rescue { + case ex: NullPointerException => 4 + } + a * a + } rescue { + case ex: ArithmeticException => 3 + } + } == 3 + ) + + assert(foo(10) == 40) + assert(bar(10) == 40) + + // should not catch fatal errors + assert( + try { { throw new OutOfMemoryError(); true } rescue false } + catch { case _: OutOfMemoryError => true } + ) + + // should catch any errors specified, including fatal errors + assert( + try { { throw new OutOfMemoryError(); true } rescue { case _: OutOfMemoryError => true } } + catch { case _: OutOfMemoryError => false } + ) + + // should not catch NonLocalReturns + def foo(x: Int): Int = boundary { + { break[Int](4 * x) : Int } rescue 10 + } + + // should catch specified exceptions, but not NonLocalReturn + def bar(x: Int): Int = boundary { + { break[Int](4 * x) : Int } rescue { case _ => 10 } + } + +}