diff --git a/tests/integration/test_sqlalchemy_integration.py b/tests/integration/test_sqlalchemy_integration.py index 1dc8f05a..4296ab7a 100644 --- a/tests/integration/test_sqlalchemy_integration.py +++ b/tests/integration/test_sqlalchemy_integration.py @@ -9,6 +9,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License +import os import pytest import sqlalchemy as sqla from sqlalchemy.sql import and_, or_, not_ @@ -177,6 +178,22 @@ def test_conjunctions(trino_connection): assert len(rows) == 1 +@pytest.mark.parametrize('trino_connection', ['system'], indirect=True) +def test_completed_states(trino_connection): + _, conn = trino_connection + metadata = sqla.MetaData() + queries = sqla.Table('queries', metadata, schema='runtime', autoload_with=conn) + s = sqla.select(queries.c.state).where(queries.c.query == "SELECT version()") + result = conn.execute(s) + rows = result.fetchall() + assert len(rows) > 0 + for row in rows: + if os.environ.get("TRINO_VERSION") == '351': + assert row['state'] == 'FAILED' + else: + assert row['state'] == 'FINISHED' + + @pytest.mark.parametrize('trino_connection', ['tpch'], indirect=True) def test_textual_sql(trino_connection): _, conn = trino_connection diff --git a/trino/sqlalchemy/dialect.py b/trino/sqlalchemy/dialect.py index c7056cc4..90cabe8f 100644 --- a/trino/sqlalchemy/dialect.py +++ b/trino/sqlalchemy/dialect.py @@ -317,7 +317,7 @@ def _get_server_version_info(self, connection: Connection) -> Any: query = "SELECT version()" try: res = connection.execute(sql.text(query)) - version = res.scalar() + version = res.scalar_one() return tuple([version]) except exc.ProgrammingError as e: logger.debug(f"Failed to get server version: {e.orig.message}")