Skip to content

Commit a16ba30

Browse files
authored
Merge pull request #264 from jonathangreen/sqlalchemy_core
Add support for SQLAlchemy Core
2 parents 3c1218e + 0c6b825 commit a16ba30

File tree

5 files changed

+199
-0
lines changed

5 files changed

+199
-0
lines changed

aws_xray_sdk/core/patcher.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
'pymysql',
2525
'psycopg2',
2626
'pg8000',
27+
'sqlalchemy_core',
2728
)
2829

2930
NO_DOUBLE_PATCH = (
@@ -37,6 +38,7 @@
3738
'pymysql',
3839
'psycopg2',
3940
'pg8000',
41+
'sqlalchemy_core',
4042
)
4143

4244
_PATCHED_MODULES = set()
Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
from .patch import patch, unpatch
2+
3+
__all__ = ['patch', 'unpatch']
Lines changed: 86 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,86 @@
1+
import logging
2+
import sys
3+
if sys.version_info >= (3, 0, 0):
4+
from urllib.parse import urlparse, uses_netloc
5+
else:
6+
from urlparse import urlparse, uses_netloc
7+
8+
import wrapt
9+
10+
from aws_xray_sdk.core import xray_recorder
11+
from aws_xray_sdk.core.patcher import _PATCHED_MODULES
12+
from aws_xray_sdk.core.utils import stacktrace
13+
from aws_xray_sdk.ext.util import unwrap
14+
15+
16+
def _sql_meta(instance, args):
17+
try:
18+
metadata = {}
19+
url = urlparse(str(instance.engine.url))
20+
# Add Scheme to uses_netloc or // will be missing from url.
21+
uses_netloc.append(url.scheme)
22+
if url.password is None:
23+
metadata['url'] = url.geturl()
24+
name = url.netloc
25+
else:
26+
# Strip password from URL
27+
host_info = url.netloc.rpartition('@')[-1]
28+
parts = url._replace(netloc='{}@{}'.format(url.username, host_info))
29+
metadata['url'] = parts.geturl()
30+
name = host_info
31+
metadata['user'] = url.username
32+
metadata['database_type'] = instance.engine.name
33+
try:
34+
version = getattr(instance.dialect, '{}_version'.format(instance.engine.driver))
35+
version_str = '.'.join(map(str, version))
36+
metadata['driver_version'] = "{}-{}".format(instance.engine.driver, version_str)
37+
except AttributeError:
38+
metadata['driver_version'] = instance.engine.driver
39+
if instance.dialect.server_version_info is not None:
40+
metadata['database_version'] = '.'.join(map(str, instance.dialect.server_version_info))
41+
if xray_recorder.stream_sql:
42+
metadata['sanitized_query'] = str(args[0])
43+
except Exception:
44+
metadata = None
45+
name = None
46+
logging.getLogger(__name__).exception('Error parsing sql metadata.')
47+
return name, metadata
48+
49+
50+
def _xray_traced_sqlalchemy_execute(wrapped, instance, args, kwargs):
51+
name, sql = _sql_meta(instance, args)
52+
if sql is not None:
53+
subsegment = xray_recorder.begin_subsegment(name, namespace='remote')
54+
else:
55+
subsegment = None
56+
try:
57+
res = wrapped(*args, **kwargs)
58+
except Exception:
59+
if subsegment is not None:
60+
exception = sys.exc_info()[1]
61+
stack = stacktrace.get_stacktrace(limit=xray_recorder._max_trace_back)
62+
subsegment.add_exception(exception, stack)
63+
raise
64+
finally:
65+
if subsegment is not None:
66+
subsegment.set_sql(sql)
67+
xray_recorder.end_subsegment()
68+
return res
69+
70+
71+
def patch():
72+
wrapt.wrap_function_wrapper(
73+
'sqlalchemy.engine.base',
74+
'Connection.execute',
75+
_xray_traced_sqlalchemy_execute
76+
)
77+
78+
79+
def unpatch():
80+
"""
81+
Unpatch any previously patched modules.
82+
This operation is idempotent.
83+
"""
84+
_PATCHED_MODULES.discard('sqlalchemy_core')
85+
import sqlalchemy
86+
unwrap(sqlalchemy.engine.base.Connection, 'execute')

tests/ext/sqlalchemy_core/__init__.py

Whitespace-only changes.
Lines changed: 108 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,108 @@
1+
from __future__ import absolute_import
2+
3+
import pytest
4+
from sqlalchemy import create_engine, Column, Integer, String
5+
from sqlalchemy.ext.declarative import declarative_base
6+
from sqlalchemy.orm import sessionmaker
7+
from sqlalchemy.sql.expression import Insert, Delete
8+
9+
from aws_xray_sdk.core import xray_recorder, patch
10+
from aws_xray_sdk.core.context import Context
11+
12+
Base = declarative_base()
13+
14+
15+
class User(Base):
16+
__tablename__ = 'users'
17+
18+
id = Column(Integer, primary_key=True)
19+
name = Column(String)
20+
fullname = Column(String)
21+
password = Column(String)
22+
23+
24+
@pytest.fixture()
25+
def engine():
26+
"""
27+
Clean up context storage on each test run and begin a segment
28+
so that later subsegment can be attached. After each test run
29+
it cleans up context storage again.
30+
"""
31+
from aws_xray_sdk.ext.sqlalchemy_core import unpatch
32+
patch(('sqlalchemy_core',))
33+
engine = create_engine('sqlite:///:memory:')
34+
xray_recorder.configure(service='test', sampling=False, context=Context())
35+
xray_recorder.begin_segment('name')
36+
Base.metadata.create_all(engine)
37+
xray_recorder.clear_trace_entities()
38+
xray_recorder.begin_segment('name')
39+
yield engine
40+
xray_recorder.clear_trace_entities()
41+
unpatch()
42+
43+
44+
@pytest.fixture()
45+
def connection(engine):
46+
return engine.connect()
47+
48+
49+
@pytest.fixture()
50+
def session(engine):
51+
Session = sessionmaker(bind=engine)
52+
return Session()
53+
54+
55+
def test_all(session):
56+
""" Test calling all() on get all records.
57+
Verify we run the query and return the SQL as metdata"""
58+
session.query(User).all()
59+
assert len(xray_recorder.current_segment().subsegments) == 1
60+
sql_meta = xray_recorder.current_segment().subsegments[0].sql
61+
assert sql_meta['url'] == 'sqlite:///:memory:'
62+
assert sql_meta['sanitized_query'].startswith('SELECT')
63+
assert sql_meta['sanitized_query'].endswith('FROM users')
64+
65+
66+
def test_add(session):
67+
""" Test calling add() on insert a row.
68+
Verify we that we capture trace for the add"""
69+
password = "123456"
70+
john = User(name='John', fullname="John Doe", password=password)
71+
session.add(john)
72+
session.commit()
73+
assert len(xray_recorder.current_segment().subsegments) == 1
74+
sql_meta = xray_recorder.current_segment().subsegments[0].sql
75+
assert sql_meta['sanitized_query'].startswith('INSERT INTO users')
76+
assert password not in sql_meta['sanitized_query']
77+
78+
79+
def test_filter_first(session):
80+
""" Test calling filter().first() on get first filtered records.
81+
Verify we run the query and return the SQL as metdata"""
82+
session.query(User).filter(User.password=="mypassword!").first()
83+
assert len(xray_recorder.current_segment().subsegments) == 1
84+
sql_meta = xray_recorder.current_segment().subsegments[0].sql
85+
assert sql_meta['sanitized_query'].startswith('SELECT')
86+
assert 'FROM users' in sql_meta['sanitized_query']
87+
assert "mypassword!" not in sql_meta['sanitized_query']
88+
89+
90+
def test_connection_add(connection):
91+
password = "123456"
92+
statement = Insert(User).values(name='John', fullname="John Doe", password=password)
93+
connection.execute(statement)
94+
assert len(xray_recorder.current_segment().subsegments) == 1
95+
sql_meta = xray_recorder.current_segment().subsegments[0].sql
96+
assert sql_meta['sanitized_query'].startswith('INSERT INTO users')
97+
assert sql_meta['url'] == 'sqlite:///:memory:'
98+
assert password not in sql_meta['sanitized_query']
99+
100+
def test_connection_query(connection):
101+
password = "123456"
102+
statement = Delete(User).where(User.name == 'John').where(User.password == password)
103+
connection.execute(statement)
104+
assert len(xray_recorder.current_segment().subsegments) == 1
105+
sql_meta = xray_recorder.current_segment().subsegments[0].sql
106+
assert sql_meta['sanitized_query'].startswith('DELETE FROM users')
107+
assert sql_meta['url'] == 'sqlite:///:memory:'
108+
assert password not in sql_meta['sanitized_query']

0 commit comments

Comments
 (0)