diff --git a/redis/commands/graph/__init__.py b/redis/commands/graph/__init__.py index 3736195007..a882dd514d 100644 --- a/redis/commands/graph/__init__.py +++ b/redis/commands/graph/__init__.py @@ -1,9 +1,13 @@ from ..helpers import quote_string, random_string, stringify_param_value -from .commands import GraphCommands +from .commands import AsyncGraphCommands, GraphCommands from .edge import Edge # noqa from .node import Node # noqa from .path import Path # noqa +DB_LABELS = "DB.LABELS" +DB_RAELATIONSHIPTYPES = "DB.RELATIONSHIPTYPES" +DB_PROPERTYKEYS = "DB.PROPERTYKEYS" + class Graph(GraphCommands): """ @@ -44,25 +48,19 @@ def _refresh_labels(self): lbls = self.labels() # Unpack data. - self._labels = [None] * len(lbls) - for i, l in enumerate(lbls): - self._labels[i] = l[0] + self._labels = [l[0] for _, l in enumerate(lbls)] def _refresh_relations(self): rels = self.relationship_types() # Unpack data. - self._relationship_types = [None] * len(rels) - for i, r in enumerate(rels): - self._relationship_types[i] = r[0] + self._relationship_types = [r[0] for _, r in enumerate(rels)] def _refresh_attributes(self): props = self.property_keys() # Unpack data. - self._properties = [None] * len(props) - for i, p in enumerate(props): - self._properties[i] = p[0] + self._properties = [p[0] for _, p in enumerate(props)] def get_label(self, idx): """ @@ -108,12 +106,12 @@ def get_property(self, idx): The index of the property """ try: - propertie = self._properties[idx] + p = self._properties[idx] except IndexError: # Refresh properties. self._refresh_attributes() - propertie = self._properties[idx] - return propertie + p = self._properties[idx] + return p def add_node(self, node): """ @@ -133,6 +131,8 @@ def add_edge(self, edge): self.edges.append(edge) def _build_params_header(self, params): + if params is None: + return "" if not isinstance(params, dict): raise TypeError("'params' must be a dict") # Header starts with "CYPHER" @@ -147,16 +147,109 @@ def call_procedure(self, procedure, *args, read_only=False, **kwagrs): q = f"CALL {procedure}({','.join(args)})" y = kwagrs.get("y", None) - if y: - q += f" YIELD {','.join(y)}" + if y is not None: + q += f"YIELD {','.join(y)}" return self.query(q, read_only=read_only) def labels(self): - return self.call_procedure("db.labels", read_only=True).result_set + return self.call_procedure(DB_LABELS, read_only=True).result_set def relationship_types(self): - return self.call_procedure("db.relationshipTypes", read_only=True).result_set + return self.call_procedure(DB_RAELATIONSHIPTYPES, read_only=True).result_set def property_keys(self): - return self.call_procedure("db.propertyKeys", read_only=True).result_set + return self.call_procedure(DB_PROPERTYKEYS, read_only=True).result_set + + +class AsyncGraph(Graph, AsyncGraphCommands): + """Async version for Graph""" + + async def _refresh_labels(self): + lbls = await self.labels() + + # Unpack data. + self._labels = [l[0] for _, l in enumerate(lbls)] + + async def _refresh_attributes(self): + props = await self.property_keys() + + # Unpack data. + self._properties = [p[0] for _, p in enumerate(props)] + + async def _refresh_relations(self): + rels = await self.relationship_types() + + # Unpack data. + self._relationship_types = [r[0] for _, r in enumerate(rels)] + + async def get_label(self, idx): + """ + Returns a label by it's index + + Args: + + idx: + The index of the label + """ + try: + label = self._labels[idx] + except IndexError: + # Refresh labels. + await self._refresh_labels() + label = self._labels[idx] + return label + + async def get_property(self, idx): + """ + Returns a property by it's index + + Args: + + idx: + The index of the property + """ + try: + p = self._properties[idx] + except IndexError: + # Refresh properties. + await self._refresh_attributes() + p = self._properties[idx] + return p + + async def get_relation(self, idx): + """ + Returns a relationship type by it's index + + Args: + + idx: + The index of the relation + """ + try: + relationship_type = self._relationship_types[idx] + except IndexError: + # Refresh relationship types. + await self._refresh_relations() + relationship_type = self._relationship_types[idx] + return relationship_type + + async def call_procedure(self, procedure, *args, read_only=False, **kwagrs): + args = [quote_string(arg) for arg in args] + q = f"CALL {procedure}({','.join(args)})" + + y = kwagrs.get("y", None) + if y is not None: + f"YIELD {','.join(y)}" + return await self.query(q, read_only=read_only) + + async def labels(self): + return ((await self.call_procedure(DB_LABELS, read_only=True))).result_set + + async def property_keys(self): + return (await self.call_procedure(DB_PROPERTYKEYS, read_only=True)).result_set + + async def relationship_types(self): + return ( + await self.call_procedure(DB_RAELATIONSHIPTYPES, read_only=True) + ).result_set diff --git a/redis/commands/graph/commands.py b/redis/commands/graph/commands.py index fe4224b5cf..762ab42e16 100644 --- a/redis/commands/graph/commands.py +++ b/redis/commands/graph/commands.py @@ -3,7 +3,16 @@ from .exceptions import VersionMismatchException from .execution_plan import ExecutionPlan -from .query_result import QueryResult +from .query_result import AsyncQueryResult, QueryResult + +PROFILE_CMD = "GRAPH.PROFILE" +RO_QUERY_CMD = "GRAPH.RO_QUERY" +QUERY_CMD = "GRAPH.QUERY" +DELETE_CMD = "GRAPH.DELETE" +SLOWLOG_CMD = "GRAPH.SLOWLOG" +CONFIG_CMD = "GRAPH.CONFIG" +LIST_CMD = "GRAPH.LIST" +EXPLAIN_CMD = "GRAPH.EXPLAIN" class GraphCommands: @@ -52,33 +61,28 @@ def query(self, q, params=None, timeout=None, read_only=False, profile=False): query = q # handle query parameters - if params is not None: - query = self._build_params_header(params) + query + query = self._build_params_header(params) + query # construct query command # ask for compact result-set format # specify known graph version if profile: - cmd = "GRAPH.PROFILE" + cmd = PROFILE_CMD else: - cmd = "GRAPH.RO_QUERY" if read_only else "GRAPH.QUERY" + cmd = RO_QUERY_CMD if read_only else QUERY_CMD command = [cmd, self.name, query, "--compact"] # include timeout is specified - if timeout: - if not isinstance(timeout, int): - raise Exception("Timeout argument must be a positive integer") - command += ["timeout", timeout] + if isinstance(timeout, int): + command.extend(["timeout", timeout]) + elif timeout is not None: + raise Exception("Timeout argument must be a positive integer") # issue query try: response = self.execute_command(*command) return QueryResult(self, response, profile) except ResponseError as e: - if "wrong number of arguments" in str(e): - print( - "Note: RedisGraph Python requires server version 2.2.8 or above" - ) # noqa if "unknown command" in str(e) and read_only: # `GRAPH.RO_QUERY` is unavailable in older versions. return self.query(q, params, timeout, read_only=False) @@ -106,7 +110,7 @@ def delete(self): For more information see `DELETE `_. # noqa """ self._clear_schema() - return self.execute_command("GRAPH.DELETE", self.name) + return self.execute_command(DELETE_CMD, self.name) # declared here, to override the built in redis.db.flush() def flush(self): @@ -146,7 +150,7 @@ def slowlog(self): 3. The issued query. 4. The amount of time needed for its execution, in milliseconds. """ - return self.execute_command("GRAPH.SLOWLOG", self.name) + return self.execute_command(SLOWLOG_CMD, self.name) def config(self, name, value=None, set=False): """ @@ -170,14 +174,14 @@ def config(self, name, value=None, set=False): raise DataError( "``value`` can be provided only when ``set`` is True" ) # noqa - return self.execute_command("GRAPH.CONFIG", *params) + return self.execute_command(CONFIG_CMD, *params) def list_keys(self): """ Lists all graph keys in the keyspace. For more information see `GRAPH.LIST `_. # noqa """ - return self.execute_command("GRAPH.LIST") + return self.execute_command(LIST_CMD) def execution_plan(self, query, params=None): """ @@ -188,10 +192,9 @@ def execution_plan(self, query, params=None): query: the query that will be executed params: query parameters """ - if params is not None: - query = self._build_params_header(params) + query + query = self._build_params_header(params) + query - plan = self.execute_command("GRAPH.EXPLAIN", self.name, query) + plan = self.execute_command(EXPLAIN_CMD, self.name, query) if isinstance(plan[0], bytes): plan = [b.decode() for b in plan] return "\n".join(plan) @@ -206,8 +209,105 @@ def explain(self, query, params=None): query: the query that will be executed params: query parameters """ - if params is not None: - query = self._build_params_header(params) + query + query = self._build_params_header(params) + query + + plan = self.execute_command(EXPLAIN_CMD, self.name, query) + return ExecutionPlan(plan) + + +class AsyncGraphCommands(GraphCommands): + async def query(self, q, params=None, timeout=None, read_only=False, profile=False): + """ + Executes a query against the graph. + For more information see `GRAPH.QUERY `_. # noqa + + Args: + + q : str + The query. + params : dict + Query parameters. + timeout : int + Maximum runtime for read queries in milliseconds. + read_only : bool + Executes a readonly query if set to True. + profile : bool + Return details on results produced by and time + spent in each operation. + """ + + # maintain original 'q' + query = q + + # handle query parameters + query = self._build_params_header(params) + query + + # construct query command + # ask for compact result-set format + # specify known graph version + if profile: + cmd = PROFILE_CMD + else: + cmd = RO_QUERY_CMD if read_only else QUERY_CMD + command = [cmd, self.name, query, "--compact"] + + # include timeout is specified + if isinstance(timeout, int): + command.extend(["timeout", timeout]) + elif timeout is not None: + raise Exception("Timeout argument must be a positive integer") + + # issue query + try: + response = await self.execute_command(*command) + return await AsyncQueryResult().initialize(self, response, profile) + except ResponseError as e: + if "unknown command" in str(e) and read_only: + # `GRAPH.RO_QUERY` is unavailable in older versions. + return await self.query(q, params, timeout, read_only=False) + raise e + except VersionMismatchException as e: + # client view over the graph schema is out of sync + # set client version and refresh local schema + self.version = e.version + self._refresh_schema() + # re-issue query + return await self.query(q, params, timeout, read_only) + + async def execution_plan(self, query, params=None): + """ + Get the execution plan for given query, + GRAPH.EXPLAIN returns an array of operations. + + Args: + query: the query that will be executed + params: query parameters + """ + query = self._build_params_header(params) + query - plan = self.execute_command("GRAPH.EXPLAIN", self.name, query) + plan = await self.execute_command(EXPLAIN_CMD, self.name, query) + if isinstance(plan[0], bytes): + plan = [b.decode() for b in plan] + return "\n".join(plan) + + async def explain(self, query, params=None): + """ + Get the execution plan for given query, + GRAPH.EXPLAIN returns ExecutionPlan object. + + Args: + query: the query that will be executed + params: query parameters + """ + query = self._build_params_header(params) + query + + plan = await self.execute_command(EXPLAIN_CMD, self.name, query) return ExecutionPlan(plan) + + async def flush(self): + """ + Commit the graph and reset the edges and the nodes to zero length. + """ + await self.commit() + self.nodes = {} + self.edges = [] diff --git a/redis/commands/graph/query_result.py b/redis/commands/graph/query_result.py index 3ffa664791..b88b4b656c 100644 --- a/redis/commands/graph/query_result.py +++ b/redis/commands/graph/query_result.py @@ -1,4 +1,6 @@ +import sys from collections import OrderedDict +from distutils.util import strtobool # from prettytable import PrettyTable from redis import ResponseError @@ -90,6 +92,9 @@ def __init__(self, graph, response, profile=False): self.parse_results(response) def _check_for_errors(self, response): + """ + Check if the response contains an error. + """ if isinstance(response[0], ResponseError): error = response[0] if str(error) == "version mismatch": @@ -103,6 +108,9 @@ def _check_for_errors(self, response): raise response[-1] def parse_results(self, raw_result_set): + """ + Parse the query execution result returned from the server. + """ self.header = self.parse_header(raw_result_set) # Empty header. @@ -112,6 +120,9 @@ def parse_results(self, raw_result_set): self.result_set = self.parse_records(raw_result_set) def parse_statistics(self, raw_statistics): + """ + Parse the statistics returned in the response. + """ self.statistics = {} # decode statistics @@ -125,31 +136,31 @@ def parse_statistics(self, raw_statistics): self.statistics[s] = v def parse_header(self, raw_result_set): + """ + Parse the header of the result. + """ # An array of column name/column type pairs. header = raw_result_set[0] return header def parse_records(self, raw_result_set): - records = [] - result_set = raw_result_set[1] - for row in result_set: - record = [] - for idx, cell in enumerate(row): - if self.header[idx][0] == ResultSetColumnTypes.COLUMN_SCALAR: # noqa - record.append(self.parse_scalar(cell)) - elif self.header[idx][0] == ResultSetColumnTypes.COLUMN_NODE: # noqa - record.append(self.parse_node(cell)) - elif ( - self.header[idx][0] == ResultSetColumnTypes.COLUMN_RELATION - ): # noqa - record.append(self.parse_edge(cell)) - else: - print("Unknown column type.\n") - records.append(record) + """ + Parses the result set and returns a list of records. + """ + records = [ + [ + self.parse_record_types[self.header[idx][0]](cell) + for idx, cell in enumerate(row) + ] + for row in raw_result_set[1] + ] return records def parse_entity_properties(self, props): + """ + Parse node / edge properties. + """ # [[name, value type, value] X N] properties = {} for prop in props: @@ -160,6 +171,9 @@ def parse_entity_properties(self, props): return properties def parse_string(self, cell): + """ + Parse the cell as a string. + """ if isinstance(cell, bytes): return cell.decode() elif not isinstance(cell, str): @@ -168,6 +182,9 @@ def parse_string(self, cell): return cell def parse_node(self, cell): + """ + Parse the cell to a node. + """ # Node ID (integer), # [label string offset (integer)], # [[name, value type, value] X N] @@ -182,6 +199,9 @@ def parse_node(self, cell): return Node(node_id=node_id, label=labels, properties=properties) def parse_edge(self, cell): + """ + Parse the cell to an edge. + """ # Edge ID (integer), # reltype string offset (integer), # src node ID offset (integer), @@ -198,11 +218,17 @@ def parse_edge(self, cell): ) def parse_path(self, cell): + """ + Parse the cell to a path. + """ nodes = self.parse_scalar(cell[0]) edges = self.parse_scalar(cell[1]) return Path(nodes, edges) def parse_map(self, cell): + """ + Parse the cell as a map. + """ m = OrderedDict() n_entries = len(cell) @@ -216,6 +242,9 @@ def parse_map(self, cell): return m def parse_point(self, cell): + """ + Parse the cell to point. + """ p = {} # A point is received an array of the form: [latitude, longitude] # It is returned as a map of the form: {"latitude": latitude, "longitude": longitude} # noqa @@ -223,94 +252,63 @@ def parse_point(self, cell): p["longitude"] = float(cell[1]) return p - def parse_scalar(self, cell): - scalar_type = int(cell[0]) - value = cell[1] - scalar = None - - if scalar_type == ResultSetScalarTypes.VALUE_NULL: - scalar = None - - elif scalar_type == ResultSetScalarTypes.VALUE_STRING: - scalar = self.parse_string(value) - - elif scalar_type == ResultSetScalarTypes.VALUE_INTEGER: - scalar = int(value) - - elif scalar_type == ResultSetScalarTypes.VALUE_BOOLEAN: - value = value.decode() if isinstance(value, bytes) else value - if value == "true": - scalar = True - elif value == "false": - scalar = False - else: - print("Unknown boolean type\n") - - elif scalar_type == ResultSetScalarTypes.VALUE_DOUBLE: - scalar = float(value) - - elif scalar_type == ResultSetScalarTypes.VALUE_ARRAY: - # array variable is introduced only for readability - scalar = array = value - for i in range(len(array)): - scalar[i] = self.parse_scalar(array[i]) + def parse_null(self, cell): + """ + Parse a null value. + """ + return None - elif scalar_type == ResultSetScalarTypes.VALUE_NODE: - scalar = self.parse_node(value) + def parse_integer(self, cell): + """ + Parse the integer value from the cell. + """ + return int(cell) - elif scalar_type == ResultSetScalarTypes.VALUE_EDGE: - scalar = self.parse_edge(value) + def parse_boolean(self, value): + """ + Parse the cell value as a boolean. + """ + value = value.decode() if isinstance(value, bytes) else value + try: + scalar = strtobool(value) + except ValueError: + sys.stderr.write("unknown boolean type\n") + scalar = None + return scalar - elif scalar_type == ResultSetScalarTypes.VALUE_PATH: - scalar = self.parse_path(value) + def parse_double(self, cell): + """ + Parse the cell as a double. + """ + return float(cell) - elif scalar_type == ResultSetScalarTypes.VALUE_MAP: - scalar = self.parse_map(value) + def parse_array(self, value): + """ + Parse an array of values. + """ + scalar = [self.parse_scalar(value[i]) for i in range(len(value))] + return scalar - elif scalar_type == ResultSetScalarTypes.VALUE_POINT: - scalar = self.parse_point(value) + def parse_unknown(self, cell): + """ + Parse a cell of unknown type. + """ + sys.stderr.write("Unknown type\n") + return None - elif scalar_type == ResultSetScalarTypes.VALUE_UNKNOWN: - print("Unknown scalar type\n") + def parse_scalar(self, cell): + """ + Parse a scalar value from a cell in the result set. + """ + scalar_type = int(cell[0]) + value = cell[1] + scalar = self.parse_scalar_types[scalar_type](value) return scalar def parse_profile(self, response): self.result_set = [x[0 : x.index(",")].strip() for x in response] - # """Prints the data from the query response: - # 1. First row result_set contains the columns names. - # Thus the first row in PrettyTable will contain the - # columns. - # 2. The row after that will contain the data returned, - # or 'No Data returned' if there is none. - # 3. Prints the statistics of the query. - # """ - - # def pretty_print(self): - # if not self.is_empty(): - # header = [col[1] for col in self.header] - # tbl = PrettyTable(header) - - # for row in self.result_set: - # record = [] - # for idx, cell in enumerate(row): - # if type(cell) is Node: - # record.append(cell.to_string()) - # elif type(cell) is Edge: - # record.append(cell.to_string()) - # else: - # record.append(cell) - # tbl.add_row(record) - - # if len(self.result_set) == 0: - # tbl.add_row(['No data returned.']) - - # print(str(tbl) + '\n') - - # for stat in self.statistics: - # print("%s %s" % (stat, self.statistics[stat])) - def is_empty(self): return len(self.result_set) == 0 @@ -384,3 +382,192 @@ def cached_execution(self): def run_time_ms(self): """Returns the server execution time of the query""" return self._get_stat(INTERNAL_EXECUTION_TIME) + + @property + def parse_scalar_types(self): + return { + ResultSetScalarTypes.VALUE_NULL: self.parse_null, + ResultSetScalarTypes.VALUE_STRING: self.parse_string, + ResultSetScalarTypes.VALUE_INTEGER: self.parse_integer, + ResultSetScalarTypes.VALUE_BOOLEAN: self.parse_boolean, + ResultSetScalarTypes.VALUE_DOUBLE: self.parse_double, + ResultSetScalarTypes.VALUE_ARRAY: self.parse_array, + ResultSetScalarTypes.VALUE_NODE: self.parse_node, + ResultSetScalarTypes.VALUE_EDGE: self.parse_edge, + ResultSetScalarTypes.VALUE_PATH: self.parse_path, + ResultSetScalarTypes.VALUE_MAP: self.parse_map, + ResultSetScalarTypes.VALUE_POINT: self.parse_point, + ResultSetScalarTypes.VALUE_UNKNOWN: self.parse_unknown, + } + + @property + def parse_record_types(self): + return { + ResultSetColumnTypes.COLUMN_SCALAR: self.parse_scalar, + ResultSetColumnTypes.COLUMN_NODE: self.parse_node, + ResultSetColumnTypes.COLUMN_RELATION: self.parse_edge, + ResultSetColumnTypes.COLUMN_UNKNOWN: self.parse_unknown, + } + + +class AsyncQueryResult(QueryResult): + """ + Async version for the QueryResult class - a class that + represents a result of the query operation. + """ + + def __init__(self): + """ + To init the class you must call self.initialize() + """ + pass + + async def initialize(self, graph, response, profile=False): + """ + Initializes the class. + Args: + + graph: + The graph on which the query was executed. + response: + The response from the server. + profile: + A boolean indicating if the query command was "GRAPH.PROFILE" + """ + self.graph = graph + self.header = [] + self.result_set = [] + + # in case of an error an exception will be raised + self._check_for_errors(response) + + if len(response) == 1: + self.parse_statistics(response[0]) + elif profile: + self.parse_profile(response) + else: + # start by parsing statistics, matches the one we have + self.parse_statistics(response[-1]) # Last element. + await self.parse_results(response) + + return self + + async def parse_node(self, cell): + """ + Parses a node from the cell. + """ + # Node ID (integer), + # [label string offset (integer)], + # [[name, value type, value] X N] + + labels = None + if len(cell[1]) > 0: + labels = [] + for inner_label in cell[1]: + labels.append(await self.graph.get_label(inner_label)) + properties = await self.parse_entity_properties(cell[2]) + node_id = int(cell[0]) + return Node(node_id=node_id, label=labels, properties=properties) + + async def parse_scalar(self, cell): + """ + Parses a scalar value from the server response. + """ + scalar_type = int(cell[0]) + value = cell[1] + try: + scalar = await self.parse_scalar_types[scalar_type](value) + except TypeError: + # Not all of the functions are async + scalar = self.parse_scalar_types[scalar_type](value) + + return scalar + + async def parse_records(self, raw_result_set): + """ + Parses the result set and returns a list of records. + """ + records = [] + for row in raw_result_set[1]: + record = [ + await self.parse_record_types[self.header[idx][0]](cell) + for idx, cell in enumerate(row) + ] + records.append(record) + + return records + + async def parse_results(self, raw_result_set): + """ + Parse the query execution result returned from the server. + """ + self.header = self.parse_header(raw_result_set) + + # Empty header. + if len(self.header) == 0: + return + + self.result_set = await self.parse_records(raw_result_set) + + async def parse_entity_properties(self, props): + """ + Parse node / edge properties. + """ + # [[name, value type, value] X N] + properties = {} + for prop in props: + prop_name = await self.graph.get_property(prop[0]) + prop_value = await self.parse_scalar(prop[1:]) + properties[prop_name] = prop_value + + return properties + + async def parse_edge(self, cell): + """ + Parse the cell to an edge. + """ + # Edge ID (integer), + # reltype string offset (integer), + # src node ID offset (integer), + # dest node ID offset (integer), + # [[name, value, value type] X N] + + edge_id = int(cell[0]) + relation = await self.graph.get_relation(cell[1]) + src_node_id = int(cell[2]) + dest_node_id = int(cell[3]) + properties = await self.parse_entity_properties(cell[4]) + return Edge( + src_node_id, relation, dest_node_id, edge_id=edge_id, properties=properties + ) + + async def parse_path(self, cell): + """ + Parse the cell to a path. + """ + nodes = await self.parse_scalar(cell[0]) + edges = await self.parse_scalar(cell[1]) + return Path(nodes, edges) + + async def parse_map(self, cell): + """ + Parse the cell to a map. + """ + m = OrderedDict() + n_entries = len(cell) + + # A map is an array of key value pairs. + # 1. key (string) + # 2. array: (value type, value) + for i in range(0, n_entries, 2): + key = self.parse_string(cell[i]) + m[key] = await self.parse_scalar(cell[i + 1]) + + return m + + async def parse_array(self, value): + """ + Parse array value. + """ + scalar = [await self.parse_scalar(value[i]) for i in range(len(value))] + return scalar diff --git a/redis/commands/redismodules.py b/redis/commands/redismodules.py index 875f3fca25..7e2045a722 100644 --- a/redis/commands/redismodules.py +++ b/redis/commands/redismodules.py @@ -73,8 +73,8 @@ def tdigest(self): return tdigest def graph(self, index_name="idx"): - """Access the timeseries namespace, providing support for - redis timeseries data. + """Access the graph namespace, providing support for + redis graph data. """ from .graph import Graph @@ -91,3 +91,13 @@ def ft(self, index_name="idx"): s = AsyncSearch(client=self, index_name=index_name) return s + + def graph(self, index_name="idx"): + """Access the graph namespace, providing support for + redis graph data. + """ + + from .graph import AsyncGraph + + g = AsyncGraph(client=self, name=index_name) + return g diff --git a/tests/test_asyncio/test_graph.py b/tests/test_asyncio/test_graph.py new file mode 100644 index 0000000000..8a8f9cf953 --- /dev/null +++ b/tests/test_asyncio/test_graph.py @@ -0,0 +1,503 @@ +import pytest + +import redis.asyncio as redis +from redis.commands.graph import Edge, Node, Path +from redis.commands.graph.execution_plan import Operation +from redis.exceptions import ResponseError +from tests.conftest import skip_if_redis_enterprise + + +@pytest.mark.redismod +async def test_bulk(modclient): + with pytest.raises(NotImplementedError): + await modclient.graph().bulk() + await modclient.graph().bulk(foo="bar!") + + +@pytest.mark.redismod +async def test_graph_creation(modclient: redis.Redis): + graph = modclient.graph() + + john = Node( + label="person", + properties={ + "name": "John Doe", + "age": 33, + "gender": "male", + "status": "single", + }, + ) + graph.add_node(john) + japan = Node(label="country", properties={"name": "Japan"}) + + graph.add_node(japan) + edge = Edge(john, "visited", japan, properties={"purpose": "pleasure"}) + graph.add_edge(edge) + + await graph.commit() + + query = ( + 'MATCH (p:person)-[v:visited {purpose:"pleasure"}]->(c:country) ' + "RETURN p, v, c" + ) + + result = await graph.query(query) + + person = result.result_set[0][0] + visit = result.result_set[0][1] + country = result.result_set[0][2] + + assert person == john + assert visit.properties == edge.properties + assert country == japan + + query = """RETURN [1, 2.3, "4", true, false, null]""" + result = await graph.query(query) + assert [1, 2.3, "4", True, False, None] == result.result_set[0][0] + + # All done, remove graph. + await graph.delete() + + +@pytest.mark.redismod +async def test_array_functions(modclient: redis.Redis): + graph = modclient.graph() + + query = """CREATE (p:person{name:'a',age:32, array:[0,1,2]})""" + await graph.query(query) + + query = """WITH [0,1,2] as x return x""" + result = await graph.query(query) + assert [0, 1, 2] == result.result_set[0][0] + + query = """MATCH(n) return collect(n)""" + result = await graph.query(query) + + a = Node( + node_id=0, + label="person", + properties={"name": "a", "age": 32, "array": [0, 1, 2]}, + ) + + assert [a] == result.result_set[0][0] + + +@pytest.mark.redismod +async def test_path(modclient: redis.Redis): + node0 = Node(node_id=0, label="L1") + node1 = Node(node_id=1, label="L1") + edge01 = Edge(node0, "R1", node1, edge_id=0, properties={"value": 1}) + + graph = modclient.graph() + graph.add_node(node0) + graph.add_node(node1) + graph.add_edge(edge01) + await graph.flush() + + path01 = Path.new_empty_path().add_node(node0).add_edge(edge01).add_node(node1) + expected_results = [[path01]] + + query = "MATCH p=(:L1)-[:R1]->(:L1) RETURN p ORDER BY p" + result = await graph.query(query) + assert expected_results == result.result_set + + +@pytest.mark.redismod +async def test_param(modclient: redis.Redis): + params = [1, 2.3, "str", True, False, None, [0, 1, 2]] + query = "RETURN $param" + for param in params: + result = await modclient.graph().query(query, {"param": param}) + expected_results = [[param]] + assert expected_results == result.result_set + + +@pytest.mark.redismod +async def test_map(modclient: redis.Redis): + query = "RETURN {a:1, b:'str', c:NULL, d:[1,2,3], e:True, f:{x:1, y:2}}" + + actual = (await modclient.graph().query(query)).result_set[0][0] + expected = { + "a": 1, + "b": "str", + "c": None, + "d": [1, 2, 3], + "e": True, + "f": {"x": 1, "y": 2}, + } + + assert actual == expected + + +@pytest.mark.redismod +async def test_point(modclient: redis.Redis): + query = "RETURN point({latitude: 32.070794860, longitude: 34.820751118})" + expected_lat = 32.070794860 + expected_lon = 34.820751118 + actual = (await modclient.graph().query(query)).result_set[0][0] + assert abs(actual["latitude"] - expected_lat) < 0.001 + assert abs(actual["longitude"] - expected_lon) < 0.001 + + query = "RETURN point({latitude: 32, longitude: 34.0})" + expected_lat = 32 + expected_lon = 34 + actual = (await modclient.graph().query(query)).result_set[0][0] + assert abs(actual["latitude"] - expected_lat) < 0.001 + assert abs(actual["longitude"] - expected_lon) < 0.001 + + +@pytest.mark.redismod +async def test_index_response(modclient: redis.Redis): + result_set = await modclient.graph().query("CREATE INDEX ON :person(age)") + assert 1 == result_set.indices_created + + result_set = await modclient.graph().query("CREATE INDEX ON :person(age)") + assert 0 == result_set.indices_created + + result_set = await modclient.graph().query("DROP INDEX ON :person(age)") + assert 1 == result_set.indices_deleted + + with pytest.raises(ResponseError): + await modclient.graph().query("DROP INDEX ON :person(age)") + + +@pytest.mark.redismod +async def test_stringify_query_result(modclient: redis.Redis): + graph = modclient.graph() + + john = Node( + alias="a", + label="person", + properties={ + "name": "John Doe", + "age": 33, + "gender": "male", + "status": "single", + }, + ) + graph.add_node(john) + + japan = Node(alias="b", label="country", properties={"name": "Japan"}) + graph.add_node(japan) + + edge = Edge(john, "visited", japan, properties={"purpose": "pleasure"}) + graph.add_edge(edge) + + assert ( + str(john) + == """(a:person{age:33,gender:"male",name:"John Doe",status:"single"})""" # noqa + ) + assert ( + str(edge) + == """(a:person{age:33,gender:"male",name:"John Doe",status:"single"})""" # noqa + + """-[:visited{purpose:"pleasure"}]->""" + + """(b:country{name:"Japan"})""" + ) + assert str(japan) == """(b:country{name:"Japan"})""" + + await graph.commit() + + query = """MATCH (p:person)-[v:visited {purpose:"pleasure"}]->(c:country) + RETURN p, v, c""" + + result = await graph.query(query) + person = result.result_set[0][0] + visit = result.result_set[0][1] + country = result.result_set[0][2] + + assert ( + str(person) + == """(:person{age:33,gender:"male",name:"John Doe",status:"single"})""" # noqa + ) + assert str(visit) == """()-[:visited{purpose:"pleasure"}]->()""" + assert str(country) == """(:country{name:"Japan"})""" + + await graph.delete() + + +@pytest.mark.redismod +async def test_optional_match(modclient: redis.Redis): + # Build a graph of form (a)-[R]->(b) + node0 = Node(node_id=0, label="L1", properties={"value": "a"}) + node1 = Node(node_id=1, label="L1", properties={"value": "b"}) + + edge01 = Edge(node0, "R", node1, edge_id=0) + + graph = modclient.graph() + graph.add_node(node0) + graph.add_node(node1) + graph.add_edge(edge01) + await graph.flush() + + # Issue a query that collects all outgoing edges from both nodes + # (the second has none) + query = """MATCH (a) OPTIONAL MATCH (a)-[e]->(b) RETURN a, e, b ORDER BY a.value""" # noqa + expected_results = [[node0, edge01, node1], [node1, None, None]] + + result = await graph.query(query) + assert expected_results == result.result_set + + await graph.delete() + + +@pytest.mark.redismod +async def test_cached_execution(modclient: redis.Redis): + await modclient.graph().query("CREATE ()") + + uncached_result = await modclient.graph().query( + "MATCH (n) RETURN n, $param", {"param": [0]} + ) + assert uncached_result.cached_execution is False + + # loop to make sure the query is cached on each thread on server + for x in range(0, 64): + cached_result = await modclient.graph().query( + "MATCH (n) RETURN n, $param", {"param": [0]} + ) + assert uncached_result.result_set == cached_result.result_set + + # should be cached on all threads by now + assert cached_result.cached_execution + + +@pytest.mark.redismod +async def test_slowlog(modclient: redis.Redis): + create_query = """CREATE (:Rider {name:'Valentino Rossi'})-[:rides]->(:Team {name:'Yamaha'}), + (:Rider {name:'Dani Pedrosa'})-[:rides]->(:Team {name:'Honda'}), + (:Rider {name:'Andrea Dovizioso'})-[:rides]->(:Team {name:'Ducati'})""" + await modclient.graph().query(create_query) + + results = await modclient.graph().slowlog() + assert results[0][1] == "GRAPH.QUERY" + assert results[0][2] == create_query + + +@pytest.mark.redismod +async def test_query_timeout(modclient: redis.Redis): + # Build a sample graph with 1000 nodes. + await modclient.graph().query("UNWIND range(0,1000) as val CREATE ({v: val})") + # Issue a long-running query with a 1-millisecond timeout. + with pytest.raises(ResponseError): + await modclient.graph().query("MATCH (a), (b), (c), (d) RETURN *", timeout=1) + assert False is False + + with pytest.raises(Exception): + await modclient.graph().query("RETURN 1", timeout="str") + assert False is False + + +@pytest.mark.redismod +async def test_read_only_query(modclient: redis.Redis): + with pytest.raises(Exception): + # Issue a write query, specifying read-only true, + # this call should fail. + await modclient.graph().query("CREATE (p:person {name:'a'})", read_only=True) + assert False is False + + +@pytest.mark.redismod +async def test_profile(modclient: redis.Redis): + q = """UNWIND range(1, 3) AS x CREATE (p:Person {v:x})""" + profile = (await modclient.graph().profile(q)).result_set + assert "Create | Records produced: 3" in profile + assert "Unwind | Records produced: 3" in profile + + q = "MATCH (p:Person) WHERE p.v > 1 RETURN p" + profile = (await modclient.graph().profile(q)).result_set + assert "Results | Records produced: 2" in profile + assert "Project | Records produced: 2" in profile + assert "Filter | Records produced: 2" in profile + assert "Node By Label Scan | (p:Person) | Records produced: 3" in profile + + +@pytest.mark.redismod +@skip_if_redis_enterprise() +async def test_config(modclient: redis.Redis): + config_name = "RESULTSET_SIZE" + config_value = 3 + + # Set configuration + response = await modclient.graph().config(config_name, config_value, set=True) + assert response == "OK" + + # Make sure config been updated. + response = await modclient.graph().config(config_name, set=False) + expected_response = [config_name, config_value] + assert response == expected_response + + config_name = "QUERY_MEM_CAPACITY" + config_value = 1 << 20 # 1MB + + # Set configuration + response = await modclient.graph().config(config_name, config_value, set=True) + assert response == "OK" + + # Make sure config been updated. + response = await modclient.graph().config(config_name, set=False) + expected_response = [config_name, config_value] + assert response == expected_response + + # reset to default + await modclient.graph().config("QUERY_MEM_CAPACITY", 0, set=True) + await modclient.graph().config("RESULTSET_SIZE", -100, set=True) + + +@pytest.mark.redismod +@pytest.mark.onlynoncluster +async def test_list_keys(modclient: redis.Redis): + result = await modclient.graph().list_keys() + assert result == [] + + await modclient.execute_command("GRAPH.EXPLAIN", "G", "RETURN 1") + result = await modclient.graph().list_keys() + assert result == ["G"] + + await modclient.execute_command("GRAPH.EXPLAIN", "X", "RETURN 1") + result = await modclient.graph().list_keys() + assert result == ["G", "X"] + + await modclient.delete("G") + await modclient.rename("X", "Z") + result = await modclient.graph().list_keys() + assert result == ["Z"] + + await modclient.delete("Z") + result = await modclient.graph().list_keys() + assert result == [] + + +@pytest.mark.redismod +async def test_multi_label(modclient: redis.Redis): + redis_graph = modclient.graph("g") + + node = Node(label=["l", "ll"]) + redis_graph.add_node(node) + await redis_graph.commit() + + query = "MATCH (n) RETURN n" + result = await redis_graph.query(query) + result_node = result.result_set[0][0] + assert result_node == node + + try: + Node(label=1) + assert False + except AssertionError: + assert True + + try: + Node(label=["l", 1]) + assert False + except AssertionError: + assert True + + +@pytest.mark.redismod +async def test_execution_plan(modclient: redis.Redis): + redis_graph = modclient.graph("execution_plan") + create_query = """CREATE (:Rider {name:'Valentino Rossi'})-[:rides]->(:Team {name:'Yamaha'}), + (:Rider {name:'Dani Pedrosa'})-[:rides]->(:Team {name:'Honda'}), + (:Rider {name:'Andrea Dovizioso'})-[:rides]->(:Team {name:'Ducati'})""" + await redis_graph.query(create_query) + + result = await redis_graph.execution_plan( + "MATCH (r:Rider)-[:rides]->(t:Team) WHERE t.name = $name RETURN r.name, t.name, $params", # noqa + {"name": "Yehuda"}, + ) + expected = "Results\n Project\n Conditional Traverse | (t:Team)->(r:Rider)\n Filter\n Node By Label Scan | (t:Team)" # noqa + assert result == expected + + await redis_graph.delete() + + +@pytest.mark.redismod +async def test_explain(modclient: redis.Redis): + redis_graph = modclient.graph("execution_plan") + # graph creation / population + create_query = """CREATE +(:Rider {name:'Valentino Rossi'})-[:rides]->(:Team {name:'Yamaha'}), +(:Rider {name:'Dani Pedrosa'})-[:rides]->(:Team {name:'Honda'}), +(:Rider {name:'Andrea Dovizioso'})-[:rides]->(:Team {name:'Ducati'})""" + await redis_graph.query(create_query) + + result = await redis_graph.explain( + """MATCH (r:Rider)-[:rides]->(t:Team) +WHERE t.name = $name +RETURN r.name, t.name +UNION +MATCH (r:Rider)-[:rides]->(t:Team) +WHERE t.name = $name +RETURN r.name, t.name""", + {"name": "Yamaha"}, + ) + expected = """\ +Results +Distinct + Join + Project + Conditional Traverse | (t:Team)->(r:Rider) + Filter + Node By Label Scan | (t:Team) + Project + Conditional Traverse | (t:Team)->(r:Rider) + Filter + Node By Label Scan | (t:Team)""" + assert str(result).replace(" ", "").replace("\n", "") == expected.replace( + " ", "" + ).replace("\n", "") + + expected = Operation("Results").append_child( + Operation("Distinct").append_child( + Operation("Join") + .append_child( + Operation("Project").append_child( + Operation( + "Conditional Traverse", "(t:Team)->(r:Rider)" + ).append_child( + Operation("Filter").append_child( + Operation("Node By Label Scan", "(t:Team)") + ) + ) + ) + ) + .append_child( + Operation("Project").append_child( + Operation( + "Conditional Traverse", "(t:Team)->(r:Rider)" + ).append_child( + Operation("Filter").append_child( + Operation("Node By Label Scan", "(t:Team)") + ) + ) + ) + ) + ) + ) + + assert result.structured_plan == expected + + result = await redis_graph.explain( + """MATCH (r:Rider), (t:Team) + RETURN r.name, t.name""" + ) + expected = """\ +Results +Project + Cartesian Product + Node By Label Scan | (r:Rider) + Node By Label Scan | (t:Team)""" + assert str(result).replace(" ", "").replace("\n", "") == expected.replace( + " ", "" + ).replace("\n", "") + + expected = Operation("Results").append_child( + Operation("Project").append_child( + Operation("Cartesian Product") + .append_child(Operation("Node By Label Scan")) + .append_child(Operation("Node By Label Scan")) + ) + ) + + assert result.structured_plan == expected + + await redis_graph.delete()