|
19 | 19 | "SparkConnectClient",
|
20 | 20 | ]
|
21 | 21 |
|
| 22 | +import string |
| 23 | + |
22 | 24 | from pyspark.sql.connect.utils import check_dependencies
|
23 | 25 |
|
24 | 26 | check_dependencies(__name__, __file__)
|
@@ -120,6 +122,7 @@ class ChannelBuilder:
|
120 | 122 | PARAM_USE_SSL = "use_ssl"
|
121 | 123 | PARAM_TOKEN = "token"
|
122 | 124 | PARAM_USER_ID = "user_id"
|
| 125 | + PARAM_USER_AGENT = "user_agent" |
123 | 126 |
|
124 | 127 | @staticmethod
|
125 | 128 | def default_port() -> int:
|
@@ -215,6 +218,7 @@ def metadata(self) -> Iterable[Tuple[str, str]]:
|
215 | 218 | ChannelBuilder.PARAM_TOKEN,
|
216 | 219 | ChannelBuilder.PARAM_USE_SSL,
|
217 | 220 | ChannelBuilder.PARAM_USER_ID,
|
| 221 | + ChannelBuilder.PARAM_USER_AGENT, |
218 | 222 | ]
|
219 | 223 | ]
|
220 | 224 |
|
@@ -244,6 +248,27 @@ def userId(self) -> Optional[str]:
|
244 | 248 | """
|
245 | 249 | return self.params.get(ChannelBuilder.PARAM_USER_ID, None)
|
246 | 250 |
|
| 251 | + @property |
| 252 | + def userAgent(self) -> str: |
| 253 | + """ |
| 254 | + Returns |
| 255 | + ------- |
| 256 | + user_agent : str |
| 257 | + The user_agent parameter specified in the connection string, |
| 258 | + or "_SPARK_CONNECT_PYTHON" when not specified. |
| 259 | + """ |
| 260 | + user_agent = self.params.get(ChannelBuilder.PARAM_USER_AGENT, "_SPARK_CONNECT_PYTHON") |
| 261 | + allowed_chars = string.ascii_letters + string.punctuation |
| 262 | + if len(user_agent) > 200: |
| 263 | + raise SparkConnectException( |
| 264 | + "'user_agent' parameter cannot exceed 200 characters in length" |
| 265 | + ) |
| 266 | + if set(user_agent).difference(allowed_chars): |
| 267 | + raise SparkConnectException( |
| 268 | + "Only alphanumeric and common punctuations are allowed for 'user_agent'" |
| 269 | + ) |
| 270 | + return user_agent |
| 271 | + |
247 | 272 | def get(self, key: str) -> Any:
|
248 | 273 | """
|
249 | 274 | Parameters
|
@@ -559,15 +584,15 @@ def close(self) -> None:
|
559 | 584 | def _execute_plan_request_with_metadata(self) -> pb2.ExecutePlanRequest:
|
560 | 585 | req = pb2.ExecutePlanRequest()
|
561 | 586 | req.client_id = self._session_id
|
562 |
| - req.client_type = "_SPARK_CONNECT_PYTHON" |
| 587 | + req.client_type = self._builder.userAgent |
563 | 588 | if self._user_id:
|
564 | 589 | req.user_context.user_id = self._user_id
|
565 | 590 | return req
|
566 | 591 |
|
567 | 592 | def _analyze_plan_request_with_metadata(self) -> pb2.AnalyzePlanRequest:
|
568 | 593 | req = pb2.AnalyzePlanRequest()
|
569 | 594 | req.client_id = self._session_id
|
570 |
| - req.client_type = "_SPARK_CONNECT_PYTHON" |
| 595 | + req.client_type = self._builder.userAgent |
571 | 596 | if self._user_id:
|
572 | 597 | req.user_context.user_id = self._user_id
|
573 | 598 | return req
|
|
0 commit comments