Skip to content

Commit 1a32502

Browse files
committed
PythonAPI: more keypoint support (showAnns, loadRes)
1 parent af29c4f commit 1a32502

File tree

1 file changed

+42
-47
lines changed

1 file changed

+42
-47
lines changed

PythonAPI/pycocotools/coco.py

+42-47
Original file line numberDiff line numberDiff line change
@@ -45,18 +45,17 @@
4545
# Licensed under the Simplified BSD License [see bsd.txt]
4646

4747
import json
48-
import datetime
4948
import time
5049
import matplotlib.pyplot as plt
5150
from matplotlib.collections import PatchCollection
5251
from matplotlib.patches import Polygon
5352
import numpy as np
54-
from skimage.draw import polygon
5553
import urllib
5654
import copy
5755
import itertools
5856
import mask
5957
import os
58+
from collections import defaultdict
6059

6160
class COCO:
6261
def __init__(self, annotation_file=None):
@@ -67,12 +66,8 @@ def __init__(self, annotation_file=None):
6766
:return:
6867
"""
6968
# load dataset
70-
self.dataset = {}
71-
self.anns = []
72-
self.imgToAnns = {}
73-
self.catToImgs = {}
74-
self.imgs = {}
75-
self.cats = {}
69+
self.dataset,self.anns,self.cats,self.imgs = dict(),dict(),dict(),dict()
70+
self.imgToAnns, self.catToImgs = defaultdict(list), defaultdict(list)
7671
if not annotation_file == None:
7772
print 'loading annotations into memory...'
7873
tic = time.time()
@@ -84,30 +79,22 @@ def __init__(self, annotation_file=None):
8479
def createIndex(self):
8580
# create index
8681
print 'creating index...'
87-
anns = {}
88-
imgToAnns = {}
89-
catToImgs = {}
90-
cats = {}
91-
imgs = {}
82+
anns,cats,imgs = dict(),dict(),dict()
83+
imgToAnns,catToImgs = defaultdict(list),defaultdict(list)
9284
if 'annotations' in self.dataset:
93-
imgToAnns = {ann['image_id']: [] for ann in self.dataset['annotations']}
94-
anns = {ann['id']: [] for ann in self.dataset['annotations']}
9585
for ann in self.dataset['annotations']:
96-
imgToAnns[ann['image_id']] += [ann]
86+
imgToAnns[ann['image_id']].append(ann)
9787
anns[ann['id']] = ann
9888

9989
if 'images' in self.dataset:
100-
imgs = {im['id']: {} for im in self.dataset['images']}
10190
for img in self.dataset['images']:
10291
imgs[img['id']] = img
10392

10493
if 'categories' in self.dataset:
105-
cats = {cat['id']: [] for cat in self.dataset['categories']}
10694
for cat in self.dataset['categories']:
10795
cats[cat['id']] = cat
108-
catToImgs = {cat['id']: [] for cat in self.dataset['categories']}
10996
for ann in self.dataset['annotations']:
110-
catToImgs[ann['category_id']] += [ann['image_id']]
97+
catToImgs[ann['category_id']].append(ann['image_id'])
11198

11299
print 'index created!'
113100

@@ -142,7 +129,6 @@ def getAnnIds(self, imgIds=[], catIds=[], areaRng=[], iscrowd=None):
142129
anns = self.dataset['annotations']
143130
else:
144131
if not len(imgIds) == 0:
145-
# this can be changed by defaultdict
146132
lists = [self.imgToAnns[imgId] for imgId in imgIds if imgId in self.imgToAnns]
147133
anns = list(itertools.chain.from_iterable(lists))
148134
else:
@@ -239,39 +225,42 @@ def showAnns(self, anns):
239225
"""
240226
if len(anns) == 0:
241227
return 0
242-
if 'segmentation' in anns[0]:
228+
if 'segmentation' in anns[0] or 'keypoints' in anns[0]:
243229
datasetType = 'instances'
244230
elif 'caption' in anns[0]:
245231
datasetType = 'captions'
232+
else:
233+
raise Exception("datasetType not supported")
246234
if datasetType == 'instances':
247235
ax = plt.gca()
248236
ax.set_autoscale_on(False)
249237
polygons = []
250238
color = []
251239
for ann in anns:
252240
c = (np.random.random((1, 3))*0.6+0.4).tolist()[0]
253-
if type(ann['segmentation']) == list:
254-
# polygon
255-
for seg in ann['segmentation']:
256-
poly = np.array(seg).reshape((len(seg)/2, 2))
257-
polygons.append(Polygon(poly))
258-
color.append(c)
259-
else:
260-
# mask
261-
t = self.imgs[ann['image_id']]
262-
if type(ann['segmentation']['counts']) == list:
263-
rle = mask.frPyObjects([ann['segmentation']], t['height'], t['width'])
241+
if 'segmentation' in ann:
242+
if type(ann['segmentation']) == list:
243+
# polygon
244+
for seg in ann['segmentation']:
245+
poly = np.array(seg).reshape((len(seg)/2, 2))
246+
polygons.append(Polygon(poly))
247+
color.append(c)
264248
else:
265-
rle = [ann['segmentation']]
266-
m = mask.decode(rle)
267-
img = np.ones( (m.shape[0], m.shape[1], 3) )
268-
if ann['iscrowd'] == 1:
269-
color_mask = np.array([2.0,166.0,101.0])/255
270-
if ann['iscrowd'] == 0:
271-
color_mask = np.random.random((1, 3)).tolist()[0]
272-
for i in range(3):
273-
img[:,:,i] = color_mask[i]
274-
ax.imshow(np.dstack( (img, m*0.5) ))
249+
# mask
250+
t = self.imgs[ann['image_id']]
251+
if type(ann['segmentation']['counts']) == list:
252+
rle = mask.frPyObjects([ann['segmentation']], t['height'], t['width'])
253+
else:
254+
rle = [ann['segmentation']]
255+
m = mask.decode(rle)
256+
img = np.ones( (m.shape[0], m.shape[1], 3) )
257+
if ann['iscrowd'] == 1:
258+
color_mask = np.array([2.0,166.0,101.0])/255
259+
if ann['iscrowd'] == 0:
260+
color_mask = np.random.random((1, 3)).tolist()[0]
261+
for i in range(3):
262+
img[:,:,i] = color_mask[i]
263+
ax.imshow(np.dstack( (img, m*0.5) ))
275264
if 'keypoints' in ann and type(ann['keypoints']) == list:
276265
# turn skeleton into zero-based index
277266
sks = np.array(self.loadCats(ann['category_id'])[0]['skeleton'])-1
@@ -282,8 +271,8 @@ def showAnns(self, anns):
282271
for sk in sks:
283272
if np.all(v[sk]>0):
284273
plt.plot(x[sk],y[sk], linewidth=3, color=c)
285-
plt.plot(x[v==1], y[v==1],'o',markersize=8, markerfacecolor=c, markeredgecolor='k',markeredgewidth=2)
286-
plt.plot(x[v==2], y[v==2],'o',markersize=8, markerfacecolor=c, markeredgecolor=c, markeredgewidth=2)
274+
plt.plot(x[v>0], y[v>0],'o',markersize=8, markerfacecolor=c, markeredgecolor='k',markeredgewidth=2)
275+
plt.plot(x[v>1], y[v>1],'o',markersize=8, markerfacecolor=c, markeredgecolor=c, markeredgewidth=2)
287276
p = PatchCollection(polygons, facecolor=color, linewidths=0, alpha=0.4)
288277
ax.add_collection(p)
289278
p = PatchCollection(polygons, facecolor="none", edgecolors=color, linewidths=2)
@@ -300,8 +289,6 @@ def loadRes(self, resFile):
300289
"""
301290
res = COCO()
302291
res.dataset['images'] = [img for img in self.dataset['images']]
303-
# res.dataset['info'] = copy.deepcopy(self.dataset['info'])
304-
# res.dataset['licenses'] = copy.deepcopy(self.dataset['licenses'])
305292

306293
print 'Loading and preparing results... '
307294
tic = time.time()
@@ -339,6 +326,14 @@ def loadRes(self, resFile):
339326
ann['bbox'] = mask.toBbox([ann['segmentation']])[0]
340327
ann['id'] = id+1
341328
ann['iscrowd'] = 0
329+
elif 'keypoints' in anns[0]:
330+
res.dataset['categories'] = copy.deepcopy(self.dataset['categories'])
331+
for id, ann in enumerate(anns):
332+
s = ann['keypoints']
333+
x = s[0::3]
334+
y = s[1::3]
335+
ann['area'] = float((np.max(x)-np.min(x))*(np.max(y)-np.min(y)))
336+
ann['id'] = id + 1
342337
print 'DONE (t=%0.2fs)'%(time.time()- tic)
343338

344339
res.dataset['annotations'] = anns

0 commit comments

Comments
 (0)