diff --git a/redis/parsers/__init__.py b/redis/_parsers/__init__.py similarity index 100% rename from redis/parsers/__init__.py rename to redis/_parsers/__init__.py diff --git a/redis/parsers/base.py b/redis/_parsers/base.py similarity index 100% rename from redis/parsers/base.py rename to redis/_parsers/base.py diff --git a/redis/parsers/commands.py b/redis/_parsers/commands.py similarity index 100% rename from redis/parsers/commands.py rename to redis/_parsers/commands.py diff --git a/redis/parsers/encoders.py b/redis/_parsers/encoders.py similarity index 100% rename from redis/parsers/encoders.py rename to redis/_parsers/encoders.py diff --git a/redis/_parsers/helpers.py b/redis/_parsers/helpers.py new file mode 100644 index 0000000000..f27e3b12c0 --- /dev/null +++ b/redis/_parsers/helpers.py @@ -0,0 +1,851 @@ +import datetime + +from redis.utils import str_if_bytes + + +def timestamp_to_datetime(response): + "Converts a unix timestamp to a Python datetime object" + if not response: + return None + try: + response = int(response) + except ValueError: + return None + return datetime.datetime.fromtimestamp(response) + + +def parse_debug_object(response): + "Parse the results of Redis's DEBUG OBJECT command into a Python dict" + # The 'type' of the object is the first item in the response, but isn't + # prefixed with a name + response = str_if_bytes(response) + response = "type:" + response + response = dict(kv.split(":") for kv in response.split()) + + # parse some expected int values from the string response + # note: this cmd isn't spec'd so these may not appear in all redis versions + int_fields = ("refcount", "serializedlength", "lru", "lru_seconds_idle") + for field in int_fields: + if field in response: + response[field] = int(response[field]) + + return response + + +def parse_info(response): + """Parse the result of Redis's INFO command into a Python dict""" + info = {} + response = str_if_bytes(response) + + def get_value(value): + if "," not in value or "=" not in value: + try: + if "." in value: + return float(value) + else: + return int(value) + except ValueError: + return value + else: + sub_dict = {} + for item in value.split(","): + k, v = item.rsplit("=", 1) + sub_dict[k] = get_value(v) + return sub_dict + + for line in response.splitlines(): + if line and not line.startswith("#"): + if line.find(":") != -1: + # Split, the info fields keys and values. + # Note that the value may contain ':'. but the 'host:' + # pseudo-command is the only case where the key contains ':' + key, value = line.split(":", 1) + if key == "cmdstat_host": + key, value = line.rsplit(":", 1) + + if key == "module": + # Hardcode a list for key 'modules' since there could be + # multiple lines that started with 'module' + info.setdefault("modules", []).append(get_value(value)) + else: + info[key] = get_value(value) + else: + # if the line isn't splittable, append it to the "__raw__" key + info.setdefault("__raw__", []).append(line) + + return info + + +def parse_memory_stats(response, **kwargs): + """Parse the results of MEMORY STATS""" + stats = pairs_to_dict(response, decode_keys=True, decode_string_values=True) + for key, value in stats.items(): + if key.startswith("db."): + stats[key] = pairs_to_dict( + value, decode_keys=True, decode_string_values=True + ) + return stats + + +SENTINEL_STATE_TYPES = { + "can-failover-its-master": int, + "config-epoch": int, + "down-after-milliseconds": int, + "failover-timeout": int, + "info-refresh": int, + "last-hello-message": int, + "last-ok-ping-reply": int, + "last-ping-reply": int, + "last-ping-sent": int, + "master-link-down-time": int, + "master-port": int, + "num-other-sentinels": int, + "num-slaves": int, + "o-down-time": int, + "pending-commands": int, + "parallel-syncs": int, + "port": int, + "quorum": int, + "role-reported-time": int, + "s-down-time": int, + "slave-priority": int, + "slave-repl-offset": int, + "voted-leader-epoch": int, +} + + +def parse_sentinel_state(item): + result = pairs_to_dict_typed(item, SENTINEL_STATE_TYPES) + flags = set(result["flags"].split(",")) + for name, flag in ( + ("is_master", "master"), + ("is_slave", "slave"), + ("is_sdown", "s_down"), + ("is_odown", "o_down"), + ("is_sentinel", "sentinel"), + ("is_disconnected", "disconnected"), + ("is_master_down", "master_down"), + ): + result[name] = flag in flags + return result + + +def parse_sentinel_master(response): + return parse_sentinel_state(map(str_if_bytes, response)) + + +def parse_sentinel_state_resp3(response): + result = {} + for key in response: + try: + value = SENTINEL_STATE_TYPES[key](str_if_bytes(response[key])) + result[str_if_bytes(key)] = value + except Exception: + result[str_if_bytes(key)] = response[str_if_bytes(key)] + flags = set(result["flags"].split(",")) + result["flags"] = flags + return result + + +def parse_sentinel_masters(response): + result = {} + for item in response: + state = parse_sentinel_state(map(str_if_bytes, item)) + result[state["name"]] = state + return result + + +def parse_sentinel_masters_resp3(response): + return [parse_sentinel_state(master) for master in response] + + +def parse_sentinel_slaves_and_sentinels(response): + return [parse_sentinel_state(map(str_if_bytes, item)) for item in response] + + +def parse_sentinel_slaves_and_sentinels_resp3(response): + return [parse_sentinel_state_resp3(item) for item in response] + + +def parse_sentinel_get_master(response): + return response and (response[0], int(response[1])) or None + + +def pairs_to_dict(response, decode_keys=False, decode_string_values=False): + """Create a dict given a list of key/value pairs""" + if response is None: + return {} + if decode_keys or decode_string_values: + # the iter form is faster, but I don't know how to make that work + # with a str_if_bytes() map + keys = response[::2] + if decode_keys: + keys = map(str_if_bytes, keys) + values = response[1::2] + if decode_string_values: + values = map(str_if_bytes, values) + return dict(zip(keys, values)) + else: + it = iter(response) + return dict(zip(it, it)) + + +def pairs_to_dict_typed(response, type_info): + it = iter(response) + result = {} + for key, value in zip(it, it): + if key in type_info: + try: + value = type_info[key](value) + except Exception: + # if for some reason the value can't be coerced, just use + # the string value + pass + result[key] = value + return result + + +def zset_score_pairs(response, **options): + """ + If ``withscores`` is specified in the options, return the response as + a list of (value, score) pairs + """ + if not response or not options.get("withscores"): + return response + score_cast_func = options.get("score_cast_func", float) + it = iter(response) + return list(zip(it, map(score_cast_func, it))) + + +def sort_return_tuples(response, **options): + """ + If ``groups`` is specified, return the response as a list of + n-element tuples with n being the value found in options['groups'] + """ + if not response or not options.get("groups"): + return response + n = options["groups"] + return list(zip(*[response[i::n] for i in range(n)])) + + +def parse_stream_list(response): + if response is None: + return None + data = [] + for r in response: + if r is not None: + data.append((r[0], pairs_to_dict(r[1]))) + else: + data.append((None, None)) + return data + + +def pairs_to_dict_with_str_keys(response): + return pairs_to_dict(response, decode_keys=True) + + +def parse_list_of_dicts(response): + return list(map(pairs_to_dict_with_str_keys, response)) + + +def parse_xclaim(response, **options): + if options.get("parse_justid", False): + return response + return parse_stream_list(response) + + +def parse_xautoclaim(response, **options): + if options.get("parse_justid", False): + return response[1] + response[1] = parse_stream_list(response[1]) + return response + + +def parse_xinfo_stream(response, **options): + if isinstance(response, list): + data = pairs_to_dict(response, decode_keys=True) + else: + data = {str_if_bytes(k): v for k, v in response.items()} + if not options.get("full", False): + first = data.get("first-entry") + if first is not None: + data["first-entry"] = (first[0], pairs_to_dict(first[1])) + last = data["last-entry"] + if last is not None: + data["last-entry"] = (last[0], pairs_to_dict(last[1])) + else: + data["entries"] = {_id: pairs_to_dict(entry) for _id, entry in data["entries"]} + if isinstance(data["groups"][0], list): + data["groups"] = [ + pairs_to_dict(group, decode_keys=True) for group in data["groups"] + ] + else: + data["groups"] = [ + {str_if_bytes(k): v for k, v in group.items()} + for group in data["groups"] + ] + return data + + +def parse_xread(response): + if response is None: + return [] + return [[r[0], parse_stream_list(r[1])] for r in response] + + +def parse_xread_resp3(response): + if response is None: + return {} + return {key: [parse_stream_list(value)] for key, value in response.items()} + + +def parse_xpending(response, **options): + if options.get("parse_detail", False): + return parse_xpending_range(response) + consumers = [{"name": n, "pending": int(p)} for n, p in response[3] or []] + return { + "pending": response[0], + "min": response[1], + "max": response[2], + "consumers": consumers, + } + + +def parse_xpending_range(response): + k = ("message_id", "consumer", "time_since_delivered", "times_delivered") + return [dict(zip(k, r)) for r in response] + + +def float_or_none(response): + if response is None: + return None + return float(response) + + +def bool_ok(response): + return str_if_bytes(response) == "OK" + + +def parse_zadd(response, **options): + if response is None: + return None + if options.get("as_score"): + return float(response) + return int(response) + + +def parse_client_list(response, **options): + clients = [] + for c in str_if_bytes(response).splitlines(): + # Values might contain '=' + clients.append(dict(pair.split("=", 1) for pair in c.split(" "))) + return clients + + +def parse_config_get(response, **options): + response = [str_if_bytes(i) if i is not None else None for i in response] + return response and pairs_to_dict(response) or {} + + +def parse_scan(response, **options): + cursor, r = response + return int(cursor), r + + +def parse_hscan(response, **options): + cursor, r = response + return int(cursor), r and pairs_to_dict(r) or {} + + +def parse_zscan(response, **options): + score_cast_func = options.get("score_cast_func", float) + cursor, r = response + it = iter(r) + return int(cursor), list(zip(it, map(score_cast_func, it))) + + +def parse_zmscore(response, **options): + # zmscore: list of scores (double precision floating point number) or nil + return [float(score) if score is not None else None for score in response] + + +def parse_slowlog_get(response, **options): + space = " " if options.get("decode_responses", False) else b" " + + def parse_item(item): + result = {"id": item[0], "start_time": int(item[1]), "duration": int(item[2])} + # Redis Enterprise injects another entry at index [3], which has + # the complexity info (i.e. the value N in case the command has + # an O(N) complexity) instead of the command. + if isinstance(item[3], list): + result["command"] = space.join(item[3]) + result["client_address"] = item[4] + result["client_name"] = item[5] + else: + result["complexity"] = item[3] + result["command"] = space.join(item[4]) + result["client_address"] = item[5] + result["client_name"] = item[6] + return result + + return [parse_item(item) for item in response] + + +def parse_stralgo(response, **options): + """ + Parse the response from `STRALGO` command. + Without modifiers the returned value is string. + When LEN is given the command returns the length of the result + (i.e integer). + When IDX is given the command returns a dictionary with the LCS + length and all the ranges in both the strings, start and end + offset for each string, where there are matches. + When WITHMATCHLEN is given, each array representing a match will + also have the length of the match at the beginning of the array. + """ + if options.get("len", False): + return int(response) + if options.get("idx", False): + if options.get("withmatchlen", False): + matches = [ + [(int(match[-1]))] + list(map(tuple, match[:-1])) + for match in response[1] + ] + else: + matches = [list(map(tuple, match)) for match in response[1]] + return { + str_if_bytes(response[0]): matches, + str_if_bytes(response[2]): int(response[3]), + } + return str_if_bytes(response) + + +def parse_cluster_info(response, **options): + response = str_if_bytes(response) + return dict(line.split(":") for line in response.splitlines() if line) + + +def _parse_node_line(line): + line_items = line.split(" ") + node_id, addr, flags, master_id, ping, pong, epoch, connected = line.split(" ")[:8] + addr = addr.split("@")[0] + node_dict = { + "node_id": node_id, + "flags": flags, + "master_id": master_id, + "last_ping_sent": ping, + "last_pong_rcvd": pong, + "epoch": epoch, + "slots": [], + "migrations": [], + "connected": True if connected == "connected" else False, + } + if len(line_items) >= 9: + slots, migrations = _parse_slots(line_items[8:]) + node_dict["slots"], node_dict["migrations"] = slots, migrations + return addr, node_dict + + +def _parse_slots(slot_ranges): + slots, migrations = [], [] + for s_range in slot_ranges: + if "->-" in s_range: + slot_id, dst_node_id = s_range[1:-1].split("->-", 1) + migrations.append( + {"slot": slot_id, "node_id": dst_node_id, "state": "migrating"} + ) + elif "-<-" in s_range: + slot_id, src_node_id = s_range[1:-1].split("-<-", 1) + migrations.append( + {"slot": slot_id, "node_id": src_node_id, "state": "importing"} + ) + else: + s_range = [sl for sl in s_range.split("-")] + slots.append(s_range) + + return slots, migrations + + +def parse_cluster_nodes(response, **options): + """ + @see: https://redis.io/commands/cluster-nodes # string / bytes + @see: https://redis.io/commands/cluster-replicas # list of string / bytes + """ + if isinstance(response, (str, bytes)): + response = response.splitlines() + return dict(_parse_node_line(str_if_bytes(node)) for node in response) + + +def parse_geosearch_generic(response, **options): + """ + Parse the response of 'GEOSEARCH', GEORADIUS' and 'GEORADIUSBYMEMBER' + commands according to 'withdist', 'withhash' and 'withcoord' labels. + """ + try: + if options["store"] or options["store_dist"]: + # `store` and `store_dist` cant be combined + # with other command arguments. + # relevant to 'GEORADIUS' and 'GEORADIUSBYMEMBER' + return response + except KeyError: # it means the command was sent via execute_command + return response + + if type(response) != list: + response_list = [response] + else: + response_list = response + + if not options["withdist"] and not options["withcoord"] and not options["withhash"]: + # just a bunch of places + return response_list + + cast = { + "withdist": float, + "withcoord": lambda ll: (float(ll[0]), float(ll[1])), + "withhash": int, + } + + # zip all output results with each casting function to get + # the properly native Python value. + f = [lambda x: x] + f += [cast[o] for o in ["withdist", "withhash", "withcoord"] if options[o]] + return [list(map(lambda fv: fv[0](fv[1]), zip(f, r))) for r in response_list] + + +def parse_command(response, **options): + commands = {} + for command in response: + cmd_dict = {} + cmd_name = str_if_bytes(command[0]) + cmd_dict["name"] = cmd_name + cmd_dict["arity"] = int(command[1]) + cmd_dict["flags"] = [str_if_bytes(flag) for flag in command[2]] + cmd_dict["first_key_pos"] = command[3] + cmd_dict["last_key_pos"] = command[4] + cmd_dict["step_count"] = command[5] + if len(command) > 7: + cmd_dict["tips"] = command[7] + cmd_dict["key_specifications"] = command[8] + cmd_dict["subcommands"] = command[9] + commands[cmd_name] = cmd_dict + return commands + + +def parse_command_resp3(response, **options): + commands = {} + for command in response: + cmd_dict = {} + cmd_name = str_if_bytes(command[0]) + cmd_dict["name"] = cmd_name + cmd_dict["arity"] = command[1] + cmd_dict["flags"] = {str_if_bytes(flag) for flag in command[2]} + cmd_dict["first_key_pos"] = command[3] + cmd_dict["last_key_pos"] = command[4] + cmd_dict["step_count"] = command[5] + cmd_dict["acl_categories"] = command[6] + if len(command) > 7: + cmd_dict["tips"] = command[7] + cmd_dict["key_specifications"] = command[8] + cmd_dict["subcommands"] = command[9] + + commands[cmd_name] = cmd_dict + return commands + + +def parse_pubsub_numsub(response, **options): + return list(zip(response[0::2], response[1::2])) + + +def parse_client_kill(response, **options): + if isinstance(response, int): + return response + return str_if_bytes(response) == "OK" + + +def parse_acl_getuser(response, **options): + if response is None: + return None + if isinstance(response, list): + data = pairs_to_dict(response, decode_keys=True) + else: + data = {str_if_bytes(key): value for key, value in response.items()} + + # convert everything but user-defined data in 'keys' to native strings + data["flags"] = list(map(str_if_bytes, data["flags"])) + data["passwords"] = list(map(str_if_bytes, data["passwords"])) + data["commands"] = str_if_bytes(data["commands"]) + if isinstance(data["keys"], str) or isinstance(data["keys"], bytes): + data["keys"] = list(str_if_bytes(data["keys"]).split(" ")) + if data["keys"] == [""]: + data["keys"] = [] + if "channels" in data: + if isinstance(data["channels"], str) or isinstance(data["channels"], bytes): + data["channels"] = list(str_if_bytes(data["channels"]).split(" ")) + if data["channels"] == [""]: + data["channels"] = [] + if "selectors" in data: + if data["selectors"] != [] and isinstance(data["selectors"][0], list): + data["selectors"] = [ + list(map(str_if_bytes, selector)) for selector in data["selectors"] + ] + elif data["selectors"] != []: + data["selectors"] = [ + {str_if_bytes(k): str_if_bytes(v) for k, v in selector.items()} + for selector in data["selectors"] + ] + + # split 'commands' into separate 'categories' and 'commands' lists + commands, categories = [], [] + for command in data["commands"].split(" "): + categories.append(command) if "@" in command else commands.append(command) + + data["commands"] = commands + data["categories"] = categories + data["enabled"] = "on" in data["flags"] + return data + + +def parse_acl_log(response, **options): + if response is None: + return None + if isinstance(response, list): + data = [] + for log in response: + log_data = pairs_to_dict(log, True, True) + client_info = log_data.get("client-info", "") + log_data["client-info"] = parse_client_info(client_info) + + # float() is lossy comparing to the "double" in C + log_data["age-seconds"] = float(log_data["age-seconds"]) + data.append(log_data) + else: + data = bool_ok(response) + return data + + +def parse_client_info(value): + """ + Parsing client-info in ACL Log in following format. + "key1=value1 key2=value2 key3=value3" + """ + client_info = {} + infos = str_if_bytes(value).split(" ") + for info in infos: + key, value = info.split("=") + client_info[key] = value + + # Those fields are defined as int in networking.c + for int_key in { + "id", + "age", + "idle", + "db", + "sub", + "psub", + "multi", + "qbuf", + "qbuf-free", + "obl", + "argv-mem", + "oll", + "omem", + "tot-mem", + }: + client_info[int_key] = int(client_info[int_key]) + return client_info + + +def parse_set_result(response, **options): + """ + Handle SET result since GET argument is available since Redis 6.2. + Parsing SET result into: + - BOOL + - String when GET argument is used + """ + if options.get("get"): + # Redis will return a getCommand result. + # See `setGenericCommand` in t_string.c + return response + return response and str_if_bytes(response) == "OK" + + +def string_keys_to_dict(key_string, callback): + return dict.fromkeys(key_string.split(), callback) + + +_RedisCallbacks = { + **string_keys_to_dict( + "AUTH COPY EXPIRE EXPIREAT HEXISTS HMSET MOVE MSETNX PERSIST PSETEX " + "PEXPIRE PEXPIREAT RENAMENX SETEX SETNX SMOVE", + bool, + ), + **string_keys_to_dict("HINCRBYFLOAT INCRBYFLOAT", float), + **string_keys_to_dict( + "ASKING FLUSHALL FLUSHDB LSET LTRIM MSET PFMERGE READONLY READWRITE " + "RENAME SAVE SELECT SHUTDOWN SLAVEOF SWAPDB WATCH UNWATCH", + bool_ok, + ), + **string_keys_to_dict("XREAD XREADGROUP", parse_xread), + **string_keys_to_dict( + "GEORADIUS GEORADIUSBYMEMBER GEOSEARCH", + parse_geosearch_generic, + ), + **string_keys_to_dict("XRANGE XREVRANGE", parse_stream_list), + "ACL GETUSER": parse_acl_getuser, + "ACL LOAD": bool_ok, + "ACL LOG": parse_acl_log, + "ACL SETUSER": bool_ok, + "ACL SAVE": bool_ok, + "CLIENT INFO": parse_client_info, + "CLIENT KILL": parse_client_kill, + "CLIENT LIST": parse_client_list, + "CLIENT PAUSE": bool_ok, + "CLIENT SETNAME": bool_ok, + "CLIENT UNBLOCK": bool, + "CLUSTER ADDSLOTS": bool_ok, + "CLUSTER ADDSLOTSRANGE": bool_ok, + "CLUSTER DELSLOTS": bool_ok, + "CLUSTER DELSLOTSRANGE": bool_ok, + "CLUSTER FAILOVER": bool_ok, + "CLUSTER FORGET": bool_ok, + "CLUSTER INFO": parse_cluster_info, + "CLUSTER MEET": bool_ok, + "CLUSTER NODES": parse_cluster_nodes, + "CLUSTER REPLICAS": parse_cluster_nodes, + "CLUSTER REPLICATE": bool_ok, + "CLUSTER RESET": bool_ok, + "CLUSTER SAVECONFIG": bool_ok, + "CLUSTER SET-CONFIG-EPOCH": bool_ok, + "CLUSTER SETSLOT": bool_ok, + "CLUSTER SLAVES": parse_cluster_nodes, + "COMMAND": parse_command, + "CONFIG RESETSTAT": bool_ok, + "CONFIG SET": bool_ok, + "FUNCTION DELETE": bool_ok, + "FUNCTION FLUSH": bool_ok, + "FUNCTION RESTORE": bool_ok, + "GEODIST": float_or_none, + "HSCAN": parse_hscan, + "INFO": parse_info, + "LASTSAVE": timestamp_to_datetime, + "MEMORY PURGE": bool_ok, + "MODULE LOAD": bool, + "MODULE UNLOAD": bool, + "PING": lambda r: str_if_bytes(r) == "PONG", + "PUBSUB NUMSUB": parse_pubsub_numsub, + "QUIT": bool_ok, + "SET": parse_set_result, + "SCAN": parse_scan, + "SCRIPT EXISTS": lambda r: list(map(bool, r)), + "SCRIPT FLUSH": bool_ok, + "SCRIPT KILL": bool_ok, + "SCRIPT LOAD": str_if_bytes, + "SENTINEL CKQUORUM": bool_ok, + "SENTINEL FAILOVER": bool_ok, + "SENTINEL FLUSHCONFIG": bool_ok, + "SENTINEL GET-MASTER-ADDR-BY-NAME": parse_sentinel_get_master, + "SENTINEL MONITOR": bool_ok, + "SENTINEL RESET": bool_ok, + "SENTINEL REMOVE": bool_ok, + "SENTINEL SET": bool_ok, + "SLOWLOG GET": parse_slowlog_get, + "SLOWLOG RESET": bool_ok, + "SORT": sort_return_tuples, + "SSCAN": parse_scan, + "TIME": lambda x: (int(x[0]), int(x[1])), + "XAUTOCLAIM": parse_xautoclaim, + "XCLAIM": parse_xclaim, + "XGROUP CREATE": bool_ok, + "XGROUP DESTROY": bool, + "XGROUP SETID": bool_ok, + "XINFO STREAM": parse_xinfo_stream, + "XPENDING": parse_xpending, + "ZSCAN": parse_zscan, +} + + +_RedisCallbacksRESP2 = { + **string_keys_to_dict( + "SDIFF SINTER SMEMBERS SUNION", lambda r: r and set(r) or set() + ), + **string_keys_to_dict( + "ZDIFF ZINTER ZPOPMAX ZPOPMIN ZRANGE ZRANGEBYSCORE ZRANK ZREVRANGE " + "ZREVRANGEBYSCORE ZREVRANK ZUNION", + zset_score_pairs, + ), + **string_keys_to_dict("ZINCRBY ZSCORE", float_or_none), + **string_keys_to_dict("BGREWRITEAOF BGSAVE", lambda r: True), + **string_keys_to_dict("BLPOP BRPOP", lambda r: r and tuple(r) or None), + **string_keys_to_dict( + "BZPOPMAX BZPOPMIN", lambda r: r and (r[0], r[1], float(r[2])) or None + ), + "ACL CAT": lambda r: list(map(str_if_bytes, r)), + "ACL GENPASS": str_if_bytes, + "ACL HELP": lambda r: list(map(str_if_bytes, r)), + "ACL LIST": lambda r: list(map(str_if_bytes, r)), + "ACL USERS": lambda r: list(map(str_if_bytes, r)), + "ACL WHOAMI": str_if_bytes, + "CLIENT GETNAME": str_if_bytes, + "CLIENT TRACKINGINFO": lambda r: list(map(str_if_bytes, r)), + "CLUSTER GETKEYSINSLOT": lambda r: list(map(str_if_bytes, r)), + "COMMAND GETKEYS": lambda r: list(map(str_if_bytes, r)), + "CONFIG GET": parse_config_get, + "DEBUG OBJECT": parse_debug_object, + "GEOHASH": lambda r: list(map(str_if_bytes, r)), + "GEOPOS": lambda r: list( + map(lambda ll: (float(ll[0]), float(ll[1])) if ll is not None else None, r) + ), + "HGETALL": lambda r: r and pairs_to_dict(r) or {}, + "MEMORY STATS": parse_memory_stats, + "MODULE LIST": lambda r: [pairs_to_dict(m) for m in r], + "RESET": str_if_bytes, + "SENTINEL MASTER": parse_sentinel_master, + "SENTINEL MASTERS": parse_sentinel_masters, + "SENTINEL SENTINELS": parse_sentinel_slaves_and_sentinels, + "SENTINEL SLAVES": parse_sentinel_slaves_and_sentinels, + "STRALGO": parse_stralgo, + "XINFO CONSUMERS": parse_list_of_dicts, + "XINFO GROUPS": parse_list_of_dicts, + "ZADD": parse_zadd, + "ZMSCORE": parse_zmscore, +} + + +_RedisCallbacksRESP3 = { + **string_keys_to_dict( + "ZRANGE ZINTER ZPOPMAX ZPOPMIN ZRANGEBYSCORE ZREVRANGE ZREVRANGEBYSCORE " + "ZUNION HGETALL XREADGROUP", + lambda r, **kwargs: r, + ), + **string_keys_to_dict("XREAD XREADGROUP", parse_xread_resp3), + "ACL LOG": lambda r: [ + {str_if_bytes(key): str_if_bytes(value) for key, value in x.items()} for x in r + ] + if isinstance(r, list) + else bool_ok(r), + "COMMAND": parse_command_resp3, + "CONFIG GET": lambda r: { + str_if_bytes(key) + if key is not None + else None: str_if_bytes(value) + if value is not None + else None + for key, value in r.items() + }, + "MEMORY STATS": lambda r: {str_if_bytes(key): value for key, value in r.items()}, + "SENTINEL MASTER": parse_sentinel_state_resp3, + "SENTINEL MASTERS": parse_sentinel_masters_resp3, + "SENTINEL SENTINELS": parse_sentinel_slaves_and_sentinels_resp3, + "SENTINEL SLAVES": parse_sentinel_slaves_and_sentinels_resp3, + "STRALGO": lambda r, **options: { + str_if_bytes(key): str_if_bytes(value) for key, value in r.items() + } + if isinstance(r, dict) + else str_if_bytes(r), + "XINFO CONSUMERS": lambda r: [ + {str_if_bytes(key): value for key, value in x.items()} for x in r + ], + "XINFO GROUPS": lambda r: [ + {str_if_bytes(key): value for key, value in d.items()} for d in r + ], +} diff --git a/redis/parsers/hiredis.py b/redis/_parsers/hiredis.py similarity index 100% rename from redis/parsers/hiredis.py rename to redis/_parsers/hiredis.py diff --git a/redis/parsers/resp2.py b/redis/_parsers/resp2.py similarity index 100% rename from redis/parsers/resp2.py rename to redis/_parsers/resp2.py diff --git a/redis/parsers/resp3.py b/redis/_parsers/resp3.py similarity index 100% rename from redis/parsers/resp3.py rename to redis/_parsers/resp3.py diff --git a/redis/parsers/socket.py b/redis/_parsers/socket.py similarity index 100% rename from redis/parsers/socket.py rename to redis/_parsers/socket.py diff --git a/redis/asyncio/client.py b/redis/asyncio/client.py index 849603abb4..111df24185 100644 --- a/redis/asyncio/client.py +++ b/redis/asyncio/client.py @@ -24,6 +24,12 @@ cast, ) +from redis._parsers.helpers import ( + _RedisCallbacks, + _RedisCallbacksRESP2, + _RedisCallbacksRESP3, + bool_ok, +) from redis.asyncio.connection import ( Connection, ConnectionPool, @@ -37,7 +43,6 @@ NEVER_DECODE, AbstractRedis, CaseInsensitiveDict, - bool_ok, ) from redis.commands import ( AsyncCoreCommands, @@ -257,12 +262,12 @@ def __init__( self.single_connection_client = single_connection_client self.connection: Optional[Connection] = None - self.response_callbacks = CaseInsensitiveDict(self.__class__.RESPONSE_CALLBACKS) + self.response_callbacks = CaseInsensitiveDict(_RedisCallbacks) if self.connection_pool.connection_kwargs.get("protocol") in ["3", 3]: - self.response_callbacks.update(self.__class__.RESP3_RESPONSE_CALLBACKS) + self.response_callbacks.update(_RedisCallbacksRESP3) else: - self.response_callbacks.update(self.__class__.RESP2_RESPONSE_CALLBACKS) + self.response_callbacks.update(_RedisCallbacksRESP2) # If using a single connection client, we need to lock creation-of and use-of # the client in order to avoid race conditions such as using asyncio.gather diff --git a/redis/asyncio/cluster.py b/redis/asyncio/cluster.py index 5c7aecfe23..9e2a40ce1b 100644 --- a/redis/asyncio/cluster.py +++ b/redis/asyncio/cluster.py @@ -18,6 +18,12 @@ Union, ) +from redis._parsers import AsyncCommandsParser, Encoder +from redis._parsers.helpers import ( + _RedisCallbacks, + _RedisCallbacksRESP2, + _RedisCallbacksRESP3, +) from redis.asyncio.client import ResponseCallbackT from redis.asyncio.connection import Connection, DefaultParser, SSLConnection, parse_url from redis.asyncio.lock import Lock @@ -55,7 +61,6 @@ TimeoutError, TryAgainError, ) -from redis.parsers import AsyncCommandsParser, Encoder from redis.typing import AnyKeyT, EncodableT, KeyT from redis.utils import dict_merge, safe_str, str_if_bytes @@ -327,11 +332,11 @@ def __init__( self.retry.update_supported_errors(retry_on_error) kwargs.update({"retry": self.retry}) - kwargs["response_callbacks"] = self.__class__.RESPONSE_CALLBACKS.copy() + kwargs["response_callbacks"] = _RedisCallbacks.copy() if kwargs.get("protocol") in ["3", 3]: - kwargs["response_callbacks"].update(self.__class__.RESP3_RESPONSE_CALLBACKS) + kwargs["response_callbacks"].update(_RedisCallbacksRESP3) else: - kwargs["response_callbacks"].update(self.__class__.RESP2_RESPONSE_CALLBACKS) + kwargs["response_callbacks"].update(_RedisCallbacksRESP2) self.connection_kwargs = kwargs if startup_nodes: diff --git a/redis/asyncio/connection.py b/redis/asyncio/connection.py index fc69b9091a..22c5030e6c 100644 --- a/redis/asyncio/connection.py +++ b/redis/asyncio/connection.py @@ -51,7 +51,7 @@ from redis.typing import EncodableT from redis.utils import HIREDIS_AVAILABLE, str_if_bytes -from ..parsers import ( +from .._parsers import ( BaseParser, Encoder, _AsyncHiredisParser, diff --git a/redis/client.py b/redis/client.py index 09156bace6..66e2c7b84f 100755 --- a/redis/client.py +++ b/redis/client.py @@ -1,5 +1,4 @@ import copy -import datetime import re import threading import time @@ -7,6 +6,12 @@ from itertools import chain from typing import Optional +from redis._parsers.helpers import ( + _RedisCallbacks, + _RedisCallbacksRESP2, + _RedisCallbacksRESP3, + bool_ok, +) from redis.commands import ( CoreCommands, RedisModuleCommands, @@ -18,7 +23,6 @@ from redis.exceptions import ( ConnectionError, ExecAbortError, - ModuleError, PubSubError, RedisError, ResponseError, @@ -36,21 +40,6 @@ NEVER_DECODE = "NEVER_DECODE" -def timestamp_to_datetime(response): - "Converts a unix timestamp to a Python datetime object" - if not response: - return None - try: - response = int(response) - except ValueError: - return None - return datetime.datetime.fromtimestamp(response) - - -def string_keys_to_dict(key_string, callback): - return dict.fromkeys(key_string.split(), callback) - - class CaseInsensitiveDict(dict): "Case insensitive dict implementation. Assumes string keys only." @@ -78,855 +67,11 @@ def update(self, data): super().update(data) -def parse_debug_object(response): - "Parse the results of Redis's DEBUG OBJECT command into a Python dict" - # The 'type' of the object is the first item in the response, but isn't - # prefixed with a name - response = str_if_bytes(response) - response = "type:" + response - response = dict(kv.split(":") for kv in response.split()) - - # parse some expected int values from the string response - # note: this cmd isn't spec'd so these may not appear in all redis versions - int_fields = ("refcount", "serializedlength", "lru", "lru_seconds_idle") - for field in int_fields: - if field in response: - response[field] = int(response[field]) - - return response - - -def parse_object(response, infotype): - """Parse the results of an OBJECT command""" - if infotype in ("idletime", "refcount"): - return int_or_none(response) - return response - - -def parse_info(response): - """Parse the result of Redis's INFO command into a Python dict""" - info = {} - response = str_if_bytes(response) - - def get_value(value): - if "," not in value or "=" not in value: - try: - if "." in value: - return float(value) - else: - return int(value) - except ValueError: - return value - else: - sub_dict = {} - for item in value.split(","): - k, v = item.rsplit("=", 1) - sub_dict[k] = get_value(v) - return sub_dict - - for line in response.splitlines(): - if line and not line.startswith("#"): - if line.find(":") != -1: - # Split, the info fields keys and values. - # Note that the value may contain ':'. but the 'host:' - # pseudo-command is the only case where the key contains ':' - key, value = line.split(":", 1) - if key == "cmdstat_host": - key, value = line.rsplit(":", 1) - - if key == "module": - # Hardcode a list for key 'modules' since there could be - # multiple lines that started with 'module' - info.setdefault("modules", []).append(get_value(value)) - else: - info[key] = get_value(value) - else: - # if the line isn't splittable, append it to the "__raw__" key - info.setdefault("__raw__", []).append(line) - - return info - - -def parse_memory_stats(response, **kwargs): - """Parse the results of MEMORY STATS""" - stats = pairs_to_dict(response, decode_keys=True, decode_string_values=True) - for key, value in stats.items(): - if key.startswith("db."): - stats[key] = pairs_to_dict( - value, decode_keys=True, decode_string_values=True - ) - return stats - - -SENTINEL_STATE_TYPES = { - "can-failover-its-master": int, - "config-epoch": int, - "down-after-milliseconds": int, - "failover-timeout": int, - "info-refresh": int, - "last-hello-message": int, - "last-ok-ping-reply": int, - "last-ping-reply": int, - "last-ping-sent": int, - "master-link-down-time": int, - "master-port": int, - "num-other-sentinels": int, - "num-slaves": int, - "o-down-time": int, - "pending-commands": int, - "parallel-syncs": int, - "port": int, - "quorum": int, - "role-reported-time": int, - "s-down-time": int, - "slave-priority": int, - "slave-repl-offset": int, - "voted-leader-epoch": int, -} - - -def parse_sentinel_state(item): - result = pairs_to_dict_typed(item, SENTINEL_STATE_TYPES) - flags = set(result["flags"].split(",")) - for name, flag in ( - ("is_master", "master"), - ("is_slave", "slave"), - ("is_sdown", "s_down"), - ("is_odown", "o_down"), - ("is_sentinel", "sentinel"), - ("is_disconnected", "disconnected"), - ("is_master_down", "master_down"), - ): - result[name] = flag in flags - return result - - -def parse_sentinel_master(response): - return parse_sentinel_state(map(str_if_bytes, response)) - - -def parse_sentinel_masters(response): - result = {} - for item in response: - state = parse_sentinel_state(map(str_if_bytes, item)) - result[state["name"]] = state - return result - - -def parse_sentinel_slaves_and_sentinels(response): - return [parse_sentinel_state(map(str_if_bytes, item)) for item in response] - - -def parse_sentinel_get_master(response): - return response and (response[0], int(response[1])) or None - - -def pairs_to_dict(response, decode_keys=False, decode_string_values=False): - """Create a dict given a list of key/value pairs""" - if response is None: - return {} - if decode_keys or decode_string_values: - # the iter form is faster, but I don't know how to make that work - # with a str_if_bytes() map - keys = response[::2] - if decode_keys: - keys = map(str_if_bytes, keys) - values = response[1::2] - if decode_string_values: - values = map(str_if_bytes, values) - return dict(zip(keys, values)) - else: - it = iter(response) - return dict(zip(it, it)) - - -def pairs_to_dict_typed(response, type_info): - it = iter(response) - result = {} - for key, value in zip(it, it): - if key in type_info: - try: - value = type_info[key](value) - except Exception: - # if for some reason the value can't be coerced, just use - # the string value - pass - result[key] = value - return result - - -def zset_score_pairs(response, **options): - """ - If ``withscores`` is specified in the options, return the response as - a list of (value, score) pairs - """ - if not response or not options.get("withscores"): - return response - score_cast_func = options.get("score_cast_func", float) - it = iter(response) - return list(zip(it, map(score_cast_func, it))) - - -def sort_return_tuples(response, **options): - """ - If ``groups`` is specified, return the response as a list of - n-element tuples with n being the value found in options['groups'] - """ - if not response or not options.get("groups"): - return response - n = options["groups"] - return list(zip(*[response[i::n] for i in range(n)])) - - -def int_or_none(response): - if response is None: - return None - return int(response) - - -def parse_stream_list(response): - if response is None: - return None - data = [] - for r in response: - if r is not None: - data.append((r[0], pairs_to_dict(r[1]))) - else: - data.append((None, None)) - return data - - -def pairs_to_dict_with_str_keys(response): - return pairs_to_dict(response, decode_keys=True) - - -def parse_list_of_dicts(response): - return list(map(pairs_to_dict_with_str_keys, response)) - - -def parse_xclaim(response, **options): - if options.get("parse_justid", False): - return response - return parse_stream_list(response) - - -def parse_xautoclaim(response, **options): - if options.get("parse_justid", False): - return response[1] - response[1] = parse_stream_list(response[1]) - return response - - -def parse_xinfo_stream(response, **options): - if isinstance(response, list): - data = pairs_to_dict(response, decode_keys=True) - else: - data = {str_if_bytes(k): v for k, v in response.items()} - if not options.get("full", False): - first = data.get("first-entry") - if first is not None: - data["first-entry"] = (first[0], pairs_to_dict(first[1])) - last = data["last-entry"] - if last is not None: - data["last-entry"] = (last[0], pairs_to_dict(last[1])) - else: - data["entries"] = {_id: pairs_to_dict(entry) for _id, entry in data["entries"]} - if isinstance(data["groups"][0], list): - data["groups"] = [ - pairs_to_dict(group, decode_keys=True) for group in data["groups"] - ] - else: - data["groups"] = [ - {str_if_bytes(k): v for k, v in group.items()} - for group in data["groups"] - ] - return data - - -def parse_xread(response): - if response is None: - return [] - return [[r[0], parse_stream_list(r[1])] for r in response] - - -def parse_xread_resp3(response): - if response is None: - return {} - return {key: [parse_stream_list(value)] for key, value in response.items()} - - -def parse_xpending(response, **options): - if options.get("parse_detail", False): - return parse_xpending_range(response) - consumers = [{"name": n, "pending": int(p)} for n, p in response[3] or []] - return { - "pending": response[0], - "min": response[1], - "max": response[2], - "consumers": consumers, - } - - -def parse_xpending_range(response): - k = ("message_id", "consumer", "time_since_delivered", "times_delivered") - return [dict(zip(k, r)) for r in response] - - -def float_or_none(response): - if response is None: - return None - return float(response) - - -def bool_ok(response): - return str_if_bytes(response) == "OK" - - -def parse_zadd(response, **options): - if response is None: - return None - if options.get("as_score"): - return float(response) - return int(response) - - -def parse_client_list(response, **options): - clients = [] - for c in str_if_bytes(response).splitlines(): - # Values might contain '=' - clients.append(dict(pair.split("=", 1) for pair in c.split(" "))) - return clients - - -def parse_config_get(response, **options): - response = [str_if_bytes(i) if i is not None else None for i in response] - return response and pairs_to_dict(response) or {} - - -def parse_scan(response, **options): - cursor, r = response - return int(cursor), r - - -def parse_hscan(response, **options): - cursor, r = response - return int(cursor), r and pairs_to_dict(r) or {} - - -def parse_zscan(response, **options): - score_cast_func = options.get("score_cast_func", float) - cursor, r = response - it = iter(r) - return int(cursor), list(zip(it, map(score_cast_func, it))) - - -def parse_zmscore(response, **options): - # zmscore: list of scores (double precision floating point number) or nil - return [float(score) if score is not None else None for score in response] - - -def parse_slowlog_get(response, **options): - space = " " if options.get("decode_responses", False) else b" " - - def parse_item(item): - result = {"id": item[0], "start_time": int(item[1]), "duration": int(item[2])} - # Redis Enterprise injects another entry at index [3], which has - # the complexity info (i.e. the value N in case the command has - # an O(N) complexity) instead of the command. - if isinstance(item[3], list): - result["command"] = space.join(item[3]) - result["client_address"] = item[4] - result["client_name"] = item[5] - else: - result["complexity"] = item[3] - result["command"] = space.join(item[4]) - result["client_address"] = item[5] - result["client_name"] = item[6] - return result - - return [parse_item(item) for item in response] - - -def parse_stralgo(response, **options): - """ - Parse the response from `STRALGO` command. - Without modifiers the returned value is string. - When LEN is given the command returns the length of the result - (i.e integer). - When IDX is given the command returns a dictionary with the LCS - length and all the ranges in both the strings, start and end - offset for each string, where there are matches. - When WITHMATCHLEN is given, each array representing a match will - also have the length of the match at the beginning of the array. - """ - if options.get("len", False): - return int(response) - if options.get("idx", False): - if options.get("withmatchlen", False): - matches = [ - [(int(match[-1]))] + list(map(tuple, match[:-1])) - for match in response[1] - ] - else: - matches = [list(map(tuple, match)) for match in response[1]] - return { - str_if_bytes(response[0]): matches, - str_if_bytes(response[2]): int(response[3]), - } - return str_if_bytes(response) - - -def parse_cluster_info(response, **options): - response = str_if_bytes(response) - return dict(line.split(":") for line in response.splitlines() if line) - - -def _parse_node_line(line): - line_items = line.split(" ") - node_id, addr, flags, master_id, ping, pong, epoch, connected = line.split(" ")[:8] - addr = addr.split("@")[0] - node_dict = { - "node_id": node_id, - "flags": flags, - "master_id": master_id, - "last_ping_sent": ping, - "last_pong_rcvd": pong, - "epoch": epoch, - "slots": [], - "migrations": [], - "connected": True if connected == "connected" else False, - } - if len(line_items) >= 9: - slots, migrations = _parse_slots(line_items[8:]) - node_dict["slots"], node_dict["migrations"] = slots, migrations - return addr, node_dict - - -def _parse_slots(slot_ranges): - slots, migrations = [], [] - for s_range in slot_ranges: - if "->-" in s_range: - slot_id, dst_node_id = s_range[1:-1].split("->-", 1) - migrations.append( - {"slot": slot_id, "node_id": dst_node_id, "state": "migrating"} - ) - elif "-<-" in s_range: - slot_id, src_node_id = s_range[1:-1].split("-<-", 1) - migrations.append( - {"slot": slot_id, "node_id": src_node_id, "state": "importing"} - ) - else: - s_range = [sl for sl in s_range.split("-")] - slots.append(s_range) - - return slots, migrations - - -def parse_cluster_nodes(response, **options): - """ - @see: https://redis.io/commands/cluster-nodes # string / bytes - @see: https://redis.io/commands/cluster-replicas # list of string / bytes - """ - if isinstance(response, (str, bytes)): - response = response.splitlines() - return dict(_parse_node_line(str_if_bytes(node)) for node in response) - - -def parse_geosearch_generic(response, **options): - """ - Parse the response of 'GEOSEARCH', GEORADIUS' and 'GEORADIUSBYMEMBER' - commands according to 'withdist', 'withhash' and 'withcoord' labels. - """ - try: - if options["store"] or options["store_dist"]: - # `store` and `store_dist` cant be combined - # with other command arguments. - # relevant to 'GEORADIUS' and 'GEORADIUSBYMEMBER' - return response - except KeyError: # it means the command was sent via execute_command - return response - - if type(response) != list: - response_list = [response] - else: - response_list = response - - if not options["withdist"] and not options["withcoord"] and not options["withhash"]: - # just a bunch of places - return response_list - - cast = { - "withdist": float, - "withcoord": lambda ll: (float(ll[0]), float(ll[1])), - "withhash": int, - } - - # zip all output results with each casting function to get - # the properly native Python value. - f = [lambda x: x] - f += [cast[o] for o in ["withdist", "withhash", "withcoord"] if options[o]] - return [list(map(lambda fv: fv[0](fv[1]), zip(f, r))) for r in response_list] - - -def parse_command(response, **options): - commands = {} - for command in response: - cmd_dict = {} - cmd_name = str_if_bytes(command[0]) - cmd_dict["name"] = cmd_name - cmd_dict["arity"] = int(command[1]) - cmd_dict["flags"] = [str_if_bytes(flag) for flag in command[2]] - cmd_dict["first_key_pos"] = command[3] - cmd_dict["last_key_pos"] = command[4] - cmd_dict["step_count"] = command[5] - if len(command) > 7: - cmd_dict["tips"] = command[7] - cmd_dict["key_specifications"] = command[8] - cmd_dict["subcommands"] = command[9] - commands[cmd_name] = cmd_dict - return commands - - -def parse_command_resp3(response, **options): - commands = {} - for command in response: - cmd_dict = {} - cmd_name = str_if_bytes(command[0]) - cmd_dict["name"] = cmd_name - cmd_dict["arity"] = command[1] - cmd_dict["flags"] = {str_if_bytes(flag) for flag in command[2]} - cmd_dict["first_key_pos"] = command[3] - cmd_dict["last_key_pos"] = command[4] - cmd_dict["step_count"] = command[5] - cmd_dict["acl_categories"] = command[6] - if len(command) > 7: - cmd_dict["tips"] = command[7] - cmd_dict["key_specifications"] = command[8] - cmd_dict["subcommands"] = command[9] - - commands[cmd_name] = cmd_dict - return commands - - -def parse_pubsub_numsub(response, **options): - return list(zip(response[0::2], response[1::2])) - - -def parse_client_kill(response, **options): - if isinstance(response, int): - return response - return str_if_bytes(response) == "OK" - - -def parse_acl_getuser(response, **options): - if response is None: - return None - if isinstance(response, list): - data = pairs_to_dict(response, decode_keys=True) - else: - data = {str_if_bytes(key): value for key, value in response.items()} - - # convert everything but user-defined data in 'keys' to native strings - data["flags"] = list(map(str_if_bytes, data["flags"])) - data["passwords"] = list(map(str_if_bytes, data["passwords"])) - data["commands"] = str_if_bytes(data["commands"]) - if isinstance(data["keys"], str) or isinstance(data["keys"], bytes): - data["keys"] = list(str_if_bytes(data["keys"]).split(" ")) - if data["keys"] == [""]: - data["keys"] = [] - if "channels" in data: - if isinstance(data["channels"], str) or isinstance(data["channels"], bytes): - data["channels"] = list(str_if_bytes(data["channels"]).split(" ")) - if data["channels"] == [""]: - data["channels"] = [] - if "selectors" in data: - if data["selectors"] != [] and isinstance(data["selectors"][0], list): - data["selectors"] = [ - list(map(str_if_bytes, selector)) for selector in data["selectors"] - ] - elif data["selectors"] != []: - data["selectors"] = [ - {str_if_bytes(k): str_if_bytes(v) for k, v in selector.items()} - for selector in data["selectors"] - ] - - # split 'commands' into separate 'categories' and 'commands' lists - commands, categories = [], [] - for command in data["commands"].split(" "): - categories.append(command) if "@" in command else commands.append(command) - - data["commands"] = commands - data["categories"] = categories - data["enabled"] = "on" in data["flags"] - return data - - -def parse_acl_log(response, **options): - if response is None: - return None - if isinstance(response, list): - data = [] - for log in response: - log_data = pairs_to_dict(log, True, True) - client_info = log_data.get("client-info", "") - log_data["client-info"] = parse_client_info(client_info) - - # float() is lossy comparing to the "double" in C - log_data["age-seconds"] = float(log_data["age-seconds"]) - data.append(log_data) - else: - data = bool_ok(response) - return data - - -def parse_client_info(value): - """ - Parsing client-info in ACL Log in following format. - "key1=value1 key2=value2 key3=value3" - """ - client_info = {} - infos = str_if_bytes(value).split(" ") - for info in infos: - key, value = info.split("=") - client_info[key] = value - - # Those fields are defined as int in networking.c - for int_key in { - "id", - "age", - "idle", - "db", - "sub", - "psub", - "multi", - "qbuf", - "qbuf-free", - "obl", - "argv-mem", - "oll", - "omem", - "tot-mem", - }: - client_info[int_key] = int(client_info[int_key]) - return client_info - - -def parse_module_result(response): - if isinstance(response, ModuleError): - raise response - return True - - -def parse_set_result(response, **options): - """ - Handle SET result since GET argument is available since Redis 6.2. - Parsing SET result into: - - BOOL - - String when GET argument is used - """ - if options.get("get"): - # Redis will return a getCommand result. - # See `setGenericCommand` in t_string.c - return response - return response and str_if_bytes(response) == "OK" +class AbstractRedis: + pass -class AbstractRedis: - RESPONSE_CALLBACKS = { - **string_keys_to_dict("EXPIRE EXPIREAT PEXPIRE PEXPIREAT AUTH", bool), - **string_keys_to_dict("EXISTS", int), - **string_keys_to_dict("INCRBYFLOAT HINCRBYFLOAT", float), - **string_keys_to_dict("READONLY MSET", bool_ok), - "CLUSTER DELSLOTS": bool_ok, - "CLUSTER ADDSLOTS": bool_ok, - "COMMAND": parse_command, - "INFO": parse_info, - "SET": parse_set_result, - "CLIENT ID": int, - "CLIENT KILL": parse_client_kill, - "CLIENT LIST": parse_client_list, - "CLIENT INFO": parse_client_info, - "CLIENT SETNAME": bool_ok, - "CLIENT TRACKINGINFO": lambda r: list(map(str_if_bytes, r)), - "LASTSAVE": timestamp_to_datetime, - "RESET": str_if_bytes, - "SLOWLOG GET": parse_slowlog_get, - "TIME": lambda x: (int(x[0]), int(x[1])), - **string_keys_to_dict("BLPOP BRPOP", lambda r: r and tuple(r) or None), - "SCAN": parse_scan, - "CLIENT GETNAME": str_if_bytes, - "SSCAN": parse_scan, - "ACL LOG": parse_acl_log, - "ACL WHOAMI": str_if_bytes, - "ACL GENPASS": str_if_bytes, - "ACL CAT": lambda r: list(map(str_if_bytes, r)), - "HSCAN": parse_hscan, - "ZSCAN": parse_zscan, - **string_keys_to_dict( - "BZPOPMIN BZPOPMAX", lambda r: r and (r[0], r[1], float(r[2])) or None - ), - "CLUSTER COUNT-FAILURE-REPORTS": lambda x: int(x), - "CLUSTER COUNTKEYSINSLOT": lambda x: int(x), - "CLUSTER FAILOVER": bool_ok, - "CLUSTER FORGET": bool_ok, - "CLUSTER INFO": parse_cluster_info, - "CLUSTER KEYSLOT": lambda x: int(x), - "CLUSTER MEET": bool_ok, - "CLUSTER NODES": parse_cluster_nodes, - "CLUSTER REPLICATE": bool_ok, - "CLUSTER RESET": bool_ok, - "CLUSTER SAVECONFIG": bool_ok, - "CLUSTER SETSLOT": bool_ok, - "CLUSTER SLAVES": parse_cluster_nodes, - **string_keys_to_dict("GEODIST", float_or_none), - "GEOHASH": lambda r: list(map(str_if_bytes, r)), - "GEOPOS": lambda r: list( - map(lambda ll: (float(ll[0]), float(ll[1])) if ll is not None else None, r) - ), - "GEOSEARCH": parse_geosearch_generic, - "GEORADIUS": parse_geosearch_generic, - "GEORADIUSBYMEMBER": parse_geosearch_generic, - "XAUTOCLAIM": parse_xautoclaim, - "XINFO STREAM": parse_xinfo_stream, - "XPENDING": parse_xpending, - **string_keys_to_dict("XREAD XREADGROUP", parse_xread), - "COMMAND GETKEYS": lambda r: list(map(str_if_bytes, r)), - **string_keys_to_dict("SORT", sort_return_tuples), - "PING": lambda r: str_if_bytes(r) == "PONG", - "ACL SETUSER": bool_ok, - "PUBSUB NUMSUB": parse_pubsub_numsub, - "SCRIPT FLUSH": bool_ok, - "SCRIPT LOAD": str_if_bytes, - "ACL GETUSER": parse_acl_getuser, - "CONFIG SET": bool_ok, - **string_keys_to_dict("XREVRANGE XRANGE", parse_stream_list), - "XCLAIM": parse_xclaim, - "CLUSTER SET-CONFIG-EPOCH": bool_ok, - "CLUSTER REPLICAS": parse_cluster_nodes, - "ACL LIST": lambda r: list(map(str_if_bytes, r)), - } - - RESP2_RESPONSE_CALLBACKS = { - "CONFIG GET": parse_config_get, - **string_keys_to_dict( - "SDIFF SINTER SMEMBERS SUNION", lambda r: r and set(r) or set() - ), - **string_keys_to_dict("READWRITE", bool_ok), - **string_keys_to_dict( - "ZPOPMAX ZPOPMIN ZINTER ZDIFF ZUNION ZRANGE ZRANGEBYSCORE " - "ZREVRANGE ZREVRANGEBYSCORE", - zset_score_pairs, - ), - **string_keys_to_dict("ZSCORE ZINCRBY", float_or_none), - "ZADD": parse_zadd, - "ZMSCORE": parse_zmscore, - "HGETALL": lambda r: r and pairs_to_dict(r) or {}, - "MEMORY STATS": parse_memory_stats, - "MODULE LIST": lambda r: [pairs_to_dict(m) for m in r], - "STRALGO": parse_stralgo, - # **string_keys_to_dict( - # "COPY " - # "HEXISTS HMSET MOVE MSETNX PERSIST " - # "PSETEX RENAMENX SMOVE SETEX SETNX", - # bool, - # ), - # **string_keys_to_dict( - # "HSTRLEN INCRBY LINSERT LLEN LPUSHX PFADD PFCOUNT RPUSHX SADD " - # "SCARD SDIFFSTORE SETBIT SETRANGE SINTERSTORE SREM STRLEN " - # "SUNIONSTORE UNLINK XACK XDEL XLEN XTRIM ZCARD ZLEXCOUNT ZREM " - # "ZREMRANGEBYLEX ZREMRANGEBYRANK ZREMRANGEBYSCORE", - # int, - # ), - # **string_keys_to_dict( - # "FLUSHALL FLUSHDB LSET LTRIM PFMERGE ASKING " - # "RENAME SAVE SELECT SHUTDOWN SLAVEOF SWAPDB WATCH UNWATCH ", - # bool_ok, - # ), - # **string_keys_to_dict("ZRANK ZREVRANK", int_or_none), - # **string_keys_to_dict("BGREWRITEAOF BGSAVE", lambda r: True), - # "ACL HELP": lambda r: list(map(str_if_bytes, r)), - # "ACL LOAD": bool_ok, - # "ACL SAVE": bool_ok, - # "ACL USERS": lambda r: list(map(str_if_bytes, r)), - # "CLIENT UNBLOCK": lambda r: r and int(r) == 1 or False, - # "CLIENT PAUSE": bool_ok, - # "CLUSTER ADDSLOTSRANGE": bool_ok, - # "CLUSTER DELSLOTSRANGE": bool_ok, - # "CLUSTER GETKEYSINSLOT": lambda r: list(map(str_if_bytes, r)), - # "CONFIG RESETSTAT": bool_ok, - # "DEBUG OBJECT": parse_debug_object, - # "FUNCTION DELETE": bool_ok, - # "FUNCTION FLUSH": bool_ok, - # "FUNCTION RESTORE": bool_ok, - # "MEMORY PURGE": bool_ok, - # "MEMORY USAGE": int_or_none, - # "MODULE LOAD": parse_module_result, - # "MODULE UNLOAD": parse_module_result, - # "OBJECT": parse_object, - # "QUIT": bool_ok, - # "RANDOMKEY": lambda r: r and r or None, - # "SCRIPT EXISTS": lambda r: list(map(bool, r)), - # "SCRIPT KILL": bool_ok, - # "SENTINEL CKQUORUM": bool_ok, - # "SENTINEL FAILOVER": bool_ok, - # "SENTINEL FLUSHCONFIG": bool_ok, - # "SENTINEL GET-MASTER-ADDR-BY-NAME": parse_sentinel_get_master, - # "SENTINEL MASTER": parse_sentinel_master, - # "SENTINEL MASTERS": parse_sentinel_masters, - # "SENTINEL MONITOR": bool_ok, - # "SENTINEL RESET": bool_ok, - # "SENTINEL REMOVE": bool_ok, - # "SENTINEL SENTINELS": parse_sentinel_slaves_and_sentinels, - # "SENTINEL SET": bool_ok, - # "SENTINEL SLAVES": parse_sentinel_slaves_and_sentinels, - # "SLOWLOG RESET": bool_ok, - # "XGROUP CREATE": bool_ok, - # "XGROUP DESTROY": bool, - # "XGROUP SETID": bool_ok, - "XINFO CONSUMERS": parse_list_of_dicts, - "XINFO GROUPS": parse_list_of_dicts, - } - - RESP3_RESPONSE_CALLBACKS = { - **string_keys_to_dict( - "ZRANGE ZINTER ZPOPMAX ZPOPMIN ZRANGEBYSCORE ZREVRANGE ZREVRANGEBYSCORE " - "ZUNION HGETALL XREADGROUP", - lambda r, **kwargs: r, - ), - "CONFIG GET": lambda r: { - str_if_bytes(key) - if key is not None - else None: str_if_bytes(value) - if value is not None - else None - for key, value in r.items() - }, - "ACL LOG": lambda r: [ - {str_if_bytes(key): str_if_bytes(value) for key, value in x.items()} - for x in r - ] - if isinstance(r, list) - else bool_ok(r), - **string_keys_to_dict("XREAD XREADGROUP", parse_xread_resp3), - "COMMAND": parse_command_resp3, - "STRALGO": lambda r, **options: { - str_if_bytes(key): str_if_bytes(value) for key, value in r.items() - } - if isinstance(r, dict) - else str_if_bytes(r), - "XINFO CONSUMERS": lambda r: [ - {str_if_bytes(key): value for key, value in x.items()} for x in r - ], - "MEMORY STATS": lambda r: { - str_if_bytes(key): value for key, value in r.items() - }, - "XINFO GROUPS": lambda r: [ - {str_if_bytes(key): value for key, value in d.items()} for d in r - ], - } - - -class Redis(AbstractRedis, RedisModuleCommands, CoreCommands, SentinelCommands): +class Redis(RedisModuleCommands, CoreCommands, SentinelCommands): """ Implementation of the Redis protocol. @@ -1125,12 +270,12 @@ def __init__( if single_connection_client: self.connection = self.connection_pool.get_connection("_") - self.response_callbacks = CaseInsensitiveDict(self.__class__.RESPONSE_CALLBACKS) + self.response_callbacks = CaseInsensitiveDict(_RedisCallbacks) if self.connection_pool.connection_kwargs.get("protocol") in ["3", 3]: - self.response_callbacks.update(self.__class__.RESP3_RESPONSE_CALLBACKS) + self.response_callbacks.update(_RedisCallbacksRESP3) else: - self.response_callbacks.update(self.__class__.RESP2_RESPONSE_CALLBACKS) + self.response_callbacks.update(_RedisCallbacksRESP2) def __repr__(self): return f"{type(self).__name__}<{repr(self.connection_pool)}>" diff --git a/redis/cluster.py b/redis/cluster.py index 0fc715f838..c179511b0c 100644 --- a/redis/cluster.py +++ b/redis/cluster.py @@ -6,8 +6,10 @@ from collections import OrderedDict from typing import Any, Callable, Dict, List, Optional, Tuple, Union +from redis._parsers import CommandsParser, Encoder +from redis._parsers.helpers import parse_scan from redis.backoff import default_backoff -from redis.client import CaseInsensitiveDict, PubSub, Redis, parse_scan +from redis.client import CaseInsensitiveDict, PubSub, Redis from redis.commands import READ_COMMANDS, RedisClusterCommands from redis.commands.helpers import list_or_args from redis.connection import ConnectionPool, DefaultParser, parse_url @@ -30,7 +32,6 @@ TryAgainError, ) from redis.lock import Lock -from redis.parsers import CommandsParser, Encoder from redis.retry import Retry from redis.utils import ( HIREDIS_AVAILABLE, diff --git a/redis/commands/bf/__init__.py b/redis/commands/bf/__init__.py index 63d866353e..bfa9456879 100644 --- a/redis/commands/bf/__init__.py +++ b/redis/commands/bf/__init__.py @@ -1,4 +1,4 @@ -from redis.client import bool_ok +from redis._parsers.helpers import bool_ok from ..helpers import parse_to_list from .commands import * # noqa @@ -91,7 +91,7 @@ class CMSBloom(CMSCommands, AbstractBloom): def __init__(self, client, **kwargs): """Create a new RedisBloom client.""" # Set the module commands' callbacks - MODULE_CALLBACKS = { + _MODULE_CALLBACKS = { CMS_INITBYDIM: bool_ok, CMS_INITBYPROB: bool_ok, # CMS_INCRBY: spaceHolder, @@ -99,21 +99,21 @@ def __init__(self, client, **kwargs): CMS_MERGE: bool_ok, } - RESP2_MODULE_CALLBACKS = { + _RESP2_MODULE_CALLBACKS = { CMS_INFO: CMSInfo, } - RESP3_MODULE_CALLBACKS = {} + _RESP3_MODULE_CALLBACKS = {} self.client = client self.commandmixin = CMSCommands self.execute_command = client.execute_command if self.client.connection_pool.connection_kwargs.get("protocol") in ["3", 3]: - MODULE_CALLBACKS.update(RESP3_MODULE_CALLBACKS) + _MODULE_CALLBACKS.update(_RESP3_MODULE_CALLBACKS) else: - MODULE_CALLBACKS.update(RESP2_MODULE_CALLBACKS) + _MODULE_CALLBACKS.update(_RESP2_MODULE_CALLBACKS) - for k, v in MODULE_CALLBACKS.items(): + for k, v in _MODULE_CALLBACKS.items(): self.client.set_response_callback(k, v) @@ -121,30 +121,30 @@ class TOPKBloom(TOPKCommands, AbstractBloom): def __init__(self, client, **kwargs): """Create a new RedisBloom client.""" # Set the module commands' callbacks - MODULE_CALLBACKS = { + _MODULE_CALLBACKS = { TOPK_RESERVE: bool_ok, # TOPK_QUERY: spaceHolder, # TOPK_COUNT: spaceHolder, } - RESP2_MODULE_CALLBACKS = { + _RESP2_MODULE_CALLBACKS = { TOPK_ADD: parse_to_list, TOPK_INCRBY: parse_to_list, - TOPK_LIST: parse_to_list, TOPK_INFO: TopKInfo, + TOPK_LIST: parse_to_list, } - RESP3_MODULE_CALLBACKS = {} + _RESP3_MODULE_CALLBACKS = {} self.client = client self.commandmixin = TOPKCommands self.execute_command = client.execute_command if self.client.connection_pool.connection_kwargs.get("protocol") in ["3", 3]: - MODULE_CALLBACKS.update(RESP3_MODULE_CALLBACKS) + _MODULE_CALLBACKS.update(_RESP3_MODULE_CALLBACKS) else: - MODULE_CALLBACKS.update(RESP2_MODULE_CALLBACKS) + _MODULE_CALLBACKS.update(_RESP2_MODULE_CALLBACKS) - for k, v in MODULE_CALLBACKS.items(): + for k, v in _MODULE_CALLBACKS.items(): self.client.set_response_callback(k, v) @@ -152,7 +152,7 @@ class CFBloom(CFCommands, AbstractBloom): def __init__(self, client, **kwargs): """Create a new RedisBloom client.""" # Set the module commands' callbacks - MODULE_CALLBACKS = { + _MODULE_CALLBACKS = { CF_RESERVE: bool_ok, # CF_ADD: spaceHolder, # CF_ADDNX: spaceHolder, @@ -165,21 +165,21 @@ def __init__(self, client, **kwargs): # CF_LOADCHUNK: spaceHolder, } - RESP2_MODULE_CALLBACKS = { + _RESP2_MODULE_CALLBACKS = { CF_INFO: CFInfo, } - RESP3_MODULE_CALLBACKS = {} + _RESP3_MODULE_CALLBACKS = {} self.client = client self.commandmixin = CFCommands self.execute_command = client.execute_command if self.client.connection_pool.connection_kwargs.get("protocol") in ["3", 3]: - MODULE_CALLBACKS.update(RESP3_MODULE_CALLBACKS) + _MODULE_CALLBACKS.update(_RESP3_MODULE_CALLBACKS) else: - MODULE_CALLBACKS.update(RESP2_MODULE_CALLBACKS) + _MODULE_CALLBACKS.update(_RESP2_MODULE_CALLBACKS) - for k, v in MODULE_CALLBACKS.items(): + for k, v in _MODULE_CALLBACKS.items(): self.client.set_response_callback(k, v) @@ -187,35 +187,35 @@ class TDigestBloom(TDigestCommands, AbstractBloom): def __init__(self, client, **kwargs): """Create a new RedisBloom client.""" # Set the module commands' callbacks - MODULE_CALLBACKS = { + _MODULE_CALLBACKS = { TDIGEST_CREATE: bool_ok, # TDIGEST_RESET: bool_ok, # TDIGEST_ADD: spaceHolder, # TDIGEST_MERGE: spaceHolder, } - RESP2_MODULE_CALLBACKS = { + _RESP2_MODULE_CALLBACKS = { TDIGEST_BYRANK: parse_to_list, TDIGEST_BYREVRANK: parse_to_list, TDIGEST_CDF: parse_to_list, - TDIGEST_QUANTILE: parse_to_list, + TDIGEST_INFO: TDigestInfo, TDIGEST_MIN: float, TDIGEST_MAX: float, TDIGEST_TRIMMED_MEAN: float, - TDIGEST_INFO: TDigestInfo, + TDIGEST_QUANTILE: parse_to_list, } - RESP3_MODULE_CALLBACKS = {} + _RESP3_MODULE_CALLBACKS = {} self.client = client self.commandmixin = TDigestCommands self.execute_command = client.execute_command if self.client.connection_pool.connection_kwargs.get("protocol") in ["3", 3]: - MODULE_CALLBACKS.update(RESP3_MODULE_CALLBACKS) + _MODULE_CALLBACKS.update(_RESP3_MODULE_CALLBACKS) else: - MODULE_CALLBACKS.update(RESP2_MODULE_CALLBACKS) + _MODULE_CALLBACKS.update(_RESP2_MODULE_CALLBACKS) - for k, v in MODULE_CALLBACKS.items(): + for k, v in _MODULE_CALLBACKS.items(): self.client.set_response_callback(k, v) @@ -223,7 +223,7 @@ class BFBloom(BFCommands, AbstractBloom): def __init__(self, client, **kwargs): """Create a new RedisBloom client.""" # Set the module commands' callbacks - MODULE_CALLBACKS = { + _MODULE_CALLBACKS = { BF_RESERVE: bool_ok, # BF_ADD: spaceHolder, # BF_MADD: spaceHolder, @@ -235,19 +235,19 @@ def __init__(self, client, **kwargs): # BF_CARD: spaceHolder, } - RESP2_MODULE_CALLBACKS = { + _RESP2_MODULE_CALLBACKS = { BF_INFO: BFInfo, } - RESP3_MODULE_CALLBACKS = {} + _RESP3_MODULE_CALLBACKS = {} self.client = client self.commandmixin = BFCommands self.execute_command = client.execute_command if self.client.connection_pool.connection_kwargs.get("protocol") in ["3", 3]: - MODULE_CALLBACKS.update(RESP3_MODULE_CALLBACKS) + _MODULE_CALLBACKS.update(_RESP3_MODULE_CALLBACKS) else: - MODULE_CALLBACKS.update(RESP2_MODULE_CALLBACKS) + _MODULE_CALLBACKS.update(_RESP2_MODULE_CALLBACKS) - for k, v in MODULE_CALLBACKS.items(): + for k, v in _MODULE_CALLBACKS.items(): self.client.set_response_callback(k, v) diff --git a/redis/commands/json/__init__.py b/redis/commands/json/__init__.py index 1980a25c03..e895e6a2ba 100644 --- a/redis/commands/json/__init__.py +++ b/redis/commands/json/__init__.py @@ -31,37 +31,37 @@ def __init__( :type json.JSONEncoder: An instance of json.JSONEncoder """ # Set the module commands' callbacks - self.MODULE_CALLBACKS = { + self._MODULE_CALLBACKS = { "JSON.ARRPOP": self._decode, - "JSON.MGET": bulk_of_jsons(self._decode), - "JSON.SET": lambda r: r and nativestr(r) == "OK", "JSON.DEBUG": self._decode, - "JSON.MSET": lambda r: r and nativestr(r) == "OK", "JSON.MERGE": lambda r: r and nativestr(r) == "OK", - "JSON.TOGGLE": self._decode, + "JSON.MGET": bulk_of_jsons(self._decode), + "JSON.MSET": lambda r: r and nativestr(r) == "OK", "JSON.RESP": self._decode, + "JSON.SET": lambda r: r and nativestr(r) == "OK", + "JSON.TOGGLE": self._decode, } - RESP2_MODULE_CALLBACKS = { - "JSON.ARRTRIM": self._decode, - "JSON.OBJLEN": self._decode, + _RESP2_MODULE_CALLBACKS = { "JSON.ARRAPPEND": self._decode, "JSON.ARRINDEX": self._decode, "JSON.ARRINSERT": self._decode, - "JSON.TOGGLE": self._decode, - "JSON.STRAPPEND": self._decode, - "JSON.STRLEN": self._decode, "JSON.ARRLEN": self._decode, + "JSON.ARRTRIM": self._decode, "JSON.CLEAR": int, "JSON.DEL": int, "JSON.FORGET": int, + "JSON.GET": self._decode, "JSON.NUMINCRBY": self._decode, "JSON.NUMMULTBY": self._decode, "JSON.OBJKEYS": self._decode, - "JSON.GET": self._decode, + "JSON.STRAPPEND": self._decode, + "JSON.OBJLEN": self._decode, + "JSON.STRLEN": self._decode, + "JSON.TOGGLE": self._decode, } - RESP3_MODULE_CALLBACKS = { + _RESP3_MODULE_CALLBACKS = { "JSON.GET": lambda response: [ [self._decode(r) for r in res] for res in response ] @@ -74,11 +74,11 @@ def __init__( self.MODULE_VERSION = version if self.client.connection_pool.connection_kwargs.get("protocol") in ["3", 3]: - self.MODULE_CALLBACKS.update(RESP3_MODULE_CALLBACKS) + self._MODULE_CALLBACKS.update(_RESP3_MODULE_CALLBACKS) else: - self.MODULE_CALLBACKS.update(RESP2_MODULE_CALLBACKS) + self._MODULE_CALLBACKS.update(_RESP2_MODULE_CALLBACKS) - for key, value in self.MODULE_CALLBACKS.items(): + for key, value in self._MODULE_CALLBACKS.items(): self.client.set_response_callback(key, value) self.__encoder__ = encoder @@ -134,7 +134,7 @@ def pipeline(self, transaction=True, shard_hint=None): else: p = Pipeline( connection_pool=self.client.connection_pool, - response_callbacks=self.MODULE_CALLBACKS, + response_callbacks=self._MODULE_CALLBACKS, transaction=transaction, shard_hint=shard_hint, ) diff --git a/redis/commands/search/__init__.py b/redis/commands/search/__init__.py index 7a7fdff844..e635f91e99 100644 --- a/redis/commands/search/__init__.py +++ b/redis/commands/search/__init__.py @@ -95,12 +95,12 @@ def __init__(self, client, index_name="idx"): If conn is not None, we employ an already existing redis connection """ - self.MODULE_CALLBACKS = {} + self._MODULE_CALLBACKS = {} self.client = client self.index_name = index_name self.execute_command = client.execute_command self._pipeline = client.pipeline - self.RESP2_MODULE_CALLBACKS = { + self._RESP2_MODULE_CALLBACKS = { INFO_CMD: self._parse_info, SEARCH_CMD: self._parse_search, AGGREGATE_CMD: self._parse_aggregate, @@ -116,7 +116,7 @@ def pipeline(self, transaction=True, shard_hint=None): """ p = Pipeline( connection_pool=self.client.connection_pool, - response_callbacks=self.MODULE_CALLBACKS, + response_callbacks=self._MODULE_CALLBACKS, transaction=transaction, shard_hint=shard_hint, ) @@ -174,7 +174,7 @@ def pipeline(self, transaction=True, shard_hint=None): """ p = AsyncPipeline( connection_pool=self.client.connection_pool, - response_callbacks=self.MODULE_CALLBACKS, + response_callbacks=self._MODULE_CALLBACKS, transaction=transaction, shard_hint=shard_hint, ) diff --git a/redis/commands/search/commands.py b/redis/commands/search/commands.py index 50ebf8c203..742474523f 100644 --- a/redis/commands/search/commands.py +++ b/redis/commands/search/commands.py @@ -67,7 +67,7 @@ def _parse_results(self, cmd, res, **kwargs): if self.client.connection_pool.connection_kwargs.get("protocol") in ["3", 3]: return res else: - return self.RESP2_MODULE_CALLBACKS[cmd](res, **kwargs) + return self._RESP2_MODULE_CALLBACKS[cmd](res, **kwargs) def _parse_info(self, res, **kwargs): it = map(to_string, res) diff --git a/redis/commands/timeseries/__init__.py b/redis/commands/timeseries/__init__.py index 7e085af768..498f5118f1 100644 --- a/redis/commands/timeseries/__init__.py +++ b/redis/commands/timeseries/__init__.py @@ -1,5 +1,5 @@ import redis -from redis.client import bool_ok +from redis._parsers.helpers import bool_ok from ..helpers import parse_to_list from .commands import ( @@ -33,35 +33,35 @@ class TimeSeries(TimeSeriesCommands): def __init__(self, client=None, **kwargs): """Create a new RedisTimeSeries client.""" # Set the module commands' callbacks - self.MODULE_CALLBACKS = { - CREATE_CMD: bool_ok, + self._MODULE_CALLBACKS = { ALTER_CMD: bool_ok, + CREATE_CMD: bool_ok, CREATERULE_CMD: bool_ok, DELETERULE_CMD: bool_ok, } - RESP2_MODULE_CALLBACKS = { + _RESP2_MODULE_CALLBACKS = { DEL_CMD: int, GET_CMD: parse_get, - QUERYINDEX_CMD: parse_to_list, - RANGE_CMD: parse_range, - REVRANGE_CMD: parse_range, + INFO_CMD: TSInfo, MGET_CMD: parse_m_get, MRANGE_CMD: parse_m_range, MREVRANGE_CMD: parse_m_range, - INFO_CMD: TSInfo, + RANGE_CMD: parse_range, + REVRANGE_CMD: parse_range, + QUERYINDEX_CMD: parse_to_list, } - RESP3_MODULE_CALLBACKS = {} + _RESP3_MODULE_CALLBACKS = {} self.client = client self.execute_command = client.execute_command if self.client.connection_pool.connection_kwargs.get("protocol") in ["3", 3]: - self.MODULE_CALLBACKS.update(RESP3_MODULE_CALLBACKS) + self._MODULE_CALLBACKS.update(_RESP3_MODULE_CALLBACKS) else: - self.MODULE_CALLBACKS.update(RESP2_MODULE_CALLBACKS) + self._MODULE_CALLBACKS.update(_RESP2_MODULE_CALLBACKS) - for k, v in self.MODULE_CALLBACKS.items(): + for k, v in self._MODULE_CALLBACKS.items(): self.client.set_response_callback(k, v) def pipeline(self, transaction=True, shard_hint=None): @@ -93,7 +93,7 @@ def pipeline(self, transaction=True, shard_hint=None): else: p = Pipeline( connection_pool=self.client.connection_pool, - response_callbacks=self.MODULE_CALLBACKS, + response_callbacks=self._MODULE_CALLBACKS, transaction=transaction, shard_hint=shard_hint, ) diff --git a/redis/connection.py b/redis/connection.py index 845350df17..66debed2ea 100644 --- a/redis/connection.py +++ b/redis/connection.py @@ -12,6 +12,7 @@ from typing import Optional, Type, Union from urllib.parse import parse_qs, unquote, urlparse +from ._parsers import Encoder, _HiredisParser, _RESP2Parser, _RESP3Parser from .backoff import NoBackoff from .credentials import CredentialProvider, UsernamePasswordCredentialProvider from .exceptions import ( @@ -24,7 +25,6 @@ ResponseError, TimeoutError, ) -from .parsers import Encoder, _HiredisParser, _RESP2Parser, _RESP3Parser from .retry import Retry from .utils import ( CRYPTOGRAPHY_AVAILABLE, diff --git a/redis/typing.py b/redis/typing.py index e555f57f5b..56a1e99ba7 100644 --- a/redis/typing.py +++ b/redis/typing.py @@ -15,9 +15,9 @@ from redis.compat import Protocol if TYPE_CHECKING: + from redis._parsers import Encoder from redis.asyncio.connection import ConnectionPool as AsyncConnectionPool from redis.connection import ConnectionPool - from redis.parsers import Encoder Number = Union[int, float] diff --git a/setup.py b/setup.py index b68ceaaf18..dce48fc259 100644 --- a/setup.py +++ b/setup.py @@ -12,6 +12,7 @@ packages=find_packages( include=[ "redis", + "redis._parsers", "redis.asyncio", "redis.commands", "redis.commands.bf", diff --git a/tests/conftest.py b/tests/conftest.py index 50459420ec..b3c410e51b 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -371,7 +371,7 @@ def mock_cluster_resp_ok(request, **kwargs): @pytest.fixture() def mock_cluster_resp_int(request, **kwargs): r = _get_client(redis.Redis, request, **kwargs) - return _gen_cluster_mock_resp(r, "2") + return _gen_cluster_mock_resp(r, 2) @pytest.fixture() diff --git a/tests/test_asyncio/conftest.py b/tests/test_asyncio/conftest.py index a7d121fa49..e5da3f8f46 100644 --- a/tests/test_asyncio/conftest.py +++ b/tests/test_asyncio/conftest.py @@ -6,11 +6,11 @@ import pytest_asyncio import redis.asyncio as redis from packaging.version import Version +from redis._parsers import _AsyncHiredisParser, _AsyncRESP2Parser from redis.asyncio.client import Monitor from redis.asyncio.connection import parse_url from redis.asyncio.retry import Retry from redis.backoff import NoBackoff -from redis.parsers import _AsyncHiredisParser, _AsyncRESP2Parser from redis.utils import HIREDIS_AVAILABLE from tests.conftest import REDIS_INFO @@ -154,7 +154,7 @@ async def mock_cluster_resp_ok(create_redis, **kwargs): @pytest_asyncio.fixture() async def mock_cluster_resp_int(create_redis, **kwargs): r = await create_redis(**kwargs) - return _gen_cluster_mock_resp(r, "2") + return _gen_cluster_mock_resp(r, 2) @pytest_asyncio.fixture() diff --git a/tests/test_asyncio/test_cluster.py b/tests/test_asyncio/test_cluster.py index 2c722826e1..ee498e71f7 100644 --- a/tests/test_asyncio/test_cluster.py +++ b/tests/test_asyncio/test_cluster.py @@ -8,6 +8,7 @@ import pytest import pytest_asyncio from _pytest.fixtures import FixtureRequest +from redis._parsers import AsyncCommandsParser from redis.asyncio.cluster import ClusterNode, NodesManager, RedisCluster from redis.asyncio.connection import Connection, SSLConnection, async_timeout from redis.asyncio.retry import Retry @@ -26,7 +27,6 @@ RedisError, ResponseError, ) -from redis.parsers import AsyncCommandsParser from redis.utils import str_if_bytes from tests.conftest import ( assert_resp_response, @@ -964,7 +964,7 @@ async def test_client_setname(self, r: RedisCluster) -> None: node = r.get_random_node() await r.client_setname("redis_py_test", target_nodes=node) client_name = await r.client_getname(target_nodes=node) - assert client_name == "redis_py_test" + assert_resp_response(r, client_name, "redis_py_test", b"redis_py_test") async def test_exists(self, r: RedisCluster) -> None: d = {"a": b"1", "b": b"2", "c": b"3", "d": b"4"} @@ -1443,7 +1443,7 @@ async def test_client_trackinginfo(self, r: RedisCluster) -> None: node = r.get_primaries()[0] res = await r.client_trackinginfo(target_nodes=node) assert len(res) > 2 - assert "prefixes" in res + assert "prefixes" in res or b"prefixes" in res @skip_if_server_version_lt("2.9.50") async def test_client_pause(self, r: RedisCluster) -> None: @@ -1609,24 +1609,68 @@ async def test_cluster_renamenx(self, r: RedisCluster) -> None: async def test_cluster_blpop(self, r: RedisCluster) -> None: await r.rpush("{foo}a", "1", "2") await r.rpush("{foo}b", "3", "4") - assert await r.blpop(["{foo}b", "{foo}a"], timeout=1) == (b"{foo}b", b"3") - assert await r.blpop(["{foo}b", "{foo}a"], timeout=1) == (b"{foo}b", b"4") - assert await r.blpop(["{foo}b", "{foo}a"], timeout=1) == (b"{foo}a", b"1") - assert await r.blpop(["{foo}b", "{foo}a"], timeout=1) == (b"{foo}a", b"2") + assert_resp_response( + r, + await r.blpop(["{foo}b", "{foo}a"], timeout=1), + (b"{foo}b", b"3"), + [b"{foo}b", b"3"], + ) + assert_resp_response( + r, + await r.blpop(["{foo}b", "{foo}a"], timeout=1), + (b"{foo}b", b"4"), + [b"{foo}b", b"4"], + ) + assert_resp_response( + r, + await r.blpop(["{foo}b", "{foo}a"], timeout=1), + (b"{foo}a", b"1"), + [b"{foo}a", b"1"], + ) + assert_resp_response( + r, + await r.blpop(["{foo}b", "{foo}a"], timeout=1), + (b"{foo}a", b"2"), + [b"{foo}a", b"2"], + ) assert await r.blpop(["{foo}b", "{foo}a"], timeout=1) is None await r.rpush("{foo}c", "1") - assert await r.blpop("{foo}c", timeout=1) == (b"{foo}c", b"1") + assert_resp_response( + r, await r.blpop("{foo}c", timeout=1), (b"{foo}c", b"1"), [b"{foo}c", b"1"] + ) async def test_cluster_brpop(self, r: RedisCluster) -> None: await r.rpush("{foo}a", "1", "2") await r.rpush("{foo}b", "3", "4") - assert await r.brpop(["{foo}b", "{foo}a"], timeout=1) == (b"{foo}b", b"4") - assert await r.brpop(["{foo}b", "{foo}a"], timeout=1) == (b"{foo}b", b"3") - assert await r.brpop(["{foo}b", "{foo}a"], timeout=1) == (b"{foo}a", b"2") - assert await r.brpop(["{foo}b", "{foo}a"], timeout=1) == (b"{foo}a", b"1") + assert_resp_response( + r, + await r.brpop(["{foo}b", "{foo}a"], timeout=1), + (b"{foo}b", b"4"), + [b"{foo}b", b"4"], + ) + assert_resp_response( + r, + await r.brpop(["{foo}b", "{foo}a"], timeout=1), + (b"{foo}b", b"3"), + [b"{foo}b", b"3"], + ) + assert_resp_response( + r, + await r.brpop(["{foo}b", "{foo}a"], timeout=1), + (b"{foo}a", b"2"), + [b"{foo}a", b"2"], + ) + assert_resp_response( + r, + await r.brpop(["{foo}b", "{foo}a"], timeout=1), + (b"{foo}a", b"1"), + [b"{foo}a", b"1"], + ) assert await r.brpop(["{foo}b", "{foo}a"], timeout=1) is None await r.rpush("{foo}c", "1") - assert await r.brpop("{foo}c", timeout=1) == (b"{foo}c", b"1") + assert_resp_response( + r, await r.brpop("{foo}c", timeout=1), (b"{foo}c", b"1"), [b"{foo}c", b"1"] + ) async def test_cluster_brpoplpush(self, r: RedisCluster) -> None: await r.rpush("{foo}a", "1", "2") @@ -1811,57 +1855,75 @@ async def test_cluster_zinterstore_with_weight(self, r: RedisCluster) -> None: async def test_cluster_bzpopmax(self, r: RedisCluster) -> None: await r.zadd("{foo}a", {"a1": 1, "a2": 2}) await r.zadd("{foo}b", {"b1": 10, "b2": 20}) - assert await r.bzpopmax(["{foo}b", "{foo}a"], timeout=1) == ( - b"{foo}b", - b"b2", - 20, - ) - assert await r.bzpopmax(["{foo}b", "{foo}a"], timeout=1) == ( - b"{foo}b", - b"b1", - 10, - ) - assert await r.bzpopmax(["{foo}b", "{foo}a"], timeout=1) == ( - b"{foo}a", - b"a2", - 2, + assert_resp_response( + r, + await r.bzpopmax(["{foo}b", "{foo}a"], timeout=1), + (b"{foo}b", b"b2", 20), + [b"{foo}b", b"b2", 20], ) - assert await r.bzpopmax(["{foo}b", "{foo}a"], timeout=1) == ( - b"{foo}a", - b"a1", - 1, + assert_resp_response( + r, + await r.bzpopmax(["{foo}b", "{foo}a"], timeout=1), + (b"{foo}b", b"b1", 10), + [b"{foo}b", b"b1", 10], + ) + assert_resp_response( + r, + await r.bzpopmax(["{foo}b", "{foo}a"], timeout=1), + (b"{foo}a", b"a2", 2), + [b"{foo}a", b"a2", 2], + ) + assert_resp_response( + r, + await r.bzpopmax(["{foo}b", "{foo}a"], timeout=1), + (b"{foo}a", b"a1", 1), + [b"{foo}a", b"a1", 1], ) assert await r.bzpopmax(["{foo}b", "{foo}a"], timeout=1) is None await r.zadd("{foo}c", {"c1": 100}) - assert await r.bzpopmax("{foo}c", timeout=1) == (b"{foo}c", b"c1", 100) + assert_resp_response( + r, + await r.bzpopmax("{foo}c", timeout=1), + (b"{foo}c", b"c1", 100), + [b"{foo}c", b"c1", 100], + ) @skip_if_server_version_lt("4.9.0") async def test_cluster_bzpopmin(self, r: RedisCluster) -> None: await r.zadd("{foo}a", {"a1": 1, "a2": 2}) await r.zadd("{foo}b", {"b1": 10, "b2": 20}) - assert await r.bzpopmin(["{foo}b", "{foo}a"], timeout=1) == ( - b"{foo}b", - b"b1", - 10, - ) - assert await r.bzpopmin(["{foo}b", "{foo}a"], timeout=1) == ( - b"{foo}b", - b"b2", - 20, - ) - assert await r.bzpopmin(["{foo}b", "{foo}a"], timeout=1) == ( - b"{foo}a", - b"a1", - 1, + assert_resp_response( + r, + await r.bzpopmin(["{foo}b", "{foo}a"], timeout=1), + (b"{foo}b", b"b1", 10), + [b"b", b"b1", 10], ) - assert await r.bzpopmin(["{foo}b", "{foo}a"], timeout=1) == ( - b"{foo}a", - b"a2", - 2, + assert_resp_response( + r, + await r.bzpopmin(["{foo}b", "{foo}a"], timeout=1), + (b"{foo}b", b"b2", 20), + [b"b", b"b2", 20], + ) + assert_resp_response( + r, + await r.bzpopmin(["{foo}b", "{foo}a"], timeout=1), + (b"{foo}a", b"a1", 1), + [b"a", b"a1", 1], + ) + assert_resp_response( + r, + await r.bzpopmin(["{foo}b", "{foo}a"], timeout=1), + (b"{foo}a", b"a2", 2), + [b"a", b"a2", 2], ) assert await r.bzpopmin(["{foo}b", "{foo}a"], timeout=1) is None await r.zadd("{foo}c", {"c1": 100}) - assert await r.bzpopmin("{foo}c", timeout=1) == (b"{foo}c", b"c1", 100) + assert_resp_response( + r, + await r.bzpopmin("{foo}c", timeout=1), + (b"{foo}c", b"c1", 100), + [b"{foo}c", b"c1", 100], + ) @skip_if_server_version_lt("6.2.0") async def test_cluster_zrangestore(self, r: RedisCluster) -> None: diff --git a/tests/test_asyncio/test_commands.py b/tests/test_asyncio/test_commands.py index bcedda80ea..08e66b050f 100644 --- a/tests/test_asyncio/test_commands.py +++ b/tests/test_asyncio/test_commands.py @@ -12,7 +12,13 @@ import pytest_asyncio import redis from redis import exceptions -from redis.client import EMPTY_RESPONSE, NEVER_DECODE, parse_info +from redis._parsers.helpers import ( + _RedisCallbacks, + _RedisCallbacksRESP2, + _RedisCallbacksRESP3, + parse_info, +) +from redis.client import EMPTY_RESPONSE, NEVER_DECODE from tests.conftest import ( assert_resp_response, assert_resp_response_in, @@ -80,13 +86,13 @@ class TestResponseCallbacks: """Tests for the response callback system""" async def test_response_callbacks(self, r: redis.Redis): - callbacks = redis.Redis.RESPONSE_CALLBACKS + callbacks = _RedisCallbacks if is_resp2_connection(r): - callbacks.update(redis.Redis.RESP2_RESPONSE_CALLBACKS) + callbacks.update(_RedisCallbacksRESP2) else: - callbacks.update(redis.Redis.RESP3_RESPONSE_CALLBACKS) + callbacks.update(_RedisCallbacksRESP3) assert r.response_callbacks == callbacks - assert id(r.response_callbacks) != id(redis.Redis.RESPONSE_CALLBACKS) + assert id(r.response_callbacks) != id(_RedisCallbacks) r.set_response_callback("GET", lambda x: "static") await r.set("a", "foo") assert await r.get("a") == "static" @@ -106,13 +112,13 @@ async def test_command_on_invalid_key_type(self, r: redis.Redis): async def test_acl_cat_no_category(self, r: redis.Redis): categories = await r.acl_cat() assert isinstance(categories, list) - assert "read" in categories + assert "read" in categories or b"read" in categories @skip_if_server_version_lt(REDIS_6_VERSION) async def test_acl_cat_with_category(self, r: redis.Redis): commands = await r.acl_cat("read") assert isinstance(commands, list) - assert "get" in commands + assert "get" in commands or b"get" in commands @skip_if_server_version_lt(REDIS_6_VERSION) async def test_acl_deluser(self, r_teardown): @@ -126,7 +132,7 @@ async def test_acl_deluser(self, r_teardown): @skip_if_server_version_lt(REDIS_6_VERSION) async def test_acl_genpass(self, r: redis.Redis): password = await r.acl_genpass() - assert isinstance(password, str) + assert isinstance(password, (str, bytes)) @skip_if_server_version_lt("7.0.0") async def test_acl_getuser_setuser(self, r_teardown): @@ -307,7 +313,7 @@ async def test_acl_users(self, r: redis.Redis): @skip_if_server_version_lt(REDIS_6_VERSION) async def test_acl_whoami(self, r: redis.Redis): username = await r.acl_whoami() - assert isinstance(username, str) + assert isinstance(username, (str, bytes)) @pytest.mark.onlynoncluster async def test_client_list(self, r: redis.Redis): @@ -345,7 +351,9 @@ async def test_client_getname(self, r: redis.Redis): @pytest.mark.onlynoncluster async def test_client_setname(self, r: redis.Redis): assert await r.client_setname("redis_py_test") - assert await r.client_getname() == "redis_py_test" + assert_resp_response( + r, await r.client_getname(), "redis_py_test", b"redis_py_test" + ) @skip_if_server_version_lt("2.6.9") @pytest.mark.onlynoncluster @@ -1093,25 +1101,45 @@ async def test_type(self, r: redis.Redis): async def test_blpop(self, r: redis.Redis): await r.rpush("a", "1", "2") await r.rpush("b", "3", "4") - assert await r.blpop(["b", "a"], timeout=1) == (b"b", b"3") - assert await r.blpop(["b", "a"], timeout=1) == (b"b", b"4") - assert await r.blpop(["b", "a"], timeout=1) == (b"a", b"1") - assert await r.blpop(["b", "a"], timeout=1) == (b"a", b"2") + assert_resp_response( + r, await r.blpop(["b", "a"], timeout=1), (b"b", b"3"), [b"b", b"3"] + ) + assert_resp_response( + r, await r.blpop(["b", "a"], timeout=1), (b"b", b"4"), [b"b", b"4"] + ) + assert_resp_response( + r, await r.blpop(["b", "a"], timeout=1), (b"a", b"1"), [b"a", b"1"] + ) + assert_resp_response( + r, await r.blpop(["b", "a"], timeout=1), (b"a", b"2"), [b"a", b"2"] + ) assert await r.blpop(["b", "a"], timeout=1) is None await r.rpush("c", "1") - assert await r.blpop("c", timeout=1) == (b"c", b"1") + assert_resp_response( + r, await r.blpop("c", timeout=1), (b"c", b"1"), [b"c", b"1"] + ) @pytest.mark.onlynoncluster async def test_brpop(self, r: redis.Redis): await r.rpush("a", "1", "2") await r.rpush("b", "3", "4") - assert await r.brpop(["b", "a"], timeout=1) == (b"b", b"4") - assert await r.brpop(["b", "a"], timeout=1) == (b"b", b"3") - assert await r.brpop(["b", "a"], timeout=1) == (b"a", b"2") - assert await r.brpop(["b", "a"], timeout=1) == (b"a", b"1") + assert_resp_response( + r, await r.brpop(["b", "a"], timeout=1), (b"b", b"4"), [b"b", b"4"] + ) + assert_resp_response( + r, await r.brpop(["b", "a"], timeout=1), (b"b", b"3"), [b"b", b"3"] + ) + assert_resp_response( + r, await r.brpop(["b", "a"], timeout=1), (b"a", b"2"), [b"a", b"2"] + ) + assert_resp_response( + r, await r.brpop(["b", "a"], timeout=1), (b"a", b"1"), [b"a", b"1"] + ) assert await r.brpop(["b", "a"], timeout=1) is None await r.rpush("c", "1") - assert await r.brpop("c", timeout=1) == (b"c", b"1") + assert_resp_response( + r, await r.brpop("c", timeout=1), (b"c", b"1"), [b"c", b"1"] + ) @pytest.mark.onlynoncluster async def test_brpoplpush(self, r: redis.Redis): @@ -1626,26 +1654,70 @@ async def test_zpopmin(self, r: redis.Redis): async def test_bzpopmax(self, r: redis.Redis): await r.zadd("a", {"a1": 1, "a2": 2}) await r.zadd("b", {"b1": 10, "b2": 20}) - assert await r.bzpopmax(["b", "a"], timeout=1) == (b"b", b"b2", 20) - assert await r.bzpopmax(["b", "a"], timeout=1) == (b"b", b"b1", 10) - assert await r.bzpopmax(["b", "a"], timeout=1) == (b"a", b"a2", 2) - assert await r.bzpopmax(["b", "a"], timeout=1) == (b"a", b"a1", 1) + assert_resp_response( + r, + await r.bzpopmax(["b", "a"], timeout=1), + (b"b", b"b2", 20), + [b"b", b"b2", 20], + ) + assert_resp_response( + r, + await r.bzpopmax(["b", "a"], timeout=1), + (b"b", b"b1", 10), + [b"b", b"b1", 10], + ) + assert_resp_response( + r, + await r.bzpopmax(["b", "a"], timeout=1), + (b"a", b"a2", 2), + [b"a", b"a2", 2], + ) + assert_resp_response( + r, + await r.bzpopmax(["b", "a"], timeout=1), + (b"a", b"a1", 1), + [b"a", b"a1", 1], + ) assert await r.bzpopmax(["b", "a"], timeout=1) is None await r.zadd("c", {"c1": 100}) - assert await r.bzpopmax("c", timeout=1) == (b"c", b"c1", 100) + assert_resp_response( + r, await r.bzpopmax("c", timeout=1), (b"c", b"c1", 100), [b"c", b"c1", 100] + ) @skip_if_server_version_lt("4.9.0") @pytest.mark.onlynoncluster async def test_bzpopmin(self, r: redis.Redis): await r.zadd("a", {"a1": 1, "a2": 2}) await r.zadd("b", {"b1": 10, "b2": 20}) - assert await r.bzpopmin(["b", "a"], timeout=1) == (b"b", b"b1", 10) - assert await r.bzpopmin(["b", "a"], timeout=1) == (b"b", b"b2", 20) - assert await r.bzpopmin(["b", "a"], timeout=1) == (b"a", b"a1", 1) - assert await r.bzpopmin(["b", "a"], timeout=1) == (b"a", b"a2", 2) + assert_resp_response( + r, + await r.bzpopmin(["b", "a"], timeout=1), + (b"b", b"b1", 10), + [b"b", b"b1", 10], + ) + assert_resp_response( + r, + await r.bzpopmin(["b", "a"], timeout=1), + (b"b", b"b2", 20), + [b"b", b"b2", 20], + ) + assert_resp_response( + r, + await r.bzpopmin(["b", "a"], timeout=1), + (b"a", b"a1", 1), + [b"a", b"a1", 1], + ) + assert_resp_response( + r, + await r.bzpopmin(["b", "a"], timeout=1), + (b"a", b"a2", 2), + [b"a", b"a2", 2], + ) assert await r.bzpopmin(["b", "a"], timeout=1) is None await r.zadd("c", {"c1": 100}) - assert await r.bzpopmin("c", timeout=1) == (b"c", b"c1", 100) + assert_resp_response( + r, await r.bzpopmin("c", timeout=1), (b"c", b"c1", 100), [b"c", b"c1", 100] + ) async def test_zrange(self, r: redis.Redis): await r.zadd("a", {"a1": 1, "a2": 2, "a3": 3}) @@ -2332,11 +2404,12 @@ async def test_geohash(self, r: redis.Redis): ) await r.geoadd("barcelona", values) - assert await r.geohash("barcelona", "place1", "place2", "place3") == [ - "sp3e9yg3kd0", - "sp3e9cbc3t0", - None, - ] + assert_resp_response( + r, + await r.geohash("barcelona", "place1", "place2", "place3"), + ["sp3e9yg3kd0", "sp3e9cbc3t0", None], + [b"sp3e9yg3kd0", b"sp3e9cbc3t0", None], + ) @skip_if_server_version_lt("3.2.0") async def test_geopos(self, r: redis.Redis): @@ -2348,10 +2421,18 @@ async def test_geopos(self, r: redis.Redis): await r.geoadd("barcelona", values) # redis uses 52 bits precision, hereby small errors may be introduced. - assert await r.geopos("barcelona", "place1", "place2") == [ - (2.19093829393386841, 41.43379028184083523), - (2.18737632036209106, 41.40634178640635099), - ] + assert_resp_response( + r, + await r.geopos("barcelona", "place1", "place2"), + [ + (2.19093829393386841, 41.43379028184083523), + (2.18737632036209106, 41.40634178640635099), + ], + [ + [2.19093829393386841, 41.43379028184083523], + [2.18737632036209106, 41.40634178640635099], + ], + ) @skip_if_server_version_lt("4.0.0") async def test_geopos_no_value(self, r: redis.Redis): diff --git a/tests/test_asyncio/test_connection.py b/tests/test_asyncio/test_connection.py index ee4a107566..09960fd7e2 100644 --- a/tests/test_asyncio/test_connection.py +++ b/tests/test_asyncio/test_connection.py @@ -5,17 +5,17 @@ import pytest import redis -from redis.asyncio import Redis -from redis.asyncio.connection import Connection, UnixDomainSocketConnection -from redis.asyncio.retry import Retry -from redis.backoff import NoBackoff -from redis.exceptions import ConnectionError, InvalidResponse, TimeoutError -from redis.parsers import ( +from redis._parsers import ( _AsyncHiredisParser, _AsyncRESP2Parser, _AsyncRESP3Parser, _AsyncRESPBase, ) +from redis.asyncio import Redis +from redis.asyncio.connection import Connection, UnixDomainSocketConnection +from redis.asyncio.retry import Retry +from redis.backoff import NoBackoff +from redis.exceptions import ConnectionError, InvalidResponse, TimeoutError from redis.utils import HIREDIS_AVAILABLE from tests.conftest import skip_if_server_version_lt diff --git a/tests/test_asyncio/test_pubsub.py b/tests/test_asyncio/test_pubsub.py index 8cac17dac5..858576584f 100644 --- a/tests/test_asyncio/test_pubsub.py +++ b/tests/test_asyncio/test_pubsub.py @@ -1013,9 +1013,9 @@ async def get_msg(): assert msg is not None # timeout waiting for another message which never arrives assert pubsub.connection.is_connected - with patch("redis.parsers._AsyncRESP2Parser.read_response") as mock1, patch( - "redis.parsers._AsyncHiredisParser.read_response" - ) as mock2, patch("redis.parsers._AsyncRESP3Parser.read_response") as mock3: + with patch("redis._parsers._AsyncRESP2Parser.read_response") as mock1, patch( + "redis._parsers._AsyncHiredisParser.read_response" + ) as mock2, patch("redis._parsers._AsyncRESP3Parser.read_response") as mock3: mock1.side_effect = BaseException("boom") mock2.side_effect = BaseException("boom") mock3.side_effect = BaseException("boom") diff --git a/tests/test_cluster.py b/tests/test_cluster.py index a3a2a6beab..31c31026be 100644 --- a/tests/test_cluster.py +++ b/tests/test_cluster.py @@ -11,6 +11,7 @@ import pytest from redis import Redis +from redis._parsers import CommandsParser from redis.backoff import ExponentialBackoff, NoBackoff, default_backoff from redis.cluster import ( PRIMARY, @@ -35,7 +36,6 @@ ResponseError, TimeoutError, ) -from redis.parsers import CommandsParser from redis.retry import Retry from redis.utils import str_if_bytes from tests.test_pubsub import wait_for_message @@ -1000,7 +1000,7 @@ def test_client_setname(self, r): node = r.get_random_node() r.client_setname("redis_py_test", target_nodes=node) client_name = r.client_getname(target_nodes=node) - assert client_name == "redis_py_test" + assert_resp_response(r, client_name, "redis_py_test", b"redis_py_test") def test_exists(self, r): d = {"a": b"1", "b": b"2", "c": b"3", "d": b"4"} @@ -1595,7 +1595,7 @@ def test_client_trackinginfo(self, r): node = r.get_primaries()[0] res = r.client_trackinginfo(target_nodes=node) assert len(res) > 2 - assert "prefixes" in res + assert "prefixes" in res or b"prefixes" in res @skip_if_server_version_lt("2.9.50") def test_client_pause(self, r): @@ -1757,24 +1757,68 @@ def test_cluster_renamenx(self, r): def test_cluster_blpop(self, r): r.rpush("{foo}a", "1", "2") r.rpush("{foo}b", "3", "4") - assert r.blpop(["{foo}b", "{foo}a"], timeout=1) == (b"{foo}b", b"3") - assert r.blpop(["{foo}b", "{foo}a"], timeout=1) == (b"{foo}b", b"4") - assert r.blpop(["{foo}b", "{foo}a"], timeout=1) == (b"{foo}a", b"1") - assert r.blpop(["{foo}b", "{foo}a"], timeout=1) == (b"{foo}a", b"2") + assert_resp_response( + r, + r.blpop(["{foo}b", "{foo}a"], timeout=1), + (b"{foo}b", b"3"), + [b"{foo}b", b"3"], + ) + assert_resp_response( + r, + r.blpop(["{foo}b", "{foo}a"], timeout=1), + (b"{foo}b", b"4"), + [b"{foo}b", b"4"], + ) + assert_resp_response( + r, + r.blpop(["{foo}b", "{foo}a"], timeout=1), + (b"{foo}a", b"1"), + [b"{foo}a", b"1"], + ) + assert_resp_response( + r, + r.blpop(["{foo}b", "{foo}a"], timeout=1), + (b"{foo}a", b"2"), + [b"{foo}a", b"2"], + ) assert r.blpop(["{foo}b", "{foo}a"], timeout=1) is None r.rpush("{foo}c", "1") - assert r.blpop("{foo}c", timeout=1) == (b"{foo}c", b"1") + assert_resp_response( + r, r.blpop("{foo}c", timeout=1), (b"{foo}c", b"1"), [b"{foo}c", b"1"] + ) def test_cluster_brpop(self, r): r.rpush("{foo}a", "1", "2") r.rpush("{foo}b", "3", "4") - assert r.brpop(["{foo}b", "{foo}a"], timeout=1) == (b"{foo}b", b"4") - assert r.brpop(["{foo}b", "{foo}a"], timeout=1) == (b"{foo}b", b"3") - assert r.brpop(["{foo}b", "{foo}a"], timeout=1) == (b"{foo}a", b"2") - assert r.brpop(["{foo}b", "{foo}a"], timeout=1) == (b"{foo}a", b"1") + assert_resp_response( + r, + r.brpop(["{foo}b", "{foo}a"], timeout=1), + (b"{foo}b", b"4"), + [b"{foo}b", b"4"], + ) + assert_resp_response( + r, + r.brpop(["{foo}b", "{foo}a"], timeout=1), + (b"{foo}b", b"3"), + [b"{foo}b", b"3"], + ) + assert_resp_response( + r, + r.brpop(["{foo}b", "{foo}a"], timeout=1), + (b"{foo}a", b"2"), + [b"{foo}a", b"2"], + ) + assert_resp_response( + r, + r.brpop(["{foo}b", "{foo}a"], timeout=1), + (b"{foo}a", b"1"), + [b"{foo}a", b"1"], + ) assert r.brpop(["{foo}b", "{foo}a"], timeout=1) is None r.rpush("{foo}c", "1") - assert r.brpop("{foo}c", timeout=1) == (b"{foo}c", b"1") + assert_resp_response( + r, r.brpop("{foo}c", timeout=1), (b"{foo}c", b"1"), [b"{foo}c", b"1"] + ) def test_cluster_brpoplpush(self, r): r.rpush("{foo}a", "1", "2") @@ -1956,25 +2000,75 @@ def test_cluster_zinterstore_with_weight(self, r): def test_cluster_bzpopmax(self, r): r.zadd("{foo}a", {"a1": 1, "a2": 2}) r.zadd("{foo}b", {"b1": 10, "b2": 20}) - assert r.bzpopmax(["{foo}b", "{foo}a"], timeout=1) == (b"{foo}b", b"b2", 20) - assert r.bzpopmax(["{foo}b", "{foo}a"], timeout=1) == (b"{foo}b", b"b1", 10) - assert r.bzpopmax(["{foo}b", "{foo}a"], timeout=1) == (b"{foo}a", b"a2", 2) - assert r.bzpopmax(["{foo}b", "{foo}a"], timeout=1) == (b"{foo}a", b"a1", 1) + assert_resp_response( + r, + r.bzpopmax(["{foo}b", "{foo}a"], timeout=1), + (b"{foo}b", b"b2", 20), + [b"{foo}b", b"b2", 20], + ) + assert_resp_response( + r, + r.bzpopmax(["{foo}b", "{foo}a"], timeout=1), + (b"{foo}b", b"b1", 10), + [b"{foo}b", b"b1", 10], + ) + assert_resp_response( + r, + r.bzpopmax(["{foo}b", "{foo}a"], timeout=1), + (b"{foo}a", b"a2", 2), + [b"{foo}a", b"a2", 2], + ) + assert_resp_response( + r, + r.bzpopmax(["{foo}b", "{foo}a"], timeout=1), + (b"{foo}a", b"a1", 1), + [b"{foo}a", b"a1", 1], + ) assert r.bzpopmax(["{foo}b", "{foo}a"], timeout=1) is None r.zadd("{foo}c", {"c1": 100}) - assert r.bzpopmax("{foo}c", timeout=1) == (b"{foo}c", b"c1", 100) + assert_resp_response( + r, + r.bzpopmax("{foo}c", timeout=1), + (b"{foo}c", b"c1", 100), + [b"{foo}c", b"c1", 100], + ) @skip_if_server_version_lt("4.9.0") def test_cluster_bzpopmin(self, r): r.zadd("{foo}a", {"a1": 1, "a2": 2}) r.zadd("{foo}b", {"b1": 10, "b2": 20}) - assert r.bzpopmin(["{foo}b", "{foo}a"], timeout=1) == (b"{foo}b", b"b1", 10) - assert r.bzpopmin(["{foo}b", "{foo}a"], timeout=1) == (b"{foo}b", b"b2", 20) - assert r.bzpopmin(["{foo}b", "{foo}a"], timeout=1) == (b"{foo}a", b"a1", 1) - assert r.bzpopmin(["{foo}b", "{foo}a"], timeout=1) == (b"{foo}a", b"a2", 2) + assert_resp_response( + r, + r.bzpopmin(["{foo}b", "{foo}a"], timeout=1), + (b"{foo}b", b"b1", 10), + [b"b", b"b1", 10], + ) + assert_resp_response( + r, + r.bzpopmin(["{foo}b", "{foo}a"], timeout=1), + (b"{foo}b", b"b2", 20), + [b"b", b"b2", 20], + ) + assert_resp_response( + r, + r.bzpopmin(["{foo}b", "{foo}a"], timeout=1), + (b"{foo}a", b"a1", 1), + [b"a", b"a1", 1], + ) + assert_resp_response( + r, + r.bzpopmin(["{foo}b", "{foo}a"], timeout=1), + (b"{foo}a", b"a2", 2), + [b"a", b"a2", 2], + ) assert r.bzpopmin(["{foo}b", "{foo}a"], timeout=1) is None r.zadd("{foo}c", {"c1": 100}) - assert r.bzpopmin("{foo}c", timeout=1) == (b"{foo}c", b"c1", 100) + assert_resp_response( + r, + r.bzpopmin("{foo}c", timeout=1), + (b"{foo}c", b"c1", 100), + [b"{foo}c", b"c1", 100], + ) @skip_if_server_version_lt("6.2.0") def test_cluster_zrangestore(self, r): diff --git a/tests/test_command_parser.py b/tests/test_command_parser.py index c89a2ab0e5..e3b44a147f 100644 --- a/tests/test_command_parser.py +++ b/tests/test_command_parser.py @@ -1,7 +1,11 @@ import pytest -from redis.parsers import CommandsParser +from redis._parsers import CommandsParser -from .conftest import skip_if_redis_enterprise, skip_if_server_version_lt +from .conftest import ( + assert_resp_response, + skip_if_redis_enterprise, + skip_if_server_version_lt, +) class TestCommandsParser: @@ -50,13 +54,40 @@ def test_get_moveable_keys(self, r): ] args7 = ["MIGRATE", "192.168.1.34", 6379, "key1", 0, 5000] - assert sorted(commands_parser.get_keys(r, *args1)) == ["key1", "key2"] - assert sorted(commands_parser.get_keys(r, *args2)) == ["mystream", "writers"] - assert sorted(commands_parser.get_keys(r, *args3)) == ["out", "zset1", "zset2"] - assert sorted(commands_parser.get_keys(r, *args4)) == ["Sicily", "out"] - assert sorted(commands_parser.get_keys(r, *args5)) == ["foo"] - assert sorted(commands_parser.get_keys(r, *args6)) == ["key1", "key2", "key3"] - assert sorted(commands_parser.get_keys(r, *args7)) == ["key1"] + assert_resp_response( + r, + sorted(commands_parser.get_keys(r, *args1)), + ["key1", "key2"], + [b"key1", b"key2"], + ) + assert_resp_response( + r, + sorted(commands_parser.get_keys(r, *args2)), + ["mystream", "writers"], + [b"mystream", b"writers"], + ) + assert_resp_response( + r, + sorted(commands_parser.get_keys(r, *args3)), + ["out", "zset1", "zset2"], + [b"out", b"zset1", b"zset2"], + ) + assert_resp_response( + r, + sorted(commands_parser.get_keys(r, *args4)), + ["Sicily", "out"], + [b"Sicily", b"out"], + ) + assert sorted(commands_parser.get_keys(r, *args5)) in [["foo"], [b"foo"]] + assert_resp_response( + r, + sorted(commands_parser.get_keys(r, *args6)), + ["key1", "key2", "key3"], + [b"key1", b"key2", b"key3"], + ) + assert_resp_response( + r, sorted(commands_parser.get_keys(r, *args7)), ["key1"], [b"key1"] + ) # A bug in redis<7.0 causes this to fail: https://github.com/redis/redis/issues/9493 @skip_if_server_version_lt("7.0.0") diff --git a/tests/test_commands.py b/tests/test_commands.py index 1f17552c15..fdf41dc5fa 100644 --- a/tests/test_commands.py +++ b/tests/test_commands.py @@ -11,7 +11,13 @@ import pytest import redis from redis import exceptions -from redis.client import EMPTY_RESPONSE, NEVER_DECODE, parse_info +from redis._parsers.helpers import ( + _RedisCallbacks, + _RedisCallbacksRESP2, + _RedisCallbacksRESP3, + parse_info, +) +from redis.client import EMPTY_RESPONSE, NEVER_DECODE from .conftest import ( _get_client, @@ -60,13 +66,13 @@ class TestResponseCallbacks: "Tests for the response callback system" def test_response_callbacks(self, r): - callbacks = redis.Redis.RESPONSE_CALLBACKS + callbacks = _RedisCallbacks if is_resp2_connection(r): - callbacks.update(redis.Redis.RESP2_RESPONSE_CALLBACKS) + callbacks.update(_RedisCallbacksRESP2) else: - callbacks.update(redis.Redis.RESP3_RESPONSE_CALLBACKS) + callbacks.update(_RedisCallbacksRESP3) assert r.response_callbacks == callbacks - assert id(r.response_callbacks) != id(redis.Redis.RESPONSE_CALLBACKS) + assert id(r.response_callbacks) != id(_RedisCallbacks) r.set_response_callback("GET", lambda x: "static") r["a"] = "foo" assert r["a"] == "static" @@ -136,13 +142,13 @@ def test_command_on_invalid_key_type(self, r): def test_acl_cat_no_category(self, r): categories = r.acl_cat() assert isinstance(categories, list) - assert "read" in categories + assert "read" in categories or b"read" in categories @skip_if_server_version_lt("6.0.0") def test_acl_cat_with_category(self, r): commands = r.acl_cat("read") assert isinstance(commands, list) - assert "get" in commands + assert "get" in commands or b"get" in commands @skip_if_server_version_lt("7.0.0") @skip_if_redis_enterprise() @@ -188,7 +194,7 @@ def teardown(): @skip_if_redis_enterprise() def test_acl_genpass(self, r): password = r.acl_genpass() - assert isinstance(password, str) + assert isinstance(password, (str, bytes)) with pytest.raises(exceptions.DataError): r.acl_genpass("value") @@ -196,7 +202,7 @@ def test_acl_genpass(self, r): r.acl_genpass(5555) r.acl_genpass(555) - assert isinstance(password, str) + assert isinstance(password, (str, bytes)) @skip_if_server_version_lt("7.0.0") @skip_if_redis_enterprise() @@ -449,7 +455,7 @@ def test_acl_users(self, r): @skip_if_server_version_lt("6.0.0") def test_acl_whoami(self, r): username = r.acl_whoami() - assert isinstance(username, str) + assert isinstance(username, (str, bytes)) @pytest.mark.onlynoncluster def test_client_list(self, r): @@ -504,7 +510,7 @@ def test_client_id(self, r): def test_client_trackinginfo(self, r): res = r.client_trackinginfo() assert len(res) > 2 - assert "prefixes" in res + assert "prefixes" in res or b"prefixes" in res @pytest.mark.onlynoncluster @skip_if_server_version_lt("6.0.0") @@ -546,7 +552,7 @@ def test_client_getname(self, r): @skip_if_server_version_lt("2.6.9") def test_client_setname(self, r): assert r.client_setname("redis_py_test") - assert r.client_getname() == "redis_py_test" + assert_resp_response(r, r.client_getname(), "redis_py_test", b"redis_py_test") @pytest.mark.onlynoncluster @skip_if_server_version_lt("2.6.9") @@ -849,7 +855,7 @@ def test_lolwut(self, r): @skip_if_server_version_lt("6.2.0") @skip_if_redis_enterprise() def test_reset(self, r): - assert r.reset() == "RESET" + assert_resp_response(r, r.reset(), "RESET", b"RESET") def test_object(self, r): r["a"] = "foo" @@ -1816,25 +1822,41 @@ def test_type(self, r): def test_blpop(self, r): r.rpush("a", "1", "2") r.rpush("b", "3", "4") - assert r.blpop(["b", "a"], timeout=1) == (b"b", b"3") - assert r.blpop(["b", "a"], timeout=1) == (b"b", b"4") - assert r.blpop(["b", "a"], timeout=1) == (b"a", b"1") - assert r.blpop(["b", "a"], timeout=1) == (b"a", b"2") + assert_resp_response( + r, r.blpop(["b", "a"], timeout=1), (b"b", b"3"), [b"b", b"3"] + ) + assert_resp_response( + r, r.blpop(["b", "a"], timeout=1), (b"b", b"4"), [b"b", b"4"] + ) + assert_resp_response( + r, r.blpop(["b", "a"], timeout=1), (b"a", b"1"), [b"a", b"1"] + ) + assert_resp_response( + r, r.blpop(["b", "a"], timeout=1), (b"a", b"2"), [b"a", b"2"] + ) assert r.blpop(["b", "a"], timeout=1) is None r.rpush("c", "1") - assert r.blpop("c", timeout=1) == (b"c", b"1") + assert_resp_response(r, r.blpop("c", timeout=1), (b"c", b"1"), [b"c", b"1"]) @pytest.mark.onlynoncluster def test_brpop(self, r): r.rpush("a", "1", "2") r.rpush("b", "3", "4") - assert r.brpop(["b", "a"], timeout=1) == (b"b", b"4") - assert r.brpop(["b", "a"], timeout=1) == (b"b", b"3") - assert r.brpop(["b", "a"], timeout=1) == (b"a", b"2") - assert r.brpop(["b", "a"], timeout=1) == (b"a", b"1") + assert_resp_response( + r, r.brpop(["b", "a"], timeout=1), (b"b", b"4"), [b"b", b"4"] + ) + assert_resp_response( + r, r.brpop(["b", "a"], timeout=1), (b"b", b"3"), [b"b", b"3"] + ) + assert_resp_response( + r, r.brpop(["b", "a"], timeout=1), (b"a", b"2"), [b"a", b"2"] + ) + assert_resp_response( + r, r.brpop(["b", "a"], timeout=1), (b"a", b"1"), [b"a", b"1"] + ) assert r.brpop(["b", "a"], timeout=1) is None r.rpush("c", "1") - assert r.brpop("c", timeout=1) == (b"c", b"1") + assert_resp_response(r, r.brpop("c", timeout=1), (b"c", b"1"), [b"c", b"1"]) @pytest.mark.onlynoncluster def test_brpoplpush(self, r): @@ -2533,26 +2555,46 @@ def test_zrandemember(self, r): def test_bzpopmax(self, r): r.zadd("a", {"a1": 1, "a2": 2}) r.zadd("b", {"b1": 10, "b2": 20}) - assert r.bzpopmax(["b", "a"], timeout=1) == (b"b", b"b2", 20) - assert r.bzpopmax(["b", "a"], timeout=1) == (b"b", b"b1", 10) - assert r.bzpopmax(["b", "a"], timeout=1) == (b"a", b"a2", 2) - assert r.bzpopmax(["b", "a"], timeout=1) == (b"a", b"a1", 1) + assert_resp_response( + r, r.bzpopmax(["b", "a"], timeout=1), (b"b", b"b2", 20), [b"b", b"b2", 20] + ) + assert_resp_response( + r, r.bzpopmax(["b", "a"], timeout=1), (b"b", b"b1", 10), [b"b", b"b1", 10] + ) + assert_resp_response( + r, r.bzpopmax(["b", "a"], timeout=1), (b"a", b"a2", 2), [b"a", b"a2", 2] + ) + assert_resp_response( + r, r.bzpopmax(["b", "a"], timeout=1), (b"a", b"a1", 1), [b"a", b"a1", 1] + ) assert r.bzpopmax(["b", "a"], timeout=1) is None r.zadd("c", {"c1": 100}) - assert r.bzpopmax("c", timeout=1) == (b"c", b"c1", 100) + assert_resp_response( + r, r.bzpopmax("c", timeout=1), (b"c", b"c1", 100), [b"c", b"c1", 100] + ) @pytest.mark.onlynoncluster @skip_if_server_version_lt("4.9.0") def test_bzpopmin(self, r): r.zadd("a", {"a1": 1, "a2": 2}) r.zadd("b", {"b1": 10, "b2": 20}) - assert r.bzpopmin(["b", "a"], timeout=1) == (b"b", b"b1", 10) - assert r.bzpopmin(["b", "a"], timeout=1) == (b"b", b"b2", 20) - assert r.bzpopmin(["b", "a"], timeout=1) == (b"a", b"a1", 1) - assert r.bzpopmin(["b", "a"], timeout=1) == (b"a", b"a2", 2) + assert_resp_response( + r, r.bzpopmin(["b", "a"], timeout=1), (b"b", b"b1", 10), [b"b", b"b1", 10] + ) + assert_resp_response( + r, r.bzpopmin(["b", "a"], timeout=1), (b"b", b"b2", 20), [b"b", b"b2", 20] + ) + assert_resp_response( + r, r.bzpopmin(["b", "a"], timeout=1), (b"a", b"a1", 1), [b"a", b"a1", 1] + ) + assert_resp_response( + r, r.bzpopmin(["b", "a"], timeout=1), (b"a", b"a2", 2), [b"a", b"a2", 2] + ) assert r.bzpopmin(["b", "a"], timeout=1) is None r.zadd("c", {"c1": 100}) - assert r.bzpopmin("c", timeout=1) == (b"c", b"c1", 100) + assert_resp_response( + r, r.bzpopmin("c", timeout=1), (b"c", b"c1", 100), [b"c", b"c1", 100] + ) @pytest.mark.onlynoncluster @skip_if_server_version_lt("7.0.0") @@ -3448,11 +3490,12 @@ def test_geohash(self, r): "place2", ) r.geoadd("barcelona", values) - assert r.geohash("barcelona", "place1", "place2", "place3") == [ - "sp3e9yg3kd0", - "sp3e9cbc3t0", - None, - ] + assert_resp_response( + r, + r.geohash("barcelona", "place1", "place2", "place3"), + ["sp3e9yg3kd0", "sp3e9cbc3t0", None], + [b"sp3e9yg3kd0", b"sp3e9cbc3t0", None], + ) @skip_unless_arch_bits(64) @skip_if_server_version_lt("3.2.0") @@ -3464,10 +3507,18 @@ def test_geopos(self, r): ) r.geoadd("barcelona", values) # redis uses 52 bits precision, hereby small errors may be introduced. - assert r.geopos("barcelona", "place1", "place2") == [ - (2.19093829393386841, 41.43379028184083523), - (2.18737632036209106, 41.40634178640635099), - ] + assert_resp_response( + r, + r.geopos("barcelona", "place1", "place2"), + [ + (2.19093829393386841, 41.43379028184083523), + (2.18737632036209106, 41.40634178640635099), + ], + [ + [2.19093829393386841, 41.43379028184083523], + [2.18737632036209106, 41.40634178640635099], + ], + ) @skip_if_server_version_lt("4.0.0") def test_geopos_no_value(self, r): @@ -4832,7 +4883,7 @@ def test_command_list(self, r: redis.Redis): @skip_if_redis_enterprise() def test_command_getkeys(self, r): res = r.command_getkeys("MSET", "a", "b", "c", "d", "e", "f") - assert res == ["a", "c", "e"] + assert_resp_response(r, res, ["a", "c", "e"], [b"a", b"c", b"e"]) res = r.command_getkeys( "EVAL", '"not consulted"', @@ -4845,7 +4896,9 @@ def test_command_getkeys(self, r): "arg3", "argN", ) - assert res == ["key1", "key2", "key3"] + assert_resp_response( + r, res, ["key1", "key2", "key3"], [b"key1", b"key2", b"key3"] + ) @skip_if_server_version_lt("2.8.13") def test_command(self, r): diff --git a/tests/test_connection.py b/tests/test_connection.py index 64ae4c5d1f..760b23c9c1 100644 --- a/tests/test_connection.py +++ b/tests/test_connection.py @@ -5,10 +5,10 @@ import pytest import redis +from redis._parsers import _HiredisParser, _RESP2Parser, _RESP3Parser from redis.backoff import NoBackoff from redis.connection import Connection, SSLConnection, UnixDomainSocketConnection from redis.exceptions import ConnectionError, InvalidResponse, TimeoutError -from redis.parsers import _HiredisParser, _RESP2Parser, _RESP3Parser from redis.retry import Retry from redis.utils import HIREDIS_AVAILABLE diff --git a/tests/test_pubsub.py b/tests/test_pubsub.py index 9c10740ae8..ba097e3194 100644 --- a/tests/test_pubsub.py +++ b/tests/test_pubsub.py @@ -1143,9 +1143,9 @@ def get_msg(): assert msg is not None # timeout waiting for another message which never arrives assert is_connected() - with patch("redis.parsers._RESP2Parser.read_response") as mock1, patch( - "redis.parsers._HiredisParser.read_response" - ) as mock2, patch("redis.parsers._RESP3Parser.read_response") as mock3: + with patch("redis._parsers._RESP2Parser.read_response") as mock1, patch( + "redis._parsers._HiredisParser.read_response" + ) as mock2, patch("redis._parsers._RESP3Parser.read_response") as mock3: mock1.side_effect = BaseException("boom") mock2.side_effect = BaseException("boom") mock3.side_effect = BaseException("boom")