diff --git a/safegraph/delphi_safegraph/process.py b/safegraph/delphi_safegraph/process.py index be7ccd4e1..dae1e137e 100644 --- a/safegraph/delphi_safegraph/process.py +++ b/safegraph/delphi_safegraph/process.py @@ -17,32 +17,30 @@ def add_prefix(signal_names, wip_signal, prefix: str): prefix : 'wip_' prefix for new/non public signals wip_signal : List[str] or bool - Either takes a list of wip signals: [], OR - incorporated all signals in the registry: True OR - no signals: False + a list of wip signals: [], OR + all signals in the registry: True OR + only signals that have never been published: False Returns ------- List of signal names wip/non wip signals for further computation """ - if wip_signal in ("", False): - return signal_names - elif wip_signal and isinstance(wip_signal, bool): + if wip_signal is True: + return [prefix + signal for signal in signal_names] + if isinstance(wip_signal,list): + make_wip = set(wip_signal) return [ - (prefix + signal) if public_signal(signal) - else signal + (prefix if signal in make_wip else "") + signal for signal in signal_names ] - elif isinstance(wip_signal, list): - for signal in wip_signal: - if public_signal(signal): - signal_names.append(prefix + signal) - signal_names.remove(signal) - return signal_names - else: - raise ValueError("Supply True | False or '' or [] | list()") - + if wip_signal in {False,""}: + return [ + signal if public_signal(signal) + else prefix + signal + for signal in signal_names + ] + raise ValueError("Supply True | False or '' or [] | list()") # Check if the signal name is public def public_signal(signal_): @@ -54,15 +52,14 @@ def public_signal(signal_): Returns ------- bool - True if the signal is not present - False if the signal is present + True if the signal is present + False if the signal is not present """ - epidata_df = covidcast.meta() + epidata_df = covidcast.metadata() for index in range(len(epidata_df)): - if 'signal' in epidata_df[index]: - if epidata_df[index]['signal'] == signal_: - return False - return True + if epidata_df['signal'][index] == signal_: + return True + return False def construct_signals(cbg_df, signal_names): diff --git a/safegraph/tests/test_process.py b/safegraph/tests/test_process.py index c69ab206d..0f6fab3fe 100644 --- a/safegraph/tests/test_process.py +++ b/safegraph/tests/test_process.py @@ -50,12 +50,17 @@ def test_aggregate_state(self): assert np.all(x[~np.isnan(x)] >= 0) def test_handle_wip_signal(self): - wip_signal = read_params()["wip_signal"] - assert isinstance(wip_signal, (list, bool)) or wip_signal == "", "Supply True | False or "" or [] | list()" - if isinstance(wip_signal, list): - assert set(wip_signal).issubset(set(SIGNALS)), "signal in params don't belong in the registry" - updated_signal_names = add_prefix(SIGNALS, wip_signal, prefix='wip_') - assert (len(updated_signal_names) >= len(SIGNALS)) + # Test wip_signal = True + signal_names = add_prefix(SIGNALS, True, prefix="wip_") + assert all(s.startswith("wip_") for s in signal_names) + # Test wip_signal = list + signal_names = add_prefix(SIGNALS, [SIGNALS[0]], prefix="wip_") + assert signal_names[0].startswith("wip_") + assert all(not s.startswith("wip_") for s in signal_names[1:]) + # Test wip_signal = False + signal_names = add_prefix(["xyzzy", SIGNALS[0]], False, prefix="wip_") + assert signal_names[0].startswith("wip_") + assert all(not s.startswith("wip_") for s in signal_names[1:])