1
1
from dataclasses import asdict , dataclass , field
2
2
from enum import Enum
3
3
from functools import partial
4
- from itertools import groupby , repeat , tee
4
+ from itertools import groupby
5
5
from numbers import Number
6
6
from typing import Callable , Generator , Iterator , Optional , Dict , List , Set , Tuple , Union
7
7
8
8
from pathlib import Path
9
9
import re
10
- from more_itertools import flatten , interleave_longest , peekable
10
+ from more_itertools import flatten , peekable
11
11
import pandas as pd
12
12
import numpy as np
13
13
14
14
from delphi_utils .nancodes import Nans
15
15
from ..._params import SourceSignalPair , TimePair
16
16
from .smooth_diff import generate_smoothed_rows , generate_diffed_rows
17
- from ...utils import shift_time_value , iterate_over_ints_and_ranges
17
+ from ...utils import shift_time_value , iterate_over_ints_and_ranges , iterate_over_range
18
18
19
19
20
20
IDENTITY : Callable = lambda rows , ** kwargs : rows
@@ -489,7 +489,6 @@ def get_day_range(time_pairs: List[TimePair]) -> Iterator[int]:
489
489
490
490
def _generate_transformed_rows (
491
491
parsed_rows : Iterator [Dict ],
492
- time_pairs : Optional [List [TimePair ]] = None ,
493
492
transform_dict : Optional [SignalTransforms ] = None ,
494
493
transform_args : Optional [Dict ] = None ,
495
494
group_keyfunc : Optional [Callable ] = None ,
@@ -499,9 +498,6 @@ def _generate_transformed_rows(
499
498
Parameters:
500
499
parsed_rows: Iterator[Dict]
501
500
An iterator streaming rows from a database query. Assumed to be sorted by source, signal, geo_type, geo_value, time_type, and time_value.
502
- time_pairs: Optional[List[TimePair]], default None
503
- A list of TimePairs, which can be used to create a continguous time index for time-series operations.
504
- The min and max dates in the TimePairs list is used.
505
501
transform_dict: Optional[SignalTransforms], default None
506
502
A dictionary mapping base sources to a list of their derived signals that the user wishes to query.
507
503
For example, transform_dict may be {("jhu-csse", "confirmed_cumulative_num): [("jhu-csse", "confirmed_incidence_num"), ("jhu-csse", "confirmed_7dav_incidence_num")]}.
@@ -527,18 +523,56 @@ def _generate_transformed_rows(
527
523
# Extract the list of derived signals; if a signal is not in the dictionary, then use the identity map.
528
524
derived_signal_transform_map : SourceSignalPair = transform_dict .get (SourceSignalPair (base_source_name , [base_signal_name ]), SourceSignalPair (base_source_name , [base_signal_name ]))
529
525
# Create a list of source-signal pairs along with the transformation required for the signal.
530
- signal_names_and_transforms : List [Tuple [Tuple [str , str ], Callable ]] = [(derived_signal , _get_base_signal_transform ((base_source_name , derived_signal ))) for derived_signal in derived_signal_transform_map .signal ]
531
- # Put the current time series on a contiguous time index.
532
- source_signal_geo_rows = _reindex_iterable (source_signal_geo_rows , time_pairs , fill_value = transform_args .get ("pad_fill_value" ))
533
- # Create copies of the iterable, with smart memory usage.
534
- source_signal_geo_rows_copies : Iterator [Iterator [Dict ]] = tee (source_signal_geo_rows , len (signal_names_and_transforms ))
535
- # Create a list of transformed group iterables, remembering their derived name as needed.
536
- transformed_signals_iterator : Iterator [Tuple [str , Iterator [Dict ]]] = (zip (repeat (derived_signal ), transform (rows , ** transform_args )) for (derived_signal , transform ), rows in zip (signal_names_and_transforms , source_signal_geo_rows_copies ))
537
- # Traverse through the transformed iterables in an interleaved fashion, which makes sure that only a small window
538
- # of the original iterable (group) is stored in memory.
539
- for derived_signal_name , row in interleave_longest (* transformed_signals_iterator ):
540
- row ["signal" ] = derived_signal_name
541
- yield row
526
+ signal_names_and_transforms : List [Tuple [str , Callable ]] = [(derived_signal , _get_base_signal_transform ((base_source_name , derived_signal ))) for derived_signal in derived_signal_transform_map .signal ]
527
+
528
+ # TODO: Fix these to come as an argument.
529
+ fields_string = ["geo_type" , "geo_value" , "source" , "signal" , "time_type" ]
530
+ fields_int = ["time_value" , "direction" , "issue" , "lag" , "missing_value" , "missing_stderr" , "missing_sample_size" ]
531
+ fields_float = ["value" , "stderr" , "sample_size" ]
532
+ columns = fields_string + fields_int + fields_float
533
+ df = pd .DataFrame .from_records (source_signal_geo_rows , columns = columns )
534
+ for derived_signal , transform in signal_names_and_transforms :
535
+ if transform == IDENTITY :
536
+ yield from df .to_dict (orient = "records" )
537
+ continue
538
+
539
+ df2 = df .set_index (["time_value" ])
540
+ df2 = df2 .reindex (iterate_over_range (df2 .index .min (), df2 .index .max (), inclusive = True ))
541
+
542
+ if transform == DIFF :
543
+ df2 ["value" ] = df2 ["value" ].diff ()
544
+ window_length = 2
545
+ elif transform == SMOOTH :
546
+ df2 ["value" ] = df2 ["value" ].rolling (7 ).mean ()
547
+ window_length = 7
548
+ elif transform == DIFF_SMOOTH :
549
+ df2 ["value" ] = df2 ["value" ].diff ().rolling (7 ).mean ()
550
+ window_length = 8
551
+ else :
552
+ raise ValueError (f"Unknown transform for { derived_signal } ." )
553
+
554
+ df2 = df2 .assign (
555
+ geo_type = df2 ["geo_type" ].fillna (method = "ffill" ),
556
+ geo_value = df2 ["geo_value" ].fillna (method = "ffill" ),
557
+ source = df2 ["source" ].fillna (method = "ffill" ),
558
+ signal = derived_signal ,
559
+ time_type = df2 ["time_type" ].fillna (method = "ffill" ),
560
+ direction = df2 ["direction" ].fillna (method = "ffill" ),
561
+ issue = df2 ["issue" ].rolling (window_length ).max (),
562
+ lag = df2 ["lag" ].rolling (window_length ).max (),
563
+ missing_value = np .where (df2 ["value" ].isna (), Nans .NOT_APPLICABLE , Nans .NOT_MISSING ),
564
+ missing_stderr = Nans .NOT_APPLICABLE ,
565
+ missing_sample_size = Nans .NOT_APPLICABLE ,
566
+ stderr = np .nan ,
567
+ sample_size = np .nan ,
568
+ )
569
+ df2 = df2 .iloc [window_length - 1 :]
570
+ for row in df2 .reset_index ().to_dict (orient = "records" ):
571
+ row .update ({
572
+ "issue" : int (row ["issue" ]) if not np .isnan (row ["issue" ]) else row ["issue" ],
573
+ "lag" : int (row ["lag" ]) if not np .isnan (row ["lag" ]) else row ["lag" ]
574
+ })
575
+ yield row
542
576
543
577
544
578
def get_basename_signal_and_jit_generator (source_signal_pairs : List [SourceSignalPair ], transform_args : Optional [Dict [str , Union [str , int ]]] = None ) -> Tuple [List [SourceSignalPair ], Generator ]:
0 commit comments