From d36e29ae6f297cf85c5f8f6c25821e7e474ab5cd Mon Sep 17 00:00:00 2001 From: george haff Date: Thu, 25 May 2023 19:57:45 -0400 Subject: [PATCH 1/6] change 'user_engine' to a 'WriteSession' instead, so the master db connection is used for writes only --- src/server/_db.py | 16 ++++++++-------- src/server/admin/models.py | 10 +++++----- 2 files changed, 13 insertions(+), 13 deletions(-) diff --git a/src/server/_db.py b/src/server/_db.py index 53e632cdf..a0399e469 100644 --- a/src/server/_db.py +++ b/src/server/_db.py @@ -10,14 +10,14 @@ engine: Engine = create_engine(SQLALCHEMY_DATABASE_URI, **SQLALCHEMY_ENGINE_OPTIONS) +metadata = MetaData(bind=engine) +Session = sessionmaker(bind=engine) if SQLALCHEMY_DATABASE_URI_PRIMARY: - user_engine: Engine = create_engine(SQLALCHEMY_DATABASE_URI_PRIMARY, **SQLALCHEMY_ENGINE_OPTIONS) + write_engine: Engine = create_engine(SQLALCHEMY_DATABASE_URI_PRIMARY, **SQLALCHEMY_ENGINE_OPTIONS) + write_metadata = MetaData(bind=write_engine) + WriteSession = sessionmaker(bind=write_engine) else: - user_engine: Engine = engine - -metadata = MetaData(bind=user_engine) - -Session = sessionmaker(bind=user_engine) - - + write_engine: Engine = engine + write_metadata = metadata + WriteSession = Session diff --git a/src/server/admin/models.py b/src/server/admin/models.py index 62cbc186d..056c547fa 100644 --- a/src/server/admin/models.py +++ b/src/server/admin/models.py @@ -3,7 +3,7 @@ from sqlalchemy.orm import relationship from copy import deepcopy -from .._db import Session +from .._db import Session, WriteSession from delphi.epidata.common.logger import get_structured_logger from typing import Set, Optional, List @@ -89,7 +89,7 @@ def find_user(*, # asterisk forces explicit naming of all arguments when calling @staticmethod def create_user(api_key: str, email: str, user_roles: Optional[Set[str]] = None) -> "User": get_structured_logger("api_user_models").info("creating user", api_key=api_key) - with Session() as session: + with WriteSession() as session: new_user = User(api_key=api_key, email=email) # TODO: we may need to populate 'created' field/column here, if the default # specified above gets bound to the time of when that line of python was evaluated. @@ -107,7 +107,7 @@ def update_user( roles: Optional[Set[str]] ) -> "User": get_structured_logger("api_user_models").info("updating user", user_id=user.id, new_api_key=api_key) - with Session() as session: + with WriteSession() as session: user = User.find_user(user_id=user.id) if user: update_stmt = ( @@ -123,7 +123,7 @@ def update_user( @staticmethod def delete_user(user_id: int) -> None: get_structured_logger("api_user_models").info("deleting user", user_id=user_id) - with Session() as session: + with WriteSession() as session: session.execute(delete(User).where(User.id == user_id)) session.commit() @@ -136,7 +136,7 @@ class UserRole(Base): @staticmethod def create_role(name: str) -> None: get_structured_logger("api_user_models").info("creating user role", role=name) - with Session() as session: + with WriteSession() as session: session.execute( f""" INSERT INTO user_role (name) From 03c27f2537cee348736dad93cdf7065d33756562 Mon Sep 17 00:00:00 2001 From: george haff Date: Tue, 13 Jun 2023 16:33:44 -0400 Subject: [PATCH 2/6] make sure sql statements and timing are logged for all engines, plus tag engines with id and log those too, and superfluous user method cleanup --- src/server/_common.py | 9 +++++---- src/server/_db.py | 4 ++-- src/server/endpoints/admin.py | 17 ++++------------- 3 files changed, 11 insertions(+), 19 deletions(-) diff --git a/src/server/_common.py b/src/server/_common.py index 56d4c38ec..f7c28c7ef 100644 --- a/src/server/_common.py +++ b/src/server/_common.py @@ -3,7 +3,7 @@ from flask import Flask, g, request from sqlalchemy import event -from sqlalchemy.engine import Connection +from sqlalchemy.engine import Connection, Engine from werkzeug.exceptions import Unauthorized from werkzeug.local import LocalProxy @@ -85,12 +85,12 @@ def log_info_with_request_and_response(message, response, **kwargs): **kwargs ) -@event.listens_for(engine, "before_cursor_execute") +@event.listens_for(Engine, "before_cursor_execute") def before_cursor_execute(conn, cursor, statement, parameters, context, executemany): context._query_start_time = time.time() -@event.listens_for(engine, "after_cursor_execute") +@event.listens_for(Engine, "after_cursor_execute") def after_cursor_execute(conn, cursor, statement, parameters, context, executemany): # this timing info may be suspect, at least in terms of dbms cpu time... # it is likely that it includes that time as well as any overhead that @@ -101,7 +101,8 @@ def after_cursor_execute(conn, cursor, statement, parameters, context, executema # Convert to milliseconds total_time *= 1000 get_structured_logger("server_api").info( - "Executed SQL", statement=statement, params=parameters, elapsed_time_ms=total_time + "Executed SQL", statement=statement, params=parameters, elapsed_time_ms=total_time, + engine_id=conn.get_execution_options().get('engine_id') ) diff --git a/src/server/_db.py b/src/server/_db.py index a0399e469..a784031a5 100644 --- a/src/server/_db.py +++ b/src/server/_db.py @@ -9,12 +9,12 @@ # previously `_common` imported from `_security` which imported from `admin.models`, which imported (back again) from `_common` for database connection objects -engine: Engine = create_engine(SQLALCHEMY_DATABASE_URI, **SQLALCHEMY_ENGINE_OPTIONS) +engine: Engine = create_engine(SQLALCHEMY_DATABASE_URI, **SQLALCHEMY_ENGINE_OPTIONS).execution_options(engine_id='default') metadata = MetaData(bind=engine) Session = sessionmaker(bind=engine) if SQLALCHEMY_DATABASE_URI_PRIMARY: - write_engine: Engine = create_engine(SQLALCHEMY_DATABASE_URI_PRIMARY, **SQLALCHEMY_ENGINE_OPTIONS) + write_engine: Engine = create_engine(SQLALCHEMY_DATABASE_URI_PRIMARY, **SQLALCHEMY_ENGINE_OPTIONS).execution_options(engine_id='write_engine') write_metadata = MetaData(bind=write_engine) WriteSession = sessionmaker(bind=write_engine) else: diff --git a/src/server/endpoints/admin.py b/src/server/endpoints/admin.py index 17bc9ca9b..e3af3fe0d 100644 --- a/src/server/endpoints/admin.py +++ b/src/server/endpoints/admin.py @@ -29,10 +29,6 @@ def _require_admin(): return token -def _parse_roles(roles: List[str]) -> Set[str]: - return set(roles) - - def _render(mode: str, token: str, flags: Dict, **kwargs): template = (templates_dir / "index.html").read_text("utf8") return render_template_string( @@ -40,11 +36,6 @@ def _render(mode: str, token: str, flags: Dict, **kwargs): ) -def user_exists(user_email: str = None, api_key: str = None): - user = User.find_user(user_email=user_email, api_key=api_key) - return True if user else False - - # ~~~~ PUBLIC ROUTES ~~~~ @@ -69,11 +60,11 @@ def _index(): flags = dict() if request.method == "POST": # register a new user - if not user_exists(user_email=request.values["email"], api_key=request.values["api_key"]): + if not User.find_user(user_email=request.values["email"], api_key=request.values["api_key"]): User.create_user( request.values["api_key"], request.values["email"], - _parse_roles(request.values.getlist("roles")), + set(request.values.getlist("roles")), ) flags["banner"] = "Successfully Added" else: @@ -101,7 +92,7 @@ def _detail(user_id: int): user=user, api_key=request.values["api_key"], email=request.values["email"], - roles=_parse_roles(request.values.getlist("roles")), + roles=set(request.values.getlist("roles")), ) flags["banner"] = "Successfully Saved" return _render("detail", token, flags, user=user.as_dict) @@ -116,7 +107,7 @@ def _register(): user_api_key = body["user_api_key"] user_email = body["user_email"] - if user_exists(user_email=user_email, api_key=user_api_key): + if User.find_user(user_email=user_email, api_key=user_api_key): return make_response( "User with email and/or API Key already exists, use different parameters or contact us for help", 409, From 11bafa2023f4a9d85f531024e25e8f3ff76d44ae Mon Sep 17 00:00:00 2001 From: george haff Date: Wed, 14 Jun 2023 17:10:48 -0400 Subject: [PATCH 3/6] sqlalchemy cleanup: removed superfluous bits, improved argument passing for engine creation --- src/server/_db.py | 12 +++++------- src/server/_security.py | 6 +----- src/server/main.py | 4 ---- 3 files changed, 6 insertions(+), 16 deletions(-) diff --git a/src/server/_db.py b/src/server/_db.py index a784031a5..b8fe12950 100644 --- a/src/server/_db.py +++ b/src/server/_db.py @@ -1,4 +1,4 @@ -from sqlalchemy import create_engine, MetaData +from sqlalchemy import create_engine from sqlalchemy.engine import Engine from sqlalchemy.orm import sessionmaker @@ -9,15 +9,13 @@ # previously `_common` imported from `_security` which imported from `admin.models`, which imported (back again) from `_common` for database connection objects -engine: Engine = create_engine(SQLALCHEMY_DATABASE_URI, **SQLALCHEMY_ENGINE_OPTIONS).execution_options(engine_id='default') -metadata = MetaData(bind=engine) +engine: Engine = create_engine(SQLALCHEMY_DATABASE_URI, **SQLALCHEMY_ENGINE_OPTIONS, execution_options={'engine_id': 'default'}) Session = sessionmaker(bind=engine) -if SQLALCHEMY_DATABASE_URI_PRIMARY: - write_engine: Engine = create_engine(SQLALCHEMY_DATABASE_URI_PRIMARY, **SQLALCHEMY_ENGINE_OPTIONS).execution_options(engine_id='write_engine') - write_metadata = MetaData(bind=write_engine) +if SQLALCHEMY_DATABASE_URI_PRIMARY and SQLALCHEMY_DATABASE_URI_PRIMARY != SQLALCHEMY_DATABASE_URI: + # TODO: insert log statement about this? + write_engine: Engine = create_engine(SQLALCHEMY_DATABASE_URI_PRIMARY, **SQLALCHEMY_ENGINE_OPTIONS, execution_options={'engine_id': 'write_engine'}) WriteSession = sessionmaker(bind=write_engine) else: write_engine: Engine = engine - write_metadata = metadata WriteSession = Session diff --git a/src/server/_security.py b/src/server/_security.py index 761d088c3..61e2608b2 100644 --- a/src/server/_security.py +++ b/src/server/_security.py @@ -16,7 +16,7 @@ TEMPORARY_API_KEY, URL_PREFIX, ) -from .admin.models import User, UserRole +from .admin.models import User API_KEY_HARD_WARNING = API_KEY_REQUIRED_STARTING_AT - timedelta(days=14) API_KEY_SOFT_WARNING = API_KEY_HARD_WARNING - timedelta(days=14) @@ -91,10 +91,6 @@ def _get_current_user(): current_user: User = cast(User, LocalProxy(_get_current_user)) -def register_user_role(role_name: str) -> None: - UserRole.create_role(role_name) - - def _is_public_route() -> bool: public_routes_list = ["lib", "admin", "version"] for route in public_routes_list: diff --git a/src/server/main.py b/src/server/main.py index c05b9d0d3..a91a91ee2 100644 --- a/src/server/main.py +++ b/src/server/main.py @@ -8,11 +8,9 @@ from ._config import URL_PREFIX, VERSION from ._common import app, set_compatibility_mode -from ._db import metadata, engine from ._exceptions import MissingOrWrongSourceException from .endpoints import endpoints from .endpoints.admin import bp as admin_bp, enable_admin -from ._security import register_user_role from ._limiter import limiter, apply_limit __all__ = ["app"] @@ -65,8 +63,6 @@ def send_lib_file(path: str): return send_from_directory(pathlib.Path(__file__).parent / "lib", path) -metadata.create_all(engine) - if __name__ == "__main__": app.run(host="0.0.0.0", port=5000) else: From 2822d9de847d97ae862bb533d92eafd04a43ca95 Mon Sep 17 00:00:00 2001 From: george haff Date: Tue, 20 Jun 2023 16:53:28 -0400 Subject: [PATCH 4/6] eager-load roles, remove unnecessary methods, add @default_session, move session ctx mgrs to admin page --- src/server/_db.py | 49 +++++++++++++++- src/server/admin/models.py | 104 +++++++++++++++------------------- src/server/endpoints/admin.py | 90 +++++++++++++++-------------- 3 files changed, 144 insertions(+), 99 deletions(-) diff --git a/src/server/_db.py b/src/server/_db.py index b8fe12950..e65c885ff 100644 --- a/src/server/_db.py +++ b/src/server/_db.py @@ -1,3 +1,6 @@ +import functools +from inspect import signature, Parameter + from sqlalchemy import create_engine from sqlalchemy.engine import Engine from sqlalchemy.orm import sessionmaker @@ -9,13 +12,57 @@ # previously `_common` imported from `_security` which imported from `admin.models`, which imported (back again) from `_common` for database connection objects + +# a decorator to automatically provide a sqlalchemy session by default, if an existing session is not explicitly +# specified to override it. it is preferred to use a single session for a sequence of operations logically grouped +# together, but this allows individual operations to be run by themselves without having to provide an +# already-established session. requires an argument to the wrapped function named 'session'. +# for instance: +# +# @default_session(WriteSession) +# def foo(session): +# pass +# +# # calling: +# foo() +# # is identical to: +# with WriteSession() as s: +# foo(s) +def default_session(sess): + def decorator__default_session(func): + # make sure `func` is compatible w/ this decorator + func_params = signature(func).parameters + if 'session' not in func_params or func_params['session'].kind == Parameter.POSITIONAL_ONLY: + raise Exception(f"@default_session(): function {func.__name__}() must accept an argument 'session' that can be specified by keyword.") + # save position of 'session' arg, to later check if its been passed in by position/order + sess_index = list(func_params).index('session') + + @functools.wraps(func) + def wrapper__default_session(*args, **kwargs): + if 'session' in kwargs or len(args) >= sess_index+1: + # 'session' has been specified by the caller, so we have nothing to do here. pass along all args unchanged. + return func(*args, **kwargs) + # otherwise, we will wrap this call with a context manager for the default session provider, and pass that session instance to the wrapped function. + with sess() as session: + return func(*args, **kwargs, session=session) + + return wrapper__default_session + + return decorator__default_session + + engine: Engine = create_engine(SQLALCHEMY_DATABASE_URI, **SQLALCHEMY_ENGINE_OPTIONS, execution_options={'engine_id': 'default'}) Session = sessionmaker(bind=engine) if SQLALCHEMY_DATABASE_URI_PRIMARY and SQLALCHEMY_DATABASE_URI_PRIMARY != SQLALCHEMY_DATABASE_URI: - # TODO: insert log statement about this? + # if available, use the main/primary DB for write operations. DB replication processes should be in place to + # propagate any written changes to the regular (load balanced) replicas. write_engine: Engine = create_engine(SQLALCHEMY_DATABASE_URI_PRIMARY, **SQLALCHEMY_ENGINE_OPTIONS, execution_options={'engine_id': 'write_engine'}) WriteSession = sessionmaker(bind=write_engine) + # TODO: insert log statement acknowledging this second session handle is in use? else: write_engine: Engine = engine WriteSession = Session +# NOTE: `WriteSession` could be called `AdminSession`, as its only (currently) used by the admin page, and the admin +# page is the only thing that should be writing to the db. its tempting to let the admin page read from the +# regular `Session` and write with `WriteSession`, but concurrency problems may arise from sync/replication lag. diff --git a/src/server/admin/models.py b/src/server/admin/models.py index 056c547fa..8e4a0316f 100644 --- a/src/server/admin/models.py +++ b/src/server/admin/models.py @@ -3,7 +3,7 @@ from sqlalchemy.orm import relationship from copy import deepcopy -from .._db import Session, WriteSession +from .._db import Session, WriteSession, default_session from delphi.epidata.common.logger import get_structured_logger from typing import Set, Optional, List @@ -25,7 +25,7 @@ def _default_date_now(): class User(Base): __tablename__ = "api_user" id = Column(Integer, primary_key=True, autoincrement=True) - roles = relationship("UserRole", secondary=association_table) + roles = relationship("UserRole", secondary=association_table, lazy="joined") # last arg does an eager load of this property from foreign tables api_key = Column(String(50), unique=True, nullable=False) email = Column(String(320), unique=True, nullable=False) created = Column(Date, default=_default_date_now) @@ -35,30 +35,19 @@ def __init__(self, api_key: str, email: str = None) -> None: self.api_key = api_key self.email = email - @staticmethod - def list_users() -> List["User"]: - with Session() as session: - return session.query(User).all() - @property def as_dict(self): return { "id": self.id, "api_key": self.api_key, "email": self.email, - "roles": User.get_user_roles(self.id), + "roles": set(role.name for role in self.roles), "created": self.created, "last_time_used": self.last_time_used } - @staticmethod - def get_user_roles(user_id: int) -> Set[str]: - with Session() as session: - user = session.query(User).filter(User.id == user_id).first() - return set([role.name for role in user.roles]) - def has_role(self, required_role: str) -> bool: - return required_role in User.get_user_roles(self.id) + return required_role in set(role.name for role in self.roles) @staticmethod def _assign_roles(user: "User", roles: Optional[Set[str]], session) -> None: @@ -74,58 +63,59 @@ def _assign_roles(user: "User", roles: Optional[Set[str]], session) -> None: db_user.roles = [] @staticmethod + @default_session(Session) def find_user(*, # asterisk forces explicit naming of all arguments when calling this method - user_id: Optional[int] = None, api_key: Optional[str] = None, user_email: Optional[str] = None + session, + user_id: Optional[int] = None, api_key: Optional[str] = None, user_email: Optional[str] = None ) -> "User": # NOTE: be careful, using multiple arguments could match multiple users, but this will return only one! - with Session() as session: - user = ( - session.query(User) - .filter((User.id == user_id) | (User.api_key == api_key) | (User.email == user_email)) - .first() - ) + user = ( + session.query(User) + .filter((User.id == user_id) | (User.api_key == api_key) | (User.email == user_email)) + .first() + ) return user if user else None @staticmethod - def create_user(api_key: str, email: str, user_roles: Optional[Set[str]] = None) -> "User": + @default_session(WriteSession) + def create_user(api_key: str, email: str, session, user_roles: Optional[Set[str]] = None) -> "User": get_structured_logger("api_user_models").info("creating user", api_key=api_key) - with WriteSession() as session: - new_user = User(api_key=api_key, email=email) - # TODO: we may need to populate 'created' field/column here, if the default - # specified above gets bound to the time of when that line of python was evaluated. - session.add(new_user) - session.commit() - User._assign_roles(new_user, user_roles, session) - session.commit() + new_user = User(api_key=api_key, email=email) + session.add(new_user) + session.commit() + User._assign_roles(new_user, user_roles, session) + session.commit() return new_user @staticmethod + @default_session(WriteSession) def update_user( user: "User", email: Optional[str], api_key: Optional[str], - roles: Optional[Set[str]] + roles: Optional[Set[str]], + session ) -> "User": get_structured_logger("api_user_models").info("updating user", user_id=user.id, new_api_key=api_key) - with WriteSession() as session: - user = User.find_user(user_id=user.id) - if user: - update_stmt = ( - update(User) - .where(User.id == user.id) - .values(api_key=api_key, email=email) - ) - session.execute(update_stmt) - User._assign_roles(user, roles, session) - session.commit() + user = User.find_user(user_id=user.id, session=session) + if user: + update_stmt = ( + update(User) + .where(User.id == user.id) + .values(api_key=api_key, email=email) + ) + session.execute(update_stmt) + User._assign_roles(user, roles, session) + session.commit() + # TODO: else: raise Exception() ?? return user @staticmethod - def delete_user(user_id: int) -> None: + @default_session(WriteSession) + def delete_user(user_id: int, session) -> None: get_structured_logger("api_user_models").info("deleting user", user_id=user_id) - with WriteSession() as session: - session.execute(delete(User).where(User.id == user_id)) - session.commit() + session.execute(delete(User).where(User.id == user_id)) + session.commit() class UserRole(Base): @@ -134,23 +124,23 @@ class UserRole(Base): name = Column(String(50), unique=True) @staticmethod - def create_role(name: str) -> None: + @default_session(WriteSession) + def create_role(name: str, session) -> None: get_structured_logger("api_user_models").info("creating user role", role=name) - with WriteSession() as session: - session.execute( - f""" + # TODO: check role doesnt already exist + session.execute(f""" INSERT INTO user_role (name) SELECT '{name}' WHERE NOT EXISTS (SELECT * FROM user_role WHERE name='{name}') - """ - ) - session.commit() + """) + session.commit() + # TODO: look up and return new role @staticmethod - def list_all_roles(): - with Session() as session: - roles = session.query(UserRole).all() + @default_session(Session) + def list_all_roles(session): + roles = session.query(UserRole).all() return [role.name for role in roles] diff --git a/src/server/endpoints/admin.py b/src/server/endpoints/admin.py index e3af3fe0d..118ee7441 100644 --- a/src/server/endpoints/admin.py +++ b/src/server/endpoints/admin.py @@ -7,6 +7,7 @@ from .._common import log_info_with_request from .._config import ADMIN_PASSWORD, API_KEY_REGISTRATION_FORM_LINK, API_KEY_REMOVAL_REQUEST_LINK, REGISTER_WEBHOOK_TOKEN +from .._db import WriteSession from .._security import resolve_auth_token from ..admin.models import User, UserRole @@ -29,10 +30,10 @@ def _require_admin(): return token -def _render(mode: str, token: str, flags: Dict, **kwargs): +def _render(mode: str, token: str, flags: Dict, session, **kwargs): template = (templates_dir / "index.html").read_text("utf8") return render_template_string( - template, mode=mode, token=token, flags=flags, roles=UserRole.list_all_roles(), **kwargs + template, mode=mode, token=token, flags=flags, roles=UserRole.list_all_roles(session), **kwargs ) @@ -58,44 +59,50 @@ def removal_request_redirect(): def _index(): token = _require_admin() flags = dict() - if request.method == "POST": - # register a new user - if not User.find_user(user_email=request.values["email"], api_key=request.values["api_key"]): - User.create_user( - request.values["api_key"], - request.values["email"], - set(request.values.getlist("roles")), - ) - flags["banner"] = "Successfully Added" - else: - flags["banner"] = "User with such email and/or api key already exists." - users = [user.as_dict for user in User.list_users()] - return _render("overview", token, flags, users=users, user=dict()) + with WriteSession() as session: + if request.method == "POST": + # register a new user + if not User.find_user( + user_email=request.values["email"], api_key=request.values["api_key"], + session=session): + User.create_user( + api_key=request.values["api_key"], + email=request.values["email"], + user_roles=set(request.values.getlist("roles")), + session=session + ) + flags["banner"] = "Successfully Added" + else: + flags["banner"] = "User with such email and/or api key already exists." + users = [user.as_dict for user in session.query(User).all()] + return _render("overview", token, flags, session=session, users=users, user=dict()) @bp.route("/", methods=["GET", "PUT", "POST", "DELETE"]) def _detail(user_id: int): token = _require_admin() - user = User.find_user(user_id=user_id) - if not user: - raise NotFound() - if request.method == "DELETE" or "delete" in request.values: - User.delete_user(user.id) - return redirect(f"./?auth={token}") - flags = dict() - if request.method in ["PUT", "POST"]: - user_check = User.find_user(api_key=request.values["api_key"], user_email=request.values["email"]) - if user_check and user_check.id != user.id: - flags["banner"] = "Could not update user; same api_key and/or email already exists." - else: - user = user.update_user( - user=user, - api_key=request.values["api_key"], - email=request.values["email"], - roles=set(request.values.getlist("roles")), - ) - flags["banner"] = "Successfully Saved" - return _render("detail", token, flags, user=user.as_dict) + with WriteSession() as session: + user = User.find_user(user_id=user_id, session=session) + if not user: + raise NotFound() + if request.method == "DELETE" or "delete" in request.values: + User.delete_user(user.id, session=session) + return redirect(f"./?auth={token}") + flags = dict() + if request.method in ["PUT", "POST"]: + user_check = User.find_user(api_key=request.values["api_key"], user_email=request.values["email"], session=session) + if user_check and user_check.id != user.id: + flags["banner"] = "Could not update user; same api_key and/or email already exists." + else: + user = user.update_user( + user=user, + api_key=request.values["api_key"], + email=request.values["email"], + roles=set(request.values.getlist("roles")), + session=session + ) + flags["banner"] = "Successfully Saved" + return _render("detail", token, flags, session=session, user=user.as_dict) @bp.route("/register", methods=["POST"]) @@ -107,12 +114,13 @@ def _register(): user_api_key = body["user_api_key"] user_email = body["user_email"] - if User.find_user(user_email=user_email, api_key=user_api_key): - return make_response( - "User with email and/or API Key already exists, use different parameters or contact us for help", - 409, - ) - User.create_user(api_key=user_api_key, email=user_email) + with WriteSession() as session: + if User.find_user(user_email=user_email, api_key=user_api_key, session=session): + return make_response( + "User with email and/or API Key already exists, use different parameters or contact us for help", + 409, + ) + User.create_user(api_key=user_api_key, email=user_email, session=session) return make_response(f"Successfully registered API key '{user_api_key}'", 200) From 3bdd09ffdc2a19f5e5cf0a2cc90bbba209205a62 Mon Sep 17 00:00:00 2001 From: george haff Date: Tue, 20 Jun 2023 18:05:35 -0400 Subject: [PATCH 5/6] _assign_roles() does its own commit() and returns an instance of the newly updated User --- src/server/admin/models.py | 18 ++++++++---------- 1 file changed, 8 insertions(+), 10 deletions(-) diff --git a/src/server/admin/models.py b/src/server/admin/models.py index 8e4a0316f..40d32c0dd 100644 --- a/src/server/admin/models.py +++ b/src/server/admin/models.py @@ -51,16 +51,17 @@ def has_role(self, required_role: str) -> bool: @staticmethod def _assign_roles(user: "User", roles: Optional[Set[str]], session) -> None: - # NOTE: this uses a borrowed/existing `session`, and thus does not do a `session.commit()`... - # that is the responsibility of the caller! get_structured_logger("api_user_models").info("setting roles", roles=roles, user_id=user.id, api_key=user.api_key) db_user = session.query(User).filter(User.id == user.id).first() # TODO: would it be sufficient to use the passed-in `user` instead of looking up this `db_user`? + # or even use this as a bound method instead of a static?? if roles: - roles_to_assign = session.query(UserRole).filter(UserRole.name.in_(roles)).all() - db_user.roles = roles_to_assign + db_user.roles = session.query(UserRole).filter(UserRole.name.in_(roles)).all() else: db_user.roles = [] + session.commit() + # retrieve the newly updated User object + return session.query(User).filter(User.id == user.id).first() @staticmethod @default_session(Session) @@ -83,9 +84,7 @@ def create_user(api_key: str, email: str, session, user_roles: Optional[Set[str] new_user = User(api_key=api_key, email=email) session.add(new_user) session.commit() - User._assign_roles(new_user, user_roles, session) - session.commit() - return new_user + return User._assign_roles(new_user, user_roles, session) @staticmethod @default_session(WriteSession) @@ -105,10 +104,9 @@ def update_user( .values(api_key=api_key, email=email) ) session.execute(update_stmt) - User._assign_roles(user, roles, session) - session.commit() + return User._assign_roles(user, roles, session) # TODO: else: raise Exception() ?? - return user + return None @staticmethod @default_session(WriteSession) From 1c926e7c653839863d65c98fc3aa81b3519d9d72 Mon Sep 17 00:00:00 2001 From: george haff Date: Fri, 23 Jun 2023 14:51:14 -0400 Subject: [PATCH 6/6] address TODOs pointed out in review ++ TODOs done: raise Exception when trying to update non-existent User, return UserRole on creation. also use more appropriate reciever for static method call, and expand comment on static vs bound methods in User. --- src/server/admin/models.py | 22 +++++++++++----------- src/server/endpoints/admin.py | 2 +- 2 files changed, 12 insertions(+), 12 deletions(-) diff --git a/src/server/admin/models.py b/src/server/admin/models.py index 40d32c0dd..f5c0d54ed 100644 --- a/src/server/admin/models.py +++ b/src/server/admin/models.py @@ -55,6 +55,7 @@ def _assign_roles(user: "User", roles: Optional[Set[str]], session) -> None: db_user = session.query(User).filter(User.id == user.id).first() # TODO: would it be sufficient to use the passed-in `user` instead of looking up this `db_user`? # or even use this as a bound method instead of a static?? + # same goes for `update_user()` and `delete_user()` below... if roles: db_user.roles = session.query(UserRole).filter(UserRole.name.in_(roles)).all() else: @@ -97,16 +98,15 @@ def update_user( ) -> "User": get_structured_logger("api_user_models").info("updating user", user_id=user.id, new_api_key=api_key) user = User.find_user(user_id=user.id, session=session) - if user: - update_stmt = ( - update(User) - .where(User.id == user.id) - .values(api_key=api_key, email=email) - ) - session.execute(update_stmt) - return User._assign_roles(user, roles, session) - # TODO: else: raise Exception() ?? - return None + if not user: + raise Exception('user not found') + update_stmt = ( + update(User) + .where(User.id == user.id) + .values(api_key=api_key, email=email) + ) + session.execute(update_stmt) + return User._assign_roles(user, roles, session) @staticmethod @default_session(WriteSession) @@ -135,7 +135,7 @@ def create_role(name: str, session) -> None: WHERE name='{name}') """) session.commit() - # TODO: look up and return new role + return session.query(UserRole).filter(UserRole.name == name).first() @staticmethod @default_session(Session) diff --git a/src/server/endpoints/admin.py b/src/server/endpoints/admin.py index 118ee7441..a6f941b48 100644 --- a/src/server/endpoints/admin.py +++ b/src/server/endpoints/admin.py @@ -94,7 +94,7 @@ def _detail(user_id: int): if user_check and user_check.id != user.id: flags["banner"] = "Could not update user; same api_key and/or email already exists." else: - user = user.update_user( + user = User.update_user( user=user, api_key=request.values["api_key"], email=request.values["email"],