Skip to content

Commit dfc7898

Browse files
committed
general codebase refactor
Signed-off-by: wiseaidev <[email protected]>
1 parent dcd84e0 commit dfc7898

File tree

6 files changed

+39
-41
lines changed

6 files changed

+39
-41
lines changed

aredis_om/checks.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -12,15 +12,15 @@ async def check_for_command(conn, cmd):
1212

1313
@lru_cache(maxsize=None)
1414
async def has_redis_json(conn=None):
15-
if conn is None:
15+
if not conn:
1616
conn = get_redis_connection()
1717
command_exists = await check_for_command(conn, "json.set")
1818
return command_exists
1919

2020

2121
@lru_cache(maxsize=None)
2222
async def has_redisearch(conn=None):
23-
if conn is None:
23+
if not conn:
2424
conn = get_redis_connection()
2525
if has_redis_json(conn):
2626
return True

aredis_om/model/encoders.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -64,9 +64,9 @@ def jsonable_encoder(
6464
custom_encoder: Dict[Any, Callable[[Any], Any]] = {},
6565
sqlalchemy_safe: bool = True,
6666
) -> Any:
67-
if include is not None and not isinstance(include, (set, dict)):
67+
if include and not isinstance(include, (set, dict)):
6868
include = set(include)
69-
if exclude is not None and not isinstance(exclude, (set, dict)):
69+
if exclude and not isinstance(exclude, (set, dict)):
7070
exclude = set(exclude)
7171

7272
if isinstance(obj, BaseModel):
@@ -107,7 +107,7 @@ def jsonable_encoder(
107107
or (not isinstance(key, str))
108108
or (not key.startswith("_sa"))
109109
)
110-
and (value is not None or not exclude_none)
110+
and (value or not exclude_none)
111111
and ((include and key in include) or not exclude or key not in exclude)
112112
):
113113
encoded_key = jsonable_encoder(

aredis_om/model/model.py

+24-24
Original file line numberDiff line numberDiff line change
@@ -117,7 +117,7 @@ def is_supported_container_type(typ: Optional[type]) -> bool:
117117

118118

119119
def validate_model_fields(model: Type["RedisModel"], field_values: Dict[str, Any]):
120-
for field_name in field_values.keys():
120+
for field_name in field_values:
121121
if "__" in field_name:
122122
obj = model
123123
for sub_field in field_name.split("__"):
@@ -432,11 +432,11 @@ def validate_sort_fields(self, sort_fields: List[str]):
432432

433433
@staticmethod
434434
def resolve_field_type(field: ModelField, op: Operators) -> RediSearchFieldTypes:
435-
if getattr(field.field_info, "primary_key", None) is True:
435+
if getattr(field.field_info, "primary_key", None):
436436
return RediSearchFieldTypes.TAG
437437
elif op is Operators.LIKE:
438438
fts = getattr(field.field_info, "full_text_search", None)
439-
if fts is not True: # Could be PydanticUndefined
439+
if not fts: # Could be PydanticUndefined
440440
raise QuerySyntaxError(
441441
f"You tried to do a full-text search on the field '{field.name}', "
442442
f"but the field is not indexed for full-text search. Use the "
@@ -464,7 +464,7 @@ def resolve_field_type(field: ModelField, op: Operators) -> RediSearchFieldTypes
464464
# is not itself directly indexed, but instead, we index any fields
465465
# within the model inside the list marked as `index=True`.
466466
return RediSearchFieldTypes.TAG
467-
elif container_type is not None:
467+
elif container_type:
468468
raise QuerySyntaxError(
469469
"Only lists and tuples are supported for multi-value fields. "
470470
f"Docs: {ERRORS_URL}#E4"
@@ -567,7 +567,7 @@ def resolve_value(
567567
# The value contains the TAG field separator. We can work
568568
# around this by breaking apart the values and unioning them
569569
# with multiple field:{} queries.
570-
values: filter = filter(None, value.split(separator_char))
570+
values: List[str] = [val for val in value.split(separator_char) if val]
571571
for value in values:
572572
value = escaper.escape(value)
573573
result += f"@{field_name}:{{{value}}}"
@@ -1131,7 +1131,7 @@ async def save(self, pipeline: Optional[Pipeline] = None) -> "RedisModel":
11311131
raise NotImplementedError
11321132

11331133
async def expire(self, num_seconds: int, pipeline: Optional[Pipeline] = None):
1134-
if pipeline is None:
1134+
if not pipeline:
11351135
db = self.db()
11361136
else:
11371137
db = pipeline
@@ -1195,12 +1195,12 @@ def to_string(s):
11951195
step = 2 # Because the result has content
11961196
offset = 1 # The first item is the count of total matches.
11971197

1198-
for i in xrange(1, len(res), step):
1198+
for i in range(1, len(res), step):
11991199
fields_offset = offset
12001200

12011201
fields = dict(
12021202
dict(
1203-
izip(
1203+
zip(
12041204
map(to_string, res[i + fields_offset][::2]),
12051205
map(to_string, res[i + fields_offset][1::2]),
12061206
)
@@ -1244,7 +1244,7 @@ async def add(
12441244
pipeline: Optional[Pipeline] = None,
12451245
pipeline_verifier: Callable[..., Any] = verify_pipeline_response,
12461246
) -> Sequence["RedisModel"]:
1247-
if pipeline is None:
1247+
if not pipeline:
12481248
# By default, send commands in a pipeline. Saving each model will
12491249
# be atomic, but Redis may process other commands in between
12501250
# these saves.
@@ -1261,7 +1261,7 @@ async def add(
12611261

12621262
# If the user didn't give us a pipeline, then we need to execute
12631263
# the one we just created.
1264-
if pipeline is None:
1264+
if not pipeline:
12651265
result = await db.execute()
12661266
pipeline_verifier(result, expected_responses=len(models))
12671267

@@ -1303,7 +1303,7 @@ def __init_subclass__(cls, **kwargs):
13031303

13041304
async def save(self, pipeline: Optional[Pipeline] = None) -> "HashModel":
13051305
self.check()
1306-
if pipeline is None:
1306+
if not pipeline:
13071307
db = self.db()
13081308
else:
13091309
db = pipeline
@@ -1356,7 +1356,7 @@ def _get_value(cls, *args, **kwargs) -> Any:
13561356
values. Is there a better way?
13571357
"""
13581358
val = super()._get_value(*args, **kwargs)
1359-
if val is None:
1359+
if not val:
13601360
return ""
13611361
return val
13621362

@@ -1392,7 +1392,7 @@ def schema_for_fields(cls):
13921392
name, _type, field.field_info
13931393
)
13941394
schema_parts.append(redisearch_field)
1395-
elif getattr(field.field_info, "index", None) is True:
1395+
elif getattr(field.field_info, "index", None):
13961396
schema_parts.append(cls.schema_for_type(name, _type, field.field_info))
13971397
elif is_subscripted_type:
13981398
# Ignore subscripted types (usually containers!) that we don't
@@ -1437,7 +1437,7 @@ def schema_for_type(cls, name, typ: Any, field_info: PydanticFieldInfo):
14371437
elif any(issubclass(typ, t) for t in NUMERIC_TYPES):
14381438
schema = f"{name} NUMERIC"
14391439
elif issubclass(typ, str):
1440-
if getattr(field_info, "full_text_search", False) is True:
1440+
if getattr(field_info, "full_text_search", False):
14411441
schema = (
14421442
f"{name} TAG SEPARATOR {SINGLE_VALUE_TAG_FIELD_SEPARATOR} "
14431443
f"{name} AS {name}_fts TEXT"
@@ -1455,7 +1455,7 @@ def schema_for_type(cls, name, typ: Any, field_info: PydanticFieldInfo):
14551455
schema = " ".join(sub_fields)
14561456
else:
14571457
schema = f"{name} TAG SEPARATOR {SINGLE_VALUE_TAG_FIELD_SEPARATOR}"
1458-
if schema and sortable is True:
1458+
if schema and sortable:
14591459
schema += " SORTABLE"
14601460
return schema
14611461

@@ -1475,7 +1475,7 @@ def __init__(self, *args, **kwargs):
14751475

14761476
async def save(self, pipeline: Optional[Pipeline] = None) -> "JsonModel":
14771477
self.check()
1478-
if pipeline is None:
1478+
if not pipeline:
14791479
db = self.db()
14801480
else:
14811481
db = pipeline
@@ -1633,7 +1633,7 @@ def schema_for_type(
16331633
parent_type=typ,
16341634
)
16351635
)
1636-
return " ".join(filter(None, sub_fields))
1636+
return " ".join([sub_field for sub_field in sub_fields if sub_field])
16371637
# NOTE: This is the termination point for recursion. We've descended
16381638
# into models and lists until we found an actual value to index.
16391639
elif should_index:
@@ -1655,28 +1655,28 @@ def schema_for_type(
16551655

16561656
# TODO: GEO field
16571657
if parent_is_container_type or parent_is_model_in_container:
1658-
if typ is not str:
1658+
if not isinstance(typ, str):
16591659
raise RedisModelError(
16601660
"In this Preview release, list and tuple fields can only "
16611661
f"contain strings. Problem field: {name}. See docs: TODO"
16621662
)
1663-
if full_text_search is True:
1663+
if full_text_search:
16641664
raise RedisModelError(
16651665
"List and tuple fields cannot be indexed for full-text "
16661666
f"search. Problem field: {name}. See docs: TODO"
16671667
)
16681668
schema = f"{path} AS {index_field_name} TAG SEPARATOR {SINGLE_VALUE_TAG_FIELD_SEPARATOR}"
1669-
if sortable is True:
1669+
if sortable:
16701670
raise sortable_tag_error
16711671
elif any(issubclass(typ, t) for t in NUMERIC_TYPES):
16721672
schema = f"{path} AS {index_field_name} NUMERIC"
16731673
elif issubclass(typ, str):
1674-
if full_text_search is True:
1674+
if full_text_search:
16751675
schema = (
16761676
f"{path} AS {index_field_name} TAG SEPARATOR {SINGLE_VALUE_TAG_FIELD_SEPARATOR} "
16771677
f"{path} AS {index_field_name}_fts TEXT"
16781678
)
1679-
if sortable is True:
1679+
if sortable:
16801680
# NOTE: With the current preview release, making a field
16811681
# full-text searchable and sortable only makes the TEXT
16821682
# field sortable. This means that results for full-text
@@ -1685,11 +1685,11 @@ def schema_for_type(
16851685
schema += " SORTABLE"
16861686
else:
16871687
schema = f"{path} AS {index_field_name} TAG SEPARATOR {SINGLE_VALUE_TAG_FIELD_SEPARATOR}"
1688-
if sortable is True:
1688+
if sortable:
16891689
raise sortable_tag_error
16901690
else:
16911691
schema = f"{path} AS {index_field_name} TAG SEPARATOR {SINGLE_VALUE_TAG_FIELD_SEPARATOR}"
1692-
if sortable is True:
1692+
if sortable:
16931693
raise sortable_tag_error
16941694
return schema
16951695
return ""

aredis_om/model/render_tree.py

+8-10
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ def render_tree(
2121
write to a StringIO buffer, then use that buffer to accumulate written lines
2222
during recursive calls to render_tree().
2323
"""
24-
if buffer is None:
24+
if not buffer:
2525
buffer = io.StringIO()
2626
if hasattr(current_node, nameattr):
2727
name = lambda node: getattr(node, nameattr) # noqa: E731
@@ -31,11 +31,9 @@ def render_tree(
3131
up = getattr(current_node, left_child, None)
3232
down = getattr(current_node, right_child, None)
3333

34-
if up is not None:
34+
if up:
3535
next_last = "up"
36-
next_indent = "{0}{1}{2}".format(
37-
indent, " " if "up" in last else "|", " " * len(str(name(current_node)))
38-
)
36+
next_indent = f'{indent}{" " if "up" in last else "|"}{" " * len(str(name(current_node)))}'
3937
render_tree(
4038
up, nameattr, left_child, right_child, next_indent, next_last, buffer
4139
)
@@ -49,7 +47,7 @@ def render_tree(
4947
else:
5048
start_shape = "├"
5149

52-
if up is not None and down is not None:
50+
if up and down:
5351
end_shape = "┤"
5452
elif up:
5553
end_shape = "┘"
@@ -59,14 +57,14 @@ def render_tree(
5957
end_shape = ""
6058

6159
print(
62-
"{0}{1}{2}{3}".format(indent, start_shape, name(current_node), end_shape),
60+
f"{indent}{start_shape}{name(current_node)}{end_shape}",
6361
file=buffer,
6462
)
6563

66-
if down is not None:
64+
if down:
6765
next_last = "down"
68-
next_indent = "{0}{1}{2}".format(
69-
indent, " " if "down" in last else "|", " " * len(str(name(current_node)))
66+
next_indent = (
67+
f'{indent}{" " if "down" in last else "|"}{len(str(name(current_node)))}'
7068
)
7169
render_tree(
7270
down, nameattr, left_child, right_child, next_indent, next_last, buffer

aredis_om/unasync_util.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ async def f():
1414
return None
1515

1616
obj = f()
17-
if obj is None:
17+
if not obj:
1818
return False
1919
else:
2020
obj.close() # prevent unawaited coroutine warning

tests/test_oss_redis_features.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,7 @@ async def members(m):
8080
async def test_all_keys(members, m):
8181
pks = sorted([pk async for pk in await m.Member.all_pks()])
8282
assert len(pks) == 3
83-
assert pks == sorted([m.pk for m in members])
83+
assert pks == sorted(m.pk for m in members)
8484

8585

8686
@py_test_mark_asyncio

0 commit comments

Comments
 (0)