Skip to content

ENH: add freq parameter to _BaseReader #199

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 4 commits into from
Closed
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
95 changes: 95 additions & 0 deletions pandas_datareader/tests/test_wb.py
Original file line number Diff line number Diff line change
Expand Up @@ -234,6 +234,101 @@ def test_wdi_get_indicators(self):
self.assertTrue(result.columns.equals(exp_col))
self.assertTrue(len(result) > 10000)

def test_wdi_download_monthly(self):


expected = {'COPPER': {('World', '2012M01'): 8040.47,
('World', '2011M12'): 7565.48,
('World', '2011M11'): 7581.02,
('World', '2011M10'): 7394.19,
('World', '2011M09'): 8300.14,
('World', '2011M08'): 9000.76,
('World', '2011M07'): 9650.46,
('World', '2011M06'): 9066.85,
('World', '2011M05'): 8959.90,
('World', '2011M04'): 9492.79,
('World', '2011M03'): 9503.36,
('World', '2011M02'): 9867.60,
('World', '2011M01'): 9555.70}}
expected = pd.DataFrame(expected)
# Round, to ignore revisions to data.
expected = np.round(expected, decimals=-3)
if PANDAS_0170:
expected = expected.sort_index()
else:
expected = expected.sort()

cntry_codes = 'ALL'
inds = 'COPPER'
result = download(country=cntry_codes, indicator=inds,
start=2011, end=2012, freq='M',errors='ignore')
if PANDAS_0170:
result = result.sort_index()
else:
result = result.sort()
result = np.round(result, decimals=-3)

if PANDAS_0140:
expected.index.names = ['country', 'year']
else:
# prior versions doesn't allow to set multiple names to MultiIndex
# Thus overwrite it with the result
expected.index = result.index

tm.assert_frame_equal(result, expected)

result = WorldBankReader(inds, countries=cntry_codes,
start=2011, end=2012, freq='M', errors='ignore').read()
if PANDAS_0170:
result = result.sort_index()
else:
result = result.sort()
result = np.round(result, decimals=-3)
tm.assert_frame_equal(result, expected)

def test_wdi_download_quarterly(self):

expected = {'DT.DOD.PUBS.CD.US': {('Albania', '2012Q1'): 3240539817.18,
('Albania', '2011Q4'): 3213979715.15,
('Albania', '2011Q3'): 3187681048.95,
('Albania', '2011Q2'): 3248041513.86,
('Albania', '2011Q1'): 3137210567.92,}}
expected = pd.DataFrame(expected)
# Round, to ignore revisions to data.
expected = np.round(expected, decimals=-3)
if PANDAS_0170:
expected = expected.sort_index()
else:
expected = expected.sort()

cntry_codes = 'ALB'
inds = 'DT.DOD.PUBS.CD.US'
result = download(country=cntry_codes, indicator=inds,
start=2011, end=2012, freq='Q',errors='ignore')
if PANDAS_0170:
result = result.sort_index()
else:
result = result.sort()
result = np.round(result, decimals=-3)

if PANDAS_0140:
expected.index.names = ['country', 'year']
else:
# prior versions doesn't allow to set multiple names to MultiIndex
# Thus overwrite it with the result
expected.index = result.index

tm.assert_frame_equal(result, expected)

result = WorldBankReader(inds, countries=cntry_codes,
start=2011, end=2012, freq='Q', errors='ignore').read()
if PANDAS_0170:
result = result.sort_index()
else:
result = result.sort()
result = np.round(result, decimals=-1)
tm.assert_frame_equal(result, expected)

if __name__ == '__main__':
nose.runmodule(argv=[__file__, '-vvs', '-x', '--pdb', '--pdb-failure'],
exit=False) # pragma: no cover
31 changes: 26 additions & 5 deletions pandas_datareader/wb.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,7 @@ class WorldBankReader(_BaseReader):
_format = 'json'

def __init__(self, symbols=None, countries=None,
start=None, end=None,
start=None, end=None, freq=None,
retry_count=3, pause=0.001, session=None, errors='warn'):

if symbols is None:
Expand All @@ -144,6 +144,12 @@ def __init__(self, symbols=None, countries=None,
if errors == 'warn':
warnings.warn('Non-standard ISO country codes: %s' % tmp, UserWarning)

freq_symbols = ['M','Q','A', None]
if freq not in freq_symbols:
msg = 'The frequency `{0}` is not in the accepted list.'.format(freq)
raise ValueError(msg)

self.freq = freq
self.countries = countries
self.errors = errors

Expand All @@ -154,8 +160,18 @@ def url(self):

@property
def params(self):
return {'date': '{0}:{1}'.format(self.start.year, self.end.year),
'per_page': 25000, 'format': 'json'}
if self.freq == 'M':
return {'date': '{0}M{1:02d}:{2}M{3:02d}'.format(self.start.year,
self.start.month, self.end.year, self.end.month),
'per_page': 25000, 'format': 'json'}
if self.freq == 'Q':
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

pls use if-elif-else

return {'date': '{0}Q{1}:{2}Q{3}'.format(self.start.year,
divmod(self.start.month-1,3)[0]+1, self.end.year,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you change _sanitize_dates to always return Timestamp, and can use self.start.quarter?

divmod(self.end.month-1,3)[0]+1),'per_page': 25000,
'format': 'json'}
if self.freq is None or self.freq == 'A':
return {'date': '{0}:{1}'.format(self.start.year, self.end.year),
'per_page': 25000, 'format': 'json'}

def read(self):
data = []
Expand Down Expand Up @@ -326,7 +342,7 @@ def search(self, string='gdp.*capi', field='name', case=False):
return out


def download(country=None, indicator=None, start=2003, end=2005,
def download(country=None, indicator=None, start=2003, end=2005, freq=None,
errors='warn', **kwargs):
"""
Download data series from the World Bank's World Development Indicators
Expand All @@ -352,6 +368,11 @@ def download(country=None, indicator=None, start=2003, end=2005,
end: int
Last year of the data series (inclusive)

freq: str
frequency or periodicity of the data to be retrieved (e.g. 'M' for
monthly, 'Q' for quarterly, and 'A' for annual). None defaults to
annual.

errors: str {'ignore', 'warn', 'raise'}, default 'warn'
Country codes are validated against a hardcoded list. This controls
the outcome of that validation, and attempts to also apply
Expand All @@ -370,7 +391,7 @@ def download(country=None, indicator=None, start=2003, end=2005,

"""
return WorldBankReader(symbols=indicator, countries=country,
start=start, end=end, errors=errors,
start=start, end=end, freq=freq, errors=errors,
**kwargs).read()


Expand Down