Skip to content

REF: collect attribute-setting at the end of create_axes #30029

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Dec 4, 2019
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 26 additions & 20 deletions pandas/io/pytables.py
Original file line number Diff line number Diff line change
Expand Up @@ -3734,7 +3734,7 @@ def read_axes(

return True

def get_object(self, obj):
def get_object(self, obj, transposed: bool):
""" return the data for this obj """
return obj

Expand Down Expand Up @@ -3838,15 +3838,13 @@ def create_axes(
)

# create according to the new data
self.non_index_axes = []
self.data_columns = []
new_non_index_axes: List = []
new_data_columns: List[Optional[str]] = []

# nan_representation
if nan_rep is None:
nan_rep = "nan"

self.nan_rep = nan_rep

# create axes to index and non_index
index_axes_map = dict()
for i, a in enumerate(obj.axes):
Expand All @@ -3863,7 +3861,7 @@ def create_axes(
# necessary
append_axis = list(a)
if existing_table is not None:
indexer = len(self.non_index_axes)
indexer = len(new_non_index_axes)
exist_axis = existing_table.non_index_axes[indexer][1]
if not array_equivalent(
np.array(append_axis), np.array(exist_axis)
Expand All @@ -3880,34 +3878,37 @@ def create_axes(
info["names"] = list(a.names)
info["type"] = type(a).__name__

self.non_index_axes.append((i, append_axis))
new_non_index_axes.append((i, append_axis))

self.non_index_axes = new_non_index_axes

# set axis positions (based on the axes)
new_index_axes = [index_axes_map[a] for a in axes]
for j, iax in enumerate(new_index_axes):
iax.set_pos(j)
iax.update_info(self.info)
self.index_axes = new_index_axes

j = len(self.index_axes)
j = len(new_index_axes)

# check for column conflicts
for a in self.axes:
for a in new_index_axes:
a.maybe_set_size(min_itemsize=min_itemsize)

# reindex by our non_index_axes & compute data_columns
for a in self.non_index_axes:
for a in new_non_index_axes:
obj = _reindex_axis(obj, a[0], a[1])

def get_blk_items(mgr, blocks):
return [mgr.items.take(blk.mgr_locs) for blk in blocks]

transposed = new_index_axes[0].axis == 1

# figure out data_columns and get out blocks
block_obj = self.get_object(obj)._consolidate()
block_obj = self.get_object(obj, transposed)._consolidate()
blocks = block_obj._data.blocks
blk_items = get_blk_items(block_obj._data, blocks)
if len(self.non_index_axes):
axis, axis_labels = self.non_index_axes[0]
if len(new_non_index_axes):
axis, axis_labels = new_non_index_axes[0]
data_columns = self.validate_data_columns(data_columns, min_itemsize)
if len(data_columns):
mgr = block_obj.reindex(
Expand Down Expand Up @@ -3945,7 +3946,7 @@ def get_blk_items(mgr, blocks):
blk_items = new_blk_items

# add my values
self.values_axes = []
vaxes = []
for i, (b, b_items) in enumerate(zip(blocks, blk_items)):

# shape of the data column are the indexable axes
Expand All @@ -3959,7 +3960,7 @@ def get_blk_items(mgr, blocks):
if not (name is None or isinstance(name, str)):
# TODO: should the message here be more specifically non-str?
raise ValueError("cannot have non-object label DataIndexableCol")
self.data_columns.append(name)
new_data_columns.append(name)

# make sure that we match up the existing columns
# if we have an existing table
Expand Down Expand Up @@ -3987,10 +3988,15 @@ def get_blk_items(mgr, blocks):
)
col.set_pos(j)

self.values_axes.append(col)
vaxes.append(col)

j += 1

self.nan_rep = nan_rep
self.data_columns = new_data_columns
self.values_axes = vaxes
self.index_axes = new_index_axes

# validate our min_itemsize
self.validate_min_itemsize(min_itemsize)

Expand Down Expand Up @@ -4428,9 +4434,9 @@ class AppendableFrameTable(AppendableTable):
def is_transposed(self) -> bool:
return self.index_axes[0].axis == 1

def get_object(self, obj):
def get_object(self, obj, transposed: bool):
""" these are written transposed """
if self.is_transposed:
if transposed:
obj = obj.T
return obj

Expand Down Expand Up @@ -4512,7 +4518,7 @@ class AppendableSeriesTable(AppendableFrameTable):
def is_transposed(self) -> bool:
return False

def get_object(self, obj):
def get_object(self, obj, transposed: bool):
return obj

def write(self, obj, data_columns=None, **kwargs):
Expand Down