Skip to content

Commit b887d3d

Browse files
nija-atHyukjinKwon
authored andcommitted
[SPARK-42477][CONNECT][PYTHON] accept user_agent in spark connect's connection string
### What changes were proposed in this pull request? Currently, the Spark Connect service's `client_type` attribute (which is really [user agent]) is set to `_SPARK_CONNECT_PYTHON` to signify PySpark. With this change, the connection for the Spark Connect remote accepts an optional `user_agent` parameter which is then passed down to the service. [user agent]: https://www.w3.org/WAI/UA/work/wiki/Definition_of_User_Agent ### Why are the changes needed? This enables partners using Spark Connect to set their application as the user agent, which then allows visibility and measurement of integrations and usages of spark connect. ### Does this PR introduce _any_ user-facing change? A new optional `user_agent` parameter is now recognized as part of the Spark Connect connection string. ### How was this patch tested? - unit tests attached - manually running the `pyspark` binary with the `user_agent` connection string set and verifying the payload sent to the server. Similar testing for the default. Closes #40054 from nija-at/user-agent. Authored-by: Niranjan Jayakar <[email protected]> Signed-off-by: Hyukjin Kwon <[email protected]>
1 parent d5fa41e commit b887d3d

File tree

5 files changed

+132
-5
lines changed

5 files changed

+132
-5
lines changed

connector/connect/docs/client-connection-string.md

+11-1
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,8 @@ sc://hostname:port/;param1=value;param2=value
5858
<td>token</td>
5959
<td>String</td>
6060
<td>When this param is set in the URL, it will enable standard
61-
bearer token authentication using GRPC. By default this value is not set.</td>
61+
bearer token authentication using GRPC. By default this value is not set.
62+
Setting this value enables SSL.</td>
6263
<td><pre>token=ABCDEFGH</pre></td>
6364
</tr>
6465
<tr>
@@ -81,6 +82,15 @@ sc://hostname:port/;param1=value;param2=value
8182
<pre>user_id=Martin</pre>
8283
</td>
8384
</tr>
85+
<tr>
86+
<td>user_agent</td>
87+
<td>String</td>
88+
<td>The user agent acting on behalf of the user, typically applications
89+
that use Spark Connect to implement its functionality and execute Spark
90+
requests on behalf of the user.<br/>
91+
<i>Default: </i><pre>_SPARK_CONNECT_PYTHON</pre> in the Python client</td>
92+
<td><pre>user_agent=my_data_query_app</pre></td>
93+
</tr>
8494
</table>
8595

8696
## Examples

dev/sparktestsupport/modules.py

+1
Original file line numberDiff line numberDiff line change
@@ -516,6 +516,7 @@ def __hash__(self):
516516
"pyspark.sql.connect.dataframe",
517517
"pyspark.sql.connect.functions",
518518
# unittests
519+
"pyspark.sql.tests.connect.test_client",
519520
"pyspark.sql.tests.connect.test_connect_plan",
520521
"pyspark.sql.tests.connect.test_connect_basic",
521522
"pyspark.sql.tests.connect.test_connect_function",

python/pyspark/sql/connect/client.py

+27-2
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,8 @@
1919
"SparkConnectClient",
2020
]
2121

22+
import string
23+
2224
from pyspark.sql.connect.utils import check_dependencies
2325

2426
check_dependencies(__name__, __file__)
@@ -120,6 +122,7 @@ class ChannelBuilder:
120122
PARAM_USE_SSL = "use_ssl"
121123
PARAM_TOKEN = "token"
122124
PARAM_USER_ID = "user_id"
125+
PARAM_USER_AGENT = "user_agent"
123126

124127
@staticmethod
125128
def default_port() -> int:
@@ -215,6 +218,7 @@ def metadata(self) -> Iterable[Tuple[str, str]]:
215218
ChannelBuilder.PARAM_TOKEN,
216219
ChannelBuilder.PARAM_USE_SSL,
217220
ChannelBuilder.PARAM_USER_ID,
221+
ChannelBuilder.PARAM_USER_AGENT,
218222
]
219223
]
220224

@@ -244,6 +248,27 @@ def userId(self) -> Optional[str]:
244248
"""
245249
return self.params.get(ChannelBuilder.PARAM_USER_ID, None)
246250

251+
@property
252+
def userAgent(self) -> str:
253+
"""
254+
Returns
255+
-------
256+
user_agent : str
257+
The user_agent parameter specified in the connection string,
258+
or "_SPARK_CONNECT_PYTHON" when not specified.
259+
"""
260+
user_agent = self.params.get(ChannelBuilder.PARAM_USER_AGENT, "_SPARK_CONNECT_PYTHON")
261+
allowed_chars = string.ascii_letters + string.punctuation
262+
if len(user_agent) > 200:
263+
raise SparkConnectException(
264+
"'user_agent' parameter cannot exceed 200 characters in length"
265+
)
266+
if set(user_agent).difference(allowed_chars):
267+
raise SparkConnectException(
268+
"Only alphanumeric and common punctuations are allowed for 'user_agent'"
269+
)
270+
return user_agent
271+
247272
def get(self, key: str) -> Any:
248273
"""
249274
Parameters
@@ -559,15 +584,15 @@ def close(self) -> None:
559584
def _execute_plan_request_with_metadata(self) -> pb2.ExecutePlanRequest:
560585
req = pb2.ExecutePlanRequest()
561586
req.client_id = self._session_id
562-
req.client_type = "_SPARK_CONNECT_PYTHON"
587+
req.client_type = self._builder.userAgent
563588
if self._user_id:
564589
req.user_context.user_id = self._user_id
565590
return req
566591

567592
def _analyze_plan_request_with_metadata(self) -> pb2.AnalyzePlanRequest:
568593
req = pb2.AnalyzePlanRequest()
569594
req.client_id = self._session_id
570-
req.client_type = "_SPARK_CONNECT_PYTHON"
595+
req.client_type = self._builder.userAgent
571596
if self._user_id:
572597
req.user_context.user_id = self._user_id
573598
return req
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,67 @@
1+
#
2+
# Licensed to the Apache Software Foundation (ASF) under one or more
3+
# contributor license agreements. See the NOTICE file distributed with
4+
# this work for additional information regarding copyright ownership.
5+
# The ASF licenses this file to You under the Apache License, Version 2.0
6+
# (the "License"); you may not use this file except in compliance with
7+
# the License. You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing, software
12+
# distributed under the License is distributed on an "AS IS" BASIS,
13+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
# See the License for the specific language governing permissions and
15+
# limitations under the License.
16+
#
17+
18+
import unittest
19+
from typing import Optional
20+
21+
from pyspark.sql.connect.client import SparkConnectClient
22+
import pyspark.sql.connect.proto as proto
23+
24+
25+
class SparkConnectClientTestCase(unittest.TestCase):
26+
def test_user_agent_passthrough(self):
27+
client = SparkConnectClient("sc://foo/;user_agent=bar")
28+
mock = MockService(client._session_id)
29+
client._stub = mock
30+
31+
command = proto.Command()
32+
client.execute_command(command)
33+
34+
self.assertIsNotNone(mock.req, "ExecutePlan API was not called when expected")
35+
self.assertEqual(mock.req.client_type, "bar")
36+
37+
def test_user_agent_default(self):
38+
client = SparkConnectClient("sc://foo/")
39+
mock = MockService(client._session_id)
40+
client._stub = mock
41+
42+
command = proto.Command()
43+
client.execute_command(command)
44+
45+
self.assertIsNotNone(mock.req, "ExecutePlan API was not called when expected")
46+
self.assertEqual(mock.req.client_type, "_SPARK_CONNECT_PYTHON")
47+
48+
49+
class MockService:
50+
# Simplest mock of the SparkConnectService.
51+
# If this needs more complex logic, it needs to be replaced with Python mocking.
52+
53+
req: Optional[proto.ExecutePlanRequest]
54+
55+
def __init__(self, session_id: str):
56+
self._session_id = session_id
57+
self.req = None
58+
59+
def ExecutePlan(self, req: proto.ExecutePlanRequest, metadata):
60+
self.req = req
61+
resp = proto.ExecutePlanResponse()
62+
resp.client_id = self._session_id
63+
return [resp]
64+
65+
66+
if __name__ == "__main__":
67+
unittest.main()

python/pyspark/sql/tests/connect/test_connect_basic.py

+26-2
Original file line numberDiff line numberDiff line change
@@ -2949,10 +2949,33 @@ def test_sensible_defaults(self):
29492949

29502950
chan = ChannelBuilder("sc://host/;token=abcs")
29512951
self.assertTrue(chan.secure, "specifying a token must set the channel to secure")
2952-
2952+
self.assertEqual(chan.userAgent, "_SPARK_CONNECT_PYTHON")
29532953
chan = ChannelBuilder("sc://host/;use_ssl=abcs")
29542954
self.assertFalse(chan.secure, "Garbage in, false out")
29552955

2956+
def test_invalid_user_agent_charset(self):
2957+
# fmt: off
2958+
invalid_user_agents = [
2959+
"agent»", # non standard symbol
2960+
"age nt", # whitespace
2961+
"ägent", # non-ascii alphabet
2962+
]
2963+
# fmt: on
2964+
for user_agent in invalid_user_agents:
2965+
with self.subTest(user_agent=user_agent):
2966+
chan = ChannelBuilder(f"sc://host/;user_agent={user_agent}")
2967+
with self.assertRaises(SparkConnectException) as err:
2968+
chan.userAgent
2969+
2970+
self.assertRegex(err.exception.message, "alphanumeric and common punctuations")
2971+
2972+
def test_invalid_user_agent_len(self):
2973+
user_agent = "x" * 201
2974+
chan = ChannelBuilder(f"sc://host/;user_agent={user_agent}")
2975+
with self.assertRaises(SparkConnectException) as err:
2976+
chan.userAgent
2977+
self.assertRegex(err.exception.message, "characters in length")
2978+
29562979
def test_valid_channel_creation(self):
29572980
chan = ChannelBuilder("sc://host").toChannel()
29582981
self.assertIsInstance(chan, grpc.Channel)
@@ -2965,8 +2988,9 @@ def test_valid_channel_creation(self):
29652988
self.assertIsInstance(chan, grpc.Channel)
29662989

29672990
def test_channel_properties(self):
2968-
chan = ChannelBuilder("sc://host/;use_ssl=true;token=abc;param1=120%2021")
2991+
chan = ChannelBuilder("sc://host/;use_ssl=true;token=abc;user_agent=foo;param1=120%2021")
29692992
self.assertEqual("host:15002", chan.endpoint)
2993+
self.assertEqual("foo", chan.userAgent)
29702994
self.assertEqual(True, chan.secure)
29712995
self.assertEqual("120 21", chan.get("param1"))
29722996

0 commit comments

Comments
 (0)