diff --git a/README.md b/README.md index cf9c3cc9..99ff33a3 100644 --- a/README.md +++ b/README.md @@ -176,6 +176,41 @@ app.router.add_get("/", handler) web.run_app(app) ``` +**Use SQLAlchemy ORM** +The SQLAlchemy integration requires you to override the Session and Query Classes for SQL Alchemy + +SQLAlchemy integration uses subsegments so you need to have a segment started before you make a query. + +```python +from aws_xray_sdk.core import xray_recorder +from aws_xray_sdk.ext.sqlalchemy.query import XRaySessionMaker + +xray_recorder.begin_segment('SQLAlchemyTest') + +Session = XRaySessionMaker(bind=engine) +session = Session() + +xray_recorder.end_segment() +app = Flask(__name__) + +xray_recorder.configure(service='fallback_name', dynamic_naming='*mysite.com*') +XRayMiddleware(app, xray_recorder) +``` + +**Add Flask-SQLAlchemy** + +```python +from aws_xray_sdk.core import xray_recorder +from aws_xray_sdk.ext.flask.middleware import XRayMiddleware +from aws_xray_sdk.ext.flask_sqlalchemy.query import XRayFlaskSqlAlchemy + +app = Flask(__name__) +app.config["SQLALCHEMY_DATABASE_URI"] = "sqlite:///:memory:" + +XRayMiddleware(app, xray_recorder) +db = XRayFlaskSqlAlchemy(app) + +``` ## License The AWS X-Ray SDK for Python is licensed under the Apache 2.0 License. See LICENSE and NOTICE.txt for more information. diff --git a/aws_xray_sdk/ext/flask_sqlalchemy/__init__.py b/aws_xray_sdk/ext/flask_sqlalchemy/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/aws_xray_sdk/ext/flask_sqlalchemy/query.py b/aws_xray_sdk/ext/flask_sqlalchemy/query.py new file mode 100644 index 00000000..8fb0dcc9 --- /dev/null +++ b/aws_xray_sdk/ext/flask_sqlalchemy/query.py @@ -0,0 +1,59 @@ +from builtins import super +from flask_sqlalchemy.model import Model +from sqlalchemy.orm.session import sessionmaker +from flask_sqlalchemy import SQLAlchemy, BaseQuery, _SessionSignalEvents, get_state +from aws_xray_sdk.ext.sqlalchemy.query import XRaySession, XRayQuery +from aws_xray_sdk.ext.sqlalchemy.util.decerators import xray_on_call, decorate_all_functions + + +@decorate_all_functions(xray_on_call) +class XRayBaseQuery(BaseQuery): + BaseQuery.__bases__ = (XRayQuery,) + + +class XRaySignallingSession(XRaySession): + """The signalling session is the default session that Flask-SQLAlchemy + uses. It extends the default session system with bind selection and + modification tracking. + If you want to use a different session you can override the + :meth:`SQLAlchemy.create_session` function. + .. versionadded:: 2.0 + .. versionadded:: 2.1 + The `binds` option was added, which allows a session to be joined + to an external transaction. + """ + + def __init__(self, db, autocommit=False, autoflush=True, **options): + #: The application that this session belongs to. + self.app = app = db.get_app() + track_modifications = app.config['SQLALCHEMY_TRACK_MODIFICATIONS'] + bind = options.pop('bind', None) or db.engine + binds = options.pop('binds', db.get_binds(app)) + + if track_modifications is None or track_modifications: + _SessionSignalEvents.register(self) + + XRaySession.__init__( + self, autocommit=autocommit, autoflush=autoflush, + bind=bind, binds=binds, **options + ) + + def get_bind(self, mapper=None, clause=None): + # mapper is None if someone tries to just get a connection + if mapper is not None: + info = getattr(mapper.mapped_table, 'info', {}) + bind_key = info.get('bind_key') + if bind_key is not None: + state = get_state(self.app) + return state.db.get_engine(self.app, bind=bind_key) + return XRaySession.get_bind(self, mapper, clause) + + +class XRayFlaskSqlAlchemy(SQLAlchemy): + def __init__(self, app=None, use_native_unicode=True, session_options=None, + metadata=None, query_class=XRayBaseQuery, model_class=Model): + super().__init__(app, use_native_unicode, session_options, + metadata, query_class, model_class) + + def create_session(self, options): + return sessionmaker(class_=XRaySignallingSession, db=self, **options) diff --git a/aws_xray_sdk/ext/sqlalchemy/__init__.py b/aws_xray_sdk/ext/sqlalchemy/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/aws_xray_sdk/ext/sqlalchemy/query.py b/aws_xray_sdk/ext/sqlalchemy/query.py new file mode 100644 index 00000000..2a687d06 --- /dev/null +++ b/aws_xray_sdk/ext/sqlalchemy/query.py @@ -0,0 +1,25 @@ +from builtins import super +from sqlalchemy.orm.query import Query +from sqlalchemy.orm.session import Session, sessionmaker +from .util.decerators import xray_on_call, decorate_all_functions + + +@decorate_all_functions(xray_on_call) +class XRaySession(Session): + pass + + +@decorate_all_functions(xray_on_call) +class XRayQuery(Query): + pass + + +@decorate_all_functions(xray_on_call) +class XRaySessionMaker(sessionmaker): + def __init__(self, bind=None, class_=XRaySession, autoflush=True, + autocommit=False, + expire_on_commit=True, + info=None, **kw): + kw['query_cls'] = XRayQuery + super().__init__(bind, class_, autoflush, autocommit, expire_on_commit, + info, **kw) diff --git a/aws_xray_sdk/ext/sqlalchemy/util/__init__.py b/aws_xray_sdk/ext/sqlalchemy/util/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/aws_xray_sdk/ext/sqlalchemy/util/decerators.py b/aws_xray_sdk/ext/sqlalchemy/util/decerators.py new file mode 100644 index 00000000..2d141861 --- /dev/null +++ b/aws_xray_sdk/ext/sqlalchemy/util/decerators.py @@ -0,0 +1,100 @@ +import re +from aws_xray_sdk.core import xray_recorder +from future.standard_library import install_aliases +install_aliases() +from urllib.parse import urlparse, uses_netloc + + + +def decorate_all_functions(function_decorator): + def decorator(cls): + for c in cls.__bases__: + for name, obj in vars(c).items(): + if name.startswith("_"): + continue + if callable(obj): + try: + obj = obj.__func__ # unwrap Python 2 unbound method + except AttributeError: + pass # not needed in Python 3 + setattr(c, name, function_decorator(c, obj)) + return cls + return decorator + +def xray_on_call(cls, func): + def wrapper(*args, **kw): + from ..query import XRayQuery, XRaySession + from ...flask_sqlalchemy.query import XRaySignallingSession + class_name = str(cls.__module__) + c = xray_recorder._context + sql = None + subsegment = None + if class_name == "sqlalchemy.orm.session": + for arg in args: + if isinstance(arg, XRaySession): + sql = parse_bind(arg.bind) + if isinstance(arg, XRaySignallingSession): + sql = parse_bind(arg.bind) + if class_name == 'sqlalchemy.orm.query': + for arg in args: + if isinstance(arg, XRayQuery): + try: + sql = parse_bind(arg.session.bind) + # Commented our for later PR + # sql['sanitized_query'] = str(arg) + except: + sql = None + if sql is not None: + if getattr(c._local, 'entities', None) is not None: + subsegment = xray_recorder.begin_subsegment(sql['url'], namespace='remote') + else: + subsegment = None + res = func(*args, **kw) + if subsegment is not None: + subsegment.set_sql(sql) + subsegment.put_annotation("sqlalchemy", class_name+'.'+func.__name__ ); + xray_recorder.end_subsegment() + return res + return wrapper +# URL Parse output +# scheme 0 URL scheme specifier scheme parameter +# netloc 1 Network location part empty string +# path 2 Hierarchical path empty string +# query 3 Query component empty string +# fragment 4 Fragment identifier empty string +# username User name None +# password Password None +# hostname Host name (lower case) None +# port Port number as integer, if present None +# +# XRAY Trace SQL metaData Sample +# "sql" : { +# "url": "jdbc:postgresql://aawijb5u25wdoy.cpamxznpdoq8.us-west-2.rds.amazonaws.com:5432/ebdb", +# "preparation": "statement", +# "database_type": "PostgreSQL", +# "database_version": "9.5.4", +# "driver_version": "PostgreSQL 9.4.1211.jre7", +# "user" : "dbuser", +# "sanitized_query" : "SELECT * FROM customers WHERE customer_id=?;" +# } +def parse_bind(bind): + """Parses a connection string and creates SQL trace metadata""" + m = re.match(r"Engine\((.*?)\)", str(bind)) + if m is not None: + u = urlparse(m.group(1)) + # Add Scheme to uses_netloc or // will be missing from url. + uses_netloc.append(u.scheme) + safe_url = "" + if u.password is None: + safe_url = u.geturl() + else: + # Strip password from URL + host_info = u.netloc.rpartition('@')[-1] + parts = u._replace(netloc='{}@{}'.format(u.username, host_info)) + safe_url = u.geturl() + sql = {} + sql['database_type'] = u.scheme + sql['url'] = safe_url + if u.username is not None: + sql['user'] = "{}".format(u.username) + return sql diff --git a/setup.py b/setup.py index 0a952af7..a4ceff35 100644 --- a/setup.py +++ b/setup.py @@ -35,7 +35,7 @@ 'Programming Language :: Python :: 3.6', ], - install_requires=['jsonpickle', 'wrapt', 'requests'], + install_requires=['jsonpickle', 'wrapt', 'requests', 'future'], keywords='aws xray sdk', diff --git a/tests/ext/flask_sqlalchemy/__init__.py b/tests/ext/flask_sqlalchemy/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/ext/flask_sqlalchemy/test_query.py b/tests/ext/flask_sqlalchemy/test_query.py new file mode 100644 index 00000000..937b267e --- /dev/null +++ b/tests/ext/flask_sqlalchemy/test_query.py @@ -0,0 +1,56 @@ +from __future__ import absolute_import +import pytest +from aws_xray_sdk.core import xray_recorder +from aws_xray_sdk.core.context import Context +from aws_xray_sdk.ext.flask_sqlalchemy.query import XRayFlaskSqlAlchemy +from flask import Flask +from ...util import find_subsegment_by_annotation + + +app = Flask(__name__) +app.config["SQLALCHEMY_TRACK_MODIFICATIONS"] = False +app.config["SQLALCHEMY_DATABASE_URI"] = "sqlite:///:memory:" +db = XRayFlaskSqlAlchemy(app) + + +class User(db.Model): + __tablename__ = "users" + + id = db.Column(db.Integer, primary_key=True) + name = db.Column(db.String(255), nullable=False, unique=True) + fullname = db.Column(db.String(255), nullable=False) + password = db.Column(db.String(255), nullable=False) + + +@pytest.fixture() +def session(): + """Test Fixture to Create DataBase Tables and start a trace segment""" + xray_recorder.configure(service='test', sampling=False, context=Context()) + xray_recorder.clear_trace_entities() + xray_recorder.begin_segment('SQLAlchemyTest') + db.create_all() + yield + xray_recorder.end_segment() + xray_recorder.clear_trace_entities() + + +def test_all(capsys, session): + """ Test calling all() on get all records. + Verify that we capture trace of query and return the SQL as metdata""" + # with capsys.disabled(): + User.query.all() + subsegment = find_subsegment_by_annotation(xray_recorder.current_segment(), 'sqlalchemy', 'sqlalchemy.orm.query.all') + assert subsegment['annotations']['sqlalchemy'] == 'sqlalchemy.orm.query.all' + # assert subsegment['sql']['sanitized_query'] + assert subsegment['sql']['url'] + + +def test_add(capsys, session): + """ Test calling add() on insert a row. + Verify we that we capture trace for the add""" + # with capsys.disabled(): + john = User(name='John', fullname="John Doe", password="password") + db.session.add(john) + subsegment = find_subsegment_by_annotation(xray_recorder.current_segment(), 'sqlalchemy', 'sqlalchemy.orm.session.add') + assert subsegment['annotations']['sqlalchemy'] == 'sqlalchemy.orm.session.add' + assert subsegment['sql']['url'] diff --git a/tests/ext/sqlalchemy/__init__.py b/tests/ext/sqlalchemy/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/ext/sqlalchemy/test_query.py b/tests/ext/sqlalchemy/test_query.py new file mode 100644 index 00000000..f329cadd --- /dev/null +++ b/tests/ext/sqlalchemy/test_query.py @@ -0,0 +1,69 @@ +from __future__ import absolute_import +import pytest +from aws_xray_sdk.core import xray_recorder +from aws_xray_sdk.core.context import Context +from aws_xray_sdk.ext.sqlalchemy.query import XRaySessionMaker +from sqlalchemy.ext.declarative import declarative_base +from sqlalchemy import create_engine, Column, Integer, String +from ...util import find_subsegment_by_annotation + + +Base = declarative_base() + + +class User(Base): + __tablename__ = 'users' + + id = Column(Integer, primary_key=True) + name = Column(String) + fullname = Column(String) + password = Column(String) + + +@pytest.fixture() +def session(): + """Test Fixture to Create DataBase Tables and start a trace segment""" + engine = create_engine('sqlite:///:memory:') + xray_recorder.configure(service='test', sampling=False, context=Context()) + xray_recorder.clear_trace_entities() + xray_recorder.begin_segment('SQLAlchemyTest') + Session = XRaySessionMaker(bind=engine) + Base.metadata.create_all(engine) + session = Session() + yield session + xray_recorder.end_segment() + xray_recorder.clear_trace_entities() + + +def test_all(capsys, session): + """ Test calling all() on get all records. + Verify we run the query and return the SQL as metdata""" + # with capsys.disabled(): + session.query(User).all() + subsegment = find_subsegment_by_annotation(xray_recorder.current_segment(), 'sqlalchemy', 'sqlalchemy.orm.query.all') + assert subsegment['annotations']['sqlalchemy'] == 'sqlalchemy.orm.query.all' + # assert subsegment['sql']['sanitized_query'] + assert subsegment['sql']['url'] + + +def test_add(capsys, session): + """ Test calling add() on insert a row. + Verify we that we capture trace for the add""" + # with capsys.disabled(): + john = User(name='John', fullname="John Doe", password="password") + session.add(john) + subsegment = find_subsegment_by_annotation(xray_recorder.current_segment(), 'sqlalchemy', 'sqlalchemy.orm.session.add') + assert subsegment['annotations']['sqlalchemy'] == 'sqlalchemy.orm.session.add' + assert subsegment['sql']['url'] + + +def test_filter(capsys, session): + """ Test calling all() on get all records. + Verify we run the query and return the SQL as metdata""" + # with capsys.disabled(): + session.query(User).filter(User.password=="mypassword!") + subsegment = find_subsegment_by_annotation(xray_recorder.current_segment(), 'sqlalchemy', 'sqlalchemy.orm.query.filter') + assert subsegment['annotations']['sqlalchemy'] == 'sqlalchemy.orm.query.filter' + # assert subsegment['sql']['sanitized_query'] + # assert "mypassword!" not in subsegment['sql']['sanitized_query'] + assert subsegment['sql']['url'] diff --git a/tests/util.py b/tests/util.py index 229ecfa4..6c9fb166 100644 --- a/tests/util.py +++ b/tests/util.py @@ -1,5 +1,6 @@ import json import threading +import jsonpickle from aws_xray_sdk.core.recorder import AWSXRayRecorder from aws_xray_sdk.core.emitters.udp_emitter import UDPEmitter @@ -43,3 +44,58 @@ def entity_to_dict(trace_entity): raw = json.loads(trace_entity.serialize()) return raw + + +def _search_entity(entity, name): + """Helper function to that recursivly looks at subentities + Returns a serialized entity that matches the name given or None""" + if 'name' in entity: + my_name = entity['name'] + if my_name == name: + return entity + else: + if "subsegments" in entity: + for s in entity['subsegments']: + result = _search_entity(s, name) + if result is not None: + return result + return None + + +def find_subsegment(segment, name): + """Helper function to find a subsegment by name in the entity tree""" + segment = jsonpickle.encode(segment, unpicklable=False) + segment = json.loads(segment) + for entity in segment['subsegments']: + result = _search_entity(entity, name) + if result is not None: + return result + return None + + +def find_subsegment_by_annotation(segment, key, value): + """Helper function to find a subsegment by annoation key & value in the entity tree""" + segment = jsonpickle.encode(segment, unpicklable=False) + segment = json.loads(segment) + for entity in segment['subsegments']: + result = _search_entity_by_annotation(entity, key, value) + if result is not None: + return result + return None + + +def _search_entity_by_annotation(entity, key, value): + """Helper function to that recursivly looks at subentities + Returns a serialized entity that matches the annoation key & value given or None""" + if 'annotations' in entity: + if key in entity['annotations']: + my_value = entity['annotations'][key] + if my_value == value: + return entity + else: + if "subsegments" in entity: + for s in entity['subsegments']: + result = _search_entity_by_annotation(s, key, value) + if result is not None: + return result + return None \ No newline at end of file diff --git a/tox.ini b/tox.ini index c1539742..fb2726f1 100644 --- a/tox.ini +++ b/tox.ini @@ -12,6 +12,9 @@ deps = botocore requests flask >= 0.10 + sqlalchemy + Flask-SQLAlchemy + future # the sdk doesn't support earlier version of django django >= 1.10, <2.0 pynamodb @@ -32,8 +35,8 @@ deps = coverage skip_install = true commands = # might need to add coverage combine at some point - coverage report - coverage html + py{36}: coverage report + py{36}: coverage html [flake8] max-line-length=120