Skip to content

Commit 6fe6e7a

Browse files
authored
use second db handle for only for user admin and writes (#1184)
* change 'user_engine' to a 'WriteSession' instead, so the master db connection is used for writes [and associated admin session reads] only * eager-load roles, remove unnecessary methods, add @default_session, move session ctx mgrs to admin page * make sure sql statements and timing are logged for all engines, plus tag engines with id and log those too, and superfluous user method cleanup * sqlalchemy cleanup: removed superfluous bits, improved argument passing for engine creation * _assign_roles() does its own commit() and returns an instance of the newly updated User * raise Exception when trying to update non-existent User, return UserRole on creation. * use more appropriate reciever for static method call, and expand comment on static vs bound methods in User.
1 parent 13fcfe3 commit 6fe6e7a

File tree

6 files changed

+159
-134
lines changed

6 files changed

+159
-134
lines changed

src/server/_common.py

+5-4
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33

44
from flask import Flask, g, request
55
from sqlalchemy import event
6-
from sqlalchemy.engine import Connection
6+
from sqlalchemy.engine import Connection, Engine
77
from werkzeug.exceptions import Unauthorized
88
from werkzeug.local import LocalProxy
99

@@ -85,12 +85,12 @@ def log_info_with_request_and_response(message, response, **kwargs):
8585
**kwargs
8686
)
8787

88-
@event.listens_for(engine, "before_cursor_execute")
88+
@event.listens_for(Engine, "before_cursor_execute")
8989
def before_cursor_execute(conn, cursor, statement, parameters, context, executemany):
9090
context._query_start_time = time.time()
9191

9292

93-
@event.listens_for(engine, "after_cursor_execute")
93+
@event.listens_for(Engine, "after_cursor_execute")
9494
def after_cursor_execute(conn, cursor, statement, parameters, context, executemany):
9595
# this timing info may be suspect, at least in terms of dbms cpu time...
9696
# it is likely that it includes that time as well as any overhead that
@@ -101,7 +101,8 @@ def after_cursor_execute(conn, cursor, statement, parameters, context, executema
101101
# Convert to milliseconds
102102
total_time *= 1000
103103
get_structured_logger("server_api").info(
104-
"Executed SQL", statement=statement, params=parameters, elapsed_time_ms=total_time
104+
"Executed SQL", statement=statement, params=parameters, elapsed_time_ms=total_time,
105+
engine_id=conn.get_execution_options().get('engine_id')
105106
)
106107

107108

src/server/_db.py

+53-8
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,7 @@
1-
from sqlalchemy import create_engine, MetaData
1+
import functools
2+
from inspect import signature, Parameter
3+
4+
from sqlalchemy import create_engine
25
from sqlalchemy.engine import Engine
36
from sqlalchemy.orm import sessionmaker
47

@@ -9,15 +12,57 @@
912
# previously `_common` imported from `_security` which imported from `admin.models`, which imported (back again) from `_common` for database connection objects
1013

1114

12-
engine: Engine = create_engine(SQLALCHEMY_DATABASE_URI, **SQLALCHEMY_ENGINE_OPTIONS)
1315

14-
if SQLALCHEMY_DATABASE_URI_PRIMARY:
15-
user_engine: Engine = create_engine(SQLALCHEMY_DATABASE_URI_PRIMARY, **SQLALCHEMY_ENGINE_OPTIONS)
16-
else:
17-
user_engine: Engine = engine
16+
# a decorator to automatically provide a sqlalchemy session by default, if an existing session is not explicitly
17+
# specified to override it. it is preferred to use a single session for a sequence of operations logically grouped
18+
# together, but this allows individual operations to be run by themselves without having to provide an
19+
# already-established session. requires an argument to the wrapped function named 'session'.
20+
# for instance:
21+
#
22+
# @default_session(WriteSession)
23+
# def foo(session):
24+
# pass
25+
#
26+
# # calling:
27+
# foo()
28+
# # is identical to:
29+
# with WriteSession() as s:
30+
# foo(s)
31+
def default_session(sess):
32+
def decorator__default_session(func):
33+
# make sure `func` is compatible w/ this decorator
34+
func_params = signature(func).parameters
35+
if 'session' not in func_params or func_params['session'].kind == Parameter.POSITIONAL_ONLY:
36+
raise Exception(f"@default_session(): function {func.__name__}() must accept an argument 'session' that can be specified by keyword.")
37+
# save position of 'session' arg, to later check if its been passed in by position/order
38+
sess_index = list(func_params).index('session')
39+
40+
@functools.wraps(func)
41+
def wrapper__default_session(*args, **kwargs):
42+
if 'session' in kwargs or len(args) >= sess_index+1:
43+
# 'session' has been specified by the caller, so we have nothing to do here. pass along all args unchanged.
44+
return func(*args, **kwargs)
45+
# otherwise, we will wrap this call with a context manager for the default session provider, and pass that session instance to the wrapped function.
46+
with sess() as session:
47+
return func(*args, **kwargs, session=session)
1848

19-
metadata = MetaData(bind=user_engine)
49+
return wrapper__default_session
2050

21-
Session = sessionmaker(bind=user_engine)
51+
return decorator__default_session
2252

2353

54+
engine: Engine = create_engine(SQLALCHEMY_DATABASE_URI, **SQLALCHEMY_ENGINE_OPTIONS, execution_options={'engine_id': 'default'})
55+
Session = sessionmaker(bind=engine)
56+
57+
if SQLALCHEMY_DATABASE_URI_PRIMARY and SQLALCHEMY_DATABASE_URI_PRIMARY != SQLALCHEMY_DATABASE_URI:
58+
# if available, use the main/primary DB for write operations. DB replication processes should be in place to
59+
# propagate any written changes to the regular (load balanced) replicas.
60+
write_engine: Engine = create_engine(SQLALCHEMY_DATABASE_URI_PRIMARY, **SQLALCHEMY_ENGINE_OPTIONS, execution_options={'engine_id': 'write_engine'})
61+
WriteSession = sessionmaker(bind=write_engine)
62+
# TODO: insert log statement acknowledging this second session handle is in use?
63+
else:
64+
write_engine: Engine = engine
65+
WriteSession = Session
66+
# NOTE: `WriteSession` could be called `AdminSession`, as its only (currently) used by the admin page, and the admin
67+
# page is the only thing that should be writing to the db. its tempting to let the admin page read from the
68+
# regular `Session` and write with `WriteSession`, but concurrency problems may arise from sync/replication lag.

src/server/_security.py

+1-5
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
TEMPORARY_API_KEY,
1717
URL_PREFIX,
1818
)
19-
from .admin.models import User, UserRole
19+
from .admin.models import User
2020

2121
API_KEY_HARD_WARNING = API_KEY_REQUIRED_STARTING_AT - timedelta(days=14)
2222
API_KEY_SOFT_WARNING = API_KEY_HARD_WARNING - timedelta(days=14)
@@ -91,10 +91,6 @@ def _get_current_user():
9191
current_user: User = cast(User, LocalProxy(_get_current_user))
9292

9393

94-
def register_user_role(role_name: str) -> None:
95-
UserRole.create_role(role_name)
96-
97-
9894
def _is_public_route() -> bool:
9995
public_routes_list = ["lib", "admin", "version"]
10096
for route in public_routes_list:

src/server/admin/models.py

+51-63
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
from sqlalchemy.orm import relationship
44
from copy import deepcopy
55

6-
from .._db import Session
6+
from .._db import Session, WriteSession, default_session
77
from delphi.epidata.common.logger import get_structured_logger
88

99
from typing import Set, Optional, List
@@ -25,7 +25,7 @@ def _default_date_now():
2525
class User(Base):
2626
__tablename__ = "api_user"
2727
id = Column(Integer, primary_key=True, autoincrement=True)
28-
roles = relationship("UserRole", secondary=association_table)
28+
roles = relationship("UserRole", secondary=association_table, lazy="joined") # last arg does an eager load of this property from foreign tables
2929
api_key = Column(String(50), unique=True, nullable=False)
3030
email = Column(String(320), unique=True, nullable=False)
3131
created = Column(Date, default=_default_date_now)
@@ -35,97 +35,85 @@ def __init__(self, api_key: str, email: str = None) -> None:
3535
self.api_key = api_key
3636
self.email = email
3737

38-
@staticmethod
39-
def list_users() -> List["User"]:
40-
with Session() as session:
41-
return session.query(User).all()
42-
4338
@property
4439
def as_dict(self):
4540
return {
4641
"id": self.id,
4742
"api_key": self.api_key,
4843
"email": self.email,
49-
"roles": User.get_user_roles(self.id),
44+
"roles": set(role.name for role in self.roles),
5045
"created": self.created,
5146
"last_time_used": self.last_time_used
5247
}
5348

54-
@staticmethod
55-
def get_user_roles(user_id: int) -> Set[str]:
56-
with Session() as session:
57-
user = session.query(User).filter(User.id == user_id).first()
58-
return set([role.name for role in user.roles])
59-
6049
def has_role(self, required_role: str) -> bool:
61-
return required_role in User.get_user_roles(self.id)
50+
return required_role in set(role.name for role in self.roles)
6251

6352
@staticmethod
6453
def _assign_roles(user: "User", roles: Optional[Set[str]], session) -> None:
65-
# NOTE: this uses a borrowed/existing `session`, and thus does not do a `session.commit()`...
66-
# that is the responsibility of the caller!
6754
get_structured_logger("api_user_models").info("setting roles", roles=roles, user_id=user.id, api_key=user.api_key)
6855
db_user = session.query(User).filter(User.id == user.id).first()
6956
# TODO: would it be sufficient to use the passed-in `user` instead of looking up this `db_user`?
57+
# or even use this as a bound method instead of a static??
58+
# same goes for `update_user()` and `delete_user()` below...
7059
if roles:
71-
roles_to_assign = session.query(UserRole).filter(UserRole.name.in_(roles)).all()
72-
db_user.roles = roles_to_assign
60+
db_user.roles = session.query(UserRole).filter(UserRole.name.in_(roles)).all()
7361
else:
7462
db_user.roles = []
63+
session.commit()
64+
# retrieve the newly updated User object
65+
return session.query(User).filter(User.id == user.id).first()
7566

7667
@staticmethod
68+
@default_session(Session)
7769
def find_user(*, # asterisk forces explicit naming of all arguments when calling this method
78-
user_id: Optional[int] = None, api_key: Optional[str] = None, user_email: Optional[str] = None
70+
session,
71+
user_id: Optional[int] = None, api_key: Optional[str] = None, user_email: Optional[str] = None
7972
) -> "User":
8073
# NOTE: be careful, using multiple arguments could match multiple users, but this will return only one!
81-
with Session() as session:
82-
user = (
83-
session.query(User)
84-
.filter((User.id == user_id) | (User.api_key == api_key) | (User.email == user_email))
85-
.first()
86-
)
74+
user = (
75+
session.query(User)
76+
.filter((User.id == user_id) | (User.api_key == api_key) | (User.email == user_email))
77+
.first()
78+
)
8779
return user if user else None
8880

8981
@staticmethod
90-
def create_user(api_key: str, email: str, user_roles: Optional[Set[str]] = None) -> "User":
82+
@default_session(WriteSession)
83+
def create_user(api_key: str, email: str, session, user_roles: Optional[Set[str]] = None) -> "User":
9184
get_structured_logger("api_user_models").info("creating user", api_key=api_key)
92-
with Session() as session:
93-
new_user = User(api_key=api_key, email=email)
94-
# TODO: we may need to populate 'created' field/column here, if the default
95-
# specified above gets bound to the time of when that line of python was evaluated.
96-
session.add(new_user)
97-
session.commit()
98-
User._assign_roles(new_user, user_roles, session)
99-
session.commit()
100-
return new_user
85+
new_user = User(api_key=api_key, email=email)
86+
session.add(new_user)
87+
session.commit()
88+
return User._assign_roles(new_user, user_roles, session)
10189

10290
@staticmethod
91+
@default_session(WriteSession)
10392
def update_user(
10493
user: "User",
10594
email: Optional[str],
10695
api_key: Optional[str],
107-
roles: Optional[Set[str]]
96+
roles: Optional[Set[str]],
97+
session
10898
) -> "User":
10999
get_structured_logger("api_user_models").info("updating user", user_id=user.id, new_api_key=api_key)
110-
with Session() as session:
111-
user = User.find_user(user_id=user.id)
112-
if user:
113-
update_stmt = (
114-
update(User)
115-
.where(User.id == user.id)
116-
.values(api_key=api_key, email=email)
117-
)
118-
session.execute(update_stmt)
119-
User._assign_roles(user, roles, session)
120-
session.commit()
121-
return user
100+
user = User.find_user(user_id=user.id, session=session)
101+
if not user:
102+
raise Exception('user not found')
103+
update_stmt = (
104+
update(User)
105+
.where(User.id == user.id)
106+
.values(api_key=api_key, email=email)
107+
)
108+
session.execute(update_stmt)
109+
return User._assign_roles(user, roles, session)
122110

123111
@staticmethod
124-
def delete_user(user_id: int) -> None:
112+
@default_session(WriteSession)
113+
def delete_user(user_id: int, session) -> None:
125114
get_structured_logger("api_user_models").info("deleting user", user_id=user_id)
126-
with Session() as session:
127-
session.execute(delete(User).where(User.id == user_id))
128-
session.commit()
115+
session.execute(delete(User).where(User.id == user_id))
116+
session.commit()
129117

130118

131119
class UserRole(Base):
@@ -134,23 +122,23 @@ class UserRole(Base):
134122
name = Column(String(50), unique=True)
135123

136124
@staticmethod
137-
def create_role(name: str) -> None:
125+
@default_session(WriteSession)
126+
def create_role(name: str, session) -> None:
138127
get_structured_logger("api_user_models").info("creating user role", role=name)
139-
with Session() as session:
140-
session.execute(
141-
f"""
128+
# TODO: check role doesnt already exist
129+
session.execute(f"""
142130
INSERT INTO user_role (name)
143131
SELECT '{name}'
144132
WHERE NOT EXISTS
145133
(SELECT *
146134
FROM user_role
147135
WHERE name='{name}')
148-
"""
149-
)
150-
session.commit()
136+
""")
137+
session.commit()
138+
return session.query(UserRole).filter(UserRole.name == name).first()
151139

152140
@staticmethod
153-
def list_all_roles():
154-
with Session() as session:
155-
roles = session.query(UserRole).all()
141+
@default_session(Session)
142+
def list_all_roles(session):
143+
roles = session.query(UserRole).all()
156144
return [role.name for role in roles]

0 commit comments

Comments
 (0)