diff --git a/pysimplesql/pysimplesql.py b/pysimplesql/pysimplesql.py index 55608764..ccdd142c 100644 --- a/pysimplesql/pysimplesql.py +++ b/pysimplesql/pysimplesql.py @@ -1002,13 +1002,17 @@ def requery( # We want to store our sort settings before we wipe out the current DataFrame try: sort_settings = self.store_sort_settings() - except AttributeError: + except (AttributeError, KeyError): sort_settings = [None, SORT_NONE] # default for first query rows = self.driver.execute(query) self.rows = rows - print(self.rows) + if len(self.rows.index): + if "sort_order" not in self.rows.attrs: + # Store the sort order as a dictionary in the attrs of the DataFrame + sort_order = self.rows[self.pk_column].to_list() + self.rows.attrs["sort_order"] = {self.pk_column: sort_order} # now we can restore the sort order self.load_sort_settings(sort_settings) self.sort(self.table) @@ -2228,7 +2232,7 @@ def sort_by_column(self, column: str, table: str, reverse=False) -> None: finally: # Drop the temporary description column (if it exists) if tmp_column is not None: - self.rows.drop(columns=tmp, inplace=True, errors="ignore") + self.rows.drop(columns=tmp_column, inplace=True, errors="ignore") def sort_by_index(self, index: int, table: str, reverse=False): """ @@ -2273,13 +2277,18 @@ def sort_reset(self) -> None: # Restore the original sort order self.rows.sort_index(inplace=True) - def sort(self, table: str) -> None: + def sort(self, table: str, update_elements: bool = True, sort_order=None) -> None: """ Sort according to the internal sort_column and sort_reverse variables. This is a good way to re-sort without changing the sort_cycle. :param table: The table associated with this DataSet. Passed along to `DataSet.sort_by_column()` + :param update_elements: Update associated selectors and navigation buttons, and + table header sort marker. + :param sort_order: Passed to `Dataset.update_headings`. A SORT_* constant + (SORT_NONE, SORT_ASC, SORT_DESC). Note that the update_elements parameter + must = True to use this parameter. :returns: None """ pk = self.get_current_pk() @@ -2292,32 +2301,49 @@ def sort(self, table: str) -> None: self.rows.attrs["sort_column"], table, self.rows.attrs["sort_reverse"] ) self.set_by_pk( - pk, update_elements=True, requery_dependents=False, skip_prompt_save=True + pk, + update_elements=False, + requery_dependents=False, + skip_prompt_save=True, ) + if update_elements and len(self.rows.index): + self.frm.update_selectors(self.table) + self.frm.update_elements(self.table, edit_protect_only=True) + self.update_headings(self.rows.attrs["sort_column"], sort_order) - def sort_cycle(self, column: str, table: str) -> int: + def sort_cycle(self, column: str, table: str, update_elements: bool = True) -> int: """ Cycle between original sort order of the DataFrame, ASC by column, and DESC by column with each call. :param column: The column name to cycle the sort on :param table: The table that the column belongs to + :param update_elements: Passed to `Dataset.sort` to update update associated + selectors and navigation buttons, and table header sort marker. :returns: A sort constant; SORT_NONE, SORT_ASC, or SORT_DESC """ if column != self.rows.attrs["sort_column"]: self.rows.attrs["sort_column"] = column self.rows.attrs["sort_reverse"] = False - self.sort(table) + self.sort(table, update_elements=update_elements, sort_order=SORT_ASC) return SORT_ASC if not self.rows.attrs["sort_reverse"]: self.rows.attrs["sort_reverse"] = True - self.sort(table) + self.sort(table, update_elements=update_elements, sort_order=SORT_DESC) return SORT_DESC self.rows.attrs["sort_reverse"] = False self.rows.attrs["sort_column"] = None - self.sort(table) + self.sort(table, update_elements=update_elements, sort_order=SORT_NONE) return SORT_NONE + def update_headings(self, column, sort_order): + for e in self.selector: + element = e["element"] + if element.metadata["TableHeading"]: + element.metadata["TableHeading"].update_headings( + element, column, sort_order + ) + def insert_row(self, row: dict, idx: int = None) -> None: """ Insert a new virtual row into the DataFrame. Virtual rows are ones that exist @@ -2877,9 +2903,7 @@ def auto_map_elements(self, win: sg.Window, keys: List[str] = None) -> None: # 3 Run update_elements() to see the changes table_heading.enable_sorting( element, - _SortCallbackWrapper( - self, data_key, element, table_heading - ), + _SortCallbackWrapper(self, data_key), ) else: @@ -5246,34 +5270,19 @@ class _SortCallbackWrapper: """Internal class used when sg.Table column headers are clicked.""" - def __init__( - self, frm_reference: Form, data_key: str, element: sg.Element, table_heading - ): + def __init__(self, frm_reference: Form, data_key: str): """ Create a new _SortCallbackWrapper object. :param frm_reference: `Form` object :param data_key: `DataSet` key - :param element: PySimpleGUI sg.Table element - :param table_heading: `TableHeading` object :returns: None """ self.frm: Form = frm_reference self.data_key = data_key - self.element = element - self.table_heading: TableHeadings = table_heading def __call__(self, column): - # store the pk: - pk = self.frm[self.data_key].get_current_pk() - sort_order = self.frm[self.data_key].sort_cycle(column, self.data_key) - # We only need to update the selectors not all elements, - # so first set by the primary key, then update_selectors() - self.frm[self.data_key].set_by_pk( - pk, update_elements=False, requery_dependents=False, skip_prompt_save=True - ) - self.frm.update_selectors(self.data_key) - self.table_heading.update_headings(self.element, column, sort_order) + self.frm[self.data_key].sort_cycle(column, self.data_key, update_elements=True) # ======================================================================================