Skip to content

Commit cb01abd

Browse files
authored
Remove InternalFrame.index_map. (#1901)
Removes `InternalFrame.index_map` which is not suitable in many cases or can easily be replaced with `zip(column_names, index_names)`.
1 parent 9285f95 commit cb01abd

File tree

7 files changed

+137
-130
lines changed

7 files changed

+137
-130
lines changed

databricks/koalas/frame.py

Lines changed: 72 additions & 69 deletions
Original file line numberDiff line numberDiff line change
@@ -3432,11 +3432,17 @@ def rename(index):
34323432
return ("level_{}".format(index),)
34333433

34343434
if level is None:
3435-
new_index_map = [
3436-
(column, name if name is not None else rename(i))
3437-
for i, (column, name) in enumerate(self._internal.index_map.items())
3435+
new_column_labels = [
3436+
name if name is not None else rename(i)
3437+
for i, name in enumerate(self._internal.index_names)
34383438
]
3439-
index_map = {} # type: Dict
3439+
new_data_spark_columns = [
3440+
scol.alias(name_like_string(label))
3441+
for scol, label in zip(self._internal.index_spark_columns, new_column_labels)
3442+
]
3443+
3444+
index_spark_column_names = []
3445+
index_names = []
34403446
else:
34413447
if is_list_like(level):
34423448
level = list(level)
@@ -3478,35 +3484,29 @@ def rename(index):
34783484
raise ValueError("Level should be all int or all string.")
34793485
idx.sort()
34803486

3481-
new_index_map = []
3482-
index_map_items = list(self._internal.index_map.items())
3483-
new_index_map_items = index_map_items.copy()
3484-
for i in idx:
3485-
info = index_map_items[i]
3486-
index_column, index_name = info
3487-
new_index_map.append(
3488-
(index_column, index_name if index_name is not None else rename(i))
3489-
)
3490-
new_index_map_items.remove(info)
3487+
new_column_labels = []
3488+
new_data_spark_columns = []
34913489

3492-
index_map = OrderedDict(new_index_map_items)
3490+
index_spark_column_names = self._internal.index_spark_column_names.copy()
3491+
index_spark_columns = self._internal.index_spark_columns.copy()
3492+
index_names = self._internal.index_names.copy()
34933493

3494-
if drop:
3495-
new_index_map = []
3494+
for i in idx[::-1]:
3495+
index_spark_column_names.pop(i)
34963496

3497-
for _, name in new_index_map:
3498-
if name in self._internal.column_labels:
3499-
raise ValueError("cannot insert {}, already exists".format(name_like_string(name)))
3497+
name = index_names.pop(i)
3498+
new_column_labels.insert(0, name if name is not None else rename(i))
35003499

3501-
sdf = self._internal.spark_frame
3502-
new_data_scols = [
3503-
scol_for(sdf, column).alias(name_like_string(name)) for column, name in new_index_map
3504-
]
3500+
scol = index_spark_columns.pop(i)
3501+
new_data_spark_columns.insert(0, scol.alias(name_like_string(name)))
35053502

3506-
index_scols = [scol_for(sdf, column) for column in index_map]
3507-
sdf = sdf.select(
3508-
index_scols + new_data_scols + self._internal.data_spark_columns + list(HIDDEN_COLUMNS)
3509-
)
3503+
if drop:
3504+
new_data_spark_columns = []
3505+
new_column_labels = []
3506+
3507+
for label in new_column_labels:
3508+
if label in self._internal.column_labels:
3509+
raise ValueError("cannot insert {}, already exists".format(name_like_string(label)))
35103510

35113511
if self._internal.column_labels_level > 1:
35123512
column_depth = len(self._internal.column_labels[0])
@@ -3516,28 +3516,22 @@ def rename(index):
35163516
column_depth, col_level + 1
35173517
)
35183518
)
3519-
if any(col_level + len(name) > column_depth for _, name in new_index_map):
3519+
if any(col_level + len(label) > column_depth for label in new_column_labels):
35203520
raise ValueError("Item must have length equal to number of levels.")
3521-
column_labels = [
3521+
new_column_labels = [
35223522
tuple(
35233523
([col_fill] * col_level)
3524-
+ list(name)
3525-
+ ([col_fill] * (column_depth - (len(name) + col_level)))
3524+
+ list(label)
3525+
+ ([col_fill] * (column_depth - (len(label) + col_level)))
35263526
)
3527-
for _, name in new_index_map
3528-
] + self._internal.column_labels
3529-
else:
3530-
column_labels = [name for _, name in new_index_map] + self._internal.column_labels
3527+
for label in new_column_labels
3528+
]
35313529

35323530
internal = self._internal.copy(
3533-
spark_frame=sdf,
3534-
index_spark_column_names=list(index_map.keys()),
3535-
index_names=list(index_map.values()),
3536-
column_labels=column_labels,
3537-
data_spark_columns=(
3538-
[scol_for(sdf, name_like_string(name)) for _, name in new_index_map]
3539-
+ [scol_for(sdf, col) for col in self._internal.data_spark_column_names]
3540-
),
3531+
index_spark_column_names=index_spark_column_names,
3532+
index_names=index_names,
3533+
column_labels=new_column_labels + self._internal.column_labels,
3534+
data_spark_columns=new_data_spark_columns + self._internal.data_spark_columns,
35413535
)
35423536

35433537
if inplace:
@@ -5957,11 +5951,10 @@ def droplevel(self, level, axis=0) -> "DataFrame":
59575951
if not isinstance(level, (tuple, list)): # huh?
59585952
level = [level]
59595953

5960-
spark_frame = self._internal.spark_frame
5961-
index_map = self._internal.index_map.copy()
59625954
index_names = self.index.names
5963-
nlevels = self.index.nlevels
5964-
int_levels = list()
5955+
nlevels = self._internal.index_level
5956+
5957+
int_level = set()
59655958
for n in level:
59665959
if isinstance(n, int):
59675960
if n < 0:
@@ -5981,22 +5974,27 @@ def droplevel(self, level, axis=0) -> "DataFrame":
59815974
if n not in index_names:
59825975
raise KeyError("Level {} not found".format(n))
59835976
n = index_names.index(n)
5984-
int_levels.append(n)
5977+
int_level.add(n)
59855978

5986-
if len(int_levels) >= nlevels:
5979+
if len(level) >= nlevels:
59875980
raise ValueError(
59885981
"Cannot remove {} levels from an index with {} levels: "
5989-
"at least one level must be left.".format(len(int_levels), nlevels)
5982+
"at least one level must be left.".format(len(level), nlevels)
59905983
)
59915984

5992-
for int_level in int_levels:
5993-
index_spark_column = self._internal.index_spark_column_names[int_level]
5994-
spark_frame = spark_frame.drop(index_spark_column)
5995-
index_map.pop(index_spark_column)
5985+
index_spark_column_names, index_names = zip(
5986+
*[
5987+
item
5988+
for i, item in enumerate(
5989+
zip(self._internal.index_spark_column_names, self._internal.index_names)
5990+
)
5991+
if i not in int_level
5992+
]
5993+
)
5994+
59965995
internal = self._internal.copy(
5997-
spark_frame=spark_frame,
5998-
index_spark_column_names=list(index_map.keys()),
5999-
index_names=list(index_map.values()),
5996+
index_spark_column_names=list(index_spark_column_names),
5997+
index_names=list(index_names),
60005998
)
60015999
return DataFrame(internal)
60026000
else:
@@ -6845,33 +6843,38 @@ def to_list(os: Optional[Union[Any, List[Any], Tuple, List[Tuple]]]) -> List[Tup
68456843
if right_index:
68466844
if how in ("inner", "left"):
68476845
exprs.extend(left_index_scols)
6848-
index_map = self._internal.index_map
6846+
index_spark_column_names = self._internal.index_spark_column_names
6847+
index_names = self._internal.index_names
68496848
elif how == "right":
68506849
exprs.extend(right_index_scols)
6851-
index_map = right._internal.index_map
6850+
index_spark_column_names = right._internal.index_spark_column_names
6851+
index_names = right._internal.index_names
68526852
else:
6853-
index_map = OrderedDict()
6854-
for (col, name), left_scol, right_scol in zip(
6855-
self._internal.index_map.items(), left_index_scols, right_index_scols
6853+
index_spark_column_names = self._internal.index_spark_column_names
6854+
index_names = self._internal.index_names
6855+
for col, left_scol, right_scol in zip(
6856+
index_spark_column_names, left_index_scols, right_index_scols
68566857
):
68576858
scol = F.when(left_scol.isNotNull(), left_scol).otherwise(right_scol)
68586859
exprs.append(scol.alias(col))
6859-
index_map[col] = name
68606860
else:
68616861
exprs.extend(right_index_scols)
6862-
index_map = right._internal.index_map
6862+
index_spark_column_names = right._internal.index_spark_column_names
6863+
index_names = right._internal.index_names
68636864
elif right_index:
68646865
exprs.extend(left_index_scols)
6865-
index_map = self._internal.index_map
6866+
index_spark_column_names = self._internal.index_spark_column_names
6867+
index_names = self._internal.index_names
68666868
else:
6867-
index_map = OrderedDict()
6869+
index_spark_column_names = None
6870+
index_names = None
68686871

68696872
selected_columns = joined_table.select(*exprs)
68706873

68716874
internal = InternalFrame(
68726875
spark_frame=selected_columns,
6873-
index_spark_column_names=list(index_map.keys()) if index_map else None,
6874-
index_names=list(index_map.values()) if index_map else None,
6876+
index_spark_column_names=index_spark_column_names,
6877+
index_names=index_names,
68756878
column_labels=column_labels,
68766879
data_spark_columns=[scol_for(selected_columns, col) for col in data_columns],
68776880
)

databricks/koalas/indexes.py

Lines changed: 32 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1233,14 +1233,25 @@ def droplevel(self, level) -> "Index":
12331233
if not is_list_like(level):
12341234
level = [level]
12351235

1236+
int_level = set()
12361237
for n in level:
12371238
if isinstance(n, int):
1238-
if n > nlevels - 1:
1239+
if n < 0:
1240+
n = n + nlevels
1241+
if n < 0:
1242+
raise IndexError(
1243+
"Too many levels: Index has only {} levels, "
1244+
"{} is not a valid level number".format(nlevels, (n - nlevels))
1245+
)
1246+
if n >= nlevels:
12391247
raise IndexError(
12401248
"Too many levels: Index has only {} levels, not {}".format(nlevels, n + 1)
12411249
)
1242-
elif n not in names:
1243-
raise KeyError("Level {} not found".format(n))
1250+
else:
1251+
if n not in names:
1252+
raise KeyError("Level {} not found".format(n))
1253+
n = names.index(n)
1254+
int_level.add(n)
12441255

12451256
if len(level) >= nlevels:
12461257
raise ValueError(
@@ -1249,19 +1260,24 @@ def droplevel(self, level) -> "Index":
12491260
"left.".format(len(level), nlevels)
12501261
)
12511262

1252-
int_level = set(n if isinstance(n, int) else names.index(n) for n in level)
12531263
index_spark_column_names, index_names = zip(
1254-
*[item for i, item in enumerate(self._internal.index_map.items()) if i not in int_level]
1264+
*[
1265+
item
1266+
for i, item in enumerate(
1267+
zip(self._internal.index_spark_column_names, self._internal.index_names)
1268+
)
1269+
if i not in int_level
1270+
]
12551271
)
12561272

12571273
sdf = self._internal.spark_frame
12581274
sdf = sdf.select(*index_spark_column_names)
1259-
result = InternalFrame(
1275+
internal = InternalFrame(
12601276
spark_frame=sdf,
12611277
index_spark_column_names=list(index_spark_column_names),
12621278
index_names=list(index_names),
12631279
)
1264-
return DataFrame(result).index
1280+
return DataFrame(internal).index
12651281

12661282
def symmetric_difference(self, other, result_name=None, sort=None) -> "Index":
12671283
"""
@@ -2632,7 +2648,7 @@ def swaplevel(self, i=-2, j=-1) -> "MultiIndex":
26322648
"%s is not a valid level number" % (len(self.names), index)
26332649
)
26342650

2635-
index_map = list(self._internal.index_map.items())
2651+
index_map = list(zip(self._internal.index_spark_column_names, self._internal.index_names))
26362652
index_map[i], index_map[j], = index_map[j], index_map[i]
26372653
index_spark_column_names, index_names = zip(*index_map)
26382654
internal = self._kdf._internal.copy(
@@ -3011,22 +3027,24 @@ def drop(self, codes, level=None) -> "MultiIndex":
30113027
elif isinstance(level, int):
30123028
scol = index_scols[level]
30133029
else:
3014-
spark_column_name = None
3015-
for index_spark_column_name, index_name in self._internal.index_map.items():
3030+
scol = None
3031+
for index_spark_column, index_name in zip(
3032+
self._internal.index_spark_columns, self._internal.index_names
3033+
):
30163034
if not isinstance(level, tuple):
30173035
level = (level,)
30183036
if level == index_name:
3019-
if spark_column_name is not None:
3037+
if scol is not None:
30203038
raise ValueError(
30213039
"The name {} occurs multiple times, use a level number".format(
30223040
name_like_string(level)
30233041
)
30243042
)
3025-
spark_column_name = index_spark_column_name
3026-
if spark_column_name is None:
3043+
scol = index_spark_column
3044+
if scol is None:
30273045
raise KeyError("Level {} not found".format(name_like_string(level)))
3028-
scol = scol_for(sdf, spark_column_name)
30293046
sdf = sdf[~scol.isin(codes)]
3047+
30303048
return MultiIndex(
30313049
DataFrame(
30323050
InternalFrame(
@@ -3220,7 +3238,6 @@ def insert(self, loc: int, item) -> Index:
32203238
)
32213239

32223240
index_name = self._internal.index_spark_column_names
3223-
sdf = self._internal.spark_frame
32243241
sdf_before = self.to_frame(name=index_name)[:loc].to_spark()
32253242
sdf_middle = Index([item]).to_frame(name=index_name).to_spark()
32263243
sdf_after = self.to_frame(name=index_name)[loc:].to_spark()
@@ -3277,8 +3294,6 @@ def intersection(self, other) -> "MultiIndex":
32773294
MultiIndex([('c', 'z')],
32783295
)
32793296
"""
3280-
keep_name = True
3281-
32823297
if isinstance(other, Series) or not is_list_like(other):
32833298
raise TypeError("other must be a MultiIndex or a list of tuples")
32843299
elif isinstance(other, DataFrame):

0 commit comments

Comments
 (0)