3
3
from sqlalchemy .orm import relationship
4
4
from copy import deepcopy
5
5
6
- from .._db import session
6
+ from .._db import Session
7
7
from delphi .epidata .common .logger import get_structured_logger
8
8
9
9
from typing import Set , Optional , List
@@ -35,61 +35,61 @@ def __init__(self, api_key: str, email: str = None) -> None:
35
35
36
36
@staticmethod
37
37
def list_users () -> List ["User" ]:
38
- return session .query (User ).all ()
38
+ with Session () as session :
39
+ return session .query (User ).all ()
39
40
40
41
@property
41
42
def as_dict (self ):
42
- user_dict = deepcopy (self .__dict__ ) # NOTE: changed from `self.__dict__.copy()` as it caused issues
43
- # so we dont change the internal representation of self
44
- user_dict ["roles" ] = self .get_user_roles
45
- try :
46
- return {k : user_dict [k ] for k in ["id" , "api_key" , "email" , "roles" , "created" , "last_time_used" ]}
47
- except KeyError :
48
- return {
49
- "id" : self .id ,
50
- "api_key" : self .api_key ,
51
- "email" : self .email ,
52
- "roles" : self .get_user_roles ,
53
- "created" : self .created ,
54
- "last_time_used" : self .last_time_used
55
- }
43
+ return {
44
+ "id" : self .id ,
45
+ "api_key" : self .api_key ,
46
+ "email" : self .email ,
47
+ "roles" : User .get_user_roles (self .id ),
48
+ "created" : self .created ,
49
+ "last_time_used" : self .last_time_used
50
+ }
56
51
57
- @property
58
- def get_user_roles (self ) -> Set [str ]:
59
- return set ([role .name for role in self .roles ])
52
+ @staticmethod
53
+ def get_user_roles (user_id : int ) -> Set [str ]:
54
+ with Session () as session :
55
+ user = session .query (User ).filter (User .id == user_id ).first ()
56
+ return set ([role .name for role in user .roles ])
60
57
61
58
def has_role (self , required_role : str ) -> bool :
62
- return required_role in self .get_user_roles
59
+ return required_role in User .get_user_roles ( self . id )
63
60
64
61
@staticmethod
65
- def assign_roles (user : "User" , roles : Optional [Set [str ]]) -> None :
62
+ def _assign_roles (user : "User" , roles : Optional [Set [str ]], session ) -> None :
66
63
get_structured_logger ("api_user_models" ).info ("setting roles" , roles = roles , user_id = user .id , api_key = user .api_key )
67
64
if roles :
65
+ db_user = session .query (User ).filter (User .id == user .id ).first ()
68
66
roles_to_assign = session .query (UserRole ).filter (UserRole .name .in_ (roles )).all ()
69
- user .roles = roles_to_assign
70
- session .commit ()
67
+ db_user .roles = roles_to_assign
71
68
else :
72
- user .roles = []
73
- session .commit ()
69
+ db_user .roles = []
74
70
75
71
@staticmethod
76
72
def find_user (* , # asterisk forces explicit naming of all arguments when calling this method
77
73
user_id : Optional [int ] = None , api_key : Optional [str ] = None , user_email : Optional [str ] = None
78
74
) -> "User" :
79
- user = (
80
- session .query (User )
81
- .filter ((User .id == user_id ) | (User .api_key == api_key ) | (User .email == user_email ))
82
- .first ()
83
- )
75
+ # TODO
76
+ with Session () as session :
77
+ user = (
78
+ session .query (User )
79
+ .filter ((User .id == user_id ) | (User .api_key == api_key ) | (User .email == user_email ))
80
+ .first ()
81
+ )
84
82
return user if user else None
85
83
86
84
@staticmethod
87
85
def create_user (api_key : str , email : str , user_roles : Optional [Set [str ]] = None ) -> "User" :
86
+ # TODO
88
87
get_structured_logger ("api_user_models" ).info ("creating user" , api_key = api_key )
89
- new_user = User (api_key = api_key , email = email )
90
- session .add (new_user )
91
- session .commit ()
92
- User .assign_roles (new_user , user_roles )
88
+ with Session () as session :
89
+ new_user = User (api_key = api_key , email = email )
90
+ session .add (new_user )
91
+ User ._assign_roles (new_user , user_roles , session )
92
+ session .commit ()
93
93
return new_user
94
94
95
95
@staticmethod
@@ -100,23 +100,27 @@ def update_user(
100
100
roles : Optional [Set [str ]]
101
101
) -> "User" :
102
102
get_structured_logger ("api_user_models" ).info ("updating user" , user_id = user .id , new_api_key = api_key )
103
- user = User .find_user (user_id = user .id )
104
- if user :
105
- update_stmt = (
106
- update (User )
107
- .where (User .id == user .id )
108
- .values (api_key = api_key , email = email )
109
- )
110
- session .execute (update_stmt )
111
- session .commit ()
112
- User .assign_roles (user , roles )
103
+ # TODO
104
+ with Session () as session :
105
+ user = User .find_user (user_id = user .id )
106
+ if user :
107
+ update_stmt = (
108
+ update (User )
109
+ .where (User .id == user .id )
110
+ .values (api_key = api_key , email = email )
111
+ )
112
+ session .execute (update_stmt )
113
+ User ._assign_roles (user , roles , session )
114
+ session .commit ()
113
115
return user
114
116
115
117
@staticmethod
116
118
def delete_user (user_id : int ) -> None :
117
119
get_structured_logger ("api_user_models" ).info ("deleting user" , user_id = user_id )
118
- session .execute (delete (User ).where (User .id == user_id ))
119
- session .commit ()
120
+ # TODO
121
+ with Session () as session :
122
+ session .execute (delete (User ).where (User .id == user_id ))
123
+ session .commit ()
120
124
121
125
122
126
class UserRole (Base ):
@@ -127,19 +131,23 @@ class UserRole(Base):
127
131
@staticmethod
128
132
def create_role (name : str ) -> None :
129
133
get_structured_logger ("api_user_models" ).info ("creating user role" , role = name )
130
- session .execute (
131
- f"""
132
- INSERT INTO user_role (name)
133
- SELECT '{ name } '
134
- WHERE NOT EXISTS
135
- (SELECT *
136
- FROM user_role
137
- WHERE name='{ name } ')
138
- """
139
- )
140
- session .commit ()
134
+ # TODO
135
+ with Session () as session :
136
+ session .execute (
137
+ f"""
138
+ INSERT INTO user_role (name)
139
+ SELECT '{ name } '
140
+ WHERE NOT EXISTS
141
+ (SELECT *
142
+ FROM user_role
143
+ WHERE name='{ name } ')
144
+ """
145
+ )
146
+ session .commit ()
141
147
142
148
@staticmethod
143
149
def list_all_roles ():
144
- roles = session .query (UserRole ).all ()
150
+ # TODO
151
+ with Session () as session :
152
+ roles = session .query (UserRole ).all ()
145
153
return [role .name for role in roles ]
0 commit comments