Skip to content

[Regression Fix] Call custom resolve functions if provided #241

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Aug 14, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion graphene_sqlalchemy/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from .fields import SQLAlchemyConnectionField
from .utils import get_query, get_session

__version__ = "2.2.1"
__version__ = "2.2.2"

__all__ = [
"__version__",
Expand Down
18 changes: 7 additions & 11 deletions graphene_sqlalchemy/converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,6 @@
ChoiceType = JSONType = ScalarListType = TSVectorType = object


def _get_attr_resolver(attr_name):
return lambda root, _info: getattr(root, attr_name, None)


def get_column_doc(column):
return getattr(column, "doc", None)

Expand All @@ -28,7 +24,7 @@ def is_column_nullable(column):
return bool(getattr(column, "nullable", True))


def convert_sqlalchemy_relationship(relationship_prop, registry, connection_field_factory, **field_kwargs):
def convert_sqlalchemy_relationship(relationship_prop, registry, connection_field_factory, resolver, **field_kwargs):
direction = relationship_prop.direction
model = relationship_prop.mapper.entity

Expand All @@ -40,7 +36,7 @@ def dynamic_type():
if direction == interfaces.MANYTOONE or not relationship_prop.uselist:
return Field(
_type,
resolver=_get_attr_resolver(relationship_prop.key),
resolver=resolver,
**field_kwargs
)
elif direction in (interfaces.ONETOMANY, interfaces.MANYTOMANY):
Expand All @@ -55,18 +51,18 @@ def dynamic_type():
return Dynamic(dynamic_type)


def convert_sqlalchemy_hybrid_method(hybrid_prop, prop_name, **field_kwargs):
def convert_sqlalchemy_hybrid_method(hybrid_prop, resolver, **field_kwargs):
if 'type' not in field_kwargs:
# TODO The default type should be dependent on the type of the property propety.
field_kwargs['type'] = String

return Field(
resolver=_get_attr_resolver(prop_name),
resolver=resolver,
**field_kwargs
)


def convert_sqlalchemy_composite(composite_prop, registry):
def convert_sqlalchemy_composite(composite_prop, registry, resolver):
converter = registry.get_converter_for_composite(composite_prop.composite_class)
if not converter:
try:
Expand Down Expand Up @@ -100,14 +96,14 @@ def inner(fn):
convert_sqlalchemy_composite.register = _register_composite_class


def convert_sqlalchemy_column(column_prop, registry, **field_kwargs):
def convert_sqlalchemy_column(column_prop, registry, resolver, **field_kwargs):
column = column_prop.columns[0]
field_kwargs.setdefault('type', convert_sqlalchemy_type(getattr(column, "type", None), column, registry))
field_kwargs.setdefault('required', not is_column_nullable(column))
field_kwargs.setdefault('description', get_column_doc(column))

return Field(
resolver=_get_attr_resolver(column_prop.key),
resolver=resolver,
**field_kwargs
)

Expand Down
4 changes: 3 additions & 1 deletion graphene_sqlalchemy/tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
from sqlalchemy import create_engine
from sqlalchemy.orm import scoped_session, sessionmaker

import graphene

from ..converter import convert_sqlalchemy_composite
from ..registry import reset_global_registry
from .models import Base, CompositeFullName
Expand All @@ -17,7 +19,7 @@ def reset_registry():
# Tests that explicitly depend on this behavior should re-register a converter
@convert_sqlalchemy_composite.register(CompositeFullName)
def convert_composite_class(composite, registry):
pass
return graphene.Field(graphene.Int)


@pytest.yield_fixture(scope="function")
Expand Down
32 changes: 24 additions & 8 deletions graphene_sqlalchemy/tests/test_converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,14 +23,18 @@
from .models import Article, CompositeFullName, Pet, Reporter


def mock_resolver():
pass


def get_field(sqlalchemy_type, **column_kwargs):
class Model(declarative_base()):
__tablename__ = 'model'
id_ = Column(types.Integer, primary_key=True)
column = Column(sqlalchemy_type, doc="Custom Help Text", **column_kwargs)

column_prop = inspect(Model).column_attrs['column']
return convert_sqlalchemy_column(column_prop, get_global_registry())
return convert_sqlalchemy_column(column_prop, get_global_registry(), mock_resolver)


def get_field_from_column(column_):
Expand All @@ -40,7 +44,7 @@ class Model(declarative_base()):
column = column_

column_prop = inspect(Model).column_attrs['column']
return convert_sqlalchemy_column(column_prop, get_global_registry())
return convert_sqlalchemy_column(column_prop, get_global_registry(), mock_resolver)


def test_should_unknown_sqlalchemy_field_raise_exception():
Expand Down Expand Up @@ -162,7 +166,7 @@ def test_should_jsontype_convert_jsonstring():
def test_should_manytomany_convert_connectionorlist():
registry = Registry()
dynamic_field = convert_sqlalchemy_relationship(
Reporter.pets.property, registry, default_connection_field_factory
Reporter.pets.property, registry, default_connection_field_factory, mock_resolver,
)
assert isinstance(dynamic_field, graphene.Dynamic)
assert not dynamic_field.get_type()
Expand All @@ -174,7 +178,7 @@ class Meta:
model = Pet

dynamic_field = convert_sqlalchemy_relationship(
Reporter.pets.property, A._meta.registry, default_connection_field_factory
Reporter.pets.property, A._meta.registry, default_connection_field_factory, mock_resolver,
)
assert isinstance(dynamic_field, graphene.Dynamic)
graphene_type = dynamic_field.get_type()
Expand All @@ -190,7 +194,7 @@ class Meta:
interfaces = (Node,)

dynamic_field = convert_sqlalchemy_relationship(
Reporter.pets.property, A._meta.registry, default_connection_field_factory
Reporter.pets.property, A._meta.registry, default_connection_field_factory, mock_resolver
)
assert isinstance(dynamic_field, graphene.Dynamic)
assert isinstance(dynamic_field.get_type(), UnsortedSQLAlchemyConnectionField)
Expand All @@ -199,7 +203,10 @@ class Meta:
def test_should_manytoone_convert_connectionorlist():
registry = Registry()
dynamic_field = convert_sqlalchemy_relationship(
Article.reporter.property, registry, default_connection_field_factory
Article.reporter.property,
registry,
default_connection_field_factory,
mock_resolver,
)
assert isinstance(dynamic_field, graphene.Dynamic)
assert not dynamic_field.get_type()
Expand All @@ -211,7 +218,10 @@ class Meta:
model = Reporter

dynamic_field = convert_sqlalchemy_relationship(
Article.reporter.property, A._meta.registry, default_connection_field_factory
Article.reporter.property,
A._meta.registry,
default_connection_field_factory,
mock_resolver,
)
assert isinstance(dynamic_field, graphene.Dynamic)
graphene_type = dynamic_field.get_type()
Expand All @@ -226,7 +236,10 @@ class Meta:
interfaces = (Node,)

dynamic_field = convert_sqlalchemy_relationship(
Article.reporter.property, A._meta.registry, default_connection_field_factory
Article.reporter.property,
A._meta.registry,
default_connection_field_factory,
mock_resolver,
)
assert isinstance(dynamic_field, graphene.Dynamic)
graphene_type = dynamic_field.get_type()
Expand All @@ -244,6 +257,7 @@ class Meta:
Reporter.favorite_article.property,
A._meta.registry,
default_connection_field_factory,
mock_resolver,
)
assert isinstance(dynamic_field, graphene.Dynamic)
graphene_type = dynamic_field.get_type()
Expand Down Expand Up @@ -310,6 +324,7 @@ def convert_composite_class(composite, registry):
field = convert_sqlalchemy_composite(
composite(CompositeClass, (Column(types.Unicode(50)), Column(types.Unicode(50))), doc="Custom Help Text"),
registry,
mock_resolver,
)
assert isinstance(field, graphene.String)

Expand All @@ -325,4 +340,5 @@ def __init__(self, col1, col2):
convert_sqlalchemy_composite(
composite(CompositeFullName, (Column(types.Unicode(50)), Column(types.Unicode(50)))),
Registry(),
mock_resolver,
)
72 changes: 70 additions & 2 deletions graphene_sqlalchemy/tests/test_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import six # noqa F401

from graphene import (Dynamic, Field, GlobalID, Int, List, Node, NonNull,
ObjectType, String)
ObjectType, Schema, String)

from ..converter import convert_sqlalchemy_composite
from ..fields import (SQLAlchemyConnectionField,
Expand Down Expand Up @@ -264,6 +264,7 @@ class Meta:
"column_prop",
"email",
"favorite_pet_kind",
"composite_prop",
"hybrid_prop",
"pets",
"articles",
Expand Down Expand Up @@ -293,6 +294,73 @@ class Meta:
assert first_name_field.type == Int


def test_resolvers(session):
"""Test that the correct resolver functions are called"""

class ReporterMixin(object):
def resolve_id(root, _info):
return 'ID'

class ReporterType(ReporterMixin, SQLAlchemyObjectType):
class Meta:
model = Reporter

email = ORMField()
email_v2 = ORMField(model_attr='email')
favorite_pet_kind = Field(String)
favorite_pet_kind_v2 = Field(String)

def resolve_last_name(root, _info):
return root.last_name.upper()

def resolve_email_v2(root, _info):
return root.email + '_V2'

def resolve_favorite_pet_kind_v2(root, _info):
return str(root.favorite_pet_kind) + '_V2'

class Query(ObjectType):
reporter = Field(ReporterType)

def resolve_reporter(self, _info):
return session.query(Reporter).first()

reporter = Reporter(first_name='first_name', last_name='last_name', email='email', favorite_pet_kind='cat')
session.add(reporter)
session.commit()

schema = Schema(query=Query)
result = schema.execute("""
query {
reporter {
id
firstName
lastName
email
emailV2
favoritePetKind
favoritePetKindV2
}
}
""")

assert not result.errors
# Custom resolver on a base class
assert result.data['reporter']['id'] == 'ID'
# Default field + default resolver
assert result.data['reporter']['firstName'] == 'first_name'
# Default field + custom resolver
assert result.data['reporter']['lastName'] == 'LAST_NAME'
# ORMField + default resolver
assert result.data['reporter']['email'] == 'email'
# ORMField + custom resolver
assert result.data['reporter']['emailV2'] == 'email_V2'
# Field + default resolver
assert result.data['reporter']['favoritePetKind'] == 'cat'
# Field + custom resolver
assert result.data['reporter']['favoritePetKindV2'] == 'cat_V2'


# Test Custom SQLAlchemyObjectType Implementation

def test_custom_objecttype_registered():
Expand All @@ -306,7 +374,7 @@ class Meta:

assert issubclass(CustomReporterType, ObjectType)
assert CustomReporterType._meta.model == Reporter
assert len(CustomReporterType._meta.fields) == 10
assert len(CustomReporterType._meta.fields) == 11


# Test Custom SQLAlchemyObjectType with Custom Options
Expand Down
30 changes: 26 additions & 4 deletions graphene_sqlalchemy/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from graphene.relay import Connection, Node
from graphene.types.objecttype import ObjectType, ObjectTypeOptions
from graphene.types.utils import yank_fields_from_attrs
from graphene.utils.get_unbound_function import get_unbound_function
from graphene.utils.orderedtype import OrderedType

from .converter import (convert_sqlalchemy_column,
Expand Down Expand Up @@ -151,20 +152,22 @@ def construct_fields(
for orm_field_name, orm_field in orm_fields.items():
attr_name = orm_field.kwargs.pop('model_attr')
attr = all_model_attrs[attr_name]
resolver = _get_field_resolver(obj_type, orm_field_name, attr_name)

if isinstance(attr, ColumnProperty):
field = convert_sqlalchemy_column(attr, registry, **orm_field.kwargs)
field = convert_sqlalchemy_column(attr, registry, resolver, **orm_field.kwargs)
elif isinstance(attr, RelationshipProperty):
field = convert_sqlalchemy_relationship(attr, registry, connection_field_factory, **orm_field.kwargs)
field = convert_sqlalchemy_relationship(attr, registry, connection_field_factory, resolver,
**orm_field.kwargs)
elif isinstance(attr, CompositeProperty):
if attr_name != orm_field_name or orm_field.kwargs:
# TODO Add a way to override composite property fields
raise ValueError(
"ORMField kwargs for composite fields must be empty. "
"Field: {}.{}".format(obj_type.__name__, orm_field_name))
field = convert_sqlalchemy_composite(attr, registry)
field = convert_sqlalchemy_composite(attr, registry, resolver)
elif isinstance(attr, hybrid_property):
field = convert_sqlalchemy_hybrid_method(attr, attr_name, **orm_field.kwargs)
field = convert_sqlalchemy_hybrid_method(attr, resolver, **orm_field.kwargs)
else:
raise Exception('Property type is not supported') # Should never happen

Expand All @@ -174,6 +177,25 @@ def construct_fields(
return fields


def _get_field_resolver(obj_type, orm_field_name, model_attr):
"""
In order to support field renaming via `ORMField.model_attr`,
we need to define resolver functions for each field.

:param SQLAlchemyObjectType obj_type:
:param model: the SQLAlchemy model
:param str model_attr: the name of SQLAlchemy of the attribute used to resolve the field
:rtype: Callable
"""
# Since `graphene` will call `resolve_<field_name>` on a field only if it
# does not have a `resolver`, we need to re-implement that logic here.
resolver = getattr(obj_type, 'resolve_{}'.format(orm_field_name), None)
if resolver:
return get_unbound_function(resolver)

return lambda root, _info: getattr(root, model_attr, None)


class SQLAlchemyObjectTypeOptions(ObjectTypeOptions):
model = None # type: sqlalchemy.Model
registry = None # type: sqlalchemy.Registry
Expand Down