Skip to content

Commit 52a80c7

Browse files
committed
refactor(python): extract database connection code to a shared method
See https://fastapi.tiangolo.com/tutorial/dependencies/ Part of #16
1 parent edb72c8 commit 52a80c7

File tree

2 files changed

+23
-36
lines changed

2 files changed

+23
-36
lines changed

examples/python/routes.py

+11-27
Original file line numberDiff line numberDiff line change
@@ -2,19 +2,21 @@
22
import psycopg2
33
import psycopg2.extras
44

5-
from fastapi import APIRouter, HTTPException
5+
from fastapi import APIRouter, Depends, HTTPException
66

77
router = APIRouter()
88

9-
10-
@router.get('/v1/categories/count')
11-
def get_v1_categories_count():
12-
conn = psycopg2.connect(
9+
async def db_connection():
10+
return psycopg2.connect(
1311
database = os.getenv('DB_NAME'),
1412
user = os.getenv('DB_USER'),
1513
password = os.getenv('DB_PASSWORD'),
1614
host = os.getenv('DB_HOST', 'localhost'),
1715
port = 5432)
16+
17+
18+
@router.get('/v1/categories/count')
19+
def get_v1_categories_count(conn = Depends(db_connection)):
1820
try:
1921
with conn:
2022
with conn.cursor(cursor_factory = psycopg2.extras.RealDictCursor) as cur:
@@ -27,13 +29,7 @@ def get_v1_categories_count():
2729
conn.close()
2830

2931
@router.get('/v1/categories/stat')
30-
def get_v1_categories_stat():
31-
conn = psycopg2.connect(
32-
database = os.getenv('DB_NAME'),
33-
user = os.getenv('DB_USER'),
34-
password = os.getenv('DB_PASSWORD'),
35-
host = os.getenv('DB_HOST', 'localhost'),
36-
port = 5432)
32+
def get_v1_categories_stat(conn = Depends(db_connection)):
3733
try:
3834
with conn:
3935
with conn.cursor(cursor_factory = psycopg2.extras.DictCursor) as cur:
@@ -51,13 +47,7 @@ def get_v1_categories_stat():
5147
conn.close()
5248

5349
@router.get('/v1/collections/{collectionId}/categories/count')
54-
def get_v1_collections_collection_id_categories_count(collectionId):
55-
conn = psycopg2.connect(
56-
database = os.getenv('DB_NAME'),
57-
user = os.getenv('DB_USER'),
58-
password = os.getenv('DB_PASSWORD'),
59-
host = os.getenv('DB_HOST', 'localhost'),
60-
port = 5432)
50+
def get_v1_collections_collection_id_categories_count(collectionId, conn = Depends(db_connection)):
6151
try:
6252
with conn:
6353
with conn.cursor(cursor_factory = psycopg2.extras.RealDictCursor) as cur:
@@ -70,21 +60,15 @@ def get_v1_collections_collection_id_categories_count(collectionId):
7060
conn.close()
7161

7262
@router.get('/v1/categories')
73-
def get_list_v1_categories():
63+
def get_list_v1_categories(conn = Depends(db_connection)):
7464
pass
7565

7666
@router.post('/v1/categories')
7767
def post_v1_categories():
7868
pass
7969

8070
@router.get('/v1/categories/{categoryId}')
81-
def get_v1_categories_category_id(categoryId):
82-
conn = psycopg2.connect(
83-
database = os.getenv('DB_NAME'),
84-
user = os.getenv('DB_USER'),
85-
password = os.getenv('DB_PASSWORD'),
86-
host = os.getenv('DB_HOST', 'localhost'),
87-
port = 5432)
71+
def get_v1_categories_category_id(categoryId, conn = Depends(db_connection)):
8872
try:
8973
with conn:
9074
with conn.cursor(cursor_factory = psycopg2.extras.RealDictCursor) as cur:

src/templates/routes.py.ejs

+12-9
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,18 @@ import os
22
import psycopg2
33
import psycopg2.extras
44

5-
from fastapi import APIRouter, HTTPException
5+
from fastapi import APIRouter, Depends, HTTPException
66

77
router = APIRouter()
88

9+
async def db_connection():
10+
return psycopg2.connect(
11+
database = os.getenv('DB_NAME'),
12+
user = os.getenv('DB_USER'),
13+
password = os.getenv('DB_PASSWORD'),
14+
host = os.getenv('DB_HOST', 'localhost'),
15+
port = 5432)
16+
917
<%
1018
// { "get", "/v1/categories/:categoryId" } => "get_v1_categories_category_id"
1119
function generate_method_name(method, path) {
@@ -27,7 +35,8 @@ function convertToFastApiPath(path) {
2735
2836
endpoints.forEach(function(endpoint) {
2937
const path = convertToFastApiPath(endpoint.path)
30-
const paramsFromPath = extractParamsFromPath(endpoint.path)
38+
const methodArgs = extractParamsFromPath(endpoint.path)
39+
methodArgs.push('conn = Depends(db_connection)')
3140
3241
endpoint.methods.forEach(function(method) {
3342
const hasGetOne = method.name === 'get'
@@ -59,14 +68,8 @@ endpoints.forEach(function(endpoint) {
5968
if (hasGetOne || hasGetMany) {
6069
%>
6170
@router.get('<%- path %>')
62-
def <%- pythonMethodName %>(<%- paramsFromPath.join(', ') %>):
71+
def <%- pythonMethodName %>(<%- methodArgs.join(', ') %>):
6372
<% if (hasGetOne) { -%>
64-
conn = psycopg2.connect(
65-
database = os.getenv('DB_NAME'),
66-
user = os.getenv('DB_USER'),
67-
password = os.getenv('DB_PASSWORD'),
68-
host = os.getenv('DB_HOST', 'localhost'),
69-
port = 5432)
7073
try:
7174
with conn:
7275
<% if (queries.length > 1) { /* we can omit cursor_factory but in this case we might get an unused import */-%>

0 commit comments

Comments
 (0)