Skip to content

Commit 78df911

Browse files
authored
Merge pull request #528 from cmu-delphi/sgratzl/exclude_fields
feat: exclude fields
2 parents 6998a19 + eda64d1 commit 78df911

File tree

2 files changed

+49
-6
lines changed

2 files changed

+49
-6
lines changed

integrations/server/test_covidcast.py

+25
Original file line numberDiff line numberDiff line change
@@ -274,6 +274,31 @@ def test_fields(self):
274274
'message': 'success',
275275
})
276276

277+
278+
# limit exclude fields
279+
response = requests.get(BASE_URL, params={
280+
'endpoint': 'covidcast',
281+
'data_source': 'src',
282+
'signal': 'sig',
283+
'time_type': 'day',
284+
'geo_type': 'county',
285+
'time_values': 20200414,
286+
'geo_value': '01234',
287+
'fields': '-value,-stderr,-sample_size,-direction,-issue,-lag,-signal'
288+
})
289+
response.raise_for_status()
290+
response = response.json()
291+
292+
# assert that the right data came back
293+
self.assertEqual(response, {
294+
'result': 1,
295+
'epidata': [{
296+
'time_value': 20200414,
297+
'geo_value': '01234'
298+
}],
299+
'message': 'success',
300+
})
301+
277302
def test_location_wildcard(self):
278303
"""Select all locations with a wildcard query."""
279304

src/server/_query.py

+24-6
Original file line numberDiff line numberDiff line change
@@ -97,11 +97,21 @@ def filter_fields(generator: Iterable[Dict[str, Any]]):
9797
if not fields:
9898
yield from generator
9999
else:
100+
exclude_fields = {f[1:] for f in fields if f.startswith("-")}
101+
include_fields = [f for f in fields if not f.startswith("-") and f not in exclude_fields]
102+
100103
for row in generator:
101104
filtered = dict()
102-
for field in fields:
103-
if field in row:
104-
filtered[field] = row[field]
105+
if include_fields:
106+
# positive list
107+
for field in include_fields:
108+
if field in row:
109+
filtered[field] = row[field]
110+
elif exclude_fields:
111+
# negative list
112+
for k, v in row.items():
113+
if k not in exclude_fields:
114+
filtered[k] = v
105115
yield filtered
106116

107117

@@ -252,9 +262,17 @@ def execute_queries(
252262

253263
fields_to_send = set(extract_strings("fields") or [])
254264
if fields_to_send:
255-
fields_string = [v for v in fields_string if v in fields_to_send]
256-
fields_int = [v for v in fields_int if v in fields_to_send]
257-
fields_float = [v for v in fields_float if v in fields_to_send]
265+
exclude_fields = {f[1:] for f in fields_to_send if f.startswith("-")}
266+
include_fields = {f for f in fields_to_send if not f.startswith("-") and f not in exclude_fields}
267+
268+
if include_fields:
269+
fields_string = [v for v in fields_string if v in include_fields]
270+
fields_int = [v for v in fields_int if v in include_fields]
271+
fields_float = [v for v in fields_float if v in include_fields]
272+
if exclude_fields:
273+
fields_string = [v for v in fields_string if v not in exclude_fields]
274+
fields_int = [v for v in fields_int if v not in exclude_fields]
275+
fields_float = [v for v in fields_float if v not in exclude_fields]
258276

259277
query_list = list(queries)
260278

0 commit comments

Comments
 (0)