Skip to content

Commit b4ece52

Browse files
authored
Add support for CliMutuallyExclusiveGroup. (#473)
1 parent 11a817c commit b4ece52

File tree

4 files changed

+236
-19
lines changed

4 files changed

+236
-19
lines changed

docs/index.md

+38
Original file line numberDiff line numberDiff line change
@@ -969,6 +969,44 @@ For `BaseModel` and `pydantic.dataclasses.dataclass` types, `CliApp.run` will in
969969
The alias generator for kebab case does not propagate to subcommands or submodels and will have to be manually set
970970
in these cases.
971971

972+
### Mutually Exclusive Groups
973+
974+
CLI mutually exclusive groups can be created by inheriting from the `CliMutuallyExclusiveGroup` class.
975+
976+
!!! note
977+
A `CliMutuallyExclusiveGroup` cannot be used in a union or contain nested models.
978+
979+
```py
980+
from typing import Optional
981+
982+
from pydantic import BaseModel
983+
984+
from pydantic_settings import CliApp, CliMutuallyExclusiveGroup, SettingsError
985+
986+
987+
class Circle(CliMutuallyExclusiveGroup):
988+
radius: Optional[float] = None
989+
diameter: Optional[float] = None
990+
perimeter: Optional[float] = None
991+
992+
993+
class Settings(BaseModel):
994+
circle: Circle
995+
996+
997+
try:
998+
CliApp.run(
999+
Settings,
1000+
cli_args=['--circle.radius=1', '--circle.diameter=2'],
1001+
cli_exit_on_error=False,
1002+
)
1003+
except SettingsError as e:
1004+
print(e)
1005+
"""
1006+
error parsing CLI: argument --circle.diameter: not allowed with argument --circle.radius
1007+
"""
1008+
```
1009+
9721010
### Customizing the CLI Experience
9731011

9741012
The below flags can be used to customise the CLI experience to your needs.

pydantic_settings/__init__.py

+2
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
AzureKeyVaultSettingsSource,
55
CliExplicitFlag,
66
CliImplicitFlag,
7+
CliMutuallyExclusiveGroup,
78
CliPositionalArg,
89
CliSettingsSource,
910
CliSubCommand,
@@ -34,6 +35,7 @@
3435
'CliPositionalArg',
3536
'CliExplicitFlag',
3637
'CliImplicitFlag',
38+
'CliMutuallyExclusiveGroup',
3739
'InitSettingsSource',
3840
'JsonConfigSettingsSource',
3941
'PyprojectTomlConfigSettingsSource',

pydantic_settings/sources.py

+45-6
Original file line numberDiff line numberDiff line change
@@ -149,6 +149,10 @@ def error(self, message: str) -> NoReturn:
149149
super().error(message)
150150

151151

152+
class CliMutuallyExclusiveGroup(BaseModel):
153+
pass
154+
155+
152156
T = TypeVar('T')
153157
CliSubCommand = Annotated[Union[T, None], _CliSubCommand]
154158
CliPositionalArg = Annotated[T, _CliPositionalArg]
@@ -1483,7 +1487,7 @@ def _connect_parser_method(
14831487
if (
14841488
parser_method is not None
14851489
and self.case_sensitive is False
1486-
and method_name == 'parsed_args_method'
1490+
and method_name == 'parse_args_method'
14871491
and isinstance(self._root_parser, _CliInternalArgParser)
14881492
):
14891493

@@ -1515,6 +1519,26 @@ def none_parser_method(*args: Any, **kwargs: Any) -> Any:
15151519
else:
15161520
return parser_method
15171521

1522+
def _connect_group_method(self, add_argument_group_method: Callable[..., Any] | None) -> Callable[..., Any]:
1523+
add_argument_group = self._connect_parser_method(add_argument_group_method, 'add_argument_group_method')
1524+
1525+
def add_group_method(parser: Any, **kwargs: Any) -> Any:
1526+
if not kwargs.pop('_is_cli_mutually_exclusive_group'):
1527+
kwargs.pop('required')
1528+
return add_argument_group(parser, **kwargs)
1529+
else:
1530+
main_group_kwargs = {arg: kwargs.pop(arg) for arg in ['title', 'description'] if arg in kwargs}
1531+
main_group_kwargs['title'] += ' (mutually exclusive)'
1532+
group = add_argument_group(parser, **main_group_kwargs)
1533+
if not hasattr(group, 'add_mutually_exclusive_group'):
1534+
raise SettingsError(
1535+
'cannot connect CLI settings source root parser: '
1536+
'group object is missing add_mutually_exclusive_group but is needed for connecting'
1537+
)
1538+
return group.add_mutually_exclusive_group(**kwargs)
1539+
1540+
return add_group_method
1541+
15181542
def _connect_root_parser(
15191543
self,
15201544
root_parser: T,
@@ -1531,9 +1555,9 @@ def _parse_known_args(*args: Any, **kwargs: Any) -> Namespace:
15311555
self._root_parser = root_parser
15321556
if parse_args_method is None:
15331557
parse_args_method = _parse_known_args if self.cli_ignore_unknown_args else ArgumentParser.parse_args
1534-
self._parse_args = self._connect_parser_method(parse_args_method, 'parsed_args_method')
1558+
self._parse_args = self._connect_parser_method(parse_args_method, 'parse_args_method')
15351559
self._add_argument = self._connect_parser_method(add_argument_method, 'add_argument_method')
1536-
self._add_argument_group = self._connect_parser_method(add_argument_group_method, 'add_argument_group_method')
1560+
self._add_group = self._connect_group_method(add_argument_group_method)
15371561
self._add_parser = self._connect_parser_method(add_parser_method, 'add_parser_method')
15381562
self._add_subparsers = self._connect_parser_method(add_subparsers_method, 'add_subparsers_method')
15391563
self._formatter_class = formatter_class
@@ -1665,6 +1689,7 @@ def _add_parser_args(
16651689
if is_parser_submodel:
16661690
self._add_parser_submodels(
16671691
parser,
1692+
model,
16681693
sub_models,
16691694
added_args,
16701695
arg_prefix,
@@ -1680,7 +1705,7 @@ def _add_parser_args(
16801705
elif not is_alias_path_only:
16811706
if group is not None:
16821707
if isinstance(group, dict):
1683-
group = self._add_argument_group(parser, **group)
1708+
group = self._add_group(parser, **group)
16841709
added_args += list(arg_names)
16851710
self._add_argument(group, *(f'{flag_prefix[:len(name)]}{name}' for name in arg_names), **kwargs)
16861711
else:
@@ -1724,6 +1749,7 @@ def _get_arg_names(
17241749
def _add_parser_submodels(
17251750
self,
17261751
parser: Any,
1752+
model: type[BaseModel],
17271753
sub_models: list[type[BaseModel]],
17281754
added_args: list[str],
17291755
arg_prefix: str,
@@ -1736,10 +1762,23 @@ def _add_parser_submodels(
17361762
alias_names: tuple[str, ...],
17371763
model_default: Any,
17381764
) -> None:
1765+
if issubclass(model, CliMutuallyExclusiveGroup):
1766+
# Argparse has deprecated "calling add_argument_group() or add_mutually_exclusive_group() on a
1767+
# mutually exclusive group" (https://docs.python.org/3/library/argparse.html#mutual-exclusion).
1768+
# Since nested models result in a group add, raise an exception for nested models in a mutually
1769+
# exclusive group.
1770+
raise SettingsError('cannot have nested models in a CliMutuallyExclusiveGroup')
1771+
17391772
model_group: Any = None
17401773
model_group_kwargs: dict[str, Any] = {}
17411774
model_group_kwargs['title'] = f'{arg_names[0]} options'
17421775
model_group_kwargs['description'] = field_info.description
1776+
model_group_kwargs['required'] = kwargs['required']
1777+
model_group_kwargs['_is_cli_mutually_exclusive_group'] = any(
1778+
issubclass(model, CliMutuallyExclusiveGroup) for model in sub_models
1779+
)
1780+
if model_group_kwargs['_is_cli_mutually_exclusive_group'] and len(sub_models) > 1:
1781+
raise SettingsError('cannot use union with CliMutuallyExclusiveGroup')
17431782
if self.cli_use_class_docs_for_groups and len(sub_models) == 1:
17441783
model_group_kwargs['description'] = None if sub_models[0].__doc__ is None else dedent(sub_models[0].__doc__)
17451784

@@ -1762,7 +1801,7 @@ def _add_parser_submodels(
17621801
if not self.cli_avoid_json:
17631802
added_args.append(arg_names[0])
17641803
kwargs['help'] = f'set {arg_names[0]} from JSON string'
1765-
model_group = self._add_argument_group(parser, **model_group_kwargs)
1804+
model_group = self._add_group(parser, **model_group_kwargs)
17661805
self._add_argument(model_group, *(f'{flag_prefix}{name}' for name in arg_names), **kwargs)
17671806
for model in sub_models:
17681807
self._add_parser_args(
@@ -1788,7 +1827,7 @@ def _add_parser_alias_paths(
17881827
if alias_path_args:
17891828
context = parser
17901829
if group is not None:
1791-
context = self._add_argument_group(parser, **group) if isinstance(group, dict) else group
1830+
context = self._add_group(parser, **group) if isinstance(group, dict) else group
17921831
is_nested_alias_path = arg_prefix.endswith('.')
17931832
arg_prefix = arg_prefix[:-1] if is_nested_alias_path else arg_prefix
17941833
for name, metavar in alias_path_args.items():

0 commit comments

Comments
 (0)