Skip to content

Commit a2de778

Browse files
committed
Add 0 boilerplate sync and async client objects
With the current approach, for every API call, the user has to import a a module from a tag package and call the `asyncio` or `sync` method with the "client=client" argument. This patch adds a wrapper around tag packages and two wrapper `Client` objects to spare the need for that boilerplate code. Check this issue for more information: openapi-generators#224
1 parent d579025 commit a2de778

File tree

5 files changed

+96
-5
lines changed

5 files changed

+96
-5
lines changed

openapi_python_client/__init__.py

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626

2727

2828
class Project:
29-
TEMPLATE_FILTERS = {"snakecase": utils.snake_case, "kebabcase": utils.kebab_case}
29+
TEMPLATE_FILTERS = {"snakecase": utils.snake_case, "kebabcase": utils.kebab_case, "pascalcase": utils.pascal_case}
3030
project_name_override: Optional[str] = None
3131
package_name_override: Optional[str] = None
3232

@@ -76,7 +76,7 @@ def update(self) -> Sequence[GeneratorError]:
7676

7777
def _reformat(self) -> None:
7878
subprocess.run(
79-
"autoflake -i -r --remove-all-unused-imports --remove-unused-variables .",
79+
"autoflake -i -r --remove-all-unused-imports --remove-unused-variables . --ignore-init-module-imports",
8080
cwd=self.package_dir,
8181
shell=True,
8282
stdout=subprocess.PIPE,
@@ -173,6 +173,13 @@ def _build_api(self) -> None:
173173
client_template = self.env.get_template("client.pyi")
174174
client_path.write_text(client_template.render())
175175

176+
# Generate wrapper
177+
wrapper = self.package_dir / "wrapper.py"
178+
wrapper_template = self.env.get_template("wrapper.pyi")
179+
wrapper.write_text(wrapper_template.render(
180+
models=self.openapi.schemas.models.values(),
181+
endpoint_collections=self.openapi.endpoint_collections_by_tag))
182+
176183
# Generate endpoints
177184
api_dir = self.package_dir / "api"
178185
api_dir.mkdir()
@@ -184,13 +191,17 @@ def _build_api(self) -> None:
184191
tag = utils.snake_case(tag)
185192
tag_dir = api_dir / tag
186193
tag_dir.mkdir()
187-
(tag_dir / "__init__.py").touch()
194+
tag_init = tag_dir / "__init__.py"
195+
tag_init_template = self.env.get_template("tag_init.pyi")
196+
tag_init.write_text(tag_init_template.render(
197+
tag=tag, collection=collection))
188198

189199
for endpoint in collection.endpoints:
190200
module_path = tag_dir / f"{snake_case(endpoint.name)}.py"
191201
module_path.write_text(endpoint_template.render(endpoint=endpoint))
192202

193203

204+
194205
def _get_project_for_url_or_path(url: Optional[str], path: Optional[Path]) -> Union[Project, GeneratorError]:
195206
data_dict = _get_document(url=url, path=path)
196207
if isinstance(data_dict, GeneratorError):

openapi_python_client/templates/endpoint_macros.pyi

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -70,14 +70,19 @@ Union[
7070
{% endmacro %}
7171

7272
{# The all the kwargs passed into an endpoint (and variants thereof)) #}
73-
{% macro arguments(endpoint) %}
73+
{% macro arguments(endpoint, client=True) %}
74+
75+
{% if endpoint.path_parameters.__len__() > 1 %}
7476
*,
77+
{% endif %}
7578
{# Proper client based on whether or not the endpoint requires authentication #}
79+
{% if client %}
7680
{% if endpoint.requires_security %}
7781
client: AuthenticatedClient,
7882
{% else %}
7983
client: Client,
8084
{% endif %}
85+
{% endif %}
8186
{# path parameters #}
8287
{% for parameter in endpoint.path_parameters %}
8388
{{ parameter.to_string() }},
Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,3 @@
11
""" {{ description }} """
2-
from .client import AuthenticatedClient, Client
2+
3+
from .wrapper import Client, SyncClient
Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
from . import {% for e in collection.endpoints %} {{e.name | snakecase }}, {% endfor %}
2+
Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,72 @@
1+
from typing import Any, Dict, Optional, Union, cast
2+
from .client import Client as InnerClient, AuthenticatedClient
3+
4+
from .models import (
5+
{% for model in models %}
6+
{{ model.reference.class_name }},
7+
{% endfor %}
8+
)
9+
10+
from .api import (
11+
{% for tag, collection in endpoint_collections.items() %}
12+
{{ tag | snakecase }},
13+
{% endfor %}
14+
)
15+
16+
{% from "endpoint_macros.pyi" import arguments, client, kwargs %}
17+
18+
{% for tag, collection in endpoint_collections.items() %}
19+
20+
class {{ tag | pascalcase }}Api:
21+
22+
def __init__(self, client: InnerClient):
23+
self._client = client
24+
25+
{% for endpoint in collection.endpoints %}
26+
async def {{ endpoint.name | snakecase }}(self, {{ arguments(endpoint, False) }}):
27+
{% if endpoint.requires_security %}
28+
client = cast(AuthenticatedClient, self._client)
29+
{% else %}
30+
client = self._client
31+
{% endif %}
32+
return await {{ tag }}.{{ endpoint.name | snakecase }}.asyncio({{ kwargs(endpoint) }})
33+
34+
{% endfor %}
35+
36+
37+
class Sync{{ tag | pascalcase }}Api:
38+
39+
def __init__(self, client: InnerClient):
40+
self._client = client
41+
42+
{% for endpoint in collection.endpoints %}
43+
def {{ endpoint.name | snakecase }}(self, {{ arguments(endpoint, False) }}):
44+
{% if endpoint.requires_security %}
45+
client = cast(AuthenticatedClient, self._client)
46+
{% else %}
47+
client = self._client
48+
{% endif %}
49+
return {{ tag }}.{{ endpoint.name | snakecase }}.sync({{ kwargs(endpoint) }})
50+
51+
{% endfor %}
52+
53+
{% endfor %}
54+
55+
{% for i in '', 'Sync' %}
56+
57+
class {{ i }}Client:
58+
def __init__(self, base_url: str, timeout: float = 5.0, token: Optional[str] = None):
59+
if token is None:
60+
self.connection = InnerClient(
61+
base_url=base_url,
62+
timeout=timeout)
63+
else:
64+
self.connection = AuthenticatedClient(
65+
base_url=base_url,
66+
timeout=timeout,
67+
token=token)
68+
{% for tag, collection in endpoint_collections.items() %}
69+
self.{{ tag | snakecase }} = {{ i }}{{ tag | pascalcase }}Api(self.connection)
70+
{% endfor %}
71+
72+
{% endfor %}

0 commit comments

Comments
 (0)