diff --git a/openapi_python_client/__init__.py b/openapi_python_client/__init__.py index b5ad8afeb..9f8fa6834 100644 --- a/openapi_python_client/__init__.py +++ b/openapi_python_client/__init__.py @@ -52,10 +52,12 @@ def __init__( meta: MetaType, custom_template_path: Optional[Path] = None, file_encoding: str = "utf-8", + extra_template_kwargs: Optional[Dict[str, str]] = None, ) -> None: self.openapi: GeneratorData = openapi self.meta: MetaType = meta self.file_encoding = file_encoding + self.extra_template_kwargs = extra_template_kwargs or {} package_loader = PackageLoader(__package__) loader: BaseLoader @@ -146,7 +148,7 @@ def _create_package(self) -> None: package_init_template = self.env.get_template("package_init.py.jinja") package_init.write_text( - package_init_template.render(description=self.package_description), encoding=self.file_encoding + package_init_template.render(description=self.package_description, **self.extra_template_kwargs), encoding=self.file_encoding ) if self.meta != MetaType.NONE: @@ -155,7 +157,7 @@ def _create_package(self) -> None: types_template = self.env.get_template("types.py.jinja") types_path = self.package_dir / "types.py" - types_path.write_text(types_template.render(), encoding=self.file_encoding) + types_path.write_text(types_template.render(**self.extra_template_kwargs), encoding=self.file_encoding) def _build_metadata(self) -> None: if self.meta == MetaType.NONE: @@ -170,7 +172,7 @@ def _build_metadata(self) -> None: readme_template = self.env.get_template("README.md.jinja") readme.write_text( readme_template.render( - project_name=self.project_name, description=self.package_description, package_name=self.package_name + project_name=self.project_name, description=self.package_description, package_name=self.package_name, **self.extra_template_kwargs ), encoding=self.file_encoding, ) @@ -178,7 +180,7 @@ def _build_metadata(self) -> None: # .gitignore git_ignore_path = self.project_dir / ".gitignore" git_ignore_template = self.env.get_template(".gitignore.jinja") - git_ignore_path.write_text(git_ignore_template.render(), encoding=self.file_encoding) + git_ignore_path.write_text(git_ignore_template.render(**self.extra_template_kwargs), encoding=self.file_encoding) def _build_pyproject_toml(self, *, use_poetry: bool) -> None: template = "pyproject.toml.jinja" if use_poetry else "pyproject_no_poetry.toml.jinja" @@ -190,6 +192,7 @@ def _build_pyproject_toml(self, *, use_poetry: bool) -> None: package_name=self.package_name, version=self.version, description=self.package_description, + **self.extra_template_kwargs ), encoding=self.file_encoding, ) @@ -203,6 +206,7 @@ def _build_setup_py(self) -> None: package_name=self.package_name, version=self.version, description=self.package_description, + **self.extra_template_kwargs, ), encoding=self.file_encoding, ) @@ -217,7 +221,7 @@ def _build_models(self) -> None: model_template = self.env.get_template("model.py.jinja") for model in self.openapi.models.values(): module_path = models_dir / f"{model.reference.module_name}.py" - module_path.write_text(model_template.render(model=model), encoding=self.file_encoding) + module_path.write_text(model_template.render(model=model, **self.extra_template_kwargs), encoding=self.file_encoding) imports.append(import_string_from_reference(model.reference)) # Generate enums @@ -226,19 +230,19 @@ def _build_models(self) -> None: for enum in self.openapi.enums.values(): module_path = models_dir / f"{enum.reference.module_name}.py" if enum.value_type is int: - module_path.write_text(int_enum_template.render(enum=enum), encoding=self.file_encoding) + module_path.write_text(int_enum_template.render(enum=enum, **self.extra_template_kwargs), encoding=self.file_encoding) else: - module_path.write_text(str_enum_template.render(enum=enum), encoding=self.file_encoding) + module_path.write_text(str_enum_template.render(enum=enum, **self.extra_template_kwargs), encoding=self.file_encoding) imports.append(import_string_from_reference(enum.reference)) models_init_template = self.env.get_template("models_init.py.jinja") - models_init.write_text(models_init_template.render(imports=imports), encoding=self.file_encoding) + models_init.write_text(models_init_template.render(imports=imports, **self.extra_template_kwargs), encoding=self.file_encoding) def _build_api(self) -> None: # Generate Client client_path = self.package_dir / "client.py" client_template = self.env.get_template("client.py.jinja") - client_path.write_text(client_template.render(), encoding=self.file_encoding) + client_path.write_text(client_template.render(**self.extra_template_kwargs), encoding=self.file_encoding) # Generate endpoints api_dir = self.package_dir / "api" @@ -254,7 +258,7 @@ def _build_api(self) -> None: for endpoint in collection.endpoints: module_path = tag_dir / f"{snake_case(endpoint.name)}.py" - module_path.write_text(endpoint_template.render(endpoint=endpoint), encoding=self.file_encoding) + module_path.write_text(endpoint_template.render(endpoint=endpoint, **self.extra_template_kwargs), encoding=self.file_encoding) def _get_project_for_url_or_path( @@ -263,6 +267,7 @@ def _get_project_for_url_or_path( meta: MetaType, custom_template_path: Optional[Path] = None, file_encoding: str = "utf-8", + extra_template_kwargs: Optional[Dict[str, str]] = None, ) -> Union[Project, GeneratorError]: data_dict = _get_document(url=url, path=path) if isinstance(data_dict, GeneratorError): @@ -270,7 +275,7 @@ def _get_project_for_url_or_path( openapi = GeneratorData.from_dict(data_dict) if isinstance(openapi, GeneratorError): return openapi - return Project(openapi=openapi, custom_template_path=custom_template_path, meta=meta, file_encoding=file_encoding) + return Project(openapi=openapi, custom_template_path=custom_template_path, meta=meta, file_encoding=file_encoding, extra_template_kwargs=extra_template_kwargs) def create_new_client( @@ -280,6 +285,7 @@ def create_new_client( meta: MetaType, custom_template_path: Optional[Path] = None, file_encoding: str = "utf-8", + extra_template_kwargs: Optional[Dict[str, str]] = None, ) -> Sequence[GeneratorError]: """ Generate the client library @@ -288,7 +294,7 @@ def create_new_client( A list containing any errors encountered when generating. """ project = _get_project_for_url_or_path( - url=url, path=path, custom_template_path=custom_template_path, meta=meta, file_encoding=file_encoding + url=url, path=path, custom_template_path=custom_template_path, meta=meta, file_encoding=file_encoding, extra_template_kwargs=extra_template_kwargs ) if isinstance(project, GeneratorError): return [project] @@ -302,6 +308,7 @@ def update_existing_client( meta: MetaType, custom_template_path: Optional[Path] = None, file_encoding: str = "utf-8", + extra_template_kwargs: Optional[Dict[str, str]] = None, ) -> Sequence[GeneratorError]: """ Update an existing client library @@ -310,7 +317,7 @@ def update_existing_client( A list containing any errors encountered when generating. """ project = _get_project_for_url_or_path( - url=url, path=path, custom_template_path=custom_template_path, meta=meta, file_encoding=file_encoding + url=url, path=path, custom_template_path=custom_template_path, meta=meta, file_encoding=file_encoding, extra_template_kwargs=extra_template_kwargs ) if isinstance(project, GeneratorError): return [project] diff --git a/openapi_python_client/cli.py b/openapi_python_client/cli.py index 1f94f37ea..d8fb9298c 100644 --- a/openapi_python_client/cli.py +++ b/openapi_python_client/cli.py @@ -1,7 +1,7 @@ import codecs import pathlib from pprint import pformat -from typing import Optional, Sequence +from typing import Optional, Dict, List, Sequence import typer @@ -112,6 +112,12 @@ def handle_errors(errors: Sequence[GeneratorError]) -> None: ) +def _parse_extra_template_kwargs(extra_template_kwargs: List[str]) -> Dict[str, str]: + out = {k: v for k, v in map(lambda s: s.split("="), extra_template_kwargs)} + return out + + + @app.command() def generate( url: Optional[str] = typer.Option(None, help="A URL to read the JSON from"), @@ -119,6 +125,7 @@ def generate( custom_template_path: Optional[pathlib.Path] = typer.Option(None, **custom_template_path_options), # type: ignore file_encoding: str = typer.Option("utf-8", help="Encoding used when writing generated"), meta: MetaType = _meta_option, + extra_template_kwargs: Optional[List[str]] = typer.Option(None), ) -> None: """ Generate a new OpenAPI Client library """ from . import create_new_client @@ -136,8 +143,11 @@ def generate( typer.secho("Unknown encoding : {}".format(file_encoding), fg=typer.colors.RED) raise typer.Exit(code=1) + if extra_template_kwargs is not None: + extra_tpl_kwargs = _parse_extra_template_kwargs(extra_template_kwargs) + errors = create_new_client( - url=url, path=path, meta=meta, custom_template_path=custom_template_path, file_encoding=file_encoding + url=url, path=path, meta=meta, custom_template_path=custom_template_path, file_encoding=file_encoding, extra_template_kwargs=extra_tpl_kwargs, ) handle_errors(errors) @@ -166,7 +176,10 @@ def update( typer.secho("Unknown encoding : {}".format(file_encoding), fg=typer.colors.RED) raise typer.Exit(code=1) + if extra_template_kwargs is not None: + extra_tpl_kwargs = _parse_extra_template_kwargs(extra_template_kwargs) + errors = update_existing_client( - url=url, path=path, meta=meta, custom_template_path=custom_template_path, file_encoding=file_encoding + url=url, path=path, meta=meta, custom_template_path=custom_template_path, file_encoding=file_encoding, extra_template_kwargs=extra_tpl_kwargs, ) handle_errors(errors)