@@ -491,7 +491,6 @@ def _generate_transformed_rows(
491
491
parsed_rows : Iterator [Dict ],
492
492
transform_dict : Optional [SignalTransforms ] = None ,
493
493
transform_args : Optional [Dict ] = None ,
494
- group_keyfunc : Optional [Callable ] = None ,
495
494
) -> Iterator [Dict ]:
496
495
"""Applies time-series transformations to streamed rows from a database.
497
496
@@ -503,9 +502,6 @@ def _generate_transformed_rows(
503
502
For example, transform_dict may be {("jhu-csse", "confirmed_cumulative_num): [("jhu-csse", "confirmed_incidence_num"), ("jhu-csse", "confirmed_7dav_incidence_num")]}.
504
503
transform_args: Optional[Dict], default None
505
504
A dictionary of keyword arguments for the transformer functions.
506
- group_keyfunc: Optional[Callable], default None
507
- The groupby function to use to order the streamed rows. Note that Python groupby does not do any sorting, so
508
- parsed_rows are assumed to be sorted in accord with this groupby.
509
505
510
506
Yields:
511
507
transformed rows: Dict
@@ -515,59 +511,57 @@ def _generate_transformed_rows(
515
511
transform_args = dict ()
516
512
if not transform_dict :
517
513
transform_dict = dict ()
518
- if not group_keyfunc :
519
- group_keyfunc = lambda row : (row ["source" ], row ["signal" ], row ["geo_type" ], row ["geo_value" ])
520
514
521
- for key , source_signal_geo_rows in groupby (parsed_rows , group_keyfunc ):
515
+ # TODO: Fix these to come as an argument?
516
+ fields_string = ["geo_type" , "geo_value" , "source" , "signal" , "time_type" ]
517
+ fields_int = ["time_value" , "direction" , "issue" , "lag" , "missing_value" , "missing_stderr" , "missing_sample_size" ]
518
+ fields_float = ["value" , "stderr" , "sample_size" ]
519
+ columns = fields_string + fields_int + fields_float
520
+ df = pd .DataFrame (parsed_rows , columns = columns )
521
+ for key , group_df in df .groupby (["source" , "signal" , "geo_type" , "geo_value" ]):
522
522
base_source_name , base_signal_name , _ , _ = key
523
523
# Extract the list of derived signals; if a signal is not in the dictionary, then use the identity map.
524
524
derived_signal_transform_map : SourceSignalPair = transform_dict .get (SourceSignalPair (base_source_name , [base_signal_name ]), SourceSignalPair (base_source_name , [base_signal_name ]))
525
525
# Create a list of source-signal pairs along with the transformation required for the signal.
526
526
signal_names_and_transforms : List [Tuple [str , Callable ]] = [(derived_signal , _get_base_signal_transform ((base_source_name , derived_signal ))) for derived_signal in derived_signal_transform_map .signal ]
527
527
528
- # TODO: Fix these to come as an argument.
529
- fields_string = ["geo_type" , "geo_value" , "source" , "signal" , "time_type" ]
530
- fields_int = ["time_value" , "direction" , "issue" , "lag" , "missing_value" , "missing_stderr" , "missing_sample_size" ]
531
- fields_float = ["value" , "stderr" , "sample_size" ]
532
- columns = fields_string + fields_int + fields_float
533
- df = pd .DataFrame .from_records (source_signal_geo_rows , columns = columns )
534
528
for derived_signal , transform in signal_names_and_transforms :
535
529
if transform == IDENTITY :
536
- yield from df .to_dict (orient = "records" )
530
+ yield from group_df .to_dict (orient = "records" )
537
531
continue
538
-
539
- df2 = df .set_index (["time_value" ])
540
- df2 = df2 .reindex (iterate_over_range (df2 .index .min (), df2 .index .max (), inclusive = True ))
532
+
533
+ derived_df = group_df .set_index (["time_value" ])
534
+ derived_df = derived_df .reindex (iterate_over_range (derived_df .index .min (), derived_df .index .max (), inclusive = True ))
541
535
542
536
if transform == DIFF :
543
- df2 ["value" ] = df2 ["value" ].diff ()
537
+ derived_df ["value" ] = derived_df ["value" ].diff ()
544
538
window_length = 2
545
539
elif transform == SMOOTH :
546
- df2 ["value" ] = df2 ["value" ].rolling (7 ).mean ()
540
+ derived_df ["value" ] = derived_df ["value" ].rolling (7 ).mean ()
547
541
window_length = 7
548
542
elif transform == DIFF_SMOOTH :
549
- df2 ["value" ] = df2 ["value" ].diff ().rolling (7 ).mean ()
543
+ derived_df ["value" ] = derived_df ["value" ].diff ().rolling (7 ).mean ()
550
544
window_length = 8
551
545
else :
552
546
raise ValueError (f"Unknown transform for { derived_signal } ." )
553
547
554
- df2 = df2 .assign (
555
- geo_type = df2 ["geo_type" ].fillna (method = "ffill" ),
556
- geo_value = df2 ["geo_value" ].fillna (method = "ffill" ),
557
- source = df2 ["source" ].fillna (method = "ffill" ),
548
+ derived_df = derived_df .assign (
549
+ geo_type = derived_df ["geo_type" ].fillna (method = "ffill" ),
550
+ geo_value = derived_df ["geo_value" ].fillna (method = "ffill" ),
551
+ source = derived_df ["source" ].fillna (method = "ffill" ),
558
552
signal = derived_signal ,
559
- time_type = df2 ["time_type" ].fillna (method = "ffill" ),
560
- direction = df2 ["direction" ].fillna (method = "ffill" ),
561
- issue = df2 ["issue" ].rolling (window_length ).max (),
562
- lag = df2 ["lag" ].rolling (window_length ).max (),
563
- missing_value = np .where (df2 ["value" ].isna (), Nans .NOT_APPLICABLE , Nans .NOT_MISSING ),
553
+ time_type = derived_df ["time_type" ].fillna (method = "ffill" ),
554
+ direction = derived_df ["direction" ].fillna (method = "ffill" ),
555
+ issue = derived_df ["issue" ].rolling (window_length ).max (),
556
+ lag = derived_df ["lag" ].rolling (window_length ).max (),
557
+ missing_value = np .where (derived_df ["value" ].isna (), Nans .NOT_APPLICABLE , Nans .NOT_MISSING ),
564
558
missing_stderr = Nans .NOT_APPLICABLE ,
565
559
missing_sample_size = Nans .NOT_APPLICABLE ,
566
560
stderr = np .nan ,
567
561
sample_size = np .nan ,
568
562
)
569
- df2 = df2 .iloc [window_length - 1 :]
570
- for row in df2 .reset_index ().to_dict (orient = "records" ):
563
+ derived_df = derived_df .iloc [window_length - 1 :]
564
+ for row in derived_df .reset_index ().to_dict (orient = "records" ):
571
565
row .update ({
572
566
"issue" : int (row ["issue" ]) if not np .isnan (row ["issue" ]) else row ["issue" ],
573
567
"lag" : int (row ["lag" ]) if not np .isnan (row ["lag" ]) else row ["lag" ]
0 commit comments