diff --git a/pandas/io/formats/style.py b/pandas/io/formats/style.py index ed142017a066b..97126975dc194 100644 --- a/pandas/io/formats/style.py +++ b/pandas/io/formats/style.py @@ -1814,7 +1814,7 @@ def set_sticky( self, axis: Axis = 0, pixel_size: int | None = None, - levels: list[int] | None = None, + levels: Level | list[Level] | None = None, ) -> Styler: """ Add CSS to permanently display the index or column headers in a scrolling frame. @@ -1827,7 +1827,7 @@ def set_sticky( Required to configure the width of index cells or the height of column header cells when sticking a MultiIndex (or with a named Index). Defaults to 75 and 25 respectively. - levels : list of int + levels : int, str, list, optional If ``axis`` is a MultiIndex the specific levels to stick. If ``None`` will stick all levels. @@ -1891,11 +1891,12 @@ def set_sticky( else: # handle the MultiIndex case range_idx = list(range(obj.nlevels)) - levels = sorted(levels) if levels else range_idx + levels_: list[int] = refactor_levels(levels, obj) if levels else range_idx + levels_ = sorted(levels_) if axis == 1: styles = [] - for i, level in enumerate(levels): + for i, level in enumerate(levels_): styles.append( { "selector": f"thead tr:nth-child({level+1}) th", @@ -1920,7 +1921,7 @@ def set_sticky( else: styles = [] - for i, level in enumerate(levels): + for i, level in enumerate(levels_): props_ = props + ( f"left:{i * pixel_size}px; " f"min-width:{pixel_size}px; " diff --git a/pandas/tests/io/formats/style/test_html.py b/pandas/tests/io/formats/style/test_html.py index d0b7e288332e2..bbb81e2dce24e 100644 --- a/pandas/tests/io/formats/style/test_html.py +++ b/pandas/tests/io/formats/style/test_html.py @@ -365,11 +365,13 @@ def test_sticky_mi(styler_mi, index, columns): @pytest.mark.parametrize("index", [False, True]) @pytest.mark.parametrize("columns", [False, True]) -def test_sticky_levels(styler_mi, index, columns): +@pytest.mark.parametrize("levels", [[1], ["one"], "one"]) +def test_sticky_levels(styler_mi, index, columns, levels): + styler_mi.index.names, styler_mi.columns.names = ["zero", "one"], ["zero", "one"] if index: - styler_mi.set_sticky(axis=0, levels=[1]) + styler_mi.set_sticky(axis=0, levels=levels) if columns: - styler_mi.set_sticky(axis=1, levels=[1]) + styler_mi.set_sticky(axis=1, levels=levels) left_css = ( "#T_ {0} {{\n position: sticky;\n background-color: white;\n"