Skip to content

Commit 1955c6d

Browse files
committed
include pk in default JSON; detect custom pk name
(fixes #29)
1 parent b568f7c commit 1955c6d

File tree

7 files changed

+104
-15
lines changed

7 files changed

+104
-15
lines changed

rest_pandas/renderers.py

+37-7
Original file line numberDiff line numberDiff line change
@@ -36,16 +36,21 @@ def render(self, data, accepted_media_type=None, renderer_context=None):
3636
raise Exception(
3737
RESPONSE_ERROR % type(data).__name__
3838
)
39+
3940
name = getattr(self, 'function', "to_%s" % self.format)
40-
function = getattr(data, name, None)
41-
if not function:
41+
if not hasattr(data, name):
4242
raise Exception("Data frame is missing %s property!" % name)
43+
4344
self.init_output()
4445
args = self.get_pandas_args(data)
4546
kwargs = self.get_pandas_kwargs(data, renderer_context)
46-
function(*args, **kwargs)
47+
self.render_dataframe(data, name, *args, **kwargs)
4748
return self.get_output()
4849

50+
def render_dataframe(self, data, name, *args, **kwargs):
51+
function = getattr(data, name)
52+
function(*args, **kwargs)
53+
4954
def init_output(self):
5055
self.output = StringIO()
5156

@@ -150,19 +155,44 @@ class PandasJSONRenderer(PandasBaseRenderer):
150155
media_type = "application/json"
151156
format = "json"
152157

158+
orient_choices = {
159+
'records-index', # Unique to DRP
160+
'split',
161+
'records',
162+
'index',
163+
'columns',
164+
'values',
165+
'table',
166+
}
167+
default_orient = 'records-index'
168+
169+
date_format_choices = {'epoch', 'iso'}
170+
default_date_format = 'iso'
171+
153172
def get_pandas_kwargs(self, data, renderer_context):
154173
request = renderer_context['request']
174+
155175
orient = request.GET.get('orient', '')
176+
if orient not in self.orient_choices:
177+
orient = self.default_orient
178+
156179
date_format = request.GET.get('date_format', '')
157-
if orient not in {'split', 'records', 'index', 'columns', 'values'}:
158-
orient = 'records'
159-
if date_format not in {'epoch', 'iso'}:
160-
date_format = 'iso'
180+
if date_format not in self.date_format_choices:
181+
date_format = self.default_date_format
182+
161183
return {
162184
'orient': orient,
163185
'date_format': date_format,
164186
}
165187

188+
def render_dataframe(self, data, name, *args, **kwargs):
189+
if kwargs.get('orient') == 'records-index':
190+
kwargs['orient'] = 'records'
191+
data.reset_index(inplace=True)
192+
return super(PandasJSONRenderer, self).render_dataframe(
193+
data, name, *args, **kwargs
194+
)
195+
166196

167197
class PandasExcelRenderer(PandasFileRenderer):
168198
"""

rest_pandas/serializers.py

+11-5
Original file line numberDiff line numberDiff line change
@@ -68,11 +68,17 @@ def get_index_fields(self):
6868
"""
6969
List of fields to use for index
7070
"""
71-
default_fields = []
72-
if getattr(self.model_serializer_meta, 'model', None):
73-
if 'id' in self.child.get_fields():
74-
default_fields = ['id']
75-
return self.get_meta_option('index', default_fields)
71+
index_fields = self.get_meta_option('index', [])
72+
if index_fields:
73+
return index_fields
74+
75+
model = getattr(self.model_serializer_meta, 'model', None)
76+
if model:
77+
pk_name = model._meta.pk.name
78+
if pk_name in self.child.get_fields():
79+
return [pk_name]
80+
81+
return []
7682

7783
def get_meta_option(self, name, default=None):
7884
meta_name = 'pandas_' + name

tests/test_views.py

+31-1
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
from rest_framework.test import APITestCase
2-
from tests.testapp.models import TimeSeries
2+
from tests.testapp.models import TimeSeries, CustomIndexSeries
33
from wq.io import load_string
44
import json
55
import datetime
@@ -17,6 +17,10 @@ def setUp(self):
1717
)
1818
for date, value in data:
1919
TimeSeries.objects.create(date=date, value=value)
20+
CustomIndexSeries.objects.create(
21+
code='v' + date.replace('-', ''),
22+
value=value,
23+
)
2024

2125
def test_view_csv(self):
2226
response = self.client.get("/timeseries.csv")
@@ -41,6 +45,8 @@ def test_view_json(self):
4145
self.assertEqual(response.accepted_media_type, "application/json")
4246
data = json.loads(response.content.decode('utf-8'))
4347
self.assertEqual(len(data), 5)
48+
self.assertIn('id', data[0])
49+
self.assertEqual(data[0]["id"], 1)
4450
self.assertEqual(data[0]["value"], 0.5)
4551
self.assertEqual(data[0]["date"], "2014-01-01T00:00:00.000Z")
4652

@@ -53,6 +59,11 @@ def test_view_json_kwargs(self):
5359
date = datetime.datetime.utcfromtimestamp(data[0]["date"] / 1000)
5460
self.assertEqual(date, datetime.datetime(2014, 1, 1))
5561

62+
response = self.client.get("/timeseries.json?orient=index")
63+
data = json.loads(response.content.decode('utf-8'))
64+
self.assertEqual(len(data.values()), 5)
65+
self.assertEqual(data["1"]["value"], 0.5)
66+
5667
def test_view_html(self):
5768
response = self.client.get("/timeseries?test=1")
5869
expected = open(
@@ -111,5 +122,24 @@ def test_from_file(self):
111122
self.assertEqual(len(data), 4)
112123
self.assertEqual(data[0].x, '5')
113124

125+
def test_customindex_csv(self):
126+
response = self.client.get("/customindex.csv")
127+
data = self.load_string(response)
128+
self.assertEqual(len(data), 5)
129+
self.assertEqual(data[0].code, 'v20140101')
130+
self.assertEqual(data[0].value, '0.5')
131+
132+
def test_customindex_json(self):
133+
response = self.client.get("/customindex.json")
134+
data = json.loads(response.content.decode('utf-8'))
135+
self.assertEqual(len(data), 5)
136+
self.assertEqual(data[0]['code'], 'v20140101')
137+
self.assertEqual(data[0]['value'], 0.5)
138+
139+
response = self.client.get("/customindex.json?orient=index")
140+
data = json.loads(response.content.decode('utf-8'))
141+
self.assertEqual(len(data), 5)
142+
self.assertEqual(data['v20140101']['value'], 0.5)
143+
114144
def load_string(self, response):
115145
return load_string(response.content.decode('utf-8'))

tests/testapp/models.py

+5
Original file line numberDiff line numberDiff line change
@@ -33,3 +33,8 @@ class ComplexTimeSeries(models.Model):
3333
# Values
3434
value = models.FloatField()
3535
flag = models.CharField(max_length=1, null=True, blank=True)
36+
37+
38+
class CustomIndexSeries(models.Model):
39+
code = models.SlugField(primary_key=True)
40+
value = models.FloatField()

tests/testapp/serializers.py

+9-1
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,9 @@
11
from rest_framework.serializers import ModelSerializer
22
from rest_framework import serializers
33
from rest_pandas import PandasUnstackedSerializer
4-
from .models import TimeSeries, MultiTimeSeries, ComplexTimeSeries
4+
from .models import (
5+
TimeSeries, MultiTimeSeries, ComplexTimeSeries, CustomIndexSeries,
6+
)
57

68

79
class TimeSeriesSerializer(ModelSerializer):
@@ -62,3 +64,9 @@ class Meta:
6264
list_serializer_class = PandasUnstackedSerializer
6365
# pandas_unstacked_header = Missing
6466
pandas_index = ['series']
67+
68+
69+
class CustomIndexSeriesSerializer(ModelSerializer):
70+
class Meta:
71+
model = CustomIndexSeries
72+
fields = '__all__'

tests/testapp/urls.py

+2
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
DjangoPandasView, TimeSeriesViewSet,
99
MultiTimeSeriesView, MultiScatterView, MultiBoxplotView,
1010
ComplexTimeSeriesView, ComplexScatterView, ComplexBoxplotView,
11+
CustomIndexSeriesView,
1112
)
1213

1314
router = DefaultRouter()
@@ -28,6 +29,7 @@
2829
url(r'^complextimeseries$', ComplexTimeSeriesView.as_view()),
2930
url(r'^complexscatter$', ComplexScatterView.as_view()),
3031
url(r'^complexboxplot$', ComplexBoxplotView.as_view()),
32+
url(r'^customindex$', CustomIndexSeriesView.as_view()),
3133
]
3234
urlpatterns = format_suffix_patterns(urlpatterns)
3335
urlpatterns += [

tests/testapp/views.py

+9-1
Original file line numberDiff line numberDiff line change
@@ -5,12 +5,15 @@
55
from rest_framework import renderers
66
from rest_framework.generics import ListAPIView
77
from rest_pandas import renderers as pandas_renderers
8-
from .models import TimeSeries, MultiTimeSeries, ComplexTimeSeries
8+
from .models import (
9+
TimeSeries, MultiTimeSeries, ComplexTimeSeries, CustomIndexSeries,
10+
)
911
from .serializers import (
1012
TimeSeriesSerializer, TimeSeriesNoIdSerializer,
1113
MultiTimeSeriesSerializer,
1214
ComplexTimeSeriesSerializer, ComplexScatterSerializer,
1315
ComplexBoxplotSerializer,
16+
CustomIndexSeriesSerializer,
1417
)
1518
import pandas as pd
1619

@@ -122,3 +125,8 @@ class ComplexBoxplotView(PandasView):
122125
queryset = ComplexTimeSeries.objects.all()
123126
serializer_class = ComplexBoxplotSerializer
124127
pandas_serializer_class = PandasBoxplotSerializer
128+
129+
130+
class CustomIndexSeriesView(PandasView):
131+
queryset = CustomIndexSeries.objects.all()
132+
serializer_class = CustomIndexSeriesSerializer

0 commit comments

Comments
 (0)