diff --git a/src/server/_common.py b/src/server/_common.py index 56d4c38ec..f7c28c7ef 100644 --- a/src/server/_common.py +++ b/src/server/_common.py @@ -3,7 +3,7 @@ from flask import Flask, g, request from sqlalchemy import event -from sqlalchemy.engine import Connection +from sqlalchemy.engine import Connection, Engine from werkzeug.exceptions import Unauthorized from werkzeug.local import LocalProxy @@ -85,12 +85,12 @@ def log_info_with_request_and_response(message, response, **kwargs): **kwargs ) -@event.listens_for(engine, "before_cursor_execute") +@event.listens_for(Engine, "before_cursor_execute") def before_cursor_execute(conn, cursor, statement, parameters, context, executemany): context._query_start_time = time.time() -@event.listens_for(engine, "after_cursor_execute") +@event.listens_for(Engine, "after_cursor_execute") def after_cursor_execute(conn, cursor, statement, parameters, context, executemany): # this timing info may be suspect, at least in terms of dbms cpu time... # it is likely that it includes that time as well as any overhead that @@ -101,7 +101,8 @@ def after_cursor_execute(conn, cursor, statement, parameters, context, executema # Convert to milliseconds total_time *= 1000 get_structured_logger("server_api").info( - "Executed SQL", statement=statement, params=parameters, elapsed_time_ms=total_time + "Executed SQL", statement=statement, params=parameters, elapsed_time_ms=total_time, + engine_id=conn.get_execution_options().get('engine_id') ) diff --git a/src/server/_db.py b/src/server/_db.py index 53e632cdf..e65c885ff 100644 --- a/src/server/_db.py +++ b/src/server/_db.py @@ -1,4 +1,7 @@ -from sqlalchemy import create_engine, MetaData +import functools +from inspect import signature, Parameter + +from sqlalchemy import create_engine from sqlalchemy.engine import Engine from sqlalchemy.orm import sessionmaker @@ -9,15 +12,57 @@ # previously `_common` imported from `_security` which imported from `admin.models`, which imported (back again) from `_common` for database connection objects -engine: Engine = create_engine(SQLALCHEMY_DATABASE_URI, **SQLALCHEMY_ENGINE_OPTIONS) -if SQLALCHEMY_DATABASE_URI_PRIMARY: - user_engine: Engine = create_engine(SQLALCHEMY_DATABASE_URI_PRIMARY, **SQLALCHEMY_ENGINE_OPTIONS) -else: - user_engine: Engine = engine +# a decorator to automatically provide a sqlalchemy session by default, if an existing session is not explicitly +# specified to override it. it is preferred to use a single session for a sequence of operations logically grouped +# together, but this allows individual operations to be run by themselves without having to provide an +# already-established session. requires an argument to the wrapped function named 'session'. +# for instance: +# +# @default_session(WriteSession) +# def foo(session): +# pass +# +# # calling: +# foo() +# # is identical to: +# with WriteSession() as s: +# foo(s) +def default_session(sess): + def decorator__default_session(func): + # make sure `func` is compatible w/ this decorator + func_params = signature(func).parameters + if 'session' not in func_params or func_params['session'].kind == Parameter.POSITIONAL_ONLY: + raise Exception(f"@default_session(): function {func.__name__}() must accept an argument 'session' that can be specified by keyword.") + # save position of 'session' arg, to later check if its been passed in by position/order + sess_index = list(func_params).index('session') + + @functools.wraps(func) + def wrapper__default_session(*args, **kwargs): + if 'session' in kwargs or len(args) >= sess_index+1: + # 'session' has been specified by the caller, so we have nothing to do here. pass along all args unchanged. + return func(*args, **kwargs) + # otherwise, we will wrap this call with a context manager for the default session provider, and pass that session instance to the wrapped function. + with sess() as session: + return func(*args, **kwargs, session=session) -metadata = MetaData(bind=user_engine) + return wrapper__default_session -Session = sessionmaker(bind=user_engine) + return decorator__default_session +engine: Engine = create_engine(SQLALCHEMY_DATABASE_URI, **SQLALCHEMY_ENGINE_OPTIONS, execution_options={'engine_id': 'default'}) +Session = sessionmaker(bind=engine) + +if SQLALCHEMY_DATABASE_URI_PRIMARY and SQLALCHEMY_DATABASE_URI_PRIMARY != SQLALCHEMY_DATABASE_URI: + # if available, use the main/primary DB for write operations. DB replication processes should be in place to + # propagate any written changes to the regular (load balanced) replicas. + write_engine: Engine = create_engine(SQLALCHEMY_DATABASE_URI_PRIMARY, **SQLALCHEMY_ENGINE_OPTIONS, execution_options={'engine_id': 'write_engine'}) + WriteSession = sessionmaker(bind=write_engine) + # TODO: insert log statement acknowledging this second session handle is in use? +else: + write_engine: Engine = engine + WriteSession = Session +# NOTE: `WriteSession` could be called `AdminSession`, as its only (currently) used by the admin page, and the admin +# page is the only thing that should be writing to the db. its tempting to let the admin page read from the +# regular `Session` and write with `WriteSession`, but concurrency problems may arise from sync/replication lag. diff --git a/src/server/_security.py b/src/server/_security.py index 761d088c3..61e2608b2 100644 --- a/src/server/_security.py +++ b/src/server/_security.py @@ -16,7 +16,7 @@ TEMPORARY_API_KEY, URL_PREFIX, ) -from .admin.models import User, UserRole +from .admin.models import User API_KEY_HARD_WARNING = API_KEY_REQUIRED_STARTING_AT - timedelta(days=14) API_KEY_SOFT_WARNING = API_KEY_HARD_WARNING - timedelta(days=14) @@ -91,10 +91,6 @@ def _get_current_user(): current_user: User = cast(User, LocalProxy(_get_current_user)) -def register_user_role(role_name: str) -> None: - UserRole.create_role(role_name) - - def _is_public_route() -> bool: public_routes_list = ["lib", "admin", "version"] for route in public_routes_list: diff --git a/src/server/admin/models.py b/src/server/admin/models.py index 62cbc186d..f5c0d54ed 100644 --- a/src/server/admin/models.py +++ b/src/server/admin/models.py @@ -3,7 +3,7 @@ from sqlalchemy.orm import relationship from copy import deepcopy -from .._db import Session +from .._db import Session, WriteSession, default_session from delphi.epidata.common.logger import get_structured_logger from typing import Set, Optional, List @@ -25,7 +25,7 @@ def _default_date_now(): class User(Base): __tablename__ = "api_user" id = Column(Integer, primary_key=True, autoincrement=True) - roles = relationship("UserRole", secondary=association_table) + roles = relationship("UserRole", secondary=association_table, lazy="joined") # last arg does an eager load of this property from foreign tables api_key = Column(String(50), unique=True, nullable=False) email = Column(String(320), unique=True, nullable=False) created = Column(Date, default=_default_date_now) @@ -35,97 +35,85 @@ def __init__(self, api_key: str, email: str = None) -> None: self.api_key = api_key self.email = email - @staticmethod - def list_users() -> List["User"]: - with Session() as session: - return session.query(User).all() - @property def as_dict(self): return { "id": self.id, "api_key": self.api_key, "email": self.email, - "roles": User.get_user_roles(self.id), + "roles": set(role.name for role in self.roles), "created": self.created, "last_time_used": self.last_time_used } - @staticmethod - def get_user_roles(user_id: int) -> Set[str]: - with Session() as session: - user = session.query(User).filter(User.id == user_id).first() - return set([role.name for role in user.roles]) - def has_role(self, required_role: str) -> bool: - return required_role in User.get_user_roles(self.id) + return required_role in set(role.name for role in self.roles) @staticmethod def _assign_roles(user: "User", roles: Optional[Set[str]], session) -> None: - # NOTE: this uses a borrowed/existing `session`, and thus does not do a `session.commit()`... - # that is the responsibility of the caller! get_structured_logger("api_user_models").info("setting roles", roles=roles, user_id=user.id, api_key=user.api_key) db_user = session.query(User).filter(User.id == user.id).first() # TODO: would it be sufficient to use the passed-in `user` instead of looking up this `db_user`? + # or even use this as a bound method instead of a static?? + # same goes for `update_user()` and `delete_user()` below... if roles: - roles_to_assign = session.query(UserRole).filter(UserRole.name.in_(roles)).all() - db_user.roles = roles_to_assign + db_user.roles = session.query(UserRole).filter(UserRole.name.in_(roles)).all() else: db_user.roles = [] + session.commit() + # retrieve the newly updated User object + return session.query(User).filter(User.id == user.id).first() @staticmethod + @default_session(Session) def find_user(*, # asterisk forces explicit naming of all arguments when calling this method - user_id: Optional[int] = None, api_key: Optional[str] = None, user_email: Optional[str] = None + session, + user_id: Optional[int] = None, api_key: Optional[str] = None, user_email: Optional[str] = None ) -> "User": # NOTE: be careful, using multiple arguments could match multiple users, but this will return only one! - with Session() as session: - user = ( - session.query(User) - .filter((User.id == user_id) | (User.api_key == api_key) | (User.email == user_email)) - .first() - ) + user = ( + session.query(User) + .filter((User.id == user_id) | (User.api_key == api_key) | (User.email == user_email)) + .first() + ) return user if user else None @staticmethod - def create_user(api_key: str, email: str, user_roles: Optional[Set[str]] = None) -> "User": + @default_session(WriteSession) + def create_user(api_key: str, email: str, session, user_roles: Optional[Set[str]] = None) -> "User": get_structured_logger("api_user_models").info("creating user", api_key=api_key) - with Session() as session: - new_user = User(api_key=api_key, email=email) - # TODO: we may need to populate 'created' field/column here, if the default - # specified above gets bound to the time of when that line of python was evaluated. - session.add(new_user) - session.commit() - User._assign_roles(new_user, user_roles, session) - session.commit() - return new_user + new_user = User(api_key=api_key, email=email) + session.add(new_user) + session.commit() + return User._assign_roles(new_user, user_roles, session) @staticmethod + @default_session(WriteSession) def update_user( user: "User", email: Optional[str], api_key: Optional[str], - roles: Optional[Set[str]] + roles: Optional[Set[str]], + session ) -> "User": get_structured_logger("api_user_models").info("updating user", user_id=user.id, new_api_key=api_key) - with Session() as session: - user = User.find_user(user_id=user.id) - if user: - update_stmt = ( - update(User) - .where(User.id == user.id) - .values(api_key=api_key, email=email) - ) - session.execute(update_stmt) - User._assign_roles(user, roles, session) - session.commit() - return user + user = User.find_user(user_id=user.id, session=session) + if not user: + raise Exception('user not found') + update_stmt = ( + update(User) + .where(User.id == user.id) + .values(api_key=api_key, email=email) + ) + session.execute(update_stmt) + return User._assign_roles(user, roles, session) @staticmethod - def delete_user(user_id: int) -> None: + @default_session(WriteSession) + def delete_user(user_id: int, session) -> None: get_structured_logger("api_user_models").info("deleting user", user_id=user_id) - with Session() as session: - session.execute(delete(User).where(User.id == user_id)) - session.commit() + session.execute(delete(User).where(User.id == user_id)) + session.commit() class UserRole(Base): @@ -134,23 +122,23 @@ class UserRole(Base): name = Column(String(50), unique=True) @staticmethod - def create_role(name: str) -> None: + @default_session(WriteSession) + def create_role(name: str, session) -> None: get_structured_logger("api_user_models").info("creating user role", role=name) - with Session() as session: - session.execute( - f""" + # TODO: check role doesnt already exist + session.execute(f""" INSERT INTO user_role (name) SELECT '{name}' WHERE NOT EXISTS (SELECT * FROM user_role WHERE name='{name}') - """ - ) - session.commit() + """) + session.commit() + return session.query(UserRole).filter(UserRole.name == name).first() @staticmethod - def list_all_roles(): - with Session() as session: - roles = session.query(UserRole).all() + @default_session(Session) + def list_all_roles(session): + roles = session.query(UserRole).all() return [role.name for role in roles] diff --git a/src/server/endpoints/admin.py b/src/server/endpoints/admin.py index 17bc9ca9b..a6f941b48 100644 --- a/src/server/endpoints/admin.py +++ b/src/server/endpoints/admin.py @@ -7,6 +7,7 @@ from .._common import log_info_with_request from .._config import ADMIN_PASSWORD, API_KEY_REGISTRATION_FORM_LINK, API_KEY_REMOVAL_REQUEST_LINK, REGISTER_WEBHOOK_TOKEN +from .._db import WriteSession from .._security import resolve_auth_token from ..admin.models import User, UserRole @@ -29,22 +30,13 @@ def _require_admin(): return token -def _parse_roles(roles: List[str]) -> Set[str]: - return set(roles) - - -def _render(mode: str, token: str, flags: Dict, **kwargs): +def _render(mode: str, token: str, flags: Dict, session, **kwargs): template = (templates_dir / "index.html").read_text("utf8") return render_template_string( - template, mode=mode, token=token, flags=flags, roles=UserRole.list_all_roles(), **kwargs + template, mode=mode, token=token, flags=flags, roles=UserRole.list_all_roles(session), **kwargs ) -def user_exists(user_email: str = None, api_key: str = None): - user = User.find_user(user_email=user_email, api_key=api_key) - return True if user else False - - # ~~~~ PUBLIC ROUTES ~~~~ @@ -67,44 +59,50 @@ def removal_request_redirect(): def _index(): token = _require_admin() flags = dict() - if request.method == "POST": - # register a new user - if not user_exists(user_email=request.values["email"], api_key=request.values["api_key"]): - User.create_user( - request.values["api_key"], - request.values["email"], - _parse_roles(request.values.getlist("roles")), - ) - flags["banner"] = "Successfully Added" - else: - flags["banner"] = "User with such email and/or api key already exists." - users = [user.as_dict for user in User.list_users()] - return _render("overview", token, flags, users=users, user=dict()) + with WriteSession() as session: + if request.method == "POST": + # register a new user + if not User.find_user( + user_email=request.values["email"], api_key=request.values["api_key"], + session=session): + User.create_user( + api_key=request.values["api_key"], + email=request.values["email"], + user_roles=set(request.values.getlist("roles")), + session=session + ) + flags["banner"] = "Successfully Added" + else: + flags["banner"] = "User with such email and/or api key already exists." + users = [user.as_dict for user in session.query(User).all()] + return _render("overview", token, flags, session=session, users=users, user=dict()) @bp.route("/", methods=["GET", "PUT", "POST", "DELETE"]) def _detail(user_id: int): token = _require_admin() - user = User.find_user(user_id=user_id) - if not user: - raise NotFound() - if request.method == "DELETE" or "delete" in request.values: - User.delete_user(user.id) - return redirect(f"./?auth={token}") - flags = dict() - if request.method in ["PUT", "POST"]: - user_check = User.find_user(api_key=request.values["api_key"], user_email=request.values["email"]) - if user_check and user_check.id != user.id: - flags["banner"] = "Could not update user; same api_key and/or email already exists." - else: - user = user.update_user( - user=user, - api_key=request.values["api_key"], - email=request.values["email"], - roles=_parse_roles(request.values.getlist("roles")), - ) - flags["banner"] = "Successfully Saved" - return _render("detail", token, flags, user=user.as_dict) + with WriteSession() as session: + user = User.find_user(user_id=user_id, session=session) + if not user: + raise NotFound() + if request.method == "DELETE" or "delete" in request.values: + User.delete_user(user.id, session=session) + return redirect(f"./?auth={token}") + flags = dict() + if request.method in ["PUT", "POST"]: + user_check = User.find_user(api_key=request.values["api_key"], user_email=request.values["email"], session=session) + if user_check and user_check.id != user.id: + flags["banner"] = "Could not update user; same api_key and/or email already exists." + else: + user = User.update_user( + user=user, + api_key=request.values["api_key"], + email=request.values["email"], + roles=set(request.values.getlist("roles")), + session=session + ) + flags["banner"] = "Successfully Saved" + return _render("detail", token, flags, session=session, user=user.as_dict) @bp.route("/register", methods=["POST"]) @@ -116,12 +114,13 @@ def _register(): user_api_key = body["user_api_key"] user_email = body["user_email"] - if user_exists(user_email=user_email, api_key=user_api_key): - return make_response( - "User with email and/or API Key already exists, use different parameters or contact us for help", - 409, - ) - User.create_user(api_key=user_api_key, email=user_email) + with WriteSession() as session: + if User.find_user(user_email=user_email, api_key=user_api_key, session=session): + return make_response( + "User with email and/or API Key already exists, use different parameters or contact us for help", + 409, + ) + User.create_user(api_key=user_api_key, email=user_email, session=session) return make_response(f"Successfully registered API key '{user_api_key}'", 200) diff --git a/src/server/main.py b/src/server/main.py index c05b9d0d3..a91a91ee2 100644 --- a/src/server/main.py +++ b/src/server/main.py @@ -8,11 +8,9 @@ from ._config import URL_PREFIX, VERSION from ._common import app, set_compatibility_mode -from ._db import metadata, engine from ._exceptions import MissingOrWrongSourceException from .endpoints import endpoints from .endpoints.admin import bp as admin_bp, enable_admin -from ._security import register_user_role from ._limiter import limiter, apply_limit __all__ = ["app"] @@ -65,8 +63,6 @@ def send_lib_file(path: str): return send_from_directory(pathlib.Path(__file__).parent / "lib", path) -metadata.create_all(engine) - if __name__ == "__main__": app.run(host="0.0.0.0", port=5000) else: