Skip to content

use second db handle for only for user admin and writes #1184

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 6 commits into from
Jun 26, 2023
9 changes: 5 additions & 4 deletions src/server/_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

from flask import Flask, g, request
from sqlalchemy import event
from sqlalchemy.engine import Connection
from sqlalchemy.engine import Connection, Engine
from werkzeug.exceptions import Unauthorized
from werkzeug.local import LocalProxy

Expand Down Expand Up @@ -85,12 +85,12 @@ def log_info_with_request_and_response(message, response, **kwargs):
**kwargs
)

@event.listens_for(engine, "before_cursor_execute")
@event.listens_for(Engine, "before_cursor_execute")
def before_cursor_execute(conn, cursor, statement, parameters, context, executemany):
context._query_start_time = time.time()


@event.listens_for(engine, "after_cursor_execute")
@event.listens_for(Engine, "after_cursor_execute")
def after_cursor_execute(conn, cursor, statement, parameters, context, executemany):
# this timing info may be suspect, at least in terms of dbms cpu time...
# it is likely that it includes that time as well as any overhead that
Expand All @@ -101,7 +101,8 @@ def after_cursor_execute(conn, cursor, statement, parameters, context, executema
# Convert to milliseconds
total_time *= 1000
get_structured_logger("server_api").info(
"Executed SQL", statement=statement, params=parameters, elapsed_time_ms=total_time
"Executed SQL", statement=statement, params=parameters, elapsed_time_ms=total_time,
engine_id=conn.get_execution_options().get('engine_id')
)


Expand Down
61 changes: 53 additions & 8 deletions src/server/_db.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
from sqlalchemy import create_engine, MetaData
import functools
from inspect import signature, Parameter

from sqlalchemy import create_engine
from sqlalchemy.engine import Engine
from sqlalchemy.orm import sessionmaker

Expand All @@ -9,15 +12,57 @@
# previously `_common` imported from `_security` which imported from `admin.models`, which imported (back again) from `_common` for database connection objects


engine: Engine = create_engine(SQLALCHEMY_DATABASE_URI, **SQLALCHEMY_ENGINE_OPTIONS)

if SQLALCHEMY_DATABASE_URI_PRIMARY:
user_engine: Engine = create_engine(SQLALCHEMY_DATABASE_URI_PRIMARY, **SQLALCHEMY_ENGINE_OPTIONS)
else:
user_engine: Engine = engine
# a decorator to automatically provide a sqlalchemy session by default, if an existing session is not explicitly
# specified to override it. it is preferred to use a single session for a sequence of operations logically grouped
# together, but this allows individual operations to be run by themselves without having to provide an
# already-established session. requires an argument to the wrapped function named 'session'.
# for instance:
#
# @default_session(WriteSession)
# def foo(session):
# pass
#
# # calling:
# foo()
# # is identical to:
# with WriteSession() as s:
# foo(s)
def default_session(sess):
def decorator__default_session(func):
# make sure `func` is compatible w/ this decorator
func_params = signature(func).parameters
if 'session' not in func_params or func_params['session'].kind == Parameter.POSITIONAL_ONLY:
raise Exception(f"@default_session(): function {func.__name__}() must accept an argument 'session' that can be specified by keyword.")
# save position of 'session' arg, to later check if its been passed in by position/order
sess_index = list(func_params).index('session')

@functools.wraps(func)
def wrapper__default_session(*args, **kwargs):
if 'session' in kwargs or len(args) >= sess_index+1:
# 'session' has been specified by the caller, so we have nothing to do here. pass along all args unchanged.
return func(*args, **kwargs)
# otherwise, we will wrap this call with a context manager for the default session provider, and pass that session instance to the wrapped function.
with sess() as session:
return func(*args, **kwargs, session=session)

metadata = MetaData(bind=user_engine)
return wrapper__default_session

Session = sessionmaker(bind=user_engine)
return decorator__default_session


engine: Engine = create_engine(SQLALCHEMY_DATABASE_URI, **SQLALCHEMY_ENGINE_OPTIONS, execution_options={'engine_id': 'default'})
Session = sessionmaker(bind=engine)

if SQLALCHEMY_DATABASE_URI_PRIMARY and SQLALCHEMY_DATABASE_URI_PRIMARY != SQLALCHEMY_DATABASE_URI:
# if available, use the main/primary DB for write operations. DB replication processes should be in place to
# propagate any written changes to the regular (load balanced) replicas.
write_engine: Engine = create_engine(SQLALCHEMY_DATABASE_URI_PRIMARY, **SQLALCHEMY_ENGINE_OPTIONS, execution_options={'engine_id': 'write_engine'})
WriteSession = sessionmaker(bind=write_engine)
# TODO: insert log statement acknowledging this second session handle is in use?
else:
write_engine: Engine = engine
WriteSession = Session
# NOTE: `WriteSession` could be called `AdminSession`, as its only (currently) used by the admin page, and the admin
# page is the only thing that should be writing to the db. its tempting to let the admin page read from the
# regular `Session` and write with `WriteSession`, but concurrency problems may arise from sync/replication lag.
6 changes: 1 addition & 5 deletions src/server/_security.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
TEMPORARY_API_KEY,
URL_PREFIX,
)
from .admin.models import User, UserRole
from .admin.models import User

API_KEY_HARD_WARNING = API_KEY_REQUIRED_STARTING_AT - timedelta(days=14)
API_KEY_SOFT_WARNING = API_KEY_HARD_WARNING - timedelta(days=14)
Expand Down Expand Up @@ -91,10 +91,6 @@ def _get_current_user():
current_user: User = cast(User, LocalProxy(_get_current_user))


def register_user_role(role_name: str) -> None:
UserRole.create_role(role_name)


def _is_public_route() -> bool:
public_routes_list = ["lib", "admin", "version"]
for route in public_routes_list:
Expand Down
114 changes: 51 additions & 63 deletions src/server/admin/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from sqlalchemy.orm import relationship
from copy import deepcopy

from .._db import Session
from .._db import Session, WriteSession, default_session
from delphi.epidata.common.logger import get_structured_logger

from typing import Set, Optional, List
Expand All @@ -25,7 +25,7 @@ def _default_date_now():
class User(Base):
__tablename__ = "api_user"
id = Column(Integer, primary_key=True, autoincrement=True)
roles = relationship("UserRole", secondary=association_table)
roles = relationship("UserRole", secondary=association_table, lazy="joined") # last arg does an eager load of this property from foreign tables
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Wow, that is awesome. Didn't know about that param.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🌊 🏄

api_key = Column(String(50), unique=True, nullable=False)
email = Column(String(320), unique=True, nullable=False)
created = Column(Date, default=_default_date_now)
Expand All @@ -35,97 +35,85 @@ def __init__(self, api_key: str, email: str = None) -> None:
self.api_key = api_key
self.email = email

@staticmethod
def list_users() -> List["User"]:
with Session() as session:
return session.query(User).all()

@property
def as_dict(self):
return {
"id": self.id,
"api_key": self.api_key,
"email": self.email,
"roles": User.get_user_roles(self.id),
"roles": set(role.name for role in self.roles),
"created": self.created,
"last_time_used": self.last_time_used
}

@staticmethod
def get_user_roles(user_id: int) -> Set[str]:
with Session() as session:
user = session.query(User).filter(User.id == user_id).first()
return set([role.name for role in user.roles])

def has_role(self, required_role: str) -> bool:
return required_role in User.get_user_roles(self.id)
return required_role in set(role.name for role in self.roles)

@staticmethod
def _assign_roles(user: "User", roles: Optional[Set[str]], session) -> None:
# NOTE: this uses a borrowed/existing `session`, and thus does not do a `session.commit()`...
# that is the responsibility of the caller!
get_structured_logger("api_user_models").info("setting roles", roles=roles, user_id=user.id, api_key=user.api_key)
db_user = session.query(User).filter(User.id == user.id).first()
# TODO: would it be sufficient to use the passed-in `user` instead of looking up this `db_user`?
# or even use this as a bound method instead of a static??
# same goes for `update_user()` and `delete_user()` below...
if roles:
roles_to_assign = session.query(UserRole).filter(UserRole.name.in_(roles)).all()
db_user.roles = roles_to_assign
db_user.roles = session.query(UserRole).filter(UserRole.name.in_(roles)).all()
else:
db_user.roles = []
session.commit()
# retrieve the newly updated User object
return session.query(User).filter(User.id == user.id).first()
Comment on lines +63 to +65
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is this still needed if we're using WriteSession for both writes and reads in admin?

oh i see; this is adding back the commit removed from 99, and lets you return directly instead of needing a separate return at 100

carry on

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yap!


@staticmethod
@default_session(Session)
def find_user(*, # asterisk forces explicit naming of all arguments when calling this method
user_id: Optional[int] = None, api_key: Optional[str] = None, user_email: Optional[str] = None
session,
user_id: Optional[int] = None, api_key: Optional[str] = None, user_email: Optional[str] = None
) -> "User":
# NOTE: be careful, using multiple arguments could match multiple users, but this will return only one!
with Session() as session:
user = (
session.query(User)
.filter((User.id == user_id) | (User.api_key == api_key) | (User.email == user_email))
.first()
)
user = (
session.query(User)
.filter((User.id == user_id) | (User.api_key == api_key) | (User.email == user_email))
.first()
)
return user if user else None

@staticmethod
def create_user(api_key: str, email: str, user_roles: Optional[Set[str]] = None) -> "User":
@default_session(WriteSession)
def create_user(api_key: str, email: str, session, user_roles: Optional[Set[str]] = None) -> "User":
get_structured_logger("api_user_models").info("creating user", api_key=api_key)
with Session() as session:
new_user = User(api_key=api_key, email=email)
# TODO: we may need to populate 'created' field/column here, if the default
# specified above gets bound to the time of when that line of python was evaluated.
session.add(new_user)
session.commit()
User._assign_roles(new_user, user_roles, session)
session.commit()
return new_user
new_user = User(api_key=api_key, email=email)
session.add(new_user)
session.commit()
return User._assign_roles(new_user, user_roles, session)

@staticmethod
@default_session(WriteSession)
def update_user(
user: "User",
email: Optional[str],
api_key: Optional[str],
roles: Optional[Set[str]]
roles: Optional[Set[str]],
session
) -> "User":
get_structured_logger("api_user_models").info("updating user", user_id=user.id, new_api_key=api_key)
with Session() as session:
user = User.find_user(user_id=user.id)
if user:
update_stmt = (
update(User)
.where(User.id == user.id)
.values(api_key=api_key, email=email)
)
session.execute(update_stmt)
User._assign_roles(user, roles, session)
session.commit()
return user
user = User.find_user(user_id=user.id, session=session)
if not user:
raise Exception('user not found')
update_stmt = (
update(User)
.where(User.id == user.id)
.values(api_key=api_key, email=email)
)
session.execute(update_stmt)
return User._assign_roles(user, roles, session)

@staticmethod
def delete_user(user_id: int) -> None:
@default_session(WriteSession)
def delete_user(user_id: int, session) -> None:
get_structured_logger("api_user_models").info("deleting user", user_id=user_id)
with Session() as session:
session.execute(delete(User).where(User.id == user_id))
session.commit()
session.execute(delete(User).where(User.id == user_id))
session.commit()


class UserRole(Base):
Expand All @@ -134,23 +122,23 @@ class UserRole(Base):
name = Column(String(50), unique=True)

@staticmethod
def create_role(name: str) -> None:
@default_session(WriteSession)
def create_role(name: str, session) -> None:
get_structured_logger("api_user_models").info("creating user role", role=name)
with Session() as session:
session.execute(
f"""
# TODO: check role doesnt already exist
session.execute(f"""
INSERT INTO user_role (name)
SELECT '{name}'
WHERE NOT EXISTS
(SELECT *
FROM user_role
WHERE name='{name}')
"""
)
session.commit()
""")
session.commit()
return session.query(UserRole).filter(UserRole.name == name).first()

@staticmethod
def list_all_roles():
with Session() as session:
roles = session.query(UserRole).all()
@default_session(Session)
def list_all_roles(session):
roles = session.query(UserRole).all()
return [role.name for role in roles]
Loading