Skip to content

Commit aeb3649

Browse files
BryanCutlerHyukjinKwon
authored andcommitted
[SPARK-33613][PYTHON][TESTS] Replace deprecated APIs in pyspark tests
### What changes were proposed in this pull request? This replaces deprecated API usage in PySpark tests with the preferred APIs. These have been deprecated for some time and usage is not consistent within tests. - https://docs.python.org/3/library/unittest.html#deprecated-aliases ### Why are the changes needed? For consistency and eventual removal of deprecated APIs. ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? Existing tests Closes #30557 from BryanCutler/replace-deprecated-apis-in-tests. Authored-by: Bryan Cutler <[email protected]> Signed-off-by: HyukjinKwon <[email protected]>
1 parent 596fbc1 commit aeb3649

27 files changed

+274
-275
lines changed

python/pyspark/ml/tests/test_feature.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -169,7 +169,7 @@ def test_count_vectorizer_from_vocab(self):
169169

170170
# Test an empty vocabulary
171171
with QuietTest(self.sc):
172-
with self.assertRaisesRegexp(Exception, "vocabSize.*invalid.*0"):
172+
with self.assertRaisesRegex(Exception, "vocabSize.*invalid.*0"):
173173
CountVectorizerModel.from_vocabulary([], inputCol="words")
174174

175175
# Test model with default settings can transform

python/pyspark/ml/tests/test_image.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -47,19 +47,19 @@ def test_read_images(self):
4747
self.assertEqual(ImageSchema.undefinedImageType, "Undefined")
4848

4949
with QuietTest(self.sc):
50-
self.assertRaisesRegexp(
50+
self.assertRaisesRegex(
5151
TypeError,
5252
"image argument should be pyspark.sql.types.Row; however",
5353
lambda: ImageSchema.toNDArray("a"))
5454

5555
with QuietTest(self.sc):
56-
self.assertRaisesRegexp(
56+
self.assertRaisesRegex(
5757
ValueError,
5858
"image argument should have attributes specified in",
5959
lambda: ImageSchema.toNDArray(Row(a=1)))
6060

6161
with QuietTest(self.sc):
62-
self.assertRaisesRegexp(
62+
self.assertRaisesRegex(
6363
TypeError,
6464
"array argument should be numpy.ndarray; however, it got",
6565
lambda: ImageSchema.toImage("a"))

python/pyspark/ml/tests/test_param.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -308,7 +308,7 @@ def test_logistic_regression_check_thresholds(self):
308308
LogisticRegression
309309
)
310310

311-
self.assertRaisesRegexp(
311+
self.assertRaisesRegex(
312312
ValueError,
313313
"Logistic Regression getThreshold found inconsistent.*$",
314314
LogisticRegression, threshold=0.42, thresholds=[0.5, 0.5]

python/pyspark/ml/tests/test_persistence.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -442,7 +442,7 @@ def test_default_read_write_default_params(self):
442442
del metadata['defaultParamMap']
443443
metadataStr = json.dumps(metadata, separators=[',', ':'])
444444
loadedMetadata = reader._parseMetaData(metadataStr, )
445-
with self.assertRaisesRegexp(AssertionError, "`defaultParamMap` section not found"):
445+
with self.assertRaisesRegex(AssertionError, "`defaultParamMap` section not found"):
446446
reader.getAndSetParams(lr, loadedMetadata)
447447

448448
# Prior to 2.4.0, metadata doesn't have `defaultParamMap`.

python/pyspark/ml/tests/test_tuning.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -499,15 +499,15 @@ def test_invalid_user_specified_folds(self):
499499
evaluator=evaluator,
500500
numFolds=2,
501501
foldCol="fold")
502-
with self.assertRaisesRegexp(Exception, "Fold number must be in range"):
502+
with self.assertRaisesRegex(Exception, "Fold number must be in range"):
503503
cv.fit(dataset_with_folds)
504504

505505
cv = CrossValidator(estimator=lr,
506506
estimatorParamMaps=grid,
507507
evaluator=evaluator,
508508
numFolds=4,
509509
foldCol="fold")
510-
with self.assertRaisesRegexp(Exception, "The validation data at fold 3 is empty"):
510+
with self.assertRaisesRegex(Exception, "The validation data at fold 3 is empty"):
511511
cv.fit(dataset_with_folds)
512512

513513

python/pyspark/ml/tests/test_wrapper.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@ def test_java_object_gets_detached(self):
5454
model.__del__()
5555

5656
def condition():
57-
with self.assertRaisesRegexp(py4j.protocol.Py4JError, error_no_object):
57+
with self.assertRaisesRegex(py4j.protocol.Py4JError, error_no_object):
5858
model._java_obj.toString()
5959
self.assertIn("LinearRegressionTrainingSummary", summary._java_obj.toString())
6060
return True
@@ -67,9 +67,9 @@ def condition():
6767
pass
6868

6969
def condition():
70-
with self.assertRaisesRegexp(py4j.protocol.Py4JError, error_no_object):
70+
with self.assertRaisesRegex(py4j.protocol.Py4JError, error_no_object):
7171
model._java_obj.toString()
72-
with self.assertRaisesRegexp(py4j.protocol.Py4JError, error_no_object):
72+
with self.assertRaisesRegex(py4j.protocol.Py4JError, error_no_object):
7373
summary._java_obj.toString()
7474
return True
7575

python/pyspark/sql/tests/test_arrow.py

Lines changed: 14 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@
3434

3535
if have_pandas:
3636
import pandas as pd
37-
from pandas.util.testing import assert_frame_equal
37+
from pandas.testing import assert_frame_equal
3838

3939
if have_pyarrow:
4040
import pyarrow as pa # noqa: F401
@@ -137,7 +137,7 @@ def test_toPandas_fallback_disabled(self):
137137
df = self.spark.createDataFrame([(None,)], schema=schema)
138138
with QuietTest(self.sc):
139139
with self.warnings_lock:
140-
with self.assertRaisesRegexp(Exception, 'Unsupported type'):
140+
with self.assertRaisesRegex(Exception, 'Unsupported type'):
141141
df.toPandas()
142142

143143
def test_null_conversion(self):
@@ -214,7 +214,7 @@ def raise_exception():
214214
exception_udf = udf(raise_exception, IntegerType())
215215
df = df.withColumn("error", exception_udf())
216216
with QuietTest(self.sc):
217-
with self.assertRaisesRegexp(Exception, 'My error'):
217+
with self.assertRaisesRegex(Exception, 'My error'):
218218
df.toPandas()
219219

220220
def _createDataFrame_toggle(self, pdf, schema=None):
@@ -228,7 +228,7 @@ def _createDataFrame_toggle(self, pdf, schema=None):
228228
def test_createDataFrame_toggle(self):
229229
pdf = self.create_pandas_data_frame()
230230
df_no_arrow, df_arrow = self._createDataFrame_toggle(pdf, schema=self.schema)
231-
self.assertEquals(df_no_arrow.collect(), df_arrow.collect())
231+
self.assertEqual(df_no_arrow.collect(), df_arrow.collect())
232232

233233
def test_createDataFrame_respect_session_timezone(self):
234234
from datetime import timedelta
@@ -258,7 +258,7 @@ def test_createDataFrame_respect_session_timezone(self):
258258
def test_createDataFrame_with_schema(self):
259259
pdf = self.create_pandas_data_frame()
260260
df = self.spark.createDataFrame(pdf, schema=self.schema)
261-
self.assertEquals(self.schema, df.schema)
261+
self.assertEqual(self.schema, df.schema)
262262
pdf_arrow = df.toPandas()
263263
assert_frame_equal(pdf_arrow, pdf)
264264

@@ -269,31 +269,31 @@ def test_createDataFrame_with_incorrect_schema(self):
269269
wrong_schema = StructType(fields)
270270
with self.sql_conf({"spark.sql.execution.pandas.convertToArrowArraySafely": False}):
271271
with QuietTest(self.sc):
272-
with self.assertRaisesRegexp(Exception, "[D|d]ecimal.*got.*date"):
272+
with self.assertRaisesRegex(Exception, "[D|d]ecimal.*got.*date"):
273273
self.spark.createDataFrame(pdf, schema=wrong_schema)
274274

275275
def test_createDataFrame_with_names(self):
276276
pdf = self.create_pandas_data_frame()
277277
new_names = list(map(str, range(len(self.schema.fieldNames()))))
278278
# Test that schema as a list of column names gets applied
279279
df = self.spark.createDataFrame(pdf, schema=list(new_names))
280-
self.assertEquals(df.schema.fieldNames(), new_names)
280+
self.assertEqual(df.schema.fieldNames(), new_names)
281281
# Test that schema as tuple of column names gets applied
282282
df = self.spark.createDataFrame(pdf, schema=tuple(new_names))
283-
self.assertEquals(df.schema.fieldNames(), new_names)
283+
self.assertEqual(df.schema.fieldNames(), new_names)
284284

285285
def test_createDataFrame_column_name_encoding(self):
286286
pdf = pd.DataFrame({u'a': [1]})
287287
columns = self.spark.createDataFrame(pdf).columns
288288
self.assertTrue(isinstance(columns[0], str))
289-
self.assertEquals(columns[0], 'a')
289+
self.assertEqual(columns[0], 'a')
290290
columns = self.spark.createDataFrame(pdf, [u'b']).columns
291291
self.assertTrue(isinstance(columns[0], str))
292-
self.assertEquals(columns[0], 'b')
292+
self.assertEqual(columns[0], 'b')
293293

294294
def test_createDataFrame_with_single_data_type(self):
295295
with QuietTest(self.sc):
296-
with self.assertRaisesRegexp(ValueError, ".*IntegerType.*not supported.*"):
296+
with self.assertRaisesRegex(ValueError, ".*IntegerType.*not supported.*"):
297297
self.spark.createDataFrame(pd.DataFrame({"a": [1]}), schema="int")
298298

299299
def test_createDataFrame_does_not_modify_input(self):
@@ -311,7 +311,7 @@ def test_schema_conversion_roundtrip(self):
311311
from pyspark.sql.pandas.types import from_arrow_schema, to_arrow_schema
312312
arrow_schema = to_arrow_schema(self.schema)
313313
schema_rt = from_arrow_schema(arrow_schema)
314-
self.assertEquals(self.schema, schema_rt)
314+
self.assertEqual(self.schema, schema_rt)
315315

316316
def test_createDataFrame_with_array_type(self):
317317
pdf = pd.DataFrame({"a": [[1, 2], [3, 4]], "b": [[u"x", u"y"], [u"y", u"z"]]})
@@ -420,7 +420,7 @@ def test_createDataFrame_fallback_enabled(self):
420420

421421
def test_createDataFrame_fallback_disabled(self):
422422
with QuietTest(self.sc):
423-
with self.assertRaisesRegexp(TypeError, 'Unsupported type'):
423+
with self.assertRaisesRegex(TypeError, 'Unsupported type'):
424424
self.spark.createDataFrame(
425425
pd.DataFrame({"a": [[datetime.datetime(2015, 11, 1, 0, 30)]]}),
426426
"a: array<timestamp>")
@@ -545,7 +545,7 @@ def tearDownClass(cls):
545545
cls.spark.stop()
546546

547547
def test_exception_by_max_results(self):
548-
with self.assertRaisesRegexp(Exception, "is bigger than"):
548+
with self.assertRaisesRegex(Exception, "is bigger than"):
549549
self.spark.range(0, 10000, 1, 100).toPandas()
550550

551551

python/pyspark/sql/tests/test_catalog.py

Lines changed: 28 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -25,11 +25,11 @@ class CatalogTests(ReusedSQLTestCase):
2525
def test_current_database(self):
2626
spark = self.spark
2727
with self.database("some_db"):
28-
self.assertEquals(spark.catalog.currentDatabase(), "default")
28+
self.assertEqual(spark.catalog.currentDatabase(), "default")
2929
spark.sql("CREATE DATABASE some_db")
3030
spark.catalog.setCurrentDatabase("some_db")
31-
self.assertEquals(spark.catalog.currentDatabase(), "some_db")
32-
self.assertRaisesRegexp(
31+
self.assertEqual(spark.catalog.currentDatabase(), "some_db")
32+
self.assertRaisesRegex(
3333
AnalysisException,
3434
"does_not_exist",
3535
lambda: spark.catalog.setCurrentDatabase("does_not_exist"))
@@ -38,10 +38,10 @@ def test_list_databases(self):
3838
spark = self.spark
3939
with self.database("some_db"):
4040
databases = [db.name for db in spark.catalog.listDatabases()]
41-
self.assertEquals(databases, ["default"])
41+
self.assertEqual(databases, ["default"])
4242
spark.sql("CREATE DATABASE some_db")
4343
databases = [db.name for db in spark.catalog.listDatabases()]
44-
self.assertEquals(sorted(databases), ["default", "some_db"])
44+
self.assertEqual(sorted(databases), ["default", "some_db"])
4545

4646
def test_list_tables(self):
4747
from pyspark.sql.catalog import Table
@@ -50,8 +50,8 @@ def test_list_tables(self):
5050
spark.sql("CREATE DATABASE some_db")
5151
with self.table("tab1", "some_db.tab2", "tab3_via_catalog"):
5252
with self.tempView("temp_tab"):
53-
self.assertEquals(spark.catalog.listTables(), [])
54-
self.assertEquals(spark.catalog.listTables("some_db"), [])
53+
self.assertEqual(spark.catalog.listTables(), [])
54+
self.assertEqual(spark.catalog.listTables("some_db"), [])
5555
spark.createDataFrame([(1, 1)]).createOrReplaceTempView("temp_tab")
5656
spark.sql("CREATE TABLE tab1 (name STRING, age INT) USING parquet")
5757
spark.sql("CREATE TABLE some_db.tab2 (name STRING, age INT) USING parquet")
@@ -66,40 +66,40 @@ def test_list_tables(self):
6666
sorted(spark.catalog.listTables("default"), key=lambda t: t.name)
6767
tablesSomeDb = \
6868
sorted(spark.catalog.listTables("some_db"), key=lambda t: t.name)
69-
self.assertEquals(tables, tablesDefault)
70-
self.assertEquals(len(tables), 3)
71-
self.assertEquals(len(tablesSomeDb), 2)
72-
self.assertEquals(tables[0], Table(
69+
self.assertEqual(tables, tablesDefault)
70+
self.assertEqual(len(tables), 3)
71+
self.assertEqual(len(tablesSomeDb), 2)
72+
self.assertEqual(tables[0], Table(
7373
name="tab1",
7474
database="default",
7575
description=None,
7676
tableType="MANAGED",
7777
isTemporary=False))
78-
self.assertEquals(tables[1], Table(
78+
self.assertEqual(tables[1], Table(
7979
name="tab3_via_catalog",
8080
database="default",
8181
description=description,
8282
tableType="MANAGED",
8383
isTemporary=False))
84-
self.assertEquals(tables[2], Table(
84+
self.assertEqual(tables[2], Table(
8585
name="temp_tab",
8686
database=None,
8787
description=None,
8888
tableType="TEMPORARY",
8989
isTemporary=True))
90-
self.assertEquals(tablesSomeDb[0], Table(
90+
self.assertEqual(tablesSomeDb[0], Table(
9191
name="tab2",
9292
database="some_db",
9393
description=None,
9494
tableType="MANAGED",
9595
isTemporary=False))
96-
self.assertEquals(tablesSomeDb[1], Table(
96+
self.assertEqual(tablesSomeDb[1], Table(
9797
name="temp_tab",
9898
database=None,
9999
description=None,
100100
tableType="TEMPORARY",
101101
isTemporary=True))
102-
self.assertRaisesRegexp(
102+
self.assertRaisesRegex(
103103
AnalysisException,
104104
"does_not_exist",
105105
lambda: spark.catalog.listTables("does_not_exist"))
@@ -119,12 +119,12 @@ def test_list_functions(self):
119119
self.assertTrue("to_timestamp" in functions)
120120
self.assertTrue("to_unix_timestamp" in functions)
121121
self.assertTrue("current_database" in functions)
122-
self.assertEquals(functions["+"], Function(
122+
self.assertEqual(functions["+"], Function(
123123
name="+",
124124
description=None,
125125
className="org.apache.spark.sql.catalyst.expressions.Add",
126126
isTemporary=True))
127-
self.assertEquals(functions, functionsDefault)
127+
self.assertEqual(functions, functionsDefault)
128128

129129
with self.function("func1", "some_db.func2"):
130130
spark.catalog.registerFunction("temp_func", lambda x: str(x))
@@ -141,7 +141,7 @@ def test_list_functions(self):
141141
self.assertTrue("temp_func" in newFunctionsSomeDb)
142142
self.assertTrue("func1" not in newFunctionsSomeDb)
143143
self.assertTrue("func2" in newFunctionsSomeDb)
144-
self.assertRaisesRegexp(
144+
self.assertRaisesRegex(
145145
AnalysisException,
146146
"does_not_exist",
147147
lambda: spark.catalog.listFunctions("does_not_exist"))
@@ -158,16 +158,16 @@ def test_list_columns(self):
158158
columns = sorted(spark.catalog.listColumns("tab1"), key=lambda c: c.name)
159159
columnsDefault = \
160160
sorted(spark.catalog.listColumns("tab1", "default"), key=lambda c: c.name)
161-
self.assertEquals(columns, columnsDefault)
162-
self.assertEquals(len(columns), 2)
163-
self.assertEquals(columns[0], Column(
161+
self.assertEqual(columns, columnsDefault)
162+
self.assertEqual(len(columns), 2)
163+
self.assertEqual(columns[0], Column(
164164
name="age",
165165
description=None,
166166
dataType="int",
167167
nullable=True,
168168
isPartition=False,
169169
isBucket=False))
170-
self.assertEquals(columns[1], Column(
170+
self.assertEqual(columns[1], Column(
171171
name="name",
172172
description=None,
173173
dataType="string",
@@ -176,26 +176,26 @@ def test_list_columns(self):
176176
isBucket=False))
177177
columns2 = \
178178
sorted(spark.catalog.listColumns("tab2", "some_db"), key=lambda c: c.name)
179-
self.assertEquals(len(columns2), 2)
180-
self.assertEquals(columns2[0], Column(
179+
self.assertEqual(len(columns2), 2)
180+
self.assertEqual(columns2[0], Column(
181181
name="nickname",
182182
description=None,
183183
dataType="string",
184184
nullable=True,
185185
isPartition=False,
186186
isBucket=False))
187-
self.assertEquals(columns2[1], Column(
187+
self.assertEqual(columns2[1], Column(
188188
name="tolerance",
189189
description=None,
190190
dataType="float",
191191
nullable=True,
192192
isPartition=False,
193193
isBucket=False))
194-
self.assertRaisesRegexp(
194+
self.assertRaisesRegex(
195195
AnalysisException,
196196
"tab2",
197197
lambda: spark.catalog.listColumns("tab2"))
198-
self.assertRaisesRegexp(
198+
self.assertRaisesRegex(
199199
AnalysisException,
200200
"does_not_exist",
201201
lambda: spark.catalog.listColumns("does_not_exist"))

python/pyspark/sql/tests/test_column.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ def test_validate_column_types(self):
4747
self.assertTrue("Column" in _to_java_column(u"a").getClass().toString())
4848
self.assertTrue("Column" in _to_java_column(self.spark.range(1).id).getClass().toString())
4949

50-
self.assertRaisesRegexp(
50+
self.assertRaisesRegex(
5151
TypeError,
5252
"Invalid argument, not a string or column",
5353
lambda: _to_java_column(1))
@@ -58,7 +58,7 @@ class A():
5858
self.assertRaises(TypeError, lambda: _to_java_column(A()))
5959
self.assertRaises(TypeError, lambda: _to_java_column([]))
6060

61-
self.assertRaisesRegexp(
61+
self.assertRaisesRegex(
6262
TypeError,
6363
"Invalid argument, not a string or column",
6464
lambda: udf(lambda x: x)(None))
@@ -79,9 +79,9 @@ def test_column_operators(self):
7979
cs.startswith('a'), cs.endswith('a'), ci.eqNullSafe(cs)
8080
self.assertTrue(all(isinstance(c, Column) for c in css))
8181
self.assertTrue(isinstance(ci.cast(LongType()), Column))
82-
self.assertRaisesRegexp(ValueError,
83-
"Cannot apply 'in' operator against a column",
84-
lambda: 1 in cs)
82+
self.assertRaisesRegex(ValueError,
83+
"Cannot apply 'in' operator against a column",
84+
lambda: 1 in cs)
8585

8686
def test_column_accessor(self):
8787
from pyspark.sql.functions import col

0 commit comments

Comments
 (0)