Skip to content

feat: Add extra template arguments option #344

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 20 additions & 13 deletions openapi_python_client/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,10 +52,12 @@ def __init__(
meta: MetaType,
custom_template_path: Optional[Path] = None,
file_encoding: str = "utf-8",
extra_template_kwargs: Optional[Dict[str, str]] = None,
) -> None:
self.openapi: GeneratorData = openapi
self.meta: MetaType = meta
self.file_encoding = file_encoding
self.extra_template_kwargs = extra_template_kwargs or {}

package_loader = PackageLoader(__package__)
loader: BaseLoader
Expand Down Expand Up @@ -146,7 +148,7 @@ def _create_package(self) -> None:

package_init_template = self.env.get_template("package_init.py.jinja")
package_init.write_text(
package_init_template.render(description=self.package_description), encoding=self.file_encoding
package_init_template.render(description=self.package_description, **self.extra_template_kwargs), encoding=self.file_encoding
)

if self.meta != MetaType.NONE:
Expand All @@ -155,7 +157,7 @@ def _create_package(self) -> None:

types_template = self.env.get_template("types.py.jinja")
types_path = self.package_dir / "types.py"
types_path.write_text(types_template.render(), encoding=self.file_encoding)
types_path.write_text(types_template.render(**self.extra_template_kwargs), encoding=self.file_encoding)

def _build_metadata(self) -> None:
if self.meta == MetaType.NONE:
Expand All @@ -170,15 +172,15 @@ def _build_metadata(self) -> None:
readme_template = self.env.get_template("README.md.jinja")
readme.write_text(
readme_template.render(
project_name=self.project_name, description=self.package_description, package_name=self.package_name
project_name=self.project_name, description=self.package_description, package_name=self.package_name, **self.extra_template_kwargs
),
encoding=self.file_encoding,
)

# .gitignore
git_ignore_path = self.project_dir / ".gitignore"
git_ignore_template = self.env.get_template(".gitignore.jinja")
git_ignore_path.write_text(git_ignore_template.render(), encoding=self.file_encoding)
git_ignore_path.write_text(git_ignore_template.render(**self.extra_template_kwargs), encoding=self.file_encoding)

def _build_pyproject_toml(self, *, use_poetry: bool) -> None:
template = "pyproject.toml.jinja" if use_poetry else "pyproject_no_poetry.toml.jinja"
Expand All @@ -190,6 +192,7 @@ def _build_pyproject_toml(self, *, use_poetry: bool) -> None:
package_name=self.package_name,
version=self.version,
description=self.package_description,
**self.extra_template_kwargs
),
encoding=self.file_encoding,
)
Expand All @@ -203,6 +206,7 @@ def _build_setup_py(self) -> None:
package_name=self.package_name,
version=self.version,
description=self.package_description,
**self.extra_template_kwargs,
),
encoding=self.file_encoding,
)
Expand All @@ -217,7 +221,7 @@ def _build_models(self) -> None:
model_template = self.env.get_template("model.py.jinja")
for model in self.openapi.models.values():
module_path = models_dir / f"{model.reference.module_name}.py"
module_path.write_text(model_template.render(model=model), encoding=self.file_encoding)
module_path.write_text(model_template.render(model=model, **self.extra_template_kwargs), encoding=self.file_encoding)
imports.append(import_string_from_reference(model.reference))

# Generate enums
Expand All @@ -226,19 +230,19 @@ def _build_models(self) -> None:
for enum in self.openapi.enums.values():
module_path = models_dir / f"{enum.reference.module_name}.py"
if enum.value_type is int:
module_path.write_text(int_enum_template.render(enum=enum), encoding=self.file_encoding)
module_path.write_text(int_enum_template.render(enum=enum, **self.extra_template_kwargs), encoding=self.file_encoding)
else:
module_path.write_text(str_enum_template.render(enum=enum), encoding=self.file_encoding)
module_path.write_text(str_enum_template.render(enum=enum, **self.extra_template_kwargs), encoding=self.file_encoding)
imports.append(import_string_from_reference(enum.reference))

models_init_template = self.env.get_template("models_init.py.jinja")
models_init.write_text(models_init_template.render(imports=imports), encoding=self.file_encoding)
models_init.write_text(models_init_template.render(imports=imports, **self.extra_template_kwargs), encoding=self.file_encoding)

def _build_api(self) -> None:
# Generate Client
client_path = self.package_dir / "client.py"
client_template = self.env.get_template("client.py.jinja")
client_path.write_text(client_template.render(), encoding=self.file_encoding)
client_path.write_text(client_template.render(**self.extra_template_kwargs), encoding=self.file_encoding)

# Generate endpoints
api_dir = self.package_dir / "api"
Expand All @@ -254,7 +258,7 @@ def _build_api(self) -> None:

for endpoint in collection.endpoints:
module_path = tag_dir / f"{snake_case(endpoint.name)}.py"
module_path.write_text(endpoint_template.render(endpoint=endpoint), encoding=self.file_encoding)
module_path.write_text(endpoint_template.render(endpoint=endpoint, **self.extra_template_kwargs), encoding=self.file_encoding)


def _get_project_for_url_or_path(
Expand All @@ -263,14 +267,15 @@ def _get_project_for_url_or_path(
meta: MetaType,
custom_template_path: Optional[Path] = None,
file_encoding: str = "utf-8",
extra_template_kwargs: Optional[Dict[str, str]] = None,
) -> Union[Project, GeneratorError]:
data_dict = _get_document(url=url, path=path)
if isinstance(data_dict, GeneratorError):
return data_dict
openapi = GeneratorData.from_dict(data_dict)
if isinstance(openapi, GeneratorError):
return openapi
return Project(openapi=openapi, custom_template_path=custom_template_path, meta=meta, file_encoding=file_encoding)
return Project(openapi=openapi, custom_template_path=custom_template_path, meta=meta, file_encoding=file_encoding, extra_template_kwargs=extra_template_kwargs)


def create_new_client(
Expand All @@ -280,6 +285,7 @@ def create_new_client(
meta: MetaType,
custom_template_path: Optional[Path] = None,
file_encoding: str = "utf-8",
extra_template_kwargs: Optional[Dict[str, str]] = None,
) -> Sequence[GeneratorError]:
"""
Generate the client library
Expand All @@ -288,7 +294,7 @@ def create_new_client(
A list containing any errors encountered when generating.
"""
project = _get_project_for_url_or_path(
url=url, path=path, custom_template_path=custom_template_path, meta=meta, file_encoding=file_encoding
url=url, path=path, custom_template_path=custom_template_path, meta=meta, file_encoding=file_encoding, extra_template_kwargs=extra_template_kwargs
)
if isinstance(project, GeneratorError):
return [project]
Expand All @@ -302,6 +308,7 @@ def update_existing_client(
meta: MetaType,
custom_template_path: Optional[Path] = None,
file_encoding: str = "utf-8",
extra_template_kwargs: Optional[Dict[str, str]] = None,
) -> Sequence[GeneratorError]:
"""
Update an existing client library
Expand All @@ -310,7 +317,7 @@ def update_existing_client(
A list containing any errors encountered when generating.
"""
project = _get_project_for_url_or_path(
url=url, path=path, custom_template_path=custom_template_path, meta=meta, file_encoding=file_encoding
url=url, path=path, custom_template_path=custom_template_path, meta=meta, file_encoding=file_encoding, extra_template_kwargs=extra_template_kwargs
)
if isinstance(project, GeneratorError):
return [project]
Expand Down
19 changes: 16 additions & 3 deletions openapi_python_client/cli.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import codecs
import pathlib
from pprint import pformat
from typing import Optional, Sequence
from typing import Optional, Dict, List, Sequence

import typer

Expand Down Expand Up @@ -112,13 +112,20 @@ def handle_errors(errors: Sequence[GeneratorError]) -> None:
)


def _parse_extra_template_kwargs(extra_template_kwargs: List[str]) -> Dict[str, str]:
out = {k: v for k, v in map(lambda s: s.split("="), extra_template_kwargs)}
return out



@app.command()
def generate(
url: Optional[str] = typer.Option(None, help="A URL to read the JSON from"),
path: Optional[pathlib.Path] = typer.Option(None, help="A path to the JSON file"),
custom_template_path: Optional[pathlib.Path] = typer.Option(None, **custom_template_path_options), # type: ignore
file_encoding: str = typer.Option("utf-8", help="Encoding used when writing generated"),
meta: MetaType = _meta_option,
extra_template_kwargs: Optional[List[str]] = typer.Option(None),
) -> None:
""" Generate a new OpenAPI Client library """
from . import create_new_client
Expand All @@ -136,8 +143,11 @@ def generate(
typer.secho("Unknown encoding : {}".format(file_encoding), fg=typer.colors.RED)
raise typer.Exit(code=1)

if extra_template_kwargs is not None:
extra_tpl_kwargs = _parse_extra_template_kwargs(extra_template_kwargs)

errors = create_new_client(
url=url, path=path, meta=meta, custom_template_path=custom_template_path, file_encoding=file_encoding
url=url, path=path, meta=meta, custom_template_path=custom_template_path, file_encoding=file_encoding, extra_template_kwargs=extra_tpl_kwargs,
)
handle_errors(errors)

Expand Down Expand Up @@ -166,7 +176,10 @@ def update(
typer.secho("Unknown encoding : {}".format(file_encoding), fg=typer.colors.RED)
raise typer.Exit(code=1)

if extra_template_kwargs is not None:
extra_tpl_kwargs = _parse_extra_template_kwargs(extra_template_kwargs)

errors = update_existing_client(
url=url, path=path, meta=meta, custom_template_path=custom_template_path, file_encoding=file_encoding
url=url, path=path, meta=meta, custom_template_path=custom_template_path, file_encoding=file_encoding, extra_template_kwargs=extra_tpl_kwargs,
)
handle_errors(errors)