diff --git a/compiler/src/dotty/tools/dotc/tastyreflect/TreeOpsImpl.scala b/compiler/src/dotty/tools/dotc/tastyreflect/TreeOpsImpl.scala index beb38c381ac4..5f86ea0dd0f1 100644 --- a/compiler/src/dotty/tools/dotc/tastyreflect/TreeOpsImpl.scala +++ b/compiler/src/dotty/tools/dotc/tastyreflect/TreeOpsImpl.scala @@ -383,6 +383,9 @@ trait TreeOpsImpl extends scala.tasty.reflect.TreeOps with RootPositionImpl with } object Ident extends IdentModule { + def apply(tmref: TermRef)(implicit ctx: Context): Term = + withDefaultPos(implicit ctx => tpd.ref(tmref).asInstanceOf[Term]) + def copy(original: Tree)(name: String)(implicit ctx: Context): Ident = tpd.cpy.Ident(original)(name.toTermName) diff --git a/compiler/src/dotty/tools/dotc/tastyreflect/TypeOrBoundsOpsImpl.scala b/compiler/src/dotty/tools/dotc/tastyreflect/TypeOrBoundsOpsImpl.scala index f375733f5e00..6dd3e134aa79 100644 --- a/compiler/src/dotty/tools/dotc/tastyreflect/TypeOrBoundsOpsImpl.scala +++ b/compiler/src/dotty/tools/dotc/tastyreflect/TypeOrBoundsOpsImpl.scala @@ -1,6 +1,7 @@ package dotty.tools.dotc.tastyreflect import dotty.tools.dotc.core.{Contexts, Names, Types} +import dotty.tools.dotc.core.Decorators._ trait TypeOrBoundsOpsImpl extends scala.tasty.reflect.TypeOrBoundsOps with CoreImpl { @@ -23,6 +24,11 @@ trait TypeOrBoundsOpsImpl extends scala.tasty.reflect.TypeOrBoundsOps with CoreI if (tpe.classSymbol.exists) Some(tpe.classSymbol.asClass) else None def typeSymbol(implicit ctx: Context): Symbol = tpe.typeSymbol + + def isSingleton(implicit ctx: Context): Boolean = tpe.isSingleton + + def memberType(member: Symbol)(implicit ctx: Context): Type = + member.info.asSeenFrom(tpe, member.owner) } def ConstantTypeDeco(x: ConstantType): Type.ConstantTypeAPI = new Type.ConstantTypeAPI { @@ -181,6 +187,9 @@ trait TypeOrBoundsOpsImpl extends scala.tasty.reflect.TypeOrBoundsOps with CoreI } object TermRef extends TermRefModule { + def apply(qual: TypeOrBounds, name: String)(implicit ctx: Context): TermRef = + Types.TermRef(qual, name.toTermName) + def unapply(x: TypeOrBounds)(implicit ctx: Context): Option[(String, TypeOrBounds /* Type | NoPrefix */)] = x match { case tp: Types.NamedType => tp.designator match { diff --git a/library/src/scala/tasty/reflect/TreeOps.scala b/library/src/scala/tasty/reflect/TreeOps.scala index 7358fb9d64c4..89f12ae03ba4 100644 --- a/library/src/scala/tasty/reflect/TreeOps.scala +++ b/library/src/scala/tasty/reflect/TreeOps.scala @@ -2,7 +2,6 @@ package scala.tasty package reflect trait TreeOps extends Core { - // Decorators implicit def TreeDeco(tree: Tree): TreeAPI @@ -249,6 +248,7 @@ trait TreeOps extends Core { /** Scala term identifier */ val Ident: IdentModule abstract class IdentModule { + def apply(tmref: TermRef)(implicit ctx: Context): Term def copy(original: Tree)(name: String)(implicit ctx: Context): Ident diff --git a/library/src/scala/tasty/reflect/TypeOrBoundsOps.scala b/library/src/scala/tasty/reflect/TypeOrBoundsOps.scala index 1063d4345e25..9e3beb602f37 100644 --- a/library/src/scala/tasty/reflect/TypeOrBoundsOps.scala +++ b/library/src/scala/tasty/reflect/TypeOrBoundsOps.scala @@ -55,6 +55,8 @@ trait TypeOrBoundsOps extends Core { def widen(implicit ctx: Context): Type def classSymbol(implicit ctx: Context): Option[ClassSymbol] def typeSymbol(implicit ctx: Context): Symbol + def isSingleton(implicit ctx: Context): Boolean + def memberType(member: Symbol)(implicit ctx: Context): Type } val IsType: IsTypeModule @@ -107,6 +109,7 @@ trait TypeOrBoundsOps extends Core { val TermRef: TermRefModule abstract class TermRefModule { + def apply(qual: TypeOrBounds, name: String)(implicit ctx: Context): TermRef def unapply(typeOrBounds: TypeOrBounds)(implicit ctx: Context): Option[(String, TypeOrBounds /* Type | NoPrefix */)] } diff --git a/tests/neg-with-compiler/i5941/macro_1.scala b/tests/neg-with-compiler/i5941/macro_1.scala new file mode 100644 index 000000000000..4791335569a5 --- /dev/null +++ b/tests/neg-with-compiler/i5941/macro_1.scala @@ -0,0 +1,55 @@ +abstract class Lens[S, T] { + def get(s: S): T + def set(t: T, s: S) :S +} + +import scala.quoted._ +import scala.tasty._ + +object Lens { + def apply[S, T](_get: S => T)(_set: T => S => S): Lens[S, T] = new Lens { + def get(s: S): T = _get(s) + def set(t: T, s: S): S = _set(t)(s) + } + + def impl[S: Type, T: Type](getter: Expr[S => T])(implicit refl: Reflection): Expr[Lens[S, T]] = { + import refl._ + import util._ + import quoted.Toolbox.Default._ + + // obj.copy(field = value) + def setterBody(obj: Expr[S], value: Expr[T], field: String): Expr[S] = + Term.Select.overloaded(obj.unseal, "copy", Nil, Term.NamedArg(field, value.unseal) :: Nil).seal[S] + + // exception: getter.unseal.underlyingArgument + getter.unseal match { + case Term.Inlined( + None, Nil, + Term.Block( + DefDef(_, Nil, (param :: Nil) :: Nil, _, Some(Term.Select(o, field))) :: Nil, + Term.Lambda(meth, _) + ) + ) if o.symbol == param.symbol => + '{ + val setter = (t: T) => (s: S) => ${ setterBody('s, 't, field) } + apply($getter)(setter) + } + case _ => + throw new QuoteError("Unsupported syntax. Example: `GenLens[Address](_.streetNumber)`") + } + } +} + +object GenLens { + /** case class Address(streetNumber: Int, streetName: String) + * + * GenLens[Address](_.streetNumber) ~~> + * + * Lens[Address, Int](_.streetNumber)(n => a => a.copy(streetNumber = n)) + */ + + def apply[S] = new MkGenLens[S] + class MkGenLens[S] { + inline def apply[T](get: => (S => T)): Lens[S, T] = ${ Lens.impl('get) } + } +} \ No newline at end of file diff --git a/tests/neg-with-compiler/i5941/usage_2.scala b/tests/neg-with-compiler/i5941/usage_2.scala new file mode 100644 index 000000000000..c419f9beb0c3 --- /dev/null +++ b/tests/neg-with-compiler/i5941/usage_2.scala @@ -0,0 +1,11 @@ +case class Address(streetNumber: Int, streetName: String) + +object Test { + def main(args: Array[String]): Unit = { + val len = GenLens[Address](_.streetNumber + 3) // error + val address = Address(10, "High Street") + assert(len.get(address) == 10) + val addr2 = len.set(5, address) + assert(len.get(addr2) == 5) + } +} \ No newline at end of file diff --git a/tests/run-with-compiler/i5941/macro_1.scala b/tests/run-with-compiler/i5941/macro_1.scala new file mode 100644 index 000000000000..6eed35a2a69c --- /dev/null +++ b/tests/run-with-compiler/i5941/macro_1.scala @@ -0,0 +1,225 @@ +trait Lens[S, T] { + def get(s: S): T + def set(t: T, s: S) :S +} + +import scala.quoted._ +import scala.tasty._ + +object Lens { + def apply[S, T](_get: S => T)(_set: T => S => S): Lens[S, T] = new Lens { + def get(s: S): T = _get(s) + def set(t: T, s: S): S = _set(t)(s) + } + + def impl[S: Type, T: Type](getter: Expr[S => T])(implicit refl: Reflection): Expr[Lens[S, T]] = { + import refl._ + import util._ + import quoted.Toolbox.Default._ + + + // obj.copy(a = obj.a.copy(b = a.b.copy(c = v))) + def setterBody(obj: Term, value: Term, parts: List[String]): Term = { + // o.copy(field = value) + def helper(obj: Term, value: Term, field: String): Term = + Term.Select.overloaded(obj, "copy", Nil, Term.NamedArg(field, value) :: Nil) + + parts match { + case field :: Nil => helper(obj, value, field) + case field :: parts => + helper(obj, setterBody(Term.Select.unique(obj, field), value, parts), field) + } + } + + object Path { + private def recur(tree: Term, selects: List[String]): Option[(Term, List[String])] = tree match { + case Term.Ident(_) if selects.nonEmpty => Some((tree, selects)) + case Term.Select(qual, name) => recur(qual, name :: selects) + case _ => None + } + + def unapply(t: Term): Option[(Term, List[String])] = recur(t, Nil) + } + + object Function { + def unapply(t: Term): Option[(List[ValDef], Term)] = t match { + case Term.Inlined( + None, Nil, + Term.Block( + (ddef @ DefDef(_, Nil, params :: Nil, _, Some(body))) :: Nil, + Term.Lambda(meth, _) + ) + ) if meth.symbol == ddef.symbol => Some((params, body)) + case _ => None + } + } + + // exception: getter.unseal.underlyingArgument + getter.unseal match { + case Function(param :: Nil, Path(o, parts)) if o.symbol == param.symbol => + '{ + val setter = (t: T) => (s: S) => ${ setterBody(('s).unseal, ('t).unseal, parts).seal[S] } + apply($getter)(setter) + } + case _ => + throw new QuoteError("Unsupported syntax. Example: `GenLens[Address](_.streetNumber)`") + } + } +} + +object GenLens { + /** case class Address(streetNumber: Int, streetName: String) + * + * GenLens[Address](_.streetNumber) ~~> + * + * Lens[Address, Int](_.streetNumber)(n => a => a.copy(streetNumber = n)) + */ + + def apply[S] = new MkGenLens[S] + class MkGenLens[S] { + inline def apply[T](get: => (S => T)): Lens[S, T] = ${ Lens.impl('get) } + } +} + +trait Iso[S, A] { + def from(a: A): S + def to(s: S): A +} + +object Iso { + def apply[S, A](_from: A => S)(_to: S => A): Iso[S, A] = new Iso { + def from(a: A): S = _from(a) + def to(s: S): A = _to(s) + } + + def impl[S: Type, A: Type](implicit refl: Reflection): Expr[Iso[S, A]] = { + import refl._ + import util._ + import quoted.Toolbox.Default._ + + val tpS = typeOf[S] + val tpA = typeOf[A] + + // 1. S must be a case class + // 2. A must be a tuple + // 3. The parameters of S must match A + if (tpS.classSymbol.flatMap(cls => if (cls.flags.is(Flags.Case)) Some(true) else None).isEmpty) + throw new QuoteError("Only support generation for case classes") + + val cls = tpS.classSymbol.get + + val companion = tpS match { + case Type.SymRef(sym, prefix) => Type.TermRef(prefix, sym.name) + case Type.TypeRef(name, prefix) => Type.TermRef(prefix, name) + } + + if (cls.caseFields.size != 1) + throw new QuoteError("Use GenIso.fields for case classes more than one parameter") + + val fieldTp = tpS.memberType(cls.caseFields.head) + if (!(fieldTp =:= tpA)) + throw new QuoteError(s"The type of case class field $fieldTp does not match $tpA") + + '{ + // (p: S) => p._1 + val to = (p: S) => ${ Term.Select.unique(('p).unseal, "_1").seal[A] } + // (p: A) => S(p) + val from = (p: A) => ${ Term.Select.overloaded(Term.Ident(companion), "apply", Nil, ('p).unseal :: Nil).seal[S] } + apply(from)(to) + } + } + + def implUnit[S: Type](implicit refl: Reflection): Expr[Iso[S, 1]] = { + import refl._ + import util._ + import quoted.Toolbox.Default._ + + val tpS = typeOf[S] + + if (tpS.isSingleton) { + val ident = Term.Ident(tpS.asInstanceOf[TermRef]).seal[S] + '{ + Iso[S, 1](Function.const($ident))(Function.const(1)) + } + } + else if (tpS.classSymbol.flatMap(cls => if (cls.flags.is(Flags.Case)) Some(true) else None).nonEmpty) { + val cls = tpS.classSymbol.get + + if (cls.caseFields.size != 0) + throw new QuoteError("Use GenIso.fields for case classes more than one parameter") + + val companion = tpS match { + case Type.SymRef(sym, prefix) => Type.TermRef(prefix, sym.name) + case Type.TypeRef(name, prefix) => Type.TermRef(prefix, name) + } + + val obj = Term.Select.overloaded(Term.Ident(companion), "apply", Nil, Nil).seal[S] + + '{ + Iso[S, 1](Function.const($obj))(Function.const(1)) + } + } + else { + throw new QuoteError("Only support generation for case classes or singleton types") + } + } + + // TODO: require whitebox macro + def implFields[S: Type](implicit refl: Reflection): Expr[Iso[S, Any]] = ??? +} + +object GenIso { + /** + * GenIso[Person, String] ~~> + * + * Iso[Person, String] + * { p => p._1 } + * { p => Person(p) } + */ + inline def apply[S, A]: Iso[S, A] = ${ Iso.impl[S, A] } + + // TODO: require whitebox macro + inline def fields[S]: Iso[S, Any] = ${ Iso.implFields[S] } + + inline def unit[S]: Iso[S, 1] = ${ Iso.implUnit[S] } +} + +trait Prism[S, A] { outer => + def getOption(s: S): Option[A] + def apply(a: A): S + + def composeIso[B](iso: Iso[A, B]): Prism[S, B] = new Prism { + def getOption(s: S): Option[B] = outer.getOption(s).map(a => iso.to(a)) + def apply(b: B): S = outer(iso.from(b)) + } +} + +object Prism { + def apply[S, A](getOpt: S => Option[A])(app: A => S): Prism[S, A] = new Prism { + def getOption(s: S): Option[A] = getOpt(s) + def apply(a: A): S = app(a) + } + + def impl[S: Type, A <: S : Type](implicit refl: Reflection): Expr[Prism[S, A]] = { + import refl._ + import util._ + + '{ + val get = (p: S) => if (p.isInstanceOf[A]) Some(p.asInstanceOf[A]) else None + val app = (p: A) => p + apply(get)(app) + } + } +} + +object GenPrism { + /** + * GenPrism[Json, JStr] ~~> + * + * Prism[Json, JStr]{ + * case JStr(v) => Some(v) + * case _ => None + * }(jstr => jstr) + */ + inline def apply[S, A <: S]: Prism[S, A] = ${ Prism.impl[S, A] } +} \ No newline at end of file diff --git a/tests/run-with-compiler/i5941/usage_2.scala b/tests/run-with-compiler/i5941/usage_2.scala new file mode 100644 index 000000000000..6cad3fc0cec5 --- /dev/null +++ b/tests/run-with-compiler/i5941/usage_2.scala @@ -0,0 +1,58 @@ +case class Address(streetNumber: Int, streetName: String) +case class Employee(name: String, addr: Address) + +sealed trait Json +case object JNull extends Json +case class JStr(v: String) extends Json +case class JNum(v: Double) extends Json +case class JObj(v: Map[String, Json]) extends Json + +object Test { + def main(args: Array[String]): Unit = { + val len = GenLens[Address](_.streetNumber) + val address = Address(10, "High Street") + assert(len.get(address) == 10) + val addr2 = len.set(5, address) + assert(len.get(addr2) == 5) + + // a.b.c + val len2 = GenLens[Employee](_.addr.streetNumber) + val employee = Employee("Bob", Address(10, "High Street")) + assert(len2.get(employee) == 10) + val employee2 = len2.set(5, employee) + assert(employee2.name == "Bob") + assert(len2.get(employee2) == 5) + + // prism + val jStr: Prism[Json, JStr] = GenPrism[Json, JStr] + assert(jStr.getOption(JNum(4.5)) == None) + assert(jStr.getOption(JStr("hello")) == Some(JStr("hello"))) + assert(jStr(JStr("world")) == JStr("world")) + + assert(GenIso[JStr, String].to(JStr("Hello")) == "Hello") + assert(GenIso.unit[JNull.type].to(JNull) == 1) + assert(GenIso.unit[JNull.type].from(1) == JNull) + + // TODO: require whitebox macros + // assert(GenIso.fields[Address].from((0, "a")) == Address(0, "a")) + + val jNum: Prism[Json, Double] = GenPrism[Json, JNum] composeIso GenIso[JNum, Double] + assert(jNum(3.5) == JNum(3.5)) + assert(jNum.getOption(JNum(3.5)) == Some(3.5)) + assert(jNum.getOption(JNull) == None) + + // inner classes + val inner = new Inner + assert(GenIso[inner.JStr, String].to(inner.JStr("Hello")) == "Hello") + assert(GenIso.unit[inner.JNull.type].to(inner.JNull) == 1) + assert(GenIso.unit[inner.JNull.type].from(1) == inner.JNull) + } +} + +class Inner { + sealed trait Json + case object JNull extends Json + case class JStr(v: String) extends Json + case class JNum(v: Double) extends Json + case class JObj(v: Map[String, Json]) extends Json +} \ No newline at end of file