Skip to content

[PECO-1263] Add get_async_execution method #314

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 8 commits into from
Jan 18, 2024
21 changes: 19 additions & 2 deletions src/databricks/sql/ae.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from enum import Enum
from typing import Optional, Union, TYPE_CHECKING
from databricks.sql.exc import RequestError
from databricks.sql.results import ResultSet

from dataclasses import dataclass
Expand Down Expand Up @@ -73,7 +74,7 @@ def __init__(
connection: "Connection",
query_id: UUID,
query_secret: UUID,
status: AsyncExecutionStatus,
status: Optional[AsyncExecutionStatus] = None,
execute_statement_response: Optional[ttypes.TExecuteStatementResp] = None,
):
self._connection = connection
Expand All @@ -83,6 +84,9 @@ def __init__(
self.query_secret = query_secret
self.status = status

if self.status is None:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't love having potentially long running ops in the init. For one thing, it means one more thing that must be mocked when unit testing, regardless of whether the test needs to interact with status directly. When I have initialization that is non-trivial, but must be complete for an object to be functional, I tend to put that in a factory method, and try to make the constructor hidden...not sure if we have that capability in python though. Take this comment with a grain of salt, because depending on the user experience, this point may be outweighed by trying to make the client experience simpler.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How about this:

I'll add a new AsyncExecutionStatus.UNKNOWN and make that the default. Then modify the .status class member to become a property that accesses a private ._status member. If ._status==AsyncExecutionStatus.UNKNOWN, this will fire-off the poll_for_status and set it.

This way, the long-running op won't happen until a user actually tries to do something which depends on the status.

Copy link
Collaborator

@benc-db benc-db Jan 10, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is there a reason to have it pretend to be a property? As opposed to just get_status()? This might be a result of other languages I've programmed in, but in my mind, a property is ideally either a.) a field, or b.) something directly computable from fields. Making it a method suggests that work will be done to get the value.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have no objection to that. Either way is technically "Pythonic".

self.poll_for_status()

status: AsyncExecutionStatus
query_id: UUID

Expand Down Expand Up @@ -111,12 +115,25 @@ def _thrift_cancel_operation(self) -> None:
_output = self._thrift_backend.async_cancel_command(self.t_operation_handle)
self.status = AsyncExecutionStatus.CANCELED

def serialize(self) -> str:
"""Return a string representing the query_id and secret of this AsyncExecution.

Use this to preserve a reference to the query_id and query_secret."""
return f"{self.query_id}:{self.query_secret}"

def poll_for_status(self) -> None:
"""Check the thrift server for the status of this operation and set self.status

This will result in an error if the operation has been canceled or aborted at the server"""

_output = self._thrift_backend._poll_for_status(self.t_operation_handle)
try:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Perfect example of why I don't love complexity in my constructors. An ideal constructor says, if you give me all the required parameters as input, you will get an instance of this object as object; here, however, our constructor could fail and throw an exception to the user. If you instead use a factory pattern, you have the choice of propagating the error, or just giving the user None, as you could name the method something like 'get_if_exists(...)'.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That makes sense to me.

I do have a factory function. It's the Connection.get_async_execution() method. I can push this "does it exist" checking into that function like you describe.

_output = self._thrift_backend._poll_for_status(self.t_operation_handle)
except RequestError as e:
if "RESOURCE_DOES_NOT_EXIST" in e.message:
raise AsyncExecutionException(
"Query not found: %s. Result may have already been fetched."
% self.query_id
) from e
self.status = _toperationstate_to_ae_status(_output.operationState)

def _thrift_fetch_result(self) -> None:
Expand Down
31 changes: 31 additions & 0 deletions src/databricks/sql/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@


from databricks.sql.ae import AsyncExecution, AsyncExecutionStatus
from uuid import UUID

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -403,6 +404,36 @@ def execute_async(
resp=execute_statement_resp,
)

def get_async_execution(
self, query_id: Union[str, UUID], query_secret: Union[str, UUID]
) -> "AsyncExecution":
"""Get an AsyncExecution object for an existing query.

Args:
query_id: The query id of the query to retrieve
query_secret: The query secret of the query to retrieve

Returns:
An AsyncExecution object that can be used to poll for status and retrieve results.
"""

if isinstance(query_id, UUID):
_qid = query_id
else:
_qid = UUID(hex=query_id)

if isinstance(query_secret, UUID):
_qs = query_secret
else:
_qs = UUID(hex=query_secret)

return AsyncExecution(
thrift_backend=self.thrift_backend,
connection=self,
query_id=_qid,
query_secret=_qs,
)


class Cursor:
def __init__(
Expand Down
150 changes: 141 additions & 9 deletions tests/e2e/test_execute_async.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,25 @@
import pytest
import time

LONG_RUNNING_QUERY = """
import threading

BASE_LONG_QUERY = """
SELECT SUM(A.id - B.id)
FROM range(1000000000) A CROSS JOIN range(100000000) B
FROM range({val}) A CROSS JOIN range({val}) B
GROUP BY (A.id - B.id)
"""
GT_ONE_MINUTE_VALUE = 100000000

# Arrived at this value through some manual testing on a serverless SQL warehouse
# The goal here is to have a query that takes longer than five seconds (therefore bypassing directResults)
# but not so long that I can't attempt to fetch its results in a reasonable amount of time
GT_FIVE_SECONDS_VALUE = 90000

LONG_RUNNING_QUERY = BASE_LONG_QUERY.format(val=GT_ONE_MINUTE_VALUE)
LONG_ISH_QUERY = BASE_LONG_QUERY.format(val=GT_FIVE_SECONDS_VALUE)

# This query should always return in < 5 seconds and therefore should be a direct result
DIRECT_RESULTS_QUERY = "select :param `col`"


class TestExecuteAsync(PySQLPytestTestCase):
Expand All @@ -26,12 +40,19 @@ def long_running_ae(self, scope="function") -> AsyncExecution:
# cancellation is idempotent
ae.cancel()

def test_basic_api(self):
@pytest.fixture
def long_ish_ae(self, scope="function") -> AsyncExecution:
"""Start a long-running query so we can make assertions about it."""
with self.connection() as conn:
ae = conn.execute_async(LONG_ISH_QUERY)
yield ae

def test_execute_async(self):
"""This is a WIP test of the basic API defined in PECO-1263"""
# This is taken directly from the design doc

with self.connection() as conn:
ae = conn.execute_async("select :param `col`", {"param": 1})
ae = conn.execute_async(DIRECT_RESULTS_QUERY, {"param": 1})
while ae.is_running:
ae.poll_for_status()
time.sleep(1)
Expand All @@ -40,6 +61,15 @@ def test_basic_api(self):

assert result.col == 1

def test_direct_results_query_canary(self):
"""This test verifies that on the current endpoint, the DIRECT_RESULTS_QUERY returned a thrift operation state
other than FINISHED_STATE. If this test fails, it means the SQL warehouse got slower at executing this query
"""

with self.connection() as conn:
ae = conn.execute_async(DIRECT_RESULTS_QUERY, {"param": 1})
assert not ae.is_running

def test_cancel_running_query(self, long_running_ae: AsyncExecution):
long_running_ae.cancel()
assert long_running_ae.status == AsyncExecutionStatus.CANCELED
Expand All @@ -53,10 +83,112 @@ def test_cant_get_results_after_cancel(self, long_running_ae: AsyncExecution):
with pytest.raises(AsyncExecutionException, match="Query was canceled"):
long_running_ae.get_result()

def test_get_async_execution_can_check_status(self, long_running_ae: AsyncExecution):
query_id, query_secret = str(long_running_ae.query_id), str(
long_running_ae.query_secret
)

with self.connection() as conn:
ae = conn.get_async_execution(query_id, query_secret)
assert ae.is_running

def test_get_async_execution_can_cancel_across_threads(self, long_running_ae: AsyncExecution):
query_id, query_secret = str(long_running_ae.query_id), str(
long_running_ae.query_secret
)

def cancel_query_in_separate_thread(query_id, query_secret):
with self.connection() as conn:
ae = conn.get_async_execution(query_id, query_secret)
ae.cancel()

threading.Thread(
target=cancel_query_in_separate_thread, args=(query_id, query_secret)
).start()

time.sleep(5)

long_running_ae.poll_for_status()
assert long_running_ae.status == AsyncExecutionStatus.CANCELED

def test_long_ish_query_canary(self, long_ish_ae: AsyncExecution):
"""This test verifies that on the current endpoint, the LONG_ISH_QUERY requires
at least one poll_for_status call before it is finished. If this test fails, it means
the SQL warehouse got faster at executing this query and we should increment the value
of GT_FIVE_SECONDS_VALUE

It would be easier to do this if Databricks SQL had a SLEEP() function :/
"""

poll_count = 0
while long_ish_ae.is_running:
time.sleep(1)
long_ish_ae.poll_for_status()
poll_count += 1

assert poll_count > 0

def test_get_async_execution_and_get_results_without_direct_results(
self, long_ish_ae: AsyncExecution
):
while long_ish_ae.is_running:
time.sleep(1)
long_ish_ae.poll_for_status()

def test_staging_operation(self):
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This diff looks wacky here. I'm removing the test_staging_operation for the moment.

"""We need to test what happens with a staging operation since this query won't have a result set
that user needs. It could be sufficient to notify users that they shouldn't use this API for staging/volumes
queries...
result = long_ish_ae.get_result().fetchone()
assert len(result) == 1

def test_get_async_execution_with_bogus_query_id(self):

with self.connection() as conn:
with pytest.raises(AsyncExecutionException, match="Query not found"):
ae = conn.get_async_execution("bedc786d-64da-45d4-99da-5d3603525803", "ba469f82-cf3f-454e-b575-f4dcd58dd692")

def test_get_async_execution_with_badly_formed_query_id(self):
with self.connection() as conn:
with pytest.raises(ValueError, match="badly formed hexadecimal UUID string"):
ae = conn.get_async_execution("foo", "bar")

def test_serialize(self, long_running_ae: AsyncExecution):
serialized = long_running_ae.serialize()
query_id, query_secret = serialized.split(":")

with self.connection() as conn:
ae = conn.get_async_execution(query_id, query_secret)
assert ae.is_running

def test_get_async_execution_no_results_when_direct_results_were_sent(self):
"""It remains to be seen whether results can be fetched repeatedly from a "picked up" execution.
"""
assert False

with self.connection() as conn:
ae = conn.execute_async(DIRECT_RESULTS_QUERY, {"param": 1})
query_id, query_secret = ae.serialize().split(":")
ae.get_result()

with self.connection() as conn:
with pytest.raises(AsyncExecutionException, match="Query not found"):
ae_late = conn.get_async_execution(query_id, query_secret)

def test_get_async_execution_and_fetch_results(self, long_ish_ae: AsyncExecution):
"""This tests currently _fails_ because of how result fetching is factored.

Currently, thrift_backend.py can't fetch results unless it has a TExecuteStatementResp object.
But with async executions, we don't have the original TExecuteStatementResp. So we'll need to build
a way to "fake" this until we can refactor thrift_backend.py to be more testable.
"""

query_id, query_secret = long_ish_ae.serialize().split(":")

with self.connection() as conn:
ae = conn.get_async_execution(query_id, query_secret)

while ae.is_running:
time.sleep(1)
ae.poll_for_status()

result = ae.get_result().fetchone()

assert len(result) == 1