diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 3a371c8249eba..d433fb08209bf 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -138,7 +138,7 @@ repos: entry: python scripts/check_for_inconsistent_pandas_namespace.py language: python types: [python] - files: ^pandas/tests/ + files: ^pandas/tests/frame/ - id: FrameOrSeriesUnion name: Check for use of Union[Series, DataFrame] instead of FrameOrSeriesUnion alias entry: Union\[.*(Series,.*DataFrame|DataFrame,.*Series).*\] diff --git a/pandas/tests/frame/indexing/test_indexing.py b/pandas/tests/frame/indexing/test_indexing.py index 4da9ed76844af..88d1e217fb45a 100644 --- a/pandas/tests/frame/indexing/test_indexing.py +++ b/pandas/tests/frame/indexing/test_indexing.py @@ -217,7 +217,7 @@ def test_setitem_multi_index(self): it = ["jim", "joe", "jolie"], ["first", "last"], ["left", "center", "right"] cols = MultiIndex.from_product(it) - index = pd.date_range("20141006", periods=20) + index = date_range("20141006", periods=20) vals = np.random.randint(1, 1000, (len(index), len(cols))) df = DataFrame(vals, columns=cols, index=index) @@ -1357,7 +1357,7 @@ def test_loc_duplicates(self): # gh-17105 # insert a duplicate element to the index - trange = pd.date_range( + trange = date_range( start=Timestamp(year=2017, month=1, day=1), end=Timestamp(year=2017, month=1, day=5), ) @@ -1421,7 +1421,7 @@ def test_setitem_with_unaligned_tz_aware_datetime_column(self): # GH 12981 # Assignment of unaligned offset-aware datetime series. # Make sure timezone isn't lost - column = Series(pd.date_range("2015-01-01", periods=3, tz="utc"), name="dates") + column = Series(date_range("2015-01-01", periods=3, tz="utc"), name="dates") df = DataFrame({"dates": column}) df["dates"] = column[[1, 0, 2]] tm.assert_series_equal(df["dates"], column) @@ -1716,7 +1716,7 @@ def test_object_casting_indexing_wraps_datetimelike(): df = DataFrame( { "A": [1, 2], - "B": pd.date_range("2000", periods=2), + "B": date_range("2000", periods=2), "C": pd.timedelta_range("1 Day", periods=2), } ) diff --git a/pandas/tests/frame/methods/test_describe.py b/pandas/tests/frame/methods/test_describe.py index 113e870c8879b..0b4ce0dfa80fc 100644 --- a/pandas/tests/frame/methods/test_describe.py +++ b/pandas/tests/frame/methods/test_describe.py @@ -283,7 +283,7 @@ def test_describe_tz_values(self, tz_naive_fixture): tm.assert_frame_equal(result, expected) def test_datetime_is_numeric_includes_datetime(self): - df = DataFrame({"a": pd.date_range("2012", periods=3), "b": [1, 2, 3]}) + df = DataFrame({"a": date_range("2012", periods=3), "b": [1, 2, 3]}) result = df.describe(datetime_is_numeric=True) expected = DataFrame( { diff --git a/pandas/tests/frame/methods/test_diff.py b/pandas/tests/frame/methods/test_diff.py index c3dfca4c121db..75d93ed2aafc6 100644 --- a/pandas/tests/frame/methods/test_diff.py +++ b/pandas/tests/frame/methods/test_diff.py @@ -80,7 +80,7 @@ def test_diff_datetime_axis0_with_nat(self, tz): @pytest.mark.parametrize("tz", [None, "UTC"]) def test_diff_datetime_with_nat_zero_periods(self, tz): # diff on NaT values should give NaT, not timedelta64(0) - dti = pd.date_range("2016-01-01", periods=4, tz=tz) + dti = date_range("2016-01-01", periods=4, tz=tz) ser = Series(dti) df = ser.to_frame() @@ -178,7 +178,7 @@ def test_diff_axis(self): def test_diff_period(self): # GH#32995 Don't pass an incorrect axis - pi = pd.date_range("2016-01-01", periods=3).to_period("D") + pi = date_range("2016-01-01", periods=3).to_period("D") df = DataFrame({"A": pi}) result = df.diff(1, axis=1) diff --git a/pandas/tests/frame/methods/test_drop.py b/pandas/tests/frame/methods/test_drop.py index f92899740f95f..dc9a1565aad1e 100644 --- a/pandas/tests/frame/methods/test_drop.py +++ b/pandas/tests/frame/methods/test_drop.py @@ -27,7 +27,7 @@ def test_drop_raise_exception_if_labels_not_in_level(msg, labels, level): # GH 8594 mi = MultiIndex.from_arrays([[1, 2, 3], [4, 5, 6]], names=["a", "b"]) - s = pd.Series([10, 20, 30], index=mi) + s = Series([10, 20, 30], index=mi) df = DataFrame([10, 20, 30], index=mi) with pytest.raises(KeyError, match=msg): @@ -40,7 +40,7 @@ def test_drop_raise_exception_if_labels_not_in_level(msg, labels, level): def test_drop_errors_ignore(labels, level): # GH 8594 mi = MultiIndex.from_arrays([[1, 2, 3], [4, 5, 6]], names=["a", "b"]) - s = pd.Series([10, 20, 30], index=mi) + s = Series([10, 20, 30], index=mi) df = DataFrame([10, 20, 30], index=mi) expected_s = s.drop(labels, level=level, errors="ignore") diff --git a/pandas/tests/frame/methods/test_join.py b/pandas/tests/frame/methods/test_join.py index 11e83a3d94151..46d32446e2553 100644 --- a/pandas/tests/frame/methods/test_join.py +++ b/pandas/tests/frame/methods/test_join.py @@ -310,7 +310,7 @@ def test_join_multiindex_leftright(self): tm.assert_frame_equal(df1.join(df2, how="left"), exp) tm.assert_frame_equal(df2.join(df1, how="right"), exp[["value2", "value1"]]) - exp_idx = pd.MultiIndex.from_product( + exp_idx = MultiIndex.from_product( [["a", "b"], ["x", "y", "z"]], names=["first", "second"] ) exp = DataFrame( diff --git a/pandas/tests/frame/methods/test_reset_index.py b/pandas/tests/frame/methods/test_reset_index.py index 2ca23127c751a..bd66d54792fba 100644 --- a/pandas/tests/frame/methods/test_reset_index.py +++ b/pandas/tests/frame/methods/test_reset_index.py @@ -426,7 +426,7 @@ def test_reset_index_multiindex_columns(self): def test_reset_index_datetime(self, tz_naive_fixture): # GH#3950 tz = tz_naive_fixture - idx1 = pd.date_range("1/1/2011", periods=5, freq="D", tz=tz, name="idx1") + idx1 = date_range("1/1/2011", periods=5, freq="D", tz=tz, name="idx1") idx2 = Index(range(5), name="idx2", dtype="int64") idx = MultiIndex.from_arrays([idx1, idx2]) df = DataFrame( @@ -453,7 +453,7 @@ def test_reset_index_datetime(self, tz_naive_fixture): tm.assert_frame_equal(df.reset_index(), expected) - idx3 = pd.date_range( + idx3 = date_range( "1/1/2012", periods=5, freq="MS", tz="Europe/Paris", name="idx3" ) idx = MultiIndex.from_arrays([idx1, idx2, idx3]) @@ -492,7 +492,7 @@ def test_reset_index_datetime(self, tz_naive_fixture): # GH#7793 idx = MultiIndex.from_product( - [["a", "b"], pd.date_range("20130101", periods=3, tz=tz)] + [["a", "b"], date_range("20130101", periods=3, tz=tz)] ) df = DataFrame( np.arange(6, dtype="int64").reshape(6, 1), columns=["a"], index=idx diff --git a/pandas/tests/frame/methods/test_to_csv.py b/pandas/tests/frame/methods/test_to_csv.py index 4cf0b1febf0af..aed784a6e4c3c 100644 --- a/pandas/tests/frame/methods/test_to_csv.py +++ b/pandas/tests/frame/methods/test_to_csv.py @@ -12,6 +12,7 @@ DataFrame, Index, MultiIndex, + NaT, Series, Timestamp, date_range, @@ -41,7 +42,7 @@ def read_csv(self, path, **kwargs): params = {"index_col": 0, "parse_dates": True} params.update(**kwargs) - return pd.read_csv(path, **params) + return read_csv(path, **params) def test_to_csv_from_csv1(self, float_frame, datetime_frame): @@ -123,7 +124,7 @@ def test_to_csv_from_csv3(self): df1.to_csv(path) df2.to_csv(path, mode="a", header=False) xp = pd.concat([df1, df2]) - rs = pd.read_csv(path, index_col=0) + rs = read_csv(path, index_col=0) rs.columns = [int(label) for label in rs.columns] xp.columns = [int(label) for label in xp.columns] tm.assert_frame_equal(xp, rs) @@ -139,7 +140,7 @@ def test_to_csv_from_csv4(self): ) df.to_csv(path) - result = pd.read_csv(path, index_col="dt_index") + result = read_csv(path, index_col="dt_index") result.index = pd.to_timedelta(result.index) # TODO: remove renaming when GH 10875 is solved result.index = result.index.rename("dt_index") @@ -153,7 +154,7 @@ def test_to_csv_from_csv5(self, timezone_frame): with tm.ensure_clean("__tmp_to_csv_from_csv5__") as path: timezone_frame.to_csv(path) - result = pd.read_csv(path, index_col=0, parse_dates=["A"]) + result = read_csv(path, index_col=0, parse_dates=["A"]) converter = ( lambda c: to_datetime(result[c]) @@ -166,8 +167,6 @@ def test_to_csv_from_csv5(self, timezone_frame): def test_to_csv_cols_reordering(self): # GH3454 - import pandas as pd - chunksize = 5 N = int(chunksize * 2.5) @@ -177,17 +176,15 @@ def test_to_csv_cols_reordering(self): with tm.ensure_clean() as path: df.to_csv(path, columns=cols, chunksize=chunksize) - rs_c = pd.read_csv(path, index_col=0) + rs_c = read_csv(path, index_col=0) tm.assert_frame_equal(df[cols], rs_c, check_names=False) def test_to_csv_new_dupe_cols(self): - import pandas as pd - def _check_df(df, cols=None): with tm.ensure_clean() as path: df.to_csv(path, columns=cols, chunksize=chunksize) - rs_c = pd.read_csv(path, index_col=0) + rs_c = read_csv(path, index_col=0) # we wrote them in a different order # so compare them in that order @@ -227,8 +224,6 @@ def _check_df(df, cols=None): @pytest.mark.slow def test_to_csv_dtnat(self): # GH3437 - from pandas import NaT - def make_dtnat_arr(n, nnat=None): if nnat is None: nnat = int(n * 0.1) # 10% @@ -999,7 +994,7 @@ def test_to_csv_path_is_none(self, float_frame): # Series.to_csv() csv_str = float_frame.to_csv(path_or_buf=None) assert isinstance(csv_str, str) - recons = pd.read_csv(StringIO(csv_str), index_col=0) + recons = read_csv(StringIO(csv_str), index_col=0) tm.assert_frame_equal(float_frame, recons) @pytest.mark.parametrize( @@ -1040,7 +1035,7 @@ def test_to_csv_compression(self, df, encoding, compression): df.to_csv(handles.handle, encoding=encoding) assert not handles.handle.closed - result = pd.read_csv( + result = read_csv( filename, compression=compression, encoding=encoding, @@ -1122,7 +1117,7 @@ def test_to_csv_with_dst_transitions(self): with tm.ensure_clean("csv_date_format_with_dst") as path: # make sure we are not failing on transitions - times = pd.date_range( + times = date_range( "2013-10-26 23:00", "2013-10-27 01:00", tz="Europe/London", @@ -1144,7 +1139,7 @@ def test_to_csv_with_dst_transitions(self): tm.assert_frame_equal(result, df) # GH11619 - idx = pd.date_range("2015-01-01", "2015-12-31", freq="H", tz="Europe/Paris") + idx = date_range("2015-01-01", "2015-12-31", freq="H", tz="Europe/Paris") idx = idx._with_freq(None) # freq does not round-trip idx._data._freq = None # otherwise there is trouble on unpickle df = DataFrame({"values": 1, "idx": idx}, index=idx) @@ -1250,7 +1245,7 @@ def test_to_csv_quoting(self): # presents with encoding? text_rows = ["a,b,c", '1,"test \r\n",3'] text = tm.convert_rows_list_to_csv_str(text_rows) - df = pd.read_csv(StringIO(text)) + df = read_csv(StringIO(text)) buf = StringIO() df.to_csv(buf, encoding="utf-8", index=False) @@ -1286,7 +1281,7 @@ def test_period_index_date_overflow(self): assert result == expected # Overflow with pd.NaT - dates = ["1990-01-01", pd.NaT, "3005-01-01"] + dates = ["1990-01-01", NaT, "3005-01-01"] index = pd.PeriodIndex(dates, freq="D") df = DataFrame([4, 5, 6], index=index) @@ -1298,7 +1293,7 @@ def test_period_index_date_overflow(self): def test_multi_index_header(self): # see gh-5539 - columns = pd.MultiIndex.from_tuples([("a", 1), ("a", 2), ("b", 1), ("b", 2)]) + columns = MultiIndex.from_tuples([("a", 1), ("a", 2), ("b", 1), ("b", 2)]) df = DataFrame([[1, 2, 3, 4], [5, 6, 7, 8]]) df.columns = columns diff --git a/pandas/tests/frame/test_api.py b/pandas/tests/frame/test_api.py index 8d3d049cf82f1..2f2de9764219b 100644 --- a/pandas/tests/frame/test_api.py +++ b/pandas/tests/frame/test_api.py @@ -73,7 +73,7 @@ def test_tab_completion(self): df = DataFrame([list("abcd"), list("efgh")], columns=list("ABCD")) for key in list("ABCD"): assert key in dir(df) - assert isinstance(df.__getitem__("A"), pd.Series) + assert isinstance(df.__getitem__("A"), Series) # DataFrame whose first-level columns are identifiers shall have # them in __dir__. @@ -85,7 +85,7 @@ def test_tab_completion(self): assert key in dir(df) for key in list("EFGH"): assert key not in dir(df) - assert isinstance(df.__getitem__("A"), pd.DataFrame) + assert isinstance(df.__getitem__("A"), DataFrame) def test_not_hashable(self): empty_frame = DataFrame() diff --git a/pandas/tests/frame/test_constructors.py b/pandas/tests/frame/test_constructors.py index 93bffd7fea95b..14adc8a992609 100644 --- a/pandas/tests/frame/test_constructors.py +++ b/pandas/tests/frame/test_constructors.py @@ -2208,7 +2208,7 @@ class DatetimeSubclass(datetime): def test_with_mismatched_index_length_raises(self): # GH#33437 - dti = pd.date_range("2016-01-01", periods=3, tz="US/Pacific") + dti = date_range("2016-01-01", periods=3, tz="US/Pacific") with pytest.raises(ValueError, match="Shape of passed values"): DataFrame(dti, index=range(4)) diff --git a/pandas/tests/frame/test_query_eval.py b/pandas/tests/frame/test_query_eval.py index 0cc4c4ad81208..fdbf8a93ddddf 100644 --- a/pandas/tests/frame/test_query_eval.py +++ b/pandas/tests/frame/test_query_eval.py @@ -719,7 +719,7 @@ def test_inf(self): def test_check_tz_aware_index_query(self, tz_aware_fixture): # https://github.com/pandas-dev/pandas/issues/29463 tz = tz_aware_fixture - df_index = pd.date_range( + df_index = date_range( start="2019-01-01", freq="1d", periods=10, tz=tz, name="time" ) expected = DataFrame(index=df_index) diff --git a/pandas/tests/frame/test_reductions.py b/pandas/tests/frame/test_reductions.py index e85b399e69874..de64aecc34ac9 100644 --- a/pandas/tests/frame/test_reductions.py +++ b/pandas/tests/frame/test_reductions.py @@ -831,7 +831,7 @@ def test_sum_nanops_timedelta(self): idx = ["a", "b", "c"] df = DataFrame({"a": [0, 0], "b": [0, np.nan], "c": [np.nan, np.nan]}) - df2 = df.apply(pd.to_timedelta) + df2 = df.apply(to_timedelta) # 0 by default result = df2.sum() @@ -861,9 +861,9 @@ def test_sum_bool(self, float_frame): def test_sum_mixed_datetime(self): # GH#30886 - df = DataFrame( - {"A": pd.date_range("2000", periods=4), "B": [1, 2, 3, 4]} - ).reindex([2, 3, 4]) + df = DataFrame({"A": date_range("2000", periods=4), "B": [1, 2, 3, 4]}).reindex( + [2, 3, 4] + ) result = df.sum() expected = Series({"B": 7.0}) @@ -893,7 +893,7 @@ def test_mean_datetimelike(self): df = DataFrame( { "A": np.arange(3), - "B": pd.date_range("2016-01-01", periods=3), + "B": date_range("2016-01-01", periods=3), "C": pd.timedelta_range("1D", periods=3), "D": pd.period_range("2016", periods=3, freq="A"), } @@ -912,7 +912,7 @@ def test_mean_datetimelike_numeric_only_false(self): df = DataFrame( { "A": np.arange(3), - "B": pd.date_range("2016-01-01", periods=3), + "B": date_range("2016-01-01", periods=3), "C": pd.timedelta_range("1D", periods=3), } ) @@ -983,7 +983,7 @@ def test_idxmax(self, float_frame, int_frame): def test_idxmax_mixed_dtype(self): # don't cast to object, which would raise in nanops - dti = pd.date_range("2016-01-01", periods=3) + dti = date_range("2016-01-01", periods=3) df = DataFrame({1: [0, 2, 1], 2: range(3)[::-1], 3: dti}) @@ -1273,8 +1273,8 @@ def test_min_max_dt64_api_consistency_with_NaT(self): # returned NaT for series. These tests check that the API is consistent in # min/max calls on empty Series/DataFrames. See GH:33704 for more # information - df = DataFrame({"x": pd.to_datetime([])}) - expected_dt_series = Series(pd.to_datetime([])) + df = DataFrame({"x": to_datetime([])}) + expected_dt_series = Series(to_datetime([])) # check axis 0 assert (df.min(axis=0).x is pd.NaT) == (expected_dt_series.min() is pd.NaT) assert (df.max(axis=0).x is pd.NaT) == (expected_dt_series.max() is pd.NaT) @@ -1302,7 +1302,7 @@ def test_min_max_dt64_api_consistency_empty_df(self): @pytest.mark.parametrize("method", ["min", "max"]) def test_preserve_timezone(self, initial: str, method): # GH 28552 - initial_dt = pd.to_datetime(initial) + initial_dt = to_datetime(initial) expected = Series([initial_dt]) df = DataFrame([expected]) result = getattr(df, method)(axis=1) @@ -1330,7 +1330,7 @@ def test_frame_any_with_timedelta(self): df = DataFrame( { "a": Series([0, 0]), - "t": Series([pd.to_timedelta(0, "s"), pd.to_timedelta(1, "ms")]), + "t": Series([to_timedelta(0, "s"), to_timedelta(1, "ms")]), } ) diff --git a/pandas/tests/frame/test_stack_unstack.py b/pandas/tests/frame/test_stack_unstack.py index 6f453ec5df71f..9945b739f8a87 100644 --- a/pandas/tests/frame/test_stack_unstack.py +++ b/pandas/tests/frame/test_stack_unstack.py @@ -158,7 +158,7 @@ def test_unstack_fill_frame(self): def test_unstack_fill_frame_datetime(self): # Test unstacking with date times - dv = pd.date_range("2012-01-01", periods=4).values + dv = date_range("2012-01-01", periods=4).values data = Series(dv) data.index = MultiIndex.from_tuples( [("x", "a"), ("x", "b"), ("y", "b"), ("z", "a")] @@ -608,7 +608,7 @@ def test_unstack_dtypes(self): "A": ["a"] * 5, "C": c, "D": d, - "B": pd.date_range("2012-01-01", periods=5), + "B": date_range("2012-01-01", periods=5), } ) @@ -942,7 +942,7 @@ def verify(df): df = DataFrame( { "1st": [1, 2, 1, 2, 1, 2], - "2nd": pd.date_range("2014-02-01", periods=6, freq="D"), + "2nd": date_range("2014-02-01", periods=6, freq="D"), "jim": 100 + np.arange(6), "joe": (np.random.randn(6) * 10).round(2), } @@ -1171,9 +1171,7 @@ def test_unstack_timezone_aware_values(): def test_stack_timezone_aware_values(): # GH 19420 - ts = pd.date_range( - freq="D", start="20180101", end="20180103", tz="America/New_York" - ) + ts = date_range(freq="D", start="20180101", end="20180103", tz="America/New_York") df = DataFrame({"A": ts}, index=["a", "b", "c"]) result = df.stack() expected = Series( diff --git a/pandas/tests/frame/test_subclass.py b/pandas/tests/frame/test_subclass.py index c32d8f63831ac..784ca03fa9c03 100644 --- a/pandas/tests/frame/test_subclass.py +++ b/pandas/tests/frame/test_subclass.py @@ -61,11 +61,11 @@ def custom_frame_function(self): assert cdf_rows.custom_frame_function() == "OK" # Make sure sliced part of multi-index frame is custom class - mcol = pd.MultiIndex.from_tuples([("A", "A"), ("A", "B")]) + mcol = MultiIndex.from_tuples([("A", "A"), ("A", "B")]) cdf_multi = CustomDataFrame([[0, 1], [2, 3]], columns=mcol) assert isinstance(cdf_multi["A"], CustomDataFrame) - mcol = pd.MultiIndex.from_tuples([("A", ""), ("B", "")]) + mcol = MultiIndex.from_tuples([("A", ""), ("B", "")]) cdf_multi2 = CustomDataFrame([[0, 1], [2, 3]], columns=mcol) assert isinstance(cdf_multi2["A"], CustomSeries) @@ -705,7 +705,7 @@ def test_idxmax_preserves_subclass(self): def test_equals_subclass(self): # https://github.com/pandas-dev/pandas/pull/34402 # allow subclass in both directions - df1 = pd.DataFrame({"a": [1, 2, 3]}) + df1 = DataFrame({"a": [1, 2, 3]}) df2 = tm.SubclassedDataFrame({"a": [1, 2, 3]}) assert df1.equals(df2) assert df2.equals(df1) diff --git a/scripts/check_for_inconsistent_pandas_namespace.py b/scripts/check_for_inconsistent_pandas_namespace.py index 11cdba6e821d2..87070e819b4a0 100644 --- a/scripts/check_for_inconsistent_pandas_namespace.py +++ b/scripts/check_for_inconsistent_pandas_namespace.py @@ -7,55 +7,115 @@ This is meant to be run as a pre-commit hook - to run it manually, you can do: pre-commit run inconsistent-namespace-usage --all-files + +To automatically fixup a given file, you can pass `--replace`, e.g. + + python scripts/check_for_inconsistent_pandas_namespace.py test_me.py --replace + +though note that you may need to manually fixup some imports and that you will also +need the additional dependency `tokenize-rt` (which is left out from the pre-commit +hook so that it uses the same virtualenv as the other local ones). """ import argparse -from pathlib import Path -import re +import ast from typing import ( + MutableMapping, Optional, Sequence, + Set, + Tuple, ) -PATTERN = r""" - ( - (? None: + self.pandas_namespace: MutableMapping[Offset, str] = {} + self.no_namespace: Set[str] = set() + + def visit_Attribute(self, node: ast.Attribute) -> None: + if ( + isinstance(node.value, ast.Name) + and node.value.id == "pd" + and node.attr not in EXCLUDE + ): + self.pandas_namespace[(node.lineno, node.col_offset)] = node.attr + self.generic_visit(node) + + def visit_Name(self, node: ast.Name) -> None: + if node.id not in EXCLUDE: + self.no_namespace.add(node.id) + self.generic_visit(node) + + +def replace_inconsistent_pandas_namespace(visitor: Visitor, content: str) -> str: + from tokenize_rt import ( + reversed_enumerate, + src_to_tokens, + tokens_to_src, ) - """ -ERROR_MESSAGE = "Found both `pd.{class_name}` and `{class_name}` in {path}" + + tokens = src_to_tokens(content) + for n, i in reversed_enumerate(tokens): + if ( + i.offset in visitor.pandas_namespace + and visitor.pandas_namespace[i.offset] in visitor.no_namespace + ): + # Replace `pd` + tokens[n] = i._replace(src="") + # Replace `.` + tokens[n + 1] = tokens[n + 1]._replace(src="") + + new_src: str = tokens_to_src(tokens) + return new_src + + +def check_for_inconsistent_pandas_namespace( + content: str, path: str, *, replace: bool +) -> Optional[str]: + tree = ast.parse(content) + + visitor = Visitor() + visitor.visit(tree) + + inconsistencies = visitor.no_namespace.intersection( + visitor.pandas_namespace.values() + ) + if not inconsistencies: + # No inconsistent namespace usage, nothing to replace. + return content + + if not replace: + msg = ERROR_MESSAGE.format(name=inconsistencies.pop(), path=path) + raise RuntimeError(msg) + + return replace_inconsistent_pandas_namespace(visitor, content) def main(argv: Optional[Sequence[str]] = None) -> None: parser = argparse.ArgumentParser() - parser.add_argument("paths", nargs="*", type=Path) + parser.add_argument("paths", nargs="*") + parser.add_argument("--replace", action="store_true") args = parser.parse_args(argv) - pattern = re.compile( - PATTERN.encode(), - flags=re.MULTILINE | re.DOTALL | re.VERBOSE, - ) for path in args.paths: - contents = path.read_bytes() - match = pattern.search(contents) - if match is None: + with open(path, encoding="utf-8") as fd: + content = fd.read() + new_content = check_for_inconsistent_pandas_namespace( + content, path, replace=args.replace + ) + if not args.replace or new_content is None: continue - if match.group(2) is not None: - raise AssertionError( - ERROR_MESSAGE.format(class_name=match.group(2).decode(), path=str(path)) - ) - if match.group(4) is not None: - raise AssertionError( - ERROR_MESSAGE.format(class_name=match.group(4).decode(), path=str(path)) - ) + with open(path, "w", encoding="utf-8") as fd: + fd.write(new_content) if __name__ == "__main__": diff --git a/scripts/tests/test_inconsistent_namespace_check.py b/scripts/tests/test_inconsistent_namespace_check.py index 37e6d288d9341..cc3509af5b138 100644 --- a/scripts/tests/test_inconsistent_namespace_check.py +++ b/scripts/tests/test_inconsistent_namespace_check.py @@ -1,28 +1,38 @@ -from pathlib import Path - import pytest -from scripts.check_for_inconsistent_pandas_namespace import main +from scripts.check_for_inconsistent_pandas_namespace import ( + check_for_inconsistent_pandas_namespace, +) BAD_FILE_0 = "cat_0 = Categorical()\ncat_1 = pd.Categorical()" BAD_FILE_1 = "cat_0 = pd.Categorical()\ncat_1 = Categorical()" GOOD_FILE_0 = "cat_0 = Categorical()\ncat_1 = Categorical()" GOOD_FILE_1 = "cat_0 = pd.Categorical()\ncat_1 = pd.Categorical()" +PATH = "t.py" + + +@pytest.mark.parametrize("content", [BAD_FILE_0, BAD_FILE_1]) +def test_inconsistent_usage(content): + msg = r"Found both `pd\.Categorical` and `Categorical` in t\.py" + with pytest.raises(RuntimeError, match=msg): + check_for_inconsistent_pandas_namespace(content, PATH, replace=False) + + +@pytest.mark.parametrize("content", [GOOD_FILE_0, GOOD_FILE_1]) +def test_consistent_usage(content): + # should not raise + check_for_inconsistent_pandas_namespace(content, PATH, replace=False) @pytest.mark.parametrize("content", [BAD_FILE_0, BAD_FILE_1]) -def test_inconsistent_usage(tmpdir, content): - tmpfile = Path(tmpdir / "tmpfile.py") - tmpfile.touch() - tmpfile.write_text(content) - msg = fr"Found both `pd\.Categorical` and `Categorical` in {str(tmpfile)}" - with pytest.raises(AssertionError, match=msg): - main((str(tmpfile),)) +def test_inconsistent_usage_with_replace(content): + result = check_for_inconsistent_pandas_namespace(content, PATH, replace=True) + expected = "cat_0 = Categorical()\ncat_1 = Categorical()" + assert result == expected @pytest.mark.parametrize("content", [GOOD_FILE_0, GOOD_FILE_1]) -def test_consistent_usage(tmpdir, content): - tmpfile = Path(tmpdir / "tmpfile.py") - tmpfile.touch() - tmpfile.write_text(content) - main((str(tmpfile),)) # Should not raise. +def test_consistent_usage_with_replace(content): + result = check_for_inconsistent_pandas_namespace(content, PATH, replace=True) + expected = content + assert result == expected