3
3
from sqlalchemy .orm import relationship
4
4
from copy import deepcopy
5
5
6
- from .._db import Session
6
+ from .._db import Session , WriteSession , default_session
7
7
from delphi .epidata .common .logger import get_structured_logger
8
8
9
9
from typing import Set , Optional , List
@@ -25,7 +25,7 @@ def _default_date_now():
25
25
class User (Base ):
26
26
__tablename__ = "api_user"
27
27
id = Column (Integer , primary_key = True , autoincrement = True )
28
- roles = relationship ("UserRole" , secondary = association_table )
28
+ roles = relationship ("UserRole" , secondary = association_table , lazy = "joined" ) # last arg does an eager load of this property from foreign tables
29
29
api_key = Column (String (50 ), unique = True , nullable = False )
30
30
email = Column (String (320 ), unique = True , nullable = False )
31
31
created = Column (Date , default = _default_date_now )
@@ -35,97 +35,85 @@ def __init__(self, api_key: str, email: str = None) -> None:
35
35
self .api_key = api_key
36
36
self .email = email
37
37
38
- @staticmethod
39
- def list_users () -> List ["User" ]:
40
- with Session () as session :
41
- return session .query (User ).all ()
42
-
43
38
@property
44
39
def as_dict (self ):
45
40
return {
46
41
"id" : self .id ,
47
42
"api_key" : self .api_key ,
48
43
"email" : self .email ,
49
- "roles" : User . get_user_roles ( self .id ),
44
+ "roles" : set ( role . name for role in self .roles ),
50
45
"created" : self .created ,
51
46
"last_time_used" : self .last_time_used
52
47
}
53
48
54
- @staticmethod
55
- def get_user_roles (user_id : int ) -> Set [str ]:
56
- with Session () as session :
57
- user = session .query (User ).filter (User .id == user_id ).first ()
58
- return set ([role .name for role in user .roles ])
59
-
60
49
def has_role (self , required_role : str ) -> bool :
61
- return required_role in User . get_user_roles ( self .id )
50
+ return required_role in set ( role . name for role in self .roles )
62
51
63
52
@staticmethod
64
53
def _assign_roles (user : "User" , roles : Optional [Set [str ]], session ) -> None :
65
- # NOTE: this uses a borrowed/existing `session`, and thus does not do a `session.commit()`...
66
- # that is the responsibility of the caller!
67
54
get_structured_logger ("api_user_models" ).info ("setting roles" , roles = roles , user_id = user .id , api_key = user .api_key )
68
55
db_user = session .query (User ).filter (User .id == user .id ).first ()
69
56
# TODO: would it be sufficient to use the passed-in `user` instead of looking up this `db_user`?
57
+ # or even use this as a bound method instead of a static??
58
+ # same goes for `update_user()` and `delete_user()` below...
70
59
if roles :
71
- roles_to_assign = session .query (UserRole ).filter (UserRole .name .in_ (roles )).all ()
72
- db_user .roles = roles_to_assign
60
+ db_user .roles = session .query (UserRole ).filter (UserRole .name .in_ (roles )).all ()
73
61
else :
74
62
db_user .roles = []
63
+ session .commit ()
64
+ # retrieve the newly updated User object
65
+ return session .query (User ).filter (User .id == user .id ).first ()
75
66
76
67
@staticmethod
68
+ @default_session (Session )
77
69
def find_user (* , # asterisk forces explicit naming of all arguments when calling this method
78
- user_id : Optional [int ] = None , api_key : Optional [str ] = None , user_email : Optional [str ] = None
70
+ session ,
71
+ user_id : Optional [int ] = None , api_key : Optional [str ] = None , user_email : Optional [str ] = None
79
72
) -> "User" :
80
73
# NOTE: be careful, using multiple arguments could match multiple users, but this will return only one!
81
- with Session () as session :
82
- user = (
83
- session .query (User )
84
- .filter ((User .id == user_id ) | (User .api_key == api_key ) | (User .email == user_email ))
85
- .first ()
86
- )
74
+ user = (
75
+ session .query (User )
76
+ .filter ((User .id == user_id ) | (User .api_key == api_key ) | (User .email == user_email ))
77
+ .first ()
78
+ )
87
79
return user if user else None
88
80
89
81
@staticmethod
90
- def create_user (api_key : str , email : str , user_roles : Optional [Set [str ]] = None ) -> "User" :
82
+ @default_session (WriteSession )
83
+ def create_user (api_key : str , email : str , session , user_roles : Optional [Set [str ]] = None ) -> "User" :
91
84
get_structured_logger ("api_user_models" ).info ("creating user" , api_key = api_key )
92
- with Session () as session :
93
- new_user = User (api_key = api_key , email = email )
94
- # TODO: we may need to populate 'created' field/column here, if the default
95
- # specified above gets bound to the time of when that line of python was evaluated.
96
- session .add (new_user )
97
- session .commit ()
98
- User ._assign_roles (new_user , user_roles , session )
99
- session .commit ()
100
- return new_user
85
+ new_user = User (api_key = api_key , email = email )
86
+ session .add (new_user )
87
+ session .commit ()
88
+ return User ._assign_roles (new_user , user_roles , session )
101
89
102
90
@staticmethod
91
+ @default_session (WriteSession )
103
92
def update_user (
104
93
user : "User" ,
105
94
email : Optional [str ],
106
95
api_key : Optional [str ],
107
- roles : Optional [Set [str ]]
96
+ roles : Optional [Set [str ]],
97
+ session
108
98
) -> "User" :
109
99
get_structured_logger ("api_user_models" ).info ("updating user" , user_id = user .id , new_api_key = api_key )
110
- with Session () as session :
111
- user = User .find_user (user_id = user .id )
112
- if user :
113
- update_stmt = (
114
- update (User )
115
- .where (User .id == user .id )
116
- .values (api_key = api_key , email = email )
117
- )
118
- session .execute (update_stmt )
119
- User ._assign_roles (user , roles , session )
120
- session .commit ()
121
- return user
100
+ user = User .find_user (user_id = user .id , session = session )
101
+ if not user :
102
+ raise Exception ('user not found' )
103
+ update_stmt = (
104
+ update (User )
105
+ .where (User .id == user .id )
106
+ .values (api_key = api_key , email = email )
107
+ )
108
+ session .execute (update_stmt )
109
+ return User ._assign_roles (user , roles , session )
122
110
123
111
@staticmethod
124
- def delete_user (user_id : int ) -> None :
112
+ @default_session (WriteSession )
113
+ def delete_user (user_id : int , session ) -> None :
125
114
get_structured_logger ("api_user_models" ).info ("deleting user" , user_id = user_id )
126
- with Session () as session :
127
- session .execute (delete (User ).where (User .id == user_id ))
128
- session .commit ()
115
+ session .execute (delete (User ).where (User .id == user_id ))
116
+ session .commit ()
129
117
130
118
131
119
class UserRole (Base ):
@@ -134,23 +122,23 @@ class UserRole(Base):
134
122
name = Column (String (50 ), unique = True )
135
123
136
124
@staticmethod
137
- def create_role (name : str ) -> None :
125
+ @default_session (WriteSession )
126
+ def create_role (name : str , session ) -> None :
138
127
get_structured_logger ("api_user_models" ).info ("creating user role" , role = name )
139
- with Session () as session :
140
- session .execute (
141
- f"""
128
+ # TODO: check role doesnt already exist
129
+ session .execute (f"""
142
130
INSERT INTO user_role (name)
143
131
SELECT '{ name } '
144
132
WHERE NOT EXISTS
145
133
(SELECT *
146
134
FROM user_role
147
135
WHERE name='{ name } ')
148
- """
149
- )
150
- session . commit ()
136
+ """ )
137
+ session . commit ( )
138
+ return session . query ( UserRole ). filter ( UserRole . name == name ). first ()
151
139
152
140
@staticmethod
153
- def list_all_roles ():
154
- with Session () as session :
155
- roles = session .query (UserRole ).all ()
141
+ @ default_session ( Session )
142
+ def list_all_roles ( session ) :
143
+ roles = session .query (UserRole ).all ()
156
144
return [role .name for role in roles ]
0 commit comments