|
1 |
| -import re |
2 |
| -from typing import Any, List, Optional, Dict, Union, Collection, Iterable, Tuple |
| 1 | +from typing import Any, List, Optional, Dict, Union, Iterable, Tuple |
3 | 2 |
|
4 | 3 | import databricks.sqlalchemy._ddl as dialect_ddl_impl
|
5 | 4 | import databricks.sqlalchemy._types as dialect_type_impl
|
6 | 5 | from databricks import sql
|
| 6 | +from databricks.sqlalchemy._information_schema import tables |
7 | 7 | from databricks.sqlalchemy._parse import (
|
8 | 8 | _describe_table_extended_result_to_dict_list,
|
9 | 9 | _match_table_not_found_string,
|
|
15 | 15 | )
|
16 | 16 |
|
17 | 17 | import sqlalchemy
|
18 |
| -from sqlalchemy import DDL, event |
| 18 | +from sqlalchemy import DDL, event, select, bindparam, exc |
19 | 19 | from sqlalchemy.engine import Connection, Engine, default, reflection
|
20 |
| -from sqlalchemy.engine.reflection import ObjectKind |
21 | 20 | from sqlalchemy.engine.interfaces import (
|
22 | 21 | ReflectedForeignKeyConstraint,
|
23 | 22 | ReflectedPrimaryKeyConstraint,
|
24 | 23 | ReflectedColumn,
|
| 24 | + ReflectedTableComment, |
25 | 25 | TableKey,
|
26 | 26 | )
|
| 27 | +from sqlalchemy.engine.reflection import ReflectionDefaults, ObjectKind, ObjectScope |
27 | 28 | from sqlalchemy.exc import DatabaseError, SQLAlchemyError
|
28 | 29 |
|
29 | 30 | try:
|
@@ -285,7 +286,7 @@ def get_table_names(self, connection: Connection, schema=None, **kwargs):
|
285 | 286 | views_result = self.get_view_names(connection=connection, schema=schema)
|
286 | 287 |
|
287 | 288 | # In Databricks, SHOW TABLES FROM <schema> returns both tables and views.
|
288 |
| - # Potential optimisation: rewrite this to instead query informtation_schema |
| 289 | + # Potential optimisation: rewrite this to instead query information_schema |
289 | 290 | tables_minus_views = [
|
290 | 291 | row.tableName for row in tables_result if row.tableName not in views_result
|
291 | 292 | ]
|
@@ -328,7 +329,7 @@ def get_materialized_view_names(
|
328 | 329 | def get_temp_view_names(
|
329 | 330 | self, connection: Connection, schema: Optional[str] = None, **kw: Any
|
330 | 331 | ) -> List[str]:
|
331 |
| - """A wrapper around get_view_names taht fetches only the names of temporary views""" |
| 332 | + """A wrapper around get_view_names that fetches only the names of temporary views""" |
332 | 333 | return self.get_view_names(connection, schema, only_temp=True)
|
333 | 334 |
|
334 | 335 | def do_rollback(self, dbapi_connection):
|
@@ -375,6 +376,87 @@ def get_schema_names(self, connection, **kw):
|
375 | 376 | schema_list = [row[0] for row in result]
|
376 | 377 | return schema_list
|
377 | 378 |
|
| 379 | + def get_multi_table_comment( |
| 380 | + self, |
| 381 | + connection, |
| 382 | + schema=None, |
| 383 | + filter_names=None, |
| 384 | + scope=ObjectScope.ANY, |
| 385 | + kind=ObjectKind.ANY, |
| 386 | + **kw, |
| 387 | + ) -> Iterable[Tuple[TableKey, ReflectedTableComment]]: |
| 388 | + result = [] |
| 389 | + _schema = schema or self.schema |
| 390 | + if ObjectScope.DEFAULT in scope: |
| 391 | + query = ( |
| 392 | + select(tables.c.table_name, tables.c.comment) |
| 393 | + .select_from(tables) |
| 394 | + .where( |
| 395 | + tables.c.table_catalog == self.catalog, |
| 396 | + tables.c.table_schema == _schema, |
| 397 | + ) |
| 398 | + ) |
| 399 | + |
| 400 | + if ObjectKind.ANY not in kind: |
| 401 | + where_in = set() |
| 402 | + if ObjectKind.TABLE in kind: |
| 403 | + where_in.update( |
| 404 | + ["BASE TABLE", "MANAGED", "EXTERNAL", "STREAMING_TABLE"] |
| 405 | + ) |
| 406 | + if ObjectKind.VIEW in kind: |
| 407 | + where_in.update(["VIEW"]) |
| 408 | + if ObjectKind.MATERIALIZED_VIEW in kind: |
| 409 | + where_in.update(["MATERIALIZED_VIEW"]) |
| 410 | + query = query.where(tables.c.table_type.in_(where_in)) |
| 411 | + |
| 412 | + if filter_names: |
| 413 | + query = query.where(tables.c.table_name.in_(bindparam("filter_names"))) |
| 414 | + result = connection.execute( |
| 415 | + query, {"filter_names": [f.lower() for f in filter_names]} |
| 416 | + ) |
| 417 | + else: |
| 418 | + result = connection.execute(query) |
| 419 | + |
| 420 | + if ObjectScope.TEMPORARY in scope and ObjectKind.VIEW in kind: |
| 421 | + result = list(result) |
| 422 | + temp_views = self.get_view_names(connection, schema, only_temp=True) |
| 423 | + if filter_names: |
| 424 | + temp_views = set(temp_views).intersection( |
| 425 | + [f.lower() for f in filter_names] |
| 426 | + ) |
| 427 | + result.extend(zip(temp_views, [None] * len(temp_views))) |
| 428 | + |
| 429 | + # TODO: figure out how to return sqlalchemy.interfaces in a way that mypy respects |
| 430 | + return ( |
| 431 | + ( |
| 432 | + (schema, table), |
| 433 | + {"text": comment} |
| 434 | + if comment is not None |
| 435 | + else ReflectionDefaults.table_comment(), |
| 436 | + ) |
| 437 | + for table, comment in result |
| 438 | + ) # type: ignore |
| 439 | + |
| 440 | + def get_table_comment( |
| 441 | + self, |
| 442 | + connection: Connection, |
| 443 | + table_name: str, |
| 444 | + schema: Optional[str] = None, |
| 445 | + **kw: Any, |
| 446 | + ) -> ReflectedTableComment: |
| 447 | + data = self.get_multi_table_comment( |
| 448 | + connection, |
| 449 | + schema, |
| 450 | + [table_name], |
| 451 | + **kw, |
| 452 | + ) |
| 453 | + try: |
| 454 | + return dict(data)[(schema, table_name.lower())] |
| 455 | + except KeyError: |
| 456 | + raise exc.NoSuchTableError( |
| 457 | + f"{schema}.{table_name}" if schema else table_name |
| 458 | + ) from None |
| 459 | + |
378 | 460 |
|
379 | 461 | @event.listens_for(Engine, "do_connect")
|
380 | 462 | def receive_do_connect(dialect, conn_rec, cargs, cparams):
|
|
0 commit comments