Skip to content

Commit d3a6c44

Browse files
committed
Improve usage of match types
1 parent 6137bd5 commit d3a6c44

File tree

6 files changed

+42
-35
lines changed

6 files changed

+42
-35
lines changed

src/main/FrameSchema.scala

Lines changed: 19 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -6,20 +6,30 @@ import types.{DataType, Encoder, StructEncoder}
66
import MacroHelpers.TupleSubtype
77

88
object FrameSchema:
9-
type Merge[S1, S2] = S1 match
10-
case TupleSubtype[s1] => S2 match
11-
case TupleSubtype[s2] => Tuple.Concat[s1, s2]
12-
case _ => Tuple.Concat[s1, S2 *: EmptyTuple]
13-
case _ => S2 match
14-
case TupleSubtype[s2] => S1 *: s2
15-
case _ => S1 *: S2 *: EmptyTuple
9+
type AsTuple[A] = A match
10+
case Tuple => A
11+
case Any => A *: EmptyTuple
12+
13+
type FromTuple[T] = T match
14+
case h *: EmptyTuple => h
15+
case Tuple => T
16+
17+
type Merge[S1, S2] = (S1, S2) match
18+
case (Tuple, Tuple) =>
19+
Tuple.Concat[S1, S2]
20+
case (Any, Tuple) =>
21+
S1 *: S2
22+
case (Tuple, Any) =>
23+
Tuple.Append[S1, S2]
24+
case (Any, Any) =>
25+
(S1, S2)
1626

1727
type NullableLabeledDataType[T] = T match
1828
case label := tpe => label := DataType.Nullable[tpe]
1929

2030
type NullableSchema[T] = T match
21-
case TupleSubtype[s] => Tuple.Map[s, NullableLabeledDataType]
22-
case _ => NullableLabeledDataType[T]
31+
case Tuple => Tuple.Map[T, NullableLabeledDataType]
32+
case Any => NullableLabeledDataType[T]
2333

2434
def reownType[Owner <: Name : Type](schema: Type[?])(using Quotes): Type[?] =
2535
schema match

src/main/MacroHelpers.scala

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,3 @@ private[iskra] object MacroHelpers:
1111
Position(file, start, end)
1212

1313
type TupleSubtype[T <: Tuple] = T
14-
15-
type AsTuple[A] <: Tuple = A match
16-
case TupleSubtype[t] => t
17-
case _ => A *: EmptyTuple

src/main/SchemaView.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ package org.virtuslab.iskra
33
import scala.quoted.*
44
import org.apache.spark.sql.functions.col
55
import types.DataType
6-
import MacroHelpers.AsTuple
6+
import MacroHelpers.TupleSubtype
77

88
inline def $(using view: SchemaView): view.type = view
99

@@ -79,7 +79,7 @@ object StructSchemaView:
7979
import quotes.reflect.*
8080
Type.of[DF] match
8181
case '[StructDataFrame[schema]] =>
82-
val schemaType = Type.of[AsTuple[schema]]
82+
val schemaType = Type.of[FrameSchema.AsTuple[schema]]
8383
val aliasViewsByName = frameAliasViewsByName(schemaType)
8484
val columns = unambiguousColumns(schemaType)
8585
val frameAliasNames = Expr(aliasViewsByName.map(_._1))

src/main/Select.scala

Lines changed: 1 addition & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -37,13 +37,7 @@ object Select:
3737
case Some(collectColumns) =>
3838
collectColumns match
3939
case '{ $cc: CollectColumns[?] { type CollectedColumns = collectedColumns } } =>
40-
Type.of[collectedColumns] match
41-
case '[head *: EmptyTuple] =>
42-
'{
43-
val cols = ${ cc }.underlyingColumns(${ columns }(using ${ select }.view))
44-
StructDataFrame[head](${ select }.underlying.select(cols*))
45-
}
46-
40+
Type.of[FrameSchema.FromTuple[collectedColumns]] match
4741
case '[s] =>
4842
'{
4943
val cols = ${ cc }.underlyingColumns(${ columns }(using ${ select }.view))

src/main/StructDataFrame.scala

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ package org.virtuslab.iskra
33
import scala.quoted.*
44

55
import types.{DataType, Encoder, StructEncoder}
6+
import MacroHelpers.TupleSubtype
67

78

89
class StructDataFrame[Schema](val untyped: UntypedDataFrame) extends DataFrame
@@ -20,7 +21,13 @@ object StructDataFrame:
2021
Expr.summon[Encoder[A]] match
2122
case Some(encoder) => encoder match
2223
case '{ $enc: StructEncoder[A] { type StructSchema = structSchema } } =>
23-
Type.of[MacroHelpers.AsTuple[FrameSchema]] match
24+
val frameSchemaTuple = Type.of[FrameSchema] match
25+
case '[TupleSubtype[t]] =>
26+
Type.of[t]
27+
case '[t] =>
28+
Type.of[t *: EmptyTuple]
29+
30+
frameSchemaTuple match
2431
case '[`structSchema`] =>
2532
'{ ClassDataFrame[A](${ df }.untyped) }
2633
case _ =>

src/main/types/DataType.scala

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -34,20 +34,20 @@ object DataType:
3434
case StructOptType[schema] => StructOptType[schema]
3535

3636
type CommonNumericNullableType[T1 <: DataType, T2 <: DataType] <: NumericOptType = (T1, T2) match
37-
case (DoubleOptType, _) | (_, DoubleOptType) => DoubleOptType
38-
case (FloatOptType, _) | (_, FloatOptType) => FloatOptType
39-
case (LongOptType, _) | (_, LongOptType) => LongOptType
40-
case (IntegerOptType, _) | (_, IntegerOptType) => IntegerOptType
41-
case (ShortOptType, _) | (_, ShortOptType) => ShortOptType
42-
case (ByteOptType, _) | (_, ByteOptType) => ByteOptType
37+
case (DoubleOptType, DataType) | (DataType, DoubleOptType) => DoubleOptType
38+
case (FloatOptType, DataType) | (DataType, FloatOptType) => FloatOptType
39+
case (LongOptType, DataType) | (DataType, LongOptType) => LongOptType
40+
case (IntegerOptType, DataType) | (DataType, IntegerOptType) => IntegerOptType
41+
case (ShortOptType, DataType) | (DataType, ShortOptType) => ShortOptType
42+
case (ByteOptType, DataType) | (DataType, ByteOptType) => ByteOptType
4343

4444
type CommonNumericNonNullableType[T1 <: DataType, T2 <: DataType] <: NumericOptType = (T1, T2) match
45-
case (DoubleOptType, _) | (_, DoubleOptType) => DoubleType
46-
case (FloatOptType, _) | (_, FloatOptType) => FloatType
47-
case (LongOptType, _) | (_, LongOptType) => LongType
48-
case (IntegerOptType, _) | (_, IntegerOptType) => IntegerType
49-
case (ShortOptType, _) | (_, ShortOptType) => ShortType
50-
case (ByteOptType, _) | (_, ByteOptType) => ByteType
45+
case (DoubleOptType, DataType) | (DataType, DoubleOptType) => DoubleType
46+
case (FloatOptType, DataType) | (DataType, FloatOptType) => FloatType
47+
case (LongOptType, DataType) | (DataType, LongOptType) => LongType
48+
case (IntegerOptType, DataType) | (DataType, IntegerOptType) => IntegerType
49+
case (ShortOptType, DataType) | (DataType, ShortOptType) => ShortType
50+
case (ByteOptType, DataType) | (DataType, ByteOptType) => ByteType
5151

5252
import DataType.NotNull
5353

0 commit comments

Comments
 (0)