Skip to content

Implement API key scrubbing and structured logging #806

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 14 commits into from
Feb 23, 2022
3 changes: 2 additions & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -10,4 +10,5 @@ newrelic==6.8.0.163
requests==2.26.0
epiweeks==2.1.2
Flask-Limiter==1.4
redis==3.5.3
redis==3.5.3
structlog
94 changes: 94 additions & 0 deletions src/server/_logger.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
"""Structured logger utility for creating JSON logs in web server logs"""
import logging
import sys
import threading
import structlog


def handle_exceptions(logger):
"""Handle exceptions using the provided logger."""
def exception_handler(etype, value, traceback):
logger.exception("Top-level exception occurred",
exc_info=(etype, value, traceback))

def multithread_exception_handler(args):
exception_handler(args.exc_type, args.exc_value, args.exc_traceback)

sys.excepthook = exception_handler
threading.excepthook = multithread_exception_handler


def get_structured_logger(name=__name__,
filename=None,
log_exceptions=True):
"""Create a new structlog logger.

Use the logger returned from this in server code using the standard
wrapper calls, e.g.:

logger = get_structured_logger(__name__)
logger.warning("Error", type="Signal too low").

The output will be rendered as JSON which can easily be consumed by logs
processors.

See the structlog documentation for details.

Parameters
---------
name: Name to use for logger (included in log lines), __name__ from caller
is a good choice.
filename: An (optional) file to write log output.
"""
# Configure the underlying logging configuration
handlers = [logging.StreamHandler()]
if filename:
handlers.append(logging.FileHandler(filename))

logging.basicConfig(
format="%(message)s",
level=logging.INFO,
handlers=handlers
)



# Configure structlog. This uses many of the standard suggestions from
# the structlog documentation.
structlog.configure(
processors=[
# Filter out log levels we are not tracking.
structlog.stdlib.filter_by_level,
# Include logger name in output.
structlog.stdlib.add_logger_name,
# Include log level in output.
structlog.stdlib.add_log_level,
# Allow formatting into arguments e.g., logger.info("Hello, %s",
# name)
structlog.stdlib.PositionalArgumentsFormatter(),
# Add timestamps.
structlog.processors.TimeStamper(fmt="iso"),
# Match support for exception logging in the standard logger.
structlog.processors.StackInfoRenderer(),
structlog.processors.format_exc_info,
# Decode unicode characters
structlog.processors.UnicodeDecoder(),
# Render as JSON
structlog.processors.JSONRenderer()
],
# Use a dict class for keeping track of data.
context_class=dict,
# Use a standard logger for the actual log call.
logger_factory=structlog.stdlib.LoggerFactory(),
# Use a standard wrapper class for utilities like log.warning()
wrapper_class=structlog.stdlib.BoundLogger,
# Cache the logger
cache_logger_on_first_use=True,
)

logger = structlog.get_logger(name)

if log_exceptions:
handle_exceptions(logger)

return logger
2 changes: 1 addition & 1 deletion src/server/_query.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,7 +235,7 @@ def run_query(p: APrinter, query_tuple: Tuple[str, Dict[str, Any]]):
query, params = query_tuple
# limit rows + 1 for detecting whether we would have more
full_query = text(f"{query} LIMIT {p.remaining_rows + 1}")
current_user.log_info("full_query: %s, params: %s", full_query, params)
current_user.log_info("Running query", full_query=str(full_query), params=params)
return db.execution_options(stream_results=True).execute(full_query, **params)


Expand Down
49 changes: 37 additions & 12 deletions src/server/_security.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
from ._common import app, request, db
from ._exceptions import MissingAPIKeyException, UnAuthenticatedException
from ._db import metadata, TABLE_OPTIONS
from ._logger import get_structured_logger
import re

API_KEY_HARD_WARNING = API_KEY_REQUIRED_STARTING_AT - timedelta(days=14)
API_KEY_SOFT_WARNING = API_KEY_HARD_WARNING - timedelta(days=14)
Expand Down Expand Up @@ -130,8 +132,8 @@ class UserRole(str, Enum):

class User:
api_key: str
roles: Set[UserRole]
authenticated: bool
roles: Set[UserRole]
tracking: bool = True
registered: bool = True

Expand All @@ -141,18 +143,31 @@ def __init__(self, api_key: str, authenticated: bool, roles: Set[UserRole], trac
self.roles = roles
self.tracking = tracking
self.registered = registered

def get_apikey(self) -> str:
return self.api_key

def is_authenticated(self) -> bool:
return self.authenticated

def has_role(self, role: UserRole) -> bool:
return role in self.roles

def log_info(self, msg: str, *args, **kwargs) -> None:
if self.authenticated and self.tracking:
app.logger.info(f"apikey: {self.api_key}, {msg}", *args, **kwargs)
else:
app.logger.info(msg, *args, **kwargs)

def is_rate_limited(self) -> bool:
return not self.registered

def is_tracking(self) -> bool:
return self.tracking

def log_info(self, msg: str, *args, **kwargs) -> None:
logger = get_structured_logger("api_key_logs", filename="api_key_logs.log")
if self.is_authenticated():
if self.is_tracking():
logger.info(msg, *args, **dict(kwargs, apikey=self.get_apikey()))
else:
logger.info(msg, *args, **dict(kwargs, apikey="*****"))
else:
logger.info(msg, *args, **kwargs)


ANONYMOUS_USER = User("anonymous", False, set())
Expand All @@ -169,7 +184,6 @@ def _find_user(api_key: Optional[str]) -> User:
return User(user.api_key, True, set(user.roles.split(",")), user.tracking, user.registered)

def resolve_auth_token() -> Optional[str]:
# auth request param
for name in ('auth', 'api_key', 'token'):
if name in request.values:
return request.values[name]
Expand All @@ -187,12 +201,23 @@ def _get_current_user() -> User:
if "user" not in g:
api_key = resolve_auth_token()
user = _find_user(api_key)
if not user.authenticated and require_api_key():
request_path = request.full_path
if not user.is_authenticated() and require_api_key():
raise MissingAPIKeyException()
user.log_info(request.full_path)
# If the user configured no-track option, mask the API key
if not user.is_tracking():
request_path = mask_apikey(request_path)
user.log_info("Get path", path=request_path)
g.user = user
return g.user

def mask_apikey(path: str) -> str:
# Function to mask API key query string from a request path
regexp = re.compile(r'[\\?&]api_key=([^&#]*)')
if regexp.search(path):
path = re.sub(regexp, "&api_key=*****", path)
return path


current_user: User = cast(User, LocalProxy(_get_current_user))

Expand All @@ -204,12 +229,12 @@ def require_api_key() -> bool:

def show_soft_api_key_warning() -> bool:
n = date.today()
return not current_user.authenticated and not app.config.get('TESTING', False) and n > API_KEY_SOFT_WARNING and n < API_KEY_HARD_WARNING
return not current_user.is_authenticated() and not app.config.get('TESTING', False) and n > API_KEY_SOFT_WARNING and n < API_KEY_HARD_WARNING


def show_hard_api_key_warning() -> bool:
n = date.today()
return not current_user.authenticated and not app.config.get('TESTING', False) and n > API_KEY_HARD_WARNING
return not current_user.is_authenticated() and not app.config.get('TESTING', False) and n > API_KEY_HARD_WARNING


def _is_public_route() -> bool:
Expand Down
3 changes: 2 additions & 1 deletion src/server/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,4 +69,5 @@ def send_lib_file(path: str):
app.logger.setLevel(gunicorn_logger.level)
sqlalchemy_logger = logging.getLogger("sqlalchemy")
sqlalchemy_logger.handlers = gunicorn_logger.handlers
sqlalchemy_logger.setLevel(gunicorn_logger.level)
# Change SQLAlchemy logging level to "ERROR" in order to prevent query details of API keys from being logged
sqlalchemy_logger.setLevel(logging.ERROR)