Skip to content

Commit 4fb920f

Browse files
committed
Fixed expired session by using new session for each User operation
1 parent 51b0ac4 commit 4fb920f

File tree

2 files changed

+66
-59
lines changed

2 files changed

+66
-59
lines changed

src/server/_db.py

-1
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,5 @@
1414
metadata = MetaData(bind=engine)
1515

1616
Session = sessionmaker(bind=engine)
17-
session = Session()
1817

1918

src/server/admin/models.py

+66-58
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
from sqlalchemy.orm import relationship
44
from copy import deepcopy
55

6-
from .._db import session
6+
from .._db import Session
77
from delphi.epidata.common.logger import get_structured_logger
88

99
from typing import Set, Optional, List
@@ -35,61 +35,61 @@ def __init__(self, api_key: str, email: str = None) -> None:
3535

3636
@staticmethod
3737
def list_users() -> List["User"]:
38-
return session.query(User).all()
38+
with Session() as session:
39+
return session.query(User).all()
3940

4041
@property
4142
def as_dict(self):
42-
user_dict = deepcopy(self.__dict__) # NOTE: changed from `self.__dict__.copy()` as it caused issues
43-
# so we dont change the internal representation of self
44-
user_dict["roles"] = self.get_user_roles
45-
try:
46-
return {k: user_dict[k] for k in ["id", "api_key", "email", "roles", "created", "last_time_used"]}
47-
except KeyError:
48-
return {
49-
"id": self.id,
50-
"api_key": self.api_key,
51-
"email": self.email,
52-
"roles": self.get_user_roles,
53-
"created": self.created,
54-
"last_time_used": self.last_time_used
55-
}
43+
return {
44+
"id": self.id,
45+
"api_key": self.api_key,
46+
"email": self.email,
47+
"roles": User.get_user_roles(self.id),
48+
"created": self.created,
49+
"last_time_used": self.last_time_used
50+
}
5651

57-
@property
58-
def get_user_roles(self) -> Set[str]:
59-
return set([role.name for role in self.roles])
52+
@staticmethod
53+
def get_user_roles(user_id: int) -> Set[str]:
54+
with Session() as session:
55+
user = session.query(User).filter(User.id == user_id).first()
56+
return set([role.name for role in user.roles])
6057

6158
def has_role(self, required_role: str) -> bool:
62-
return required_role in self.get_user_roles
59+
return required_role in User.get_user_roles(self.id)
6360

6461
@staticmethod
65-
def assign_roles(user: "User", roles: Optional[Set[str]]) -> None:
62+
def _assign_roles(user: "User", roles: Optional[Set[str]], session) -> None:
6663
get_structured_logger("api_user_models").info("setting roles", roles=roles, user_id=user.id, api_key=user.api_key)
6764
if roles:
65+
db_user = session.query(User).filter(User.id == user.id).first()
6866
roles_to_assign = session.query(UserRole).filter(UserRole.name.in_(roles)).all()
69-
user.roles = roles_to_assign
70-
session.commit()
67+
db_user.roles = roles_to_assign
7168
else:
72-
user.roles = []
73-
session.commit()
69+
db_user.roles = []
7470

7571
@staticmethod
7672
def find_user(*, # asterisk forces explicit naming of all arguments when calling this method
7773
user_id: Optional[int] = None, api_key: Optional[str] = None, user_email: Optional[str] = None
7874
) -> "User":
79-
user = (
80-
session.query(User)
81-
.filter((User.id == user_id) | (User.api_key == api_key) | (User.email == user_email))
82-
.first()
83-
)
75+
# TODO
76+
with Session() as session:
77+
user = (
78+
session.query(User)
79+
.filter((User.id == user_id) | (User.api_key == api_key) | (User.email == user_email))
80+
.first()
81+
)
8482
return user if user else None
8583

8684
@staticmethod
8785
def create_user(api_key: str, email: str, user_roles: Optional[Set[str]] = None) -> "User":
86+
# TODO
8887
get_structured_logger("api_user_models").info("creating user", api_key=api_key)
89-
new_user = User(api_key=api_key, email=email)
90-
session.add(new_user)
91-
session.commit()
92-
User.assign_roles(new_user, user_roles)
88+
with Session() as session:
89+
new_user = User(api_key=api_key, email=email)
90+
session.add(new_user)
91+
User._assign_roles(new_user, user_roles, session)
92+
session.commit()
9393
return new_user
9494

9595
@staticmethod
@@ -100,23 +100,27 @@ def update_user(
100100
roles: Optional[Set[str]]
101101
) -> "User":
102102
get_structured_logger("api_user_models").info("updating user", user_id=user.id, new_api_key=api_key)
103-
user = User.find_user(user_id=user.id)
104-
if user:
105-
update_stmt = (
106-
update(User)
107-
.where(User.id == user.id)
108-
.values(api_key=api_key, email=email)
109-
)
110-
session.execute(update_stmt)
111-
session.commit()
112-
User.assign_roles(user, roles)
103+
# TODO
104+
with Session() as session:
105+
user = User.find_user(user_id=user.id)
106+
if user:
107+
update_stmt = (
108+
update(User)
109+
.where(User.id == user.id)
110+
.values(api_key=api_key, email=email)
111+
)
112+
session.execute(update_stmt)
113+
User._assign_roles(user, roles, session)
114+
session.commit()
113115
return user
114116

115117
@staticmethod
116118
def delete_user(user_id: int) -> None:
117119
get_structured_logger("api_user_models").info("deleting user", user_id=user_id)
118-
session.execute(delete(User).where(User.id == user_id))
119-
session.commit()
120+
# TODO
121+
with Session() as session:
122+
session.execute(delete(User).where(User.id == user_id))
123+
session.commit()
120124

121125

122126
class UserRole(Base):
@@ -127,19 +131,23 @@ class UserRole(Base):
127131
@staticmethod
128132
def create_role(name: str) -> None:
129133
get_structured_logger("api_user_models").info("creating user role", role=name)
130-
session.execute(
131-
f"""
132-
INSERT INTO user_role (name)
133-
SELECT '{name}'
134-
WHERE NOT EXISTS
135-
(SELECT *
136-
FROM user_role
137-
WHERE name='{name}')
138-
"""
139-
)
140-
session.commit()
134+
# TODO
135+
with Session() as session:
136+
session.execute(
137+
f"""
138+
INSERT INTO user_role (name)
139+
SELECT '{name}'
140+
WHERE NOT EXISTS
141+
(SELECT *
142+
FROM user_role
143+
WHERE name='{name}')
144+
"""
145+
)
146+
session.commit()
141147

142148
@staticmethod
143149
def list_all_roles():
144-
roles = session.query(UserRole).all()
150+
# TODO
151+
with Session() as session:
152+
roles = session.query(UserRole).all()
145153
return [role.name for role in roles]

0 commit comments

Comments
 (0)