Skip to content

Commit 94e3ce0

Browse files
authored
Fix edge cases of generic @records (#24629)
## Summary & Motivation This came about from the PR stacked above this. Realized that there were some remaining issues, particularly with `copy()` and type checking paramters that were generic. Added tests for these cases and fixed them. ## How I Tested These Changes ## Changelog NOCHANGELOG - [ ] `NEW` _(added new feature or capability)_ - [ ] `BUGFIX` _(fixed a bug)_ - [ ] `DOCS` _(added or updated documentation)_
1 parent d25e4fa commit 94e3ce0

File tree

3 files changed

+147
-15
lines changed

3 files changed

+147
-15
lines changed

python_modules/dagster/dagster/_check/builder.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -366,5 +366,11 @@ def build_check_call_str(
366366
if tuple_types is not None:
367367
tt_name = _name(tuple_types)
368368
return f'{name} if isinstance({name}, {tt_name}) else check.inst_param({name}, "{name}", {tt_name})'
369+
# generic
370+
else:
371+
inst_type = _coerce_type(ttype, eval_ctx)
372+
if inst_type:
373+
it = _name(inst_type)
374+
return f'{name} if isinstance({name}, {it}) else check.inst_param({name}, "{name}", {it})'
369375

370376
failed(f"Unhandled {ttype}")

python_modules/dagster/dagster/_record/__init__.py

Lines changed: 20 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -128,6 +128,7 @@ def _namedtuple_record_transform(
128128
(cls, base),
129129
{ # these will override an implementation on the class if it exists
130130
**{n: getattr(base, n) for n in field_set.keys()},
131+
"_fields": base._fields,
131132
"__iter__": _banned_iter,
132133
"__getitem__": _banned_idx,
133134
"__hidden_iter__": base.__iter__,
@@ -364,23 +365,27 @@ def __init__(
364365
self._nt_base = nt_base
365366
self._eval_ctx = eval_ctx
366367
self._new_frames = new_frames # how many frames of __new__ there are
368+
self._compiled_fn = None
367369

368370
def __call__(self, cls, *args, **kwargs):
369-
# update the context with callsite locals/globals to resolve
370-
# ForwardRefs that were unavailable at definition time.
371-
self._eval_ctx.update_from_frame(1 + self._new_frames)
372-
373-
# ensure check is in scope
374-
if "check" not in self._eval_ctx.global_ns:
375-
self._eval_ctx.global_ns["check"] = check
376-
377-
# jit that shit
378-
self._nt_base.__new__ = self._eval_ctx.compile_fn(
379-
self._build_checked_new_str(),
380-
_CHECKED_NEW,
381-
)
382-
383-
return self._nt_base.__new__(cls, *args, **kwargs)
371+
if self._compiled_fn is None:
372+
# update the context with callsite locals/globals to resolve
373+
# ForwardRefs that were unavailable at definition time.
374+
self._eval_ctx.update_from_frame(1 + self._new_frames)
375+
376+
# ensure check is in scope
377+
if "check" not in self._eval_ctx.global_ns:
378+
self._eval_ctx.global_ns["check"] = check
379+
380+
# we are double-memoizing this to handle some confusing mro issues
381+
# in which the _nt_base's __new__ method is not on the critical
382+
# path, causing this to get invoked multiple times
383+
self._compiled_fn = self._eval_ctx.compile_fn(
384+
self._build_checked_new_str(),
385+
_CHECKED_NEW,
386+
)
387+
self._nt_base.__new__ = self._compiled_fn
388+
return self._compiled_fn(cls, *args, **kwargs)
384389

385390
def _build_checked_new_str(self) -> str:
386391
kw_args_str, set_calls_str = build_args_and_assignment_strs(self._field_set, self._defaults)

python_modules/dagster/dagster_tests/general_tests/test_record.py

Lines changed: 121 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
from typing import TYPE_CHECKING, Any, Dict, Generic, List, Optional, Sequence, TypeVar, Union
55

66
import pytest
7+
from dagster._check.functions import CheckError
78
from dagster._record import (
89
_INJECTED_DEFAULT_VALS_LOCAL_VAR,
910
IHaveNew,
@@ -600,3 +601,123 @@ class SubSub(Sub):
600601
assert subsub.b == 0
601602

602603
assert repr(subsub) == "SubSub(a=0, b=0, c=-1, d=2)"
604+
605+
606+
def test_generic_with_propagate() -> None:
607+
T = TypeVar("T")
608+
609+
class Base(Generic[T]): ...
610+
611+
@record
612+
class RecordBase(Base[T]):
613+
label: Optional[str] = None
614+
615+
@record
616+
class SubAdditionalArg(RecordBase):
617+
some_val: str
618+
619+
obj = SubAdditionalArg(some_val="hi")
620+
assert SubAdditionalArg(some_val="hi").some_val == "hi"
621+
assert copy(obj, label="...").label == "..."
622+
assert copy(obj, some_val="new").some_val == "new"
623+
624+
@record
625+
class SubAdditionalArgRecursive(RecordBase):
626+
vals: Sequence[RecordBase]
627+
628+
obj = SubAdditionalArgRecursive(vals=[SubAdditionalArg(some_val="hi")])
629+
assert len(obj.vals) == 1
630+
assert copy(obj, label="...").label == "..."
631+
632+
@record
633+
class SubAdditionalArgSpecific(RecordBase[int]):
634+
vals: Sequence[RecordBase[str]]
635+
636+
obj = SubAdditionalArgSpecific(vals=[SubAdditionalArg(some_val="hi")])
637+
assert len(obj.vals) == 1
638+
assert copy(obj, label="...").label == "..."
639+
640+
@record
641+
class SubAdditionalArgVariableBase(RecordBase[T]):
642+
vals: Sequence[RecordBase[T]]
643+
val: Base[T]
644+
645+
obj = SubAdditionalArgVariableBase(
646+
vals=[SubAdditionalArg(some_val="hi")], val=SubAdditionalArg(some_val="bye")
647+
)
648+
assert len(obj.vals) == 1
649+
assert copy(obj, label="...").label == "..."
650+
651+
@record
652+
class SubSubVariableBaseSpecific(SubAdditionalArgVariableBase[str]): ...
653+
654+
obj = SubSubVariableBaseSpecific(
655+
vals=[SubAdditionalArg(some_val="hi")], val=SubAdditionalArg(some_val="bye")
656+
)
657+
assert len(obj.vals) == 1
658+
assert copy(obj, label="...").label == "..."
659+
660+
@record
661+
class SubSubVariableBaseAny(SubAdditionalArgVariableBase): ...
662+
663+
obj = SubSubVariableBaseAny(
664+
vals=[SubAdditionalArg(some_val="hi")], val=SubAdditionalArg(some_val="bye")
665+
)
666+
assert len(obj.vals) == 1
667+
assert copy(obj, label="...").label == "..."
668+
669+
670+
def test_generic_with_propagate_type_checking() -> None:
671+
T = TypeVar("T")
672+
673+
class Base(Generic[T]): ...
674+
675+
@record
676+
class RecordBase(Base[T]):
677+
inner: T
678+
679+
@record
680+
class SpecificRecord(RecordBase):
681+
val1: RecordBase
682+
val2: RecordBase[str]
683+
val3: Base
684+
val4: Base[int]
685+
686+
valid_record = SpecificRecord(
687+
inner=...,
688+
val1=RecordBase(inner=1.23),
689+
val2=RecordBase(inner="hi"),
690+
val3=RecordBase(inner=1.23),
691+
val4=RecordBase(inner=1),
692+
)
693+
694+
with pytest.raises(CheckError, match='"val1" is not a RecordBase'):
695+
copy(valid_record, val1=Base())
696+
697+
with pytest.raises(CheckError, match='"val2" is not a RecordBase'):
698+
copy(valid_record, val2=Base())
699+
700+
with pytest.raises(CheckError, match='"val3" is not a Base'):
701+
copy(valid_record, val3=3)
702+
703+
with pytest.raises(CheckError, match='"val4" is not a Base'):
704+
copy(valid_record, val4=4)
705+
706+
707+
@pytest.mark.xfail()
708+
def test_custom_subclass() -> None:
709+
@record_custom
710+
class Thing(IHaveNew):
711+
val: str
712+
713+
def __new__(cls, val_short: str):
714+
return super().__new__(cls, val=val_short * 2)
715+
716+
assert Thing(val_short="abc").val == "abcabc"
717+
718+
@record
719+
class SubThing(Thing):
720+
other_val: int
721+
722+
# this does not work, as we've overridden the wrong __new__
723+
SubThing(other_val=1)

0 commit comments

Comments
 (0)