Skip to content

Add patch support for pymysql (pure Python driver) #215

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 10 commits into from
Apr 10, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions aws_xray_sdk/core/patcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
'mysql',
'httplib',
'pymongo',
'pymysql',
'psycopg2',
'pg8000',
)
Expand All @@ -33,6 +34,7 @@
'sqlite3',
'mysql',
'pymongo',
'pymysql',
'psycopg2',
'pg8000',
)
Expand Down
4 changes: 4 additions & 0 deletions aws_xray_sdk/ext/pymysql/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
from .patch import patch, unpatch


__all__ = ['patch', 'unpatch']
53 changes: 53 additions & 0 deletions aws_xray_sdk/ext/pymysql/patch.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
import pymysql
import wrapt

from aws_xray_sdk.ext.dbapi2 import XRayTracedConn
from aws_xray_sdk.core.patcher import _PATCHED_MODULES
from aws_xray_sdk.ext.util import unwrap


def patch():

wrapt.wrap_function_wrapper(
'pymysql',
'connect',
_xray_traced_connect
)

# patch alias
if hasattr(pymysql, 'Connect'):
pymysql.Connect = pymysql.connect


def _xray_traced_connect(wrapped, instance, args, kwargs):

conn = wrapped(*args, **kwargs)
meta = {
'database_type': 'MySQL',
'user': conn.user.decode('utf-8'),
'driver_version': 'PyMySQL'
}

if hasattr(conn, 'server_version'):
version = sanitize_db_ver(getattr(conn, 'server_version'))
if version:
meta['database_version'] = version

return XRayTracedConn(conn, meta)


def sanitize_db_ver(raw):

if not raw or not isinstance(raw, tuple):
return raw

return '.'.join(str(num) for num in raw)


def unpatch():
"""
Unpatch any previously patched modules.
This operation is idempotent.
"""
_PATCHED_MODULES.discard('pymysql')
unwrap(pymysql, 'connect')
Empty file added tests/ext/pymysql/__init__.py
Empty file.
79 changes: 79 additions & 0 deletions tests/ext/pymysql/test_pymysql.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
import pymysql

import pytest
import testing.mysqld

from aws_xray_sdk.core import patch
from aws_xray_sdk.core import xray_recorder
from aws_xray_sdk.core.context import Context
from aws_xray_sdk.ext.pymysql import unpatch


@pytest.fixture(scope='module', autouse=True)
def patch_module():
patch(('pymysql',))
yield
unpatch()


@pytest.fixture(autouse=True)
def construct_ctx():
"""
Clean up context storage on each test run and begin a segment
so that later subsegment can be attached. After each test run
it cleans up context storage again.
"""
xray_recorder.configure(service='test', sampling=False, context=Context())
xray_recorder.clear_trace_entities()
xray_recorder.begin_segment('name')
yield
xray_recorder.clear_trace_entities()


def test_execute_dsn_kwargs():
q = 'SELECT 1'
with testing.mysqld.Mysqld() as mysqld:
dsn = mysqld.dsn()
conn = pymysql.connect(database=dsn['db'],
user=dsn['user'],
password='',
host=dsn['host'],
port=dsn['port'])
cur = conn.cursor()
cur.execute(q)

subsegment = xray_recorder.current_segment().subsegments[-1]
assert subsegment.name == 'execute'
sql = subsegment.sql
assert sql['database_type'] == 'MySQL'
assert sql['user'] == dsn['user']
assert sql['driver_version'] == 'PyMySQL'
assert sql['database_version']


def test_execute_bad_query():
q = "SELECT blarg"
with testing.mysqld.Mysqld() as mysqld:
dsn = mysqld.dsn()
conn = pymysql.connect(database=dsn['db'],
user=dsn['user'],
password='',
host=dsn['host'],
port=dsn['port'])

cur = conn.cursor()
try:
cur.execute(q)
except Exception:
pass

subsegment = xray_recorder.current_segment().subsegments[-1]
assert subsegment.name == "execute"
sql = subsegment.sql
assert sql['database_type'] == 'MySQL'
assert sql['user'] == dsn['user']
assert sql['driver_version'] == 'PyMySQL'
assert sql['database_version']

exception = subsegment.cause['exceptions'][0]
assert exception.type == 'InternalError'
2 changes: 2 additions & 0 deletions tox.ini
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,11 @@ deps =
django >= 1.10
django-fake-model
pynamodb >= 3.3.1
pymysql
psycopg2
pg8000
testing.postgresql
testing.mysqld
webtest

# Python2 only deps
Expand Down