Skip to content

Commit df03576

Browse files
committed
Fix nested zip bug
1 parent a7ff543 commit df03576

File tree

2 files changed

+94
-33
lines changed

2 files changed

+94
-33
lines changed

tests/run-with-compiler-custom-args/staged-streams_1.check

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,4 +16,4 @@
1616

1717
15
1818

19-
36
19+
72

tests/run-with-compiler-custom-args/staged-streams_1.scala

Lines changed: 93 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -33,14 +33,16 @@ object Test {
3333
*
3434
* The latter transforms the state and returns either the end-of-the-stream or a value and
3535
* the new state. The existential quantification over the state keeps it private: the only permissible operation is
36-
* to pass it to the function. However in `Producer` the elements are not pulled but the step accepts a continuation.
36+
* to pass it to the function.
37+
*
38+
* Note: in `Producer` the elements are not pulled but the step accepts a continuation.
3739
*
3840
* A Producer defines the three basic elements of a loop structure:
3941
* - `init` contributes the code before iteration starts
4042
* - `step` contributes the code during execution
4143
* - `hasNext` contributes the code of the boolean test to end the iteration
4244
*
43-
* @tparam A type of the collection element. Since a `Producer` is polymorphic yet it handles `Expr` values, we
45+
* @tparam A type of the collection element. Since a `Producer` is polymorphic it handles `Expr` values, we
4446
* can pack together fragments of code to accompany each element production (e.g., a variable incremented
4547
* during each transformation)
4648
*/
@@ -57,11 +59,17 @@ object Test {
5759

5860
/** Step method that defines the transformation of data.
5961
*
60-
* @param st the state needed for this iteration step
61-
* @param k
62-
* @return
62+
* @param st the state needed for this iteration step
63+
* @param k the continuation that accepts each element and proceeds with the step-wise processing
64+
* @return expr value of unit per the CPS-encoding
6365
*/
6466
def step(st: St, k: (A => Expr[Unit])): Expr[Unit]
67+
68+
/** The condition that checks for termination
69+
*
70+
* @param st the state needed for this iteration check
71+
* @return the expression for a boolean
72+
*/
6573
def hasNext(st: St): Expr[Boolean]
6674
}
6775

@@ -75,9 +83,18 @@ object Test {
7583

7684
case class Stream[A: Type](stream: StagedStream[Expr[A]]) {
7785

86+
/** Main consumer
87+
*
88+
* Fold accumulates the results in a variable and delegates its functionality to `foldRaw`
89+
*
90+
* @param z the accumulator
91+
* @param f the zipping function
92+
* @tparam W the type of the accumulator
93+
* @return
94+
*/
7895
def fold[W: Type](z: Expr[W], f: ((Expr[W], Expr[A]) => Expr[W])): Expr[W] = {
7996
Var(z) { s: Var[W] => '{
80-
~fold_raw[Expr[A]]((a: Expr[A]) => '{
97+
~foldRaw[Expr[A]]((a: Expr[A]) => '{
8198
~s.update(f(s.get, a))
8299
}, stream)
83100

@@ -86,7 +103,7 @@ object Test {
86103
}
87104
}
88105

89-
private def fold_raw[A](consumer: A => Expr[Unit], stream: StagedStream[A]): Expr[Unit] = {
106+
private def foldRaw[A](consumer: A => Expr[Unit], stream: StagedStream[A]): Expr[Unit] = {
90107
stream match {
91108
case Linear(producer) => {
92109
producer.card match {
@@ -105,7 +122,7 @@ object Test {
105122
}
106123
}
107124
case nested: Nested[A, bt] => {
108-
fold_raw[bt](((e: bt) => fold_raw[A](consumer, nested.nestedf(e))), Linear(nested.producer))
125+
foldRaw[bt](((e: bt) => foldRaw[A](consumer, nested.nestedf(e))), Linear(nested.producer))
109126
}
110127
}
111128
}
@@ -355,6 +372,43 @@ object Test {
355372
* The reified stream is an imperative *non-recursive* function, called `adv`, of `Unit => Unit` type. Nested streams are
356373
* also handled.
357374
*
375+
* @example {{{
376+
*
377+
* Stream.of(1,2,3).flatMap(d => ...)
378+
* .zip(Stream.of(1,2,3).flatMap(d => ...))
379+
* .map{ case (a, b) => a + b }
380+
* .fold(0)((a, b) => a + b)
381+
* }}}
382+
*
383+
* -->
384+
*
385+
* {{{
386+
* /* initialization for stream 1 */
387+
*
388+
* var curr = null.asInstanceOf[Int] // keeps each element from reified stream
389+
* var nadv: Unit => Unit = (_) => () // keeps the advance for each nested level
390+
*
391+
* def adv: Unit => Unit = /* Linearization of stream1 - updates curr from stream1 */
392+
* nadv = adv
393+
* adv()
394+
*
395+
* /* initialization for stream 2 */
396+
*
397+
* def outer () = {
398+
* /* initialization for outer stream of stream 2 */
399+
* def inner() = {
400+
* /* initialization for inner stream of stream 2 */
401+
* val el = curr
402+
* nadv()
403+
* /* process elements for map and fold */
404+
* inner()
405+
* }
406+
* inner()
407+
* outer()
408+
* }
409+
* outer()
410+
* }}}
411+
*
358412
* @param stream
359413
* @tparam A
360414
* @return
@@ -366,18 +420,18 @@ object Test {
366420
/** Helper function that orchestrates the handling of the function that represents an `advance: Unit => Unit`.
367421
* It reifies a nested stream as calls to `advance`. Advance encodes the step function of each nested stream.
368422
* It is used in the init of a producer of a nested stream. When an inner stream finishes, the
369-
* `currentAdvance` holds the function to the `advance` function of the earlier stream.
423+
* `nadv` holds the function to the `advance` function of the earlier stream.
370424
* `makeAdvanceFunction`, for each nested stream, installs a new `advance` function that after
371425
* the stream finishes it will restore the earlier one.
372426
*
373427
* When `advance` is called the result is consumed in the continuation. Within this continuation
374428
* the resulting value should be saved in a variable.
375429
*
376-
* @param currentAdvance variable that holds a function that represents the stream at each level.
430+
* @param nadv variable that holds a function that represents the stream at each level.
377431
* @param k the continuation that consumes a variable.
378432
* @return the quote of the orchestrated code that will be executed as
379433
*/
380-
def makeAdvanceFunction[A](currentAdvance: Var[Unit => Unit], k: A => Expr[Unit], stream: StagedStream[A]): Expr[Unit] = {
434+
def makeAdvanceFunction[A](nadv: Var[Unit => Unit], k: A => Expr[Unit], stream: StagedStream[A]): Expr[Unit] = {
381435
stream match {
382436
case Linear(producer) =>
383437
producer.card match {
@@ -386,28 +440,28 @@ object Test {
386440
~producer.step(st, k)
387441
}
388442
else {
389-
val newAdvance = ~currentAdvance.get
390-
newAdvance(_)
443+
val f = ~nadv.get
444+
f(())
391445
}
392446
})
393447
case Many => producer.init(st => '{
394-
val oldAdvance : Unit => Unit = ~currentAdvance.get
395-
val newAdvance : Unit => Unit = { _: Unit => {
448+
val oldnadv: Unit => Unit = ~nadv.get
449+
val adv1: Unit => Unit = { _: Unit => {
396450
if(~producer.hasNext(st)) {
397451
~producer.step(st, k)
398452
}
399453
else {
400-
~currentAdvance.update('{oldAdvance})
401-
oldAdvance(_)
454+
~nadv.update('{oldnadv})
455+
oldnadv(())
402456
}
403457
}}
404458

405-
~currentAdvance.update('{newAdvance})
406-
newAdvance(_)
459+
~nadv.update('{adv1})
460+
adv1(())
407461
})
408462
}
409463
case nested: Nested[A, bt] =>
410-
makeAdvanceFunction(currentAdvance, (a: bt) => makeAdvanceFunction(currentAdvance, k, nested.nestedf(a)), Linear(nested.producer))
464+
makeAdvanceFunction(nadv, (a: bt) => makeAdvanceFunction(nadv, k, nested.nestedf(a)), Linear(nested.producer))
411465
}
412466
}
413467

@@ -420,32 +474,37 @@ object Test {
420474

421475
def init(k: St => Expr[Unit]): Expr[Unit] = {
422476
producer.init(st =>
423-
Var('{ (_: Unit) => ()}){ advf => {
477+
Var('{ (_: Unit) => ()}){ nadv => {
424478
Var('{ true }) { hasNext => {
425479
Var('{ null.asInstanceOf[A] }) { curr => '{
480+
481+
// Code generation of the `adv` function
426482
def adv: Unit => Unit = { _ =>
427483
~hasNext.update(producer.hasNext(st))
428484
if(~hasNext.get) {
429-
~producer.step(st, el => makeAdvanceFunction[Expr[A]](advf, (a => curr.update(a)), nestedf(el)))
485+
~producer.step(st, el => {
486+
makeAdvanceFunction[Expr[A]](nadv, (a => curr.update(a)), nestedf(el))
487+
})
430488
}
431489
}
432490

433-
~advf.update('{adv})
491+
~nadv.update('{adv})
434492
adv(())
435-
436-
~k((hasNext, curr, advf))
493+
~k((hasNext, curr, nadv))
437494
}}
438495
}}
439496
}})
440497
}
441498

442499
def step(st: St, k: Expr[A] => Expr[Unit]): Expr[Unit] = {
443-
val (flag, current, advf) = st
444-
var el: Var[A] = current
445-
val f: Expr[Unit => Unit] = advf.get
500+
val (flag, current, nadv) = st
501+
'{
502+
var el = ~current.get
503+
val f: Unit => Unit = ~nadv.get
504+
f(())
505+
~k('(el))
506+
}
446507

447-
f('())
448-
k((el.get))
449508
}
450509

451510
def hasNext(st: St): Expr[Boolean] = {
@@ -465,8 +524,8 @@ object Test {
465524

466525
def init(k: St => Expr[Unit]): Expr[Unit] = {
467526
producer.init(s1 => '{ ~nestedProducer.init(s2 =>
468-
Var('{ ~producer.hasNext(s1) }) { term1r =>
469-
k((term1r, s1, s2))
527+
Var('{ ~producer.hasNext(s1) }) { flag =>
528+
k((flag, s1, s2))
470529
})})
471530
}
472531

@@ -491,6 +550,7 @@ object Test {
491550
})
492551
}
493552

553+
/** Computes the producer of zipping two linear streams **/
494554
private def zip_producer[A, B](producer1: Producer[A], producer2: Producer[B]) = {
495555
new Producer[(A, B)] {
496556

@@ -513,6 +573,7 @@ object Test {
513573
}
514574
}
515575

576+
/** zip **/
516577
def zip[B: Type, C: Type](f: (Expr[A] => Expr[B] => Expr[C]), stream2: Stream[B]): Stream[C] = {
517578
val Stream(stream_b) = stream2
518579
Stream(mapRaw[(Expr[A], Expr[B]), Expr[C]]((t => k => '{ ~k(f(t._1)(t._2)) }), zipRaw[A, Expr[B]](stream, stream_b)))

0 commit comments

Comments
 (0)