Skip to content

Commit 91e97f9

Browse files
grundprinzipHyukjinKwon
authored andcommitted
[SPARK-44528][CONNECT] Support proper usage of hasattr() for Connect dataframe
### What changes were proposed in this pull request? Currently Connect does not allow the proper usage of Python's `hasattr()` to identify if an attribute is defined or not. This patch fixes that bug (it's working in regular PySpark). ### Why are the changes needed? Bugfix ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? UT Closes #42132 from grundprinzip/SPARK-44528. Lead-authored-by: Martin Grund <[email protected]> Co-authored-by: Martin Grund <[email protected]> Signed-off-by: Hyukjin Kwon <[email protected]>
1 parent 6b6216c commit 91e97f9

File tree

3 files changed

+46
-7
lines changed

3 files changed

+46
-7
lines changed

python/pyspark/sql/connect/dataframe.py

+8
Original file line numberDiff line numberDiff line change
@@ -1584,8 +1584,16 @@ def __getattr__(self, name: str) -> "Column":
15841584
error_class="NOT_IMPLEMENTED",
15851585
message_parameters={"feature": f"{name}()"},
15861586
)
1587+
1588+
if name not in self.columns:
1589+
raise AttributeError(
1590+
"'%s' object has no attribute '%s'" % (self.__class__.__name__, name)
1591+
)
1592+
15871593
return self[name]
15881594

1595+
__getattr__.__doc__ = PySparkDataFrame.__getattr__.__doc__
1596+
15891597
@overload
15901598
def __getitem__(self, item: Union[int, str]) -> Column:
15911599
...

python/pyspark/sql/tests/connect/test_connect_basic.py

+15-2
Original file line numberDiff line numberDiff line change
@@ -157,6 +157,19 @@ def spark_connect_clean_up_test_data(cls):
157157

158158

159159
class SparkConnectBasicTests(SparkConnectSQLTestCase):
160+
def test_df_getattr_behavior(self):
161+
cdf = self.connect.range(10)
162+
sdf = self.spark.range(10)
163+
164+
sdf._simple_extension = 10
165+
cdf._simple_extension = 10
166+
167+
self.assertEqual(sdf._simple_extension, cdf._simple_extension)
168+
self.assertEqual(type(sdf._simple_extension), type(cdf._simple_extension))
169+
170+
self.assertTrue(hasattr(cdf, "_simple_extension"))
171+
self.assertFalse(hasattr(cdf, "_simple_extension_does_not_exsit"))
172+
160173
def test_df_get_item(self):
161174
# SPARK-41779: test __getitem__
162175

@@ -1296,8 +1309,8 @@ def test_drop(self):
12961309
sdf.drop("a", "x").toPandas(),
12971310
)
12981311
self.assert_eq(
1299-
cdf.drop(cdf.a, cdf.x).toPandas(),
1300-
sdf.drop("a", "x").toPandas(),
1312+
cdf.drop(cdf.a, "x").toPandas(),
1313+
sdf.drop(sdf.a, "x").toPandas(),
13011314
)
13021315

13031316
def test_subquery_alias(self) -> None:

python/pyspark/testing/connectutils.py

+23-5
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
#
1717
import shutil
1818
import tempfile
19+
import types
1920
import typing
2021
import os
2122
import functools
@@ -67,7 +68,7 @@
6768

6869
if should_test_connect:
6970
from pyspark.sql.connect.dataframe import DataFrame
70-
from pyspark.sql.connect.plan import Read, Range, SQL
71+
from pyspark.sql.connect.plan import Read, Range, SQL, LogicalPlan
7172
from pyspark.sql.connect.session import SparkSession
7273

7374

@@ -88,16 +89,33 @@ def __getattr__(self, item):
8889
return functools.partial(self.hooks[item])
8990

9091

92+
class MockDF(DataFrame):
93+
"""Helper class that must only be used for the mock plan tests."""
94+
95+
def __init__(self, session: SparkSession, plan: LogicalPlan):
96+
super().__init__(session)
97+
self._plan = plan
98+
99+
def __getattr__(self, name):
100+
"""All attributes are resolved to columns, because none really exist in the
101+
mocked DataFrame."""
102+
return self[name]
103+
104+
91105
@unittest.skipIf(not should_test_connect, connect_requirement_message)
92106
class PlanOnlyTestFixture(unittest.TestCase, PySparkErrorTestUtils):
93107
@classmethod
94108
def _read_table(cls, table_name):
95-
return DataFrame.withPlan(Read(table_name), cls.connect)
109+
return cls._df_mock(Read(table_name))
96110

97111
@classmethod
98112
def _udf_mock(cls, *args, **kwargs):
99113
return "internal_name"
100114

115+
@classmethod
116+
def _df_mock(cls, plan: LogicalPlan) -> MockDF:
117+
return MockDF(cls.connect, plan)
118+
101119
@classmethod
102120
def _session_range(
103121
cls,
@@ -106,17 +124,17 @@ def _session_range(
106124
step=1,
107125
num_partitions=None,
108126
):
109-
return DataFrame.withPlan(Range(start, end, step, num_partitions), cls.connect)
127+
return cls._df_mock(Range(start, end, step, num_partitions))
110128

111129
@classmethod
112130
def _session_sql(cls, query):
113-
return DataFrame.withPlan(SQL(query), cls.connect)
131+
return cls._df_mock(SQL(query))
114132

115133
if have_pandas:
116134

117135
@classmethod
118136
def _with_plan(cls, plan):
119-
return DataFrame.withPlan(plan, cls.connect)
137+
return cls._df_mock(plan)
120138

121139
@classmethod
122140
def setUpClass(cls):

0 commit comments

Comments
 (0)