|
4 | 4 | from typing import TYPE_CHECKING, Any, Dict, Generic, List, Optional, Sequence, TypeVar, Union
|
5 | 5 |
|
6 | 6 | import pytest
|
| 7 | +from dagster._check.functions import CheckError |
7 | 8 | from dagster._record import (
|
8 | 9 | _INJECTED_DEFAULT_VALS_LOCAL_VAR,
|
9 | 10 | IHaveNew,
|
@@ -600,3 +601,123 @@ class SubSub(Sub):
|
600 | 601 | assert subsub.b == 0
|
601 | 602 |
|
602 | 603 | assert repr(subsub) == "SubSub(a=0, b=0, c=-1, d=2)"
|
| 604 | + |
| 605 | + |
| 606 | +def test_generic_with_propagate() -> None: |
| 607 | + T = TypeVar("T") |
| 608 | + |
| 609 | + class Base(Generic[T]): ... |
| 610 | + |
| 611 | + @record |
| 612 | + class RecordBase(Base[T]): |
| 613 | + label: Optional[str] = None |
| 614 | + |
| 615 | + @record |
| 616 | + class SubAdditionalArg(RecordBase): |
| 617 | + some_val: str |
| 618 | + |
| 619 | + obj = SubAdditionalArg(some_val="hi") |
| 620 | + assert SubAdditionalArg(some_val="hi").some_val == "hi" |
| 621 | + assert copy(obj, label="...").label == "..." |
| 622 | + assert copy(obj, some_val="new").some_val == "new" |
| 623 | + |
| 624 | + @record |
| 625 | + class SubAdditionalArgRecursive(RecordBase): |
| 626 | + vals: Sequence[RecordBase] |
| 627 | + |
| 628 | + obj = SubAdditionalArgRecursive(vals=[SubAdditionalArg(some_val="hi")]) |
| 629 | + assert len(obj.vals) == 1 |
| 630 | + assert copy(obj, label="...").label == "..." |
| 631 | + |
| 632 | + @record |
| 633 | + class SubAdditionalArgSpecific(RecordBase[int]): |
| 634 | + vals: Sequence[RecordBase[str]] |
| 635 | + |
| 636 | + obj = SubAdditionalArgSpecific(vals=[SubAdditionalArg(some_val="hi")]) |
| 637 | + assert len(obj.vals) == 1 |
| 638 | + assert copy(obj, label="...").label == "..." |
| 639 | + |
| 640 | + @record |
| 641 | + class SubAdditionalArgVariableBase(RecordBase[T]): |
| 642 | + vals: Sequence[RecordBase[T]] |
| 643 | + val: Base[T] |
| 644 | + |
| 645 | + obj = SubAdditionalArgVariableBase( |
| 646 | + vals=[SubAdditionalArg(some_val="hi")], val=SubAdditionalArg(some_val="bye") |
| 647 | + ) |
| 648 | + assert len(obj.vals) == 1 |
| 649 | + assert copy(obj, label="...").label == "..." |
| 650 | + |
| 651 | + @record |
| 652 | + class SubSubVariableBaseSpecific(SubAdditionalArgVariableBase[str]): ... |
| 653 | + |
| 654 | + obj = SubSubVariableBaseSpecific( |
| 655 | + vals=[SubAdditionalArg(some_val="hi")], val=SubAdditionalArg(some_val="bye") |
| 656 | + ) |
| 657 | + assert len(obj.vals) == 1 |
| 658 | + assert copy(obj, label="...").label == "..." |
| 659 | + |
| 660 | + @record |
| 661 | + class SubSubVariableBaseAny(SubAdditionalArgVariableBase): ... |
| 662 | + |
| 663 | + obj = SubSubVariableBaseAny( |
| 664 | + vals=[SubAdditionalArg(some_val="hi")], val=SubAdditionalArg(some_val="bye") |
| 665 | + ) |
| 666 | + assert len(obj.vals) == 1 |
| 667 | + assert copy(obj, label="...").label == "..." |
| 668 | + |
| 669 | + |
| 670 | +def test_generic_with_propagate_type_checking() -> None: |
| 671 | + T = TypeVar("T") |
| 672 | + |
| 673 | + class Base(Generic[T]): ... |
| 674 | + |
| 675 | + @record |
| 676 | + class RecordBase(Base[T]): |
| 677 | + inner: T |
| 678 | + |
| 679 | + @record |
| 680 | + class SpecificRecord(RecordBase): |
| 681 | + val1: RecordBase |
| 682 | + val2: RecordBase[str] |
| 683 | + val3: Base |
| 684 | + val4: Base[int] |
| 685 | + |
| 686 | + valid_record = SpecificRecord( |
| 687 | + inner=..., |
| 688 | + val1=RecordBase(inner=1.23), |
| 689 | + val2=RecordBase(inner="hi"), |
| 690 | + val3=RecordBase(inner=1.23), |
| 691 | + val4=RecordBase(inner=1), |
| 692 | + ) |
| 693 | + |
| 694 | + with pytest.raises(CheckError, match='"val1" is not a RecordBase'): |
| 695 | + copy(valid_record, val1=Base()) |
| 696 | + |
| 697 | + with pytest.raises(CheckError, match='"val2" is not a RecordBase'): |
| 698 | + copy(valid_record, val2=Base()) |
| 699 | + |
| 700 | + with pytest.raises(CheckError, match='"val3" is not a Base'): |
| 701 | + copy(valid_record, val3=3) |
| 702 | + |
| 703 | + with pytest.raises(CheckError, match='"val4" is not a Base'): |
| 704 | + copy(valid_record, val4=4) |
| 705 | + |
| 706 | + |
| 707 | +@pytest.mark.xfail() |
| 708 | +def test_custom_subclass() -> None: |
| 709 | + @record_custom |
| 710 | + class Thing(IHaveNew): |
| 711 | + val: str |
| 712 | + |
| 713 | + def __new__(cls, val_short: str): |
| 714 | + return super().__new__(cls, val=val_short * 2) |
| 715 | + |
| 716 | + assert Thing(val_short="abc").val == "abcabc" |
| 717 | + |
| 718 | + @record |
| 719 | + class SubThing(Thing): |
| 720 | + other_val: int |
| 721 | + |
| 722 | + # this does not work, as we've overridden the wrong __new__ |
| 723 | + SubThing(other_val=1) |
0 commit comments