diff --git a/safegraph/delphi_safegraph/constants.py b/safegraph/delphi_safegraph/constants.py index c2fe606cb..e6679df52 100644 --- a/safegraph/delphi_safegraph/constants.py +++ b/safegraph/delphi_safegraph/constants.py @@ -15,4 +15,6 @@ GEO_RESOLUTIONS = [ 'county', 'state', + 'msa', + 'hrr' ] diff --git a/safegraph/delphi_safegraph/process.py b/safegraph/delphi_safegraph/process.py index 6c3ed8bb3..6f76a3750 100644 --- a/safegraph/delphi_safegraph/process.py +++ b/safegraph/delphi_safegraph/process.py @@ -123,17 +123,31 @@ def aggregate(df, signal_names, geo_resolution='county'): signals, standard errors, and sample sizes. """ # Prepare geo resolution + gmpr = GeoMapper() if geo_resolution == 'county': geo_transformed_df = df.copy() geo_transformed_df['geo_id'] = df['county_fips'] elif geo_resolution == 'state': - gmpr = GeoMapper() geo_transformed_df = gmpr.add_geocode(df, - from_col='county_fips', - from_code='fips', - new_code='state_id', - new_col='geo_id', - dropna=False) + from_col='county_fips', + from_code='fips', + new_code='state_id', + new_col='geo_id', + dropna=False) + elif geo_resolution == 'msa': + geo_transformed_df = gmpr.add_geocode(df, + from_col='county_fips', + from_code='fips', + new_code='msa', + new_col='geo_id', + dropna=False) + elif geo_resolution == 'hrr': + geo_transformed_df = gmpr.add_geocode(df, + from_col='county_fips', + from_code='fips', + new_code='hrr', + new_col='geo_id', + dropna=False) else: raise ValueError( f'`geo_resolution` must be one of {GEO_RESOLUTIONS}.') diff --git a/safegraph/tests/test_process.py b/safegraph/tests/test_process.py index c8332e8d6..57ad8b7a4 100644 --- a/safegraph/tests/test_process.py +++ b/safegraph/tests/test_process.py @@ -42,6 +42,7 @@ def test_aggregate_county(self): assert np.all(df[f'{SIGNALS[0]}_n'].values > 0) x = df[f'{SIGNALS[0]}_se'].values assert np.all(x[~np.isnan(x)] >= 0) + assert df.shape == (1472, 17) def test_aggregate_state(self): """Tests that aggregation at the state level creates non-zero-valued @@ -53,6 +54,31 @@ def test_aggregate_state(self): assert np.all(df[f'{SIGNALS[0]}_n'].values > 0) x = df[f'{SIGNALS[0]}_se'].values assert np.all(x[~np.isnan(x)] >= 0) + assert df.shape == (52, 17) + + def test_aggregate_msa(self): + """Tests that aggregation at the state level creates non-zero-valued + signals.""" + cbg_df = construct_signals(pd.read_csv('raw_data/sample_raw_data.csv'), + SIGNALS) + df = aggregate(cbg_df, SIGNALS, 'msa') + + assert np.all(df[f'{SIGNALS[0]}_n'].values > 0) + x = df[f'{SIGNALS[0]}_se'].values + assert np.all(x[~np.isnan(x)] >= 0) + assert df.shape == (372, 17) + + def test_aggregate_hrr(self): + """Tests that aggregation at the state level creates non-zero-valued + signals.""" + cbg_df = construct_signals(pd.read_csv('raw_data/sample_raw_data.csv'), + SIGNALS) + df = aggregate(cbg_df, SIGNALS, 'hrr') + + assert np.all(df[f'{SIGNALS[0]}_n'].values > 0) + x = df[f'{SIGNALS[0]}_se'].values + assert np.all(x[~np.isnan(x)] >= 0) + assert df.shape == (306, 17) def test_files_in_past_week(self): """Tests that `files_in_past_week()` finds the file names corresponding