Skip to content

Commit a0924bc

Browse files
authored
Fix alias resolution to use preferred key. (#481)
1 parent 6fe3bd1 commit a0924bc

File tree

2 files changed

+45
-6
lines changed

2 files changed

+45
-6
lines changed

pydantic_settings/sources.py

+13-5
Original file line numberDiff line numberDiff line change
@@ -661,7 +661,9 @@ def get_field_value(self, field: FieldInfo, field_name: str) -> tuple[Any, str,
661661
a flag to determine whether value is complex.
662662
"""
663663

664-
for field_key, env_name, value_is_complex in self._extract_field_info(field, field_name):
664+
field_infos = self._extract_field_info(field, field_name)
665+
preferred_key, *_ = field_infos[0]
666+
for field_key, env_name, value_is_complex in field_infos:
665667
# paths reversed to match the last-wins behaviour of `env_file`
666668
for secrets_path in reversed(self.secrets_paths):
667669
path = self.find_case_path(secrets_path, env_name, self.case_sensitive)
@@ -670,14 +672,16 @@ def get_field_value(self, field: FieldInfo, field_name: str) -> tuple[Any, str,
670672
continue
671673

672674
if path.is_file():
673-
return path.read_text().strip(), field_key, value_is_complex
675+
if value_is_complex or (self.config.get('populate_by_name', False) and (field_key == field_name)):
676+
preferred_key = field_key
677+
return path.read_text().strip(), preferred_key, value_is_complex
674678
else:
675679
warnings.warn(
676680
f'attempted to load secret file "{path}" but found a {path_type_label(path)} instead.',
677681
stacklevel=4,
678682
)
679683

680-
return None, field_key, value_is_complex
684+
return None, preferred_key, value_is_complex
681685

682686
def __repr__(self) -> str:
683687
return f'{self.__class__.__name__}(secrets_dir={self.secrets_dir!r})'
@@ -725,12 +729,16 @@ def get_field_value(self, field: FieldInfo, field_name: str) -> tuple[Any, str,
725729
"""
726730

727731
env_val: str | None = None
728-
for field_key, env_name, value_is_complex in self._extract_field_info(field, field_name):
732+
field_infos = self._extract_field_info(field, field_name)
733+
preferred_key, *_ = field_infos[0]
734+
for field_key, env_name, value_is_complex in field_infos:
729735
env_val = self.env_vars.get(env_name)
730736
if env_val is not None:
737+
if value_is_complex or (self.config.get('populate_by_name', False) and (field_key == field_name)):
738+
preferred_key = field_key
731739
break
732740

733-
return env_val, field_key, value_is_complex
741+
return env_val, preferred_key, value_is_complex
734742

735743
def prepare_field_value(self, field_name: str, field: FieldInfo, value: Any, value_is_complex: bool) -> Any:
736744
"""

tests/test_source_cli.py

+32-1
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
import typing_extensions
1111
from pydantic import (
1212
AliasChoices,
13+
AliasGenerator,
1314
AliasPath,
1415
BaseModel,
1516
ConfigDict,
@@ -107,7 +108,7 @@ def parse_args(self, *args: Any, **kwargs: Any) -> argparse.Namespace:
107108
return self.parser.parse_args(*args, **kwargs)
108109

109110

110-
def test_validation_alias_with_cli_prefix():
111+
def test_cli_validation_alias_with_cli_prefix():
111112
class Settings(BaseSettings, cli_exit_on_error=False):
112113
foobar: str = Field(validation_alias='foo')
113114

@@ -119,6 +120,36 @@ class Settings(BaseSettings, cli_exit_on_error=False):
119120
assert CliApp.run(Settings, cli_args=['--p.foo', 'bar']).foobar == 'bar'
120121

121122

123+
@pytest.mark.parametrize(
124+
'alias_generator',
125+
[
126+
AliasGenerator(validation_alias=lambda s: AliasChoices(s, s.replace('_', '-'))),
127+
AliasGenerator(validation_alias=lambda s: AliasChoices(s.replace('_', '-'), s)),
128+
],
129+
)
130+
def test_cli_alias_resolution_consistency_with_env(env, alias_generator):
131+
class SubModel(BaseModel):
132+
v1: str = 'model default'
133+
134+
class Settings(BaseSettings):
135+
model_config = SettingsConfigDict(
136+
env_nested_delimiter='__',
137+
nested_model_default_partial_update=True,
138+
alias_generator=alias_generator,
139+
)
140+
141+
sub_model: SubModel = SubModel(v1='top default')
142+
143+
assert CliApp.run(Settings, cli_args=[]).model_dump() == {'sub_model': {'v1': 'top default'}}
144+
145+
env.set('SUB_MODEL__V1', 'env default')
146+
assert CliApp.run(Settings, cli_args=[]).model_dump() == {'sub_model': {'v1': 'env default'}}
147+
148+
assert CliApp.run(Settings, cli_args=['--sub-model.v1=cli default']).model_dump() == {
149+
'sub_model': {'v1': 'cli default'}
150+
}
151+
152+
122153
def test_cli_nested_arg():
123154
class SubSubValue(BaseModel):
124155
v6: str

0 commit comments

Comments
 (0)