Skip to content

chore(internal): add reflection helper function #412

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jun 28, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion src/finch/_utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,4 +49,7 @@
maybe_transform as maybe_transform,
async_maybe_transform as async_maybe_transform,
)
from ._reflection import function_has_argument as function_has_argument
from ._reflection import (
function_has_argument as function_has_argument,
assert_signatures_in_sync as assert_signatures_in_sync,
)
34 changes: 34 additions & 0 deletions src/finch/_utils/_reflection.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from __future__ import annotations

import inspect
from typing import Any, Callable

Expand All @@ -6,3 +8,35 @@ def function_has_argument(func: Callable[..., Any], arg_name: str) -> bool:
"""Returns whether or not the given function has a specific parameter"""
sig = inspect.signature(func)
return arg_name in sig.parameters


def assert_signatures_in_sync(
source_func: Callable[..., Any],
check_func: Callable[..., Any],
*,
exclude_params: set[str] = set(),
) -> None:
"""Ensure that the signature of the second function matches the first."""

check_sig = inspect.signature(check_func)
source_sig = inspect.signature(source_func)

errors: list[str] = []

for name, source_param in source_sig.parameters.items():
if name in exclude_params:
continue

custom_param = check_sig.parameters.get(name)
if not custom_param:
errors.append(f"the `{name}` param is missing")
continue

if custom_param.annotation != source_param.annotation:
errors.append(
f"types for the `{name}` param are do not match; source={repr(source_param.annotation)} checking={repr(source_param.annotation)}"
)
continue

if errors:
raise AssertionError(f"{len(errors)} errors encountered when comparing signatures:\n\n" + "\n\n".join(errors))