45
45
# Licensed under the Simplified BSD License [see bsd.txt]
46
46
47
47
import json
48
- import datetime
49
48
import time
50
49
import matplotlib .pyplot as plt
51
50
from matplotlib .collections import PatchCollection
52
51
from matplotlib .patches import Polygon
53
52
import numpy as np
54
- from skimage .draw import polygon
55
53
import urllib
56
54
import copy
57
55
import itertools
58
56
import mask
59
57
import os
58
+ from collections import defaultdict
60
59
61
60
class COCO :
62
61
def __init__ (self , annotation_file = None ):
@@ -67,12 +66,8 @@ def __init__(self, annotation_file=None):
67
66
:return:
68
67
"""
69
68
# load dataset
70
- self .dataset = {}
71
- self .anns = []
72
- self .imgToAnns = {}
73
- self .catToImgs = {}
74
- self .imgs = {}
75
- self .cats = {}
69
+ self .dataset ,self .anns ,self .cats ,self .imgs = dict (),dict (),dict (),dict ()
70
+ self .imgToAnns , self .catToImgs = defaultdict (list ), defaultdict (list )
76
71
if not annotation_file == None :
77
72
print 'loading annotations into memory...'
78
73
tic = time .time ()
@@ -84,30 +79,22 @@ def __init__(self, annotation_file=None):
84
79
def createIndex (self ):
85
80
# create index
86
81
print 'creating index...'
87
- anns = {}
88
- imgToAnns = {}
89
- catToImgs = {}
90
- cats = {}
91
- imgs = {}
82
+ anns ,cats ,imgs = dict (),dict (),dict ()
83
+ imgToAnns ,catToImgs = defaultdict (list ),defaultdict (list )
92
84
if 'annotations' in self .dataset :
93
- imgToAnns = {ann ['image_id' ]: [] for ann in self .dataset ['annotations' ]}
94
- anns = {ann ['id' ]: [] for ann in self .dataset ['annotations' ]}
95
85
for ann in self .dataset ['annotations' ]:
96
- imgToAnns [ann ['image_id' ]] += [ ann ]
86
+ imgToAnns [ann ['image_id' ]]. append ( ann )
97
87
anns [ann ['id' ]] = ann
98
88
99
89
if 'images' in self .dataset :
100
- imgs = {im ['id' ]: {} for im in self .dataset ['images' ]}
101
90
for img in self .dataset ['images' ]:
102
91
imgs [img ['id' ]] = img
103
92
104
93
if 'categories' in self .dataset :
105
- cats = {cat ['id' ]: [] for cat in self .dataset ['categories' ]}
106
94
for cat in self .dataset ['categories' ]:
107
95
cats [cat ['id' ]] = cat
108
- catToImgs = {cat ['id' ]: [] for cat in self .dataset ['categories' ]}
109
96
for ann in self .dataset ['annotations' ]:
110
- catToImgs [ann ['category_id' ]] += [ ann ['image_id' ]]
97
+ catToImgs [ann ['category_id' ]]. append ( ann ['image_id' ])
111
98
112
99
print 'index created!'
113
100
@@ -142,7 +129,6 @@ def getAnnIds(self, imgIds=[], catIds=[], areaRng=[], iscrowd=None):
142
129
anns = self .dataset ['annotations' ]
143
130
else :
144
131
if not len (imgIds ) == 0 :
145
- # this can be changed by defaultdict
146
132
lists = [self .imgToAnns [imgId ] for imgId in imgIds if imgId in self .imgToAnns ]
147
133
anns = list (itertools .chain .from_iterable (lists ))
148
134
else :
@@ -239,39 +225,42 @@ def showAnns(self, anns):
239
225
"""
240
226
if len (anns ) == 0 :
241
227
return 0
242
- if 'segmentation' in anns [0 ]:
228
+ if 'segmentation' in anns [0 ] or 'keypoints' in anns [ 0 ] :
243
229
datasetType = 'instances'
244
230
elif 'caption' in anns [0 ]:
245
231
datasetType = 'captions'
232
+ else :
233
+ raise Exception ("datasetType not supported" )
246
234
if datasetType == 'instances' :
247
235
ax = plt .gca ()
248
236
ax .set_autoscale_on (False )
249
237
polygons = []
250
238
color = []
251
239
for ann in anns :
252
240
c = (np .random .random ((1 , 3 ))* 0.6 + 0.4 ).tolist ()[0 ]
253
- if type (ann ['segmentation' ]) == list :
254
- # polygon
255
- for seg in ann ['segmentation' ]:
256
- poly = np .array (seg ).reshape ((len (seg )/ 2 , 2 ))
257
- polygons .append (Polygon (poly ))
258
- color .append (c )
259
- else :
260
- # mask
261
- t = self .imgs [ann ['image_id' ]]
262
- if type (ann ['segmentation' ]['counts' ]) == list :
263
- rle = mask .frPyObjects ([ann ['segmentation' ]], t ['height' ], t ['width' ])
241
+ if 'segmentation' in ann :
242
+ if type (ann ['segmentation' ]) == list :
243
+ # polygon
244
+ for seg in ann ['segmentation' ]:
245
+ poly = np .array (seg ).reshape ((len (seg )/ 2 , 2 ))
246
+ polygons .append (Polygon (poly ))
247
+ color .append (c )
264
248
else :
265
- rle = [ann ['segmentation' ]]
266
- m = mask .decode (rle )
267
- img = np .ones ( (m .shape [0 ], m .shape [1 ], 3 ) )
268
- if ann ['iscrowd' ] == 1 :
269
- color_mask = np .array ([2.0 ,166.0 ,101.0 ])/ 255
270
- if ann ['iscrowd' ] == 0 :
271
- color_mask = np .random .random ((1 , 3 )).tolist ()[0 ]
272
- for i in range (3 ):
273
- img [:,:,i ] = color_mask [i ]
274
- ax .imshow (np .dstack ( (img , m * 0.5 ) ))
249
+ # mask
250
+ t = self .imgs [ann ['image_id' ]]
251
+ if type (ann ['segmentation' ]['counts' ]) == list :
252
+ rle = mask .frPyObjects ([ann ['segmentation' ]], t ['height' ], t ['width' ])
253
+ else :
254
+ rle = [ann ['segmentation' ]]
255
+ m = mask .decode (rle )
256
+ img = np .ones ( (m .shape [0 ], m .shape [1 ], 3 ) )
257
+ if ann ['iscrowd' ] == 1 :
258
+ color_mask = np .array ([2.0 ,166.0 ,101.0 ])/ 255
259
+ if ann ['iscrowd' ] == 0 :
260
+ color_mask = np .random .random ((1 , 3 )).tolist ()[0 ]
261
+ for i in range (3 ):
262
+ img [:,:,i ] = color_mask [i ]
263
+ ax .imshow (np .dstack ( (img , m * 0.5 ) ))
275
264
if 'keypoints' in ann and type (ann ['keypoints' ]) == list :
276
265
# turn skeleton into zero-based index
277
266
sks = np .array (self .loadCats (ann ['category_id' ])[0 ]['skeleton' ])- 1
@@ -282,8 +271,8 @@ def showAnns(self, anns):
282
271
for sk in sks :
283
272
if np .all (v [sk ]> 0 ):
284
273
plt .plot (x [sk ],y [sk ], linewidth = 3 , color = c )
285
- plt .plot (x [v == 1 ], y [v == 1 ],'o' ,markersize = 8 , markerfacecolor = c , markeredgecolor = 'k' ,markeredgewidth = 2 )
286
- plt .plot (x [v == 2 ], y [v == 2 ],'o' ,markersize = 8 , markerfacecolor = c , markeredgecolor = c , markeredgewidth = 2 )
274
+ plt .plot (x [v > 0 ], y [v > 0 ],'o' ,markersize = 8 , markerfacecolor = c , markeredgecolor = 'k' ,markeredgewidth = 2 )
275
+ plt .plot (x [v > 1 ], y [v > 1 ],'o' ,markersize = 8 , markerfacecolor = c , markeredgecolor = c , markeredgewidth = 2 )
287
276
p = PatchCollection (polygons , facecolor = color , linewidths = 0 , alpha = 0.4 )
288
277
ax .add_collection (p )
289
278
p = PatchCollection (polygons , facecolor = "none" , edgecolors = color , linewidths = 2 )
@@ -300,8 +289,6 @@ def loadRes(self, resFile):
300
289
"""
301
290
res = COCO ()
302
291
res .dataset ['images' ] = [img for img in self .dataset ['images' ]]
303
- # res.dataset['info'] = copy.deepcopy(self.dataset['info'])
304
- # res.dataset['licenses'] = copy.deepcopy(self.dataset['licenses'])
305
292
306
293
print 'Loading and preparing results... '
307
294
tic = time .time ()
@@ -339,6 +326,14 @@ def loadRes(self, resFile):
339
326
ann ['bbox' ] = mask .toBbox ([ann ['segmentation' ]])[0 ]
340
327
ann ['id' ] = id + 1
341
328
ann ['iscrowd' ] = 0
329
+ elif 'keypoints' in anns [0 ]:
330
+ res .dataset ['categories' ] = copy .deepcopy (self .dataset ['categories' ])
331
+ for id , ann in enumerate (anns ):
332
+ s = ann ['keypoints' ]
333
+ x = s [0 ::3 ]
334
+ y = s [1 ::3 ]
335
+ ann ['area' ] = float ((np .max (x )- np .min (x ))* (np .max (y )- np .min (y )))
336
+ ann ['id' ] = id + 1
342
337
print 'DONE (t=%0.2fs)' % (time .time ()- tic )
343
338
344
339
res .dataset ['annotations' ] = anns
0 commit comments