@@ -2791,6 +2791,59 @@ def test_insert_error_msmgs(self):
2791
2791
with assertRaisesRegexp(TypeError, msg):
2792
2792
df['gr'] = df.groupby(['b', 'c']).count()
2793
2793
2794
+ def test_frame_subclassing_and_slicing(self):
2795
+ # Subclass frame and ensure it returns the right class on slicing it
2796
+ # In reference to PR 9632
2797
+
2798
+ class CustomSeries(Series):
2799
+ @property
2800
+ def _constructor(self):
2801
+ return CustomSeries
2802
+
2803
+ def custom_series_function(self):
2804
+ return 'OK'
2805
+
2806
+ class CustomDataFrame(DataFrame):
2807
+ "Subclasses pandas DF, fills DF with simulation results, adds some custom plotting functions."
2808
+
2809
+ def __init__(self, *args, **kw):
2810
+ super(CustomDataFrame, self).__init__(*args, **kw)
2811
+
2812
+ @property
2813
+ def _constructor(self):
2814
+ return CustomDataFrame
2815
+
2816
+ _constructor_sliced = CustomSeries
2817
+
2818
+ def custom_frame_function(self):
2819
+ return 'OK'
2820
+
2821
+ data = {'col1': range(10),
2822
+ 'col2': range(10)}
2823
+ cdf = CustomDataFrame(data)
2824
+
2825
+ # Did we get back our own DF class?
2826
+ self.assertTrue(isinstance(cdf, CustomDataFrame))
2827
+
2828
+ # Do we get back our own Series class after selecting a column?
2829
+ cdf_series = cdf.col1
2830
+ self.assertTrue(isinstance(cdf_series, CustomSeries))
2831
+ self.assertEqual(cdf_series.custom_series_function(), 'OK')
2832
+
2833
+ # Do we get back our own DF class after slicing row-wise?
2834
+ cdf_rows = cdf[1:5]
2835
+ self.assertTrue(isinstance(cdf_rows, CustomDataFrame))
2836
+ self.assertEqual(cdf_rows.custom_frame_function(), 'OK')
2837
+
2838
+ # Make sure sliced part of multi-index frame is custom class
2839
+ mcol = pd.MultiIndex.from_tuples([('A', 'A'), ('A', 'B')])
2840
+ cdf_multi = CustomDataFrame([[0, 1], [2, 3]], columns=mcol)
2841
+ self.assertTrue(isinstance(cdf_multi['A'], CustomDataFrame))
2842
+
2843
+ mcol = pd.MultiIndex.from_tuples([('A', ''), ('B', '')])
2844
+ cdf_multi2 = CustomDataFrame([[0, 1], [2, 3]], columns=mcol)
2845
+ self.assertTrue(isinstance(cdf_multi2['A'], CustomSeries))
2846
+
2794
2847
def test_constructor_subclass_dict(self):
2795
2848
# Test for passing dict subclass to constructor
2796
2849
data = {'col1': tm.TestSubDict((x, 10.0 * x) for x in range(10)),
0 commit comments