Skip to content
This repository was archived by the owner on Feb 7, 2024. It is now read-only.

Commit 1524ecb

Browse files
committed
Updated fastapi-users sqlalchemy plugin
There were a large number of follow-on effects that had to be resolved
1 parent 9f345ba commit 1524ecb

29 files changed

+564
-476
lines changed

docs/coding/backend.md

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,11 +20,10 @@ The backend uses [FastAPI](fastapi.tiangolo.com/), an async [Python](https://www
2020
2. Install [homebrew](https://brew.sh/), a package manager for Mac OS.
2121
3. In a terminal, run:
2222
```
23-
brew install pyenv libpq openssl
23+
brew install pyenv
2424
```
25-
4. Follow the printed instructions to make sure [pyenv](https://github.com/pyenv/pyenv) and libpq are initialized when your terminal starts up. It should be something like:
25+
4. Follow the printed instructions to make sure [pyenv](https://github.com/pyenv/pyenv) is initialized when your terminal starts up. It should be something like:
2626
```
27-
echo 'export PATH="/usr/local/opt/libpq/bin:$PATH"' >> ~/.zshrc
2827
echo 'eval "$(pyenv init --path)"' >> ~/.zprofile
2928
echo 'eval "$(pyenv init -)"' >> ~/.zshrc
3029
```

fanviddb/api_keys/db.py

Lines changed: 0 additions & 24 deletions
This file was deleted.

fanviddb/api_keys/helpers.py

Lines changed: 28 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -8,64 +8,73 @@
88
from passlib import pwd # type: ignore
99
from passlib.context import CryptContext # type: ignore
1010
from sqlalchemy import select
11+
from sqlalchemy.exc import IntegrityError
12+
from sqlalchemy.ext.asyncio import AsyncSession
1113
from starlette.exceptions import HTTPException
1214
from starlette.status import HTTP_401_UNAUTHORIZED
1315

14-
from fanviddb.db import database
16+
from fanviddb.db import get_async_session
1517

16-
from . import db
18+
from . import models
1719

1820
X_API_KEY = APIKeyHeader(name="X-API-Key", auto_error=False)
1921
api_key_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
2022

2123

22-
async def generate() -> str:
24+
async def generate(session: AsyncSession) -> str:
2325
api_key = None
2426
attempts = 0
2527
while api_key is None:
2628
pk = pwd.genword(length=8, charset="ascii_50")
2729
secret = pwd.genword(entropy="secure", charset="ascii_50")
2830
api_key = f"{pk}_{secret}"
2931

30-
query = db.api_keys.insert().values(
32+
instance = models.ApiKeyTable(
3133
pk=pk,
3234
hashed_api_key=api_key_context.hash(api_key),
3335
created_timestamp=datetime.datetime.utcnow(),
3436
state="active",
3537
)
36-
transaction = await database.transaction()
37-
try:
38-
await database.execute(query)
39-
except UniqueViolationError:
40-
api_key = None
41-
await transaction.rollback()
42-
else:
43-
await transaction.commit()
38+
async with session.begin_nested() as nested:
39+
try:
40+
session.add(instance)
41+
await nested.commit()
42+
except IntegrityError as exc:
43+
await nested.rollback()
44+
# This is necessary because of how sqlalchemy currently
45+
# nests errors.
46+
if exc.orig.sqlstate != UniqueViolationError.sqlstate:
47+
raise
48+
api_key = None
4449
attempts += 1
4550
if attempts > 10:
4651
raise ValueError(_("Too many collisions"))
4752

4853
return api_key
4954

5055

51-
async def verify(api_key: str):
56+
async def verify(session: AsyncSession, api_key: str):
5257
try:
5358
pk, _ = api_key.split("_")
5459
except ValueError:
5560
return False
56-
query = select([db.api_keys]).where(db.api_keys.c.pk == pk)
57-
result = await database.fetch_one(query)
58-
if result is None:
61+
query = select([models.api_keys]).where(models.api_keys.c.pk == pk)
62+
result = await session.execute(query)
63+
row = result.first()
64+
if row is None:
5965
return False
60-
return api_key_context.verify(api_key, result["hashed_api_key"])
66+
return api_key_context.verify(api_key, row.hashed_api_key)
6167

6268

63-
async def check_api_key_header(api_key: Optional[str] = Depends(X_API_KEY)):
69+
async def check_api_key_header(
70+
session: AsyncSession = Depends(get_async_session),
71+
api_key: Optional[str] = Depends(X_API_KEY),
72+
):
6473
# If no API key is present, that's fine; however, an invalid or revoked api key
6574
# is always an error.
6675
if api_key is None:
6776
return False
68-
is_valid = await verify(api_key)
77+
is_valid = await verify(session, api_key)
6978
if not is_valid:
7079
raise HTTPException(
7180
status_code=HTTP_401_UNAUTHORIZED, detail=_("Invalid api key")

fanviddb/api_keys/models.py

Lines changed: 22 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,24 @@
1-
from pydantic import BaseModel
1+
from sqlalchemy import Column
2+
from sqlalchemy import DateTime
3+
from sqlalchemy import String
24

5+
from fanviddb.db import Base
36

4-
class ApiKey(BaseModel):
5-
api_key: str
7+
8+
class ApiKeyTable(Base):
9+
10+
__tablename__ = "api_keys"
11+
12+
# pk functions as a public "username" so that we can find the
13+
# correct hashed secret to check.
14+
pk = Column(String(), primary_key=True)
15+
hashed_api_key = Column(String(), nullable=False)
16+
17+
# Admin-only
18+
state = Column(String())
19+
20+
# Internal
21+
created_timestamp = Column(DateTime())
22+
23+
24+
api_keys = ApiKeyTable.__table__

fanviddb/api_keys/router.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,16 @@
11
from fastapi import APIRouter
2+
from fastapi import Depends
3+
from sqlalchemy.ext.asyncio import AsyncSession
4+
5+
from fanviddb.db import get_async_session
26

37
from .helpers import generate
4-
from .models import ApiKey
8+
from .schema import ApiKey
59

610
api_key_router = APIRouter()
711

812

913
@api_key_router.post("", response_model=ApiKey)
10-
async def create_api_key():
11-
api_key = await generate()
14+
async def create_api_key(session: AsyncSession = Depends(get_async_session)):
15+
api_key = await generate(session)
1216
return {"api_key": api_key}

fanviddb/api_keys/schema.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
from pydantic import BaseModel
2+
3+
4+
class ApiKey(BaseModel):
5+
api_key: str

fanviddb/app.py

Lines changed: 11 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@
88
from .api_keys.router import api_key_router
99
from .auth.routers import auth_router
1010
from .auth.routers import users_router
11-
from .db import database
1211
from .email import EmailSendFailed
1312
from .fanvids.router import router as fanvid_router
1413

@@ -26,42 +25,42 @@ async def check_config(self) -> None:
2625

2726
main_app = Starlette()
2827

29-
api = FastAPI(docs_url=None)
28+
api_app = FastAPI(docs_url=None)
3029

31-
api.include_router(
30+
api_app.include_router(
3231
fanvid_router,
3332
prefix="/fanvids",
3433
tags=["Fanvids"],
3534
)
36-
api.include_router(
35+
api_app.include_router(
3736
auth_router,
3837
prefix="/auth",
3938
tags=["Auth"],
4039
)
41-
api.include_router(
40+
api_app.include_router(
4241
users_router,
4342
prefix="/users",
4443
tags=["Users"],
4544
)
46-
api.include_router(
45+
api_app.include_router(
4746
api_key_router,
4847
prefix="/api_keys",
4948
tags=["API Keys"],
5049
)
5150

5251

53-
frontend = Starlette()
52+
frontend_app = Starlette()
5453

5554

56-
@frontend.middleware("http")
55+
@frontend_app.middleware("http")
5756
async def default_response(request, call_next):
5857
response = await call_next(request)
5958
if response.status_code == 404:
6059
return FileResponse("frontend/build/index.html")
6160
return response
6261

6362

64-
frontend.mount(
63+
frontend_app.mount(
6564
"/", HelpfulStaticFiles(directory="frontend/build/", html=True, check_dir=False)
6665
)
6766
main_app.mount(
@@ -78,27 +77,17 @@ async def default_response(request, call_next):
7877
)
7978
main_app.mount(
8079
"/api",
81-
api,
80+
api_app,
8281
name="api",
8382
)
8483
main_app.mount(
8584
"/",
86-
frontend,
85+
frontend_app,
8786
name="frontend",
8887
)
8988

9089

91-
@main_app.on_event("startup")
92-
async def startup():
93-
await database.connect()
94-
95-
96-
@main_app.on_event("shutdown")
97-
async def shutdown():
98-
await database.disconnect()
99-
100-
101-
@api.exception_handler(EmailSendFailed)
90+
@api_app.exception_handler(EmailSendFailed)
10291
async def email_send_failed_handler(__: Request, exc: EmailSendFailed):
10392
return JSONResponse(
10493
status_code=503,

fanviddb/auth/db.py

Lines changed: 0 additions & 20 deletions
This file was deleted.

fanviddb/auth/helpers.py

Lines changed: 17 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -12,22 +12,23 @@
1212
from fastapi_users.authentication import AuthenticationBackend
1313
from fastapi_users.authentication import CookieTransport
1414
from fastapi_users.authentication import JWTStrategy
15+
from fastapi_users.db import SQLAlchemyUserDatabase
16+
from sqlalchemy.ext.asyncio import AsyncSession
1517
from sqlalchemy.sql import exists
1618
from sqlalchemy.sql import select
1719
from zxcvbn import zxcvbn # type: ignore
1820

1921
from fanviddb import conf
20-
from fanviddb.db import database
22+
from fanviddb.db import get_async_session
2123
from fanviddb.email import send_email
2224
from fanviddb.i18n.utils import get_fluent
2325
from fanviddb.i18n.utils import get_request_locales
2426

25-
from .db import get_user_db
26-
from .db import users
27-
from .models import User
28-
from .models import UserCreate
29-
from .models import UserDB
30-
from .models import UserUpdate
27+
from .models import UserTable
28+
from .schema import User
29+
from .schema import UserCreate
30+
from .schema import UserDB
31+
from .schema import UserUpdate
3132

3233
AUTH_LIFETIME = 60 * 60 * 24 * 14
3334
cookie_transport = CookieTransport(
@@ -53,6 +54,7 @@ class UserManager(BaseUserManager[UserCreate, UserDB]):
5354
reset_password_token_lifetime_seconds = 60 * 5
5455
verification_token_secret = conf.EMAIL_TOKEN_SECRET_KEY
5556
verification_token_lifetime_seconds = 60 * 5
57+
user_db: SQLAlchemyUserDatabase
5658

5759
async def on_after_forgot_password(
5860
self, user: UserDB, token: str, request: Optional[Request] = None
@@ -130,16 +132,21 @@ async def create(
130132
request: Optional[Request] = None,
131133
) -> UserDB:
132134
user = cast(UserCreate, user)
133-
query = select([exists().where(users.c.username == user.username)])
134-
result = await database.fetch_one(query)
135-
if result and result.get("anon_1"):
135+
query = select([exists().where(UserTable.username == user.username)])
136+
result = await self.user_db.session.execute(query)
137+
row = result.first()
138+
if row and row._mapping.get("anon_1"):
136139
raise HTTPException(
137140
status_code=status.HTTP_400_BAD_REQUEST,
138141
detail="REGISTER_USERNAME_ALREADY_EXISTS",
139142
)
140143
return await super().create(user, safe, request)
141144

142145

146+
async def get_user_db(session: AsyncSession = Depends(get_async_session)):
147+
yield SQLAlchemyUserDatabase(UserDB, session, UserTable)
148+
149+
143150
async def get_user_manager(user_db=Depends(get_user_db)):
144151
yield UserManager(user_db)
145152

fanviddb/auth/models.py

Lines changed: 7 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,12 @@
1-
from typing import Optional
1+
from fastapi_users.db import SQLAlchemyBaseUserTable
2+
from sqlalchemy import Column
3+
from sqlalchemy import String
24

3-
from fastapi_users import models
5+
from fanviddb.db import Base
46

57

6-
class User(models.BaseUser):
7-
username: Optional[str]
8+
class UserTable(Base, SQLAlchemyBaseUserTable):
9+
username = Column(String(length=40), nullable=False, unique=True)
810

911

10-
class UserCreate(models.BaseUserCreate):
11-
username: str
12-
13-
14-
class UserUpdate(models.BaseUserUpdate):
15-
username: Optional[str]
16-
17-
18-
class UserDB(User, models.BaseUserDB):
19-
username: str
12+
users = UserTable.__table__ # type: ignore

0 commit comments

Comments
 (0)