From 63004c185c5fb37b14e630a9d8e6fa241e172567 Mon Sep 17 00:00:00 2001 From: Jesse Whitehouse Date: Mon, 28 Oct 2024 19:03:33 -0500 Subject: [PATCH 1/2] Implement a DefaultNone so that a log message can be emitted if the user passes an access_token which evaluates to None --- src/databricks/sql/__init__.py | 19 +++++++++++++++++-- src/databricks/sql/client.py | 12 +++++++++--- 2 files changed, 26 insertions(+), 5 deletions(-) diff --git a/src/databricks/sql/__init__.py b/src/databricks/sql/__init__.py index 42167b00..3913f490 100644 --- a/src/databricks/sql/__init__.py +++ b/src/databricks/sql/__init__.py @@ -10,7 +10,7 @@ import re -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Optional, Union if TYPE_CHECKING: # Use this import purely for type annotations, a la https://mypy.readthedocs.io/en/latest/runtime_troubles.html#import-cycles @@ -83,8 +83,23 @@ def DateFromTicks(ticks): def TimestampFromTicks(ticks): return Timestamp(*time.localtime(ticks)[:6]) +def singleton(class_): + instances = {} + def getinstance(*args, **kwargs): + if class_ not in instances: + instances[class_] = class_(*args, **kwargs) + return instances[class_] + return getinstance -def connect(server_hostname, http_path, access_token=None, **kwargs) -> "Connection": +@singleton +class DefaultNone(object): + """Used to represent a default value of None so that this code can distinguish between + the user passing None versus a default value of None being used. + """ + pass + + +def connect(server_hostname, http_path, access_token: Optional[Union[str, DefaultNone]]=DefaultNone, **kwargs) -> "Connection": from .client import Connection return Connection(server_hostname, http_path, access_token, **kwargs) diff --git a/src/databricks/sql/client.py b/src/databricks/sql/client.py index 4e0ab941..c89058fd 100755 --- a/src/databricks/sql/client.py +++ b/src/databricks/sql/client.py @@ -12,7 +12,7 @@ import decimal from uuid import UUID -from databricks.sql import __version__ +from databricks.sql import __version__, DefaultNone from databricks.sql import * from databricks.sql.exc import ( OperationalError, @@ -63,7 +63,7 @@ def __init__( self, server_hostname: str, http_path: str, - access_token: Optional[str] = None, + access_token: Optional[Union[str, DefaultNone]] = None, http_headers: Optional[List[Tuple[str, str]]] = None, session_configuration: Optional[Dict[str, Any]] = None, catalog: Optional[str] = None, @@ -204,7 +204,13 @@ def read(self) -> Optional[OAuthToken]: # use_cloud_fetch # Enable use of cloud fetch to extract large query results in parallel via cloud storage - if access_token: + if access_token is DefaultNone: + access_token = None + elif access_token is None: + logger.info( + "Connection access_token was passed a None value. U2M OAuth will be attempted" + ) + else: access_token_kv = {"access_token": access_token} kwargs = {**kwargs, **access_token_kv} From d10270af96c25c7825e5b5e15431e87b26c424b6 Mon Sep 17 00:00:00 2001 From: Jesse Whitehouse Date: Mon, 28 Oct 2024 19:16:43 -0500 Subject: [PATCH 2/2] Black the files --- src/databricks/sql/__init__.py | 12 +++++++++++- src/databricks/sql/client.py | 2 +- 2 files changed, 12 insertions(+), 2 deletions(-) diff --git a/src/databricks/sql/__init__.py b/src/databricks/sql/__init__.py index 3913f490..6387fff8 100644 --- a/src/databricks/sql/__init__.py +++ b/src/databricks/sql/__init__.py @@ -83,23 +83,33 @@ def DateFromTicks(ticks): def TimestampFromTicks(ticks): return Timestamp(*time.localtime(ticks)[:6]) + def singleton(class_): instances = {} + def getinstance(*args, **kwargs): if class_ not in instances: instances[class_] = class_(*args, **kwargs) return instances[class_] + return getinstance + @singleton class DefaultNone(object): """Used to represent a default value of None so that this code can distinguish between the user passing None versus a default value of None being used. """ + pass -def connect(server_hostname, http_path, access_token: Optional[Union[str, DefaultNone]]=DefaultNone, **kwargs) -> "Connection": +def connect( + server_hostname, + http_path, + access_token: Optional[Union[str, DefaultNone]] = DefaultNone, + **kwargs +) -> "Connection": from .client import Connection return Connection(server_hostname, http_path, access_token, **kwargs) diff --git a/src/databricks/sql/client.py b/src/databricks/sql/client.py index c89058fd..2b83465f 100755 --- a/src/databricks/sql/client.py +++ b/src/databricks/sql/client.py @@ -209,7 +209,7 @@ def read(self) -> Optional[OAuthToken]: elif access_token is None: logger.info( "Connection access_token was passed a None value. U2M OAuth will be attempted" - ) + ) else: access_token_kv = {"access_token": access_token} kwargs = {**kwargs, **access_token_kv}