Skip to content

Commit 4729d8f

Browse files
authored
STY/WIP: check for private imports/lookups (#36055)
1 parent c104622 commit 4729d8f

File tree

13 files changed

+119
-34
lines changed

13 files changed

+119
-34
lines changed

Makefile

+6
Original file line numberDiff line numberDiff line change
@@ -32,3 +32,9 @@ check:
3232
--included-file-extensions="py" \
3333
--excluded-file-paths=pandas/tests,asv_bench/,pandas/_vendored \
3434
pandas/
35+
36+
python3 scripts/validate_unwanted_patterns.py \
37+
--validation-type="private_import_across_module" \
38+
--included-file-extensions="py" \
39+
--excluded-file-paths=pandas/tests,asv_bench/,pandas/_vendored,doc/
40+
pandas/

ci/code_checks.sh

+11-3
Original file line numberDiff line numberDiff line change
@@ -116,11 +116,19 @@ if [[ -z "$CHECK" || "$CHECK" == "lint" ]]; then
116116
fi
117117
RET=$(($RET + $?)) ; echo $MSG "DONE"
118118

119-
MSG='Check for use of private module attribute access' ; echo $MSG
119+
MSG='Check for import of private attributes across modules' ; echo $MSG
120120
if [[ "$GITHUB_ACTIONS" == "true" ]]; then
121-
$BASE_DIR/scripts/validate_unwanted_patterns.py --validation-type="private_function_across_module" --included-file-extensions="py" --excluded-file-paths=pandas/tests,asv_bench/,pandas/_vendored --format="##[error]{source_path}:{line_number}:{msg}" pandas/
121+
$BASE_DIR/scripts/validate_unwanted_patterns.py --validation-type="private_import_across_module" --included-file-extensions="py" --excluded-file-paths=pandas/tests,asv_bench/,pandas/_vendored --format="##[error]{source_path}:{line_number}:{msg}" pandas/
122122
else
123-
$BASE_DIR/scripts/validate_unwanted_patterns.py --validation-type="private_function_across_module" --included-file-extensions="py" --excluded-file-paths=pandas/tests,asv_bench/,pandas/_vendored pandas/
123+
$BASE_DIR/scripts/validate_unwanted_patterns.py --validation-type="private_import_across_module" --included-file-extensions="py" --excluded-file-paths=pandas/tests,asv_bench/,pandas/_vendored pandas/
124+
fi
125+
RET=$(($RET + $?)) ; echo $MSG "DONE"
126+
127+
MSG='Check for use of private functions across modules' ; echo $MSG
128+
if [[ "$GITHUB_ACTIONS" == "true" ]]; then
129+
$BASE_DIR/scripts/validate_unwanted_patterns.py --validation-type="private_function_across_module" --included-file-extensions="py" --excluded-file-paths=pandas/tests,asv_bench/,pandas/_vendored,doc/ --format="##[error]{source_path}:{line_number}:{msg}" pandas/
130+
else
131+
$BASE_DIR/scripts/validate_unwanted_patterns.py --validation-type="private_function_across_module" --included-file-extensions="py" --excluded-file-paths=pandas/tests,asv_bench/,pandas/_vendored,doc/ pandas/
124132
fi
125133
RET=$(($RET + $?)) ; echo $MSG "DONE"
126134

pandas/core/arrays/datetimelike.py

+5-5
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@
5454

5555
from pandas.core import missing, nanops, ops
5656
from pandas.core.algorithms import checked_add_with_arr, unique1d, value_counts
57-
from pandas.core.arrays._mixins import _T, NDArrayBackedExtensionArray
57+
from pandas.core.arrays._mixins import NDArrayBackedExtensionArray
5858
from pandas.core.arrays.base import ExtensionOpsMixin
5959
import pandas.core.common as com
6060
from pandas.core.construction import array, extract_array
@@ -472,11 +472,11 @@ class DatetimeLikeArrayMixin(
472472
def _ndarray(self) -> np.ndarray:
473473
return self._data
474474

475-
def _from_backing_data(self: _T, arr: np.ndarray) -> _T:
475+
def _from_backing_data(
476+
self: DatetimeLikeArrayT, arr: np.ndarray
477+
) -> DatetimeLikeArrayT:
476478
# Note: we do not retain `freq`
477-
return type(self)._simple_new( # type: ignore[attr-defined]
478-
arr, dtype=self.dtype
479-
)
479+
return type(self)._simple_new(arr, dtype=self.dtype)
480480

481481
# ------------------------------------------------------------------
482482

pandas/core/arrays/integer.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -106,7 +106,7 @@ def _get_common_dtype(self, dtypes: List[DtypeObj]) -> Optional[DtypeObj]:
106106
[t.numpy_dtype if isinstance(t, BaseMaskedDtype) else t for t in dtypes], []
107107
)
108108
if np.issubdtype(np_dtype, np.integer):
109-
return _dtypes[str(np_dtype)]
109+
return STR_TO_DTYPE[str(np_dtype)]
110110
return None
111111

112112
def __from_arrow__(
@@ -214,7 +214,7 @@ def coerce_to_array(
214214

215215
if not issubclass(type(dtype), _IntegerDtype):
216216
try:
217-
dtype = _dtypes[str(np.dtype(dtype))]
217+
dtype = STR_TO_DTYPE[str(np.dtype(dtype))]
218218
except KeyError as err:
219219
raise ValueError(f"invalid dtype specified {dtype}") from err
220220

@@ -354,7 +354,7 @@ class IntegerArray(BaseMaskedArray):
354354

355355
@cache_readonly
356356
def dtype(self) -> _IntegerDtype:
357-
return _dtypes[str(self._data.dtype)]
357+
return STR_TO_DTYPE[str(self._data.dtype)]
358358

359359
def __init__(self, values: np.ndarray, mask: np.ndarray, copy: bool = False):
360360
if not (isinstance(values, np.ndarray) and values.dtype.kind in ["i", "u"]):
@@ -735,7 +735,7 @@ class UInt64Dtype(_IntegerDtype):
735735
__doc__ = _dtype_docstring.format(dtype="uint64")
736736

737737

738-
_dtypes: Dict[str, _IntegerDtype] = {
738+
STR_TO_DTYPE: Dict[str, _IntegerDtype] = {
739739
"int8": Int8Dtype(),
740740
"int16": Int16Dtype(),
741741
"int32": Int32Dtype(),

pandas/core/dtypes/cast.py

+4-2
Original file line numberDiff line numberDiff line change
@@ -1151,9 +1151,11 @@ def convert_dtypes(
11511151
target_int_dtype = "Int64"
11521152

11531153
if is_integer_dtype(input_array.dtype):
1154-
from pandas.core.arrays.integer import _dtypes
1154+
from pandas.core.arrays.integer import STR_TO_DTYPE
11551155

1156-
inferred_dtype = _dtypes.get(input_array.dtype.name, target_int_dtype)
1156+
inferred_dtype = STR_TO_DTYPE.get(
1157+
input_array.dtype.name, target_int_dtype
1158+
)
11571159
if not is_integer_dtype(input_array.dtype) and is_numeric_dtype(
11581160
input_array.dtype
11591161
):

pandas/core/groupby/groupby.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -459,7 +459,7 @@ def f(self):
459459

460460

461461
@contextmanager
462-
def group_selection_context(groupby: "_GroupBy"):
462+
def group_selection_context(groupby: "BaseGroupBy"):
463463
"""
464464
Set / reset the group_selection_context.
465465
"""
@@ -479,7 +479,7 @@ def group_selection_context(groupby: "_GroupBy"):
479479
]
480480

481481

482-
class _GroupBy(PandasObject, SelectionMixin, Generic[FrameOrSeries]):
482+
class BaseGroupBy(PandasObject, SelectionMixin, Generic[FrameOrSeries]):
483483
_group_selection = None
484484
_apply_allowlist: FrozenSet[str] = frozenset()
485485

@@ -1212,7 +1212,7 @@ def _apply_filter(self, indices, dropna):
12121212
OutputFrameOrSeries = TypeVar("OutputFrameOrSeries", bound=NDFrame)
12131213

12141214

1215-
class GroupBy(_GroupBy[FrameOrSeries]):
1215+
class GroupBy(BaseGroupBy[FrameOrSeries]):
12161216
"""
12171217
Class for grouping and aggregating relational data.
12181218

pandas/core/indexes/datetimes.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -312,9 +312,9 @@ def _is_dates_only(self) -> bool:
312312
-------
313313
bool
314314
"""
315-
from pandas.io.formats.format import _is_dates_only
315+
from pandas.io.formats.format import is_dates_only
316316

317-
return self.tz is None and _is_dates_only(self._values)
317+
return self.tz is None and is_dates_only(self._values)
318318

319319
def __reduce__(self):
320320

pandas/core/resample.py

+7-2
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,12 @@
2626
from pandas.core.generic import NDFrame, _shared_docs
2727
from pandas.core.groupby.base import GroupByMixin
2828
from pandas.core.groupby.generic import SeriesGroupBy
29-
from pandas.core.groupby.groupby import GroupBy, _GroupBy, _pipe_template, get_groupby
29+
from pandas.core.groupby.groupby import (
30+
BaseGroupBy,
31+
GroupBy,
32+
_pipe_template,
33+
get_groupby,
34+
)
3035
from pandas.core.groupby.grouper import Grouper
3136
from pandas.core.groupby.ops import BinGrouper
3237
from pandas.core.indexes.api import Index
@@ -40,7 +45,7 @@
4045
_shared_docs_kwargs: Dict[str, str] = dict()
4146

4247

43-
class Resampler(_GroupBy, ShallowMixin):
48+
class Resampler(BaseGroupBy, ShallowMixin):
4449
"""
4550
Class for resampling datetimelike data, a groupby-like operation.
4651
See aggregate, transform, and apply functions on this object.

pandas/core/window/ewm.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515

1616
import pandas.core.common as common
1717
from pandas.core.window.common import _doc_template, _shared_docs, zsqrt
18-
from pandas.core.window.rolling import _Rolling, flex_binary_moment
18+
from pandas.core.window.rolling import RollingMixin, flex_binary_moment
1919

2020
_bias_template = """
2121
Parameters
@@ -60,7 +60,7 @@ def get_center_of_mass(
6060
return float(comass)
6161

6262

63-
class ExponentialMovingWindow(_Rolling):
63+
class ExponentialMovingWindow(RollingMixin):
6464
r"""
6565
Provide exponential weighted (EW) functions.
6666

pandas/core/window/expanding.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -5,10 +5,10 @@
55
from pandas.util._decorators import Appender, Substitution, doc
66

77
from pandas.core.window.common import WindowGroupByMixin, _doc_template, _shared_docs
8-
from pandas.core.window.rolling import _Rolling_and_Expanding
8+
from pandas.core.window.rolling import RollingAndExpandingMixin
99

1010

11-
class Expanding(_Rolling_and_Expanding):
11+
class Expanding(RollingAndExpandingMixin):
1212
"""
1313
Provide expanding transformations.
1414

pandas/core/window/rolling.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -1214,13 +1214,13 @@ def std(self, ddof=1, *args, **kwargs):
12141214
return zsqrt(self.var(ddof=ddof, name="std", **kwargs))
12151215

12161216

1217-
class _Rolling(_Window):
1217+
class RollingMixin(_Window):
12181218
@property
12191219
def _constructor(self):
12201220
return Rolling
12211221

12221222

1223-
class _Rolling_and_Expanding(_Rolling):
1223+
class RollingAndExpandingMixin(RollingMixin):
12241224

12251225
_shared_docs["count"] = dedent(
12261226
r"""
@@ -1917,7 +1917,7 @@ def _get_corr(a, b):
19171917
)
19181918

19191919

1920-
class Rolling(_Rolling_and_Expanding):
1920+
class Rolling(RollingAndExpandingMixin):
19211921
@cache_readonly
19221922
def is_datetimelike(self) -> bool:
19231923
return isinstance(

pandas/io/formats/format.py

+5-5
Original file line numberDiff line numberDiff line change
@@ -1586,7 +1586,7 @@ def format_percentiles(
15861586
return [i + "%" for i in out]
15871587

15881588

1589-
def _is_dates_only(
1589+
def is_dates_only(
15901590
values: Union[np.ndarray, DatetimeArray, Index, DatetimeIndex]
15911591
) -> bool:
15921592
# return a boolean if we are only dates (and don't have a timezone)
@@ -1658,8 +1658,8 @@ def get_format_datetime64_from_values(
16581658
# only accepts 1D values
16591659
values = values.ravel()
16601660

1661-
is_dates_only = _is_dates_only(values)
1662-
if is_dates_only:
1661+
ido = is_dates_only(values)
1662+
if ido:
16631663
return date_format or "%Y-%m-%d"
16641664
return date_format
16651665

@@ -1668,9 +1668,9 @@ class Datetime64TZFormatter(Datetime64Formatter):
16681668
def _format_strings(self) -> List[str]:
16691669
""" we by definition have a TZ """
16701670
values = self.values.astype(object)
1671-
is_dates_only = _is_dates_only(values)
1671+
ido = is_dates_only(values)
16721672
formatter = self.formatter or get_format_datetime64(
1673-
is_dates_only, date_format=self.date_format
1673+
ido, date_format=self.date_format
16741674
)
16751675
fmt_values = [formatter(x) for x in values]
16761676

scripts/validate_unwanted_patterns.py

+65-1
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,39 @@
1818
import tokenize
1919
from typing import IO, Callable, FrozenSet, Iterable, List, Set, Tuple
2020

21+
PRIVATE_IMPORTS_TO_IGNORE: Set[str] = {
22+
"_extension_array_shared_docs",
23+
"_index_shared_docs",
24+
"_interval_shared_docs",
25+
"_merge_doc",
26+
"_shared_docs",
27+
"_apply_docs",
28+
"_new_Index",
29+
"_new_PeriodIndex",
30+
"_doc_template",
31+
"_agg_template",
32+
"_pipe_template",
33+
"_get_version",
34+
"__main__",
35+
"_transform_template",
36+
"_arith_doc_FRAME",
37+
"_flex_comp_doc_FRAME",
38+
"_make_flex_doc",
39+
"_op_descriptions",
40+
"_IntegerDtype",
41+
"_use_inf_as_na",
42+
"_get_plot_backend",
43+
"_matplotlib",
44+
"_arrow_utils",
45+
"_registry",
46+
"_get_offset", # TODO: remove after get_offset deprecation enforced
47+
"_test_parse_iso8601",
48+
"_json_normalize", # TODO: remove after deprecation is enforced
49+
"_testing",
50+
"_test_decorators",
51+
"__version__", # check np.__version__ in compat.numpy.function
52+
}
53+
2154

2255
def _get_literal_string_prefix_len(token_string: str) -> int:
2356
"""
@@ -164,6 +197,36 @@ def private_function_across_module(file_obj: IO[str]) -> Iterable[Tuple[int, str
164197
yield (node.lineno, f"Private function '{module_name}.{function_name}'")
165198

166199

200+
def private_import_across_module(file_obj: IO[str]) -> Iterable[Tuple[int, str]]:
201+
"""
202+
Checking that a private function is not imported across modules.
203+
Parameters
204+
----------
205+
file_obj : IO
206+
File-like object containing the Python code to validate.
207+
Yields
208+
------
209+
line_number : int
210+
Line number of import statement, that imports the private function.
211+
msg : str
212+
Explenation of the error.
213+
"""
214+
contents = file_obj.read()
215+
tree = ast.parse(contents)
216+
217+
for node in ast.walk(tree):
218+
if not (isinstance(node, ast.Import) or isinstance(node, ast.ImportFrom)):
219+
continue
220+
221+
for module in node.names:
222+
module_name = module.name.split(".")[-1]
223+
if module_name in PRIVATE_IMPORTS_TO_IGNORE:
224+
continue
225+
226+
if module_name.startswith("_"):
227+
yield (node.lineno, f"Import of internal function {repr(module_name)}")
228+
229+
167230
def strings_to_concatenate(file_obj: IO[str]) -> Iterable[Tuple[int, str]]:
168231
"""
169232
This test case is necessary after 'Black' (https://github.com/psf/black),
@@ -419,6 +482,7 @@ def main(
419482
available_validation_types: List[str] = [
420483
"bare_pytest_raises",
421484
"private_function_across_module",
485+
"private_import_across_module",
422486
"strings_to_concatenate",
423487
"strings_with_wrong_placed_whitespace",
424488
]
@@ -449,7 +513,7 @@ def main(
449513
parser.add_argument(
450514
"--excluded-file-paths",
451515
default="asv_bench/env",
452-
help="Comma separated file extensions to check.",
516+
help="Comma separated file paths to exclude.",
453517
)
454518

455519
args = parser.parse_args()

0 commit comments

Comments
 (0)