diff --git a/bson/__init__.py b/bson/__init__.py index 1efb1f7ff5..366ba98b7b 100644 --- a/bson/__init__.py +++ b/bson/__init__.py @@ -62,35 +62,37 @@ import struct import sys import uuid - -from codecs import (utf_8_decode as _utf_8_decode, - utf_8_encode as _utf_8_encode) +from codecs import utf_8_decode as _utf_8_decode +from codecs import utf_8_encode as _utf_8_encode from collections import abc as _abc -from bson.binary import (Binary, UuidRepresentation, ALL_UUID_SUBTYPES, - OLD_UUID_SUBTYPE, - JAVA_LEGACY, CSHARP_LEGACY, STANDARD, - UUID_SUBTYPE) +from bson.binary import ( + ALL_UUID_SUBTYPES, + CSHARP_LEGACY, + JAVA_LEGACY, + OLD_UUID_SUBTYPE, + STANDARD, + UUID_SUBTYPE, + Binary, + UuidRepresentation, +) from bson.code import Code -from bson.codec_options import ( - CodecOptions, DEFAULT_CODEC_OPTIONS, _raw_document_class) +from bson.codec_options import DEFAULT_CODEC_OPTIONS, CodecOptions, _raw_document_class from bson.dbref import DBRef from bson.decimal128 import Decimal128 -from bson.errors import (InvalidBSON, - InvalidDocument, - InvalidStringData) +from bson.errors import InvalidBSON, InvalidDocument, InvalidStringData from bson.int64 import Int64 from bson.max_key import MaxKey from bson.min_key import MinKey from bson.objectid import ObjectId from bson.regex import Regex -from bson.son import SON, RE_TYPE +from bson.son import RE_TYPE, SON from bson.timestamp import Timestamp from bson.tz_util import utc - try: from bson import _cbson + _USE_C = True except ImportError: _USE_C = False @@ -100,27 +102,27 @@ EPOCH_NAIVE = datetime.datetime.utcfromtimestamp(0) -BSONNUM = b"\x01" # Floating point -BSONSTR = b"\x02" # UTF-8 string -BSONOBJ = b"\x03" # Embedded document -BSONARR = b"\x04" # Array -BSONBIN = b"\x05" # Binary -BSONUND = b"\x06" # Undefined -BSONOID = b"\x07" # ObjectId -BSONBOO = b"\x08" # Boolean -BSONDAT = b"\x09" # UTC Datetime -BSONNUL = b"\x0A" # Null -BSONRGX = b"\x0B" # Regex -BSONREF = b"\x0C" # DBRef -BSONCOD = b"\x0D" # Javascript code -BSONSYM = b"\x0E" # Symbol -BSONCWS = b"\x0F" # Javascript code with scope -BSONINT = b"\x10" # 32bit int -BSONTIM = b"\x11" # Timestamp -BSONLON = b"\x12" # 64bit int -BSONDEC = b"\x13" # Decimal128 -BSONMIN = b"\xFF" # Min key -BSONMAX = b"\x7F" # Max key +BSONNUM = b"\x01" # Floating point +BSONSTR = b"\x02" # UTF-8 string +BSONOBJ = b"\x03" # Embedded document +BSONARR = b"\x04" # Array +BSONBIN = b"\x05" # Binary +BSONUND = b"\x06" # Undefined +BSONOID = b"\x07" # ObjectId +BSONBOO = b"\x08" # Boolean +BSONDAT = b"\x09" # UTC Datetime +BSONNUL = b"\x0A" # Null +BSONRGX = b"\x0B" # Regex +BSONREF = b"\x0C" # DBRef +BSONCOD = b"\x0D" # Javascript code +BSONSYM = b"\x0E" # Symbol +BSONCWS = b"\x0F" # Javascript code with scope +BSONINT = b"\x10" # 32bit int +BSONTIM = b"\x11" # Timestamp +BSONLON = b"\x12" # 64bit int +BSONDEC = b"\x13" # Decimal128 +BSONMIN = b"\xFF" # Min key +BSONMAX = b"\x7F" # Max key _UNPACK_FLOAT_FROM = struct.Struct(" obj_end: - raise InvalidBSON('bad binary object length') + raise InvalidBSON("bad binary object length") # Convert UUID subtypes to native UUIDs. if subtype in ALL_UUID_SUBTYPES: uuid_rep = opts.uuid_representation binary_value = Binary(data[position:end], subtype) - if ((uuid_rep == UuidRepresentation.UNSPECIFIED) or - (subtype == UUID_SUBTYPE and uuid_rep != STANDARD) or - (subtype == OLD_UUID_SUBTYPE and uuid_rep == STANDARD)): + if ( + (uuid_rep == UuidRepresentation.UNSPECIFIED) + or (subtype == UUID_SUBTYPE and uuid_rep != STANDARD) + or (subtype == OLD_UUID_SUBTYPE and uuid_rep == STANDARD) + ): return binary_value, end return binary_value.as_uuid(uuid_rep), end @@ -293,17 +297,16 @@ def _get_boolean(data, view, position, dummy0, dummy1, dummy2): """Decode a BSON true/false to python True/False.""" end = position + 1 boolean_byte = data[position:end] - if boolean_byte == b'\x00': + if boolean_byte == b"\x00": return False, end - elif boolean_byte == b'\x01': + elif boolean_byte == b"\x01": return True, end - raise InvalidBSON('invalid boolean value: %r' % boolean_byte) + raise InvalidBSON("invalid boolean value: %r" % boolean_byte) def _get_date(data, view, position, dummy0, opts, dummy1): """Decode a BSON datetime to python datetime.datetime.""" - return _millis_to_datetime( - _UNPACK_LONG_FROM(data, position)[0], opts), position + 8 + return _millis_to_datetime(_UNPACK_LONG_FROM(data, position)[0], opts), position + 8 def _get_code(data, view, position, obj_end, opts, element_name): @@ -315,11 +318,10 @@ def _get_code(data, view, position, obj_end, opts, element_name): def _get_code_w_scope(data, view, position, obj_end, opts, element_name): """Decode a BSON code_w_scope to bson.code.Code.""" code_end = position + _UNPACK_INT_FROM(data, position)[0] - code, position = _get_string( - data, view, position + 4, code_end, opts, element_name) + code, position = _get_string(data, view, position + 4, code_end, opts, element_name) scope, position = _get_object(data, view, position, code_end, opts, element_name) if position != code_end: - raise InvalidBSON('scope outside of javascript code boundaries') + raise InvalidBSON("scope outside of javascript code boundaries") return Code(code, scope), position @@ -333,8 +335,7 @@ def _get_regex(data, view, position, dummy0, opts, dummy1): def _get_ref(data, view, position, obj_end, opts, element_name): """Decode (deprecated) BSON DBPointer to bson.dbref.DBRef.""" - collection, position = _get_string( - data, view, position, obj_end, opts, element_name) + collection, position = _get_string(data, view, position, obj_end, opts, element_name) oid, position = _get_oid(data, view, position, obj_end, opts, element_name) return DBRef(collection, oid), position @@ -383,22 +384,26 @@ def _get_decimal128(data, view, position, dummy0, dummy1, dummy2): ord(BSONLON): _get_int64, ord(BSONDEC): _get_decimal128, ord(BSONMIN): lambda u, v, w, x, y, z: (MinKey(), w), - ord(BSONMAX): lambda u, v, w, x, y, z: (MaxKey(), w)} + ord(BSONMAX): lambda u, v, w, x, y, z: (MaxKey(), w), +} if _USE_C: + def _element_to_dict(data, view, position, obj_end, opts): return _cbson._element_to_dict(data, position, obj_end, opts) + else: + def _element_to_dict(data, view, position, obj_end, opts): """Decode a single key, value pair.""" element_type = data[position] position += 1 element_name, position = _get_c_string(data, view, position, opts) try: - value, position = _ELEMENT_GETTER[element_type](data, view, position, - obj_end, opts, - element_name) + value, position = _ELEMENT_GETTER[element_type]( + data, view, position, obj_end, opts, element_name + ) except KeyError: _raise_unknown_type(element_type, element_name) @@ -424,7 +429,7 @@ def _elements_to_dict(data, view, position, obj_end, opts, result=None): key, value, position = _element_to_dict(data, view, position, obj_end, opts) result[key] = value if position != obj_end: - raise InvalidBSON('bad object or element length') + raise InvalidBSON("bad object or element length") return result @@ -442,6 +447,8 @@ def _bson_to_dict(data, opts): # Change exception type to InvalidBSON but preserve traceback. _, exc_value, exc_tb = sys.exc_info() raise InvalidBSON(str(exc_value)).with_traceback(exc_tb) + + if _USE_C: _bson_to_dict = _cbson._bson_to_dict @@ -451,7 +458,7 @@ def _bson_to_dict(data, opts): _PACK_LENGTH_SUBTYPE = struct.Struct("> 49 == 1: - high = high & 0x7fffffffffff + high = high & 0x7FFFFFFFFFFF high |= _EXPONENT_MASK - high |= (biased_exponent & 0x3fff) << 47 + high |= (biased_exponent & 0x3FFF) << 47 else: high |= biased_exponent << 49 @@ -213,7 +210,8 @@ class Decimal128(object): >>> Decimal('NaN') == Decimal('NaN') False """ - __slots__ = ('__high', '__low') + + __slots__ = ("__high", "__low") _type_marker = 19 @@ -222,9 +220,11 @@ def __init__(self, value): self.__high, self.__low = _decimal_to_128(value) elif isinstance(value, (list, tuple)): if len(value) != 2: - raise ValueError('Invalid size for creation of Decimal128 ' - 'from list or tuple. Must have exactly 2 ' - 'elements.') + raise ValueError( + "Invalid size for creation of Decimal128 " + "from list or tuple. Must have exactly 2 " + "elements." + ) self.__high, self.__low = value else: raise TypeError("Cannot convert %r to Decimal128" % (value,)) @@ -238,25 +238,25 @@ def to_decimal(self): sign = 1 if (high & _SIGN) else 0 if (high & _SNAN) == _SNAN: - return decimal.Decimal((sign, (), 'N')) + return decimal.Decimal((sign, (), "N")) elif (high & _NAN) == _NAN: - return decimal.Decimal((sign, (), 'n')) + return decimal.Decimal((sign, (), "n")) elif (high & _INF) == _INF: - return decimal.Decimal((sign, (), 'F')) + return decimal.Decimal((sign, (), "F")) if (high & _EXPONENT_MASK) == _EXPONENT_MASK: - exponent = ((high & 0x1fffe00000000000) >> 47) - _EXPONENT_BIAS + exponent = ((high & 0x1FFFE00000000000) >> 47) - _EXPONENT_BIAS return decimal.Decimal((sign, (0,), exponent)) else: - exponent = ((high & 0x7fff800000000000) >> 49) - _EXPONENT_BIAS + exponent = ((high & 0x7FFF800000000000) >> 49) - _EXPONENT_BIAS arr = bytearray(15) - mask = 0x00000000000000ff + mask = 0x00000000000000FF for i in range(14, 6, -1): arr[i] = (low & mask) >> ((14 - i) << 3) mask = mask << 8 - mask = 0x00000000000000ff + mask = 0x00000000000000FF for i in range(6, 0, -1): arr[i] = (high & mask) >> ((6 - i) << 3) mask = mask << 8 @@ -265,8 +265,7 @@ def to_decimal(self): arr[0] = (high & mask) >> 48 # cdecimal only accepts a tuple for digits. - digits = tuple( - int(digit) for digit in str(int.from_bytes(arr, 'big'))) + digits = tuple(int(digit) for digit in str(int.from_bytes(arr, "big"))) with decimal.localcontext(_DEC128_CTX) as ctx: return ctx.create_decimal((sign, digits, exponent)) diff --git a/bson/errors.py b/bson/errors.py index 9bdb741371..7333b27b58 100644 --- a/bson/errors.py +++ b/bson/errors.py @@ -16,25 +16,20 @@ class BSONError(Exception): - """Base class for all BSON exceptions. - """ + """Base class for all BSON exceptions.""" class InvalidBSON(BSONError): - """Raised when trying to create a BSON object from invalid data. - """ + """Raised when trying to create a BSON object from invalid data.""" class InvalidStringData(BSONError): - """Raised when trying to encode a string containing non-UTF8 data. - """ + """Raised when trying to encode a string containing non-UTF8 data.""" class InvalidDocument(BSONError): - """Raised when trying to create a BSON object from an invalid document. - """ + """Raised when trying to create a BSON object from an invalid document.""" class InvalidId(BSONError): - """Raised when trying to create an ObjectId from invalid data. - """ + """Raised when trying to create an ObjectId from invalid data.""" diff --git a/bson/int64.py b/bson/int64.py index fb9bfe9143..d40b1953fe 100644 --- a/bson/int64.py +++ b/bson/int64.py @@ -14,6 +14,7 @@ """A BSON wrapper for long (int in python3)""" + class Int64(int): """Representation of the BSON int64 type. @@ -24,6 +25,7 @@ class Int64(int): :Parameters: - `value`: the numeric value to represent """ + __slots__ = () _type_marker = 18 diff --git a/bson/json_util.py b/bson/json_util.py index ed67d9a36c..0dab51b10d 100644 --- a/bson/json_util.py +++ b/bson/json_util.py @@ -95,8 +95,7 @@ import bson from bson import EPOCH_AWARE, RE_TYPE, SON -from bson.binary import (Binary, UuidRepresentation, ALL_UUID_SUBTYPES, - UUID_SUBTYPE) +from bson.binary import ALL_UUID_SUBTYPES, UUID_SUBTYPE, Binary, UuidRepresentation from bson.code import Code from bson.codec_options import CodecOptions from bson.dbref import DBRef @@ -109,7 +108,6 @@ from bson.timestamp import Timestamp from bson.tz_util import utc - _RE_OPT_TABLE = { "i": re.I, "l": re.L, @@ -247,57 +245,58 @@ class JSONOptions(CodecOptions): Changed default value of `tz_aware` to False. """ - def __new__(cls, strict_number_long=None, - datetime_representation=None, - strict_uuid=None, json_mode=JSONMode.RELAXED, - *args, **kwargs): + def __new__( + cls, + strict_number_long=None, + datetime_representation=None, + strict_uuid=None, + json_mode=JSONMode.RELAXED, + *args, + **kwargs + ): kwargs["tz_aware"] = kwargs.get("tz_aware", False) if kwargs["tz_aware"]: kwargs["tzinfo"] = kwargs.get("tzinfo", utc) - if datetime_representation not in (DatetimeRepresentation.LEGACY, - DatetimeRepresentation.NUMBERLONG, - DatetimeRepresentation.ISO8601, - None): + if datetime_representation not in ( + DatetimeRepresentation.LEGACY, + DatetimeRepresentation.NUMBERLONG, + DatetimeRepresentation.ISO8601, + None, + ): raise ValueError( "JSONOptions.datetime_representation must be one of LEGACY, " - "NUMBERLONG, or ISO8601 from DatetimeRepresentation.") + "NUMBERLONG, or ISO8601 from DatetimeRepresentation." + ) self = super(JSONOptions, cls).__new__(cls, *args, **kwargs) - if json_mode not in (JSONMode.LEGACY, - JSONMode.RELAXED, - JSONMode.CANONICAL): + if json_mode not in (JSONMode.LEGACY, JSONMode.RELAXED, JSONMode.CANONICAL): raise ValueError( "JSONOptions.json_mode must be one of LEGACY, RELAXED, " - "or CANONICAL from JSONMode.") + "or CANONICAL from JSONMode." + ) self.json_mode = json_mode if self.json_mode == JSONMode.RELAXED: if strict_number_long: - raise ValueError( - "Cannot specify strict_number_long=True with" - " JSONMode.RELAXED") - if datetime_representation not in (None, - DatetimeRepresentation.ISO8601): + raise ValueError("Cannot specify strict_number_long=True with" " JSONMode.RELAXED") + if datetime_representation not in (None, DatetimeRepresentation.ISO8601): raise ValueError( "datetime_representation must be DatetimeRepresentation." - "ISO8601 or omitted with JSONMode.RELAXED") + "ISO8601 or omitted with JSONMode.RELAXED" + ) if strict_uuid not in (None, True): - raise ValueError( - "Cannot specify strict_uuid=False with JSONMode.RELAXED") + raise ValueError("Cannot specify strict_uuid=False with JSONMode.RELAXED") self.strict_number_long = False self.datetime_representation = DatetimeRepresentation.ISO8601 self.strict_uuid = True elif self.json_mode == JSONMode.CANONICAL: if strict_number_long not in (None, True): - raise ValueError( - "Cannot specify strict_number_long=False with" - " JSONMode.RELAXED") - if datetime_representation not in ( - None, DatetimeRepresentation.NUMBERLONG): + raise ValueError("Cannot specify strict_number_long=False with" " JSONMode.RELAXED") + if datetime_representation not in (None, DatetimeRepresentation.NUMBERLONG): raise ValueError( "datetime_representation must be DatetimeRepresentation." - "NUMBERLONG or omitted with JSONMode.RELAXED") + "NUMBERLONG or omitted with JSONMode.RELAXED" + ) if strict_uuid not in (None, True): - raise ValueError( - "Cannot specify strict_uuid=False with JSONMode.RELAXED") + raise ValueError("Cannot specify strict_uuid=False with JSONMode.RELAXED") self.strict_number_long = True self.datetime_representation = DatetimeRepresentation.NUMBERLONG self.strict_uuid = True @@ -314,23 +313,30 @@ def __new__(cls, strict_number_long=None, return self def _arguments_repr(self): - return ('strict_number_long=%r, ' - 'datetime_representation=%r, ' - 'strict_uuid=%r, json_mode=%r, %s' % ( - self.strict_number_long, - self.datetime_representation, - self.strict_uuid, - self.json_mode, - super(JSONOptions, self)._arguments_repr())) + return ( + "strict_number_long=%r, " + "datetime_representation=%r, " + "strict_uuid=%r, json_mode=%r, %s" + % ( + self.strict_number_long, + self.datetime_representation, + self.strict_uuid, + self.json_mode, + super(JSONOptions, self)._arguments_repr(), + ) + ) def _options_dict(self): # TODO: PYTHON-2442 use _asdict() instead options_dict = super(JSONOptions, self)._options_dict() - options_dict.update({ - 'strict_number_long': self.strict_number_long, - 'datetime_representation': self.datetime_representation, - 'strict_uuid': self.strict_uuid, - 'json_mode': self.json_mode}) + options_dict.update( + { + "strict_number_long": self.strict_number_long, + "datetime_representation": self.datetime_representation, + "strict_uuid": self.strict_uuid, + "json_mode": self.json_mode, + } + ) return options_dict def with_options(self, **kwargs): @@ -347,8 +353,7 @@ def with_options(self, **kwargs): .. versionadded:: 3.12 """ opts = self._options_dict() - for opt in ('strict_number_long', 'datetime_representation', - 'strict_uuid', 'json_mode'): + for opt in ("strict_number_long", "datetime_representation", "strict_uuid", "json_mode"): opts[opt] = kwargs.get(opt, getattr(self, opt)) opts.update(kwargs) return JSONOptions(**opts) @@ -435,8 +440,7 @@ def loads(s, *args, **kwargs): Accepts optional parameter `json_options`. See :class:`JSONOptions`. """ json_options = kwargs.pop("json_options", DEFAULT_JSON_OPTIONS) - kwargs["object_pairs_hook"] = lambda pairs: object_pairs_hook( - pairs, json_options) + kwargs["object_pairs_hook"] = lambda pairs: object_pairs_hook(pairs, json_options) return json.loads(s, *args, **kwargs) @@ -444,10 +448,9 @@ def _json_convert(obj, json_options=DEFAULT_JSON_OPTIONS): """Recursive helper method that converts BSON types so they can be converted into json. """ - if hasattr(obj, 'items'): - return SON(((k, _json_convert(v, json_options)) - for k, v in obj.items())) - elif hasattr(obj, '__iter__') and not isinstance(obj, (str, bytes)): + if hasattr(obj, "items"): + return SON(((k, _json_convert(v, json_options)) for k, v in obj.items())) + elif hasattr(obj, "__iter__") and not isinstance(obj, (str, bytes)): return list((_json_convert(v, json_options) for v in obj)) try: return default(obj, json_options) @@ -462,9 +465,11 @@ def object_pairs_hook(pairs, json_options=DEFAULT_JSON_OPTIONS): def object_hook(dct, json_options=DEFAULT_JSON_OPTIONS): if "$oid" in dct: return _parse_canonical_oid(dct) - if (isinstance(dct.get('$ref'), str) and - "$id" in dct and - isinstance(dct.get('$db'), (str, type(None)))): + if ( + isinstance(dct.get("$ref"), str) + and "$id" in dct + and isinstance(dct.get("$db"), (str, type(None))) + ): return _parse_canonical_dbref(dct) if "$date" in dct: return _parse_canonical_datetime(dct, json_options) @@ -520,9 +525,9 @@ def _parse_legacy_regex(doc): def _parse_legacy_uuid(doc, json_options): """Decode a JSON legacy $uuid to Python UUID.""" if len(doc) != 1: - raise TypeError('Bad $uuid, extra field(s): %s' % (doc,)) + raise TypeError("Bad $uuid, extra field(s): %s" % (doc,)) if not isinstance(doc["$uuid"], str): - raise TypeError('$uuid must be a string: %s' % (doc,)) + raise TypeError("$uuid must be a string: %s" % (doc,)) if json_options.uuid_representation == UuidRepresentation.UNSPECIFIED: return Binary.from_uuid(uuid.UUID(doc["$uuid"])) else: @@ -554,7 +559,7 @@ def _parse_legacy_binary(doc, json_options): if isinstance(doc["$type"], int): doc["$type"] = "%02x" % doc["$type"] subtype = int(doc["$type"], 16) - if subtype >= 0xffffff80: # Handle mongoexport values + if subtype >= 0xFFFFFF80: # Handle mongoexport values subtype = int(doc["$type"][6:], 16) data = base64.b64decode(doc["$binary"].encode()) return _binary_or_uuid(data, subtype, json_options) @@ -565,13 +570,13 @@ def _parse_canonical_binary(doc, json_options): b64 = binary["base64"] subtype = binary["subType"] if not isinstance(b64, str): - raise TypeError('$binary base64 must be a string: %s' % (doc,)) + raise TypeError("$binary base64 must be a string: %s" % (doc,)) if not isinstance(subtype, str) or len(subtype) > 2: - raise TypeError('$binary subType must be a string at most 2 ' - 'characters: %s' % (doc,)) + raise TypeError("$binary subType must be a string at most 2 " "characters: %s" % (doc,)) if len(binary) != 2: - raise TypeError('$binary must include only "base64" and "subType" ' - 'components: %s' % (doc,)) + raise TypeError( + '$binary must include only "base64" and "subType" ' "components: %s" % (doc,) + ) data = base64.b64decode(b64.encode()) return _binary_or_uuid(data, int(subtype, 16), json_options) @@ -581,46 +586,46 @@ def _parse_canonical_datetime(doc, json_options): """Decode a JSON datetime to python datetime.datetime.""" dtm = doc["$date"] if len(doc) != 1: - raise TypeError('Bad $date, extra field(s): %s' % (doc,)) + raise TypeError("Bad $date, extra field(s): %s" % (doc,)) # mongoexport 2.6 and newer if isinstance(dtm, str): # Parse offset - if dtm[-1] == 'Z': + if dtm[-1] == "Z": dt = dtm[:-1] - offset = 'Z' - elif dtm[-6] in ('+', '-') and dtm[-3] == ':': + offset = "Z" + elif dtm[-6] in ("+", "-") and dtm[-3] == ":": # (+|-)HH:MM dt = dtm[:-6] offset = dtm[-6:] - elif dtm[-5] in ('+', '-'): + elif dtm[-5] in ("+", "-"): # (+|-)HHMM dt = dtm[:-5] offset = dtm[-5:] - elif dtm[-3] in ('+', '-'): + elif dtm[-3] in ("+", "-"): # (+|-)HH dt = dtm[:-3] offset = dtm[-3:] else: dt = dtm - offset = '' + offset = "" # Parse the optional factional seconds portion. - dot_index = dt.rfind('.') + dot_index = dt.rfind(".") microsecond = 0 if dot_index != -1: microsecond = int(float(dt[dot_index:]) * 1000000) dt = dt[:dot_index] - aware = datetime.datetime.strptime( - dt, "%Y-%m-%dT%H:%M:%S").replace(microsecond=microsecond, - tzinfo=utc) + aware = datetime.datetime.strptime(dt, "%Y-%m-%dT%H:%M:%S").replace( + microsecond=microsecond, tzinfo=utc + ) - if offset and offset != 'Z': + if offset and offset != "Z": if len(offset) == 6: - hours, minutes = offset[1:].split(':') - secs = (int(hours) * 3600 + int(minutes) * 60) + hours, minutes = offset[1:].split(":") + secs = int(hours) * 3600 + int(minutes) * 60 elif len(offset) == 5: - secs = (int(offset[1:3]) * 3600 + int(offset[3:]) * 60) + secs = int(offset[1:3]) * 3600 + int(offset[3:]) * 60 elif len(offset) == 3: secs = int(offset[1:3]) * 3600 if offset[0] == "-": @@ -639,133 +644,130 @@ def _parse_canonical_datetime(doc, json_options): def _parse_canonical_oid(doc): """Decode a JSON ObjectId to bson.objectid.ObjectId.""" if len(doc) != 1: - raise TypeError('Bad $oid, extra field(s): %s' % (doc,)) - return ObjectId(doc['$oid']) + raise TypeError("Bad $oid, extra field(s): %s" % (doc,)) + return ObjectId(doc["$oid"]) def _parse_canonical_symbol(doc): """Decode a JSON symbol to Python string.""" - symbol = doc['$symbol'] + symbol = doc["$symbol"] if len(doc) != 1: - raise TypeError('Bad $symbol, extra field(s): %s' % (doc,)) + raise TypeError("Bad $symbol, extra field(s): %s" % (doc,)) return str(symbol) def _parse_canonical_code(doc): """Decode a JSON code to bson.code.Code.""" for key in doc: - if key not in ('$code', '$scope'): - raise TypeError('Bad $code, extra field(s): %s' % (doc,)) - return Code(doc['$code'], scope=doc.get('$scope')) + if key not in ("$code", "$scope"): + raise TypeError("Bad $code, extra field(s): %s" % (doc,)) + return Code(doc["$code"], scope=doc.get("$scope")) def _parse_canonical_regex(doc): """Decode a JSON regex to bson.regex.Regex.""" - regex = doc['$regularExpression'] + regex = doc["$regularExpression"] if len(doc) != 1: - raise TypeError('Bad $regularExpression, extra field(s): %s' % (doc,)) + raise TypeError("Bad $regularExpression, extra field(s): %s" % (doc,)) if len(regex) != 2: - raise TypeError('Bad $regularExpression must include only "pattern"' - 'and "options" components: %s' % (doc,)) - opts = regex['options'] + raise TypeError( + 'Bad $regularExpression must include only "pattern"' + 'and "options" components: %s' % (doc,) + ) + opts = regex["options"] if not isinstance(opts, str): - raise TypeError('Bad $regularExpression options, options must be ' - 'string, was type %s' % (type(opts))) - return Regex(regex['pattern'], opts) + raise TypeError( + "Bad $regularExpression options, options must be " "string, was type %s" % (type(opts)) + ) + return Regex(regex["pattern"], opts) def _parse_canonical_dbref(doc): """Decode a JSON DBRef to bson.dbref.DBRef.""" - return DBRef(doc.pop('$ref'), doc.pop('$id'), - database=doc.pop('$db', None), **doc) + return DBRef(doc.pop("$ref"), doc.pop("$id"), database=doc.pop("$db", None), **doc) def _parse_canonical_dbpointer(doc): """Decode a JSON (deprecated) DBPointer to bson.dbref.DBRef.""" - dbref = doc['$dbPointer'] + dbref = doc["$dbPointer"] if len(doc) != 1: - raise TypeError('Bad $dbPointer, extra field(s): %s' % (doc,)) + raise TypeError("Bad $dbPointer, extra field(s): %s" % (doc,)) if isinstance(dbref, DBRef): dbref_doc = dbref.as_doc() # DBPointer must not contain $db in its value. if dbref.database is not None: - raise TypeError( - 'Bad $dbPointer, extra field $db: %s' % (dbref_doc,)) + raise TypeError("Bad $dbPointer, extra field $db: %s" % (dbref_doc,)) if not isinstance(dbref.id, ObjectId): - raise TypeError( - 'Bad $dbPointer, $id must be an ObjectId: %s' % (dbref_doc,)) + raise TypeError("Bad $dbPointer, $id must be an ObjectId: %s" % (dbref_doc,)) if len(dbref_doc) != 2: - raise TypeError( - 'Bad $dbPointer, extra field(s) in DBRef: %s' % (dbref_doc,)) + raise TypeError("Bad $dbPointer, extra field(s) in DBRef: %s" % (dbref_doc,)) return dbref else: - raise TypeError('Bad $dbPointer, expected a DBRef: %s' % (doc,)) + raise TypeError("Bad $dbPointer, expected a DBRef: %s" % (doc,)) def _parse_canonical_int32(doc): """Decode a JSON int32 to python int.""" - i_str = doc['$numberInt'] + i_str = doc["$numberInt"] if len(doc) != 1: - raise TypeError('Bad $numberInt, extra field(s): %s' % (doc,)) + raise TypeError("Bad $numberInt, extra field(s): %s" % (doc,)) if not isinstance(i_str, str): - raise TypeError('$numberInt must be string: %s' % (doc,)) + raise TypeError("$numberInt must be string: %s" % (doc,)) return int(i_str) def _parse_canonical_int64(doc): """Decode a JSON int64 to bson.int64.Int64.""" - l_str = doc['$numberLong'] + l_str = doc["$numberLong"] if len(doc) != 1: - raise TypeError('Bad $numberLong, extra field(s): %s' % (doc,)) + raise TypeError("Bad $numberLong, extra field(s): %s" % (doc,)) return Int64(l_str) def _parse_canonical_double(doc): """Decode a JSON double to python float.""" - d_str = doc['$numberDouble'] + d_str = doc["$numberDouble"] if len(doc) != 1: - raise TypeError('Bad $numberDouble, extra field(s): %s' % (doc,)) + raise TypeError("Bad $numberDouble, extra field(s): %s" % (doc,)) if not isinstance(d_str, str): - raise TypeError('$numberDouble must be string: %s' % (doc,)) + raise TypeError("$numberDouble must be string: %s" % (doc,)) return float(d_str) def _parse_canonical_decimal128(doc): """Decode a JSON decimal128 to bson.decimal128.Decimal128.""" - d_str = doc['$numberDecimal'] + d_str = doc["$numberDecimal"] if len(doc) != 1: - raise TypeError('Bad $numberDecimal, extra field(s): %s' % (doc,)) + raise TypeError("Bad $numberDecimal, extra field(s): %s" % (doc,)) if not isinstance(d_str, str): - raise TypeError('$numberDecimal must be string: %s' % (doc,)) + raise TypeError("$numberDecimal must be string: %s" % (doc,)) return Decimal128(d_str) def _parse_canonical_minkey(doc): """Decode a JSON MinKey to bson.min_key.MinKey.""" - if type(doc['$minKey']) is not int or doc['$minKey'] != 1: - raise TypeError('$minKey value must be 1: %s' % (doc,)) + if type(doc["$minKey"]) is not int or doc["$minKey"] != 1: + raise TypeError("$minKey value must be 1: %s" % (doc,)) if len(doc) != 1: - raise TypeError('Bad $minKey, extra field(s): %s' % (doc,)) + raise TypeError("Bad $minKey, extra field(s): %s" % (doc,)) return MinKey() def _parse_canonical_maxkey(doc): """Decode a JSON MaxKey to bson.max_key.MaxKey.""" - if type(doc['$maxKey']) is not int or doc['$maxKey'] != 1: - raise TypeError('$maxKey value must be 1: %s', (doc,)) + if type(doc["$maxKey"]) is not int or doc["$maxKey"] != 1: + raise TypeError("$maxKey value must be 1: %s", (doc,)) if len(doc) != 1: - raise TypeError('Bad $minKey, extra field(s): %s' % (doc,)) + raise TypeError("Bad $minKey, extra field(s): %s" % (doc,)) return MaxKey() def _encode_binary(data, subtype, json_options): if json_options.json_mode == JSONMode.LEGACY: - return SON([ - ('$binary', base64.b64encode(data).decode()), - ('$type', "%02x" % subtype)]) - return {'$binary': SON([ - ('base64', base64.b64encode(data).decode()), - ('subType', "%02x" % subtype)])} + return SON([("$binary", base64.b64encode(data).decode()), ("$type", "%02x" % subtype)]) + return { + "$binary": SON([("base64", base64.b64encode(data).decode()), ("subType", "%02x" % subtype)]) + } def default(obj, json_options=DEFAULT_JSON_OPTIONS): @@ -776,24 +778,23 @@ def default(obj, json_options=DEFAULT_JSON_OPTIONS): if isinstance(obj, DBRef): return _json_convert(obj.as_doc(), json_options=json_options) if isinstance(obj, datetime.datetime): - if (json_options.datetime_representation == - DatetimeRepresentation.ISO8601): + if json_options.datetime_representation == DatetimeRepresentation.ISO8601: if not obj.tzinfo: obj = obj.replace(tzinfo=utc) if obj >= EPOCH_AWARE: off = obj.tzinfo.utcoffset(obj) if (off.days, off.seconds, off.microseconds) == (0, 0, 0): - tz_string = 'Z' + tz_string = "Z" else: - tz_string = obj.strftime('%z') + tz_string = obj.strftime("%z") millis = int(obj.microsecond / 1000) fracsecs = ".%03d" % (millis,) if millis else "" - return {"$date": "%s%s%s" % ( - obj.strftime("%Y-%m-%dT%H:%M:%S"), fracsecs, tz_string)} + return { + "$date": "%s%s%s" % (obj.strftime("%Y-%m-%dT%H:%M:%S"), fracsecs, tz_string) + } millis = bson._datetime_to_millis(obj) - if (json_options.datetime_representation == - DatetimeRepresentation.LEGACY): + if json_options.datetime_representation == DatetimeRepresentation.LEGACY: return {"$date": millis} return {"$date": {"$numberLong": str(millis)}} if json_options.strict_number_long and isinstance(obj, Int64): @@ -815,11 +816,10 @@ def default(obj, json_options=DEFAULT_JSON_OPTIONS): if isinstance(obj.pattern, str): pattern = obj.pattern else: - pattern = obj.pattern.decode('utf-8') + pattern = obj.pattern.decode("utf-8") if json_options.json_mode == JSONMode.LEGACY: return SON([("$regex", pattern), ("$options", flags)]) - return {'$regularExpression': SON([("pattern", pattern), - ("options", flags)])} + return {"$regularExpression": SON([("pattern", pattern), ("options", flags)])} if isinstance(obj, MinKey): return {"$minKey": 1} if isinstance(obj, MaxKey): @@ -828,18 +828,15 @@ def default(obj, json_options=DEFAULT_JSON_OPTIONS): return {"$timestamp": SON([("t", obj.time), ("i", obj.inc)])} if isinstance(obj, Code): if obj.scope is None: - return {'$code': str(obj)} - return SON([ - ('$code', str(obj)), - ('$scope', _json_convert(obj.scope, json_options))]) + return {"$code": str(obj)} + return SON([("$code", str(obj)), ("$scope", _json_convert(obj.scope, json_options))]) if isinstance(obj, Binary): return _encode_binary(obj, obj.subtype, json_options) if isinstance(obj, bytes): return _encode_binary(obj, 0, json_options) if isinstance(obj, uuid.UUID): if json_options.strict_uuid: - binval = Binary.from_uuid( - obj, uuid_representation=json_options.uuid_representation) + binval = Binary.from_uuid(obj, uuid_representation=json_options.uuid_representation) return _encode_binary(binval, binval.subtype, json_options) else: return {"$uuid": obj.hex} @@ -847,19 +844,18 @@ def default(obj, json_options=DEFAULT_JSON_OPTIONS): return {"$numberDecimal": str(obj)} if isinstance(obj, bool): return obj - if (json_options.json_mode == JSONMode.CANONICAL and - isinstance(obj, int)): - if -2 ** 31 <= obj < 2 ** 31: - return {'$numberInt': str(obj)} - return {'$numberLong': str(obj)} + if json_options.json_mode == JSONMode.CANONICAL and isinstance(obj, int): + if -(2**31) <= obj < 2**31: + return {"$numberInt": str(obj)} + return {"$numberLong": str(obj)} if json_options.json_mode != JSONMode.LEGACY and isinstance(obj, float): if math.isnan(obj): - return {'$numberDouble': 'NaN'} + return {"$numberDouble": "NaN"} elif math.isinf(obj): - representation = 'Infinity' if obj > 0 else '-Infinity' - return {'$numberDouble': representation} + representation = "Infinity" if obj > 0 else "-Infinity" + return {"$numberDouble": representation} elif json_options.json_mode == JSONMode.CANONICAL: # repr() will return the shortest string guaranteed to produce the # original value, when float() is called on it. - return {'$numberDouble': str(repr(obj))} + return {"$numberDouble": str(repr(obj))} raise TypeError("%r is not JSON serializable" % obj) diff --git a/bson/max_key.py b/bson/max_key.py index afd7fcb1b3..322c539312 100644 --- a/bson/max_key.py +++ b/bson/max_key.py @@ -18,6 +18,7 @@ class MaxKey(object): """MongoDB internal MaxKey type.""" + __slots__ = () _type_marker = 127 diff --git a/bson/min_key.py b/bson/min_key.py index bcb7f9e60f..7520dd2654 100644 --- a/bson/min_key.py +++ b/bson/min_key.py @@ -18,6 +18,7 @@ class MinKey(object): """MongoDB internal MinKey type.""" + __slots__ = () _type_marker = 255 diff --git a/bson/objectid.py b/bson/objectid.py index faf8910edc..38a189cac0 100644 --- a/bson/objectid.py +++ b/bson/objectid.py @@ -23,20 +23,19 @@ import struct import threading import time - from random import SystemRandom from bson.errors import InvalidId from bson.tz_util import utc - _MAX_COUNTER_VALUE = 0xFFFFFF def _raise_invalid_id(oid): raise InvalidId( "%r is not a valid ObjectId, it must be a 12-byte input" - " or a 24-character hex string" % oid) + " or a 24-character hex string" % oid + ) def _random_bytes(): @@ -45,8 +44,7 @@ def _random_bytes(): class ObjectId(object): - """A MongoDB ObjectId. - """ + """A MongoDB ObjectId.""" _pid = os.getpid() @@ -55,7 +53,7 @@ class ObjectId(object): __random = _random_bytes() - __slots__ = ('__id',) + __slots__ = ("__id",) _type_marker = 7 @@ -135,8 +133,7 @@ def from_datetime(cls, generation_time): if generation_time.utcoffset() is not None: generation_time = generation_time - generation_time.utcoffset() timestamp = calendar.timegm(generation_time.timetuple()) - oid = struct.pack( - ">I", int(timestamp)) + b"\x00\x00\x00\x00\x00\x00\x00\x00" + oid = struct.pack(">I", int(timestamp)) + b"\x00\x00\x00\x00\x00\x00\x00\x00" return cls(oid) @classmethod @@ -159,8 +156,7 @@ def is_valid(cls, oid): @classmethod def _random(cls): - """Generate a 5-byte random number once per process. - """ + """Generate a 5-byte random number once per process.""" pid = os.getpid() if pid != cls._pid: cls._pid = pid @@ -168,8 +164,7 @@ def _random(cls): return cls.__random def __generate(self): - """Generate a new value for this ObjectId. - """ + """Generate a new value for this ObjectId.""" # 4 bytes current time oid = struct.pack(">I", int(time.time())) @@ -206,13 +201,13 @@ def __validate(self, oid): else: _raise_invalid_id(oid) else: - raise TypeError("id must be an instance of (bytes, str, ObjectId), " - "not %s" % (type(oid),)) + raise TypeError( + "id must be an instance of (bytes, str, ObjectId), " "not %s" % (type(oid),) + ) @property def binary(self): - """12-byte binary representation of this ObjectId. - """ + """12-byte binary representation of this ObjectId.""" return self.__id @property @@ -234,8 +229,7 @@ def __getstate__(self): return self.__id def __setstate__(self, value): - """explicit state set from pickling - """ + """explicit state set from pickling""" # Provide backwards compatability with OIDs # pickled with pymongo-1.9 or older. if isinstance(value, dict): @@ -246,7 +240,7 @@ def __setstate__(self, value): # In python 3.x this has to be converted to `bytes` # by encoding latin-1. if isinstance(oid, str): - self.__id = oid.encode('latin-1') + self.__id = oid.encode("latin-1") else: self.__id = oid diff --git a/bson/raw_bson.py b/bson/raw_bson.py index bfe888b6b7..a2566e69ff 100644 --- a/bson/raw_bson.py +++ b/bson/raw_bson.py @@ -53,9 +53,9 @@ from collections.abc import Mapping as _Mapping -from bson import _raw_to_dict, _get_object_size -from bson.codec_options import ( - DEFAULT_CODEC_OPTIONS as DEFAULT, _RAW_BSON_DOCUMENT_MARKER) +from bson import _get_object_size, _raw_to_dict +from bson.codec_options import _RAW_BSON_DOCUMENT_MARKER +from bson.codec_options import DEFAULT_CODEC_OPTIONS as DEFAULT from bson.son import SON @@ -67,7 +67,7 @@ class RawBSONDocument(_Mapping): RawBSONDocument decode its bytes. """ - __slots__ = ('__raw', '__inflated_doc', '__codec_options') + __slots__ = ("__raw", "__inflated_doc", "__codec_options") _type_marker = _RAW_BSON_DOCUMENT_MARKER def __init__(self, bson_bytes, codec_options=None): @@ -113,7 +113,8 @@ class from the standard library so it can be used like a read-only elif codec_options.document_class is not RawBSONDocument: raise TypeError( "RawBSONDocument cannot use CodecOptions with document " - "class %s" % (codec_options.document_class, )) + "class %s" % (codec_options.document_class,) + ) self.__codec_options = codec_options # Validate the bson object size. _get_object_size(bson_bytes, 0, len(bson_bytes)) @@ -133,8 +134,7 @@ def __inflated(self): # We already validated the object's size when this document was # created, so no need to do that again. # Use SON to preserve ordering of elements. - self.__inflated_doc = _inflate_bson( - self.__raw, self.__codec_options) + self.__inflated_doc = _inflate_bson(self.__raw, self.__codec_options) return self.__inflated_doc def __getitem__(self, item): @@ -152,8 +152,7 @@ def __eq__(self, other): return NotImplemented def __repr__(self): - return ("RawBSONDocument(%r, codec_options=%r)" - % (self.raw, self.__codec_options)) + return "RawBSONDocument(%r, codec_options=%r)" % (self.raw, self.__codec_options) def _inflate_bson(bson_bytes, codec_options): @@ -166,8 +165,7 @@ def _inflate_bson(bson_bytes, codec_options): must be :class:`RawBSONDocument`. """ # Use SON to preserve ordering of elements. - return _raw_to_dict( - bson_bytes, 4, len(bson_bytes)-1, codec_options, SON()) + return _raw_to_dict(bson_bytes, 4, len(bson_bytes) - 1, codec_options, SON()) DEFAULT_RAW_BSON_OPTIONS = DEFAULT.with_options(document_class=RawBSONDocument) diff --git a/bson/regex.py b/bson/regex.py index 5cf097f08c..95cd5c566b 100644 --- a/bson/regex.py +++ b/bson/regex.py @@ -17,8 +17,8 @@ import re -from bson.son import RE_TYPE from bson._helpers import _getstate_slots, _setstate_slots +from bson.son import RE_TYPE def str_flags_to_int(str_flags): @@ -41,6 +41,7 @@ def str_flags_to_int(str_flags): class Regex(object): """BSON regular expression data.""" + __slots__ = ("pattern", "flags") __getstate__ = _getstate_slots @@ -74,9 +75,7 @@ def from_native(cls, regex): .. _PCRE: http://www.pcre.org/ """ if not isinstance(regex, RE_TYPE): - raise TypeError( - "regex must be a compiled regular expression, not %s" - % type(regex)) + raise TypeError("regex must be a compiled regular expression, not %s" % type(regex)) return Regex(regex.pattern, regex.flags) @@ -100,8 +99,7 @@ def __init__(self, pattern, flags=0): elif isinstance(flags, int): self.flags = flags else: - raise TypeError( - "flags must be a string or int, not %s" % type(flags)) + raise TypeError("flags must be a string or int, not %s" % type(flags)) def __eq__(self, other): if isinstance(other, Regex): diff --git a/bson/son.py b/bson/son.py index 5a3210fcdb..9ad37b5732 100644 --- a/bson/son.py +++ b/bson/son.py @@ -20,7 +20,6 @@ import copy import re - from collections.abc import Mapping as _Mapping # This sort of sucks, but seems to be as good as it gets... @@ -101,8 +100,7 @@ def setdefault(self, key, default=None): def pop(self, key, *args): if len(args) > 1: - raise TypeError("pop expected at most 2 arguments, got "\ - + repr(1 + len(args))) + raise TypeError("pop expected at most 2 arguments, got " + repr(1 + len(args))) try: value = self[key] except KeyError: @@ -116,7 +114,7 @@ def popitem(self): try: k, v = next(iter(self.items())) except StopIteration: - raise KeyError('container is empty') + raise KeyError("container is empty") del self[k] return (k, v) @@ -124,10 +122,10 @@ def update(self, other=None, **kwargs): # Make progressively weaker assumptions about "other" if other is None: pass - elif hasattr(other, 'items'): + elif hasattr(other, "items"): for k, v in other.items(): self[k] = v - elif hasattr(other, 'keys'): + elif hasattr(other, "keys"): for k in other.keys(): self[k] = other[k] else: @@ -147,8 +145,7 @@ def __eq__(self, other): regular dictionary is order-insensitive. """ if isinstance(other, SON): - return len(self) == len(other) and list(self.items()) == \ - list(other.items()) + return len(self) == len(other) and list(self.items()) == list(other.items()) return self.to_dict() == other def __ne__(self, other): @@ -168,9 +165,7 @@ def transform_value(value): if isinstance(value, list): return [transform_value(v) for v in value] elif isinstance(value, _Mapping): - return dict([ - (k, transform_value(v)) - for k, v in value.items()]) + return dict([(k, transform_value(v)) for k, v in value.items()]) else: return value diff --git a/bson/timestamp.py b/bson/timestamp.py index 69c061d2a5..e522159bb3 100644 --- a/bson/timestamp.py +++ b/bson/timestamp.py @@ -18,15 +18,15 @@ import calendar import datetime -from bson.tz_util import utc from bson._helpers import _getstate_slots, _setstate_slots +from bson.tz_util import utc UPPERBOUND = 4294967296 class Timestamp(object): - """MongoDB internal timestamps used in the opLog. - """ + """MongoDB internal timestamps used in the opLog.""" + __slots__ = ("__time", "__inc") __getstate__ = _getstate_slots @@ -70,19 +70,17 @@ def __init__(self, time, inc): @property def time(self): - """Get the time portion of this :class:`Timestamp`. - """ + """Get the time portion of this :class:`Timestamp`.""" return self.__time @property def inc(self): - """Get the inc portion of this :class:`Timestamp`. - """ + """Get the inc portion of this :class:`Timestamp`.""" return self.__inc def __eq__(self, other): if isinstance(other, Timestamp): - return (self.__time == other.time and self.__inc == other.inc) + return self.__time == other.time and self.__inc == other.inc else: return NotImplemented diff --git a/bson/tz_util.py b/bson/tz_util.py index 6ec918fb2b..6cfb230cc5 100644 --- a/bson/tz_util.py +++ b/bson/tz_util.py @@ -14,8 +14,7 @@ """Timezone related utilities for BSON.""" -from datetime import (timedelta, - tzinfo) +from datetime import timedelta, tzinfo ZERO = timedelta(0) diff --git a/doc/conf.py b/doc/conf.py index facb74f470..47debcf14c 100644 --- a/doc/conf.py +++ b/doc/conf.py @@ -4,8 +4,10 @@ # # This file is execfile()d with the current directory set to its containing dir. -import sys, os -sys.path[0:0] = [os.path.abspath('..')] +import os +import sys + +sys.path[0:0] = [os.path.abspath("..")] import pymongo @@ -13,21 +15,26 @@ # Add any Sphinx extension module names here, as strings. They can be extensions # coming with Sphinx (named 'sphinx.ext.*') or your custom ones. -extensions = ['sphinx.ext.autodoc', 'sphinx.ext.doctest', 'sphinx.ext.coverage', - 'sphinx.ext.todo', 'sphinx.ext.intersphinx'] +extensions = [ + "sphinx.ext.autodoc", + "sphinx.ext.doctest", + "sphinx.ext.coverage", + "sphinx.ext.todo", + "sphinx.ext.intersphinx", +] # Add any paths that contain templates here, relative to this directory. -templates_path = ['_templates'] +templates_path = ["_templates"] # The suffix of source filenames. -source_suffix = '.rst' +source_suffix = ".rst" # The master toctree document. -master_doc = 'index' +master_doc = "index" # General information about the project. -project = 'PyMongo' -copyright = 'MongoDB, Inc. 2008-present. MongoDB, Mongo, and the leaf logo are registered trademarks of MongoDB, Inc' +project = "PyMongo" +copyright = "MongoDB, Inc. 2008-present. MongoDB, Mongo, and the leaf logo are registered trademarks of MongoDB, Inc" html_show_sphinx = False # The version info for the project you're documenting, acts as replacement for @@ -44,31 +51,31 @@ # List of directories, relative to source directory, that shouldn't be searched # for source files. -exclude_trees = ['_build'] +exclude_trees = ["_build"] # The reST default role (used for this markup: `text`) to use for all documents. -#default_role = None +# default_role = None # If true, sectionauthor and moduleauthor directives will be shown in the # output. They are ignored by default. -#show_authors = False +# show_authors = False # If true, the current module name will be prepended to all description # unit titles (such as .. function::). add_module_names = True # The name of the Pygments (syntax highlighting) style to use. -pygments_style = 'sphinx' +pygments_style = "sphinx" # A list of ignored prefixes for module index sorting. -#modindex_common_prefix = [] +# modindex_common_prefix = [] # -- Options for extensions ---------------------------------------------------- -autoclass_content = 'init' +autoclass_content = "init" -doctest_path = [os.path.abspath('..')] +doctest_path = [os.path.abspath("..")] -doctest_test_doctest_blocks = '' +doctest_test_doctest_blocks = "" doctest_global_setup = """ from pymongo.mongo_client import MongoClient @@ -82,91 +89,87 @@ # Theme gratefully vendored from CPython source. html_theme = "pydoctheme" html_theme_path = ["."] -html_theme_options = { - 'collapsiblesidebar': True, - 'googletag': False -} +html_theme_options = {"collapsiblesidebar": True, "googletag": False} # Additional static files. -html_static_path = ['static'] +html_static_path = ["static"] # The name for this set of Sphinx documents. If None, it defaults to # " v documentation". -#html_title = None +# html_title = None # A shorter title for the navigation bar. Default is the same as html_title. -#html_short_title = None +# html_short_title = None # The name of an image file (relative to this directory) to place at the top # of the sidebar. -#html_logo = None +# html_logo = None # The name of an image file (within the static path) to use as favicon of the # docs. This file should be a Windows icon file (.ico) being 16x16 or 32x32 # pixels large. -#html_favicon = None +# html_favicon = None # Add any paths that contain custom static files (such as style sheets) here, # relative to this directory. They are copied after the builtin static files, # so a file named "default.css" will overwrite the builtin "default.css". -#html_static_path = ['_static'] +# html_static_path = ['_static'] # Custom sidebar templates, maps document names to template names. -#html_sidebars = {} +# html_sidebars = {} # Additional templates that should be rendered to pages, maps page names to # template names. -#html_additional_pages = {} +# html_additional_pages = {} # If true, links to the reST sources are added to the pages. -#html_show_sourcelink = True +# html_show_sourcelink = True # If true, an OpenSearch description file will be output, and all pages will # contain a tag referring to it. The value of this option must be the # base URL from which the finished HTML is served. -#html_use_opensearch = '' +# html_use_opensearch = '' # If nonempty, this is the file name suffix for HTML files (e.g. ".xhtml"). -#html_file_suffix = '' +# html_file_suffix = '' # Output file base name for HTML help builder. -htmlhelp_basename = 'PyMongo' + release.replace('.', '_') +htmlhelp_basename = "PyMongo" + release.replace(".", "_") # -- Options for LaTeX output -------------------------------------------------- # The paper size ('letter' or 'a4'). -#latex_paper_size = 'letter' +# latex_paper_size = 'letter' # The font size ('10pt', '11pt' or '12pt'). -#latex_font_size = '10pt' +# latex_font_size = '10pt' # Grouping the document tree into LaTeX files. List of tuples # (source start file, target name, title, author, documentclass [howto/manual]). latex_documents = [ - ('index', 'PyMongo.tex', 'PyMongo Documentation', - 'Michael Dirolf', 'manual'), + ("index", "PyMongo.tex", "PyMongo Documentation", "Michael Dirolf", "manual"), ] # The name of an image file (relative to this directory) to place at the top of # the title page. -#latex_logo = None +# latex_logo = None # For "manual" documents, if this is true, then toplevel headings are parts, # not chapters. -#latex_use_parts = False +# latex_use_parts = False # Additional stuff for the LaTeX preamble. -#latex_preamble = '' +# latex_preamble = '' # Documents to append as an appendix to all manuals. -#latex_appendices = [] +# latex_appendices = [] # If false, no module index is generated. -#latex_use_modindex = True +# latex_use_modindex = True intersphinx_mapping = { - 'gevent': ('http://www.gevent.org/', None), - 'py': ('https://docs.python.org/3/', None), + "gevent": ("http://www.gevent.org/", None), + "py": ("https://docs.python.org/3/", None), } diff --git a/green_framework_test.py b/green_framework_test.py index baffe21b15..610845a9f6 100644 --- a/green_framework_test.py +++ b/green_framework_test.py @@ -21,30 +21,35 @@ def run_gevent(): """Prepare to run tests with Gevent. Can raise ImportError.""" from gevent import monkey + monkey.patch_all() def run_eventlet(): """Prepare to run tests with Eventlet. Can raise ImportError.""" import eventlet + # https://github.com/eventlet/eventlet/issues/401 eventlet.sleep() eventlet.monkey_patch() FRAMEWORKS = { - 'gevent': run_gevent, - 'eventlet': run_eventlet, + "gevent": run_gevent, + "eventlet": run_eventlet, } def list_frameworks(): """Tell the user what framework names are valid.""" - sys.stdout.write("""Testable frameworks: %s + sys.stdout.write( + """Testable frameworks: %s Note that membership in this list means the framework can be tested with PyMongo, not necessarily that it is officially supported. -""" % ", ".join(sorted(FRAMEWORKS))) +""" + % ", ".join(sorted(FRAMEWORKS)) + ) def run(framework_name, *args): @@ -53,7 +58,7 @@ def run(framework_name, *args): FRAMEWORKS[framework_name]() # Run the tests. - sys.argv[:] = ['setup.py', 'test'] + list(args) + sys.argv[:] = ["setup.py", "test"] + list(args) import setup @@ -62,11 +67,13 @@ def main(): usage = """python %s FRAMEWORK_NAME Test PyMongo with a variety of greenlet-based monkey-patching frameworks. See -python %s --help-frameworks.""" % (sys.argv[0], sys.argv[0]) +python %s --help-frameworks.""" % ( + sys.argv[0], + sys.argv[0], + ) try: - opts, args = getopt.getopt( - sys.argv[1:], "h", ["help", "help-frameworks"]) + opts, args = getopt.getopt(sys.argv[1:], "h", ["help", "help-frameworks"]) except getopt.GetoptError as err: print(str(err)) print(usage) @@ -87,13 +94,14 @@ def main(): sys.exit(1) if args[0] not in FRAMEWORKS: - print('%r is not a testable framework.\n' % args[0]) + print("%r is not a testable framework.\n" % args[0]) list_frameworks() sys.exit(1) - run(args[0], # Framework name. - *args[1:]) # Command line args to setup.py, like what test to run. + run( + args[0], *args[1:] # Framework name. + ) # Command line args to setup.py, like what test to run. -if __name__ == '__main__': +if __name__ == "__main__": main() diff --git a/gridfs/__init__.py b/gridfs/__init__.py index c36d921e8c..2480e6c2b5 100644 --- a/gridfs/__init__.py +++ b/gridfs/__init__.py @@ -22,23 +22,24 @@ from collections import abc -from pymongo import (ASCENDING, - DESCENDING) +from gridfs.errors import NoFile +from gridfs.grid_file import ( + DEFAULT_CHUNK_SIZE, + GridIn, + GridOut, + GridOutCursor, + _clear_entity_type_registry, + _disallow_transactions, +) +from pymongo import ASCENDING, DESCENDING from pymongo.common import UNAUTHORIZED_CODES, validate_string from pymongo.database import Database from pymongo.errors import ConfigurationError, OperationFailure -from gridfs.errors import NoFile -from gridfs.grid_file import (GridIn, - GridOut, - GridOutCursor, - DEFAULT_CHUNK_SIZE, - _clear_entity_type_registry, - _disallow_transactions) class GridFS(object): - """An instance of GridFS on top of a single Database. - """ + """An instance of GridFS on top of a single Database.""" + def __init__(self, database, collection="fs"): """Create a new instance of :class:`GridFS`. @@ -75,8 +76,7 @@ def __init__(self, database, collection="fs"): database = _clear_entity_type_registry(database) if not database.write_concern.acknowledged: - raise ConfigurationError('database must use ' - 'acknowledged write_concern') + raise ConfigurationError("database must use " "acknowledged write_concern") self.__collection = database[collection] self.__files = self.__collection.files @@ -204,8 +204,7 @@ def get_version(self, filename=None, version=-1, session=None, **kwargs): cursor.limit(-1).skip(version).sort("uploadDate", ASCENDING) try: doc = next(cursor) - return GridOut( - self.__collection, file_document=doc, session=session) + return GridOut(self.__collection, file_document=doc, session=session) except StopIteration: raise NoFile("no version %d for filename %r" % (version, filename)) @@ -275,8 +274,8 @@ def list(self, session=None): # With an index, distinct includes documents with no filename # as None. return [ - name for name in self.__files.distinct("filename", session=session) - if name is not None] + name for name in self.__files.distinct("filename", session=session) if name is not None + ] def find_one(self, filter=None, session=None, *args, **kwargs): """Get a single file from gridfs. @@ -422,9 +421,14 @@ def exists(self, document_or_id=None, session=None, **kwargs): class GridFSBucket(object): """An instance of GridFS on top of a single Database.""" - def __init__(self, db, bucket_name="fs", - chunk_size_bytes=DEFAULT_CHUNK_SIZE, write_concern=None, - read_preference=None): + def __init__( + self, + db, + bucket_name="fs", + chunk_size_bytes=DEFAULT_CHUNK_SIZE, + write_concern=None, + read_preference=None, + ): """Create a new instance of :class:`GridFSBucket`. Raises :exc:`TypeError` if `database` is not an instance of @@ -466,22 +470,21 @@ def __init__(self, db, bucket_name="fs", wtc = write_concern if write_concern is not None else db.write_concern if not wtc.acknowledged: - raise ConfigurationError('write concern must be acknowledged') + raise ConfigurationError("write concern must be acknowledged") self._bucket_name = bucket_name self._collection = db[bucket_name] self._chunks = self._collection.chunks.with_options( - write_concern=write_concern, - read_preference=read_preference) + write_concern=write_concern, read_preference=read_preference + ) self._files = self._collection.files.with_options( - write_concern=write_concern, - read_preference=read_preference) + write_concern=write_concern, read_preference=read_preference + ) self._chunk_size_bytes = chunk_size_bytes - def open_upload_stream(self, filename, chunk_size_bytes=None, - metadata=None, session=None): + def open_upload_stream(self, filename, chunk_size_bytes=None, metadata=None, session=None): """Opens a Stream that the application can write the contents of the file to. @@ -519,17 +522,20 @@ def open_upload_stream(self, filename, chunk_size_bytes=None, """ validate_string("filename", filename) - opts = {"filename": filename, - "chunk_size": (chunk_size_bytes if chunk_size_bytes - is not None else self._chunk_size_bytes)} + opts = { + "filename": filename, + "chunk_size": ( + chunk_size_bytes if chunk_size_bytes is not None else self._chunk_size_bytes + ), + } if metadata is not None: opts["metadata"] = metadata return GridIn(self._collection, session=session, **opts) def open_upload_stream_with_id( - self, file_id, filename, chunk_size_bytes=None, metadata=None, - session=None): + self, file_id, filename, chunk_size_bytes=None, metadata=None, session=None + ): """Opens a Stream that the application can write the contents of the file to. @@ -571,17 +577,21 @@ def open_upload_stream_with_id( """ validate_string("filename", filename) - opts = {"_id": file_id, - "filename": filename, - "chunk_size": (chunk_size_bytes if chunk_size_bytes - is not None else self._chunk_size_bytes)} + opts = { + "_id": file_id, + "filename": filename, + "chunk_size": ( + chunk_size_bytes if chunk_size_bytes is not None else self._chunk_size_bytes + ), + } if metadata is not None: opts["metadata"] = metadata return GridIn(self._collection, session=session, **opts) - def upload_from_stream(self, filename, source, chunk_size_bytes=None, - metadata=None, session=None): + def upload_from_stream( + self, filename, source, chunk_size_bytes=None, metadata=None, session=None + ): """Uploads a user file to a GridFS bucket. Reads the contents of the user file from `source` and uploads @@ -617,15 +627,14 @@ def upload_from_stream(self, filename, source, chunk_size_bytes=None, .. versionchanged:: 3.6 Added ``session`` parameter. """ - with self.open_upload_stream( - filename, chunk_size_bytes, metadata, session=session) as gin: + with self.open_upload_stream(filename, chunk_size_bytes, metadata, session=session) as gin: gin.write(source) return gin._id - def upload_from_stream_with_id(self, file_id, filename, source, - chunk_size_bytes=None, metadata=None, - session=None): + def upload_from_stream_with_id( + self, file_id, filename, source, chunk_size_bytes=None, metadata=None, session=None + ): """Uploads a user file to a GridFS bucket with a custom file id. Reads the contents of the user file from `source` and uploads @@ -663,8 +672,8 @@ def upload_from_stream_with_id(self, file_id, filename, source, Added ``session`` parameter. """ with self.open_upload_stream_with_id( - file_id, filename, chunk_size_bytes, metadata, - session=session) as gin: + file_id, filename, chunk_size_bytes, metadata, session=session + ) as gin: gin.write(source) def open_download_stream(self, file_id, session=None): @@ -755,8 +764,7 @@ def delete(self, file_id, session=None): res = self._files.delete_one({"_id": file_id}, session=session) self._chunks.delete_many({"files_id": file_id}, session=session) if not res.deleted_count: - raise NoFile( - "no file could be deleted because none matched %s" % file_id) + raise NoFile("no file could be deleted because none matched %s" % file_id) def find(self, *args, **kwargs): """Find and return the files collection documents that match ``filter`` @@ -855,14 +863,11 @@ def open_download_stream_by_name(self, filename, revision=-1, session=None): cursor.limit(-1).skip(revision).sort("uploadDate", ASCENDING) try: grid_file = next(cursor) - return GridOut( - self._collection, file_document=grid_file, session=session) + return GridOut(self._collection, file_document=grid_file, session=session) except StopIteration: - raise NoFile( - "no version %d for filename %r" % (revision, filename)) + raise NoFile("no version %d for filename %r" % (revision, filename)) - def download_to_stream_by_name(self, filename, destination, revision=-1, - session=None): + def download_to_stream_by_name(self, filename, destination, revision=-1, session=None): """Write the contents of `filename` (with optional `revision`) to `destination`. @@ -900,8 +905,7 @@ def download_to_stream_by_name(self, filename, destination, revision=-1, .. versionchanged:: 3.6 Added ``session`` parameter. """ - with self.open_download_stream_by_name( - filename, revision, session=session) as gout: + with self.open_download_stream_by_name(filename, revision, session=session) as gout: for chunk in gout: destination.write(chunk) @@ -928,9 +932,11 @@ def rename(self, file_id, new_filename, session=None): Added ``session`` parameter. """ _disallow_transactions(session) - result = self._files.update_one({"_id": file_id}, - {"$set": {"filename": new_filename}}, - session=session) + result = self._files.update_one( + {"_id": file_id}, {"$set": {"filename": new_filename}}, session=session + ) if not result.matched_count: - raise NoFile("no files could be renamed %r because none " - "matched file_id %i" % (new_filename, file_id)) + raise NoFile( + "no files could be renamed %r because none " + "matched file_id %i" % (new_filename, file_id) + ) diff --git a/gridfs/grid_file.py b/gridfs/grid_file.py index fc01d88d24..feec325b0c 100644 --- a/gridfs/grid_file.py +++ b/gridfs/grid_file.py @@ -18,22 +18,23 @@ import math import os -from bson.int64 import Int64 -from bson.son import SON from bson.binary import Binary +from bson.int64 import Int64 from bson.objectid import ObjectId +from bson.son import SON +from gridfs.errors import CorruptGridFile, FileExists, NoFile from pymongo import ASCENDING from pymongo.collection import Collection from pymongo.cursor import Cursor -from pymongo.errors import (ConfigurationError, - CursorNotFound, - DuplicateKeyError, - InvalidOperation, - OperationFailure) +from pymongo.errors import ( + ConfigurationError, + CursorNotFound, + DuplicateKeyError, + InvalidOperation, + OperationFailure, +) from pymongo.read_preferences import ReadPreference -from gridfs.errors import CorruptGridFile, FileExists, NoFile - try: _SEEK_SET = os.SEEK_SET _SEEK_CUR = os.SEEK_CUR @@ -55,30 +56,31 @@ _F_INDEX = SON([("filename", ASCENDING), ("uploadDate", ASCENDING)]) -def _grid_in_property(field_name, docstring, read_only=False, - closed_only=False): +def _grid_in_property(field_name, docstring, read_only=False, closed_only=False): """Create a GridIn property.""" + def getter(self): if closed_only and not self._closed: - raise AttributeError("can only get %r on a closed file" % - field_name) + raise AttributeError("can only get %r on a closed file" % field_name) # Protect against PHP-237 - if field_name == 'length': + if field_name == "length": return self._file.get(field_name, 0) return self._file.get(field_name, None) def setter(self, value): if self._closed: - self._coll.files.update_one({"_id": self._file["_id"]}, - {"$set": {field_name: value}}) + self._coll.files.update_one({"_id": self._file["_id"]}, {"$set": {field_name: value}}) self._file[field_name] = value if read_only: docstring += "\n\nThis attribute is read-only." elif closed_only: - docstring = "%s\n\n%s" % (docstring, "This attribute is read-only and " - "can only be read after :meth:`close` " - "has been called.") + docstring = "%s\n\n%s" % ( + docstring, + "This attribute is read-only and " + "can only be read after :meth:`close` " + "has been called.", + ) if not read_only and not closed_only: return property(getter, setter, doc=docstring) @@ -87,11 +89,12 @@ def setter(self, value): def _grid_out_property(field_name, docstring): """Create a GridOut property.""" + def getter(self): self._ensure_file() # Protect against PHP-237 - if field_name == 'length': + if field_name == "length": return self._file.get(field_name, 0) return self._file.get(field_name, None) @@ -107,13 +110,12 @@ def _clear_entity_type_registry(entity, **kwargs): def _disallow_transactions(session): if session and session.in_transaction: - raise InvalidOperation( - 'GridFS does not support multi-document transactions') + raise InvalidOperation("GridFS does not support multi-document transactions") class GridIn(object): - """Class to write data to GridFS. - """ + """Class to write data to GridFS.""" + def __init__(self, root_collection, session=None, **kwargs): """Write a file to GridFS @@ -167,12 +169,10 @@ def __init__(self, root_collection, session=None, **kwargs): :attr:`~pymongo.collection.Collection.write_concern` """ if not isinstance(root_collection, Collection): - raise TypeError("root_collection must be an " - "instance of Collection") + raise TypeError("root_collection must be an " "instance of Collection") if not root_collection.write_concern.acknowledged: - raise ConfigurationError('root_collection must use ' - 'acknowledged write_concern') + raise ConfigurationError("root_collection must use " "acknowledged write_concern") _disallow_transactions(session) # Handle alternative naming @@ -181,8 +181,7 @@ def __init__(self, root_collection, session=None, **kwargs): if "chunk_size" in kwargs: kwargs["chunkSize"] = kwargs.pop("chunk_size") - coll = _clear_entity_type_registry( - root_collection, read_preference=ReadPreference.PRIMARY) + coll = _clear_entity_type_registry(root_collection, read_preference=ReadPreference.PRIMARY) # Defaults kwargs["_id"] = kwargs.get("_id", ObjectId()) @@ -201,13 +200,14 @@ def __create_index(self, collection, index_key, unique): doc = collection.find_one(projection={"_id": 1}, session=self._session) if doc is None: try: - index_keys = [index_spec['key'] for index_spec in - collection.list_indexes(session=self._session)] + index_keys = [ + index_spec["key"] + for index_spec in collection.list_indexes(session=self._session) + ] except OperationFailure: index_keys = [] if index_key not in index_keys: - collection.create_index( - index_key.items(), unique=unique, session=self._session) + collection.create_index(index_key.items(), unique=unique, session=self._session) def __ensure_indexes(self): if not object.__getattribute__(self, "_ensured_index"): @@ -217,35 +217,28 @@ def __ensure_indexes(self): object.__setattr__(self, "_ensured_index", True) def abort(self): - """Remove all chunks/files that may have been uploaded and close. - """ - self._coll.chunks.delete_many( - {"files_id": self._file['_id']}, session=self._session) - self._coll.files.delete_one( - {"_id": self._file['_id']}, session=self._session) + """Remove all chunks/files that may have been uploaded and close.""" + self._coll.chunks.delete_many({"files_id": self._file["_id"]}, session=self._session) + self._coll.files.delete_one({"_id": self._file["_id"]}, session=self._session) object.__setattr__(self, "_closed", True) @property def closed(self): - """Is this file closed? - """ + """Is this file closed?""" return self._closed - _id = _grid_in_property("_id", "The ``'_id'`` value for this file.", - read_only=True) + _id = _grid_in_property("_id", "The ``'_id'`` value for this file.", read_only=True) filename = _grid_in_property("filename", "Name of this file.") name = _grid_in_property("filename", "Alias for `filename`.") content_type = _grid_in_property("contentType", "Mime-type for this file.") - length = _grid_in_property("length", "Length (in bytes) of this file.", - closed_only=True) - chunk_size = _grid_in_property("chunkSize", "Chunk size for this file.", - read_only=True) - upload_date = _grid_in_property("uploadDate", - "Date that this file was uploaded.", - closed_only=True) - md5 = _grid_in_property("md5", "MD5 of the contents of this file " - "if an md5 sum was created.", - closed_only=True) + length = _grid_in_property("length", "Length (in bytes) of this file.", closed_only=True) + chunk_size = _grid_in_property("chunkSize", "Chunk size for this file.", read_only=True) + upload_date = _grid_in_property( + "uploadDate", "Date that this file was uploaded.", closed_only=True + ) + md5 = _grid_in_property( + "md5", "MD5 of the contents of this file " "if an md5 sum was created.", closed_only=True + ) def __getattr__(self, name): if name in self._file: @@ -263,46 +256,39 @@ def __setattr__(self, name, value): # them now. self._file[name] = value if self._closed: - self._coll.files.update_one({"_id": self._file["_id"]}, - {"$set": {name: value}}) + self._coll.files.update_one({"_id": self._file["_id"]}, {"$set": {name: value}}) def __flush_data(self, data): - """Flush `data` to a chunk. - """ + """Flush `data` to a chunk.""" self.__ensure_indexes() if not data: return - assert(len(data) <= self.chunk_size) + assert len(data) <= self.chunk_size - chunk = {"files_id": self._file["_id"], - "n": self._chunk_number, - "data": Binary(data)} + chunk = {"files_id": self._file["_id"], "n": self._chunk_number, "data": Binary(data)} try: self._chunks.insert_one(chunk, session=self._session) except DuplicateKeyError: - self._raise_file_exists(self._file['_id']) + self._raise_file_exists(self._file["_id"]) self._chunk_number += 1 self._position += len(data) def __flush_buffer(self): - """Flush the buffer contents out to a chunk. - """ + """Flush the buffer contents out to a chunk.""" self.__flush_data(self._buffer.getvalue()) self._buffer.close() self._buffer = io.BytesIO() def __flush(self): - """Flush the file to the database. - """ + """Flush the file to the database.""" try: self.__flush_buffer() # The GridFS spec says length SHOULD be an Int64. self._file["length"] = Int64(self._position) self._file["uploadDate"] = datetime.datetime.utcnow() - return self._coll.files.insert_one( - self._file, session=self._session) + return self._coll.files.insert_one(self._file, session=self._session) except DuplicateKeyError: self._raise_file_exists(self._id) @@ -321,7 +307,7 @@ def close(self): object.__setattr__(self, "_closed", True) def read(self, size=-1): - raise io.UnsupportedOperation('read') + raise io.UnsupportedOperation("read") def readable(self): return False @@ -364,8 +350,7 @@ def write(self, data): try: data = data.encode(self.encoding) except AttributeError: - raise TypeError("must specify an encoding for file in " - "order to write str") + raise TypeError("must specify an encoding for file in " "order to write str") read = io.BytesIO(data).read if self._buffer.tell() > 0: @@ -399,8 +384,7 @@ def writeable(self): return True def __enter__(self): - """Support for the context manager protocol. - """ + """Support for the context manager protocol.""" return self def __exit__(self, exc_type, exc_val, exc_tb): @@ -415,10 +399,9 @@ def __exit__(self, exc_type, exc_val, exc_tb): class GridOut(io.IOBase): - """Class to read data out of GridFS. - """ - def __init__(self, root_collection, file_id=None, file_document=None, - session=None): + """Class to read data out of GridFS.""" + + def __init__(self, root_collection, file_id=None, file_document=None, session=None): """Read a file from GridFS Application developers should generally not need to @@ -452,8 +435,7 @@ def __init__(self, root_collection, file_id=None, file_document=None, from the server. Metadata is fetched when first needed. """ if not isinstance(root_collection, Collection): - raise TypeError("root_collection must be an " - "instance of Collection") + raise TypeError("root_collection must be an " "instance of Collection") _disallow_transactions(session) root_collection = _clear_entity_type_registry(root_collection) @@ -475,21 +457,21 @@ def __init__(self, root_collection, file_id=None, file_document=None, content_type = _grid_out_property("contentType", "Mime-type for this file.") length = _grid_out_property("length", "Length (in bytes) of this file.") chunk_size = _grid_out_property("chunkSize", "Chunk size for this file.") - upload_date = _grid_out_property("uploadDate", - "Date that this file was first uploaded.") + upload_date = _grid_out_property("uploadDate", "Date that this file was first uploaded.") aliases = _grid_out_property("aliases", "List of aliases for this file.") metadata = _grid_out_property("metadata", "Metadata attached to this file.") - md5 = _grid_out_property("md5", "MD5 of the contents of this file " - "if an md5 sum was created.") + md5 = _grid_out_property( + "md5", "MD5 of the contents of this file " "if an md5 sum was created." + ) def _ensure_file(self): if not self._file: _disallow_transactions(self._session) - self._file = self.__files.find_one({"_id": self.__file_id}, - session=self._session) + self._file = self.__files.find_one({"_id": self.__file_id}, session=self._session) if not self._file: - raise NoFile("no file in gridfs collection %r with _id %r" % - (self.__files, self.__file_id)) + raise NoFile( + "no file in gridfs collection %r with _id %r" % (self.__files, self.__file_id) + ) def __getattr__(self, name): self._ensure_file() @@ -514,10 +496,11 @@ def readchunk(self): chunk_number = int((received + self.__position) / chunk_size) if self.__chunk_iter is None: self.__chunk_iter = _GridOutChunkIterator( - self, self.__chunks, self._session, chunk_number) + self, self.__chunks, self._session, chunk_number + ) chunk = self.__chunk_iter.next() - chunk_data = chunk["data"][self.__position % chunk_size:] + chunk_data = chunk["data"][self.__position % chunk_size :] if not chunk_data: raise CorruptGridFile("truncated chunk") @@ -607,8 +590,7 @@ def readline(self, size=-1): return data.read(size) def tell(self): - """Return the current position of this file. - """ + """Return the current position of this file.""" return self.__position def seek(self, pos, whence=_SEEK_SET): @@ -677,10 +659,10 @@ def close(self): super().close() def write(self, value): - raise io.UnsupportedOperation('write') + raise io.UnsupportedOperation("write") def writelines(self, lines): - raise io.UnsupportedOperation('writelines') + raise io.UnsupportedOperation("writelines") def writable(self): return False @@ -699,7 +681,7 @@ def __exit__(self, exc_type, exc_val, exc_tb): return False def fileno(self): - raise io.UnsupportedOperation('fileno') + raise io.UnsupportedOperation("fileno") def flush(self): # GridOut is read-only, so flush does nothing. @@ -711,7 +693,7 @@ def isatty(self): def truncate(self, size=None): # See https://docs.python.org/3/library/io.html#io.IOBase.writable # for why truncate has to raise. - raise io.UnsupportedOperation('truncate') + raise io.UnsupportedOperation("truncate") # Override IOBase.__del__ otherwise it will lead to __getattr__ on # __IOBase_closed which calls _ensure_file and potentially performs I/O. @@ -726,6 +708,7 @@ class _GridOutChunkIterator(object): Raises CorruptGridFile when encountering any truncated, missing, or extra chunk in a file. """ + def __init__(self, grid_out, chunks, session, next_chunk): self._id = grid_out._id self._chunk_size = int(grid_out.chunk_size) @@ -749,8 +732,7 @@ def _create_cursor(self): if self._next_chunk > 0: filter["n"] = {"$gte": self._next_chunk} _disallow_transactions(self._session) - self._cursor = self._chunks.find(filter, sort=[("n", 1)], - session=self._session) + self._cursor = self._chunks.find(filter, sort=[("n", 1)], session=self._session) def _next_with_retry(self): """Return the next chunk and retry once on CursorNotFound. @@ -781,7 +763,8 @@ def next(self): self.close() raise CorruptGridFile( "Missing chunk: expected chunk #%d but found " - "chunk with n=%d" % (self._next_chunk, chunk["n"])) + "chunk with n=%d" % (self._next_chunk, chunk["n"]) + ) if chunk["n"] >= self._num_chunks: # According to spec, ignore extra chunks if they are empty. @@ -789,15 +772,16 @@ def next(self): self.close() raise CorruptGridFile( "Extra chunk found: expected %d chunks but found " - "chunk with n=%d" % (self._num_chunks, chunk["n"])) + "chunk with n=%d" % (self._num_chunks, chunk["n"]) + ) expected_length = self.expected_chunk_length(chunk["n"]) if len(chunk["data"]) != expected_length: self.close() raise CorruptGridFile( "truncated chunk #%d: expected chunk length to be %d but " - "found chunk with length %d" % ( - chunk["n"], expected_length, len(chunk["data"]))) + "found chunk with length %d" % (chunk["n"], expected_length, len(chunk["data"])) + ) self._next_chunk += 1 return chunk @@ -828,9 +812,18 @@ class GridOutCursor(Cursor): """A cursor / iterator for returning GridOut objects as the result of an arbitrary query against the GridFS files collection. """ - def __init__(self, collection, filter=None, skip=0, limit=0, - no_cursor_timeout=False, sort=None, batch_size=0, - session=None): + + def __init__( + self, + collection, + filter=None, + skip=0, + limit=0, + no_cursor_timeout=False, + sort=None, + batch_size=0, + session=None, + ): """Create a new cursor, similar to the normal :class:`~pymongo.cursor.Cursor`. @@ -848,18 +841,22 @@ def __init__(self, collection, filter=None, skip=0, limit=0, self.__root_collection = collection super(GridOutCursor, self).__init__( - collection.files, filter, skip=skip, limit=limit, - no_cursor_timeout=no_cursor_timeout, sort=sort, - batch_size=batch_size, session=session) + collection.files, + filter, + skip=skip, + limit=limit, + no_cursor_timeout=no_cursor_timeout, + sort=sort, + batch_size=batch_size, + session=session, + ) def next(self): - """Get next GridOut object from cursor. - """ + """Get next GridOut object from cursor.""" _disallow_transactions(self.session) # Work around "super is not iterable" issue in Python 3.x next_file = super(GridOutCursor, self).next() - return GridOut(self.__root_collection, file_document=next_file, - session=self.session) + return GridOut(self.__root_collection, file_document=next_file, session=self.session) __next__ = next @@ -870,6 +867,5 @@ def remove_option(self, *args, **kwargs): raise NotImplementedError("Method does not exist for GridOutCursor") def _clone_base(self, session): - """Creates an empty GridOutCursor for information to be copied into. - """ + """Creates an empty GridOutCursor for information to be copied into.""" return GridOutCursor(self.__root_collection, session=session) diff --git a/pymongo/__init__.py b/pymongo/__init__.py index 5cb10fcbb0..8a1d3464ff 100644 --- a/pymongo/__init__.py +++ b/pymongo/__init__.py @@ -53,35 +53,40 @@ .. _text index: http://docs.mongodb.org/manual/core/index-text/ """ -version_tuple = (4, 0, 2, '.dev0') +version_tuple = (4, 0, 2, ".dev0") + def get_version_string(): if isinstance(version_tuple[-1], str): - return '.'.join(map(str, version_tuple[:-1])) + version_tuple[-1] - return '.'.join(map(str, version_tuple)) + return ".".join(map(str, version_tuple[:-1])) + version_tuple[-1] + return ".".join(map(str, version_tuple)) + __version__ = version = get_version_string() """Current version of PyMongo.""" from pymongo.collection import ReturnDocument -from pymongo.common import (MIN_SUPPORTED_WIRE_VERSION, - MAX_SUPPORTED_WIRE_VERSION) +from pymongo.common import MAX_SUPPORTED_WIRE_VERSION, MIN_SUPPORTED_WIRE_VERSION from pymongo.cursor import CursorType from pymongo.mongo_client import MongoClient -from pymongo.operations import (IndexModel, - InsertOne, - DeleteOne, - DeleteMany, - UpdateOne, - UpdateMany, - ReplaceOne) +from pymongo.operations import ( + DeleteMany, + DeleteOne, + IndexModel, + InsertOne, + ReplaceOne, + UpdateMany, + UpdateOne, +) from pymongo.read_preferences import ReadPreference from pymongo.write_concern import WriteConcern + def has_c(): """Is the C extension installed?""" try: from pymongo import _cmessage + return True except ImportError: return False diff --git a/pymongo/aggregation.py b/pymongo/aggregation.py index 2a34a05d3a..346b394e19 100644 --- a/pymongo/aggregation.py +++ b/pymongo/aggregation.py @@ -15,7 +15,6 @@ """Perform aggregation operations on a collection or database.""" from bson.son import SON - from pymongo import common from pymongo.collation import validate_collation_or_none from pymongo.errors import ConfigurationError @@ -29,27 +28,38 @@ class _AggregationCommand(object): :meth:`pymongo.collection.Collection.aggregate`, or :meth:`pymongo.database.Database.aggregate` instead. """ - def __init__(self, target, cursor_class, pipeline, options, - explicit_session, user_fields=None, result_processor=None): + + def __init__( + self, + target, + cursor_class, + pipeline, + options, + explicit_session, + user_fields=None, + result_processor=None, + ): if "explain" in options: - raise ConfigurationError("The explain option is not supported. " - "Use Database.command instead.") + raise ConfigurationError( + "The explain option is not supported. " "Use Database.command instead." + ) self._target = target - common.validate_list('pipeline', pipeline) + common.validate_list("pipeline", pipeline) self._pipeline = pipeline self._performs_write = False if pipeline and ("$out" in pipeline[-1] or "$merge" in pipeline[-1]): self._performs_write = True - common.validate_is_mapping('options', options) + common.validate_is_mapping("options", options) self._options = options # This is the batchSize that will be used for setting the initial # batchSize for the cursor, as well as the subsequent getMores. self._batch_size = common.validate_non_negative_integer_or_none( - "batchSize", self._options.pop("batchSize", None)) + "batchSize", self._options.pop("batchSize", None) + ) # If the cursor option is already specified, avoid overriding it. self._options.setdefault("cursor", {}) @@ -63,10 +73,9 @@ def __init__(self, target, cursor_class, pipeline, options, self._user_fields = user_fields self._result_processor = result_processor - self._collation = validate_collation_or_none( - options.pop('collation', None)) + self._collation = validate_collation_or_none(options.pop("collation", None)) - self._max_await_time_ms = options.pop('maxAwaitTimeMS', None) + self._max_await_time_ms = options.pop("maxAwaitTimeMS", None) @property def _aggregation_target(self): @@ -90,8 +99,7 @@ def _database(self): def _process_result(self, result, session, server, sock_info, secondary_ok): if self._result_processor: - self._result_processor( - result, session, server, sock_info, secondary_ok) + self._result_processor(result, session, server, sock_info, secondary_ok) def get_read_preference(self, session): if self._performs_write: @@ -100,17 +108,16 @@ def get_read_preference(self, session): def get_cursor(self, session, server, sock_info, secondary_ok): # Serialize command. - cmd = SON([("aggregate", self._aggregation_target), - ("pipeline", self._pipeline)]) + cmd = SON([("aggregate", self._aggregation_target), ("pipeline", self._pipeline)]) cmd.update(self._options) # Apply this target's read concern if: # readConcern has not been specified as a kwarg and either # - server version is >= 4.2 or # - server version is >= 3.2 and pipeline doesn't use $out - if (('readConcern' not in cmd) and - (not self._performs_write or - (sock_info.max_wire_version >= 8))): + if ("readConcern" not in cmd) and ( + not self._performs_write or (sock_info.max_wire_version >= 8) + ): read_concern = self._target.read_concern else: read_concern = None @@ -118,7 +125,7 @@ def get_cursor(self, session, server, sock_info, secondary_ok): # Apply this target's write concern if: # writeConcern has not been specified as a kwarg and pipeline doesn't # perform a write operation - if 'writeConcern' not in cmd and self._performs_write: + if "writeConcern" not in cmd and self._performs_write: write_concern = self._target._write_concern_for(session) else: write_concern = None @@ -136,13 +143,14 @@ def get_cursor(self, session, server, sock_info, secondary_ok): collation=self._collation, session=session, client=self._database.client, - user_fields=self._user_fields) + user_fields=self._user_fields, + ) self._process_result(result, session, server, sock_info, secondary_ok) # Extract cursor from result or mock/fake one if necessary. - if 'cursor' in result: - cursor = result['cursor'] + if "cursor" in result: + cursor = result["cursor"] else: # Pre-MongoDB 2.6 or unacknowledged write. Fake a cursor. cursor = { @@ -153,16 +161,19 @@ def get_cursor(self, session, server, sock_info, secondary_ok): # Create and return cursor instance. cmd_cursor = self._cursor_class( - self._cursor_collection(cursor), cursor, sock_info.address, + self._cursor_collection(cursor), + cursor, + sock_info.address, batch_size=self._batch_size or 0, max_await_time_ms=self._max_await_time_ms, - session=session, explicit_session=self._explicit_session) + session=session, + explicit_session=self._explicit_session, + ) cmd_cursor._maybe_pin_connection(sock_info) return cmd_cursor class _CollectionAggregationCommand(_AggregationCommand): - @property def _aggregation_target(self): return self._target.name diff --git a/pymongo/auth.py b/pymongo/auth.py index 17f3a32fe8..072df0ba02 100644 --- a/pymongo/auth.py +++ b/pymongo/auth.py @@ -19,7 +19,6 @@ import hmac import os import socket - from base64 import standard_b64decode, standard_b64encode from collections import namedtuple from urllib.parse import quote @@ -34,7 +33,8 @@ _USE_PRINCIPAL = False try: import winkerberos as kerberos - if tuple(map(int, kerberos.__version__.split('.')[:2])) >= (0, 5): + + if tuple(map(int, kerberos.__version__.split(".")[:2])) >= (0, 5): _USE_PRINCIPAL = True except ImportError: try: @@ -44,21 +44,24 @@ MECHANISMS = frozenset( - ['GSSAPI', - 'MONGODB-CR', - 'MONGODB-X509', - 'MONGODB-AWS', - 'PLAIN', - 'SCRAM-SHA-1', - 'SCRAM-SHA-256', - 'DEFAULT']) + [ + "GSSAPI", + "MONGODB-CR", + "MONGODB-X509", + "MONGODB-AWS", + "PLAIN", + "SCRAM-SHA-1", + "SCRAM-SHA-256", + "DEFAULT", + ] +) """The authentication mechanisms supported by PyMongo.""" class _Cache(object): __slots__ = ("data",) - _hash_val = hash('_Cache') + _hash_val = hash("_Cache") def __init__(self): self.data = None @@ -78,80 +81,69 @@ def __hash__(self): return self._hash_val - MongoCredential = namedtuple( - 'MongoCredential', - ['mechanism', - 'source', - 'username', - 'password', - 'mechanism_properties', - 'cache']) + "MongoCredential", + ["mechanism", "source", "username", "password", "mechanism_properties", "cache"], +) """A hashable namedtuple of values used for authentication.""" -GSSAPIProperties = namedtuple('GSSAPIProperties', - ['service_name', - 'canonicalize_host_name', - 'service_realm']) +GSSAPIProperties = namedtuple( + "GSSAPIProperties", ["service_name", "canonicalize_host_name", "service_realm"] +) """Mechanism properties for GSSAPI authentication.""" -_AWSProperties = namedtuple('AWSProperties', ['aws_session_token']) +_AWSProperties = namedtuple("AWSProperties", ["aws_session_token"]) """Mechanism properties for MONGODB-AWS authentication.""" def _build_credentials_tuple(mech, source, user, passwd, extra, database): - """Build and return a mechanism specific credentials tuple. - """ - if mech not in ('MONGODB-X509', 'MONGODB-AWS') and user is None: + """Build and return a mechanism specific credentials tuple.""" + if mech not in ("MONGODB-X509", "MONGODB-AWS") and user is None: raise ConfigurationError("%s requires a username." % (mech,)) - if mech == 'GSSAPI': - if source is not None and source != '$external': - raise ValueError( - "authentication source must be $external or None for GSSAPI") - properties = extra.get('authmechanismproperties', {}) - service_name = properties.get('SERVICE_NAME', 'mongodb') - canonicalize = properties.get('CANONICALIZE_HOST_NAME', False) - service_realm = properties.get('SERVICE_REALM') - props = GSSAPIProperties(service_name=service_name, - canonicalize_host_name=canonicalize, - service_realm=service_realm) + if mech == "GSSAPI": + if source is not None and source != "$external": + raise ValueError("authentication source must be $external or None for GSSAPI") + properties = extra.get("authmechanismproperties", {}) + service_name = properties.get("SERVICE_NAME", "mongodb") + canonicalize = properties.get("CANONICALIZE_HOST_NAME", False) + service_realm = properties.get("SERVICE_REALM") + props = GSSAPIProperties( + service_name=service_name, + canonicalize_host_name=canonicalize, + service_realm=service_realm, + ) # Source is always $external. - return MongoCredential(mech, '$external', user, passwd, props, None) - elif mech == 'MONGODB-X509': + return MongoCredential(mech, "$external", user, passwd, props, None) + elif mech == "MONGODB-X509": if passwd is not None: - raise ConfigurationError( - "Passwords are not supported by MONGODB-X509") - if source is not None and source != '$external': - raise ValueError( - "authentication source must be " - "$external or None for MONGODB-X509") + raise ConfigurationError("Passwords are not supported by MONGODB-X509") + if source is not None and source != "$external": + raise ValueError("authentication source must be " "$external or None for MONGODB-X509") # Source is always $external, user can be None. - return MongoCredential(mech, '$external', user, None, None, None) - elif mech == 'MONGODB-AWS': + return MongoCredential(mech, "$external", user, None, None, None) + elif mech == "MONGODB-AWS": if user is not None and passwd is None: + raise ConfigurationError("username without a password is not supported by MONGODB-AWS") + if source is not None and source != "$external": raise ConfigurationError( - "username without a password is not supported by MONGODB-AWS") - if source is not None and source != '$external': - raise ConfigurationError( - "authentication source must be " - "$external or None for MONGODB-AWS") + "authentication source must be " "$external or None for MONGODB-AWS" + ) - properties = extra.get('authmechanismproperties', {}) - aws_session_token = properties.get('AWS_SESSION_TOKEN') + properties = extra.get("authmechanismproperties", {}) + aws_session_token = properties.get("AWS_SESSION_TOKEN") props = _AWSProperties(aws_session_token=aws_session_token) # user can be None for temporary link-local EC2 credentials. - return MongoCredential(mech, '$external', user, passwd, props, None) - elif mech == 'PLAIN': - source_database = source or database or '$external' + return MongoCredential(mech, "$external", user, passwd, props, None) + elif mech == "PLAIN": + source_database = source or database or "$external" return MongoCredential(mech, source_database, user, passwd, None, None) else: - source_database = source or database or 'admin' + source_database = source or database or "admin" if passwd is None: raise ConfigurationError("A password is required.") - return MongoCredential( - mech, source_database, user, passwd, None, _Cache()) + return MongoCredential(mech, source_database, user, passwd, None, _Cache()) def _xor(fir, sec): @@ -170,18 +162,22 @@ def _authenticate_scram_start(credentials, mechanism): nonce = standard_b64encode(os.urandom(32)) first_bare = b"n=" + user + b",r=" + nonce - cmd = SON([('saslStart', 1), - ('mechanism', mechanism), - ('payload', Binary(b"n,," + first_bare)), - ('autoAuthorize', 1), - ('options', {'skipEmptyExchange': True})]) + cmd = SON( + [ + ("saslStart", 1), + ("mechanism", mechanism), + ("payload", Binary(b"n,," + first_bare)), + ("autoAuthorize", 1), + ("options", {"skipEmptyExchange": True}), + ] + ) return nonce, first_bare, cmd def _authenticate_scram(credentials, sock_info, mechanism): """Authenticate using SCRAM.""" username = credentials.username - if mechanism == 'SCRAM-SHA-256': + if mechanism == "SCRAM-SHA-256": digest = "sha256" digestmod = hashlib.sha256 data = saslprep(credentials.password).encode("utf-8") @@ -200,17 +196,16 @@ def _authenticate_scram(credentials, sock_info, mechanism): nonce, first_bare = ctx.scram_data res = ctx.speculative_authenticate else: - nonce, first_bare, cmd = _authenticate_scram_start( - credentials, mechanism) + nonce, first_bare, cmd = _authenticate_scram_start(credentials, mechanism) res = sock_info.command(source, cmd) - server_first = res['payload'] + server_first = res["payload"] parsed = _parse_scram_response(server_first) - iterations = int(parsed[b'i']) + iterations = int(parsed[b"i"]) if iterations < 4096: raise OperationFailure("Server returned an invalid iteration count.") - salt = parsed[b's'] - rnonce = parsed[b'r'] + salt = parsed[b"s"] + rnonce = parsed[b"r"] if not rnonce.startswith(nonce): raise OperationFailure("Server returned an invalid nonce.") @@ -223,8 +218,7 @@ def _authenticate_scram(credentials, sock_info, mechanism): # Salt and / or iterations could change for a number of different # reasons. Either changing invalidates the cache. if not client_key or salt != csalt or iterations != citerations: - salted_pass = hashlib.pbkdf2_hmac( - digest, data, standard_b64decode(salt), iterations) + salted_pass = hashlib.pbkdf2_hmac(digest, data, standard_b64decode(salt), iterations) client_key = _hmac(salted_pass, b"Client Key", digestmod).digest() server_key = _hmac(salted_pass, b"Server Key", digestmod).digest() cache.data = (client_key, server_key, salt, iterations) @@ -234,32 +228,38 @@ def _authenticate_scram(credentials, sock_info, mechanism): client_proof = b"p=" + standard_b64encode(_xor(client_key, client_sig)) client_final = b",".join((without_proof, client_proof)) - server_sig = standard_b64encode( - _hmac(server_key, auth_msg, digestmod).digest()) + server_sig = standard_b64encode(_hmac(server_key, auth_msg, digestmod).digest()) - cmd = SON([('saslContinue', 1), - ('conversationId', res['conversationId']), - ('payload', Binary(client_final))]) + cmd = SON( + [ + ("saslContinue", 1), + ("conversationId", res["conversationId"]), + ("payload", Binary(client_final)), + ] + ) res = sock_info.command(source, cmd) - parsed = _parse_scram_response(res['payload']) - if not hmac.compare_digest(parsed[b'v'], server_sig): + parsed = _parse_scram_response(res["payload"]) + if not hmac.compare_digest(parsed[b"v"], server_sig): raise OperationFailure("Server returned an invalid signature.") # A third empty challenge may be required if the server does not support # skipEmptyExchange: SERVER-44857. - if not res['done']: - cmd = SON([('saslContinue', 1), - ('conversationId', res['conversationId']), - ('payload', Binary(b''))]) + if not res["done"]: + cmd = SON( + [ + ("saslContinue", 1), + ("conversationId", res["conversationId"]), + ("payload", Binary(b"")), + ] + ) res = sock_info.command(source, cmd) - if not res['done']: - raise OperationFailure('SASL conversation failed to complete.') + if not res["done"]: + raise OperationFailure("SASL conversation failed to complete.") def _password_digest(username, password): - """Get a password digest to use for authentication. - """ + """Get a password digest to use for authentication.""" if not isinstance(password, str): raise TypeError("password must be an instance of str") if len(password) == 0: @@ -269,17 +269,16 @@ def _password_digest(username, password): md5hash = hashlib.md5() data = "%s:mongo:%s" % (username, password) - md5hash.update(data.encode('utf-8')) + md5hash.update(data.encode("utf-8")) return md5hash.hexdigest() def _auth_key(nonce, username, password): - """Get an auth key to use for authentication. - """ + """Get an auth key to use for authentication.""" digest = _password_digest(username, password) md5hash = hashlib.md5() data = "%s%s%s" % (nonce, username, digest) - md5hash.update(data.encode('utf-8')) + md5hash.update(data.encode("utf-8")) return md5hash.hexdigest() @@ -287,7 +286,8 @@ def _canonicalize_hostname(hostname): """Canonicalize hostname following MIT-krb5 behavior.""" # https://github.com/krb5/krb5/blob/d406afa363554097ac48646a29249c04f498c88e/src/util/k5test.py#L505-L520 af, socktype, proto, canonname, sockaddr = socket.getaddrinfo( - hostname, None, 0, 0, socket.IPPROTO_TCP, socket.AI_CANONNAME)[0] + hostname, None, 0, 0, socket.IPPROTO_TCP, socket.AI_CANONNAME + )[0] try: name = socket.getnameinfo(sockaddr, socket.NI_NAMEREQD) @@ -298,11 +298,11 @@ def _canonicalize_hostname(hostname): def _authenticate_gssapi(credentials, sock_info): - """Authenticate using GSSAPI. - """ + """Authenticate using GSSAPI.""" if not HAVE_KERBEROS: - raise ConfigurationError('The "kerberos" module must be ' - 'installed to use GSSAPI authentication.') + raise ConfigurationError( + 'The "kerberos" module must be ' "installed to use GSSAPI authentication." + ) try: username = credentials.username @@ -313,9 +313,9 @@ def _authenticate_gssapi(credentials, sock_info): host = sock_info.address[0] if props.canonicalize_host_name: host = _canonicalize_hostname(host) - service = props.service_name + '@' + host + service = props.service_name + "@" + host if props.service_realm is not None: - service = service + '@' + props.service_realm + service = service + "@" + props.service_realm if password is not None: if _USE_PRINCIPAL: @@ -324,81 +324,88 @@ def _authenticate_gssapi(credentials, sock_info): # by WinKerberos) doesn't support +. principal = ":".join((quote(username), quote(password))) result, ctx = kerberos.authGSSClientInit( - service, principal, gssflags=kerberos.GSS_C_MUTUAL_FLAG) + service, principal, gssflags=kerberos.GSS_C_MUTUAL_FLAG + ) else: - if '@' in username: - user, domain = username.split('@', 1) + if "@" in username: + user, domain = username.split("@", 1) else: user, domain = username, None result, ctx = kerberos.authGSSClientInit( - service, gssflags=kerberos.GSS_C_MUTUAL_FLAG, - user=user, domain=domain, password=password) + service, + gssflags=kerberos.GSS_C_MUTUAL_FLAG, + user=user, + domain=domain, + password=password, + ) else: - result, ctx = kerberos.authGSSClientInit( - service, gssflags=kerberos.GSS_C_MUTUAL_FLAG) + result, ctx = kerberos.authGSSClientInit(service, gssflags=kerberos.GSS_C_MUTUAL_FLAG) if result != kerberos.AUTH_GSS_COMPLETE: - raise OperationFailure('Kerberos context failed to initialize.') + raise OperationFailure("Kerberos context failed to initialize.") try: # pykerberos uses a weird mix of exceptions and return values # to indicate errors. # 0 == continue, 1 == complete, -1 == error # Only authGSSClientStep can return 0. - if kerberos.authGSSClientStep(ctx, '') != 0: - raise OperationFailure('Unknown kerberos ' - 'failure in step function.') + if kerberos.authGSSClientStep(ctx, "") != 0: + raise OperationFailure("Unknown kerberos " "failure in step function.") # Start a SASL conversation with mongod/s # Note: pykerberos deals with base64 encoded byte strings. # Since mongo accepts base64 strings as the payload we don't # have to use bson.binary.Binary. payload = kerberos.authGSSClientResponse(ctx) - cmd = SON([('saslStart', 1), - ('mechanism', 'GSSAPI'), - ('payload', payload), - ('autoAuthorize', 1)]) - response = sock_info.command('$external', cmd) + cmd = SON( + [ + ("saslStart", 1), + ("mechanism", "GSSAPI"), + ("payload", payload), + ("autoAuthorize", 1), + ] + ) + response = sock_info.command("$external", cmd) # Limit how many times we loop to catch protocol / library issues for _ in range(10): - result = kerberos.authGSSClientStep(ctx, - str(response['payload'])) + result = kerberos.authGSSClientStep(ctx, str(response["payload"])) if result == -1: - raise OperationFailure('Unknown kerberos ' - 'failure in step function.') + raise OperationFailure("Unknown kerberos " "failure in step function.") - payload = kerberos.authGSSClientResponse(ctx) or '' + payload = kerberos.authGSSClientResponse(ctx) or "" - cmd = SON([('saslContinue', 1), - ('conversationId', response['conversationId']), - ('payload', payload)]) - response = sock_info.command('$external', cmd) + cmd = SON( + [ + ("saslContinue", 1), + ("conversationId", response["conversationId"]), + ("payload", payload), + ] + ) + response = sock_info.command("$external", cmd) if result == kerberos.AUTH_GSS_COMPLETE: break else: - raise OperationFailure('Kerberos ' - 'authentication failed to complete.') + raise OperationFailure("Kerberos " "authentication failed to complete.") # Once the security context is established actually authenticate. # See RFC 4752, Section 3.1, last two paragraphs. - if kerberos.authGSSClientUnwrap(ctx, - str(response['payload'])) != 1: - raise OperationFailure('Unknown kerberos ' - 'failure during GSS_Unwrap step.') + if kerberos.authGSSClientUnwrap(ctx, str(response["payload"])) != 1: + raise OperationFailure("Unknown kerberos " "failure during GSS_Unwrap step.") - if kerberos.authGSSClientWrap(ctx, - kerberos.authGSSClientResponse(ctx), - username) != 1: - raise OperationFailure('Unknown kerberos ' - 'failure during GSS_Wrap step.') + if kerberos.authGSSClientWrap(ctx, kerberos.authGSSClientResponse(ctx), username) != 1: + raise OperationFailure("Unknown kerberos " "failure during GSS_Wrap step.") payload = kerberos.authGSSClientResponse(ctx) - cmd = SON([('saslContinue', 1), - ('conversationId', response['conversationId']), - ('payload', payload)]) - sock_info.command('$external', cmd) + cmd = SON( + [ + ("saslContinue", 1), + ("conversationId", response["conversationId"]), + ("payload", payload), + ] + ) + sock_info.command("$external", cmd) finally: kerberos.authGSSClientClean(ctx) @@ -408,47 +415,45 @@ def _authenticate_gssapi(credentials, sock_info): def _authenticate_plain(credentials, sock_info): - """Authenticate using SASL PLAIN (RFC 4616) - """ + """Authenticate using SASL PLAIN (RFC 4616)""" source = credentials.source username = credentials.username password = credentials.password - payload = ('\x00%s\x00%s' % (username, password)).encode('utf-8') - cmd = SON([('saslStart', 1), - ('mechanism', 'PLAIN'), - ('payload', Binary(payload)), - ('autoAuthorize', 1)]) + payload = ("\x00%s\x00%s" % (username, password)).encode("utf-8") + cmd = SON( + [ + ("saslStart", 1), + ("mechanism", "PLAIN"), + ("payload", Binary(payload)), + ("autoAuthorize", 1), + ] + ) sock_info.command(source, cmd) def _authenticate_x509(credentials, sock_info): - """Authenticate using MONGODB-X509. - """ + """Authenticate using MONGODB-X509.""" ctx = sock_info.auth_ctx.get(credentials) if ctx and ctx.speculate_succeeded(): # MONGODB-X509 is done after the speculative auth step. return cmd = _X509Context(credentials).speculate_command() - sock_info.command('$external', cmd) + sock_info.command("$external", cmd) def _authenticate_mongo_cr(credentials, sock_info): - """Authenticate using MONGODB-CR. - """ + """Authenticate using MONGODB-CR.""" source = credentials.source username = credentials.username password = credentials.password # Get a nonce - response = sock_info.command(source, {'getnonce': 1}) - nonce = response['nonce'] + response = sock_info.command(source, {"getnonce": 1}) + nonce = response["nonce"] key = _auth_key(nonce, username, password) # Actually authenticate - query = SON([('authenticate', 1), - ('user', username), - ('nonce', nonce), - ('key', key)]) + query = SON([("authenticate", 1), ("user", username), ("nonce", nonce), ("key", key)]) sock_info.command(source, query) @@ -459,29 +464,27 @@ def _authenticate_default(credentials, sock_info): else: source = credentials.source cmd = sock_info.hello_cmd() - cmd['saslSupportedMechs'] = source + '.' + credentials.username - mechs = sock_info.command( - source, cmd, publish_events=False).get( - 'saslSupportedMechs', []) - if 'SCRAM-SHA-256' in mechs: - return _authenticate_scram(credentials, sock_info, 'SCRAM-SHA-256') + cmd["saslSupportedMechs"] = source + "." + credentials.username + mechs = sock_info.command(source, cmd, publish_events=False).get( + "saslSupportedMechs", [] + ) + if "SCRAM-SHA-256" in mechs: + return _authenticate_scram(credentials, sock_info, "SCRAM-SHA-256") else: - return _authenticate_scram(credentials, sock_info, 'SCRAM-SHA-1') + return _authenticate_scram(credentials, sock_info, "SCRAM-SHA-1") else: - return _authenticate_scram(credentials, sock_info, 'SCRAM-SHA-1') + return _authenticate_scram(credentials, sock_info, "SCRAM-SHA-1") _AUTH_MAP = { - 'GSSAPI': _authenticate_gssapi, - 'MONGODB-CR': _authenticate_mongo_cr, - 'MONGODB-X509': _authenticate_x509, - 'MONGODB-AWS': _authenticate_aws, - 'PLAIN': _authenticate_plain, - 'SCRAM-SHA-1': functools.partial( - _authenticate_scram, mechanism='SCRAM-SHA-1'), - 'SCRAM-SHA-256': functools.partial( - _authenticate_scram, mechanism='SCRAM-SHA-256'), - 'DEFAULT': _authenticate_default, + "GSSAPI": _authenticate_gssapi, + "MONGODB-CR": _authenticate_mongo_cr, + "MONGODB-X509": _authenticate_x509, + "MONGODB-AWS": _authenticate_aws, + "PLAIN": _authenticate_plain, + "SCRAM-SHA-1": functools.partial(_authenticate_scram, mechanism="SCRAM-SHA-1"), + "SCRAM-SHA-256": functools.partial(_authenticate_scram, mechanism="SCRAM-SHA-256"), + "DEFAULT": _authenticate_default, } @@ -514,10 +517,9 @@ def __init__(self, credentials, mechanism): self.mechanism = mechanism def speculate_command(self): - nonce, first_bare, cmd = _authenticate_scram_start( - self.credentials, self.mechanism) + nonce, first_bare, cmd = _authenticate_scram_start(self.credentials, self.mechanism) # The 'db' field is included only on the speculative command. - cmd['db'] = self.credentials.source + cmd["db"] = self.credentials.source # Save for later use. self.scram_data = (nonce, first_bare) return cmd @@ -525,19 +527,17 @@ def speculate_command(self): class _X509Context(_AuthContext): def speculate_command(self): - cmd = SON([('authenticate', 1), - ('mechanism', 'MONGODB-X509')]) + cmd = SON([("authenticate", 1), ("mechanism", "MONGODB-X509")]) if self.credentials.username is not None: - cmd['user'] = self.credentials.username + cmd["user"] = self.credentials.username return cmd _SPECULATIVE_AUTH_MAP = { - 'MONGODB-X509': _X509Context, - 'SCRAM-SHA-1': functools.partial(_ScramContext, mechanism='SCRAM-SHA-1'), - 'SCRAM-SHA-256': functools.partial(_ScramContext, - mechanism='SCRAM-SHA-256'), - 'DEFAULT': functools.partial(_ScramContext, mechanism='SCRAM-SHA-256'), + "MONGODB-X509": _X509Context, + "SCRAM-SHA-1": functools.partial(_ScramContext, mechanism="SCRAM-SHA-1"), + "SCRAM-SHA-256": functools.partial(_ScramContext, mechanism="SCRAM-SHA-256"), + "DEFAULT": functools.partial(_ScramContext, mechanism="SCRAM-SHA-256"), } @@ -546,4 +546,3 @@ def authenticate(credentials, sock_info): mechanism = credentials.mechanism auth_func = _AUTH_MAP.get(mechanism) auth_func(credentials, sock_info) - diff --git a/pymongo/auth_aws.py b/pymongo/auth_aws.py index ff07a12e7f..37189aa69f 100644 --- a/pymongo/auth_aws.py +++ b/pymongo/auth_aws.py @@ -16,14 +16,15 @@ try: import pymongo_auth_aws - from pymongo_auth_aws import (AwsCredential, - AwsSaslContext, - PyMongoAuthAwsError) + from pymongo_auth_aws import AwsCredential, AwsSaslContext, PyMongoAuthAwsError + _HAVE_MONGODB_AWS = True except ImportError: + class AwsSaslContext(object): def __init__(self, credentials): pass + _HAVE_MONGODB_AWS = False import bson @@ -48,38 +49,46 @@ def bson_decode(self, data): def _authenticate_aws(credentials, sock_info): - """Authenticate using MONGODB-AWS. - """ + """Authenticate using MONGODB-AWS.""" if not _HAVE_MONGODB_AWS: raise ConfigurationError( "MONGODB-AWS authentication requires pymongo-auth-aws: " - "install with: python -m pip install 'pymongo[aws]'") + "install with: python -m pip install 'pymongo[aws]'" + ) if sock_info.max_wire_version < 9: - raise ConfigurationError( - "MONGODB-AWS authentication requires MongoDB version 4.4 or later") + raise ConfigurationError("MONGODB-AWS authentication requires MongoDB version 4.4 or later") try: - ctx = _AwsSaslContext(AwsCredential( - credentials.username, credentials.password, - credentials.mechanism_properties.aws_session_token)) + ctx = _AwsSaslContext( + AwsCredential( + credentials.username, + credentials.password, + credentials.mechanism_properties.aws_session_token, + ) + ) client_payload = ctx.step(None) - client_first = SON([('saslStart', 1), - ('mechanism', 'MONGODB-AWS'), - ('payload', client_payload)]) - server_first = sock_info.command('$external', client_first) + client_first = SON( + [("saslStart", 1), ("mechanism", "MONGODB-AWS"), ("payload", client_payload)] + ) + server_first = sock_info.command("$external", client_first) res = server_first # Limit how many times we loop to catch protocol / library issues for _ in range(10): - client_payload = ctx.step(res['payload']) - cmd = SON([('saslContinue', 1), - ('conversationId', server_first['conversationId']), - ('payload', client_payload)]) - res = sock_info.command('$external', cmd) - if res['done']: + client_payload = ctx.step(res["payload"]) + cmd = SON( + [ + ("saslContinue", 1), + ("conversationId", server_first["conversationId"]), + ("payload", client_payload), + ] + ) + res = sock_info.command("$external", cmd) + if res["done"]: # SASL complete. break except PyMongoAuthAwsError as exc: # Convert to OperationFailure and include pymongo-auth-aws version. - raise OperationFailure('%s (pymongo-auth-aws version %s)' % ( - exc, pymongo_auth_aws.__version__)) + raise OperationFailure( + "%s (pymongo-auth-aws version %s)" % (exc, pymongo_auth_aws.__version__) + ) diff --git a/pymongo/bulk.py b/pymongo/bulk.py index 1bb8edf943..e394a8d598 100644 --- a/pymongo/bulk.py +++ b/pymongo/bulk.py @@ -17,31 +17,37 @@ .. versionadded:: 2.7 """ import copy - from itertools import islice from bson.objectid import ObjectId from bson.raw_bson import RawBSONDocument from bson.son import SON from pymongo.client_session import _validate_session_write_concern -from pymongo.common import (validate_is_mapping, - validate_is_document_type, - validate_ok_for_replace, - validate_ok_for_update) -from pymongo.helpers import _RETRYABLE_ERROR_CODES, _get_wce_doc from pymongo.collation import validate_collation_or_none -from pymongo.errors import (BulkWriteError, - ConfigurationError, - InvalidOperation, - OperationFailure) -from pymongo.message import (_INSERT, _UPDATE, _DELETE, - _randint, - _BulkWriteContext, - _EncryptedBulkWriteContext) +from pymongo.common import ( + validate_is_document_type, + validate_is_mapping, + validate_ok_for_replace, + validate_ok_for_update, +) +from pymongo.errors import ( + BulkWriteError, + ConfigurationError, + InvalidOperation, + OperationFailure, +) +from pymongo.helpers import _RETRYABLE_ERROR_CODES, _get_wce_doc +from pymongo.message import ( + _DELETE, + _INSERT, + _UPDATE, + _BulkWriteContext, + _EncryptedBulkWriteContext, + _randint, +) from pymongo.read_preferences import ReadPreference from pymongo.write_concern import WriteConcern - _DELETE_ALL = 0 _DELETE_ONE = 1 @@ -50,15 +56,14 @@ _UNKNOWN_ERROR = 8 _WRITE_CONCERN_ERROR = 64 -_COMMANDS = ('insert', 'update', 'delete') +_COMMANDS = ("insert", "update", "delete") class _Run(object): - """Represents a batch of write operations. - """ + """Represents a batch of write operations.""" + def __init__(self, op_type): - """Initialize a new Run object. - """ + """Initialize a new Run object.""" self.op_type = op_type self.index_map = [] self.ops = [] @@ -85,8 +90,7 @@ def add(self, original_index, operation): def _merge_command(run, full_result, offset, result): - """Merge a write command result into the full bulk result. - """ + """Merge a write command result into the full bulk result.""" affected = result.get("n", 0) if run.op_type == _INSERT: @@ -103,7 +107,7 @@ def _merge_command(run, full_result, offset, result): doc["index"] = run.index(doc["index"] + offset) full_result["upserted"].extend(upserted) full_result["nUpserted"] += n_upserted - full_result["nMatched"] += (affected - n_upserted) + full_result["nMatched"] += affected - n_upserted else: full_result["nMatched"] += affected full_result["nModified"] += result["nModified"] @@ -125,24 +129,22 @@ def _merge_command(run, full_result, offset, result): def _raise_bulk_write_error(full_result): - """Raise a BulkWriteError from the full bulk api result. - """ + """Raise a BulkWriteError from the full bulk api result.""" if full_result["writeErrors"]: - full_result["writeErrors"].sort( - key=lambda error: error["index"]) + full_result["writeErrors"].sort(key=lambda error: error["index"]) raise BulkWriteError(full_result) class _Bulk(object): - """The private guts of the bulk write API. - """ + """The private guts of the bulk write API.""" + def __init__(self, collection, ordered, bypass_document_validation): - """Initialize a _Bulk instance. - """ + """Initialize a _Bulk instance.""" self.collection = collection.with_options( codec_options=collection.codec_options._replace( - unicode_decode_error_handler='replace', - document_class=dict)) + unicode_decode_error_handler="replace", document_class=dict + ) + ) self.ordered = ordered self.ops = [] self.executed = False @@ -165,63 +167,64 @@ def bulk_ctx_class(self): return _BulkWriteContext def add_insert(self, document): - """Add an insert document to the list of ops. - """ + """Add an insert document to the list of ops.""" validate_is_document_type("document", document) # Generate ObjectId client side. - if not (isinstance(document, RawBSONDocument) or '_id' in document): - document['_id'] = ObjectId() + if not (isinstance(document, RawBSONDocument) or "_id" in document): + document["_id"] = ObjectId() self.ops.append((_INSERT, document)) - def add_update(self, selector, update, multi=False, upsert=False, - collation=None, array_filters=None, hint=None): - """Create an update document and add it to the list of ops. - """ + def add_update( + self, + selector, + update, + multi=False, + upsert=False, + collation=None, + array_filters=None, + hint=None, + ): + """Create an update document and add it to the list of ops.""" validate_ok_for_update(update) - cmd = SON([('q', selector), ('u', update), - ('multi', multi), ('upsert', upsert)]) + cmd = SON([("q", selector), ("u", update), ("multi", multi), ("upsert", upsert)]) collation = validate_collation_or_none(collation) if collation is not None: self.uses_collation = True - cmd['collation'] = collation + cmd["collation"] = collation if array_filters is not None: self.uses_array_filters = True - cmd['arrayFilters'] = array_filters + cmd["arrayFilters"] = array_filters if hint is not None: self.uses_hint = True - cmd['hint'] = hint + cmd["hint"] = hint if multi: # A bulk_write containing an update_many is not retryable. self.is_retryable = False self.ops.append((_UPDATE, cmd)) - def add_replace(self, selector, replacement, upsert=False, - collation=None, hint=None): - """Create a replace document and add it to the list of ops. - """ + def add_replace(self, selector, replacement, upsert=False, collation=None, hint=None): + """Create a replace document and add it to the list of ops.""" validate_ok_for_replace(replacement) - cmd = SON([('q', selector), ('u', replacement), - ('multi', False), ('upsert', upsert)]) + cmd = SON([("q", selector), ("u", replacement), ("multi", False), ("upsert", upsert)]) collation = validate_collation_or_none(collation) if collation is not None: self.uses_collation = True - cmd['collation'] = collation + cmd["collation"] = collation if hint is not None: self.uses_hint = True - cmd['hint'] = hint + cmd["hint"] = hint self.ops.append((_UPDATE, cmd)) def add_delete(self, selector, limit, collation=None, hint=None): - """Create a delete document and add it to the list of ops. - """ - cmd = SON([('q', selector), ('limit', limit)]) + """Create a delete document and add it to the list of ops.""" + cmd = SON([("q", selector), ("limit", limit)]) collation = validate_collation_or_none(collation) if collation is not None: self.uses_collation = True - cmd['collation'] = collation + cmd["collation"] = collation if hint is not None: self.uses_hint = True - cmd['hint'] = hint + cmd["hint"] = hint if limit == _DELETE_ALL: # A bulk_write containing a delete_many is not retryable. self.is_retryable = False @@ -253,8 +256,9 @@ def gen_unordered(self): if run.ops: yield run - def _execute_command(self, generator, write_concern, session, - sock_info, op_id, retryable, full_result): + def _execute_command( + self, generator, write_concern, session, sock_info, op_id, retryable, full_result + ): db_name = self.collection.database.name client = self.collection.database.client listeners = client._event_listeners @@ -269,24 +273,29 @@ def _execute_command(self, generator, write_concern, session, while run: cmd_name = _COMMANDS[run.op_type] bwc = self.bulk_ctx_class( - db_name, cmd_name, sock_info, op_id, listeners, session, - run.op_type, self.collection.codec_options) + db_name, + cmd_name, + sock_info, + op_id, + listeners, + session, + run.op_type, + self.collection.codec_options, + ) while run.idx_offset < len(run.ops): - cmd = SON([(cmd_name, self.collection.name), - ('ordered', self.ordered)]) + cmd = SON([(cmd_name, self.collection.name), ("ordered", self.ordered)]) if not write_concern.is_server_default: - cmd['writeConcern'] = write_concern.document + cmd["writeConcern"] = write_concern.document if self.bypass_doc_val: - cmd['bypassDocumentValidation'] = True + cmd["bypassDocumentValidation"] = True if session: # Start a new retryable write unless one was already # started for this command. if retryable and not self.started_retryable_write: session._start_retryable_write() self.started_retryable_write = True - session._apply_to(cmd, retryable, ReadPreference.PRIMARY, - sock_info) + session._apply_to(cmd, retryable, ReadPreference.PRIMARY, sock_info) sock_info.send_cluster_time(cmd, session, client) sock_info.add_server_api(cmd) ops = islice(run.ops, run.idx_offset, None) @@ -294,8 +303,8 @@ def _execute_command(self, generator, write_concern, session, result, to_send = bwc.execute(cmd, ops, client) # Retryable writeConcernErrors halt the execution of this run. - wce = result.get('writeConcernError', {}) - if wce.get('code', 0) in _RETRYABLE_ERROR_CODES: + wce = result.get("writeConcernError", {}) + if wce.get("code", 0) in _RETRYABLE_ERROR_CODES: # Synthesize the full bulk result without modifying the # current one because this write operation may be retried. full = copy.deepcopy(full_result) @@ -313,14 +322,13 @@ def _execute_command(self, generator, write_concern, session, # We're supposed to continue if errors are # at the write concern level (e.g. wtimeout) - if self.ordered and full_result['writeErrors']: + if self.ordered and full_result["writeErrors"]: break # Reset our state self.current_run = run = next(generator, None) def execute_command(self, generator, write_concern, session): - """Execute using write commands. - """ + """Execute using write commands.""" # nModified is only reported for write commands, not legacy ops. full_result = { "writeErrors": [], @@ -336,21 +344,19 @@ def execute_command(self, generator, write_concern, session): def retryable_bulk(session, sock_info, retryable): self._execute_command( - generator, write_concern, session, sock_info, op_id, - retryable, full_result) + generator, write_concern, session, sock_info, op_id, retryable, full_result + ) client = self.collection.database.client with client._tmp_session(session) as s: - client._retry_with_session( - self.is_retryable, retryable_bulk, s, self) + client._retry_with_session(self.is_retryable, retryable_bulk, s, self) if full_result["writeErrors"] or full_result["writeConcernErrors"]: _raise_bulk_write_error(full_result) return full_result def execute_op_msg_no_results(self, sock_info, generator): - """Execute write commands with OP_MSG and w=0 writeConcern, unordered. - """ + """Execute write commands with OP_MSG and w=0 writeConcern, unordered.""" db_name = self.collection.database.name client = self.collection.database.client listeners = client._event_listeners @@ -363,13 +369,24 @@ def execute_op_msg_no_results(self, sock_info, generator): while run: cmd_name = _COMMANDS[run.op_type] bwc = self.bulk_ctx_class( - db_name, cmd_name, sock_info, op_id, listeners, None, - run.op_type, self.collection.codec_options) + db_name, + cmd_name, + sock_info, + op_id, + listeners, + None, + run.op_type, + self.collection.codec_options, + ) while run.idx_offset < len(run.ops): - cmd = SON([(cmd_name, self.collection.name), - ('ordered', False), - ('writeConcern', {'w': 0})]) + cmd = SON( + [ + (cmd_name, self.collection.name), + ("ordered", False), + ("writeConcern", {"w": 0}), + ] + ) sock_info.add_server_api(cmd) ops = islice(run.ops, run.idx_offset, None) # Run as many ops as possible. @@ -378,8 +395,7 @@ def execute_op_msg_no_results(self, sock_info, generator): self.current_run = run = next(generator, None) def execute_command_no_results(self, sock_info, generator): - """Execute write commands with OP_MSG and w=0 WriteConcern, ordered. - """ + """Execute write commands with OP_MSG and w=0 WriteConcern, ordered.""" full_result = { "writeErrors": [], "writeConcernErrors": [], @@ -397,40 +413,35 @@ def execute_command_no_results(self, sock_info, generator): op_id = _randint() try: self._execute_command( - generator, write_concern, None, - sock_info, op_id, False, full_result) + generator, write_concern, None, sock_info, op_id, False, full_result + ) except OperationFailure: pass def execute_no_results(self, sock_info, generator): - """Execute all operations, returning no results (w=0). - """ + """Execute all operations, returning no results (w=0).""" if self.uses_collation: - raise ConfigurationError( - 'Collation is unsupported for unacknowledged writes.') + raise ConfigurationError("Collation is unsupported for unacknowledged writes.") if self.uses_array_filters: - raise ConfigurationError( - 'arrayFilters is unsupported for unacknowledged writes.') + raise ConfigurationError("arrayFilters is unsupported for unacknowledged writes.") if self.uses_hint: - raise ConfigurationError( - 'hint is unsupported for unacknowledged writes.') + raise ConfigurationError("hint is unsupported for unacknowledged writes.") # Cannot have both unacknowledged writes and bypass document validation. if self.bypass_doc_val: - raise OperationFailure("Cannot set bypass_document_validation with" - " unacknowledged write concern") + raise OperationFailure( + "Cannot set bypass_document_validation with" " unacknowledged write concern" + ) if self.ordered: return self.execute_command_no_results(sock_info, generator) return self.execute_op_msg_no_results(sock_info, generator) def execute(self, write_concern, session): - """Execute operations. - """ + """Execute operations.""" if not self.ops: - raise InvalidOperation('No operations to execute') + raise InvalidOperation("No operations to execute") if self.executed: - raise InvalidOperation('Bulk operations can ' - 'only be executed once.') + raise InvalidOperation("Bulk operations can " "only be executed once.") self.executed = True write_concern = write_concern or self.collection.write_concern session = _validate_session_write_concern(session, write_concern) diff --git a/pymongo/change_stream.py b/pymongo/change_stream.py index 00d049a838..7266696416 100644 --- a/pymongo/change_stream.py +++ b/pymongo/change_stream.py @@ -18,41 +18,45 @@ from bson import _bson_to_dict from bson.raw_bson import RawBSONDocument - from pymongo import common -from pymongo.aggregation import (_CollectionAggregationCommand, - _DatabaseAggregationCommand) +from pymongo.aggregation import ( + _CollectionAggregationCommand, + _DatabaseAggregationCommand, +) from pymongo.collation import validate_collation_or_none from pymongo.command_cursor import CommandCursor -from pymongo.errors import (ConnectionFailure, - CursorNotFound, - InvalidOperation, - OperationFailure, - PyMongoError) - +from pymongo.errors import ( + ConnectionFailure, + CursorNotFound, + InvalidOperation, + OperationFailure, + PyMongoError, +) # The change streams spec considers the following server errors from the # getMore command non-resumable. All other getMore errors are resumable. -_RESUMABLE_GETMORE_ERRORS = frozenset([ - 6, # HostUnreachable - 7, # HostNotFound - 89, # NetworkTimeout - 91, # ShutdownInProgress - 189, # PrimarySteppedDown - 262, # ExceededTimeLimit - 9001, # SocketException - 10107, # NotWritablePrimary - 11600, # InterruptedAtShutdown - 11602, # InterruptedDueToReplStateChange - 13435, # NotPrimaryNoSecondaryOk - 13436, # NotPrimaryOrSecondary - 63, # StaleShardVersion - 150, # StaleEpoch - 13388, # StaleConfig - 234, # RetryChangeStream - 133, # FailedToSatisfyReadPreference - 216, # ElectionInProgress -]) +_RESUMABLE_GETMORE_ERRORS = frozenset( + [ + 6, # HostUnreachable + 7, # HostNotFound + 89, # NetworkTimeout + 91, # ShutdownInProgress + 189, # PrimarySteppedDown + 262, # ExceededTimeLimit + 9001, # SocketException + 10107, # NotWritablePrimary + 11600, # InterruptedAtShutdown + 11602, # InterruptedDueToReplStateChange + 13435, # NotPrimaryNoSecondaryOk + 13436, # NotPrimaryOrSecondary + 63, # StaleShardVersion + 150, # StaleEpoch + 13388, # StaleConfig + 234, # RetryChangeStream + 133, # FailedToSatisfyReadPreference + 216, # ElectionInProgress + ] +) class ChangeStream(object): @@ -66,15 +70,26 @@ class ChangeStream(object): .. versionadded:: 3.6 .. seealso:: The MongoDB documentation on `changeStreams `_. """ - def __init__(self, target, pipeline, full_document, resume_after, - max_await_time_ms, batch_size, collation, - start_at_operation_time, session, start_after): + + def __init__( + self, + target, + pipeline, + full_document, + resume_after, + max_await_time_ms, + batch_size, + collation, + start_at_operation_time, + session, + start_after, + ): if pipeline is None: pipeline = [] elif not isinstance(pipeline, list): raise TypeError("pipeline must be a list") - common.validate_string_or_none('full_document', full_document) + common.validate_string_or_none("full_document", full_document) validate_collation_or_none(collation) common.validate_non_negative_integer_or_none("batchSize", batch_size) @@ -85,8 +100,8 @@ def __init__(self, target, pipeline, full_document, resume_after, # Keep the type registry so that we support encoding custom types # in the pipeline. self._target = target.with_options( - codec_options=target.codec_options.with_options( - document_class=RawBSONDocument)) + codec_options=target.codec_options.with_options(document_class=RawBSONDocument) + ) else: self._target = target @@ -112,24 +127,24 @@ def _aggregation_command_class(self): @property def _client(self): """The client against which the aggregation commands for - this ChangeStream will be run. """ + this ChangeStream will be run.""" raise NotImplementedError def _change_stream_options(self): """Return the options dict for the $changeStream pipeline stage.""" options = {} if self._full_document is not None: - options['fullDocument'] = self._full_document + options["fullDocument"] = self._full_document resume_token = self.resume_token if resume_token is not None: if self._uses_start_after: - options['startAfter'] = resume_token + options["startAfter"] = resume_token else: - options['resumeAfter'] = resume_token + options["resumeAfter"] = resume_token if self._start_at_operation_time is not None: - options['startAtOperationTime'] = self._start_at_operation_time + options["startAtOperationTime"] = self._start_at_operation_time return options def _command_options(self): @@ -144,7 +159,7 @@ def _command_options(self): def _aggregation_pipeline(self): """Return the full aggregation pipeline for this ChangeStream.""" options = self._change_stream_options() - full_pipeline = [{'$changeStream': options}] + full_pipeline = [{"$changeStream": options}] full_pipeline.extend(self._pipeline) return full_pipeline @@ -156,38 +171,43 @@ def _process_result(self, result, session, server, sock_info, secondary_ok): This is implemented as a callback because we need access to the wire version in order to determine whether to cache this value. """ - if not result['cursor']['firstBatch']: - if 'postBatchResumeToken' in result['cursor']: - self._resume_token = result['cursor']['postBatchResumeToken'] - elif (self._start_at_operation_time is None and - self._uses_resume_after is False and - self._uses_start_after is False and - sock_info.max_wire_version >= 7): + if not result["cursor"]["firstBatch"]: + if "postBatchResumeToken" in result["cursor"]: + self._resume_token = result["cursor"]["postBatchResumeToken"] + elif ( + self._start_at_operation_time is None + and self._uses_resume_after is False + and self._uses_start_after is False + and sock_info.max_wire_version >= 7 + ): self._start_at_operation_time = result.get("operationTime") # PYTHON-2181: informative error on missing operationTime. if self._start_at_operation_time is None: raise OperationFailure( "Expected field 'operationTime' missing from command " - "response : %r" % (result, )) + "response : %r" % (result,) + ) def _run_aggregation_cmd(self, session, explicit_session): """Run the full aggregation pipeline for this ChangeStream and return the corresponding CommandCursor. """ cmd = self._aggregation_command_class( - self._target, CommandCursor, self._aggregation_pipeline(), - self._command_options(), explicit_session, - result_processor=self._process_result) + self._target, + CommandCursor, + self._aggregation_pipeline(), + self._command_options(), + explicit_session, + result_processor=self._process_result, + ) return self._client._retryable_read( - cmd.get_cursor, self._target._read_preference_for(session), - session) + cmd.get_cursor, self._target._read_preference_for(session), session + ) def _create_cursor(self): with self._client._tmp_session(self._session, close=False) as s: - return self._run_aggregation_cmd( - session=s, - explicit_session=self._session is not None) + return self._run_aggregation_cmd(session=s, explicit_session=self._session is not None) def _resume(self): """Reestablish this change stream after a resumable error.""" @@ -307,10 +327,9 @@ def try_next(self): except OperationFailure as exc: if exc._max_wire_version is None: raise - is_resumable = ((exc._max_wire_version >= 9 and - exc.has_error_label("ResumableChangeStreamError")) or - (exc._max_wire_version < 9 and - exc.code in _RESUMABLE_GETMORE_ERRORS)) + is_resumable = ( + exc._max_wire_version >= 9 and exc.has_error_label("ResumableChangeStreamError") + ) or (exc._max_wire_version < 9 and exc.code in _RESUMABLE_GETMORE_ERRORS) if not is_resumable: raise self._resume() @@ -329,17 +348,16 @@ def try_next(self): # Else, changes are available. try: - resume_token = change['_id'] + resume_token = change["_id"] except KeyError: self.close() raise InvalidOperation( - "Cannot provide resume functionality when the resume " - "token is missing.") + "Cannot provide resume functionality when the resume " "token is missing." + ) # If this is the last change document from the current batch, cache the # postBatchResumeToken. - if (not self._cursor._has_next() and - self._cursor._post_batch_resume_token): + if not self._cursor._has_next() and self._cursor._post_batch_resume_token: resume_token = self._cursor._post_batch_resume_token # Hereafter, don't use startAfter; instead use resumeAfter. @@ -369,6 +387,7 @@ class CollectionChangeStream(ChangeStream): .. versionadded:: 3.7 """ + @property def _aggregation_command_class(self): return _CollectionAggregationCommand @@ -386,6 +405,7 @@ class DatabaseChangeStream(ChangeStream): .. versionadded:: 3.7 """ + @property def _aggregation_command_class(self): return _DatabaseAggregationCommand @@ -403,6 +423,7 @@ class ClusterChangeStream(DatabaseChangeStream): .. versionadded:: 3.7 """ + def _change_stream_options(self): options = super(ClusterChangeStream, self)._change_stream_options() options["allChangesForCluster"] = True diff --git a/pymongo/client_options.py b/pymongo/client_options.py index f7dbf255bc..d197fb8816 100644 --- a/pymongo/client_options.py +++ b/pymongo/client_options.py @@ -15,16 +15,15 @@ """Tools to parse mongo client options.""" from bson.codec_options import _parse_codec_options +from pymongo import common from pymongo.auth import _build_credentials_tuple from pymongo.common import validate_boolean -from pymongo import common from pymongo.compression_support import CompressionSettings from pymongo.errors import ConfigurationError from pymongo.monitoring import _EventListeners from pymongo.pool import PoolOptions from pymongo.read_concern import ReadConcern -from pymongo.read_preferences import (make_read_preference, - read_pref_mode_from_name) +from pymongo.read_preferences import make_read_preference, read_pref_mode_from_name from pymongo.server_selectors import any_server_selector from pymongo.ssl_support import get_ssl_context from pymongo.write_concern import WriteConcern @@ -32,63 +31,69 @@ def _parse_credentials(username, password, database, options): """Parse authentication credentials.""" - mechanism = options.get('authmechanism', 'DEFAULT' if username else None) - source = options.get('authsource') + mechanism = options.get("authmechanism", "DEFAULT" if username else None) + source = options.get("authsource") if username or mechanism: - return _build_credentials_tuple( - mechanism, source, username, password, options, database) + return _build_credentials_tuple(mechanism, source, username, password, options, database) return None def _parse_read_preference(options): """Parse read preference options.""" - if 'read_preference' in options: - return options['read_preference'] + if "read_preference" in options: + return options["read_preference"] - name = options.get('readpreference', 'primary') + name = options.get("readpreference", "primary") mode = read_pref_mode_from_name(name) - tags = options.get('readpreferencetags') - max_staleness = options.get('maxstalenessseconds', -1) + tags = options.get("readpreferencetags") + max_staleness = options.get("maxstalenessseconds", -1) return make_read_preference(mode, tags, max_staleness) def _parse_write_concern(options): """Parse write concern options.""" - concern = options.get('w') - wtimeout = options.get('wtimeoutms') - j = options.get('journal') - fsync = options.get('fsync') + concern = options.get("w") + wtimeout = options.get("wtimeoutms") + j = options.get("journal") + fsync = options.get("fsync") return WriteConcern(concern, wtimeout, j, fsync) def _parse_read_concern(options): """Parse read concern options.""" - concern = options.get('readconcernlevel') + concern = options.get("readconcernlevel") return ReadConcern(concern) def _parse_ssl_options(options): """Parse ssl options.""" - use_tls = options.get('tls') + use_tls = options.get("tls") if use_tls is not None: - validate_boolean('tls', use_tls) + validate_boolean("tls", use_tls) - certfile = options.get('tlscertificatekeyfile') - passphrase = options.get('tlscertificatekeyfilepassword') - ca_certs = options.get('tlscafile') - crlfile = options.get('tlscrlfile') - allow_invalid_certificates = options.get('tlsallowinvalidcertificates', False) - allow_invalid_hostnames = options.get('tlsallowinvalidhostnames', False) - disable_ocsp_endpoint_check = options.get('tlsdisableocspendpointcheck', False) + certfile = options.get("tlscertificatekeyfile") + passphrase = options.get("tlscertificatekeyfilepassword") + ca_certs = options.get("tlscafile") + crlfile = options.get("tlscrlfile") + allow_invalid_certificates = options.get("tlsallowinvalidcertificates", False) + allow_invalid_hostnames = options.get("tlsallowinvalidhostnames", False) + disable_ocsp_endpoint_check = options.get("tlsdisableocspendpointcheck", False) enabled_tls_opts = [] - for opt in ('tlscertificatekeyfile', 'tlscertificatekeyfilepassword', - 'tlscafile', 'tlscrlfile'): + for opt in ( + "tlscertificatekeyfile", + "tlscertificatekeyfilepassword", + "tlscafile", + "tlscrlfile", + ): # Any non-null value of these options implies tls=True. if opt in options and options[opt]: enabled_tls_opts.append(opt) - for opt in ('tlsallowinvalidcertificates', 'tlsallowinvalidhostnames', - 'tlsdisableocspendpointcheck'): + for opt in ( + "tlsallowinvalidcertificates", + "tlsallowinvalidhostnames", + "tlsdisableocspendpointcheck", + ): # A value of False for these options implies tls=True. if opt in options and not options[opt]: enabled_tls_opts.append(opt) @@ -99,10 +104,11 @@ def _parse_ssl_options(options): use_tls = True elif not use_tls: # Error since tls is explicitly disabled but a tls option is set. - raise ConfigurationError("TLS has not been enabled but the " - "following tls parameters have been set: " - "%s. Please set `tls=True` or remove." - % ', '.join(enabled_tls_opts)) + raise ConfigurationError( + "TLS has not been enabled but the " + "following tls parameters have been set: " + "%s. Please set `tls=True` or remove." % ", ".join(enabled_tls_opts) + ) if use_tls: ctx = get_ssl_context( @@ -112,46 +118,49 @@ def _parse_ssl_options(options): crlfile, allow_invalid_certificates, allow_invalid_hostnames, - disable_ocsp_endpoint_check) + disable_ocsp_endpoint_check, + ) return ctx, allow_invalid_hostnames return None, allow_invalid_hostnames def _parse_pool_options(options): """Parse connection pool options.""" - max_pool_size = options.get('maxpoolsize', common.MAX_POOL_SIZE) - min_pool_size = options.get('minpoolsize', common.MIN_POOL_SIZE) - max_idle_time_seconds = options.get( - 'maxidletimems', common.MAX_IDLE_TIME_SEC) + max_pool_size = options.get("maxpoolsize", common.MAX_POOL_SIZE) + min_pool_size = options.get("minpoolsize", common.MIN_POOL_SIZE) + max_idle_time_seconds = options.get("maxidletimems", common.MAX_IDLE_TIME_SEC) if max_pool_size is not None and min_pool_size > max_pool_size: raise ValueError("minPoolSize must be smaller or equal to maxPoolSize") - connect_timeout = options.get('connecttimeoutms', common.CONNECT_TIMEOUT) - socket_timeout = options.get('sockettimeoutms') - wait_queue_timeout = options.get( - 'waitqueuetimeoutms', common.WAIT_QUEUE_TIMEOUT) - event_listeners = options.get('event_listeners') - appname = options.get('appname') - driver = options.get('driver') - server_api = options.get('server_api') + connect_timeout = options.get("connecttimeoutms", common.CONNECT_TIMEOUT) + socket_timeout = options.get("sockettimeoutms") + wait_queue_timeout = options.get("waitqueuetimeoutms", common.WAIT_QUEUE_TIMEOUT) + event_listeners = options.get("event_listeners") + appname = options.get("appname") + driver = options.get("driver") + server_api = options.get("server_api") compression_settings = CompressionSettings( - options.get('compressors', []), - options.get('zlibcompressionlevel', -1)) + options.get("compressors", []), options.get("zlibcompressionlevel", -1) + ) ssl_context, tls_allow_invalid_hostnames = _parse_ssl_options(options) - load_balanced = options.get('loadbalanced') - max_connecting = options.get('maxconnecting', common.MAX_CONNECTING) - return PoolOptions(max_pool_size, - min_pool_size, - max_idle_time_seconds, - connect_timeout, socket_timeout, - wait_queue_timeout, - ssl_context, tls_allow_invalid_hostnames, - _EventListeners(event_listeners), - appname, - driver, - compression_settings, - max_connecting=max_connecting, - server_api=server_api, - load_balanced=load_balanced) + load_balanced = options.get("loadbalanced") + max_connecting = options.get("maxconnecting", common.MAX_CONNECTING) + return PoolOptions( + max_pool_size, + min_pool_size, + max_idle_time_seconds, + connect_timeout, + socket_timeout, + wait_queue_timeout, + ssl_context, + tls_allow_invalid_hostnames, + _EventListeners(event_listeners), + appname, + driver, + compression_settings, + max_connecting=max_connecting, + server_api=server_api, + load_balanced=load_balanced, + ) class ClientOptions(object): @@ -166,29 +175,26 @@ def __init__(self, username, password, database, options): self.__options = options self.__codec_options = _parse_codec_options(options) - self.__credentials = _parse_credentials( - username, password, database, options) - self.__direct_connection = options.get('directconnection') - self.__local_threshold_ms = options.get( - 'localthresholdms', common.LOCAL_THRESHOLD_MS) + self.__credentials = _parse_credentials(username, password, database, options) + self.__direct_connection = options.get("directconnection") + self.__local_threshold_ms = options.get("localthresholdms", common.LOCAL_THRESHOLD_MS) # self.__server_selection_timeout is in seconds. Must use full name for # common.SERVER_SELECTION_TIMEOUT because it is set directly by tests. self.__server_selection_timeout = options.get( - 'serverselectiontimeoutms', common.SERVER_SELECTION_TIMEOUT) + "serverselectiontimeoutms", common.SERVER_SELECTION_TIMEOUT + ) self.__pool_options = _parse_pool_options(options) self.__read_preference = _parse_read_preference(options) - self.__replica_set_name = options.get('replicaset') + self.__replica_set_name = options.get("replicaset") self.__write_concern = _parse_write_concern(options) self.__read_concern = _parse_read_concern(options) - self.__connect = options.get('connect') - self.__heartbeat_frequency = options.get( - 'heartbeatfrequencyms', common.HEARTBEAT_FREQUENCY) - self.__retry_writes = options.get('retrywrites', common.RETRY_WRITES) - self.__retry_reads = options.get('retryreads', common.RETRY_READS) - self.__server_selector = options.get( - 'server_selector', any_server_selector) - self.__auto_encryption_opts = options.get('auto_encryption_opts') - self.__load_balanced = options.get('loadbalanced') + self.__connect = options.get("connect") + self.__heartbeat_frequency = options.get("heartbeatfrequencyms", common.HEARTBEAT_FREQUENCY) + self.__retry_writes = options.get("retrywrites", common.RETRY_WRITES) + self.__retry_reads = options.get("retryreads", common.RETRY_READS) + self.__server_selector = options.get("server_selector", any_server_selector) + self.__auto_encryption_opts = options.get("auto_encryption_opts") + self.__load_balanced = options.get("loadbalanced") @property def _options(self): diff --git a/pymongo/client_session.py b/pymongo/client_session.py index f8071e5f2b..6c10d83eee 100644 --- a/pymongo/client_session.py +++ b/pymongo/client_session.py @@ -136,21 +136,21 @@ import collections import time import uuid - from collections.abc import Mapping as _Mapping from bson.binary import Binary from bson.int64 import Int64 from bson.son import SON from bson.timestamp import Timestamp - from pymongo.cursor import _SocketManager -from pymongo.errors import (ConfigurationError, - ConnectionFailure, - InvalidOperation, - OperationFailure, - PyMongoError, - WTimeoutError) +from pymongo.errors import ( + ConfigurationError, + ConnectionFailure, + InvalidOperation, + OperationFailure, + PyMongoError, + WTimeoutError, +) from pymongo.helpers import _RETRYABLE_ERROR_CODES from pymongo.read_concern import ReadConcern from pymongo.read_preferences import ReadPreference, _ServerMode @@ -174,14 +174,11 @@ class SessionOptions(object): .. versionchanged:: 3.12 Added the ``snapshot`` parameter. """ - def __init__(self, - causal_consistency=None, - default_transaction_options=None, - snapshot=False): + + def __init__(self, causal_consistency=None, default_transaction_options=None, snapshot=False): if snapshot: if causal_consistency: - raise ConfigurationError('snapshot reads do not support ' - 'causal_consistency=True') + raise ConfigurationError("snapshot reads do not support " "causal_consistency=True") causal_consistency = False elif causal_consistency is None: causal_consistency = True @@ -190,8 +187,9 @@ def __init__(self, if not isinstance(default_transaction_options, TransactionOptions): raise TypeError( "default_transaction_options must be an instance of " - "pymongo.client_session.TransactionOptions, not: %r" % - (default_transaction_options,)) + "pymongo.client_session.TransactionOptions, not: %r" + % (default_transaction_options,) + ) self._default_transaction_options = default_transaction_options self._snapshot = snapshot @@ -245,35 +243,41 @@ class TransactionOptions(object): .. versionadded:: 3.7 """ - def __init__(self, read_concern=None, write_concern=None, - read_preference=None, max_commit_time_ms=None): + + def __init__( + self, read_concern=None, write_concern=None, read_preference=None, max_commit_time_ms=None + ): self._read_concern = read_concern self._write_concern = write_concern self._read_preference = read_preference self._max_commit_time_ms = max_commit_time_ms if read_concern is not None: if not isinstance(read_concern, ReadConcern): - raise TypeError("read_concern must be an instance of " - "pymongo.read_concern.ReadConcern, not: %r" % - (read_concern,)) + raise TypeError( + "read_concern must be an instance of " + "pymongo.read_concern.ReadConcern, not: %r" % (read_concern,) + ) if write_concern is not None: if not isinstance(write_concern, WriteConcern): - raise TypeError("write_concern must be an instance of " - "pymongo.write_concern.WriteConcern, not: %r" % - (write_concern,)) + raise TypeError( + "write_concern must be an instance of " + "pymongo.write_concern.WriteConcern, not: %r" % (write_concern,) + ) if not write_concern.acknowledged: raise ConfigurationError( "transactions do not support unacknowledged write concern" - ": %r" % (write_concern,)) + ": %r" % (write_concern,) + ) if read_preference is not None: if not isinstance(read_preference, _ServerMode): - raise TypeError("%r is not valid for read_preference. See " - "pymongo.read_preferences for valid " - "options." % (read_preference,)) + raise TypeError( + "%r is not valid for read_preference. See " + "pymongo.read_preferences for valid " + "options." % (read_preference,) + ) if max_commit_time_ms is not None: if not isinstance(max_commit_time_ms, int): - raise TypeError( - "max_commit_time_ms must be an integer or None") + raise TypeError("max_commit_time_ms must be an integer or None") @property def read_concern(self): @@ -287,8 +291,7 @@ def write_concern(self): @property def read_preference(self): - """This transaction's :class:`~pymongo.read_preferences.ReadPreference`. - """ + """This transaction's :class:`~pymongo.read_preferences.ReadPreference`.""" return self._read_preference @property @@ -316,14 +319,15 @@ def _validate_session_write_concern(session, write_concern): return None else: raise ConfigurationError( - 'Explicit sessions are incompatible with ' - 'unacknowledged write concern: %r' % ( - write_concern,)) + "Explicit sessions are incompatible with " + "unacknowledged write concern: %r" % (write_concern,) + ) return session class _TransactionContext(object): """Internal transaction context manager for start_transaction.""" + def __init__(self, session): self.__session = session @@ -349,6 +353,7 @@ class _TxnState(object): class _Transaction(object): """Internal class to hold transaction information in a ClientSession.""" + def __init__(self, opts, client): self.opts = opts self.state = _TxnState.NONE @@ -412,10 +417,12 @@ def _max_time_expired_error(exc): # From the transactions spec, all the retryable writes errors plus # WriteConcernFailed. -_UNKNOWN_COMMIT_ERROR_CODES = _RETRYABLE_ERROR_CODES | frozenset([ - 64, # WriteConcernFailed - 50, # MaxTimeMSExpired -]) +_UNKNOWN_COMMIT_ERROR_CODES = _RETRYABLE_ERROR_CODES | frozenset( + [ + 64, # WriteConcernFailed + 50, # MaxTimeMSExpired + ] +) # From the Convenient API for Transactions spec, with_transaction must # halt retries after 120 seconds. @@ -441,6 +448,7 @@ class ClientSession(object): :class:`ClientSession`, call :meth:`~pymongo.mongo_client.MongoClient.start_session`. """ + def __init__(self, client, server_session, options, implicit): # A MongoClient, a _ServerSession, a SessionOptions, and a set. self._client = client @@ -524,8 +532,14 @@ def _inherit_option(self, name, val): return val return getattr(self.client, name) - def with_transaction(self, callback, read_concern=None, write_concern=None, - read_preference=None, max_commit_time_ms=None): + def with_transaction( + self, + callback, + read_concern=None, + write_concern=None, + read_preference=None, + max_commit_time_ms=None, + ): """Execute a callback in a transaction. This method starts a transaction on this session, executes ``callback`` @@ -613,17 +627,17 @@ def callback(session, custom_arg, custom_kwarg=None): """ start_time = time.monotonic() while True: - self.start_transaction( - read_concern, write_concern, read_preference, - max_commit_time_ms) + self.start_transaction(read_concern, write_concern, read_preference, max_commit_time_ms) try: ret = callback(self) except Exception as exc: if self.in_transaction: self.abort_transaction() - if (isinstance(exc, PyMongoError) and - exc.has_error_label("TransientTransactionError") and - _within_time_limit(start_time)): + if ( + isinstance(exc, PyMongoError) + and exc.has_error_label("TransientTransactionError") + and _within_time_limit(start_time) + ): # Retry the entire transaction. continue raise @@ -636,14 +650,17 @@ def callback(session, custom_arg, custom_kwarg=None): try: self.commit_transaction() except PyMongoError as exc: - if (exc.has_error_label("UnknownTransactionCommitResult") - and _within_time_limit(start_time) - and not _max_time_expired_error(exc)): + if ( + exc.has_error_label("UnknownTransactionCommitResult") + and _within_time_limit(start_time) + and not _max_time_expired_error(exc) + ): # Retry the commit. continue - if (exc.has_error_label("TransientTransactionError") and - _within_time_limit(start_time)): + if exc.has_error_label("TransientTransactionError") and _within_time_limit( + start_time + ): # Retry the entire transaction. break raise @@ -651,8 +668,9 @@ def callback(session, custom_arg, custom_kwarg=None): # Commit succeeded. return ret - def start_transaction(self, read_concern=None, write_concern=None, - read_preference=None, max_commit_time_ms=None): + def start_transaction( + self, read_concern=None, write_concern=None, read_preference=None, max_commit_time_ms=None + ): """Start a multi-statement transaction. Takes the same arguments as :class:`TransactionOptions`. @@ -665,23 +683,22 @@ def start_transaction(self, read_concern=None, write_concern=None, self._check_ended() if self.options.snapshot: - raise InvalidOperation("Transactions are not supported in " - "snapshot sessions") + raise InvalidOperation("Transactions are not supported in " "snapshot sessions") if self.in_transaction: raise InvalidOperation("Transaction already in progress") read_concern = self._inherit_option("read_concern", read_concern) write_concern = self._inherit_option("write_concern", write_concern) - read_preference = self._inherit_option( - "read_preference", read_preference) + read_preference = self._inherit_option("read_preference", read_preference) if max_commit_time_ms is None: opts = self.options.default_transaction_options if opts: max_commit_time_ms = opts.max_commit_time_ms self._transaction.opts = TransactionOptions( - read_concern, write_concern, read_preference, max_commit_time_ms) + read_concern, write_concern, read_preference, max_commit_time_ms + ) self._transaction.reset() self._transaction.state = _TxnState.STARTING self._start_retryable_write() @@ -701,8 +718,7 @@ def commit_transaction(self): self._transaction.state = _TxnState.COMMITTED_EMPTY return elif state is _TxnState.ABORTED: - raise InvalidOperation( - "Cannot call commitTransaction after calling abortTransaction") + raise InvalidOperation("Cannot call commitTransaction after calling abortTransaction") elif state is _TxnState.COMMITTED: # We're explicitly retrying the commit, move the state back to # "in progress" so that in_transaction returns true. @@ -748,8 +764,7 @@ def abort_transaction(self): elif state is _TxnState.ABORTED: raise InvalidOperation("Cannot call abortTransaction twice") elif state in (_TxnState.COMMITTED, _TxnState.COMMITTED_EMPTY): - raise InvalidOperation( - "Cannot call abortTransaction after calling commitTransaction") + raise InvalidOperation("Cannot call abortTransaction after calling commitTransaction") try: self._finish_transaction_with_retry("abortTransaction") @@ -766,8 +781,10 @@ def _finish_transaction_with_retry(self, command_name): :Parameters: - `command_name`: Either "commitTransaction" or "abortTransaction". """ + def func(session, sock_info, retryable): return self._finish_transaction(sock_info, command_name) + return self._client._retry_internal(True, func, self, None) def _finish_transaction(self, sock_info, command_name): @@ -777,7 +794,7 @@ def _finish_transaction(self, sock_info, command_name): cmd = SON([(command_name, 1)]) if command_name == "commitTransaction": if opts.max_commit_time_ms: - cmd['maxTimeMS'] = opts.max_commit_time_ms + cmd["maxTimeMS"] = opts.max_commit_time_ms # Transaction spec says that after the initial commit attempt, # subsequent commitTransaction commands should be upgraded to use @@ -789,14 +806,11 @@ def _finish_transaction(self, sock_info, command_name): wc = WriteConcern(**wc_doc) if self._transaction.recovery_token: - cmd['recoveryToken'] = self._transaction.recovery_token + cmd["recoveryToken"] = self._transaction.recovery_token return self._client.admin._command( - sock_info, - cmd, - session=self, - write_concern=wc, - parse_write_concern_error=True) + sock_info, cmd, session=self, write_concern=wc, parse_write_concern_error=True + ) def _advance_cluster_time(self, cluster_time): """Internal cluster time helper.""" @@ -815,8 +829,7 @@ def advance_cluster_time(self, cluster_time): another `ClientSession` instance. """ if not isinstance(cluster_time, _Mapping): - raise TypeError( - "cluster_time must be a subclass of collections.Mapping") + raise TypeError("cluster_time must be a subclass of collections.Mapping") if not isinstance(cluster_time.get("clusterTime"), Timestamp): raise ValueError("Invalid cluster_time") self._advance_cluster_time(cluster_time) @@ -838,22 +851,21 @@ def advance_operation_time(self, operation_time): another `ClientSession` instance. """ if not isinstance(operation_time, Timestamp): - raise TypeError("operation_time must be an instance " - "of bson.timestamp.Timestamp") + raise TypeError("operation_time must be an instance " "of bson.timestamp.Timestamp") self._advance_operation_time(operation_time) def _process_response(self, reply): """Process a response to a command that was run with this session.""" - self._advance_cluster_time(reply.get('$clusterTime')) - self._advance_operation_time(reply.get('operationTime')) + self._advance_cluster_time(reply.get("$clusterTime")) + self._advance_operation_time(reply.get("operationTime")) if self._options.snapshot and self._snapshot_time is None: - if 'cursor' in reply: - ct = reply['cursor'].get('atClusterTime') + if "cursor" in reply: + ct = reply["cursor"].get("atClusterTime") else: - ct = reply.get('atClusterTime') + ct = reply.get("atClusterTime") self._snapshot_time = ct if self.in_transaction and self._transaction.sharded: - recovery_token = reply.get('recoveryToken') + recovery_token = reply.get("recoveryToken") if recovery_token: self._transaction.recovery_token = recovery_token @@ -872,8 +884,7 @@ def in_transaction(self): @property def _starting_transaction(self): - """True if this session is starting a multi-statement transaction. - """ + """True if this session is starting a multi-statement transaction.""" return self._transaction.starting() @property @@ -909,58 +920,56 @@ def _apply_to(self, command, is_retryable, read_preference, sock_info): self._update_read_concern(command, sock_info) self._server_session.last_use = time.monotonic() - command['lsid'] = self._server_session.session_id + command["lsid"] = self._server_session.session_id if is_retryable: - command['txnNumber'] = self._server_session.transaction_id + command["txnNumber"] = self._server_session.transaction_id return if self.in_transaction: if read_preference != ReadPreference.PRIMARY: raise InvalidOperation( - 'read preference in a transaction must be primary, not: ' - '%r' % (read_preference,)) + "read preference in a transaction must be primary, not: " + "%r" % (read_preference,) + ) if self._transaction.state == _TxnState.STARTING: # First command begins a new transaction. self._transaction.state = _TxnState.IN_PROGRESS - command['startTransaction'] = True + command["startTransaction"] = True if self._transaction.opts.read_concern: rc = self._transaction.opts.read_concern.document if rc: - command['readConcern'] = rc + command["readConcern"] = rc self._update_read_concern(command, sock_info) - command['txnNumber'] = self._server_session.transaction_id - command['autocommit'] = False + command["txnNumber"] = self._server_session.transaction_id + command["autocommit"] = False def _start_retryable_write(self): self._check_ended() self._server_session.inc_transaction_id() def _update_read_concern(self, cmd, sock_info): - if (self.options.causal_consistency - and self.operation_time is not None): - cmd.setdefault('readConcern', {})[ - 'afterClusterTime'] = self.operation_time + if self.options.causal_consistency and self.operation_time is not None: + cmd.setdefault("readConcern", {})["afterClusterTime"] = self.operation_time if self.options.snapshot: if sock_info.max_wire_version < 13: - raise ConfigurationError( - 'Snapshot reads require MongoDB 5.0 or later') - rc = cmd.setdefault('readConcern', {}) - rc['level'] = 'snapshot' + raise ConfigurationError("Snapshot reads require MongoDB 5.0 or later") + rc = cmd.setdefault("readConcern", {}) + rc["level"] = "snapshot" if self._snapshot_time is not None: - rc['atClusterTime'] = self._snapshot_time + rc["atClusterTime"] = self._snapshot_time def __copy__(self): - raise TypeError('A ClientSession cannot be copied, create a new session instead') + raise TypeError("A ClientSession cannot be copied, create a new session instead") class _ServerSession(object): def __init__(self, generation): # Ensure id is type 4, regardless of CodecOptions.uuid_representation. - self.session_id = {'id': Binary(uuid.uuid4().bytes, 4)} + self.session_id = {"id": Binary(uuid.uuid4().bytes, 4)} self.last_use = time.monotonic() self._transaction_id = 0 self.dirty = False @@ -994,6 +1003,7 @@ class _ServerSessionPool(collections.deque): This class is not thread-safe, access it while holding the Topology lock. """ + def __init__(self, *args, **kwargs): super(_ServerSessionPool, self).__init__(*args, **kwargs) self.generation = 0 @@ -1034,8 +1044,7 @@ def return_server_session(self, server_session, session_timeout_minutes): def return_server_session_no_lock(self, server_session): # Discard sessions from an old pool to avoid duplicate sessions in the # child process after a fork. - if (server_session.generation == self.generation and - not server_session.dirty): + if server_session.generation == self.generation and not server_session.dirty: self.appendleft(server_session) def _clear_stale(self, session_timeout_minutes): diff --git a/pymongo/collation.py b/pymongo/collation.py index 873d603336..3117b9ed62 100644 --- a/pymongo/collation.py +++ b/pymongo/collation.py @@ -48,10 +48,10 @@ class CollationAlternate(object): :class:`~pymongo.collation.Collation`. """ - NON_IGNORABLE = 'non-ignorable' + NON_IGNORABLE = "non-ignorable" """Spaces and punctuation are treated as base characters.""" - SHIFTED = 'shifted' + SHIFTED = "shifted" """Spaces and punctuation are *not* considered base characters. Spaces and punctuation are distinguished regardless when the @@ -67,10 +67,10 @@ class CollationMaxVariable(object): :class:`~pymongo.collation.Collation`. """ - PUNCT = 'punct' + PUNCT = "punct" """Both punctuation and spaces are ignored.""" - SPACE = 'space' + SPACE = "space" """Spaces alone are ignored.""" @@ -80,13 +80,13 @@ class CollationCaseFirst(object): :class:`~pymongo.collation.Collation`. """ - UPPER = 'upper' + UPPER = "upper" """Sort uppercase characters first.""" - LOWER = 'lower' + LOWER = "lower" """Sort lowercase characters first.""" - OFF = 'off' + OFF = "off" """Default for locale or collation strength.""" @@ -151,42 +151,41 @@ class Collation(object): __slots__ = ("__document",) - def __init__(self, locale, - caseLevel=None, - caseFirst=None, - strength=None, - numericOrdering=None, - alternate=None, - maxVariable=None, - normalization=None, - backwards=None, - **kwargs): - locale = common.validate_string('locale', locale) - self.__document = {'locale': locale} + def __init__( + self, + locale, + caseLevel=None, + caseFirst=None, + strength=None, + numericOrdering=None, + alternate=None, + maxVariable=None, + normalization=None, + backwards=None, + **kwargs + ): + locale = common.validate_string("locale", locale) + self.__document = {"locale": locale} if caseLevel is not None: - self.__document['caseLevel'] = common.validate_boolean( - 'caseLevel', caseLevel) + self.__document["caseLevel"] = common.validate_boolean("caseLevel", caseLevel) if caseFirst is not None: - self.__document['caseFirst'] = common.validate_string( - 'caseFirst', caseFirst) + self.__document["caseFirst"] = common.validate_string("caseFirst", caseFirst) if strength is not None: - self.__document['strength'] = common.validate_integer( - 'strength', strength) + self.__document["strength"] = common.validate_integer("strength", strength) if numericOrdering is not None: - self.__document['numericOrdering'] = common.validate_boolean( - 'numericOrdering', numericOrdering) + self.__document["numericOrdering"] = common.validate_boolean( + "numericOrdering", numericOrdering + ) if alternate is not None: - self.__document['alternate'] = common.validate_string( - 'alternate', alternate) + self.__document["alternate"] = common.validate_string("alternate", alternate) if maxVariable is not None: - self.__document['maxVariable'] = common.validate_string( - 'maxVariable', maxVariable) + self.__document["maxVariable"] = common.validate_string("maxVariable", maxVariable) if normalization is not None: - self.__document['normalization'] = common.validate_boolean( - 'normalization', normalization) + self.__document["normalization"] = common.validate_boolean( + "normalization", normalization + ) if backwards is not None: - self.__document['backwards'] = common.validate_boolean( - 'backwards', backwards) + self.__document["backwards"] = common.validate_boolean("backwards", backwards) self.__document.update(kwargs) @property @@ -201,8 +200,7 @@ def document(self): def __repr__(self): document = self.document - return 'Collation(%s)' % ( - ', '.join('%s=%r' % (key, document[key]) for key in document),) + return "Collation(%s)" % (", ".join("%s=%r" % (key, document[key]) for key in document),) def __eq__(self, other): if isinstance(other, Collation): @@ -220,6 +218,4 @@ def validate_collation_or_none(value): return value.document if isinstance(value, dict): return value - raise TypeError( - 'collation must be a dict, an instance of collation.Collation, ' - 'or None.') + raise TypeError("collation must be a dict, an instance of collation.Collation, " "or None.") diff --git a/pymongo/collection.py b/pymongo/collection.py index 774c290235..472e897dfe 100644 --- a/pymongo/collection.py +++ b/pymongo/collection.py @@ -16,40 +16,43 @@ import datetime import warnings - from collections import abc from bson.code import Code +from bson.codec_options import CodecOptions from bson.objectid import ObjectId from bson.raw_bson import RawBSONDocument -from bson.codec_options import CodecOptions from bson.son import SON -from pymongo import (common, - helpers, - message) -from pymongo.aggregation import (_CollectionAggregationCommand, - _CollectionRawAggregationCommand) +from pymongo import common, helpers, message +from pymongo.aggregation import ( + _CollectionAggregationCommand, + _CollectionRawAggregationCommand, +) from pymongo.bulk import _Bulk -from pymongo.command_cursor import CommandCursor, RawBatchCommandCursor -from pymongo.collation import validate_collation_or_none from pymongo.change_stream import CollectionChangeStream +from pymongo.collation import validate_collation_or_none +from pymongo.command_cursor import CommandCursor, RawBatchCommandCursor from pymongo.cursor import Cursor, RawBatchCursor -from pymongo.errors import (ConfigurationError, - InvalidName, - InvalidOperation, - OperationFailure) +from pymongo.errors import ( + ConfigurationError, + InvalidName, + InvalidOperation, + OperationFailure, +) from pymongo.helpers import _check_write_command_response from pymongo.message import _UNICODE_REPLACE_CODEC_OPTIONS from pymongo.operations import IndexModel from pymongo.read_preferences import ReadPreference -from pymongo.results import (BulkWriteResult, - DeleteResult, - InsertOneResult, - InsertManyResult, - UpdateResult) +from pymongo.results import ( + BulkWriteResult, + DeleteResult, + InsertManyResult, + InsertOneResult, + UpdateResult, +) from pymongo.write_concern import WriteConcern -_FIND_AND_MODIFY_DOC_FIELDS = {'value': 1} +_FIND_AND_MODIFY_DOC_FIELDS = {"value": 1} class ReturnDocument(object): @@ -57,6 +60,7 @@ class ReturnDocument(object): :meth:`~pymongo.collection.Collection.find_one_and_replace` and :meth:`~pymongo.collection.Collection.find_one_and_update`. """ + BEFORE = False """Return the original document before it was updated/replaced, or ``None`` if no document matches the query. @@ -66,12 +70,20 @@ class ReturnDocument(object): class Collection(common.BaseObject): - """A Mongo collection. - """ - - def __init__(self, database, name, create=False, codec_options=None, - read_preference=None, write_concern=None, read_concern=None, - session=None, **kwargs): + """A Mongo collection.""" + + def __init__( + self, + database, + name, + create=False, + codec_options=None, + read_preference=None, + write_concern=None, + read_concern=None, + session=None, + **kwargs + ): """Get / create a Mongo collection. Raises :class:`TypeError` if `name` is not an instance of @@ -151,24 +163,21 @@ def __init__(self, database, name, create=False, codec_options=None, codec_options or database.codec_options, read_preference or database.read_preference, write_concern or database.write_concern, - read_concern or database.read_concern) + read_concern or database.read_concern, + ) if not isinstance(name, str): raise TypeError("name must be an instance of str") if not name or ".." in name: raise InvalidName("collection names cannot be empty") - if "$" in name and not (name.startswith("oplog.$main") or - name.startswith("$cmd")): - raise InvalidName("collection names must not " - "contain '$': %r" % name) + if "$" in name and not (name.startswith("oplog.$main") or name.startswith("$cmd")): + raise InvalidName("collection names must not " "contain '$': %r" % name) if name[0] == "." or name[-1] == ".": - raise InvalidName("collection names must not start " - "or end with '.': %r" % name) + raise InvalidName("collection names must not start " "or end with '.': %r" % name) if "\x00" in name: - raise InvalidName("collection names must not contain the " - "null character") - collation = validate_collation_or_none(kwargs.pop('collation', None)) + raise InvalidName("collection names must not contain the " "null character") + collation = validate_collation_or_none(kwargs.pop("collation", None)) self.__database = database self.__name = name @@ -177,25 +186,31 @@ def __init__(self, database, name, create=False, codec_options=None, self.__create(kwargs, collation, session) self.__write_response_codec_options = self.codec_options._replace( - unicode_decode_error_handler='replace', - document_class=dict) + unicode_decode_error_handler="replace", document_class=dict + ) def _socket_for_reads(self, session): - return self.__database.client._socket_for_reads( - self._read_preference_for(session), session) + return self.__database.client._socket_for_reads(self._read_preference_for(session), session) def _socket_for_writes(self, session): return self.__database.client._socket_for_writes(session) - def _command(self, sock_info, command, secondary_ok=False, - read_preference=None, - codec_options=None, check=True, allowable_errors=None, - read_concern=None, - write_concern=None, - collation=None, - session=None, - retryable_write=False, - user_fields=None): + def _command( + self, + sock_info, + command, + secondary_ok=False, + read_preference=None, + codec_options=None, + check=True, + allowable_errors=None, + read_concern=None, + write_concern=None, + collation=None, + session=None, + retryable_write=False, + user_fields=None, + ): """Internal command helper. :Parameters: @@ -240,11 +255,11 @@ def _command(self, sock_info, command, secondary_ok=False, session=s, client=self.__database.client, retryable_write=retryable_write, - user_fields=user_fields) + user_fields=user_fields, + ) def __create(self, options, collation, session): - """Sends a create command with the given options. - """ + """Sends a create command with the given options.""" cmd = SON([("create", self.__name)]) if options: if "size" in options: @@ -252,9 +267,13 @@ def __create(self, options, collation, session): cmd.update(options) with self._socket_for_writes(session) as sock_info: self._command( - sock_info, cmd, read_preference=ReadPreference.PRIMARY, + sock_info, + cmd, + read_preference=ReadPreference.PRIMARY, write_concern=self._write_concern_for(session), - collation=collation, session=session) + collation=collation, + session=session, + ) def __getattr__(self, name): """Get a sub-collection of this collection by name. @@ -264,30 +283,31 @@ def __getattr__(self, name): :Parameters: - `name`: the name of the collection to get """ - if name.startswith('_'): + if name.startswith("_"): full_name = "%s.%s" % (self.__name, name) raise AttributeError( "Collection has no attribute %r. To access the %s" - " collection, use database['%s']." % ( - name, full_name, full_name)) + " collection, use database['%s']." % (name, full_name, full_name) + ) return self.__getitem__(name) def __getitem__(self, name): - return Collection(self.__database, - "%s.%s" % (self.__name, name), - False, - self.codec_options, - self.read_preference, - self.write_concern, - self.read_concern) + return Collection( + self.__database, + "%s.%s" % (self.__name, name), + False, + self.codec_options, + self.read_preference, + self.write_concern, + self.read_concern, + ) def __repr__(self): return "Collection(%r, %r)" % (self.__database, self.__name) def __eq__(self, other): if isinstance(other, Collection): - return (self.__database == other.database and - self.__name == other.name) + return self.__database == other.database and self.__name == other.name return NotImplemented def __ne__(self, other): @@ -297,9 +317,11 @@ def __hash__(self): return hash((self.__database, self.__name)) def __bool__(self): - raise NotImplementedError("Collection objects do not implement truth " - "value testing or bool(). Please compare " - "with None instead: collection is not None") + raise NotImplementedError( + "Collection objects do not implement truth " + "value testing or bool(). Please compare " + "with None instead: collection is not None" + ) @property def full_name(self): @@ -321,8 +343,9 @@ def database(self): """ return self.__database - def with_options(self, codec_options=None, read_preference=None, - write_concern=None, read_concern=None): + def with_options( + self, codec_options=None, read_preference=None, write_concern=None, read_concern=None + ): """Get a clone of this collection changing the specified settings. >>> coll1.read_preference @@ -352,16 +375,17 @@ def with_options(self, codec_options=None, read_preference=None, default) the :attr:`read_concern` of this :class:`Collection` is used. """ - return Collection(self.__database, - self.__name, - False, - codec_options or self.codec_options, - read_preference or self.read_preference, - write_concern or self.write_concern, - read_concern or self.read_concern) - - def bulk_write(self, requests, ordered=True, - bypass_document_validation=False, session=None): + return Collection( + self.__database, + self.__name, + False, + codec_options or self.codec_options, + read_preference or self.read_preference, + write_concern or self.write_concern, + read_concern or self.read_concern, + ) + + def bulk_write(self, requests, ordered=True, bypass_document_validation=False, session=None): """Send a batch of write operations to the server. Requests are passed as a list of write operation instances ( @@ -442,22 +466,17 @@ def bulk_write(self, requests, ordered=True, return BulkWriteResult(bulk_api_result, True) return BulkWriteResult({}, False) - def _insert_one( - self, doc, ordered, - check_keys, write_concern, op_id, bypass_doc_val, - session): + def _insert_one(self, doc, ordered, check_keys, write_concern, op_id, bypass_doc_val, session): """Internal helper for inserting a single document.""" write_concern = write_concern or self.write_concern acknowledged = write_concern.acknowledged - command = SON([('insert', self.name), - ('ordered', ordered), - ('documents', [doc])]) + command = SON([("insert", self.name), ("ordered", ordered), ("documents", [doc])]) if not write_concern.is_server_default: - command['writeConcern'] = write_concern.document + command["writeConcern"] = write_concern.document def _insert_command(session, sock_info, retryable_write): if bypass_doc_val: - command['bypassDocumentValidation'] = True + command["bypassDocumentValidation"] = True result = sock_info.command( self.__database.name, @@ -467,18 +486,17 @@ def _insert_command(session, sock_info, retryable_write): check_keys=check_keys, session=session, client=self.__database.client, - retryable_write=retryable_write) + retryable_write=retryable_write, + ) _check_write_command_response(result) - self.__database.client._retryable_write( - acknowledged, _insert_command, session) + self.__database.client._retryable_write(acknowledged, _insert_command, session) if not isinstance(doc, RawBSONDocument): - return doc.get('_id') + return doc.get("_id") - def insert_one(self, document, bypass_document_validation=False, - session=None): + def insert_one(self, document, bypass_document_validation=False, session=None): """Insert a single document. >>> db.test.count_documents({'x': 1}) @@ -522,13 +540,18 @@ def insert_one(self, document, bypass_document_validation=False, write_concern = self._write_concern_for(session) return InsertOneResult( self._insert_one( - document, ordered=True, check_keys=False, - write_concern=write_concern, op_id=None, - bypass_doc_val=bypass_document_validation, session=session), - write_concern.acknowledged) + document, + ordered=True, + check_keys=False, + write_concern=write_concern, + op_id=None, + bypass_doc_val=bypass_document_validation, + session=session, + ), + write_concern.acknowledged, + ) - def insert_many(self, documents, ordered=True, - bypass_document_validation=False, session=None): + def insert_many(self, documents, ordered=True, bypass_document_validation=False, session=None): """Insert an iterable of documents. >>> db.test.count_documents({}) @@ -568,11 +591,14 @@ def insert_many(self, documents, ordered=True, .. versionadded:: 3.0 """ - if (not isinstance(documents, abc.Iterable) - or isinstance(documents, abc.Mapping) - or not documents): + if ( + not isinstance(documents, abc.Iterable) + or isinstance(documents, abc.Mapping) + or not documents + ): raise TypeError("documents must be a non-empty list") inserted_ids = [] + def gen(): """A generator that validates documents and handles _ids.""" for document in documents: @@ -589,49 +615,54 @@ def gen(): blk.execute(write_concern, session=session) return InsertManyResult(inserted_ids, write_concern.acknowledged) - def _update(self, sock_info, criteria, document, upsert=False, - check_keys=False, multi=False, - write_concern=None, op_id=None, ordered=True, - bypass_doc_val=False, collation=None, array_filters=None, - hint=None, session=None, retryable_write=False): + def _update( + self, + sock_info, + criteria, + document, + upsert=False, + check_keys=False, + multi=False, + write_concern=None, + op_id=None, + ordered=True, + bypass_doc_val=False, + collation=None, + array_filters=None, + hint=None, + session=None, + retryable_write=False, + ): """Internal update / replace helper.""" common.validate_boolean("upsert", upsert) collation = validate_collation_or_none(collation) write_concern = write_concern or self.write_concern acknowledged = write_concern.acknowledged - update_doc = SON([('q', criteria), - ('u', document), - ('multi', multi), - ('upsert', upsert)]) + update_doc = SON([("q", criteria), ("u", document), ("multi", multi), ("upsert", upsert)]) if collation is not None: if not acknowledged: - raise ConfigurationError( - 'Collation is unsupported for unacknowledged writes.') + raise ConfigurationError("Collation is unsupported for unacknowledged writes.") else: - update_doc['collation'] = collation + update_doc["collation"] = collation if array_filters is not None: if not acknowledged: - raise ConfigurationError( - 'arrayFilters is unsupported for unacknowledged writes.') + raise ConfigurationError("arrayFilters is unsupported for unacknowledged writes.") else: - update_doc['arrayFilters'] = array_filters + update_doc["arrayFilters"] = array_filters if hint is not None: if not acknowledged: - raise ConfigurationError( - 'hint is unsupported for unacknowledged writes.') + raise ConfigurationError("hint is unsupported for unacknowledged writes.") if not isinstance(hint, str): hint = helpers._index_document(hint) - update_doc['hint'] = hint + update_doc["hint"] = hint - command = SON([('update', self.name), - ('ordered', ordered), - ('updates', [update_doc])]) + command = SON([("update", self.name), ("ordered", ordered), ("updates", [update_doc])]) if not write_concern.is_server_default: - command['writeConcern'] = write_concern.document + command["writeConcern"] = write_concern.document # Update command. if bypass_doc_val: - command['bypassDocumentValidation'] = True + command["bypassDocumentValidation"] = True # The command result has to be published for APM unmodified # so we make a shallow copy here before adding updatedExisting. @@ -642,45 +673,74 @@ def _update(self, sock_info, criteria, document, upsert=False, codec_options=self.__write_response_codec_options, session=session, client=self.__database.client, - retryable_write=retryable_write).copy() + retryable_write=retryable_write, + ).copy() _check_write_command_response(result) # Add the updatedExisting field for compatibility. - if result.get('n') and 'upserted' not in result: - result['updatedExisting'] = True + if result.get("n") and "upserted" not in result: + result["updatedExisting"] = True else: - result['updatedExisting'] = False + result["updatedExisting"] = False # MongoDB >= 2.6.0 returns the upsert _id in an array # element. Break it out for backward compatibility. - if 'upserted' in result: - result['upserted'] = result['upserted'][0]['_id'] + if "upserted" in result: + result["upserted"] = result["upserted"][0]["_id"] if not acknowledged: return None return result def _update_retryable( - self, criteria, document, upsert=False, - check_keys=False, multi=False, - write_concern=None, op_id=None, ordered=True, - bypass_doc_val=False, collation=None, array_filters=None, - hint=None, session=None): + self, + criteria, + document, + upsert=False, + check_keys=False, + multi=False, + write_concern=None, + op_id=None, + ordered=True, + bypass_doc_val=False, + collation=None, + array_filters=None, + hint=None, + session=None, + ): """Internal update / replace helper.""" + def _update(session, sock_info, retryable_write): return self._update( - sock_info, criteria, document, upsert=upsert, - check_keys=check_keys, multi=multi, - write_concern=write_concern, op_id=op_id, ordered=ordered, - bypass_doc_val=bypass_doc_val, collation=collation, - array_filters=array_filters, hint=hint, session=session, - retryable_write=retryable_write) + sock_info, + criteria, + document, + upsert=upsert, + check_keys=check_keys, + multi=multi, + write_concern=write_concern, + op_id=op_id, + ordered=ordered, + bypass_doc_val=bypass_doc_val, + collation=collation, + array_filters=array_filters, + hint=hint, + session=session, + retryable_write=retryable_write, + ) return self.__database.client._retryable_write( - (write_concern or self.write_concern).acknowledged and not multi, - _update, session) - - def replace_one(self, filter, replacement, upsert=False, - bypass_document_validation=False, collation=None, - hint=None, session=None): + (write_concern or self.write_concern).acknowledged and not multi, _update, session + ) + + def replace_one( + self, + filter, + replacement, + upsert=False, + bypass_document_validation=False, + collation=None, + hint=None, + session=None, + ): """Replace a single document matching the filter. >>> for doc in db.test.find({}): @@ -750,16 +810,29 @@ def replace_one(self, filter, replacement, upsert=False, write_concern = self._write_concern_for(session) return UpdateResult( self._update_retryable( - filter, replacement, upsert, + filter, + replacement, + upsert, write_concern=write_concern, bypass_doc_val=bypass_document_validation, - collation=collation, hint=hint, session=session), - write_concern.acknowledged) - - def update_one(self, filter, update, upsert=False, - bypass_document_validation=False, - collation=None, array_filters=None, hint=None, - session=None): + collation=collation, + hint=hint, + session=session, + ), + write_concern.acknowledged, + ) + + def update_one( + self, + filter, + update, + upsert=False, + bypass_document_validation=False, + collation=None, + array_filters=None, + hint=None, + session=None, + ): """Update a single document matching the filter. >>> for doc in db.test.find(): @@ -821,21 +894,36 @@ def update_one(self, filter, update, upsert=False, """ common.validate_is_mapping("filter", filter) common.validate_ok_for_update(update) - common.validate_list_or_none('array_filters', array_filters) + common.validate_list_or_none("array_filters", array_filters) write_concern = self._write_concern_for(session) return UpdateResult( self._update_retryable( - filter, update, upsert, check_keys=False, + filter, + update, + upsert, + check_keys=False, write_concern=write_concern, bypass_doc_val=bypass_document_validation, - collation=collation, array_filters=array_filters, - hint=hint, session=session), - write_concern.acknowledged) - - def update_many(self, filter, update, upsert=False, array_filters=None, - bypass_document_validation=False, collation=None, - hint=None, session=None): + collation=collation, + array_filters=array_filters, + hint=hint, + session=session, + ), + write_concern.acknowledged, + ) + + def update_many( + self, + filter, + update, + upsert=False, + array_filters=None, + bypass_document_validation=False, + collation=None, + hint=None, + session=None, + ): """Update one or more documents that match the filter. >>> for doc in db.test.find(): @@ -897,17 +985,25 @@ def update_many(self, filter, update, upsert=False, array_filters=None, """ common.validate_is_mapping("filter", filter) common.validate_ok_for_update(update) - common.validate_list_or_none('array_filters', array_filters) + common.validate_list_or_none("array_filters", array_filters) write_concern = self._write_concern_for(session) return UpdateResult( self._update_retryable( - filter, update, upsert, check_keys=False, multi=True, + filter, + update, + upsert, + check_keys=False, + multi=True, write_concern=write_concern, bypass_doc_val=bypass_document_validation, - collation=collation, array_filters=array_filters, - hint=hint, session=session), - write_concern.acknowledged) + collation=collation, + array_filters=array_filters, + hint=hint, + session=session, + ), + write_concern.acknowledged, + ) def drop(self, session=None): """Alias for :meth:`~pymongo.database.Database.drop_collection`. @@ -932,38 +1028,43 @@ def drop(self, session=None): self.codec_options, self.read_preference, self.write_concern, - self.read_concern) + self.read_concern, + ) dbo.drop_collection(self.__name, session=session) def _delete( - self, sock_info, criteria, multi, - write_concern=None, op_id=None, ordered=True, - collation=None, hint=None, session=None, retryable_write=False): + self, + sock_info, + criteria, + multi, + write_concern=None, + op_id=None, + ordered=True, + collation=None, + hint=None, + session=None, + retryable_write=False, + ): """Internal delete helper.""" common.validate_is_mapping("filter", criteria) write_concern = write_concern or self.write_concern acknowledged = write_concern.acknowledged - delete_doc = SON([('q', criteria), - ('limit', int(not multi))]) + delete_doc = SON([("q", criteria), ("limit", int(not multi))]) collation = validate_collation_or_none(collation) if collation is not None: if not acknowledged: - raise ConfigurationError( - 'Collation is unsupported for unacknowledged writes.') + raise ConfigurationError("Collation is unsupported for unacknowledged writes.") else: - delete_doc['collation'] = collation + delete_doc["collation"] = collation if hint is not None: if not acknowledged: - raise ConfigurationError( - 'hint is unsupported for unacknowledged writes.') + raise ConfigurationError("hint is unsupported for unacknowledged writes.") if not isinstance(hint, str): hint = helpers._index_document(hint) - delete_doc['hint'] = hint - command = SON([('delete', self.name), - ('ordered', ordered), - ('deletes', [delete_doc])]) + delete_doc["hint"] = hint + command = SON([("delete", self.name), ("ordered", ordered), ("deletes", [delete_doc])]) if not write_concern.is_server_default: - command['writeConcern'] = write_concern.document + command["writeConcern"] = write_concern.document # Delete command. result = sock_info.command( @@ -973,25 +1074,41 @@ def _delete( codec_options=self.__write_response_codec_options, session=session, client=self.__database.client, - retryable_write=retryable_write) + retryable_write=retryable_write, + ) _check_write_command_response(result) return result def _delete_retryable( - self, criteria, multi, - write_concern=None, op_id=None, ordered=True, - collation=None, hint=None, session=None): + self, + criteria, + multi, + write_concern=None, + op_id=None, + ordered=True, + collation=None, + hint=None, + session=None, + ): """Internal delete helper.""" + def _delete(session, sock_info, retryable_write): return self._delete( - sock_info, criteria, multi, - write_concern=write_concern, op_id=op_id, ordered=ordered, - collation=collation, hint=hint, session=session, - retryable_write=retryable_write) + sock_info, + criteria, + multi, + write_concern=write_concern, + op_id=op_id, + ordered=ordered, + collation=collation, + hint=hint, + session=session, + retryable_write=retryable_write, + ) return self.__database.client._retryable_write( - (write_concern or self.write_concern).acknowledged and not multi, - _delete, session) + (write_concern or self.write_concern).acknowledged and not multi, _delete, session + ) def delete_one(self, filter, collation=None, hint=None, session=None): """Delete a single document matching the filter. @@ -1032,10 +1149,15 @@ def delete_one(self, filter, collation=None, hint=None, session=None): write_concern = self._write_concern_for(session) return DeleteResult( self._delete_retryable( - filter, False, + filter, + False, write_concern=write_concern, - collation=collation, hint=hint, session=session), - write_concern.acknowledged) + collation=collation, + hint=hint, + session=session, + ), + write_concern.acknowledged, + ) def delete_many(self, filter, collation=None, hint=None, session=None): """Delete one or more documents matching the filter. @@ -1076,10 +1198,15 @@ def delete_many(self, filter, collation=None, hint=None, session=None): write_concern = self._write_concern_for(session) return DeleteResult( self._delete_retryable( - filter, True, + filter, + True, write_concern=write_concern, - collation=collation, hint=hint, session=session), - write_concern.acknowledged) + collation=collation, + hint=hint, + session=session, + ), + write_concern.acknowledged, + ) def find_one(self, filter=None, *args, **kwargs): """Get a single document from the database. @@ -1106,8 +1233,7 @@ def find_one(self, filter=None, *args, **kwargs): >>> collection.find_one(max_time_ms=100) """ - if (filter is not None and not - isinstance(filter, abc.Mapping)): + if filter is not None and not isinstance(filter, abc.Mapping): filter = {"_id": filter} cursor = self.find(filter, *args, **kwargs) @@ -1333,8 +1459,7 @@ def find_raw_batches(self, *args, **kwargs): """ # OP_MSG is required to support encryption. if self.__database.client._encrypter: - raise InvalidOperation( - "find_raw_batches does not support auto encryption") + raise InvalidOperation("find_raw_batches does not support auto encryption") return RawBatchCursor(self, *args, **kwargs) @@ -1350,13 +1475,13 @@ def _count_cmd(self, session, sock_info, secondary_ok, cmd, collation): codec_options=self.__write_response_codec_options, read_concern=self.read_concern, collation=collation, - session=session) + session=session, + ) if res.get("errmsg", "") == "ns missing": return 0 return int(res["n"]) - def _aggregate_one_result( - self, sock_info, secondary_ok, cmd, collation, session): + def _aggregate_one_result(self, sock_info, secondary_ok, cmd, collation, session): """Internal helper to run an aggregate that returns a single result.""" result = self._command( sock_info, @@ -1366,11 +1491,12 @@ def _aggregate_one_result( codec_options=self.__write_response_codec_options, read_concern=self.read_concern, collation=collation, - session=session) + session=session, + ) # cursor will not be present for NamespaceNotFound errors. - if 'cursor' not in result: + if "cursor" not in result: return None - batch = result['cursor']['firstBatch'] + batch = result["cursor"]["firstBatch"] return batch[0] if batch else None def estimated_document_count(self, **kwargs): @@ -1391,34 +1517,31 @@ def estimated_document_count(self, **kwargs): .. versionadded:: 3.7 """ - if 'session' in kwargs: - raise ConfigurationError( - 'estimated_document_count does not support sessions') + if "session" in kwargs: + raise ConfigurationError("estimated_document_count does not support sessions") def _cmd(session, server, sock_info, secondary_ok): if sock_info.max_wire_version >= 12: # MongoDB 4.9+ pipeline = [ - {'$collStats': {'count': {}}}, - {'$group': {'_id': 1, 'n': {'$sum': '$count'}}}, + {"$collStats": {"count": {}}}, + {"$group": {"_id": 1, "n": {"$sum": "$count"}}}, ] - cmd = SON([('aggregate', self.__name), - ('pipeline', pipeline), - ('cursor', {})]) + cmd = SON([("aggregate", self.__name), ("pipeline", pipeline), ("cursor", {})]) cmd.update(kwargs) result = self._aggregate_one_result( - sock_info, secondary_ok, cmd, collation=None, session=session) + sock_info, secondary_ok, cmd, collation=None, session=session + ) if not result: return 0 - return int(result['n']) + return int(result["n"]) else: # MongoDB < 4.9 - cmd = SON([('count', self.__name)]) + cmd = SON([("count", self.__name)]) cmd.update(kwargs) return self._count_cmd(None, sock_info, secondary_ok, cmd, None) - return self.__database.client._retryable_read( - _cmd, self.read_preference, None) + return self.__database.client._retryable_read(_cmd, self.read_preference, None) def count_documents(self, filter, session=None, **kwargs): """Count the number of documents in this collection. @@ -1478,29 +1601,27 @@ def count_documents(self, filter, session=None, **kwargs): .. _$center: https://docs.mongodb.com/manual/reference/operator/query/center/#op._S_center .. _$centerSphere: https://docs.mongodb.com/manual/reference/operator/query/centerSphere/#op._S_centerSphere """ - pipeline = [{'$match': filter}] - if 'skip' in kwargs: - pipeline.append({'$skip': kwargs.pop('skip')}) - if 'limit' in kwargs: - pipeline.append({'$limit': kwargs.pop('limit')}) - pipeline.append({'$group': {'_id': 1, 'n': {'$sum': 1}}}) - cmd = SON([('aggregate', self.__name), - ('pipeline', pipeline), - ('cursor', {})]) + pipeline = [{"$match": filter}] + if "skip" in kwargs: + pipeline.append({"$skip": kwargs.pop("skip")}) + if "limit" in kwargs: + pipeline.append({"$limit": kwargs.pop("limit")}) + pipeline.append({"$group": {"_id": 1, "n": {"$sum": 1}}}) + cmd = SON([("aggregate", self.__name), ("pipeline", pipeline), ("cursor", {})]) if "hint" in kwargs and not isinstance(kwargs["hint"], str): kwargs["hint"] = helpers._index_document(kwargs["hint"]) - collation = validate_collation_or_none(kwargs.pop('collation', None)) + collation = validate_collation_or_none(kwargs.pop("collation", None)) cmd.update(kwargs) def _cmd(session, server, sock_info, secondary_ok): - result = self._aggregate_one_result( - sock_info, secondary_ok, cmd, collation, session) + result = self._aggregate_one_result(sock_info, secondary_ok, cmd, collation, session) if not result: return 0 - return result['n'] + return result["n"] return self.__database.client._retryable_read( - _cmd, self._read_preference_for(session), session) + _cmd, self._read_preference_for(session), session + ) def create_indexes(self, indexes, session=None, **kwargs): """Create one or more indexes on this collection. @@ -1539,7 +1660,7 @@ def create_indexes(self, indexes, session=None, **kwargs): .. _createIndexes: https://docs.mongodb.com/manual/reference/command/createIndexes/ """ - common.validate_list('indexes', indexes) + common.validate_list("indexes", indexes) return self.__create_indexes(indexes, session, **kwargs) def __create_indexes(self, indexes, session, **kwargs): @@ -1561,25 +1682,28 @@ def gen_indexes(): for index in indexes: if not isinstance(index, IndexModel): raise TypeError( - "%r is not an instance of " - "pymongo.operations.IndexModel" % (index,)) + "%r is not an instance of " "pymongo.operations.IndexModel" % (index,) + ) document = index.document names.append(document["name"]) yield document - cmd = SON([('createIndexes', self.name), - ('indexes', list(gen_indexes()))]) + cmd = SON([("createIndexes", self.name), ("indexes", list(gen_indexes()))]) cmd.update(kwargs) - if 'commitQuorum' in kwargs and not supports_quorum: + if "commitQuorum" in kwargs and not supports_quorum: raise ConfigurationError( "Must be connected to MongoDB 4.4+ to use the " - "commitQuorum option for createIndexes") + "commitQuorum option for createIndexes" + ) self._command( - sock_info, cmd, read_preference=ReadPreference.PRIMARY, + sock_info, + cmd, + read_preference=ReadPreference.PRIMARY, codec_options=_UNICODE_REPLACE_CODEC_OPTIONS, write_concern=self._write_concern_for(session), - session=session) + session=session, + ) return names def create_index(self, keys, session=None, **kwargs): @@ -1760,12 +1884,14 @@ def drop_index(self, index_or_name, session=None, **kwargs): cmd = SON([("dropIndexes", self.__name), ("index", name)]) cmd.update(kwargs) with self._socket_for_writes(session) as sock_info: - self._command(sock_info, - cmd, - read_preference=ReadPreference.PRIMARY, - allowable_errors=["ns not found", 26], - write_concern=self._write_concern_for(session), - session=session) + self._command( + sock_info, + cmd, + read_preference=ReadPreference.PRIMARY, + allowable_errors=["ns not found", 26], + write_concern=self._write_concern_for(session), + session=session, + ) def list_indexes(self, session=None): """Get a cursor over the index documents for this collection. @@ -1788,33 +1914,31 @@ def list_indexes(self, session=None): .. versionadded:: 3.0 """ codec_options = CodecOptions(SON) - coll = self.with_options(codec_options=codec_options, - read_preference=ReadPreference.PRIMARY) - read_pref = ((session and session._txn_read_preference()) - or ReadPreference.PRIMARY) + coll = self.with_options( + codec_options=codec_options, read_preference=ReadPreference.PRIMARY + ) + read_pref = (session and session._txn_read_preference()) or ReadPreference.PRIMARY def _cmd(session, server, sock_info, secondary_ok): cmd = SON([("listIndexes", self.__name), ("cursor", {})]) with self.__database.client._tmp_session(session, False) as s: try: - cursor = self._command(sock_info, cmd, secondary_ok, - read_pref, - codec_options, - session=s)["cursor"] + cursor = self._command( + sock_info, cmd, secondary_ok, read_pref, codec_options, session=s + )["cursor"] except OperationFailure as exc: # Ignore NamespaceNotFound errors to match the behavior # of reading from *.system.indexes. if exc.code != 26: raise - cursor = {'id': 0, 'firstBatch': []} + cursor = {"id": 0, "firstBatch": []} cmd_cursor = CommandCursor( - coll, cursor, sock_info.address, session=s, - explicit_session=session is not None) + coll, cursor, sock_info.address, session=s, explicit_session=session is not None + ) cmd_cursor._maybe_pin_connection(sock_info) return cmd_cursor - return self.__database.client._retryable_read( - _cmd, read_pref, session) + return self.__database.client._retryable_read(_cmd, read_pref, session) def index_information(self, session=None): """Get information on this collection's indexes. @@ -1870,9 +1994,9 @@ def options(self, session=None): self.codec_options, self.read_preference, self.write_concern, - self.read_concern) - cursor = dbo.list_collections( - session=session, filter={"name": self.__name}) + self.read_concern, + ) + cursor = dbo.list_collections(session=session, filter={"name": self.__name}) result = None for doc in cursor: @@ -1888,14 +2012,23 @@ def options(self, session=None): return options - def _aggregate(self, aggregation_command, pipeline, cursor_class, session, - explicit_session, **kwargs): + def _aggregate( + self, aggregation_command, pipeline, cursor_class, session, explicit_session, **kwargs + ): cmd = aggregation_command( - self, cursor_class, pipeline, kwargs, explicit_session, - user_fields={'cursor': {'firstBatch': 1}}) + self, + cursor_class, + pipeline, + kwargs, + explicit_session, + user_fields={"cursor": {"firstBatch": 1}}, + ) return self.__database.client._retryable_read( - cmd.get_cursor, cmd.get_read_preference(session), session, - retryable=not cmd._performs_write) + cmd.get_cursor, + cmd.get_read_preference(session), + session, + retryable=not cmd._performs_write, + ) def aggregate(self, pipeline, session=None, **kwargs): """Perform an aggregation using the aggregation framework on this @@ -1968,12 +2101,14 @@ def aggregate(self, pipeline, session=None, **kwargs): https://docs.mongodb.com/manual/reference/command/aggregate """ with self.__database.client._tmp_session(session, close=False) as s: - return self._aggregate(_CollectionAggregationCommand, - pipeline, - CommandCursor, - session=s, - explicit_session=session is not None, - **kwargs) + return self._aggregate( + _CollectionAggregationCommand, + pipeline, + CommandCursor, + session=s, + explicit_session=session is not None, + **kwargs + ) def aggregate_raw_batches(self, pipeline, session=None, **kwargs): """Perform an aggregation and retrieve batches of raw BSON. @@ -2001,20 +2136,30 @@ def aggregate_raw_batches(self, pipeline, session=None, **kwargs): """ # OP_MSG is required to support encryption. if self.__database.client._encrypter: - raise InvalidOperation( - "aggregate_raw_batches does not support auto encryption") + raise InvalidOperation("aggregate_raw_batches does not support auto encryption") with self.__database.client._tmp_session(session, close=False) as s: - return self._aggregate(_CollectionRawAggregationCommand, - pipeline, - RawBatchCommandCursor, - session=s, - explicit_session=session is not None, - **kwargs) - - def watch(self, pipeline=None, full_document=None, resume_after=None, - max_await_time_ms=None, batch_size=None, collation=None, - start_at_operation_time=None, session=None, start_after=None): + return self._aggregate( + _CollectionRawAggregationCommand, + pipeline, + RawBatchCommandCursor, + session=s, + explicit_session=session is not None, + **kwargs + ) + + def watch( + self, + pipeline=None, + full_document=None, + resume_after=None, + max_await_time_ms=None, + batch_size=None, + collation=None, + start_at_operation_time=None, + session=None, + start_after=None, + ): """Watch changes on this collection. Performs an aggregation with an implicit initial ``$changeStream`` @@ -2112,9 +2257,17 @@ def watch(self, pipeline=None, full_document=None, resume_after=None, https://github.com/mongodb/specifications/blob/master/source/change-streams/change-streams.rst """ return CollectionChangeStream( - self, pipeline, full_document, resume_after, max_await_time_ms, - batch_size, collation, start_at_operation_time, session, - start_after) + self, + pipeline, + full_document, + resume_after, + max_await_time_ms, + batch_size, + collation, + start_at_operation_time, + session, + start_after, + ) def rename(self, new_name, session=None, **kwargs): """Rename this collection. @@ -2163,10 +2316,13 @@ def rename(self, new_name, session=None, **kwargs): with self._socket_for_writes(session) as sock_info: with self.__database.client._tmp_session(session) as s: return sock_info.command( - 'admin', cmd, + "admin", + cmd, write_concern=write_concern, parse_write_concern_error=True, - session=s, client=self.__database.client) + session=s, + client=self.__database.client, + ) def distinct(self, key, filter=None, session=None, **kwargs): """Get a list of distinct values for `key` among all documents @@ -2205,48 +2361,60 @@ def distinct(self, key, filter=None, session=None, **kwargs): """ if not isinstance(key, str): raise TypeError("key must be an instance of str") - cmd = SON([("distinct", self.__name), - ("key", key)]) + cmd = SON([("distinct", self.__name), ("key", key)]) if filter is not None: if "query" in kwargs: raise ConfigurationError("can't pass both filter and query") kwargs["query"] = filter - collation = validate_collation_or_none(kwargs.pop('collation', None)) + collation = validate_collation_or_none(kwargs.pop("collation", None)) cmd.update(kwargs) + def _cmd(session, server, sock_info, secondary_ok): return self._command( - sock_info, cmd, secondary_ok, read_concern=self.read_concern, - collation=collation, session=session, - user_fields={"values": 1})["values"] + sock_info, + cmd, + secondary_ok, + read_concern=self.read_concern, + collation=collation, + session=session, + user_fields={"values": 1}, + )["values"] return self.__database.client._retryable_read( - _cmd, self._read_preference_for(session), session) + _cmd, self._read_preference_for(session), session + ) def _write_concern_for_cmd(self, cmd, session): - raw_wc = cmd.get('writeConcern') + raw_wc = cmd.get("writeConcern") if raw_wc is not None: return WriteConcern(**raw_wc) else: return self._write_concern_for(session) - def __find_and_modify(self, filter, projection, sort, upsert=None, - return_document=ReturnDocument.BEFORE, - array_filters=None, hint=None, session=None, - **kwargs): + def __find_and_modify( + self, + filter, + projection, + sort, + upsert=None, + return_document=ReturnDocument.BEFORE, + array_filters=None, + hint=None, + session=None, + **kwargs + ): """Internal findAndModify helper.""" common.validate_is_mapping("filter", filter) if not isinstance(return_document, bool): - raise ValueError("return_document must be " - "ReturnDocument.BEFORE or ReturnDocument.AFTER") - collation = validate_collation_or_none(kwargs.pop('collation', None)) - cmd = SON([("findAndModify", self.__name), - ("query", filter), - ("new", return_document)]) + raise ValueError( + "return_document must be " "ReturnDocument.BEFORE or ReturnDocument.AFTER" + ) + collation = validate_collation_or_none(kwargs.pop("collation", None)) + cmd = SON([("findAndModify", self.__name), ("query", filter), ("new", return_document)]) cmd.update(kwargs) if projection is not None: - cmd["fields"] = helpers._fields_list_to_dict(projection, - "projection") + cmd["fields"] = helpers._fields_list_to_dict(projection, "projection") if sort is not None: cmd["sort"] = helpers._index_document(sort) if upsert is not None: @@ -2262,35 +2430,38 @@ def _find_and_modify(session, sock_info, retryable_write): if array_filters is not None: if not write_concern.acknowledged: raise ConfigurationError( - 'arrayFilters is unsupported for unacknowledged ' - 'writes.') + "arrayFilters is unsupported for unacknowledged " "writes." + ) cmd["arrayFilters"] = array_filters if hint is not None: if sock_info.max_wire_version < 8: - raise ConfigurationError( - 'Must be connected to MongoDB 4.2+ to use hint.') + raise ConfigurationError("Must be connected to MongoDB 4.2+ to use hint.") if not write_concern.acknowledged: - raise ConfigurationError( - 'hint is unsupported for unacknowledged writes.') - cmd['hint'] = hint + raise ConfigurationError("hint is unsupported for unacknowledged writes.") + cmd["hint"] = hint if not write_concern.is_server_default: - cmd['writeConcern'] = write_concern.document - out = self._command(sock_info, cmd, - read_preference=ReadPreference.PRIMARY, - write_concern=write_concern, - collation=collation, session=session, - retryable_write=retryable_write, - user_fields=_FIND_AND_MODIFY_DOC_FIELDS) + cmd["writeConcern"] = write_concern.document + out = self._command( + sock_info, + cmd, + read_preference=ReadPreference.PRIMARY, + write_concern=write_concern, + collation=collation, + session=session, + retryable_write=retryable_write, + user_fields=_FIND_AND_MODIFY_DOC_FIELDS, + ) _check_write_command_response(out) return out.get("value") return self.__database.client._retryable_write( - write_concern.acknowledged, _find_and_modify, session) + write_concern.acknowledged, _find_and_modify, session + ) - def find_one_and_delete(self, filter, - projection=None, sort=None, hint=None, - session=None, **kwargs): + def find_one_and_delete( + self, filter, projection=None, sort=None, hint=None, session=None, **kwargs + ): """Finds a single document and deletes it, returning the document. >>> db.test.count_documents({'x': 1}) @@ -2355,14 +2526,23 @@ def find_one_and_delete(self, filter, Added the `collation` option. .. versionadded:: 3.0 """ - kwargs['remove'] = True - return self.__find_and_modify(filter, projection, sort, - hint=hint, session=session, **kwargs) - - def find_one_and_replace(self, filter, replacement, - projection=None, sort=None, upsert=False, - return_document=ReturnDocument.BEFORE, - hint=None, session=None, **kwargs): + kwargs["remove"] = True + return self.__find_and_modify( + filter, projection, sort, hint=hint, session=session, **kwargs + ) + + def find_one_and_replace( + self, + filter, + replacement, + projection=None, + sort=None, + upsert=False, + return_document=ReturnDocument.BEFORE, + hint=None, + session=None, + **kwargs + ): """Finds a single document and replaces it, returning either the original or the replaced document. @@ -2434,16 +2614,24 @@ def find_one_and_replace(self, filter, replacement, .. versionadded:: 3.0 """ common.validate_ok_for_replace(replacement) - kwargs['update'] = replacement - return self.__find_and_modify(filter, projection, - sort, upsert, return_document, - hint=hint, session=session, **kwargs) - - def find_one_and_update(self, filter, update, - projection=None, sort=None, upsert=False, - return_document=ReturnDocument.BEFORE, - array_filters=None, hint=None, session=None, - **kwargs): + kwargs["update"] = replacement + return self.__find_and_modify( + filter, projection, sort, upsert, return_document, hint=hint, session=session, **kwargs + ) + + def find_one_and_update( + self, + filter, + update, + projection=None, + sort=None, + upsert=False, + return_document=ReturnDocument.BEFORE, + array_filters=None, + hint=None, + session=None, + **kwargs + ): """Finds a single document and updates it, returning either the original or the updated document. @@ -2557,12 +2745,19 @@ def find_one_and_update(self, filter, update, .. versionadded:: 3.0 """ common.validate_ok_for_update(update) - common.validate_list_or_none('array_filters', array_filters) - kwargs['update'] = update - return self.__find_and_modify(filter, projection, - sort, upsert, return_document, - array_filters, hint=hint, - session=session, **kwargs) + common.validate_list_or_none("array_filters", array_filters) + kwargs["update"] = update + return self.__find_and_modify( + filter, + projection, + sort, + upsert, + return_document, + array_filters, + hint=hint, + session=session, + **kwargs + ) def __iter__(self): return self @@ -2573,15 +2768,16 @@ def __next__(self): next = __next__ def __call__(self, *args, **kwargs): - """This is only here so that some API misusages are easier to debug. - """ + """This is only here so that some API misusages are easier to debug.""" if "." not in self.__name: - raise TypeError("'Collection' object is not callable. If you " - "meant to call the '%s' method on a 'Database' " - "object it is failing because no such method " - "exists." % - self.__name) - raise TypeError("'Collection' object is not callable. If you meant to " - "call the '%s' method on a 'Collection' object it is " - "failing because no such method exists." % - self.__name.split(".")[-1]) + raise TypeError( + "'Collection' object is not callable. If you " + "meant to call the '%s' method on a 'Database' " + "object it is failing because no such method " + "exists." % self.__name + ) + raise TypeError( + "'Collection' object is not callable. If you meant to " + "call the '%s' method on a 'Collection' object it is " + "failing because no such method exists." % self.__name.split(".")[-1] + ) diff --git a/pymongo/command_cursor.py b/pymongo/command_cursor.py index 21822ac61b..24725485f6 100644 --- a/pymongo/command_cursor.py +++ b/pymongo/command_cursor.py @@ -17,35 +17,39 @@ from collections import deque from bson import _convert_raw_document_lists_to_streams -from pymongo.cursor import _SocketManager, _CURSOR_CLOSED_ERRORS -from pymongo.errors import (ConnectionFailure, - InvalidOperation, - OperationFailure) -from pymongo.message import (_CursorAddress, - _GetMore, - _RawBatchGetMore) +from pymongo.cursor import _CURSOR_CLOSED_ERRORS, _SocketManager +from pymongo.errors import ConnectionFailure, InvalidOperation, OperationFailure +from pymongo.message import _CursorAddress, _GetMore, _RawBatchGetMore from pymongo.response import PinnedResponse class CommandCursor(object): """A cursor / iterator over command cursors.""" + _getmore_class = _GetMore - def __init__(self, collection, cursor_info, address, - batch_size=0, max_await_time_ms=None, session=None, - explicit_session=False): + def __init__( + self, + collection, + cursor_info, + address, + batch_size=0, + max_await_time_ms=None, + session=None, + explicit_session=False, + ): """Create a new command cursor.""" self.__sock_mgr = None self.__collection = collection - self.__id = cursor_info['id'] - self.__data = deque(cursor_info['firstBatch']) - self.__postbatchresumetoken = cursor_info.get('postBatchResumeToken') + self.__id = cursor_info["id"] + self.__data = deque(cursor_info["firstBatch"]) + self.__postbatchresumetoken = cursor_info.get("postBatchResumeToken") self.__address = address self.__batch_size = batch_size self.__max_await_time_ms = max_await_time_ms self.__session = session self.__explicit_session = explicit_session - self.__killed = (self.__id == 0) + self.__killed = self.__id == 0 if self.__killed: self.__end_session(True) @@ -56,22 +60,19 @@ def __init__(self, collection, cursor_info, address, self.batch_size(batch_size) - if (not isinstance(max_await_time_ms, int) - and max_await_time_ms is not None): + if not isinstance(max_await_time_ms, int) and max_await_time_ms is not None: raise TypeError("max_await_time_ms must be an integer or None") def __del__(self): self.__die() def __die(self, synchronous=False): - """Closes this cursor. - """ + """Closes this cursor.""" already_killed = self.__killed self.__killed = True if self.__id and not already_killed: cursor_id = self.__id - address = _CursorAddress( - self.__address, self.__ns) + address = _CursorAddress(self.__address, self.__ns) else: # Skip killCursors. cursor_id = 0 @@ -82,7 +83,8 @@ def __die(self, synchronous=False): address, self.__sock_mgr, self.__session, - self.__explicit_session) + self.__explicit_session, + ) if not self.__explicit_session: self.__session = None self.__sock_mgr = None @@ -93,8 +95,7 @@ def __end_session(self, synchronous): self.__session = None def close(self): - """Explicitly close / kill this cursor. - """ + """Explicitly close / kill this cursor.""" self.__die(True) def batch_size(self, batch_size): @@ -147,12 +148,12 @@ def _maybe_pin_connection(self, sock_info): self.__sock_mgr = sock_mgr def __send_message(self, operation): - """Send a getmore message and handle the response. - """ + """Send a getmore message and handle the response.""" client = self.__collection.database.client try: response = client._run_operation( - operation, self._unpack_response, address=self.__address) + operation, self._unpack_response, address=self.__address + ) except OperationFailure as exc: if exc.code in _CURSOR_CLOSED_ERRORS: # Don't send killCursors because the cursor is already closed. @@ -172,13 +173,12 @@ def __send_message(self, operation): if isinstance(response, PinnedResponse): if not self.__sock_mgr: - self.__sock_mgr = _SocketManager(response.socket_info, - response.more_to_come) + self.__sock_mgr = _SocketManager(response.socket_info, response.more_to_come) if response.from_command: - cursor = response.docs[0]['cursor'] - documents = cursor['nextBatch'] - self.__postbatchresumetoken = cursor.get('postBatchResumeToken') - self.__id = cursor['id'] + cursor = response.docs[0]["cursor"] + documents = cursor["nextBatch"] + self.__postbatchresumetoken = cursor.get("postBatchResumeToken") + self.__id = cursor["id"] else: documents = response.docs self.__id = response.data.cursor_id @@ -187,10 +187,10 @@ def __send_message(self, operation): self.close() self.__data = deque(documents) - def _unpack_response(self, response, cursor_id, codec_options, - user_fields=None, legacy_response=False): - return response.unpack_response(cursor_id, codec_options, user_fields, - legacy_response) + def _unpack_response( + self, response, cursor_id, codec_options, user_fields=None, legacy_response=False + ): + return response.unpack_response(cursor_id, codec_options, user_fields, legacy_response) def _refresh(self): """Refreshes the cursor with more data from the server. @@ -203,19 +203,23 @@ def _refresh(self): return len(self.__data) if self.__id: # Get More - dbname, collname = self.__ns.split('.', 1) + dbname, collname = self.__ns.split(".", 1) read_pref = self.__collection._read_preference_for(self.session) self.__send_message( - self._getmore_class(dbname, - collname, - self.__batch_size, - self.__id, - self.__collection.codec_options, - read_pref, - self.__session, - self.__collection.database.client, - self.__max_await_time_ms, - self.__sock_mgr, False)) + self._getmore_class( + dbname, + collname, + self.__batch_size, + self.__id, + self.__collection.codec_options, + read_pref, + self.__session, + self.__collection.database.client, + self.__max_await_time_ms, + self.__sock_mgr, + False, + ) + ) else: # Cursor id is zero nothing else to return self.__die(True) @@ -294,9 +298,16 @@ def __exit__(self, exc_type, exc_val, exc_tb): class RawBatchCommandCursor(CommandCursor): _getmore_class = _RawBatchGetMore - def __init__(self, collection, cursor_info, address, - batch_size=0, max_await_time_ms=None, session=None, - explicit_session=False): + def __init__( + self, + collection, + cursor_info, + address, + batch_size=0, + max_await_time_ms=None, + session=None, + explicit_session=False, + ): """Create a new cursor / iterator over raw batches of BSON data. Should not be called directly by application developers - @@ -305,15 +316,21 @@ def __init__(self, collection, cursor_info, address, .. seealso:: The MongoDB documentation on `cursors `_. """ - assert not cursor_info.get('firstBatch') + assert not cursor_info.get("firstBatch") super(RawBatchCommandCursor, self).__init__( - collection, cursor_info, address, batch_size, - max_await_time_ms, session, explicit_session) - - def _unpack_response(self, response, cursor_id, codec_options, - user_fields=None, legacy_response=False): - raw_response = response.raw_response( - cursor_id, user_fields=user_fields) + collection, + cursor_info, + address, + batch_size, + max_await_time_ms, + session, + explicit_session, + ) + + def _unpack_response( + self, response, cursor_id, codec_options, user_fields=None, legacy_response=False + ): + raw_response = response.raw_response(cursor_id, user_fields=user_fields) if not legacy_response: # OP_MSG returns firstBatch/nextBatch documents as a BSON array # Re-assemble the array of documents into a document stream diff --git a/pymongo/common.py b/pymongo/common.py index 772f2f299b..3c588852e5 100644 --- a/pymongo/common.py +++ b/pymongo/common.py @@ -17,8 +17,7 @@ import datetime import warnings - -from collections import abc, OrderedDict +from collections import OrderedDict, abc from urllib.parse import unquote_plus from bson import SON @@ -26,20 +25,22 @@ from bson.codec_options import CodecOptions, TypeRegistry from bson.raw_bson import RawBSONDocument from pymongo.auth import MECHANISMS -from pymongo.compression_support import (validate_compressors, - validate_zlib_compression_level) +from pymongo.compression_support import ( + validate_compressors, + validate_zlib_compression_level, +) from pymongo.driver_info import DriverInfo -from pymongo.server_api import ServerApi from pymongo.errors import ConfigurationError from pymongo.monitoring import _validate_event_listeners from pymongo.read_concern import ReadConcern from pymongo.read_preferences import _MONGOS_MODES, _ServerMode +from pymongo.server_api import ServerApi from pymongo.write_concern import DEFAULT_WRITE_CONCERN, WriteConcern ORDERED_TYPES = (SON, OrderedDict) # Defaults until we connect to a server and get updated limits. -MAX_BSON_SIZE = 16 * (1024 ** 2) +MAX_BSON_SIZE = 16 * (1024**2) MAX_MESSAGE_SIZE = 2 * MAX_BSON_SIZE MIN_WIRE_VERSION = 0 MAX_WIRE_VERSION = 0 @@ -120,10 +121,10 @@ def partition_node(node): """Split a host:port string into (host, int(port)) pair.""" host = node port = 27017 - idx = node.rfind(':') + idx = node.rfind(":") if idx != -1: - host, port = node[:idx], int(node[idx + 1:]) - if host.startswith('['): + host, port = node[:idx], int(node[idx + 1 :]) + if host.startswith("["): host = host[1:-1] return host, port @@ -146,11 +147,11 @@ def raise_config_error(key, dummy): # Mapping of URI uuid representation options to valid subtypes. _UUID_REPRESENTATIONS = { - 'unspecified': UuidRepresentation.UNSPECIFIED, - 'standard': UuidRepresentation.STANDARD, - 'pythonLegacy': UuidRepresentation.PYTHON_LEGACY, - 'javaLegacy': UuidRepresentation.JAVA_LEGACY, - 'csharpLegacy': UuidRepresentation.CSHARP_LEGACY + "unspecified": UuidRepresentation.UNSPECIFIED, + "standard": UuidRepresentation.STANDARD, + "pythonLegacy": UuidRepresentation.PYTHON_LEGACY, + "javaLegacy": UuidRepresentation.JAVA_LEGACY, + "csharpLegacy": UuidRepresentation.CSHARP_LEGACY, } @@ -164,95 +165,81 @@ def validate_boolean(option, value): def validate_boolean_or_string(option, value): """Validates that value is True, False, 'true', or 'false'.""" if isinstance(value, str): - if value not in ('true', 'false'): - raise ValueError("The value of %s must be " - "'true' or 'false'" % (option,)) - return value == 'true' + if value not in ("true", "false"): + raise ValueError("The value of %s must be " "'true' or 'false'" % (option,)) + return value == "true" return validate_boolean(option, value) def validate_integer(option, value): - """Validates that 'value' is an integer (or basestring representation). - """ + """Validates that 'value' is an integer (or basestring representation).""" if isinstance(value, int): return value elif isinstance(value, str): try: return int(value) except ValueError: - raise ValueError("The value of %s must be " - "an integer" % (option,)) + raise ValueError("The value of %s must be " "an integer" % (option,)) raise TypeError("Wrong type for %s, value must be an integer" % (option,)) def validate_positive_integer(option, value): - """Validate that 'value' is a positive integer, which does not include 0. - """ + """Validate that 'value' is a positive integer, which does not include 0.""" val = validate_integer(option, value) if val <= 0: - raise ValueError("The value of %s must be " - "a positive integer" % (option,)) + raise ValueError("The value of %s must be " "a positive integer" % (option,)) return val def validate_non_negative_integer(option, value): - """Validate that 'value' is a positive integer or 0. - """ + """Validate that 'value' is a positive integer or 0.""" val = validate_integer(option, value) if val < 0: - raise ValueError("The value of %s must be " - "a non negative integer" % (option,)) + raise ValueError("The value of %s must be " "a non negative integer" % (option,)) return val def validate_readable(option, value): - """Validates that 'value' is file-like and readable. - """ + """Validates that 'value' is file-like and readable.""" if value is None: return value # First make sure its a string py3.3 open(True, 'r') succeeds # Used in ssl cert checking due to poor ssl module error reporting value = validate_string(option, value) - open(value, 'r').close() + open(value, "r").close() return value def validate_positive_integer_or_none(option, value): - """Validate that 'value' is a positive integer or None. - """ + """Validate that 'value' is a positive integer or None.""" if value is None: return value return validate_positive_integer(option, value) def validate_non_negative_integer_or_none(option, value): - """Validate that 'value' is a positive integer or 0 or None. - """ + """Validate that 'value' is a positive integer or 0 or None.""" if value is None: return value return validate_non_negative_integer(option, value) def validate_string(option, value): - """Validates that 'value' is an instance of `str`. - """ + """Validates that 'value' is an instance of `str`.""" if isinstance(value, str): return value - raise TypeError("Wrong type for %s, value must be an instance of " - "str" % (option,)) + raise TypeError("Wrong type for %s, value must be an instance of " "str" % (option,)) def validate_string_or_none(option, value): - """Validates that 'value' is an instance of `basestring` or `None`. - """ + """Validates that 'value' is an instance of `basestring` or `None`.""" if value is None: return value return validate_string(option, value) def validate_int_or_basestring(option, value): - """Validates that 'value' is an integer or string. - """ + """Validates that 'value' is an integer or string.""" if isinstance(value, int): return value elif isinstance(value, str): @@ -260,13 +247,11 @@ def validate_int_or_basestring(option, value): return int(value) except ValueError: return value - raise TypeError("Wrong type for %s, value must be an " - "integer or a string" % (option,)) + raise TypeError("Wrong type for %s, value must be an " "integer or a string" % (option,)) def validate_non_negative_int_or_basestring(option, value): - """Validates that 'value' is an integer or string. - """ + """Validates that 'value' is an integer or string.""" if isinstance(value, int): return value elif isinstance(value, str): @@ -275,13 +260,14 @@ def validate_non_negative_int_or_basestring(option, value): except ValueError: return value return validate_non_negative_integer(option, val) - raise TypeError("Wrong type for %s, value must be an " - "non negative integer or a string" % (option,)) + raise TypeError( + "Wrong type for %s, value must be an " "non negative integer or a string" % (option,) + ) def validate_positive_float(option, value): """Validates that 'value' is a float, or can be converted to one, and is - positive. + positive. """ errmsg = "%s must be an integer or float" % (option,) try: @@ -294,8 +280,7 @@ def validate_positive_float(option, value): # float('inf') doesn't work in 2.4 or 2.5 on Windows, so just cap floats at # one billion - this is a reasonable approximation for infinity if not 0 < value < 1e9: - raise ValueError("%s must be greater than 0 and " - "less than one billion" % (option,)) + raise ValueError("%s must be greater than 0 and " "less than one billion" % (option,)) return value @@ -324,7 +309,7 @@ def validate_timeout_or_zero(option, value): config error. """ if value is None: - raise ConfigurationError("%s cannot be None" % (option, )) + raise ConfigurationError("%s cannot be None" % (option,)) if value == 0 or value == "0": return 0 return validate_positive_float(option, value) / 1000.0 @@ -349,8 +334,7 @@ def validate_max_staleness(option, value): def validate_read_preference(dummy, value): - """Validate a read preference. - """ + """Validate a read preference.""" if not isinstance(value, _ServerMode): raise TypeError("%r is not a read preference." % (value,)) return value @@ -369,33 +353,32 @@ def validate_read_preference_mode(dummy, value): def validate_auth_mechanism(option, value): - """Validate the authMechanism URI option. - """ + """Validate the authMechanism URI option.""" if value not in MECHANISMS: raise ValueError("%s must be in %s" % (option, tuple(MECHANISMS))) return value def validate_uuid_representation(dummy, value): - """Validate the uuid representation option selected in the URI. - """ + """Validate the uuid representation option selected in the URI.""" try: return _UUID_REPRESENTATIONS[value] except KeyError: - raise ValueError("%s is an invalid UUID representation. " - "Must be one of " - "%s" % (value, tuple(_UUID_REPRESENTATIONS))) + raise ValueError( + "%s is an invalid UUID representation. " + "Must be one of " + "%s" % (value, tuple(_UUID_REPRESENTATIONS)) + ) def validate_read_preference_tags(name, value): - """Parse readPreferenceTags if passed as a client kwarg. - """ + """Parse readPreferenceTags if passed as a client kwarg.""" if not isinstance(value, list): value = [value] tag_sets = [] for tag_set in value: - if tag_set == '': + if tag_set == "": tag_sets.append({}) continue try: @@ -405,37 +388,41 @@ def validate_read_preference_tags(name, value): tags[unquote_plus(key)] = unquote_plus(val) tag_sets.append(tags) except Exception: - raise ValueError("%r not a valid " - "value for %s" % (tag_set, name)) + raise ValueError("%r not a valid " "value for %s" % (tag_set, name)) return tag_sets -_MECHANISM_PROPS = frozenset(['SERVICE_NAME', - 'CANONICALIZE_HOST_NAME', - 'SERVICE_REALM', - 'AWS_SESSION_TOKEN']) +_MECHANISM_PROPS = frozenset( + ["SERVICE_NAME", "CANONICALIZE_HOST_NAME", "SERVICE_REALM", "AWS_SESSION_TOKEN"] +) def validate_auth_mechanism_properties(option, value): """Validate authMechanismProperties.""" value = validate_string(option, value) props = {} - for opt in value.split(','): + for opt in value.split(","): try: - key, val = opt.split(':') + key, val = opt.split(":") except ValueError: # Try not to leak the token. - if 'AWS_SESSION_TOKEN' in opt: - opt = ('AWS_SESSION_TOKEN:, did you forget ' - 'to percent-escape the token with quote_plus?') - raise ValueError("auth mechanism properties must be " - "key:value pairs like SERVICE_NAME:" - "mongodb, not %s." % (opt,)) + if "AWS_SESSION_TOKEN" in opt: + opt = ( + "AWS_SESSION_TOKEN:, did you forget " + "to percent-escape the token with quote_plus?" + ) + raise ValueError( + "auth mechanism properties must be " + "key:value pairs like SERVICE_NAME:" + "mongodb, not %s." % (opt,) + ) if key not in _MECHANISM_PROPS: - raise ValueError("%s is not a supported auth " - "mechanism property. Must be one of " - "%s." % (key, tuple(_MECHANISM_PROPS))) - if key == 'CANONICALIZE_HOST_NAME': + raise ValueError( + "%s is not a supported auth " + "mechanism property. Must be one of " + "%s." % (key, tuple(_MECHANISM_PROPS)) + ) + if key == "CANONICALIZE_HOST_NAME": props[key] = validate_boolean_or_string(key, val) else: props[key] = unquote_plus(val) @@ -446,17 +433,18 @@ def validate_auth_mechanism_properties(option, value): def validate_document_class(option, value): """Validate the document_class option.""" if not issubclass(value, (abc.MutableMapping, RawBSONDocument)): - raise TypeError("%s must be dict, bson.son.SON, " - "bson.raw_bson.RawBSONDocument, or a " - "sublass of collections.MutableMapping" % (option,)) + raise TypeError( + "%s must be dict, bson.son.SON, " + "bson.raw_bson.RawBSONDocument, or a " + "sublass of collections.MutableMapping" % (option,) + ) return value def validate_type_registry(option, value): """Validate the type_registry option.""" if value is not None and not isinstance(value, TypeRegistry): - raise TypeError("%s must be an instance of %s" % ( - option, TypeRegistry)) + raise TypeError("%s must be an instance of %s" % (option, TypeRegistry)) return value @@ -477,26 +465,32 @@ def validate_list_or_none(option, value): def validate_list_or_mapping(option, value): """Validates that 'value' is a list or a document.""" if not isinstance(value, (abc.Mapping, list)): - raise TypeError("%s must either be a list or an instance of dict, " - "bson.son.SON, or any other type that inherits from " - "collections.Mapping" % (option,)) + raise TypeError( + "%s must either be a list or an instance of dict, " + "bson.son.SON, or any other type that inherits from " + "collections.Mapping" % (option,) + ) def validate_is_mapping(option, value): """Validate the type of method arguments that expect a document.""" if not isinstance(value, abc.Mapping): - raise TypeError("%s must be an instance of dict, bson.son.SON, or " - "any other type that inherits from " - "collections.Mapping" % (option,)) + raise TypeError( + "%s must be an instance of dict, bson.son.SON, or " + "any other type that inherits from " + "collections.Mapping" % (option,) + ) def validate_is_document_type(option, value): """Validate the type of method arguments that expect a MongoDB document.""" if not isinstance(value, (abc.MutableMapping, RawBSONDocument)): - raise TypeError("%s must be an instance of dict, bson.son.SON, " - "bson.raw_bson.RawBSONDocument, or " - "a type that inherits from " - "collections.MutableMapping" % (option,)) + raise TypeError( + "%s must be an instance of dict, bson.son.SON, " + "bson.raw_bson.RawBSONDocument, or " + "a type that inherits from " + "collections.MutableMapping" % (option,) + ) def validate_appname_or_none(option, value): @@ -505,7 +499,7 @@ def validate_appname_or_none(option, value): return value validate_string(option, value) # We need length in bytes, so encode utf8 first. - if len(value.encode('utf-8')) > 128: + if len(value.encode("utf-8")) > 128: raise ValueError("%s must be <= 128 bytes" % (option,)) return value @@ -543,8 +537,8 @@ def validate_ok_for_replace(replacement): # Replacement can be {} if replacement and not isinstance(replacement, RawBSONDocument): first = next(iter(replacement)) - if first.startswith('$'): - raise ValueError('replacement can not include $ operators') + if first.startswith("$"): + raise ValueError("replacement can not include $ operators") def validate_ok_for_update(update): @@ -552,30 +546,30 @@ def validate_ok_for_update(update): validate_list_or_mapping("update", update) # Update cannot be {}. if not update: - raise ValueError('update cannot be empty') + raise ValueError("update cannot be empty") is_document = not isinstance(update, list) first = next(iter(update)) - if is_document and not first.startswith('$'): - raise ValueError('update only works with $ operators') + if is_document and not first.startswith("$"): + raise ValueError("update only works with $ operators") -_UNICODE_DECODE_ERROR_HANDLERS = frozenset(['strict', 'replace', 'ignore']) +_UNICODE_DECODE_ERROR_HANDLERS = frozenset(["strict", "replace", "ignore"]) def validate_unicode_decode_error_handler(dummy, value): - """Validate the Unicode decode error handler option of CodecOptions. - """ + """Validate the Unicode decode error handler option of CodecOptions.""" if value not in _UNICODE_DECODE_ERROR_HANDLERS: - raise ValueError("%s is an invalid Unicode decode error handler. " - "Must be one of " - "%s" % (value, tuple(_UNICODE_DECODE_ERROR_HANDLERS))) + raise ValueError( + "%s is an invalid Unicode decode error handler. " + "Must be one of " + "%s" % (value, tuple(_UNICODE_DECODE_ERROR_HANDLERS)) + ) return value def validate_tzinfo(dummy, value): - """Validate the tzinfo option - """ + """Validate the tzinfo option""" if value is not None and not isinstance(value, datetime.tzinfo): raise TypeError("%s must be an instance of datetime.tzinfo" % value) return value @@ -586,9 +580,9 @@ def validate_auto_encryption_opts_or_none(option, value): if value is None: return value from pymongo.encryption_options import AutoEncryptionOpts + if not isinstance(value, AutoEncryptionOpts): - raise TypeError("%s must be an instance of AutoEncryptionOpts" % ( - option,)) + raise TypeError("%s must be an instance of AutoEncryptionOpts" % (option,)) return value @@ -596,7 +590,7 @@ def validate_auto_encryption_opts_or_none(option, value): # Dictionary where keys are the names of public URI options, and values # are lists of aliases for that option. URI_OPTIONS_ALIAS_MAP = { - 'tls': ['ssl'], + "tls": ["ssl"], } # Dictionary where keys are the names of URI options, and values @@ -604,73 +598,73 @@ def validate_auto_encryption_opts_or_none(option, value): # alias uses a different validator than its public counterpart, it should be # included here as a key, value pair. URI_OPTIONS_VALIDATOR_MAP = { - 'appname': validate_appname_or_none, - 'authmechanism': validate_auth_mechanism, - 'authmechanismproperties': validate_auth_mechanism_properties, - 'authsource': validate_string, - 'compressors': validate_compressors, - 'connecttimeoutms': validate_timeout_or_none_or_zero, - 'directconnection': validate_boolean_or_string, - 'heartbeatfrequencyms': validate_timeout_or_none, - 'journal': validate_boolean_or_string, - 'localthresholdms': validate_positive_float_or_zero, - 'maxidletimems': validate_timeout_or_none, - 'maxconnecting': validate_positive_integer, - 'maxpoolsize': validate_non_negative_integer_or_none, - 'maxstalenessseconds': validate_max_staleness, - 'readconcernlevel': validate_string_or_none, - 'readpreference': validate_read_preference_mode, - 'readpreferencetags': validate_read_preference_tags, - 'replicaset': validate_string_or_none, - 'retryreads': validate_boolean_or_string, - 'retrywrites': validate_boolean_or_string, - 'loadbalanced': validate_boolean_or_string, - 'serverselectiontimeoutms': validate_timeout_or_zero, - 'sockettimeoutms': validate_timeout_or_none_or_zero, - 'tls': validate_boolean_or_string, - 'tlsallowinvalidcertificates': validate_boolean_or_string, - 'tlsallowinvalidhostnames': validate_boolean_or_string, - 'tlscafile': validate_readable, - 'tlscertificatekeyfile': validate_readable, - 'tlscertificatekeyfilepassword': validate_string_or_none, - 'tlsdisableocspendpointcheck': validate_boolean_or_string, - 'tlsinsecure': validate_boolean_or_string, - 'w': validate_non_negative_int_or_basestring, - 'wtimeoutms': validate_non_negative_integer, - 'zlibcompressionlevel': validate_zlib_compression_level, - 'srvservicename': validate_string, - 'srvmaxhosts': validate_non_negative_integer + "appname": validate_appname_or_none, + "authmechanism": validate_auth_mechanism, + "authmechanismproperties": validate_auth_mechanism_properties, + "authsource": validate_string, + "compressors": validate_compressors, + "connecttimeoutms": validate_timeout_or_none_or_zero, + "directconnection": validate_boolean_or_string, + "heartbeatfrequencyms": validate_timeout_or_none, + "journal": validate_boolean_or_string, + "localthresholdms": validate_positive_float_or_zero, + "maxidletimems": validate_timeout_or_none, + "maxconnecting": validate_positive_integer, + "maxpoolsize": validate_non_negative_integer_or_none, + "maxstalenessseconds": validate_max_staleness, + "readconcernlevel": validate_string_or_none, + "readpreference": validate_read_preference_mode, + "readpreferencetags": validate_read_preference_tags, + "replicaset": validate_string_or_none, + "retryreads": validate_boolean_or_string, + "retrywrites": validate_boolean_or_string, + "loadbalanced": validate_boolean_or_string, + "serverselectiontimeoutms": validate_timeout_or_zero, + "sockettimeoutms": validate_timeout_or_none_or_zero, + "tls": validate_boolean_or_string, + "tlsallowinvalidcertificates": validate_boolean_or_string, + "tlsallowinvalidhostnames": validate_boolean_or_string, + "tlscafile": validate_readable, + "tlscertificatekeyfile": validate_readable, + "tlscertificatekeyfilepassword": validate_string_or_none, + "tlsdisableocspendpointcheck": validate_boolean_or_string, + "tlsinsecure": validate_boolean_or_string, + "w": validate_non_negative_int_or_basestring, + "wtimeoutms": validate_non_negative_integer, + "zlibcompressionlevel": validate_zlib_compression_level, + "srvservicename": validate_string, + "srvmaxhosts": validate_non_negative_integer, } # Dictionary where keys are the names of URI options specific to pymongo, # and values are functions that validate user-input values for those options. NONSPEC_OPTIONS_VALIDATOR_MAP = { - 'connect': validate_boolean_or_string, - 'driver': validate_driver_or_none, - 'server_api': validate_server_api_or_none, - 'fsync': validate_boolean_or_string, - 'minpoolsize': validate_non_negative_integer, - 'tlscrlfile': validate_readable, - 'tz_aware': validate_boolean_or_string, - 'unicode_decode_error_handler': validate_unicode_decode_error_handler, - 'uuidrepresentation': validate_uuid_representation, - 'waitqueuemultiple': validate_non_negative_integer_or_none, - 'waitqueuetimeoutms': validate_timeout_or_none, + "connect": validate_boolean_or_string, + "driver": validate_driver_or_none, + "server_api": validate_server_api_or_none, + "fsync": validate_boolean_or_string, + "minpoolsize": validate_non_negative_integer, + "tlscrlfile": validate_readable, + "tz_aware": validate_boolean_or_string, + "unicode_decode_error_handler": validate_unicode_decode_error_handler, + "uuidrepresentation": validate_uuid_representation, + "waitqueuemultiple": validate_non_negative_integer_or_none, + "waitqueuetimeoutms": validate_timeout_or_none, } # Dictionary where keys are the names of keyword-only options for the # MongoClient constructor, and values are functions that validate user-input # values for those options. KW_VALIDATORS = { - 'document_class': validate_document_class, - 'type_registry': validate_type_registry, - 'read_preference': validate_read_preference, - 'event_listeners': _validate_event_listeners, - 'tzinfo': validate_tzinfo, - 'username': validate_string_or_none, - 'password': validate_string_or_none, - 'server_selector': validate_is_callable_or_none, - 'auto_encryption_opts': validate_auto_encryption_opts_or_none, + "document_class": validate_document_class, + "type_registry": validate_type_registry, + "read_preference": validate_read_preference, + "event_listeners": _validate_event_listeners, + "tzinfo": validate_tzinfo, + "username": validate_string_or_none, + "password": validate_string_or_none, + "server_selector": validate_is_callable_or_none, + "auto_encryption_opts": validate_auto_encryption_opts_or_none, } # Dictionary where keys are any URI option name, and values are the @@ -678,7 +672,7 @@ def validate_auto_encryption_opts_or_none(option, value): # variant need not be included here. Options whose public and internal # names are the same need not be included here. INTERNAL_URI_OPTION_NAME_MAP = { - 'ssl': 'tls', + "ssl": "tls", } # Map from deprecated URI option names to a tuple indicating the method of @@ -700,8 +694,7 @@ def validate_auto_encryption_opts_or_none(option, value): for optname, aliases in URI_OPTIONS_ALIAS_MAP.items(): for alias in aliases: if alias not in URI_OPTIONS_VALIDATOR_MAP: - URI_OPTIONS_VALIDATOR_MAP[alias] = ( - URI_OPTIONS_VALIDATOR_MAP[optname]) + URI_OPTIONS_VALIDATOR_MAP[alias] = URI_OPTIONS_VALIDATOR_MAP[optname] # Map containing all URI option and keyword argument validators. VALIDATORS = URI_OPTIONS_VALIDATOR_MAP.copy() @@ -709,32 +702,29 @@ def validate_auto_encryption_opts_or_none(option, value): # List of timeout-related options. TIMEOUT_OPTIONS = [ - 'connecttimeoutms', - 'heartbeatfrequencyms', - 'maxidletimems', - 'maxstalenessseconds', - 'serverselectiontimeoutms', - 'sockettimeoutms', - 'waitqueuetimeoutms', + "connecttimeoutms", + "heartbeatfrequencyms", + "maxidletimems", + "maxstalenessseconds", + "serverselectiontimeoutms", + "sockettimeoutms", + "waitqueuetimeoutms", ] -_AUTH_OPTIONS = frozenset(['authmechanismproperties']) +_AUTH_OPTIONS = frozenset(["authmechanismproperties"]) def validate_auth_option(option, value): - """Validate optional authentication parameters. - """ + """Validate optional authentication parameters.""" lower, value = validate(option, value) if lower not in _AUTH_OPTIONS: - raise ConfigurationError('Unknown ' - 'authentication option: %s' % (option,)) + raise ConfigurationError("Unknown " "authentication option: %s" % (option,)) return option, value def validate(option, value): - """Generic validation function. - """ + """Generic validation function.""" lower = option.lower() validator = VALIDATORS.get(lower, raise_config_error) value = validator(option, value) @@ -763,8 +753,7 @@ def get_validated_options(options, warn=True): for opt, value in options.items(): normed_key = get_normed_key(opt) try: - validator = URI_OPTIONS_VALIDATOR_MAP.get( - normed_key, raise_config_error) + validator = URI_OPTIONS_VALIDATOR_MAP.get(normed_key, raise_config_error) value = validator(opt, value) except (ValueError, TypeError, ConfigurationError) as exc: if warn: @@ -777,14 +766,7 @@ def get_validated_options(options, warn=True): # List of write-concern-related options. -WRITE_CONCERN_OPTIONS = frozenset([ - 'w', - 'wtimeout', - 'wtimeoutms', - 'fsync', - 'j', - 'journal' -]) +WRITE_CONCERN_OPTIONS = frozenset(["w", "wtimeout", "wtimeoutms", "fsync", "j", "journal"]) class BaseObject(object): @@ -794,28 +776,32 @@ class BaseObject(object): SHOULD NOT BE USED BY DEVELOPERS EXTERNAL TO MONGODB. """ - def __init__(self, codec_options, read_preference, write_concern, - read_concern): + def __init__(self, codec_options, read_preference, write_concern, read_concern): if not isinstance(codec_options, CodecOptions): - raise TypeError("codec_options must be an instance of " - "bson.codec_options.CodecOptions") + raise TypeError( + "codec_options must be an instance of " "bson.codec_options.CodecOptions" + ) self.__codec_options = codec_options if not isinstance(read_preference, _ServerMode): - raise TypeError("%r is not valid for read_preference. See " - "pymongo.read_preferences for valid " - "options." % (read_preference,)) + raise TypeError( + "%r is not valid for read_preference. See " + "pymongo.read_preferences for valid " + "options." % (read_preference,) + ) self.__read_preference = read_preference if not isinstance(write_concern, WriteConcern): - raise TypeError("write_concern must be an instance of " - "pymongo.write_concern.WriteConcern") + raise TypeError( + "write_concern must be an instance of " "pymongo.write_concern.WriteConcern" + ) self.__write_concern = write_concern if not isinstance(read_concern, ReadConcern): - raise TypeError("read_concern must be an instance of " - "pymongo.read_concern.ReadConcern") + raise TypeError( + "read_concern must be an instance of " "pymongo.read_concern.ReadConcern" + ) self.__read_concern = read_concern @property @@ -836,8 +822,7 @@ def write_concern(self): return self.__write_concern def _write_concern_for(self, session): - """Read only access to the write concern of this instance or session. - """ + """Read only access to the write concern of this instance or session.""" # Override this operation's write concern with the transaction's. if session and session.in_transaction: return DEFAULT_WRITE_CONCERN @@ -853,8 +838,7 @@ def read_preference(self): return self.__read_preference def _read_preference_for(self, session): - """Read only access to the read preference of this instance or session. - """ + """Read only access to the read preference of this instance or session.""" # Override this operation's read preference with the transaction's. if session: return session._txn_read_preference() or self.__read_preference diff --git a/pymongo/compression_support.py b/pymongo/compression_support.py index d367595288..d3921ad2e8 100644 --- a/pymongo/compression_support.py +++ b/pymongo/compression_support.py @@ -16,6 +16,7 @@ try: import snappy + _HAVE_SNAPPY = True except ImportError: # python-snappy isn't available. @@ -23,6 +24,7 @@ try: import zlib + _HAVE_ZLIB = True except ImportError: # Python built without zlib support. @@ -30,6 +32,7 @@ try: from zstandard import ZstdCompressor, ZstdDecompressor + _HAVE_ZSTD = True except ImportError: _HAVE_ZSTD = False @@ -58,17 +61,20 @@ def validate_compressors(dummy, value): compressors.remove(compressor) warnings.warn( "Wire protocol compression with snappy is not available. " - "You must install the python-snappy module for snappy support.") + "You must install the python-snappy module for snappy support." + ) elif compressor == "zlib" and not _HAVE_ZLIB: compressors.remove(compressor) warnings.warn( "Wire protocol compression with zlib is not available. " - "The zlib module is not available.") + "The zlib module is not available." + ) elif compressor == "zstd" and not _HAVE_ZSTD: compressors.remove(compressor) warnings.warn( "Wire protocol compression with zstandard is not available. " - "You must install the zstandard module for zstandard support.") + "You must install the zstandard module for zstandard support." + ) return compressors @@ -78,8 +84,7 @@ def validate_zlib_compression_level(option, value): except: raise TypeError("%s must be an integer, not %r." % (option, value)) if level < -1 or level > 9: - raise ValueError( - "%s must be between -1 and 9, not %d." % (option, level)) + raise ValueError("%s must be between -1 and 9, not %d." % (option, level)) return level diff --git a/pymongo/cursor.py b/pymongo/cursor.py index c38adaf377..72b0320ccb 100644 --- a/pymongo/cursor.py +++ b/pymongo/cursor.py @@ -17,54 +17,54 @@ import copy import threading import warnings - from collections import deque from bson import RE_TYPE, _convert_raw_document_lists_to_streams from bson.code import Code from bson.son import SON from pymongo import helpers -from pymongo.common import validate_boolean, validate_is_mapping from pymongo.collation import validate_collation_or_none -from pymongo.errors import (ConnectionFailure, - InvalidOperation, - OperationFailure) -from pymongo.message import (_CursorAddress, - _GetMore, - _RawBatchGetMore, - _Query, - _RawBatchQuery) +from pymongo.common import validate_boolean, validate_is_mapping +from pymongo.errors import ConnectionFailure, InvalidOperation, OperationFailure +from pymongo.message import ( + _CursorAddress, + _GetMore, + _Query, + _RawBatchGetMore, + _RawBatchQuery, +) from pymongo.response import PinnedResponse # These errors mean that the server has already killed the cursor so there is # no need to send killCursors. -_CURSOR_CLOSED_ERRORS = frozenset([ - 43, # CursorNotFound - 50, # MaxTimeMSExpired - 175, # QueryPlanKilled - 237, # CursorKilled - - # On a tailable cursor, the following errors mean the capped collection - # rolled over. - # MongoDB 2.6: - # {'$err': 'Runner killed during getMore', 'code': 28617, 'ok': 0} - 28617, - # MongoDB 3.0: - # {'$err': 'getMore executor error: UnknownError no details available', - # 'code': 17406, 'ok': 0} - 17406, - # MongoDB 3.2 + 3.4: - # {'ok': 0.0, 'errmsg': 'GetMore command executor error: - # CappedPositionLost: CollectionScan died due to failure to restore - # tailable cursor position. Last seen record id: RecordId(3)', - # 'code': 96} - 96, - # MongoDB 3.6+: - # {'ok': 0.0, 'errmsg': 'errmsg: "CollectionScan died due to failure to - # restore tailable cursor position. Last seen record id: RecordId(3)"', - # 'code': 136, 'codeName': 'CappedPositionLost'} - 136, -]) +_CURSOR_CLOSED_ERRORS = frozenset( + [ + 43, # CursorNotFound + 50, # MaxTimeMSExpired + 175, # QueryPlanKilled + 237, # CursorKilled + # On a tailable cursor, the following errors mean the capped collection + # rolled over. + # MongoDB 2.6: + # {'$err': 'Runner killed during getMore', 'code': 28617, 'ok': 0} + 28617, + # MongoDB 3.0: + # {'$err': 'getMore executor error: UnknownError no details available', + # 'code': 17406, 'ok': 0} + 17406, + # MongoDB 3.2 + 3.4: + # {'ok': 0.0, 'errmsg': 'GetMore command executor error: + # CappedPositionLost: CollectionScan died due to failure to restore + # tailable cursor position. Last seen record id: RecordId(3)', + # 'code': 96} + 96, + # MongoDB 3.6+: + # {'ok': 0.0, 'errmsg': 'errmsg: "CollectionScan died due to failure to + # restore tailable cursor position. Last seen record id: RecordId(3)"', + # 'code': 136, 'codeName': 'CappedPositionLost'} + 136, + ] +) _QUERY_OPTIONS = { "tailable_cursor": 2, @@ -73,7 +73,8 @@ "no_timeout": 16, "await_data": 32, "exhaust": 64, - "partial": 128} + "partial": 128, +} class CursorType(object): @@ -106,8 +107,8 @@ class CursorType(object): class _SocketManager(object): - """Used with exhaust cursors to ensure the socket is returned. - """ + """Used with exhaust cursors to ensure the socket is returned.""" + def __init__(self, sock, more_to_come): self.sock = sock self.more_to_come = more_to_come @@ -118,8 +119,7 @@ def update_exhaust(self, more_to_come): self.more_to_come = more_to_come def close(self): - """Return this instance's socket to the connection pool. - """ + """Return this instance's socket to the connection pool.""" if not self.closed: self.closed = True self.sock.unpin() @@ -127,20 +127,37 @@ def close(self): class Cursor(object): - """A cursor / iterator over Mongo query results. - """ + """A cursor / iterator over Mongo query results.""" + _query_class = _Query _getmore_class = _GetMore - def __init__(self, collection, filter=None, projection=None, skip=0, - limit=0, no_cursor_timeout=False, - cursor_type=CursorType.NON_TAILABLE, - sort=None, allow_partial_results=False, oplog_replay=False, - batch_size=0, - collation=None, hint=None, max_scan=None, max_time_ms=None, - max=None, min=None, return_key=None, show_record_id=None, - snapshot=None, comment=None, session=None, - allow_disk_use=None): + def __init__( + self, + collection, + filter=None, + projection=None, + skip=0, + limit=0, + no_cursor_timeout=False, + cursor_type=CursorType.NON_TAILABLE, + sort=None, + allow_partial_results=False, + oplog_replay=False, + batch_size=0, + collation=None, + hint=None, + max_scan=None, + max_time_ms=None, + max=None, + min=None, + return_key=None, + show_record_id=None, + snapshot=None, + comment=None, + session=None, + allow_disk_use=None, + ): """Create a new cursor. Should not be called directly by application developers - see @@ -174,15 +191,22 @@ def __init__(self, collection, filter=None, projection=None, skip=0, raise TypeError("limit must be an instance of int") validate_boolean("no_cursor_timeout", no_cursor_timeout) if no_cursor_timeout and not self.__explicit_session: - warnings.warn("use an explicit session with no_cursor_timeout=True " - "otherwise the cursor may still timeout after " - "30 minutes, for more info see " - "https://docs.mongodb.com/v4.4/reference/method/" - "cursor.noCursorTimeout/" - "#session-idle-timeout-overrides-nocursortimeout", - UserWarning, stacklevel=2) - if cursor_type not in (CursorType.NON_TAILABLE, CursorType.TAILABLE, - CursorType.TAILABLE_AWAIT, CursorType.EXHAUST): + warnings.warn( + "use an explicit session with no_cursor_timeout=True " + "otherwise the cursor may still timeout after " + "30 minutes, for more info see " + "https://docs.mongodb.com/v4.4/reference/method/" + "cursor.noCursorTimeout/" + "#session-idle-timeout-overrides-nocursortimeout", + UserWarning, + stacklevel=2, + ) + if cursor_type not in ( + CursorType.NON_TAILABLE, + CursorType.TAILABLE, + CursorType.TAILABLE_AWAIT, + CursorType.EXHAUST, + ): raise ValueError("not a valid value for cursor_type") validate_boolean("allow_partial_results", allow_partial_results) validate_boolean("oplog_replay", oplog_replay) @@ -220,8 +244,7 @@ def __init__(self, collection, filter=None, projection=None, skip=0, # Exhaust cursor support if cursor_type == CursorType.EXHAUST: if self.__collection.database.client.is_mongos: - raise InvalidOperation('Exhaust cursors are ' - 'not supported by mongos') + raise InvalidOperation("Exhaust cursors are " "not supported by mongos") if limit: raise InvalidOperation("Can't use limit and exhaust together.") self.__exhaust = True @@ -264,8 +287,7 @@ def collection(self): @property def retrieved(self): - """The number of documents retrieved so far. - """ + """The number of documents retrieved so far.""" return self.__retrieved def __del__(self): @@ -307,28 +329,46 @@ def _clone(self, deepcopy=True, base=None): else: base = self._clone_base(None) - values_to_clone = ("spec", "projection", "skip", "limit", - "max_time_ms", "max_await_time_ms", "comment", - "max", "min", "ordering", "explain", "hint", - "batch_size", "max_scan", - "query_flags", "collation", "empty", - "show_record_id", "return_key", "allow_disk_use", - "snapshot", "exhaust") - data = dict((k, v) for k, v in self.__dict__.items() - if k.startswith('_Cursor__') and k[9:] in values_to_clone) + values_to_clone = ( + "spec", + "projection", + "skip", + "limit", + "max_time_ms", + "max_await_time_ms", + "comment", + "max", + "min", + "ordering", + "explain", + "hint", + "batch_size", + "max_scan", + "query_flags", + "collation", + "empty", + "show_record_id", + "return_key", + "allow_disk_use", + "snapshot", + "exhaust", + ) + data = dict( + (k, v) + for k, v in self.__dict__.items() + if k.startswith("_Cursor__") and k[9:] in values_to_clone + ) if deepcopy: data = self._deepcopy(data) base.__dict__.update(data) return base def _clone_base(self, session): - """Creates an empty Cursor object for information to be copied into. - """ + """Creates an empty Cursor object for information to be copied into.""" return self.__class__(self.__collection, session=session) def __die(self, synchronous=False): - """Closes this cursor. - """ + """Closes this cursor.""" try: already_killed = self.__killed except AttributeError: @@ -338,8 +378,7 @@ def __die(self, synchronous=False): self.__killed = True if self.__id and not already_killed: cursor_id = self.__id - address = _CursorAddress( - self.__address, "%s.%s" % (self.__dbname, self.__collname)) + address = _CursorAddress(self.__address, "%s.%s" % (self.__dbname, self.__collname)) else: # Skip killCursors. cursor_id = 0 @@ -350,19 +389,18 @@ def __die(self, synchronous=False): address, self.__sock_mgr, self.__session, - self.__explicit_session) + self.__explicit_session, + ) if not self.__explicit_session: self.__session = None self.__sock_mgr = None def close(self): - """Explicitly close / kill this cursor. - """ + """Explicitly close / kill this cursor.""" self.__die(True) def __query_spec(self): - """Get the spec to use for a query. - """ + """Get the spec to use for a query.""" operators = {} if self.__ordering: operators["$orderby"] = self.__ordering @@ -409,16 +447,15 @@ def __query_spec(self): # that breaks commands like count and find_and_modify. # Checking spec.keys()[0] covers the case that the spec # was passed as an instance of SON or OrderedDict. - elif ("query" in self.__spec and - (len(self.__spec) == 1 or - next(iter(self.__spec)) == "query")): + elif "query" in self.__spec and ( + len(self.__spec) == 1 or next(iter(self.__spec)) == "query" + ): return SON({"$query": self.__spec}) return self.__spec def __check_okay_to_chain(self): - """Check if it is okay to chain more options onto this cursor. - """ + """Check if it is okay to chain more options onto this cursor.""" if self.__retrieved or self.__id is not None: raise InvalidOperation("cannot set options after executing query") @@ -436,8 +473,7 @@ def add_option(self, mask): if self.__limit: raise InvalidOperation("Can't use limit and exhaust together.") if self.__collection.database.client.is_mongos: - raise InvalidOperation('Exhaust cursors are ' - 'not supported by mongos') + raise InvalidOperation("Exhaust cursors are " "not supported by mongos") self.__exhaust = True self.__query_flags |= mask @@ -475,7 +511,7 @@ def allow_disk_use(self, allow_disk_use): .. versionadded:: 3.11 """ if not isinstance(allow_disk_use, bool): - raise TypeError('allow_disk_use must be a bool') + raise TypeError("allow_disk_use must be a bool") self.__check_okay_to_chain() self.__allow_disk_use = allow_disk_use @@ -566,8 +602,7 @@ def max_time_ms(self, max_time_ms): :Parameters: - `max_time_ms`: the time limit after which the operation is aborted """ - if (not isinstance(max_time_ms, int) - and max_time_ms is not None): + if not isinstance(max_time_ms, int) and max_time_ms is not None: raise TypeError("max_time_ms must be an integer or None") self.__check_okay_to_chain() @@ -591,8 +626,7 @@ def max_await_time_ms(self, max_await_time_ms): .. versionadded:: 3.2 """ - if (not isinstance(max_await_time_ms, int) - and max_await_time_ms is not None): + if not isinstance(max_await_time_ms, int) and max_await_time_ms is not None: raise TypeError("max_await_time_ms must be an integer or None") self.__check_okay_to_chain() @@ -652,15 +686,15 @@ def __getitem__(self, index): skip = 0 if index.start is not None: if index.start < 0: - raise IndexError("Cursor instances do not support " - "negative indices") + raise IndexError("Cursor instances do not support " "negative indices") skip = index.start if index.stop is not None: limit = index.stop - skip if limit < 0: - raise IndexError("stop index must be greater than start " - "index for slice %r" % index) + raise IndexError( + "stop index must be greater than start " "index for slice %r" % index + ) if limit == 0: self.__empty = True else: @@ -672,8 +706,7 @@ def __getitem__(self, index): if isinstance(index, int): if index < 0: - raise IndexError("Cursor instances do not support negative " - "indices") + raise IndexError("Cursor instances do not support negative " "indices") clone = self.clone() clone.skip(index + self.__skip) clone.limit(-1) # use a hard limit @@ -681,8 +714,7 @@ def __getitem__(self, index): for doc in clone: return doc raise IndexError("no such item for Cursor instance") - raise TypeError("index %r cannot be applied to Cursor " - "instances" % index) + raise TypeError("index %r cannot be applied to Cursor " "instances" % index) def max_scan(self, max_scan): """**DEPRECATED** - Limit the number of documents to scan when @@ -818,14 +850,13 @@ def distinct(self, key): if self.__spec: options["query"] = self.__spec if self.__max_time_ms is not None: - options['maxTimeMS'] = self.__max_time_ms + options["maxTimeMS"] = self.__max_time_ms if self.__comment: - options['comment'] = self.__comment + options["comment"] = self.__comment if self.__collation is not None: - options['collation'] = self.__collation + options["collation"] = self.__collation - return self.__collection.distinct( - key, session=self.__session, **options) + return self.__collection.distinct(key, session=self.__session, **options) def explain(self): """Returns an explain plan record for this cursor. @@ -964,12 +995,12 @@ def __send_message(self, operation): client = self.__collection.database.client # OP_MSG is required to support exhaust cursors with encryption. if client._encrypter and self.__exhaust: - raise InvalidOperation( - "exhaust cursors do not support auto encryption") + raise InvalidOperation("exhaust cursors do not support auto encryption") try: response = client._run_operation( - operation, self._unpack_response, address=self.__address) + operation, self._unpack_response, address=self.__address + ) except OperationFailure as exc: if exc.code in _CURSOR_CLOSED_ERRORS or self.__exhaust: # Don't send killCursors because the cursor is already closed. @@ -979,8 +1010,10 @@ def __send_message(self, operation): # due to capped collection roll over. Setting # self.__killed to True ensures Cursor.alive will be # False. No need to re-raise. - if (exc.code in _CURSOR_CLOSED_ERRORS and - self.__query_flags & _QUERY_OPTIONS["tailable_cursor"]): + if ( + exc.code in _CURSOR_CLOSED_ERRORS + and self.__query_flags & _QUERY_OPTIONS["tailable_cursor"] + ): return raise except ConnectionFailure: @@ -995,23 +1028,22 @@ def __send_message(self, operation): self.__address = response.address if isinstance(response, PinnedResponse): if not self.__sock_mgr: - self.__sock_mgr = _SocketManager(response.socket_info, - response.more_to_come) + self.__sock_mgr = _SocketManager(response.socket_info, response.more_to_come) cmd_name = operation.name docs = response.docs if response.from_command: if cmd_name != "explain": - cursor = docs[0]['cursor'] - self.__id = cursor['id'] - if cmd_name == 'find': - documents = cursor['firstBatch'] + cursor = docs[0]["cursor"] + self.__id = cursor["id"] + if cmd_name == "find": + documents = cursor["firstBatch"] # Update the namespace used for future getMore commands. - ns = cursor.get('ns') + ns = cursor.get("ns") if ns: - self.__dbname, self.__collname = ns.split('.', 1) + self.__dbname, self.__collname = ns.split(".", 1) else: - documents = cursor['nextBatch'] + documents = cursor["nextBatch"] self.__data = deque(documents) self.__retrieved += len(documents) else: @@ -1031,16 +1063,15 @@ def __send_message(self, operation): if self.__limit and self.__id and self.__limit <= self.__retrieved: self.close() - def _unpack_response(self, response, cursor_id, codec_options, - user_fields=None, legacy_response=False): - return response.unpack_response(cursor_id, codec_options, user_fields, - legacy_response) + def _unpack_response( + self, response, cursor_id, codec_options, user_fields=None, legacy_response=False + ): + return response.unpack_response(cursor_id, codec_options, user_fields, legacy_response) def _read_preference(self): if self.__read_preference is None: # Save the read preference for getMore commands. - self.__read_preference = self.__collection._read_preference_for( - self.session) + self.__read_preference = self.__collection._read_preference_for(self.session) return self.__read_preference def _refresh(self): @@ -1060,23 +1091,26 @@ def _refresh(self): if (self.__min or self.__max) and not self.__hint: raise InvalidOperation( "Passing a 'hint' is required when using the min/max query" - " option to ensure the query utilizes the correct index") - q = self._query_class(self.__query_flags, - self.__collection.database.name, - self.__collection.name, - self.__skip, - self.__query_spec(), - self.__projection, - self.__codec_options, - self._read_preference(), - self.__limit, - self.__batch_size, - self.__read_concern, - self.__collation, - self.__session, - self.__collection.database.client, - self.__allow_disk_use, - self.__exhaust) + " option to ensure the query utilizes the correct index" + ) + q = self._query_class( + self.__query_flags, + self.__collection.database.name, + self.__collection.name, + self.__skip, + self.__query_spec(), + self.__projection, + self.__codec_options, + self._read_preference(), + self.__limit, + self.__batch_size, + self.__read_concern, + self.__collation, + self.__session, + self.__collection.database.client, + self.__allow_disk_use, + self.__exhaust, + ) self.__send_message(q) elif self.__id: # Get More if self.__limit: @@ -1086,17 +1120,19 @@ def _refresh(self): else: limit = self.__batch_size # Exhaust cursors don't send getMore messages. - g = self._getmore_class(self.__dbname, - self.__collname, - limit, - self.__id, - self.__codec_options, - self._read_preference(), - self.__session, - self.__collection.database.client, - self.__max_await_time_ms, - self.__sock_mgr, - self.__exhaust) + g = self._getmore_class( + self.__dbname, + self.__collname, + limit, + self.__id, + self.__codec_options, + self._read_preference(), + self.__session, + self.__collection.database.client, + self.__max_await_time_ms, + self.__sock_mgr, + self.__exhaust, + ) self.__send_message(g) return len(self.__data) @@ -1189,7 +1225,7 @@ def _deepcopy(self, x, memo=None): Regular expressions cannot be deep copied but as they are immutable we don't have to copy them when cloning. """ - if not hasattr(x, 'items'): + if not hasattr(x, "items"): y, is_list, iterator = [], True, enumerate(x) else: y, is_list, iterator = {}, False, x.items() @@ -1233,10 +1269,10 @@ def __init__(self, *args, **kwargs): """ super(RawBatchCursor, self).__init__(*args, **kwargs) - def _unpack_response(self, response, cursor_id, codec_options, - user_fields=None, legacy_response=False): - raw_response = response.raw_response( - cursor_id, user_fields=user_fields) + def _unpack_response( + self, response, cursor_id, codec_options, user_fields=None, legacy_response=False + ): + raw_response = response.raw_response(cursor_id, user_fields=user_fields) if not legacy_response: # OP_MSG returns firstBatch/nextBatch documents as a BSON array # Re-assemble the array of documents into a document stream diff --git a/pymongo/daemon.py b/pymongo/daemon.py index f0253547d9..53141751ac 100644 --- a/pymongo/daemon.py +++ b/pymongo/daemon.py @@ -24,7 +24,6 @@ import sys import warnings - # The maximum amount of time to wait for the intermediate subprocess. _WAIT_TIMEOUT = 10 _THIS_FILE = os.path.realpath(__file__) @@ -53,23 +52,29 @@ def _silence_resource_warning(popen): popen.returncode = 0 -if sys.platform == 'win32': +if sys.platform == "win32": # On Windows we spawn the daemon process simply by using DETACHED_PROCESS. - _DETACHED_PROCESS = getattr(subprocess, 'DETACHED_PROCESS', 0x00000008) + _DETACHED_PROCESS = getattr(subprocess, "DETACHED_PROCESS", 0x00000008) def _spawn_daemon(args): """Spawn a daemon process (Windows).""" try: - with open(os.devnull, 'r+b') as devnull: + with open(os.devnull, "r+b") as devnull: popen = subprocess.Popen( args, creationflags=_DETACHED_PROCESS, - stdin=devnull, stderr=devnull, stdout=devnull) + stdin=devnull, + stderr=devnull, + stdout=devnull, + ) _silence_resource_warning(popen) except FileNotFoundError as exc: - warnings.warn(f'Failed to start {args[0]}: is it on your $PATH?\n' - f'Original exception: {exc}', RuntimeWarning, - stacklevel=2) + warnings.warn( + f"Failed to start {args[0]}: is it on your $PATH?\n" f"Original exception: {exc}", + RuntimeWarning, + stacklevel=2, + ) + else: # On Unix we spawn the daemon process with a double Popen. # 1) The first Popen runs this file as a Python script using the current @@ -85,16 +90,16 @@ def _spawn_daemon(args): def _spawn(args): """Spawn the process and silence stdout/stderr.""" try: - with open(os.devnull, 'r+b') as devnull: + with open(os.devnull, "r+b") as devnull: return subprocess.Popen( - args, - close_fds=True, - stdin=devnull, stderr=devnull, stdout=devnull) + args, close_fds=True, stdin=devnull, stderr=devnull, stdout=devnull + ) except FileNotFoundError as exc: - warnings.warn(f'Failed to start {args[0]}: is it on your $PATH?\n' - f'Original exception: {exc}', RuntimeWarning, - stacklevel=2) - + warnings.warn( + f"Failed to start {args[0]}: is it on your $PATH?\n" f"Original exception: {exc}", + RuntimeWarning, + stacklevel=2, + ) def _spawn_daemon_double_popen(args): """Spawn a daemon process using a double subprocess.Popen.""" @@ -105,7 +110,6 @@ def _spawn_daemon_double_popen(args): # processes. _popen_wait(temp_proc, _WAIT_TIMEOUT) - def _spawn_daemon(args): """Spawn a daemon process (Unix).""" # "If Python is unable to retrieve the real path to its executable, @@ -123,10 +127,9 @@ def _spawn_daemon(args): # until the main application exits. _spawn(args) - - if __name__ == '__main__': + if __name__ == "__main__": # Attempt to start a new session to decouple from the parent. - if hasattr(os, 'setsid'): + if hasattr(os, "setsid"): try: os.setsid() except OSError: diff --git a/pymongo/database.py b/pymongo/database.py index dc8c13cbb0..bad84ecec4 100644 --- a/pymongo/database.py +++ b/pymongo/database.py @@ -22,29 +22,32 @@ from pymongo.change_stream import DatabaseChangeStream from pymongo.collection import Collection from pymongo.command_cursor import CommandCursor -from pymongo.errors import (CollectionInvalid, - InvalidName) +from pymongo.errors import CollectionInvalid, InvalidName from pymongo.read_preferences import ReadPreference def _check_name(name): - """Check if a database name is valid. - """ + """Check if a database name is valid.""" if not name: raise InvalidName("database name cannot be the empty string") - for invalid_char in [' ', '.', '$', '/', '\\', '\x00', '"']: + for invalid_char in [" ", ".", "$", "/", "\\", "\x00", '"']: if invalid_char in name: - raise InvalidName("database names cannot contain the " - "character %r" % invalid_char) + raise InvalidName("database names cannot contain the " "character %r" % invalid_char) class Database(common.BaseObject): - """A Mongo database. - """ - - def __init__(self, client, name, codec_options=None, read_preference=None, - write_concern=None, read_concern=None): + """A Mongo database.""" + + def __init__( + self, + client, + name, + codec_options=None, + read_preference=None, + write_concern=None, + read_concern=None, + ): """Get a database by client and name. Raises :class:`TypeError` if `name` is not an instance of @@ -95,12 +98,13 @@ def __init__(self, client, name, codec_options=None, read_preference=None, codec_options or client.codec_options, read_preference or client.read_preference, write_concern or client.write_concern, - read_concern or client.read_concern) + read_concern or client.read_concern, + ) if not isinstance(name, str): raise TypeError("name must be an instance of str") - if name != '$external': + if name != "$external": _check_name(name) self.__name = name @@ -116,8 +120,9 @@ def name(self): """The name of this :class:`Database`.""" return self.__name - def with_options(self, codec_options=None, read_preference=None, - write_concern=None, read_concern=None): + def with_options( + self, codec_options=None, read_preference=None, write_concern=None, read_concern=None + ): """Get a clone of this database changing the specified settings. >>> db1.read_preference @@ -149,17 +154,18 @@ def with_options(self, codec_options=None, read_preference=None, .. versionadded:: 3.8 """ - return Database(self.client, - self.__name, - codec_options or self.codec_options, - read_preference or self.read_preference, - write_concern or self.write_concern, - read_concern or self.read_concern) + return Database( + self.client, + self.__name, + codec_options or self.codec_options, + read_preference or self.read_preference, + write_concern or self.write_concern, + read_concern or self.read_concern, + ) def __eq__(self, other): if isinstance(other, Database): - return (self.__client == other.client and - self.__name == other.name) + return self.__client == other.client and self.__name == other.name return NotImplemented def __ne__(self, other): @@ -179,10 +185,11 @@ def __getattr__(self, name): :Parameters: - `name`: the name of the collection to get """ - if name.startswith('_'): + if name.startswith("_"): raise AttributeError( "Database has no attribute %r. To access the %s" - " collection, use database[%r]." % (name, name, name)) + " collection, use database[%r]." % (name, name, name) + ) return self.__getitem__(name) def __getitem__(self, name): @@ -195,8 +202,9 @@ def __getitem__(self, name): """ return Collection(self, name) - def get_collection(self, name, codec_options=None, read_preference=None, - write_concern=None, read_concern=None): + def get_collection( + self, name, codec_options=None, read_preference=None, write_concern=None, read_concern=None + ): """Get a :class:`~pymongo.collection.Collection` with the given name and options. @@ -235,12 +243,19 @@ def get_collection(self, name, codec_options=None, read_preference=None, used. """ return Collection( - self, name, False, codec_options, read_preference, - write_concern, read_concern) - - def create_collection(self, name, codec_options=None, - read_preference=None, write_concern=None, - read_concern=None, session=None, **kwargs): + self, name, False, codec_options, read_preference, write_concern, read_concern + ) + + def create_collection( + self, + name, + codec_options=None, + read_preference=None, + write_concern=None, + read_concern=None, + session=None, + **kwargs + ): """Create a new :class:`~pymongo.collection.Collection` in this database. @@ -306,14 +321,22 @@ def create_collection(self, name, codec_options=None, with self.__client._tmp_session(session) as s: # Skip this check in a transaction where listCollections is not # supported. - if ((not s or not s.in_transaction) and - name in self.list_collection_names( - filter={"name": name}, session=s)): + if (not s or not s.in_transaction) and name in self.list_collection_names( + filter={"name": name}, session=s + ): raise CollectionInvalid("collection %s already exists" % name) - return Collection(self, name, True, codec_options, - read_preference, write_concern, - read_concern, session=s, **kwargs) + return Collection( + self, + name, + True, + codec_options, + read_preference, + write_concern, + read_concern, + session=s, + **kwargs + ) def aggregate(self, pipeline, session=None, **kwargs): """Perform a database-level aggregation. @@ -381,15 +404,29 @@ def aggregate(self, pipeline, session=None, **kwargs): """ with self.client._tmp_session(session, close=False) as s: cmd = _DatabaseAggregationCommand( - self, CommandCursor, pipeline, kwargs, session is not None, - user_fields={'cursor': {'firstBatch': 1}}) + self, + CommandCursor, + pipeline, + kwargs, + session is not None, + user_fields={"cursor": {"firstBatch": 1}}, + ) return self.client._retryable_read( - cmd.get_cursor, cmd.get_read_preference(s), s, - retryable=not cmd._performs_write) - - def watch(self, pipeline=None, full_document=None, resume_after=None, - max_await_time_ms=None, batch_size=None, collation=None, - start_at_operation_time=None, session=None, start_after=None): + cmd.get_cursor, cmd.get_read_preference(s), s, retryable=not cmd._performs_write + ) + + def watch( + self, + pipeline=None, + full_document=None, + resume_after=None, + max_await_time_ms=None, + batch_size=None, + collation=None, + start_at_operation_time=None, + session=None, + start_after=None, + ): """Watch changes on this database. Performs an aggregation with an implicit initial ``$changeStream`` @@ -475,15 +512,33 @@ def watch(self, pipeline=None, full_document=None, resume_after=None, https://github.com/mongodb/specifications/blob/master/source/change-streams/change-streams.rst """ return DatabaseChangeStream( - self, pipeline, full_document, resume_after, max_await_time_ms, - batch_size, collation, start_at_operation_time, session, - start_after) - - def _command(self, sock_info, command, secondary_ok=False, value=1, check=True, - allowable_errors=None, read_preference=ReadPreference.PRIMARY, - codec_options=DEFAULT_CODEC_OPTIONS, - write_concern=None, - parse_write_concern_error=False, session=None, **kwargs): + self, + pipeline, + full_document, + resume_after, + max_await_time_ms, + batch_size, + collation, + start_at_operation_time, + session, + start_after, + ) + + def _command( + self, + sock_info, + command, + secondary_ok=False, + value=1, + check=True, + allowable_errors=None, + read_preference=ReadPreference.PRIMARY, + codec_options=DEFAULT_CODEC_OPTIONS, + write_concern=None, + parse_write_concern_error=False, + session=None, + **kwargs + ): """Internal command helper.""" if isinstance(command, str): command = SON([(command, value)]) @@ -501,11 +556,20 @@ def _command(self, sock_info, command, secondary_ok=False, value=1, check=True, write_concern=write_concern, parse_write_concern_error=parse_write_concern_error, session=s, - client=self.__client) - - def command(self, command, value=1, check=True, - allowable_errors=None, read_preference=None, - codec_options=DEFAULT_CODEC_OPTIONS, session=None, **kwargs): + client=self.__client, + ) + + def command( + self, + command, + value=1, + check=True, + allowable_errors=None, + read_preference=None, + codec_options=DEFAULT_CODEC_OPTIONS, + session=None, + **kwargs + ): """Issue a MongoDB command. Send command `command` to the database and return the @@ -589,51 +653,69 @@ def command(self, command, value=1, check=True, .. seealso:: The MongoDB documentation on `commands `_. """ if read_preference is None: - read_preference = ((session and session._txn_read_preference()) - or ReadPreference.PRIMARY) - with self.__client._socket_for_reads( - read_preference, session) as (sock_info, secondary_ok): - return self._command(sock_info, command, secondary_ok, value, - check, allowable_errors, read_preference, - codec_options, session=session, **kwargs) - - def _retryable_read_command(self, command, value=1, check=True, - allowable_errors=None, read_preference=None, - codec_options=DEFAULT_CODEC_OPTIONS, session=None, **kwargs): + read_preference = (session and session._txn_read_preference()) or ReadPreference.PRIMARY + with self.__client._socket_for_reads(read_preference, session) as (sock_info, secondary_ok): + return self._command( + sock_info, + command, + secondary_ok, + value, + check, + allowable_errors, + read_preference, + codec_options, + session=session, + **kwargs + ) + + def _retryable_read_command( + self, + command, + value=1, + check=True, + allowable_errors=None, + read_preference=None, + codec_options=DEFAULT_CODEC_OPTIONS, + session=None, + **kwargs + ): """Same as command but used for retryable read commands.""" if read_preference is None: - read_preference = ((session and session._txn_read_preference()) - or ReadPreference.PRIMARY) + read_preference = (session and session._txn_read_preference()) or ReadPreference.PRIMARY def _cmd(session, server, sock_info, secondary_ok): - return self._command(sock_info, command, secondary_ok, value, - check, allowable_errors, read_preference, - codec_options, session=session, **kwargs) + return self._command( + sock_info, + command, + secondary_ok, + value, + check, + allowable_errors, + read_preference, + codec_options, + session=session, + **kwargs + ) - return self.__client._retryable_read( - _cmd, read_preference, session) + return self.__client._retryable_read(_cmd, read_preference, session) - def _list_collections(self, sock_info, secondary_okay, session, - read_preference, **kwargs): + def _list_collections(self, sock_info, secondary_okay, session, read_preference, **kwargs): """Internal listCollections helper.""" - coll = self.get_collection( - "$cmd", read_preference=read_preference) - cmd = SON([("listCollections", 1), - ("cursor", {})]) + coll = self.get_collection("$cmd", read_preference=read_preference) + cmd = SON([("listCollections", 1), ("cursor", {})]) cmd.update(kwargs) - with self.__client._tmp_session( - session, close=False) as tmp_session: + with self.__client._tmp_session(session, close=False) as tmp_session: cursor = self._command( - sock_info, cmd, secondary_okay, - read_preference=read_preference, - session=tmp_session)["cursor"] + sock_info, cmd, secondary_okay, read_preference=read_preference, session=tmp_session + )["cursor"] cmd_cursor = CommandCursor( coll, cursor, sock_info.address, session=tmp_session, - explicit_session=session is not None) + explicit_session=session is not None, + ) cmd_cursor._maybe_pin_connection(sock_info) return cmd_cursor @@ -657,17 +739,15 @@ def list_collections(self, session=None, filter=None, **kwargs): .. versionadded:: 3.6 """ if filter is not None: - kwargs['filter'] = filter - read_pref = ((session and session._txn_read_preference()) - or ReadPreference.PRIMARY) + kwargs["filter"] = filter + read_pref = (session and session._txn_read_preference()) or ReadPreference.PRIMARY def _cmd(session, server, sock_info, secondary_okay): return self._list_collections( - sock_info, secondary_okay, session, read_preference=read_pref, - **kwargs) + sock_info, secondary_okay, session, read_preference=read_pref, **kwargs + ) - return self.__client._retryable_read( - _cmd, read_pref, session) + return self.__client._retryable_read(_cmd, read_pref, session) def list_collection_names(self, session=None, filter=None, **kwargs): """Get a list of all the collection names in this database. @@ -703,8 +783,7 @@ def list_collection_names(self, session=None, filter=None, **kwargs): if not filter or (len(filter) == 1 and "name" in filter): kwargs["nameOnly"] = True - return [result["name"] - for result in self.list_collections(session=session, **kwargs)] + return [result["name"] for result in self.list_collections(session=session, **kwargs)] def drop_collection(self, name_or_collection, session=None): """Drop a collection. @@ -736,15 +815,18 @@ def drop_collection(self, name_or_collection, session=None): with self.__client._socket_for_writes(session) as sock_info: return self._command( - sock_info, 'drop', value=name, - allowable_errors=['ns not found', 26], + sock_info, + "drop", + value=name, + allowable_errors=["ns not found", 26], write_concern=self._write_concern_for(session), parse_write_concern_error=True, - session=session) + session=session, + ) - def validate_collection(self, name_or_collection, - scandata=False, full=False, session=None, - background=None): + def validate_collection( + self, name_or_collection, scandata=False, full=False, session=None, background=None + ): """Validate a collection. Returns a dict of validation info. Raises CollectionInvalid if @@ -779,12 +861,9 @@ def validate_collection(self, name_or_collection, name = name.name if not isinstance(name, str): - raise TypeError("name_or_collection must be an instance of str or " - "Collection") + raise TypeError("name_or_collection must be an instance of str or " "Collection") - cmd = SON([("validate", name), - ("scandata", scandata), - ("full", full)]) + cmd = SON([("validate", name), ("scandata", scandata), ("full", full)]) if background is not None: cmd["background"] = background @@ -801,10 +880,8 @@ def validate_collection(self, name_or_collection, for _, res in result["raw"].items(): if "result" in res: info = res["result"] - if (info.find("exception") != -1 or - info.find("corrupt") != -1): - raise CollectionInvalid("%s invalid: " - "%s" % (name, info)) + if info.find("exception") != -1 or info.find("corrupt") != -1: + raise CollectionInvalid("%s invalid: " "%s" % (name, info)) elif not res.get("valid", False): valid = False break @@ -826,9 +903,11 @@ def __next__(self): next = __next__ def __bool__(self): - raise NotImplementedError("Database objects do not implement truth " - "value testing or bool(). Please compare " - "with None instead: database is not None") + raise NotImplementedError( + "Database objects do not implement truth " + "value testing or bool(). Please compare " + "with None instead: database is not None" + ) def dereference(self, dbref, session=None, **kwargs): """Dereference a :class:`~bson.dbref.DBRef`, getting the @@ -854,8 +933,8 @@ def dereference(self, dbref, session=None, **kwargs): if not isinstance(dbref, DBRef): raise TypeError("cannot dereference a %s" % type(dbref)) if dbref.database is not None and dbref.database != self.__name: - raise ValueError("trying to dereference a DBRef that points to " - "another database (%r not %r)" % (dbref.database, - self.__name)) - return self[dbref.collection].find_one( - {"_id": dbref.id}, session=session, **kwargs) + raise ValueError( + "trying to dereference a DBRef that points to " + "another database (%r not %r)" % (dbref.database, self.__name) + ) + return self[dbref.collection].find_one({"_id": dbref.id}, session=session, **kwargs) diff --git a/pymongo/driver_info.py b/pymongo/driver_info.py index 5e0843e4df..f6f6c00347 100644 --- a/pymongo/driver_info.py +++ b/pymongo/driver_info.py @@ -17,7 +17,7 @@ from collections import namedtuple -class DriverInfo(namedtuple('DriverInfo', ['name', 'version', 'platform'])): +class DriverInfo(namedtuple("DriverInfo", ["name", "version", "platform"])): """Info about a driver wrapping PyMongo. The MongoDB server logs PyMongo's name, version, and platform whenever @@ -26,11 +26,14 @@ class DriverInfo(namedtuple('DriverInfo', ['name', 'version', 'platform'])): like 'MyDriver', '1.2.3', 'some platform info'. Any of these strings may be None to accept PyMongo's default. """ + def __new__(cls, name, version=None, platform=None): self = super(DriverInfo, cls).__new__(cls, name, version, platform) for key, value in self._asdict().items(): if value is not None and not isinstance(value, str): - raise TypeError("Wrong type for DriverInfo %s option, value " - "must be an instance of str" % (key,)) + raise TypeError( + "Wrong type for DriverInfo %s option, value " + "must be an instance of str" % (key,) + ) return self diff --git a/pymongo/encryption.py b/pymongo/encryption.py index 064ba48d51..85bf5e7132 100644 --- a/pymongo/encryption.py +++ b/pymongo/encryption.py @@ -26,35 +26,32 @@ from pymongocrypt.explicit_encrypter import ExplicitEncrypter from pymongocrypt.mongocrypt import MongoCryptOptions from pymongocrypt.state_machine import MongoCryptCallback + _HAVE_PYMONGOCRYPT = True except ImportError: _HAVE_PYMONGOCRYPT = False MongoCryptCallback = object from bson import _dict_to_bson, decode, encode +from bson.binary import STANDARD, UUID_SUBTYPE, Binary from bson.codec_options import CodecOptions -from bson.binary import (Binary, - STANDARD, - UUID_SUBTYPE) from bson.errors import BSONError -from bson.raw_bson import (DEFAULT_RAW_BSON_OPTIONS, - RawBSONDocument, - _inflate_bson) +from bson.raw_bson import DEFAULT_RAW_BSON_OPTIONS, RawBSONDocument, _inflate_bson from bson.son import SON - -from pymongo.errors import (ConfigurationError, - EncryptionError, - InvalidOperation, - ServerSelectionTimeoutError) +from pymongo.daemon import _spawn_daemon from pymongo.encryption_options import AutoEncryptionOpts +from pymongo.errors import ( + ConfigurationError, + EncryptionError, + InvalidOperation, + ServerSelectionTimeoutError, +) from pymongo.mongo_client import MongoClient -from pymongo.pool import _configured_socket, PoolOptions +from pymongo.pool import PoolOptions, _configured_socket from pymongo.read_concern import ReadConcern from pymongo.ssl_support import get_ssl_context from pymongo.uri_parser import parse_host from pymongo.write_concern import WriteConcern -from pymongo.daemon import _spawn_daemon - _HTTPS_PORT = 443 _KMS_CONNECT_TIMEOUT = 10 # TODO: CDRIVER-3262 will define this value. @@ -63,8 +60,7 @@ _DATA_KEY_OPTS = CodecOptions(document_class=SON, uuid_representation=STANDARD) # Use RawBSONDocument codec options to avoid needlessly decoding # documents from the key vault. -_KEY_VAULT_OPTS = CodecOptions(document_class=RawBSONDocument, - uuid_representation=STANDARD) +_KEY_VAULT_OPTS = CodecOptions(document_class=RawBSONDocument, uuid_representation=STANDARD) @contextlib.contextmanager @@ -90,8 +86,9 @@ def __init__(self, client, key_vault_coll, mongocryptd_client, opts): self.client_ref = None self.key_vault_coll = key_vault_coll.with_options( codec_options=_KEY_VAULT_OPTS, - read_concern=ReadConcern(level='majority'), - write_concern=WriteConcern(w='majority')) + read_concern=ReadConcern(level="majority"), + write_concern=WriteConcern(w="majority"), + ) self.mongocryptd_client = mongocryptd_client self.opts = opts self._spawned = False @@ -113,16 +110,19 @@ def kms_request(self, kms_context): # Enable strict certificate verification, OCSP, match hostname, and # SNI using the system default CA certificates. ctx = get_ssl_context( - None, # certfile - None, # passphrase - None, # ca_certs - None, # crlfile + None, # certfile + None, # passphrase + None, # ca_certs + None, # crlfile False, # allow_invalid_certificates False, # allow_invalid_hostnames - False) # disable_ocsp_endpoint_check - opts = PoolOptions(connect_timeout=_KMS_CONNECT_TIMEOUT, - socket_timeout=_KMS_CONNECT_TIMEOUT, - ssl_context=ctx) + False, + ) # disable_ocsp_endpoint_check + opts = PoolOptions( + connect_timeout=_KMS_CONNECT_TIMEOUT, + socket_timeout=_KMS_CONNECT_TIMEOUT, + ssl_context=ctx, + ) host, port = parse_host(endpoint, _HTTPS_PORT) conn = _configured_socket((host, port), opts) try: @@ -130,7 +130,7 @@ def kms_request(self, kms_context): while kms_context.bytes_needed > 0: data = conn.recv(kms_context.bytes_needed) if not data: - raise OSError('KMS connection closed') + raise OSError("KMS connection closed") kms_context.feed(data) finally: conn.close() @@ -148,8 +148,7 @@ def collection_info(self, database, filter): :Returns: The first document from the listCollections command response as BSON. """ - with self.client_ref()[database].list_collections( - filter=RawBSONDocument(filter)) as cursor: + with self.client_ref()[database].list_collections(filter=RawBSONDocument(filter)) as cursor: for doc in cursor: return _dict_to_bson(doc, False, _DATA_KEY_OPTS) @@ -160,7 +159,7 @@ def spawn(self): successfully. """ self._spawned = True - args = [self.opts._mongocryptd_spawn_path or 'mongocryptd'] + args = [self.opts._mongocryptd_spawn_path or "mongocryptd"] args.extend(self.opts._mongocryptd_spawn_args) _spawn_daemon(args) @@ -181,15 +180,15 @@ def mark_command(self, database, cmd): inflated_cmd = _inflate_bson(cmd, DEFAULT_RAW_BSON_OPTIONS) try: res = self.mongocryptd_client[database].command( - inflated_cmd, - codec_options=DEFAULT_RAW_BSON_OPTIONS) + inflated_cmd, codec_options=DEFAULT_RAW_BSON_OPTIONS + ) except ServerSelectionTimeoutError: if self.opts._mongocryptd_bypass_spawn: raise self.spawn() res = self.mongocryptd_client[database].command( - inflated_cmd, - codec_options=DEFAULT_RAW_BSON_OPTIONS) + inflated_cmd, codec_options=DEFAULT_RAW_BSON_OPTIONS + ) return res.raw def fetch_keys(self, filter): @@ -215,9 +214,9 @@ def insert_data_key(self, data_key): The _id of the inserted data key document. """ raw_doc = RawBSONDocument(data_key, _KEY_VAULT_OPTS) - data_key_id = raw_doc.get('_id') + data_key_id = raw_doc.get("_id") if not isinstance(data_key_id, uuid.UUID): - raise TypeError('data_key _id must be a UUID') + raise TypeError("data_key _id must be a UUID") self.key_vault_coll.insert_one(raw_doc) return Binary(data_key_id.bytes, subtype=UUID_SUBTYPE) @@ -252,6 +251,7 @@ class _Encrypter(object): This class is used to support automatic encryption and decryption of MongoDB commands.""" + def __init__(self, client, opts): """Create a _Encrypter for a client. @@ -273,8 +273,7 @@ def _get_internal_client(encrypter, mongo_client): # Else - limited pool size, use an internal client. if encrypter._internal_client is not None: return encrypter._internal_client - internal_client = mongo_client._duplicate( - minPoolSize=0, auto_encryption_opts=None) + internal_client = mongo_client._duplicate(minPoolSize=0, auto_encryption_opts=None) encrypter._internal_client = internal_client return internal_client @@ -288,17 +287,17 @@ def _get_internal_client(encrypter, mongo_client): else: metadata_client = _get_internal_client(self, client) - db, coll = opts._key_vault_namespace.split('.', 1) + db, coll = opts._key_vault_namespace.split(".", 1) key_vault_coll = key_vault_client[db][coll] mongocryptd_client = MongoClient( - opts._mongocryptd_uri, connect=False, - serverSelectionTimeoutMS=_MONGOCRYPTD_TIMEOUT_MS) + opts._mongocryptd_uri, connect=False, serverSelectionTimeoutMS=_MONGOCRYPTD_TIMEOUT_MS + ) - io_callbacks = _EncryptionIO( - metadata_client, key_vault_coll, mongocryptd_client, opts) - self._auto_encrypter = AutoEncrypter(io_callbacks, MongoCryptOptions( - opts._kms_providers, schema_map)) + io_callbacks = _EncryptionIO(metadata_client, key_vault_coll, mongocryptd_client, opts) + self._auto_encrypter = AutoEncrypter( + io_callbacks, MongoCryptOptions(opts._kms_providers, schema_map) + ) self._closed = False def encrypt(self, database, cmd, check_keys, codec_options): @@ -316,15 +315,14 @@ def encrypt(self, database, cmd, check_keys, codec_options): self._check_closed() # Workaround for $clusterTime which is incompatible with # check_keys. - cluster_time = check_keys and cmd.pop('$clusterTime', None) + cluster_time = check_keys and cmd.pop("$clusterTime", None) encoded_cmd = _dict_to_bson(cmd, check_keys, codec_options) with _wrap_encryption_errors(): encrypted_cmd = self._auto_encrypter.encrypt(database, encoded_cmd) # TODO: PYTHON-1922 avoid decoding the encrypted_cmd. - encrypt_cmd = _inflate_bson( - encrypted_cmd, DEFAULT_RAW_BSON_OPTIONS) + encrypt_cmd = _inflate_bson(encrypted_cmd, DEFAULT_RAW_BSON_OPTIONS) if cluster_time: - encrypt_cmd['$clusterTime'] = cluster_time + encrypt_cmd["$clusterTime"] = cluster_time return encrypt_cmd def decrypt(self, response): @@ -355,17 +353,22 @@ def close(self): class Algorithm(object): """An enum that defines the supported encryption algorithms.""" - AEAD_AES_256_CBC_HMAC_SHA_512_Deterministic = ( - "AEAD_AES_256_CBC_HMAC_SHA_512-Deterministic") - AEAD_AES_256_CBC_HMAC_SHA_512_Random = ( - "AEAD_AES_256_CBC_HMAC_SHA_512-Random") + + AEAD_AES_256_CBC_HMAC_SHA_512_Deterministic = "AEAD_AES_256_CBC_HMAC_SHA_512-Deterministic" + AEAD_AES_256_CBC_HMAC_SHA_512_Random = "AEAD_AES_256_CBC_HMAC_SHA_512-Random" class ClientEncryption(object): """Explicit client-side field level encryption.""" - def __init__(self, kms_providers, key_vault_namespace, key_vault_client, - codec_options, kms_tls_options=None): + def __init__( + self, + kms_providers, + key_vault_namespace, + key_vault_client, + codec_options, + kms_tls_options=None, + ): """Explicit client-side field level encryption. The ClientEncryption class encapsulates explicit operations on a key @@ -439,28 +442,31 @@ def __init__(self, kms_providers, key_vault_namespace, key_vault_client, raise ConfigurationError( "client-side field level encryption requires the pymongocrypt " "library: install a compatible version with: " - "python -m pip install 'pymongo[encryption]'") + "python -m pip install 'pymongo[encryption]'" + ) if not isinstance(codec_options, CodecOptions): - raise TypeError("codec_options must be an instance of " - "bson.codec_options.CodecOptions") + raise TypeError( + "codec_options must be an instance of " "bson.codec_options.CodecOptions" + ) self._kms_providers = kms_providers self._key_vault_namespace = key_vault_namespace self._key_vault_client = key_vault_client self._codec_options = codec_options - db, coll = key_vault_namespace.split('.', 1) + db, coll = key_vault_namespace.split(".", 1) key_vault_coll = key_vault_client[db][coll] - opts = AutoEncryptionOpts(kms_providers, key_vault_namespace, - kms_tls_options=kms_tls_options) + opts = AutoEncryptionOpts( + kms_providers, key_vault_namespace, kms_tls_options=kms_tls_options + ) self._io_callbacks = _EncryptionIO(None, key_vault_coll, None, opts) self._encryption = ExplicitEncrypter( - self._io_callbacks, MongoCryptOptions(kms_providers, None)) + self._io_callbacks, MongoCryptOptions(kms_providers, None) + ) - def create_data_key(self, kms_provider, master_key=None, - key_alt_names=None): + def create_data_key(self, kms_provider, master_key=None, key_alt_names=None): """Create and insert a new data key into the key vault collection. :Parameters: @@ -529,8 +535,8 @@ def create_data_key(self, kms_provider, master_key=None, self._check_closed() with _wrap_encryption_errors(): return self._encryption.create_data_key( - kms_provider, master_key=master_key, - key_alt_names=key_alt_names) + kms_provider, master_key=master_key, key_alt_names=key_alt_names + ) def encrypt(self, value, algorithm, key_id=None, key_alt_name=None): """Encrypt a BSON value with a given key and algorithm. @@ -551,17 +557,17 @@ def encrypt(self, value, algorithm, key_id=None, key_alt_name=None): The encrypted value, a :class:`~bson.binary.Binary` with subtype 6. """ self._check_closed() - if (key_id is not None and not ( - isinstance(key_id, Binary) and - key_id.subtype == UUID_SUBTYPE)): - raise TypeError( - 'key_id must be a bson.binary.Binary with subtype 4') + if key_id is not None and not ( + isinstance(key_id, Binary) and key_id.subtype == UUID_SUBTYPE + ): + raise TypeError("key_id must be a bson.binary.Binary with subtype 4") - doc = encode({'v': value}, codec_options=self._codec_options) + doc = encode({"v": value}, codec_options=self._codec_options) with _wrap_encryption_errors(): encrypted_doc = self._encryption.encrypt( - doc, algorithm, key_id=key_id, key_alt_name=key_alt_name) - return decode(encrypted_doc)['v'] + doc, algorithm, key_id=key_id, key_alt_name=key_alt_name + ) + return decode(encrypted_doc)["v"] def decrypt(self, value): """Decrypt an encrypted value. @@ -575,14 +581,12 @@ def decrypt(self, value): """ self._check_closed() if not (isinstance(value, Binary) and value.subtype == 6): - raise TypeError( - 'value to decrypt must be a bson.binary.Binary with subtype 6') + raise TypeError("value to decrypt must be a bson.binary.Binary with subtype 6") with _wrap_encryption_errors(): - doc = encode({'v': value}) + doc = encode({"v": value}) decrypted_doc = self._encryption.decrypt(doc) - return decode(decrypted_doc, - codec_options=self._codec_options)['v'] + return decode(decrypted_doc, codec_options=self._codec_options)["v"] def __enter__(self): return self diff --git a/pymongo/encryption_options.py b/pymongo/encryption_options.py index d0c2d5ce72..7cf14bbc53 100644 --- a/pymongo/encryption_options.py +++ b/pymongo/encryption_options.py @@ -18,6 +18,7 @@ try: import pymongocrypt + _HAVE_PYMONGOCRYPT = True except ImportError: _HAVE_PYMONGOCRYPT = False @@ -29,15 +30,19 @@ class AutoEncryptionOpts(object): """Options to configure automatic client-side field level encryption.""" - def __init__(self, kms_providers, key_vault_namespace, - key_vault_client=None, - schema_map=None, - bypass_auto_encryption=False, - mongocryptd_uri='mongodb://localhost:27020', - mongocryptd_bypass_spawn=False, - mongocryptd_spawn_path='mongocryptd', - mongocryptd_spawn_args=None, - kms_tls_options=None): + def __init__( + self, + kms_providers, + key_vault_namespace, + key_vault_client=None, + schema_map=None, + bypass_auto_encryption=False, + mongocryptd_uri="mongodb://localhost:27020", + mongocryptd_bypass_spawn=False, + mongocryptd_spawn_path="mongocryptd", + mongocryptd_spawn_args=None, + kms_tls_options=None, + ): """Options to configure automatic client-side field level encryption. Automatic client-side field level encryption requires MongoDB 4.2 @@ -142,7 +147,8 @@ def __init__(self, kms_providers, key_vault_namespace, raise ConfigurationError( "client side encryption requires the pymongocrypt library: " "install a compatible version with: " - "python -m pip install 'pymongo[encryption]'") + "python -m pip install 'pymongo[encryption]'" + ) self._kms_providers = kms_providers self._key_vault_namespace = key_vault_namespace @@ -152,12 +158,12 @@ def __init__(self, kms_providers, key_vault_namespace, self._mongocryptd_uri = mongocryptd_uri self._mongocryptd_bypass_spawn = mongocryptd_bypass_spawn self._mongocryptd_spawn_path = mongocryptd_spawn_path - self._mongocryptd_spawn_args = (copy.copy(mongocryptd_spawn_args) or - ['--idleShutdownTimeoutSecs=60']) + self._mongocryptd_spawn_args = copy.copy(mongocryptd_spawn_args) or [ + "--idleShutdownTimeoutSecs=60" + ] if not isinstance(self._mongocryptd_spawn_args, list): - raise TypeError('mongocryptd_spawn_args must be a list') - if not any('idleShutdownTimeoutSecs' in s - for s in self._mongocryptd_spawn_args): - self._mongocryptd_spawn_args.append('--idleShutdownTimeoutSecs=60') + raise TypeError("mongocryptd_spawn_args must be a list") + if not any("idleShutdownTimeoutSecs" in s for s in self._mongocryptd_spawn_args): + self._mongocryptd_spawn_args.append("--idleShutdownTimeoutSecs=60") # Maps KMS provider name to a SSLContext. self._kms_ssl_contexts = _parse_kms_tls_options(kms_tls_options) diff --git a/pymongo/errors.py b/pymongo/errors.py index 0ee35827a7..49af6d6fca 100644 --- a/pymongo/errors.py +++ b/pymongo/errors.py @@ -23,13 +23,15 @@ try: from ssl import CertificateError as _CertificateError except ImportError: + class _CertificateError(ValueError): pass class PyMongoError(Exception): """Base class for all PyMongo exceptions.""" - def __init__(self, message='', error_labels=None): + + def __init__(self, message="", error_labels=None): super(PyMongoError, self).__init__(message) self._message = message self._error_labels = set(error_labels or []) @@ -70,10 +72,11 @@ class AutoReconnect(ConnectionFailure): Subclass of :exc:`~pymongo.errors.ConnectionFailure`. """ - def __init__(self, message='', errors=None): + + def __init__(self, message="", errors=None): error_labels = None if errors is not None and isinstance(errors, dict): - error_labels = errors.get('errorLabels') + error_labels = errors.get("errorLabels") super(AutoReconnect, self).__init__(message, error_labels) self.errors = self.details = errors or [] @@ -109,9 +112,11 @@ class NotPrimaryError(AutoReconnect): .. versionadded:: 3.12 """ - def __init__(self, message='', errors=None): + + def __init__(self, message="", errors=None): super(NotPrimaryError, self).__init__( - _format_detailed_error(message, errors), errors=errors) + _format_detailed_error(message, errors), errors=errors + ) class ServerSelectionTimeoutError(AutoReconnect): @@ -128,8 +133,7 @@ class ServerSelectionTimeoutError(AutoReconnect): class ConfigurationError(PyMongoError): - """Raised when something is incorrectly configured. - """ + """Raised when something is incorrectly configured.""" class OperationFailure(PyMongoError): @@ -142,9 +146,10 @@ class OperationFailure(PyMongoError): def __init__(self, error, code=None, details=None, max_wire_version=None): error_labels = None if details is not None: - error_labels = details.get('errorLabels') + error_labels = details.get("errorLabels") super(OperationFailure, self).__init__( - _format_detailed_error(error, details), error_labels=error_labels) + _format_detailed_error(error, details), error_labels=error_labels + ) self.__code = code self.__details = details self.__max_wire_version = max_wire_version @@ -155,8 +160,7 @@ def _max_wire_version(self): @property def code(self): - """The error code returned by the server, if any. - """ + """The error code returned by the server, if any.""" return self.__code @property @@ -172,7 +176,6 @@ def details(self): return self.__details - class CursorNotFound(OperationFailure): """Raised while iterating query results if the cursor is invalidated on the server. @@ -225,9 +228,9 @@ class BulkWriteError(OperationFailure): .. versionadded:: 2.7 """ + def __init__(self, results): - super(BulkWriteError, self).__init__( - "batch op errors occurred", 65, results) + super(BulkWriteError, self).__init__("batch op errors occurred", 65, results) def __reduce__(self): return self.__class__, (self.details,) @@ -250,8 +253,8 @@ class InvalidURI(ConfigurationError): class DocumentTooLarge(InvalidDocument): - """Raised when an encoded document is too large for the connected server. - """ + """Raised when an encoded document is too large for the connected server.""" + pass @@ -275,6 +278,6 @@ def cause(self): class _OperationCancelled(AutoReconnect): - """Internal error raised when a socket operation is cancelled. - """ + """Internal error raised when a socket operation is cancelled.""" + pass diff --git a/pymongo/event_loggers.py b/pymongo/event_loggers.py index 7d5501c372..b80d96d73c 100644 --- a/pymongo/event_loggers.py +++ b/pymongo/event_loggers.py @@ -42,22 +42,29 @@ class CommandLogger(monitoring.CommandListener): logs them at the `INFO` severity level using :mod:`logging`. .. versionadded:: 3.11 """ + def started(self, event): - logging.info("Command {0.command_name} with request id " - "{0.request_id} started on server " - "{0.connection_id}".format(event)) + logging.info( + "Command {0.command_name} with request id " + "{0.request_id} started on server " + "{0.connection_id}".format(event) + ) def succeeded(self, event): - logging.info("Command {0.command_name} with request id " - "{0.request_id} on server {0.connection_id} " - "succeeded in {0.duration_micros} " - "microseconds".format(event)) + logging.info( + "Command {0.command_name} with request id " + "{0.request_id} on server {0.connection_id} " + "succeeded in {0.duration_micros} " + "microseconds".format(event) + ) def failed(self, event): - logging.info("Command {0.command_name} with request id " - "{0.request_id} on server {0.connection_id} " - "failed in {0.duration_micros} " - "microseconds".format(event)) + logging.info( + "Command {0.command_name} with request id " + "{0.request_id} on server {0.connection_id} " + "failed in {0.duration_micros} " + "microseconds".format(event) + ) class ServerLogger(monitoring.ServerListener): @@ -70,9 +77,9 @@ class ServerLogger(monitoring.ServerListener): .. versionadded:: 3.11 """ + def opened(self, event): - logging.info("Server {0.server_address} added to topology " - "{0.topology_id}".format(event)) + logging.info("Server {0.server_address} added to topology " "{0.topology_id}".format(event)) def description_changed(self, event): previous_server_type = event.previous_description.server_type @@ -82,11 +89,13 @@ def description_changed(self, event): logging.info( "Server {0.server_address} changed type from " "{0.previous_description.server_type_name} to " - "{0.new_description.server_type_name}".format(event)) + "{0.new_description.server_type_name}".format(event) + ) def closed(self, event): - logging.warning("Server {0.server_address} removed from topology " - "{0.topology_id}".format(event)) + logging.warning( + "Server {0.server_address} removed from topology " "{0.topology_id}".format(event) + ) class HeartbeatLogger(monitoring.ServerHeartbeatListener): @@ -99,19 +108,22 @@ class HeartbeatLogger(monitoring.ServerHeartbeatListener): .. versionadded:: 3.11 """ + def started(self, event): - logging.info("Heartbeat sent to server " - "{0.connection_id}".format(event)) + logging.info("Heartbeat sent to server " "{0.connection_id}".format(event)) def succeeded(self, event): # The reply.document attribute was added in PyMongo 3.4. - logging.info("Heartbeat to server {0.connection_id} " - "succeeded with reply " - "{0.reply.document}".format(event)) + logging.info( + "Heartbeat to server {0.connection_id} " + "succeeded with reply " + "{0.reply.document}".format(event) + ) def failed(self, event): - logging.warning("Heartbeat to server {0.connection_id} " - "failed with error {0.reply}".format(event)) + logging.warning( + "Heartbeat to server {0.connection_id} " "failed with error {0.reply}".format(event) + ) class TopologyLogger(monitoring.TopologyListener): @@ -124,13 +136,14 @@ class TopologyLogger(monitoring.TopologyListener): .. versionadded:: 3.11 """ + def opened(self, event): - logging.info("Topology with id {0.topology_id} " - "opened".format(event)) + logging.info("Topology with id {0.topology_id} " "opened".format(event)) def description_changed(self, event): - logging.info("Topology description updated for " - "topology id {0.topology_id}".format(event)) + logging.info( + "Topology description updated for " "topology id {0.topology_id}".format(event) + ) previous_topology_type = event.previous_description.topology_type new_topology_type = event.new_description.topology_type if new_topology_type != previous_topology_type: @@ -138,7 +151,8 @@ def description_changed(self, event): logging.info( "Topology {0.topology_id} changed type from " "{0.previous_description.topology_type_name} to " - "{0.new_description.topology_type_name}".format(event)) + "{0.new_description.topology_type_name}".format(event) + ) # The has_writable_server and has_readable_server methods # were added in PyMongo 3.4. if not event.new_description.has_writable_server(): @@ -147,8 +161,7 @@ def description_changed(self, event): logging.warning("No readable servers available.") def closed(self, event): - logging.info("Topology with id {0.topology_id} " - "closed".format(event)) + logging.info("Topology with id {0.topology_id} " "closed".format(event)) class ConnectionPoolLogger(monitoring.ConnectionPoolListener): @@ -168,6 +181,7 @@ class ConnectionPoolLogger(monitoring.ConnectionPoolListener): .. versionadded:: 3.11 """ + def pool_created(self, event): logging.info("[pool {0.address}] pool created".format(event)) @@ -181,30 +195,39 @@ def pool_closed(self, event): logging.info("[pool {0.address}] pool closed".format(event)) def connection_created(self, event): - logging.info("[pool {0.address}][conn #{0.connection_id}] " - "connection created".format(event)) + logging.info( + "[pool {0.address}][conn #{0.connection_id}] " "connection created".format(event) + ) def connection_ready(self, event): - logging.info("[pool {0.address}][conn #{0.connection_id}] " - "connection setup succeeded".format(event)) + logging.info( + "[pool {0.address}][conn #{0.connection_id}] " + "connection setup succeeded".format(event) + ) def connection_closed(self, event): - logging.info("[pool {0.address}][conn #{0.connection_id}] " - "connection closed, reason: " - "{0.reason}".format(event)) + logging.info( + "[pool {0.address}][conn #{0.connection_id}] " + "connection closed, reason: " + "{0.reason}".format(event) + ) def connection_check_out_started(self, event): - logging.info("[pool {0.address}] connection check out " - "started".format(event)) + logging.info("[pool {0.address}] connection check out " "started".format(event)) def connection_check_out_failed(self, event): - logging.info("[pool {0.address}] connection check out " - "failed, reason: {0.reason}".format(event)) + logging.info( + "[pool {0.address}] connection check out " "failed, reason: {0.reason}".format(event) + ) def connection_checked_out(self, event): - logging.info("[pool {0.address}][conn #{0.connection_id}] " - "connection checked out of pool".format(event)) + logging.info( + "[pool {0.address}][conn #{0.connection_id}] " + "connection checked out of pool".format(event) + ) def connection_checked_in(self, event): - logging.info("[pool {0.address}][conn #{0.connection_id}] " - "connection checked into pool".format(event)) + logging.info( + "[pool {0.address}][conn #{0.connection_id}] " + "connection checked into pool".format(event) + ) diff --git a/pymongo/hello.py b/pymongo/hello.py index 0ad06e9619..579929fbd6 100644 --- a/pymongo/hello.py +++ b/pymongo/hello.py @@ -21,36 +21,36 @@ class HelloCompat: - CMD = 'hello' - LEGACY_CMD = 'ismaster' - PRIMARY = 'isWritablePrimary' - LEGACY_PRIMARY = 'ismaster' - LEGACY_ERROR = 'not master' + CMD = "hello" + LEGACY_CMD = "ismaster" + PRIMARY = "isWritablePrimary" + LEGACY_PRIMARY = "ismaster" + LEGACY_ERROR = "not master" def _get_server_type(doc): """Determine the server type from a hello response.""" - if not doc.get('ok'): + if not doc.get("ok"): return SERVER_TYPE.Unknown - if doc.get('serviceId'): + if doc.get("serviceId"): return SERVER_TYPE.LoadBalancer - elif doc.get('isreplicaset'): + elif doc.get("isreplicaset"): return SERVER_TYPE.RSGhost - elif doc.get('setName'): - if doc.get('hidden'): + elif doc.get("setName"): + if doc.get("hidden"): return SERVER_TYPE.RSOther elif doc.get(HelloCompat.PRIMARY): return SERVER_TYPE.RSPrimary elif doc.get(HelloCompat.LEGACY_PRIMARY): return SERVER_TYPE.RSPrimary - elif doc.get('secondary'): + elif doc.get("secondary"): return SERVER_TYPE.RSSecondary - elif doc.get('arbiterOnly'): + elif doc.get("arbiterOnly"): return SERVER_TYPE.RSArbiter else: return SERVER_TYPE.RSOther - elif doc.get('msg') == 'isdbgrid': + elif doc.get("msg") == "isdbgrid": return SERVER_TYPE.Mongos else: return SERVER_TYPE.Standalone @@ -61,8 +61,8 @@ class Hello(object): .. versionadded:: 3.12 """ - __slots__ = ('_doc', '_server_type', '_is_writable', '_is_readable', - '_awaitable') + + __slots__ = ("_doc", "_server_type", "_is_writable", "_is_readable", "_awaitable") def __init__(self, doc, awaitable=False): self._server_type = _get_server_type(doc) @@ -71,11 +71,10 @@ def __init__(self, doc, awaitable=False): SERVER_TYPE.RSPrimary, SERVER_TYPE.Standalone, SERVER_TYPE.Mongos, - SERVER_TYPE.LoadBalancer) + SERVER_TYPE.LoadBalancer, + ) - self._is_readable = ( - self.server_type == SERVER_TYPE.RSSecondary - or self._is_writable) + self._is_readable = self.server_type == SERVER_TYPE.RSSecondary or self._is_writable self._awaitable = awaitable @property @@ -93,64 +92,70 @@ def server_type(self): @property def all_hosts(self): """List of hosts, passives, and arbiters known to this server.""" - return set(map(common.clean_node, itertools.chain( - self._doc.get('hosts', []), - self._doc.get('passives', []), - self._doc.get('arbiters', [])))) + return set( + map( + common.clean_node, + itertools.chain( + self._doc.get("hosts", []), + self._doc.get("passives", []), + self._doc.get("arbiters", []), + ), + ) + ) @property def tags(self): """Replica set member tags or empty dict.""" - return self._doc.get('tags', {}) + return self._doc.get("tags", {}) @property def primary(self): """This server's opinion about who the primary is, or None.""" - if self._doc.get('primary'): - return common.partition_node(self._doc['primary']) + if self._doc.get("primary"): + return common.partition_node(self._doc["primary"]) else: return None @property def replica_set_name(self): """Replica set name or None.""" - return self._doc.get('setName') + return self._doc.get("setName") @property def max_bson_size(self): - return self._doc.get('maxBsonObjectSize', common.MAX_BSON_SIZE) + return self._doc.get("maxBsonObjectSize", common.MAX_BSON_SIZE) @property def max_message_size(self): - return self._doc.get('maxMessageSizeBytes', 2 * self.max_bson_size) + return self._doc.get("maxMessageSizeBytes", 2 * self.max_bson_size) @property def max_write_batch_size(self): - return self._doc.get('maxWriteBatchSize', common.MAX_WRITE_BATCH_SIZE) + return self._doc.get("maxWriteBatchSize", common.MAX_WRITE_BATCH_SIZE) @property def min_wire_version(self): - return self._doc.get('minWireVersion', common.MIN_WIRE_VERSION) + return self._doc.get("minWireVersion", common.MIN_WIRE_VERSION) @property def max_wire_version(self): - return self._doc.get('maxWireVersion', common.MAX_WIRE_VERSION) + return self._doc.get("maxWireVersion", common.MAX_WIRE_VERSION) @property def set_version(self): - return self._doc.get('setVersion') + return self._doc.get("setVersion") @property def election_id(self): - return self._doc.get('electionId') + return self._doc.get("electionId") @property def cluster_time(self): - return self._doc.get('$clusterTime') + return self._doc.get("$clusterTime") @property def logical_session_timeout_minutes(self): - return self._doc.get('logicalSessionTimeoutMinutes') + return self._doc.get("logicalSessionTimeoutMinutes") @property def is_writable(self): @@ -162,17 +167,17 @@ def is_readable(self): @property def me(self): - me = self._doc.get('me') + me = self._doc.get("me") if me: return common.clean_node(me) @property def last_write_date(self): - return self._doc.get('lastWrite', {}).get('lastWriteDate') + return self._doc.get("lastWrite", {}).get("lastWriteDate") @property def compressors(self): - return self._doc.get('compression') + return self._doc.get("compression") @property def sasl_supported_mechs(self): @@ -184,16 +189,16 @@ def sasl_supported_mechs(self): ["SCRAM-SHA-1", "SCRAM-SHA-256"] """ - return self._doc.get('saslSupportedMechs', []) + return self._doc.get("saslSupportedMechs", []) @property def speculative_authenticate(self): """The speculativeAuthenticate field.""" - return self._doc.get('speculativeAuthenticate') + return self._doc.get("speculativeAuthenticate") @property def topology_version(self): - return self._doc.get('topologyVersion') + return self._doc.get("topologyVersion") @property def awaitable(self): @@ -201,8 +206,8 @@ def awaitable(self): @property def service_id(self): - return self._doc.get('serviceId') + return self._doc.get("serviceId") @property def hello_ok(self): - return self._doc.get('helloOk', False) + return self._doc.get("helloOk", False) diff --git a/pymongo/helpers.py b/pymongo/helpers.py index a9d40d8103..3345ea9378 100644 --- a/pymongo/helpers.py +++ b/pymongo/helpers.py @@ -16,45 +16,55 @@ import sys import traceback - from collections import abc from bson.son import SON from pymongo import ASCENDING -from pymongo.errors import (CursorNotFound, - DuplicateKeyError, - ExecutionTimeout, - NotPrimaryError, - OperationFailure, - WriteError, - WriteConcernError, - WTimeoutError) +from pymongo.errors import ( + CursorNotFound, + DuplicateKeyError, + ExecutionTimeout, + NotPrimaryError, + OperationFailure, + WriteConcernError, + WriteError, + WTimeoutError, +) from pymongo.hello import HelloCompat # From the SDAM spec, the "node is shutting down" codes. -_SHUTDOWN_CODES = frozenset([ - 11600, # InterruptedAtShutdown - 91, # ShutdownInProgress -]) +_SHUTDOWN_CODES = frozenset( + [ + 11600, # InterruptedAtShutdown + 91, # ShutdownInProgress + ] +) # From the SDAM spec, the "not primary" error codes are combined with the # "node is recovering" error codes (of which the "node is shutting down" # errors are a subset). -_NOT_PRIMARY_CODES = frozenset([ - 10058, # LegacyNotPrimary <=3.2 "not primary" error code - 10107, # NotWritablePrimary - 13435, # NotPrimaryNoSecondaryOk - 11602, # InterruptedDueToReplStateChange - 13436, # NotPrimaryOrSecondary - 189, # PrimarySteppedDown -]) | _SHUTDOWN_CODES +_NOT_PRIMARY_CODES = ( + frozenset( + [ + 10058, # LegacyNotPrimary <=3.2 "not primary" error code + 10107, # NotWritablePrimary + 13435, # NotPrimaryNoSecondaryOk + 11602, # InterruptedDueToReplStateChange + 13436, # NotPrimaryOrSecondary + 189, # PrimarySteppedDown + ] + ) + | _SHUTDOWN_CODES +) # From the retryable writes spec. -_RETRYABLE_ERROR_CODES = _NOT_PRIMARY_CODES | frozenset([ - 7, # HostNotFound - 6, # HostUnreachable - 89, # NetworkTimeout - 9001, # SocketException - 262, # ExceededTimeLimit -]) +_RETRYABLE_ERROR_CODES = _NOT_PRIMARY_CODES | frozenset( + [ + 7, # HostNotFound + 6, # HostUnreachable + 89, # NetworkTimeout + 9001, # SocketException + 262, # ExceededTimeLimit + ] +) def _gen_index_name(keys): @@ -75,8 +85,9 @@ def _index_list(key_or_list, direction=None): if isinstance(key_or_list, abc.ItemsView): return list(key_or_list) elif not isinstance(key_or_list, (list, tuple)): - raise TypeError("if no direction is specified, " - "key_or_list must be an instance of list") + raise TypeError( + "if no direction is specified, " "key_or_list must be an instance of list" + ) return key_or_list @@ -86,44 +97,44 @@ def _index_document(index_list): Takes a list of (key, direction) pairs. """ if isinstance(index_list, abc.Mapping): - raise TypeError("passing a dict to sort/create_index/hint is not " - "allowed - use a list of tuples instead. did you " - "mean %r?" % list(index_list.items())) + raise TypeError( + "passing a dict to sort/create_index/hint is not " + "allowed - use a list of tuples instead. did you " + "mean %r?" % list(index_list.items()) + ) elif not isinstance(index_list, (list, tuple)): - raise TypeError("must use a list of (key, direction) pairs, " - "not: " + repr(index_list)) + raise TypeError("must use a list of (key, direction) pairs, " "not: " + repr(index_list)) if not len(index_list): raise ValueError("key_or_list must not be the empty list") index = SON() for (key, value) in index_list: if not isinstance(key, str): - raise TypeError( - "first item in each key pair must be an instance of str") + raise TypeError("first item in each key pair must be an instance of str") if not isinstance(value, (str, int, abc.Mapping)): - raise TypeError("second item in each key pair must be 1, -1, " - "'2d', or another valid MongoDB index specifier.") + raise TypeError( + "second item in each key pair must be 1, -1, " + "'2d', or another valid MongoDB index specifier." + ) index[key] = value return index -def _check_command_response(response, max_wire_version, - allowable_errors=None, - parse_write_concern_error=False): - """Check the response to a command for errors. - """ +def _check_command_response( + response, max_wire_version, allowable_errors=None, parse_write_concern_error=False +): + """Check the response to a command for errors.""" if "ok" not in response: # Server didn't recognize our message as a command. - raise OperationFailure(response.get("$err"), - response.get("code"), - response, - max_wire_version) + raise OperationFailure( + response.get("$err"), response.get("code"), response, max_wire_version + ) - if parse_write_concern_error and 'writeConcernError' in response: + if parse_write_concern_error and "writeConcernError" in response: _error = response["writeConcernError"] _labels = response.get("errorLabels") if _labels: - _error.update({'errorLabels': _labels}) + _error.update({"errorLabels": _labels}) _raise_write_concern_error(_error) if response["ok"]: @@ -180,12 +191,10 @@ def _raise_last_write_error(write_errors): def _raise_write_concern_error(error): - if "errInfo" in error and error["errInfo"].get('wtimeout'): + if "errInfo" in error and error["errInfo"].get("wtimeout"): # Make sure we raise WTimeoutError - raise WTimeoutError( - error.get("errmsg"), error.get("code"), error) - raise WriteConcernError( - error.get("errmsg"), error.get("code"), error) + raise WTimeoutError(error.get("errmsg"), error.get("code"), error) + raise WriteConcernError(error.get("errmsg"), error.get("code"), error) def _get_wce_doc(result): @@ -201,8 +210,7 @@ def _get_wce_doc(result): def _check_write_command_response(result): - """Backward compatibility helper for write command error handling. - """ + """Backward compatibility helper for write command error handling.""" # Prefer write errors over write concern errors write_errors = result.get("writeErrors") if write_errors: @@ -227,12 +235,12 @@ def _fields_list_to_dict(fields, option_name): if isinstance(fields, (abc.Sequence, abc.Set)): if not all(isinstance(field, str) for field in fields): - raise TypeError("%s must be a list of key names, each an " - "instance of str" % (option_name,)) + raise TypeError( + "%s must be a list of key names, each an " "instance of str" % (option_name,) + ) return dict.fromkeys(fields, 1) - raise TypeError("%s must be a mapping or " - "list of key names" % (option_name,)) + raise TypeError("%s must be a mapping or " "list of key names" % (option_name,)) def _handle_exception(): @@ -244,8 +252,7 @@ def _handle_exception(): if sys.stderr: einfo = sys.exc_info() try: - traceback.print_exception(einfo[0], einfo[1], einfo[2], - None, sys.stderr) + traceback.print_exception(einfo[0], einfo[1], einfo[2], None, sys.stderr) except IOError: pass finally: diff --git a/pymongo/max_staleness_selectors.py b/pymongo/max_staleness_selectors.py index 6bc2fe7232..28b0bb615e 100644 --- a/pymongo/max_staleness_selectors.py +++ b/pymongo/max_staleness_selectors.py @@ -30,28 +30,27 @@ from pymongo.errors import ConfigurationError from pymongo.server_type import SERVER_TYPE - # Constant defined in Max Staleness Spec: An idle primary writes a no-op every # 10 seconds to refresh secondaries' lastWriteDate values. IDLE_WRITE_PERIOD = 10 SMALLEST_MAX_STALENESS = 90 -def _validate_max_staleness(max_staleness, - heartbeat_frequency): +def _validate_max_staleness(max_staleness, heartbeat_frequency): # We checked for max staleness -1 before this, it must be positive here. if max_staleness < heartbeat_frequency + IDLE_WRITE_PERIOD: raise ConfigurationError( "maxStalenessSeconds must be at least heartbeatFrequencyMS +" " %d seconds. maxStalenessSeconds is set to %d," - " heartbeatFrequencyMS is set to %d." % ( - IDLE_WRITE_PERIOD, max_staleness, heartbeat_frequency * 1000)) + " heartbeatFrequencyMS is set to %d." + % (IDLE_WRITE_PERIOD, max_staleness, heartbeat_frequency * 1000) + ) if max_staleness < SMALLEST_MAX_STALENESS: raise ConfigurationError( "maxStalenessSeconds must be at least %d. " - "maxStalenessSeconds is set to %d." % ( - SMALLEST_MAX_STALENESS, max_staleness)) + "maxStalenessSeconds is set to %d." % (SMALLEST_MAX_STALENESS, max_staleness) + ) def _with_primary(max_staleness, selection): @@ -63,9 +62,10 @@ def _with_primary(max_staleness, selection): if s.server_type == SERVER_TYPE.RSSecondary: # See max-staleness.rst for explanation of this formula. staleness = ( - (s.last_update_time - s.last_write_date) - - (primary.last_update_time - primary.last_write_date) + - selection.heartbeat_frequency) + (s.last_update_time - s.last_write_date) + - (primary.last_update_time - primary.last_write_date) + + selection.heartbeat_frequency + ) if staleness <= max_staleness: sds.append(s) @@ -88,9 +88,7 @@ def _no_primary(max_staleness, selection): for s in selection.server_descriptions: if s.server_type == SERVER_TYPE.RSSecondary: # See max-staleness.rst for explanation of this formula. - staleness = (smax.last_write_date - - s.last_write_date + - selection.heartbeat_frequency) + staleness = smax.last_write_date - s.last_write_date + selection.heartbeat_frequency if staleness <= max_staleness: sds.append(s) diff --git a/pymongo/message.py b/pymongo/message.py index fe203c8431..f41a4e10d8 100644 --- a/pymongo/message.py +++ b/pymongo/message.py @@ -23,38 +23,34 @@ import datetime import random import struct - from io import BytesIO as _BytesIO import bson -from bson import (CodecOptions, - encode, - _decode_selective, - _dict_to_bson, - _make_c_string) +from bson import CodecOptions, _decode_selective, _dict_to_bson, _make_c_string, encode from bson.int64 import Int64 -from bson.raw_bson import (_inflate_bson, DEFAULT_RAW_BSON_OPTIONS, - RawBSONDocument) +from bson.raw_bson import DEFAULT_RAW_BSON_OPTIONS, RawBSONDocument, _inflate_bson from bson.son import SON try: from pymongo import _cmessage + _use_c = True except ImportError: _use_c = False -from pymongo.errors import (ConfigurationError, - CursorNotFound, - DocumentTooLarge, - ExecutionTimeout, - InvalidOperation, - NotPrimaryError, - OperationFailure, - ProtocolError) +from pymongo.errors import ( + ConfigurationError, + CursorNotFound, + DocumentTooLarge, + ExecutionTimeout, + InvalidOperation, + NotPrimaryError, + OperationFailure, + ProtocolError, +) from pymongo.hello import HelloCompat from pymongo.read_preferences import ReadPreference from pymongo.write_concern import WriteConcern - MAX_INT32 = 2147483647 MIN_INT32 = -2147483648 @@ -65,26 +61,21 @@ _UPDATE = 1 _DELETE = 2 -_EMPTY = b'' -_BSONOBJ = b'\x03' -_ZERO_8 = b'\x00' -_ZERO_16 = b'\x00\x00' -_ZERO_32 = b'\x00\x00\x00\x00' -_ZERO_64 = b'\x00\x00\x00\x00\x00\x00\x00\x00' -_SKIPLIM = b'\x00\x00\x00\x00\xff\xff\xff\xff' +_EMPTY = b"" +_BSONOBJ = b"\x03" +_ZERO_8 = b"\x00" +_ZERO_16 = b"\x00\x00" +_ZERO_32 = b"\x00\x00\x00\x00" +_ZERO_64 = b"\x00\x00\x00\x00\x00\x00\x00\x00" +_SKIPLIM = b"\x00\x00\x00\x00\xff\xff\xff\xff" _OP_MAP = { - _INSERT: b'\x04documents\x00\x00\x00\x00\x00', - _UPDATE: b'\x04updates\x00\x00\x00\x00\x00', - _DELETE: b'\x04deletes\x00\x00\x00\x00\x00', -} -_FIELD_MAP = { - 'insert': 'documents', - 'update': 'updates', - 'delete': 'deletes' + _INSERT: b"\x04documents\x00\x00\x00\x00\x00", + _UPDATE: b"\x04updates\x00\x00\x00\x00\x00", + _DELETE: b"\x04deletes\x00\x00\x00\x00\x00", } +_FIELD_MAP = {"insert": "documents", "update": "updates", "delete": "deletes"} -_UNICODE_REPLACE_CODEC_OPTIONS = CodecOptions( - unicode_decode_error_handler='replace') +_UNICODE_REPLACE_CODEC_OPTIONS = CodecOptions(unicode_decode_error_handler="replace") def _randint(): @@ -101,9 +92,7 @@ def _maybe_add_read_preference(spec, read_preference): # for maximum backwards compatibility, don't add $readPreference for # secondaryPreferred unless tags or maxStalenessSeconds are in use (setting # the secondaryOkay bit has the same effect). - if mode and ( - mode != ReadPreference.SECONDARY_PREFERRED.mode or - len(document) > 1): + if mode and (mode != ReadPreference.SECONDARY_PREFERRED.mode or len(document) > 1): if "$query" not in spec: spec = SON([("$query", spec)]) spec["$readPreference"] = document @@ -112,8 +101,7 @@ def _maybe_add_read_preference(spec, read_preference): def _convert_exception(exception): """Convert an Exception into a failure document for publishing.""" - return {'errmsg': str(exception), - 'errtype': exception.__class__.__name__} + return {"errmsg": str(exception), "errtype": exception.__class__.__name__} def _convert_write_result(operation, command, result): @@ -126,21 +114,17 @@ def _convert_write_result(operation, command, result): if errmsg: # The write was successful on at least the primary so don't return. if result.get("wtimeout"): - res["writeConcernError"] = {"errmsg": errmsg, - "code": 64, - "errInfo": {"wtimeout": True}} + res["writeConcernError"] = {"errmsg": errmsg, "code": 64, "errInfo": {"wtimeout": True}} else: # The write failed. - error = {"index": 0, - "code": result.get("code", 8), - "errmsg": errmsg} + error = {"index": 0, "code": result.get("code", 8), "errmsg": errmsg} if "errInfo" in result: error["errInfo"] = result["errInfo"] res["writeErrors"] = [error] return res if operation == "insert": # GLE result for insert is always 0 in most MongoDB versions. - res["n"] = len(command['documents']) + res["n"] = len(command["documents"]) elif operation == "update": if "upserted" in result: res["upserted"] = [{"index": 0, "_id": result["upserted"]}] @@ -149,102 +133,149 @@ def _convert_write_result(operation, command, result): elif result.get("updatedExisting") is False and affected == 1: # If _id is in both the update document *and* the query spec # the update document _id takes precedence. - update = command['updates'][0] + update = command["updates"][0] _id = update["u"].get("_id", update["q"].get("_id")) res["upserted"] = [{"index": 0, "_id": _id}] return res -_OPTIONS = SON([ - ('tailable', 2), - ('oplogReplay', 8), - ('noCursorTimeout', 16), - ('awaitData', 32), - ('allowPartialResults', 128)]) - - -_MODIFIERS = SON([ - ('$query', 'filter'), - ('$orderby', 'sort'), - ('$hint', 'hint'), - ('$comment', 'comment'), - ('$maxScan', 'maxScan'), - ('$maxTimeMS', 'maxTimeMS'), - ('$max', 'max'), - ('$min', 'min'), - ('$returnKey', 'returnKey'), - ('$showRecordId', 'showRecordId'), - ('$showDiskLoc', 'showRecordId'), # <= MongoDb 3.0 - ('$snapshot', 'snapshot')]) - - -def _gen_find_command(coll, spec, projection, skip, limit, batch_size, options, - read_concern, collation=None, session=None, - allow_disk_use=None): +_OPTIONS = SON( + [ + ("tailable", 2), + ("oplogReplay", 8), + ("noCursorTimeout", 16), + ("awaitData", 32), + ("allowPartialResults", 128), + ] +) + + +_MODIFIERS = SON( + [ + ("$query", "filter"), + ("$orderby", "sort"), + ("$hint", "hint"), + ("$comment", "comment"), + ("$maxScan", "maxScan"), + ("$maxTimeMS", "maxTimeMS"), + ("$max", "max"), + ("$min", "min"), + ("$returnKey", "returnKey"), + ("$showRecordId", "showRecordId"), + ("$showDiskLoc", "showRecordId"), # <= MongoDb 3.0 + ("$snapshot", "snapshot"), + ] +) + + +def _gen_find_command( + coll, + spec, + projection, + skip, + limit, + batch_size, + options, + read_concern, + collation=None, + session=None, + allow_disk_use=None, +): """Generate a find command document.""" - cmd = SON([('find', coll)]) - if '$query' in spec: - cmd.update([(_MODIFIERS[key], val) if key in _MODIFIERS else (key, val) - for key, val in spec.items()]) - if '$explain' in cmd: - cmd.pop('$explain') - if '$readPreference' in cmd: - cmd.pop('$readPreference') + cmd = SON([("find", coll)]) + if "$query" in spec: + cmd.update( + [ + (_MODIFIERS[key], val) if key in _MODIFIERS else (key, val) + for key, val in spec.items() + ] + ) + if "$explain" in cmd: + cmd.pop("$explain") + if "$readPreference" in cmd: + cmd.pop("$readPreference") else: - cmd['filter'] = spec + cmd["filter"] = spec if projection: - cmd['projection'] = projection + cmd["projection"] = projection if skip: - cmd['skip'] = skip + cmd["skip"] = skip if limit: - cmd['limit'] = abs(limit) + cmd["limit"] = abs(limit) if limit < 0: - cmd['singleBatch'] = True + cmd["singleBatch"] = True if batch_size: - cmd['batchSize'] = batch_size + cmd["batchSize"] = batch_size if read_concern.level and not (session and session.in_transaction): - cmd['readConcern'] = read_concern.document + cmd["readConcern"] = read_concern.document if collation: - cmd['collation'] = collation + cmd["collation"] = collation if allow_disk_use is not None: - cmd['allowDiskUse'] = allow_disk_use + cmd["allowDiskUse"] = allow_disk_use if options: - cmd.update([(opt, True) - for opt, val in _OPTIONS.items() - if options & val]) + cmd.update([(opt, True) for opt, val in _OPTIONS.items() if options & val]) return cmd def _gen_get_more_command(cursor_id, coll, batch_size, max_await_time_ms): """Generate a getMore command document.""" - cmd = SON([('getMore', cursor_id), - ('collection', coll)]) + cmd = SON([("getMore", cursor_id), ("collection", coll)]) if batch_size: - cmd['batchSize'] = batch_size + cmd["batchSize"] = batch_size if max_await_time_ms is not None: - cmd['maxTimeMS'] = max_await_time_ms + cmd["maxTimeMS"] = max_await_time_ms return cmd class _Query(object): """A query operation.""" - __slots__ = ('flags', 'db', 'coll', 'ntoskip', 'spec', - 'fields', 'codec_options', 'read_preference', 'limit', - 'batch_size', 'name', 'read_concern', 'collation', - 'session', 'client', 'allow_disk_use', '_as_command', - 'exhaust') + __slots__ = ( + "flags", + "db", + "coll", + "ntoskip", + "spec", + "fields", + "codec_options", + "read_preference", + "limit", + "batch_size", + "name", + "read_concern", + "collation", + "session", + "client", + "allow_disk_use", + "_as_command", + "exhaust", + ) # For compatibility with the _GetMore class. sock_mgr = None cursor_id = None - def __init__(self, flags, db, coll, ntoskip, spec, fields, - codec_options, read_preference, limit, - batch_size, read_concern, collation, session, client, - allow_disk_use, exhaust): + def __init__( + self, + flags, + db, + coll, + ntoskip, + spec, + fields, + codec_options, + read_preference, + limit, + batch_size, + read_concern, + collation, + session, + client, + allow_disk_use, + exhaust, + ): self.flags = flags self.db = db self.coll = coll @@ -260,7 +291,7 @@ def __init__(self, flags, db, coll, ntoskip, spec, fields, self.session = session self.client = client self.allow_disk_use = allow_disk_use - self.name = 'find' + self.name = "find" self._as_command = None self.exhaust = exhaust @@ -276,10 +307,10 @@ def use_command(self, sock_info): use_find_cmd = True elif not self.read_concern.ok_for_legacy: raise ConfigurationError( - 'read concern level of %s is not valid ' - 'with a max wire version of %d.' - % (self.read_concern.level, - sock_info.max_wire_version)) + "read concern level of %s is not valid " + "with a max wire version of %d." + % (self.read_concern.level, sock_info.max_wire_version) + ) sock_info.validate_session(self.client, self.session) return use_find_cmd @@ -291,14 +322,23 @@ def as_command(self, sock_info): if self._as_command is not None: return self._as_command - explain = '$explain' in self.spec + explain = "$explain" in self.spec cmd = _gen_find_command( - self.coll, self.spec, self.fields, self.ntoskip, - self.limit, self.batch_size, self.flags, self.read_concern, - self.collation, self.session, self.allow_disk_use) + self.coll, + self.spec, + self.fields, + self.ntoskip, + self.limit, + self.batch_size, + self.flags, + self.read_concern, + self.collation, + self.session, + self.allow_disk_use, + ) if explain: - self.name = 'explain' - cmd = SON([('explain', cmd)]) + self.name = "explain" + cmd = SON([("explain", cmd)]) session = self.session sock_info.add_server_api(cmd) if session: @@ -309,10 +349,8 @@ def as_command(self, sock_info): sock_info.send_cluster_time(cmd, session, self.client) # Support auto encryption client = self.client - if (client._encrypter and - not client._encrypter._bypass_auto_encryption): - cmd = client._encrypter.encrypt( - self.db, cmd, False, self.codec_options) + if client._encrypter and not client._encrypter._bypass_auto_encryption: + cmd = client._encrypter.encrypt(self.db, cmd, False, self.codec_options) self._as_command = cmd, self.db return self._as_command @@ -330,9 +368,15 @@ def get_message(self, set_secondary_ok, sock_info, use_cmd=False): if use_cmd: spec = self.as_command(sock_info)[0] request_id, msg, size, _ = _op_msg( - 0, spec, self.db, self.read_preference, - set_secondary_ok, False, self.codec_options, - ctx=sock_info.compression_context) + 0, + spec, + self.db, + self.read_preference, + set_secondary_ok, + False, + self.codec_options, + ctx=sock_info.compression_context, + ) return request_id, msg, size # OP_QUERY treats ntoreturn of -1 and 1 the same, return @@ -346,26 +390,54 @@ def get_message(self, set_secondary_ok, sock_info, use_cmd=False): ntoreturn = self.limit if sock_info.is_mongos: - spec = _maybe_add_read_preference(spec, - self.read_preference) + spec = _maybe_add_read_preference(spec, self.read_preference) - return _query(flags, ns, self.ntoskip, ntoreturn, - spec, None if use_cmd else self.fields, - self.codec_options, ctx=sock_info.compression_context) + return _query( + flags, + ns, + self.ntoskip, + ntoreturn, + spec, + None if use_cmd else self.fields, + self.codec_options, + ctx=sock_info.compression_context, + ) class _GetMore(object): """A getmore operation.""" - __slots__ = ('db', 'coll', 'ntoreturn', 'cursor_id', 'max_await_time_ms', - 'codec_options', 'read_preference', 'session', 'client', - 'sock_mgr', '_as_command', 'exhaust') - - name = 'getMore' - - def __init__(self, db, coll, ntoreturn, cursor_id, codec_options, - read_preference, session, client, max_await_time_ms, - sock_mgr, exhaust): + __slots__ = ( + "db", + "coll", + "ntoreturn", + "cursor_id", + "max_await_time_ms", + "codec_options", + "read_preference", + "session", + "client", + "sock_mgr", + "_as_command", + "exhaust", + ) + + name = "getMore" + + def __init__( + self, + db, + coll, + ntoreturn, + cursor_id, + codec_options, + read_preference, + session, + client, + max_await_time_ms, + sock_mgr, + exhaust, + ): self.db = db self.coll = coll self.ntoreturn = ntoreturn @@ -399,9 +471,9 @@ def as_command(self, sock_info): if self._as_command is not None: return self._as_command - cmd = _gen_get_more_command(self.cursor_id, self.coll, - self.ntoreturn, - self.max_await_time_ms) + cmd = _gen_get_more_command( + self.cursor_id, self.coll, self.ntoreturn, self.max_await_time_ms + ) if self.session: self.session._apply_to(cmd, False, self.read_preference, sock_info) @@ -409,10 +481,8 @@ def as_command(self, sock_info): sock_info.send_cluster_time(cmd, self.session, self.client) # Support auto encryption client = self.client - if (client._encrypter and - not client._encrypter._bypass_auto_encryption): - cmd = client._encrypter.encrypt( - self.db, cmd, False, self.codec_options) + if client._encrypter and not client._encrypter._bypass_auto_encryption: + cmd = client._encrypter.encrypt(self.db, cmd, False, self.codec_options) self._as_command = cmd, self.db return self._as_command @@ -429,9 +499,15 @@ def get_message(self, dummy0, sock_info, use_cmd=False): else: flags = 0 request_id, msg, size, _ = _op_msg( - flags, spec, self.db, None, - False, False, self.codec_options, - ctx=sock_info.compression_context) + flags, + spec, + self.db, + None, + False, + False, + self.codec_options, + ctx=sock_info.compression_context, + ) return request_id, msg, size return _get_more(ns, self.ntoreturn, self.cursor_id, ctx) @@ -481,8 +557,7 @@ def __hash__(self): def __eq__(self, other): if isinstance(other, _CursorAddress): - return (tuple(self) == tuple(other) - and self.namespace == other.namespace) + return tuple(self) == tuple(other) and self.namespace == other.namespace return NotImplemented def __ne__(self, other): @@ -492,19 +567,21 @@ def __ne__(self, other): _pack_compression_header = struct.Struct(" max_message_size)) + doc_too_large = idx == 0 and (new_message_size > max_message_size) # When OP_MSG is used unacknowleged we have to check # document size client side or applications won't be notified. # Otherwise we let the server deal with documents that are too large # since ordered=False causes those documents to be skipped instead of # halting the bulk write operation. - unacked_doc_too_large = (not ack and (doc_length > max_bson_size)) + unacked_doc_too_large = not ack and (doc_length > max_bson_size) if doc_too_large or unacked_doc_too_large: write_op = list(_FIELD_MAP.keys())[operation] - _raise_document_too_large( - write_op, len(value), max_bson_size) + _raise_document_too_large(write_op, len(value), max_bson_size) # We have enough data, return this batch. if new_message_size > max_message_size: break @@ -990,37 +1144,31 @@ def _batched_op_msg_impl( return to_send, length -def _encode_batched_op_msg( - operation, command, docs, check_keys, ack, opts, ctx): +def _encode_batched_op_msg(operation, command, docs, check_keys, ack, opts, ctx): """Encode the next batched insert, update, or delete operation as OP_MSG. """ buf = _BytesIO() - to_send, _ = _batched_op_msg_impl( - operation, command, docs, check_keys, ack, opts, ctx, buf) + to_send, _ = _batched_op_msg_impl(operation, command, docs, check_keys, ack, opts, ctx, buf) return buf.getvalue(), to_send + + if _use_c: _encode_batched_op_msg = _cmessage._encode_batched_op_msg -def _batched_op_msg_compressed( - operation, command, docs, check_keys, ack, opts, ctx): +def _batched_op_msg_compressed(operation, command, docs, check_keys, ack, opts, ctx): """Create the next batched insert, update, or delete operation with OP_MSG, compressed. """ - data, to_send = _encode_batched_op_msg( - operation, command, docs, check_keys, ack, opts, ctx) + data, to_send = _encode_batched_op_msg(operation, command, docs, check_keys, ack, opts, ctx) - request_id, msg = _compress( - 2013, - data, - ctx.sock_info.compression_context) + request_id, msg = _compress(2013, data, ctx.sock_info.compression_context) return request_id, msg, to_send -def _batched_op_msg( - operation, command, docs, check_keys, ack, opts, ctx): +def _batched_op_msg(operation, command, docs, check_keys, ack, opts, ctx): """OP_MSG implementation entry point.""" buf = _BytesIO() @@ -1030,7 +1178,8 @@ def _batched_op_msg( buf.write(b"\x00\x00\x00\x00\xdd\x07\x00\x00") to_send, length = _batched_op_msg_impl( - operation, command, docs, check_keys, ack, opts, ctx, buf) + operation, command, docs, check_keys, ack, opts, ctx, buf + ) # Header - request id and message length buf.seek(4) @@ -1040,45 +1189,44 @@ def _batched_op_msg( buf.write(_pack_int(length)) return request_id, buf.getvalue(), to_send + + if _use_c: _batched_op_msg = _cmessage._batched_op_msg -def _do_batched_op_msg( - namespace, operation, command, docs, check_keys, opts, ctx): +def _do_batched_op_msg(namespace, operation, command, docs, check_keys, opts, ctx): """Create the next batched insert, update, or delete operation using OP_MSG. """ - command['$db'] = namespace.split('.', 1)[0] - if 'writeConcern' in command: - ack = bool(command['writeConcern'].get('w', 1)) + command["$db"] = namespace.split(".", 1)[0] + if "writeConcern" in command: + ack = bool(command["writeConcern"].get("w", 1)) else: ack = True if ctx.sock_info.compression_context: - return _batched_op_msg_compressed( - operation, command, docs, check_keys, ack, opts, ctx) - return _batched_op_msg( - operation, command, docs, check_keys, ack, opts, ctx) + return _batched_op_msg_compressed(operation, command, docs, check_keys, ack, opts, ctx) + return _batched_op_msg(operation, command, docs, check_keys, ack, opts, ctx) # End OP_MSG ----------------------------------------------------- -def _encode_batched_write_command( - namespace, operation, command, docs, check_keys, opts, ctx): - """Encode the next batched insert, update, or delete command. - """ +def _encode_batched_write_command(namespace, operation, command, docs, check_keys, opts, ctx): + """Encode the next batched insert, update, or delete command.""" buf = _BytesIO() to_send, _ = _batched_write_command_impl( - namespace, operation, command, docs, check_keys, opts, ctx, buf) + namespace, operation, command, docs, check_keys, opts, ctx, buf + ) return buf.getvalue(), to_send + + if _use_c: _encode_batched_write_command = _cmessage._encode_batched_write_command -def _batched_write_command_impl( - namespace, operation, command, docs, check_keys, opts, ctx, buf): +def _batched_write_command_impl(namespace, operation, command, docs, check_keys, opts, ctx, buf): """Create a batched OP_QUERY write command.""" max_bson_size = ctx.max_bson_size max_write_batch_size = ctx.max_write_batch_size @@ -1090,7 +1238,7 @@ def _batched_write_command_impl( # No options buf.write(_ZERO_32) # Namespace as C string - buf.write(namespace.encode('utf8')) + buf.write(namespace.encode("utf8")) buf.write(_ZERO_8) # Skip: 0, Limit: -1 buf.write(_SKIPLIM) @@ -1106,7 +1254,7 @@ def _batched_write_command_impl( try: buf.write(_OP_MAP[operation]) except KeyError: - raise InvalidOperation('Unknown command') + raise InvalidOperation("Unknown command") if operation in (_UPDATE, _DELETE): check_keys = False @@ -1117,18 +1265,16 @@ def _batched_write_command_impl( idx = 0 for doc in docs: # Encode the current operation - key = str(idx).encode('utf8') + key = str(idx).encode("utf8") value = encode(doc, check_keys, opts) # Is there enough room to add this document? max_cmd_size accounts for # the two trailing null bytes. doc_too_large = len(value) > max_cmd_size if doc_too_large: write_op = list(_FIELD_MAP.keys())[operation] - _raise_document_too_large( - write_op, len(value), max_bson_size) - enough_data = (idx >= 1 and - (buf.tell() + len(key) + len(value)) >= max_split_size) - enough_documents = (idx >= max_write_batch_size) + _raise_document_too_large(write_op, len(value), max_bson_size) + enough_data = idx >= 1 and (buf.tell() + len(key) + len(value)) >= max_split_size + enough_documents = idx >= max_write_batch_size if enough_data or enough_documents: break buf.write(_BSONOBJ) @@ -1196,20 +1342,25 @@ def raw_response(self, cursor_id=None, user_fields=None): if error_object["$err"].startswith(HelloCompat.LEGACY_ERROR): raise NotPrimaryError(error_object["$err"], error_object) elif error_object.get("code") == 50: - raise ExecutionTimeout(error_object.get("$err"), - error_object.get("code"), - error_object) - raise OperationFailure("database error: %s" % - error_object.get("$err"), - error_object.get("code"), - error_object) + raise ExecutionTimeout( + error_object.get("$err"), error_object.get("code"), error_object + ) + raise OperationFailure( + "database error: %s" % error_object.get("$err"), + error_object.get("code"), + error_object, + ) if self.documents: return [self.documents] return [] - def unpack_response(self, cursor_id=None, - codec_options=_UNICODE_REPLACE_CODEC_OPTIONS, - user_fields=None, legacy_response=False): + def unpack_response( + self, + cursor_id=None, + codec_options=_UNICODE_REPLACE_CODEC_OPTIONS, + user_fields=None, + legacy_response=False, + ): """Unpack a response from the database and decode the BSON document(s). Check the response for errors and unpack, returning a dictionary @@ -1228,8 +1379,7 @@ def unpack_response(self, cursor_id=None, self.raw_response(cursor_id) if legacy_response: return bson.decode_all(self.documents, codec_options) - return bson._decode_all_selective( - self.documents, codec_options, user_fields) + return bson._decode_all_selective(self.documents, codec_options, user_fields) def command_response(self, codec_options): """Unpack a command response.""" @@ -1280,13 +1430,17 @@ def raw_response(self, cursor_id=None, user_fields={}): user_fields is used to determine which fields must not be decoded """ inflated_response = _decode_selective( - RawBSONDocument(self.payload_document), user_fields, - DEFAULT_RAW_BSON_OPTIONS) + RawBSONDocument(self.payload_document), user_fields, DEFAULT_RAW_BSON_OPTIONS + ) return [inflated_response] - def unpack_response(self, cursor_id=None, - codec_options=_UNICODE_REPLACE_CODEC_OPTIONS, - user_fields=None, legacy_response=False): + def unpack_response( + self, + cursor_id=None, + codec_options=_UNICODE_REPLACE_CODEC_OPTIONS, + user_fields=None, + legacy_response=False, + ): """Unpack a OP_MSG command response. :Parameters: @@ -1296,8 +1450,7 @@ def unpack_response(self, cursor_id=None, """ # If _OpMsg is in-use, this cannot be a legacy response. assert not legacy_response - return bson._decode_all_selective( - self.payload_document, codec_options, user_fields) + return bson._decode_all_selective(self.payload_document, codec_options, user_fields) def command_response(self, codec_options): """Unpack a command response.""" @@ -1318,17 +1471,12 @@ def unpack(cls, msg): flags, first_payload_type, first_payload_size = cls.UNPACK_FROM(msg) if flags != 0: if flags & cls.CHECKSUM_PRESENT: - raise ProtocolError( - "Unsupported OP_MSG flag checksumPresent: " - "0x%x" % (flags,)) + raise ProtocolError("Unsupported OP_MSG flag checksumPresent: " "0x%x" % (flags,)) if flags ^ cls.MORE_TO_COME: - raise ProtocolError( - "Unsupported OP_MSG flags: 0x%x" % (flags,)) + raise ProtocolError("Unsupported OP_MSG flags: 0x%x" % (flags,)) if first_payload_type != 0: - raise ProtocolError( - "Unsupported OP_MSG payload type: " - "0x%x" % (first_payload_type,)) + raise ProtocolError("Unsupported OP_MSG payload type: " "0x%x" % (first_payload_type,)) if len(msg) != first_payload_size + 5: raise ProtocolError("Unsupported OP_MSG reply: >1 section") diff --git a/pymongo/mongo_client.py b/pymongo/mongo_client.py index 70b34156e7..17e0a7cf3a 100644 --- a/pymongo/mongo_client.py +++ b/pymongo/mongo_client.py @@ -34,42 +34,46 @@ import contextlib import threading import weakref - from collections import defaultdict from bson.codec_options import DEFAULT_CODEC_OPTIONS from bson.son import SON -from pymongo import (common, - database, - helpers, - message, - periodic_executor, - uri_parser, - client_session) +from pymongo import ( + client_session, + common, + database, + helpers, + message, + periodic_executor, + uri_parser, +) from pymongo.change_stream import ClusterChangeStream from pymongo.client_options import ClientOptions from pymongo.command_cursor import CommandCursor -from pymongo.errors import (AutoReconnect, - BulkWriteError, - ConfigurationError, - ConnectionFailure, - InvalidOperation, - NotPrimaryError, - OperationFailure, - PyMongoError, - ServerSelectionTimeoutError) +from pymongo.errors import ( + AutoReconnect, + BulkWriteError, + ConfigurationError, + ConnectionFailure, + InvalidOperation, + NotPrimaryError, + OperationFailure, + PyMongoError, + ServerSelectionTimeoutError, +) from pymongo.pool import ConnectionClosedReason from pymongo.read_preferences import ReadPreference from pymongo.server_selectors import writable_server_selector from pymongo.server_type import SERVER_TYPE -from pymongo.topology import (Topology, - _ErrorContext) -from pymongo.topology_description import TOPOLOGY_TYPE from pymongo.settings import TopologySettings -from pymongo.uri_parser import (_handle_option_deprecations, - _handle_security_options, - _normalize_options, - _check_options) +from pymongo.topology import Topology, _ErrorContext +from pymongo.topology_description import TOPOLOGY_TYPE +from pymongo.uri_parser import ( + _check_options, + _handle_option_deprecations, + _handle_security_options, + _normalize_options, +) from pymongo.write_concern import DEFAULT_WRITE_CONCERN @@ -83,21 +87,23 @@ class MongoClient(common.BaseObject): resources related to this, including background threads for monitoring, and connection pools. """ + HOST = "localhost" PORT = 27017 # Define order to retrieve options from ClientOptions for __repr__. # No host/port; these are retrieved from TopologySettings. - _constructor_args = ('document_class', 'tz_aware', 'connect') + _constructor_args = ("document_class", "tz_aware", "connect") def __init__( - self, - host=None, - port=None, - document_class=dict, - tz_aware=None, - connect=None, - type_registry=None, - **kwargs): + self, + host=None, + port=None, + document_class=dict, + tz_aware=None, + connect=None, + type_registry=None, + **kwargs + ): """Client for a MongoDB instance, a replica set, or a set of mongoses. The client object is thread-safe and has connection-pooling built in. @@ -627,13 +633,15 @@ def __init__( client.__my_database__ """ - self.__init_kwargs = {'host': host, - 'port': port, - 'document_class': document_class, - 'tz_aware': tz_aware, - 'connect': connect, - 'type_registry': type_registry, - **kwargs} + self.__init_kwargs = { + "host": host, + "port": port, + "document_class": document_class, + "tz_aware": tz_aware, + "connect": connect, + "type_registry": type_registry, + **kwargs, + } if host is None: host = self.HOST @@ -646,13 +654,13 @@ def __init__( # _pool_class, _monitor_class, and _condition_class are for deep # customization of PyMongo, e.g. Motor. - pool_class = kwargs.pop('_pool_class', None) - monitor_class = kwargs.pop('_monitor_class', None) - condition_class = kwargs.pop('_condition_class', None) + pool_class = kwargs.pop("_pool_class", None) + monitor_class = kwargs.pop("_monitor_class", None) + condition_class = kwargs.pop("_condition_class", None) # Parse options passed as kwargs. keyword_opts = common._CaseInsensitiveDictionary(kwargs) - keyword_opts['document_class'] = document_class + keyword_opts["document_class"] = document_class seeds = set() username = None @@ -663,8 +671,7 @@ def __init__( srv_service_name = keyword_opts.get("srvservicename") srv_max_hosts = keyword_opts.get("srvmaxhosts") if len([h for h in host if "/" in h]) > 1: - raise ConfigurationError("host must not contain multiple MongoDB " - "URIs") + raise ConfigurationError("host must not contain multiple MongoDB " "URIs") for entity in host: # A hostname can only include a-z, 0-9, '-' and '.'. If we find a '/' # it must be a URI, @@ -674,12 +681,18 @@ def __init__( timeout = keyword_opts.get("connecttimeoutms") if timeout is not None: timeout = common.validate_timeout_or_none_or_zero( - keyword_opts.cased_key("connecttimeoutms"), timeout) + keyword_opts.cased_key("connecttimeoutms"), timeout + ) res = uri_parser.parse_uri( - entity, port, validate=True, warn=True, normalize=False, + entity, + port, + validate=True, + warn=True, + normalize=False, connect_timeout=timeout, srv_service_name=srv_service_name, - srv_max_hosts=srv_max_hosts) + srv_max_hosts=srv_max_hosts, + ) seeds.update(res["nodelist"]) username = res["username"] or username password = res["password"] or password @@ -693,19 +706,20 @@ def __init__( # Add options with named keyword arguments to the parsed kwarg options. if type_registry is not None: - keyword_opts['type_registry'] = type_registry + keyword_opts["type_registry"] = type_registry if tz_aware is None: - tz_aware = opts.get('tz_aware', False) + tz_aware = opts.get("tz_aware", False) if connect is None: - connect = opts.get('connect', True) - keyword_opts['tz_aware'] = tz_aware - keyword_opts['connect'] = connect + connect = opts.get("connect", True) + keyword_opts["tz_aware"] = tz_aware + keyword_opts["connect"] = connect # Handle deprecated options in kwarg options. keyword_opts = _handle_option_deprecations(keyword_opts) # Validate kwarg options. - keyword_opts = common._CaseInsensitiveDictionary(dict(common.validate( - keyword_opts.cased_key(k), v) for k, v in keyword_opts.items())) + keyword_opts = common._CaseInsensitiveDictionary( + dict(common.validate(keyword_opts.cased_key(k), v) for k, v in keyword_opts.items()) + ) # Override connection string options with kwarg options. opts.update(keyword_opts) @@ -723,18 +737,19 @@ def __init__( # Username and password passed as kwargs override user info in URI. username = opts.get("username", username) password = opts.get("password", password) - self.__options = options = ClientOptions( - username, password, dbase, opts) + self.__options = options = ClientOptions(username, password, dbase, opts) self.__default_database_name = dbase self.__lock = threading.Lock() self.__kill_cursors_queue = [] self._event_listeners = options.pool_options._event_listeners - super(MongoClient, self).__init__(options.codec_options, - options.read_preference, - options.write_concern, - options.read_concern) + super(MongoClient, self).__init__( + options.codec_options, + options.read_preference, + options.write_concern, + options.read_concern, + ) self.__all_credentials = {} creds = options._credentials @@ -756,7 +771,7 @@ def __init__( direct_connection=options.direct_connection, load_balanced=options.load_balanced, srv_service_name=srv_service_name, - srv_max_hosts=srv_max_hosts + srv_max_hosts=srv_max_hosts, ) self._topology = Topology(self._topology_settings) @@ -772,7 +787,8 @@ def target(): interval=common.KILL_CURSOR_FREQUENCY, min_interval=common.MIN_HEARTBEAT_INTERVAL, target=target, - name="pymongo_kill_cursors_thread") + name="pymongo_kill_cursors_thread", + ) # We strongly reference the executor and it weakly references us via # this closure. When the client is freed, stop the executor soon. @@ -785,8 +801,8 @@ def target(): self._encrypter = None if self.__options.auto_encryption_opts: from pymongo.encryption import _Encrypter - self._encrypter = _Encrypter( - self, self.__options.auto_encryption_opts) + + self._encrypter = _Encrypter(self, self.__options.auto_encryption_opts) def _duplicate(self, **kwargs): args = self.__init_kwargs.copy() @@ -804,14 +820,22 @@ def _server_property(self, attr_name): the server may change. In such cases, store a local reference to a ServerDescription first, then use its properties. """ - server = self._topology.select_server( - writable_server_selector) + server = self._topology.select_server(writable_server_selector) return getattr(server.description, attr_name) - def watch(self, pipeline=None, full_document=None, resume_after=None, - max_await_time_ms=None, batch_size=None, collation=None, - start_at_operation_time=None, session=None, start_after=None): + def watch( + self, + pipeline=None, + full_document=None, + resume_after=None, + max_await_time_ms=None, + batch_size=None, + collation=None, + start_at_operation_time=None, + session=None, + start_after=None, + ): """Watch changes on this cluster. Performs an aggregation with an implicit initial ``$changeStream`` @@ -897,9 +921,17 @@ def watch(self, pipeline=None, full_document=None, resume_after=None, https://github.com/mongodb/specifications/blob/master/source/change-streams/change-streams.rst """ return ClusterChangeStream( - self.admin, pipeline, full_document, resume_after, max_await_time_ms, - batch_size, collation, start_at_operation_time, session, - start_after) + self.admin, + pipeline, + full_document, + resume_after, + max_await_time_ms, + batch_size, + collation, + start_at_operation_time, + session, + start_after, + ) @property def topology_description(self): @@ -938,17 +970,22 @@ def address(self): .. versionadded:: 3.0 """ topology_type = self._topology._description.topology_type - if (topology_type == TOPOLOGY_TYPE.Sharded and - len(self.topology_description.server_descriptions()) > 1): + if ( + topology_type == TOPOLOGY_TYPE.Sharded + and len(self.topology_description.server_descriptions()) > 1 + ): raise InvalidOperation( 'Cannot use "address" property when load balancing among' - ' mongoses, use "nodes" instead.') - if topology_type not in (TOPOLOGY_TYPE.ReplicaSetWithPrimary, - TOPOLOGY_TYPE.Single, - TOPOLOGY_TYPE.LoadBalanced, - TOPOLOGY_TYPE.Sharded): + ' mongoses, use "nodes" instead.' + ) + if topology_type not in ( + TOPOLOGY_TYPE.ReplicaSetWithPrimary, + TOPOLOGY_TYPE.Single, + TOPOLOGY_TYPE.LoadBalanced, + TOPOLOGY_TYPE.Sharded, + ): return None - return self._server_property('address') + return self._server_property("address") @property def primary(self): @@ -995,7 +1032,7 @@ def is_primary(self): connection is established or raise ServerSelectionTimeoutError if no server is available. """ - return self._server_property('is_writable') + return self._server_property("is_writable") @property def is_mongos(self): @@ -1003,7 +1040,7 @@ def is_mongos(self): connected, this will block until a connection is established or raise ServerSelectionTimeoutError if no server is available.. """ - return self._server_property('server_type') == SERVER_TYPE.Mongos + return self._server_property("server_type") == SERVER_TYPE.Mongos @property def nodes(self): @@ -1035,17 +1072,16 @@ def _end_sessions(self, session_ids): try: # Use SocketInfo.command directly to avoid implicitly creating # another session. - with self._socket_for_reads( - ReadPreference.PRIMARY_PREFERRED, - None) as (sock_info, secondary_ok): + with self._socket_for_reads(ReadPreference.PRIMARY_PREFERRED, None) as ( + sock_info, + secondary_ok, + ): if not sock_info.supports_sessions: return for i in range(0, len(session_ids), common._MAX_END_SESSIONS): - spec = SON([('endSessions', - session_ids[i:i + common._MAX_END_SESSIONS])]) - sock_info.command( - 'admin', spec, secondary_ok=secondary_ok, client=self) + spec = SON([("endSessions", session_ids[i : i + common._MAX_END_SESSIONS])]) + sock_info.command("admin", spec, secondary_ok=secondary_ok, client=self) except PyMongoError: # Drivers MUST ignore any errors returned by the endSessions # command. @@ -1097,19 +1133,22 @@ def _get_socket(self, server, session): if in_txn and session._pinned_connection: yield session._pinned_connection return - with server.get_socket( - self.__all_credentials, handler=err_handler) as sock_info: + with server.get_socket(self.__all_credentials, handler=err_handler) as sock_info: # Pin this session to the selected server or connection. - if (in_txn and server.description.server_type in ( - SERVER_TYPE.Mongos, SERVER_TYPE.LoadBalancer)): + if in_txn and server.description.server_type in ( + SERVER_TYPE.Mongos, + SERVER_TYPE.LoadBalancer, + ): session._pin(server, sock_info) err_handler.contribute_socket(sock_info) - if (self._encrypter and - not self._encrypter._bypass_auto_encryption and - sock_info.max_wire_version < 8): + if ( + self._encrypter + and not self._encrypter._bypass_auto_encryption + and sock_info.max_wire_version < 8 + ): raise ConfigurationError( - 'Auto-encryption requires a minimum MongoDB version ' - 'of 4.2') + "Auto-encryption requires a minimum MongoDB version " "of 4.2" + ) yield sock_info def _select_server(self, server_selector, session, address=None): @@ -1132,8 +1171,7 @@ def _select_server(self, server_selector, session, address=None): # We're running a getMore or this session is pinned to a mongos. server = topology.select_server_by_address(address) if not server: - raise AutoReconnect('server %s:%d no longer available' - % address) + raise AutoReconnect("server %s:%d no longer available" % address) else: server = topology.select_server(server_selector) return server @@ -1162,7 +1200,8 @@ def _secondaryok_for_server(self, read_preference, server, session): with self._get_socket(server, session) as sock_info: secondary_ok = (single and not sock_info.is_mongos) or ( - read_preference != ReadPreference.PRIMARY) + read_preference != ReadPreference.PRIMARY + ) yield sock_info, secondary_ok @contextlib.contextmanager @@ -1180,12 +1219,12 @@ def _socket_for_reads(self, read_preference, session): with self._get_socket(server, session) as sock_info: secondary_ok = (single and not sock_info.is_mongos) or ( - read_preference != ReadPreference.PRIMARY) + read_preference != ReadPreference.PRIMARY + ) yield sock_info, secondary_ok def _should_pin_cursor(self, session): - return (self.__options.load_balanced and - not (session and session.in_transaction)) + return self.__options.load_balanced and not (session and session.in_transaction) def _run_operation(self, operation, unpack_res, address=None): """Run a _Query/_GetMore operation and return a Response. @@ -1198,24 +1237,28 @@ def _run_operation(self, operation, unpack_res, address=None): """ if operation.sock_mgr: server = self._select_server( - operation.read_preference, operation.session, address=address) + operation.read_preference, operation.session, address=address + ) with operation.sock_mgr.lock: - with _MongoClientErrorHandler( - self, server, operation.session) as err_handler: + with _MongoClientErrorHandler(self, server, operation.session) as err_handler: err_handler.contribute_socket(operation.sock_mgr.sock) return server.run_operation( - operation.sock_mgr.sock, operation, True, - self._event_listeners, unpack_res) + operation.sock_mgr.sock, operation, True, self._event_listeners, unpack_res + ) def _cmd(session, server, sock_info, secondary_ok): return server.run_operation( - sock_info, operation, secondary_ok, self._event_listeners, - unpack_res) + sock_info, operation, secondary_ok, self._event_listeners, unpack_res + ) return self._retryable_read( - _cmd, operation.read_preference, operation.session, - address=address, retryable=isinstance(operation, message._Query)) + _cmd, + operation.read_preference, + operation.session, + address=address, + retryable=isinstance(operation, message._Query), + ) def _retry_with_session(self, retryable, func, session, bulk): """Execute an operation with at most one consecutive retries @@ -1225,8 +1268,9 @@ def _retry_with_session(self, retryable, func, session, bulk): Re-raises any exception thrown by func(). """ - retryable = (retryable and self.options.retry_writes - and session and not session.in_transaction) + retryable = ( + retryable and self.options.retry_writes and session and not session.in_transaction + ) return self._retry_internal(retryable, func, session, bulk) def _retry_internal(self, retryable, func, session, bulk): @@ -1237,6 +1281,7 @@ def _retry_internal(self, retryable, func, session, bulk): def is_retrying(): return bulk.retrying if bulk else retrying + # Increment the transaction id up front to ensure any retry attempt # will use the proper txnNumber, even if server or socket selection # fails before the command can be sent. @@ -1249,8 +1294,8 @@ def is_retrying(): try: server = self._select_server(writable_server_selector, session) supports_session = ( - session is not None and - server.description.retryable_writes_supported) + session is not None and server.description.retryable_writes_supported + ) with self._get_socket(server, session) as sock_info: max_wire_version = sock_info.max_wire_version if retryable and not supports_session: @@ -1286,8 +1331,7 @@ def is_retrying(): retrying = True last_error = exc - def _retryable_read(self, func, read_pref, session, address=None, - retryable=True): + def _retryable_read(self, func, read_pref, session, address=None, retryable=True): """Execute an operation with at most one consecutive retries Returns func()'s return value on success. On error retries the same @@ -1295,18 +1339,19 @@ def _retryable_read(self, func, read_pref, session, address=None, Re-raises any exception thrown by func(). """ - retryable = (retryable and - self.options.retry_reads - and not (session and session.in_transaction)) + retryable = ( + retryable and self.options.retry_reads and not (session and session.in_transaction) + ) last_error = None retrying = False while True: try: - server = self._select_server( - read_pref, session, address=address) + server = self._select_server(read_pref, session, address=address) with self._secondaryok_for_server(read_pref, server, session) as ( - sock_info, secondary_ok): + sock_info, + secondary_ok, + ): if retrying and not retryable: # A retry is not possible because this server does # not support retryable reads, raise the last error. @@ -1354,35 +1399,38 @@ def __hash__(self): def _repr_helper(self): def option_repr(option, value): """Fix options whose __repr__ isn't usable in a constructor.""" - if option == 'document_class': + if option == "document_class": if value is dict: - return 'document_class=dict' + return "document_class=dict" else: - return 'document_class=%s.%s' % (value.__module__, - value.__name__) + return "document_class=%s.%s" % (value.__module__, value.__name__) if option in common.TIMEOUT_OPTIONS and value is not None: return "%s=%s" % (option, int(value * 1000)) - return '%s=%r' % (option, value) + return "%s=%r" % (option, value) # Host first... - options = ['host=%r' % [ - '%s:%d' % (host, port) if port is not None else host - for host, port in self._topology_settings.seeds]] + options = [ + "host=%r" + % [ + "%s:%d" % (host, port) if port is not None else host + for host, port in self._topology_settings.seeds + ] + ] # ... then everything in self._constructor_args... options.extend( - option_repr(key, self.__options._options[key]) - for key in self._constructor_args) + option_repr(key, self.__options._options[key]) for key in self._constructor_args + ) # ... then everything else. options.extend( option_repr(key, self.__options._options[key]) for key in self.__options._options - if key not in set(self._constructor_args) - and key != 'username' and key != 'password') - return ', '.join(options) + if key not in set(self._constructor_args) and key != "username" and key != "password" + ) + return ", ".join(options) def __repr__(self): - return ("MongoClient(%s)" % (self._repr_helper(),)) + return "MongoClient(%s)" % (self._repr_helper(),) def __getattr__(self, name): """Get a database by name. @@ -1393,10 +1441,11 @@ def __getattr__(self, name): :Parameters: - `name`: the name of the database to get """ - if name.startswith('_'): + if name.startswith("_"): raise AttributeError( "MongoClient has no attribute %r. To access the %s" - " database, use client[%r]." % (name, name, name)) + " database, use client[%r]." % (name, name, name) + ) return self.__getitem__(name) def __getitem__(self, name): @@ -1410,8 +1459,9 @@ def __getitem__(self, name): """ return database.Database(self, name) - def _cleanup_cursor(self, locks_allowed, cursor_id, address, sock_mgr, - session, explicit_session): + def _cleanup_cursor( + self, locks_allowed, cursor_id, address, sock_mgr, session, explicit_session + ): """Cleanup a cursor from cursor.close() or __del__. This method handles cleanup for Cursors/CommandCursors including any @@ -1432,12 +1482,9 @@ def _cleanup_cursor(self, locks_allowed, cursor_id, address, sock_mgr, # If this is an exhaust cursor and we haven't completely # exhausted the result set we *must* close the socket # to stop the server from sending more data. - sock_mgr.sock.close_socket( - ConnectionClosedReason.ERROR) + sock_mgr.sock.close_socket(ConnectionClosedReason.ERROR) else: - self._close_cursor_now( - cursor_id, address, session=session, - sock_mgr=sock_mgr) + self._close_cursor_now(cursor_id, address, session=session, sock_mgr=sock_mgr) if sock_mgr: sock_mgr.close() else: @@ -1451,8 +1498,7 @@ def _close_cursor_soon(self, cursor_id, address, sock_mgr=None): """Request that a cursor and/or connection be cleaned up soon.""" self.__kill_cursors_queue.append((address, cursor_id, sock_mgr)) - def _close_cursor_now(self, cursor_id, address=None, session=None, - sock_mgr=None): + def _close_cursor_now(self, cursor_id, address=None, session=None, sock_mgr=None): """Send a kill cursors message with the given id. The cursor is closed synchronously on the current thread. @@ -1464,11 +1510,9 @@ def _close_cursor_now(self, cursor_id, address=None, session=None, if sock_mgr: with sock_mgr.lock: # Cursor is pinned to LB outside of a transaction. - self._kill_cursor_impl( - [cursor_id], address, session, sock_mgr.sock) + self._kill_cursor_impl([cursor_id], address, session, sock_mgr.sock) else: - self._kill_cursors( - [cursor_id], address, self._get_topology(), session) + self._kill_cursors([cursor_id], address, self._get_topology(), session) except PyMongoError: # Make another attempt to kill the cursor later. self._close_cursor_soon(cursor_id, address) @@ -1488,8 +1532,8 @@ def _kill_cursors(self, cursor_ids, address, topology, session): def _kill_cursor_impl(self, cursor_ids, address, session, sock_info): namespace = address.namespace - db, coll = namespace.split('.', 1) - spec = SON([('killCursors', coll), ('cursors', cursor_ids)]) + db, coll = namespace.split(".", 1) + spec = SON([("killCursors", coll), ("cursors", cursor_ids)]) sock_info.command(db, spec, session=session, client=self) def _process_kill_cursors(self): @@ -1511,11 +1555,9 @@ def _process_kill_cursors(self): for address, cursor_id, sock_mgr in pinned_cursors: try: - self._cleanup_cursor(True, cursor_id, address, sock_mgr, - None, False) + self._cleanup_cursor(True, cursor_id, address, sock_mgr, None, False) except Exception as exc: - if (isinstance(exc, InvalidOperation) - and self._topology._closed): + if isinstance(exc, InvalidOperation) and self._topology._closed: # Raise the exception when client is closed so that it # can be caught in _process_periodic_tasks raise @@ -1527,11 +1569,9 @@ def _process_kill_cursors(self): topology = self._get_topology() for address, cursor_ids in address_to_cursor_ids.items(): try: - self._kill_cursors( - cursor_ids, address, topology, session=None) + self._kill_cursors(cursor_ids, address, topology, session=None) except Exception as exc: - if (isinstance(exc, InvalidOperation) and - self._topology._closed): + if isinstance(exc, InvalidOperation) and self._topology._closed: raise else: helpers._handle_exception() @@ -1553,13 +1593,11 @@ def __start_session(self, implicit, **kwargs): # Raises ConfigurationError if sessions are not supported. server_session = self._get_server_session() opts = client_session.SessionOptions(**kwargs) - return client_session.ClientSession( - self, server_session, opts, implicit) + return client_session.ClientSession(self, server_session, opts, implicit) - def start_session(self, - causal_consistency=None, - default_transaction_options=None, - snapshot=False): + def start_session( + self, causal_consistency=None, default_transaction_options=None, snapshot=False + ): """Start a logical session. This method takes the same parameters as @@ -1583,7 +1621,8 @@ def start_session(self, False, causal_consistency=causal_consistency, default_transaction_options=default_transaction_options, - snapshot=snapshot) + snapshot=snapshot, + ) def _get_server_session(self): """Internal: start or resume a _ServerSession.""" @@ -1636,17 +1675,17 @@ def _send_cluster_time(self, command, session): topology_time = self._topology.max_cluster_time() session_time = session.cluster_time if session else None if topology_time and session_time: - if topology_time['clusterTime'] > session_time['clusterTime']: + if topology_time["clusterTime"] > session_time["clusterTime"]: cluster_time = topology_time else: cluster_time = session_time else: cluster_time = topology_time or session_time if cluster_time: - command['$clusterTime'] = cluster_time + command["$clusterTime"] = cluster_time def _process_response(self, reply, session): - self._topology.receive_cluster_time(reply.get('$clusterTime')) + self._topology.receive_cluster_time(reply.get("$clusterTime")) if session is not None: session._process_response(reply) @@ -1660,9 +1699,9 @@ def server_info(self, session=None): .. versionchanged:: 3.6 Added ``session`` parameter. """ - return self.admin.command("buildinfo", - read_preference=ReadPreference.PRIMARY, - session=session) + return self.admin.command( + "buildinfo", read_preference=ReadPreference.PRIMARY, session=session + ) def list_databases(self, session=None, **kwargs): """Get a cursor over the databases of the connected server. @@ -1702,8 +1741,7 @@ def list_database_names(self, session=None): .. versionadded:: 3.6 """ - return [doc["name"] - for doc in self.list_databases(session, nameOnly=True)] + return [doc["name"] for doc in self.list_databases(session, nameOnly=True)] def drop_database(self, name_or_database, session=None): """Drop a database. @@ -1736,8 +1774,7 @@ def drop_database(self, name_or_database, session=None): name = name.name if not isinstance(name, str): - raise TypeError("name_or_database must be an instance " - "of str or a Database") + raise TypeError("name_or_database must be an instance " "of str or a Database") with self._socket_for_writes(session) as sock_info: self[name]._command( @@ -1746,10 +1783,17 @@ def drop_database(self, name_or_database, session=None): read_preference=ReadPreference.PRIMARY, write_concern=self._write_concern_for(session), parse_write_concern_error=True, - session=session) - - def get_default_database(self, default=None, codec_options=None, - read_preference=None, write_concern=None, read_concern=None): + session=session, + ) + + def get_default_database( + self, + default=None, + codec_options=None, + read_preference=None, + write_concern=None, + read_concern=None, + ): """Get the database named in the MongoDB connection URI. >>> uri = 'mongodb://host/my_database' @@ -1791,15 +1835,25 @@ def get_default_database(self, default=None, codec_options=None, Deprecated, use :meth:`get_database` instead. """ if self.__default_database_name is None and default is None: - raise ConfigurationError( - 'No default database name defined or provided.') + raise ConfigurationError("No default database name defined or provided.") return database.Database( - self, self.__default_database_name or default, codec_options, - read_preference, write_concern, read_concern) + self, + self.__default_database_name or default, + codec_options, + read_preference, + write_concern, + read_concern, + ) - def get_database(self, name=None, codec_options=None, read_preference=None, - write_concern=None, read_concern=None): + def get_database( + self, + name=None, + codec_options=None, + read_preference=None, + write_concern=None, + read_concern=None, + ): """Get a :class:`~pymongo.database.Database` with the given name and options. @@ -1845,19 +1899,21 @@ def get_database(self, name=None, codec_options=None, read_preference=None, """ if name is None: if self.__default_database_name is None: - raise ConfigurationError('No default database defined') + raise ConfigurationError("No default database defined") name = self.__default_database_name return database.Database( - self, name, codec_options, read_preference, - write_concern, read_concern) + self, name, codec_options, read_preference, write_concern, read_concern + ) def _database_default_options(self, name): """Get a Database instance with the default settings.""" return self.get_database( - name, codec_options=DEFAULT_CODEC_OPTIONS, + name, + codec_options=DEFAULT_CODEC_OPTIONS, read_preference=ReadPreference.PRIMARY, - write_concern=DEFAULT_WRITE_CONCERN) + write_concern=DEFAULT_WRITE_CONCERN, + ) def __enter__(self): return self @@ -1879,7 +1935,7 @@ def _retryable_error_doc(exc): if isinstance(exc, BulkWriteError): # Check the last writeConcernError to determine if this # BulkWriteError is retryable. - wces = exc.details['writeConcernErrors'] + wces = exc.details["writeConcernErrors"] wce = wces[-1] if wces else None return wce if isinstance(exc, (NotPrimaryError, OperationFailure)): @@ -1890,18 +1946,18 @@ def _retryable_error_doc(exc): def _add_retryable_write_error(exc, max_wire_version): doc = _retryable_error_doc(exc) if doc: - code = doc.get('code', 0) + code = doc.get("code", 0) # retryWrites on MMAPv1 should raise an actionable error. - if (code == 20 and - str(exc).startswith("Transaction numbers")): + if code == 20 and str(exc).startswith("Transaction numbers"): errmsg = ( "This MongoDB deployment does not support " "retryable writes. Please add retryWrites=false " - "to your connection string.") + "to your connection string." + ) raise OperationFailure(errmsg, code, exc.details) if max_wire_version >= 9: # In MongoDB 4.4+, the server reports the error labels. - for label in doc.get('errorLabels', []): + for label in doc.get("errorLabels", []): exc._add_error_label(label) else: if code in helpers._RETRYABLE_ERROR_CODES: @@ -1909,16 +1965,23 @@ def _add_retryable_write_error(exc, max_wire_version): # Connection errors are always retryable except NotPrimaryError which is # handled above. - if (isinstance(exc, ConnectionFailure) and - not isinstance(exc, NotPrimaryError)): + if isinstance(exc, ConnectionFailure) and not isinstance(exc, NotPrimaryError): exc._add_error_label("RetryableWriteError") class _MongoClientErrorHandler(object): """Handle errors raised when executing an operation.""" - __slots__ = ('client', 'server_address', 'session', 'max_wire_version', - 'sock_generation', 'completed_handshake', 'service_id', - 'handled') + + __slots__ = ( + "client", + "server_address", + "session", + "max_wire_version", + "sock_generation", + "completed_handshake", + "service_id", + "handled", + ) def __init__(self, client, server, session): self.client = client @@ -1952,13 +2015,18 @@ def handle(self, exc_type, exc_val): self.session._server_session.mark_dirty() if issubclass(exc_type, PyMongoError): - if (exc_val.has_error_label("TransientTransactionError") or - exc_val.has_error_label("RetryableWriteError")): + if exc_val.has_error_label("TransientTransactionError") or exc_val.has_error_label( + "RetryableWriteError" + ): self.session._unpin() err_ctx = _ErrorContext( - exc_val, self.max_wire_version, self.sock_generation, - self.completed_handshake, self.service_id) + exc_val, + self.max_wire_version, + self.sock_generation, + self.completed_handshake, + self.service_id, + ) self.client._topology.handle_error(self.server_address, err_ctx) def __enter__(self): diff --git a/pymongo/monitor.py b/pymongo/monitor.py index a383e272cd..80dfce5356 100644 --- a/pymongo/monitor.py +++ b/pymongo/monitor.py @@ -20,9 +20,7 @@ import weakref from pymongo import common, periodic_executor -from pymongo.errors import (NotPrimaryError, - OperationFailure, - _OperationCancelled) +from pymongo.errors import NotPrimaryError, OperationFailure, _OperationCancelled from pymongo.hello import Hello from pymongo.periodic_executor import _shutdown_executors from pymongo.read_preferences import MovingAverage @@ -54,10 +52,8 @@ def target(): return True executor = periodic_executor.PeriodicExecutor( - interval=interval, - min_interval=min_interval, - target=target, - name=name) + interval=interval, min_interval=min_interval, target=target, name=name + ) self._executor = executor @@ -101,12 +97,7 @@ def request_check(self): class Monitor(MonitorBase): - def __init__( - self, - server_description, - topology, - pool, - topology_settings): + def __init__(self, server_description, topology, pool, topology_settings): """Class to monitor a MongoDB server on a background thread. Pass an initial ServerDescription, a Topology, a Pool, and @@ -119,7 +110,8 @@ def __init__( topology, "pymongo_server_monitor_thread", topology_settings.heartbeat_frequency, - common.MIN_HEARTBEAT_INTERVAL) + common.MIN_HEARTBEAT_INTERVAL, + ) self._server_description = server_description self._pool = pool self._settings = topology_settings @@ -128,8 +120,10 @@ def __init__( self._publish = pub and self._listeners.enabled_for_server_heartbeat self._cancel_context = None self._rtt_monitor = _RttMonitor( - topology, topology_settings, topology._create_pool_for_monitor( - server_description.address)) + topology, + topology_settings, + topology._create_pool_for_monitor(server_description.address), + ) self.heartbeater = None def cancel_check(self): @@ -179,7 +173,8 @@ def _run(self): _sanitize(exc) # Already closed the connection, wait for the next check. self._server_description = ServerDescription( - self._server_description.address, error=exc) + self._server_description.address, error=exc + ) if prev_sd.is_server_type_known: # Immediately retry since we've already waited 500ms to # discover that we've been cancelled. @@ -187,11 +182,14 @@ def _run(self): return # Update the Topology and clear the server pool on error. - self._topology.on_change(self._server_description, - reset_pool=self._server_description.error) - - if (self._server_description.is_server_type_known and - self._server_description.topology_version): + self._topology.on_change( + self._server_description, reset_pool=self._server_description.error + ) + + if ( + self._server_description.is_server_type_known + and self._server_description.topology_version + ): self._start_rtt_monitor() # Immediately check for the next streaming response. self._executor.skip_sleep() @@ -214,8 +212,7 @@ def _check_server(self): return self._check_once() except (OperationFailure, NotPrimaryError) as exc: # Update max cluster time even when hello fails. - self._topology.receive_cluster_time( - exc.details.get('$clusterTime')) + self._topology.receive_cluster_time(exc.details.get("$clusterTime")) raise except ReferenceError: raise @@ -226,8 +223,7 @@ def _check_server(self): duration = time.monotonic() - start if self._publish: awaited = sd.is_server_type_known and sd.topology_version - self._listeners.publish_server_heartbeat_failed( - address, duration, error, awaited) + self._listeners.publish_server_heartbeat_failed(address, duration, error, awaited) self._reset_connection() if isinstance(error, _OperationCancelled): raise @@ -252,11 +248,11 @@ def _check_once(self): if not response.awaitable: self._rtt_monitor.add_sample(round_trip_time) - sd = ServerDescription(address, response, - self._rtt_monitor.average()) + sd = ServerDescription(address, response, self._rtt_monitor.average()) if self._publish: self._listeners.publish_server_heartbeat_succeeded( - address, round_trip_time, response, response.awaitable) + address, round_trip_time, response, response.awaitable + ) return sd def _check_with_socket(self, conn): @@ -269,14 +265,14 @@ def _check_with_socket(self, conn): if conn.more_to_come: # Read the next streaming hello (MongoDB 4.4+). response = Hello(conn._next_reply(), awaitable=True) - elif (conn.performed_handshake and - self._server_description.topology_version): + elif conn.performed_handshake and self._server_description.topology_version: # Initiate streaming hello (MongoDB 4.4+). response = conn._hello( cluster_time, self._server_description.topology_version, self._settings.heartbeat_frequency, - None) + None, + ) else: # New connection handshake or polling hello (MongoDB <4.4). response = conn._hello(cluster_time, None, None, None) @@ -295,7 +291,8 @@ def __init__(self, topology, topology_settings): topology, "pymongo_srv_polling_thread", common.MIN_SRV_RESCAN_INTERVAL, - topology_settings.heartbeat_frequency) + topology_settings.heartbeat_frequency, + ) self._settings = topology_settings self._seedlist = self._settings._seeds self._fqdn = self._settings.fqdn @@ -316,9 +313,11 @@ def _get_seedlist(self): Returns a list of ServerDescriptions. """ try: - resolver = _SrvResolver(self._fqdn, - self._settings.pool_options.connect_timeout, - self._settings.srv_service_name) + resolver = _SrvResolver( + self._fqdn, + self._settings.pool_options.connect_timeout, + self._settings.srv_service_name, + ) seedlist, ttl = resolver.get_hosts_and_min_ttl() if len(seedlist) == 0: # As per the spec: this should be treated as a failure. @@ -331,8 +330,7 @@ def _get_seedlist(self): self.request_check() return None else: - self._executor.update_interval( - max(ttl, common.MIN_SRV_RESCAN_INTERVAL)) + self._executor.update_interval(max(ttl, common.MIN_SRV_RESCAN_INTERVAL)) return seedlist @@ -346,7 +344,8 @@ def __init__(self, topology, topology_settings, pool): topology, "pymongo_server_rtt_thread", topology_settings.heartbeat_frequency, - common.MIN_HEARTBEAT_INTERVAL) + common.MIN_HEARTBEAT_INTERVAL, + ) self._pool = pool self._moving_average = MovingAverage() @@ -390,7 +389,7 @@ def _ping(self): """Run a "hello" command and return the RTT.""" with self._pool.get_socket({}) as sock_info: if self._executor._stopped: - raise Exception('_RttMonitor closed') + raise Exception("_RttMonitor closed") start = time.monotonic() sock_info.hello() return time.monotonic() - start diff --git a/pymongo/monitoring.py b/pymongo/monitoring.py index b877e19a23..fb9995ee32 100644 --- a/pymongo/monitoring.py +++ b/pymongo/monitoring.py @@ -185,10 +185,16 @@ def connection_checked_in(self, event): from pymongo.hello import HelloCompat from pymongo.helpers import _handle_exception -_Listeners = namedtuple('Listeners', - ('command_listeners', 'server_listeners', - 'server_heartbeat_listeners', 'topology_listeners', - 'cmap_listeners')) +_Listeners = namedtuple( + "Listeners", + ( + "command_listeners", + "server_listeners", + "server_heartbeat_listeners", + "topology_listeners", + "cmap_listeners", + ), +) _LISTENERS = _Listeners([], [], [], [], []) @@ -471,10 +477,12 @@ def _validate_event_listeners(option, listeners): raise TypeError("%s must be a list or tuple" % (option,)) for listener in listeners: if not isinstance(listener, _EventListener): - raise TypeError("Listeners for %s must be either a " - "CommandListener, ServerHeartbeatListener, " - "ServerListener, TopologyListener, or " - "ConnectionPoolListener." % (option,)) + raise TypeError( + "Listeners for %s must be either a " + "CommandListener, ServerHeartbeatListener, " + "ServerListener, TopologyListener, or " + "ConnectionPoolListener." % (option,) + ) return listeners @@ -487,10 +495,12 @@ def register(listener): :class:`TopologyListener`, or :class:`ConnectionPoolListener`. """ if not isinstance(listener, _EventListener): - raise TypeError("Listeners for %s must be either a " - "CommandListener, ServerHeartbeatListener, " - "ServerListener, TopologyListener, or " - "ConnectionPoolListener." % (listener,)) + raise TypeError( + "Listeners for %s must be either a " + "CommandListener, ServerHeartbeatListener, " + "ServerListener, TopologyListener, or " + "ConnectionPoolListener." % (listener,) + ) if isinstance(listener, CommandListener): _LISTENERS.command_listeners.append(listener) if isinstance(listener, ServerHeartbeatListener): @@ -502,19 +512,32 @@ def register(listener): if isinstance(listener, ConnectionPoolListener): _LISTENERS.cmap_listeners.append(listener) + # Note - to avoid bugs from forgetting which if these is all lowercase and # which are camelCase, and at the same time avoid having to add a test for # every command, use all lowercase here and test against command_name.lower(). _SENSITIVE_COMMANDS = set( - ["authenticate", "saslstart", "saslcontinue", "getnonce", "createuser", - "updateuser", "copydbgetnonce", "copydbsaslstart", "copydb"]) + [ + "authenticate", + "saslstart", + "saslcontinue", + "getnonce", + "createuser", + "updateuser", + "copydbgetnonce", + "copydbsaslstart", + "copydb", + ] +) # The "hello" command is also deemed sensitive when attempting speculative # authentication. def _is_speculative_authenticate(command_name, doc): - if (command_name.lower() in ('hello', HelloCompat.LEGACY_CMD) and - 'speculativeAuthenticate' in doc): + if ( + command_name.lower() in ("hello", HelloCompat.LEGACY_CMD) + and "speculativeAuthenticate" in doc + ): return True return False @@ -522,11 +545,9 @@ def _is_speculative_authenticate(command_name, doc): class _CommandEvent(object): """Base class for command events.""" - __slots__ = ("__cmd_name", "__rqst_id", "__conn_id", "__op_id", - "__service_id") + __slots__ = ("__cmd_name", "__rqst_id", "__conn_id", "__op_id", "__service_id") - def __init__(self, command_name, request_id, connection_id, operation_id, - service_id=None): + def __init__(self, command_name, request_id, connection_id, operation_id, service_id=None): self.__cmd_name = command_name self.__rqst_id = request_id self.__conn_id = connection_id @@ -574,6 +595,7 @@ class CommandStartedEvent(_CommandEvent): - `operation_id`: An optional identifier for a series of related events. - `service_id`: The service_id this command was sent to, or ``None``. """ + __slots__ = ("__cmd", "__db") def __init__(self, command, database_name, *args, service_id=None): @@ -581,11 +603,9 @@ def __init__(self, command, database_name, *args, service_id=None): raise ValueError("%r is not a valid command" % (command,)) # Command name must be first key. command_name = next(iter(command)) - super(CommandStartedEvent, self).__init__( - command_name, *args, service_id=service_id) + super(CommandStartedEvent, self).__init__(command_name, *args, service_id=service_id) cmd_name, cmd_doc = command_name.lower(), command[command_name] - if (cmd_name in _SENSITIVE_COMMANDS or - _is_speculative_authenticate(cmd_name, command)): + if cmd_name in _SENSITIVE_COMMANDS or _is_speculative_authenticate(cmd_name, command): self.__cmd = {} else: self.__cmd = command @@ -602,12 +622,14 @@ def database_name(self): return self.__db def __repr__(self): - return ( - "<%s %s db: %r, command: %r, operation_id: %s, " - "service_id: %s>") % ( - self.__class__.__name__, self.connection_id, - self.database_name, self.command_name, self.operation_id, - self.service_id) + return ("<%s %s db: %r, command: %r, operation_id: %s, " "service_id: %s>") % ( + self.__class__.__name__, + self.connection_id, + self.database_name, + self.command_name, + self.operation_id, + self.service_id, + ) class CommandSucceededEvent(_CommandEvent): @@ -623,17 +645,25 @@ class CommandSucceededEvent(_CommandEvent): - `operation_id`: An optional identifier for a series of related events. - `service_id`: The service_id this command was sent to, or ``None``. """ + __slots__ = ("__duration_micros", "__reply") - def __init__(self, duration, reply, command_name, - request_id, connection_id, operation_id, service_id=None): + def __init__( + self, + duration, + reply, + command_name, + request_id, + connection_id, + operation_id, + service_id=None, + ): super(CommandSucceededEvent, self).__init__( - command_name, request_id, connection_id, operation_id, - service_id=service_id) + command_name, request_id, connection_id, operation_id, service_id=service_id + ) self.__duration_micros = _to_micros(duration) cmd_name = command_name.lower() - if (cmd_name in _SENSITIVE_COMMANDS or - _is_speculative_authenticate(cmd_name, reply)): + if cmd_name in _SENSITIVE_COMMANDS or _is_speculative_authenticate(cmd_name, reply): self.__reply = {} else: self.__reply = reply @@ -649,12 +679,14 @@ def reply(self): return self.__reply def __repr__(self): - return ( - "<%s %s command: %r, operation_id: %s, duration_micros: %s, " - "service_id: %s>") % ( - self.__class__.__name__, self.connection_id, - self.command_name, self.operation_id, self.duration_micros, - self.service_id) + return ("<%s %s command: %r, operation_id: %s, duration_micros: %s, " "service_id: %s>") % ( + self.__class__.__name__, + self.connection_id, + self.command_name, + self.operation_id, + self.duration_micros, + self.service_id, + ) class CommandFailedEvent(_CommandEvent): @@ -670,6 +702,7 @@ class CommandFailedEvent(_CommandEvent): - `operation_id`: An optional identifier for a series of related events. - `service_id`: The service_id this command was sent to, or ``None``. """ + __slots__ = ("__duration_micros", "__failure") def __init__(self, duration, failure, *args, service_id=None): @@ -690,14 +723,21 @@ def failure(self): def __repr__(self): return ( "<%s %s command: %r, operation_id: %s, duration_micros: %s, " - "failure: %r, service_id: %s>") % ( - self.__class__.__name__, self.connection_id, self.command_name, - self.operation_id, self.duration_micros, self.failure, - self.service_id) + "failure: %r, service_id: %s>" + ) % ( + self.__class__.__name__, + self.connection_id, + self.command_name, + self.operation_id, + self.duration_micros, + self.failure, + self.service_id, + ) class _PoolEvent(object): """Base class for pool events.""" + __slots__ = ("__address",) def __init__(self, address): @@ -711,7 +751,7 @@ def address(self): return self.__address def __repr__(self): - return '%s(%r)' % (self.__class__.__name__, self.__address) + return "%s(%r)" % (self.__class__.__name__, self.__address) class PoolCreatedEvent(_PoolEvent): @@ -723,6 +763,7 @@ class PoolCreatedEvent(_PoolEvent): .. versionadded:: 3.9 """ + __slots__ = ("__options",) def __init__(self, address, options): @@ -731,13 +772,11 @@ def __init__(self, address, options): @property def options(self): - """Any non-default pool options that were set on this Connection Pool. - """ + """Any non-default pool options that were set on this Connection Pool.""" return self.__options def __repr__(self): - return '%s(%r, %r)' % ( - self.__class__.__name__, self.address, self.__options) + return "%s(%r, %r)" % (self.__class__.__name__, self.address, self.__options) class PoolReadyEvent(_PoolEvent): @@ -749,6 +788,7 @@ class PoolReadyEvent(_PoolEvent): .. versionadded:: 4.0 """ + __slots__ = () @@ -762,6 +802,7 @@ class PoolClearedEvent(_PoolEvent): .. versionadded:: 3.9 """ + __slots__ = ("__service_id",) def __init__(self, address, service_id=None): @@ -779,8 +820,7 @@ def service_id(self): return self.__service_id def __repr__(self): - return '%s(%r, %r)' % ( - self.__class__.__name__, self.address, self.__service_id) + return "%s(%r, %r)" % (self.__class__.__name__, self.address, self.__service_id) class PoolClosedEvent(_PoolEvent): @@ -792,6 +832,7 @@ class PoolClosedEvent(_PoolEvent): .. versionadded:: 3.9 """ + __slots__ = () @@ -802,17 +843,17 @@ class ConnectionClosedReason(object): .. versionadded:: 3.9 """ - STALE = 'stale' + STALE = "stale" """The pool was cleared, making the connection no longer valid.""" - IDLE = 'idle' + IDLE = "idle" """The connection became stale by being idle for too long (maxIdleTimeMS). """ - ERROR = 'error' + ERROR = "error" """The connection experienced an error, making it no longer valid.""" - POOL_CLOSED = 'poolClosed' + POOL_CLOSED = "poolClosed" """The pool was closed, making the connection no longer valid.""" @@ -823,13 +864,13 @@ class ConnectionCheckOutFailedReason(object): .. versionadded:: 3.9 """ - TIMEOUT = 'timeout' + TIMEOUT = "timeout" """The connection check out attempt exceeded the specified timeout.""" - POOL_CLOSED = 'poolClosed' + POOL_CLOSED = "poolClosed" """The pool was previously closed, and cannot provide new connections.""" - CONN_ERROR = 'connectionError' + CONN_ERROR = "connectionError" """The connection check out attempt experienced an error while setting up a new connection. """ @@ -837,6 +878,7 @@ class ConnectionCheckOutFailedReason(object): class _ConnectionEvent(object): """Private base class for some connection events.""" + __slots__ = ("__address", "__connection_id") def __init__(self, address, connection_id): @@ -856,8 +898,7 @@ def connection_id(self): return self.__connection_id def __repr__(self): - return '%s(%r, %r)' % ( - self.__class__.__name__, self.__address, self.__connection_id) + return "%s(%r, %r)" % (self.__class__.__name__, self.__address, self.__connection_id) class ConnectionCreatedEvent(_ConnectionEvent): @@ -873,6 +914,7 @@ class ConnectionCreatedEvent(_ConnectionEvent): .. versionadded:: 3.9 """ + __slots__ = () @@ -886,6 +928,7 @@ class ConnectionReadyEvent(_ConnectionEvent): .. versionadded:: 3.9 """ + __slots__ = () @@ -900,6 +943,7 @@ class ConnectionClosedEvent(_ConnectionEvent): .. versionadded:: 3.9 """ + __slots__ = ("__reason",) def __init__(self, address, connection_id, reason): @@ -916,9 +960,12 @@ def reason(self): return self.__reason def __repr__(self): - return '%s(%r, %r, %r)' % ( - self.__class__.__name__, self.address, self.connection_id, - self.__reason) + return "%s(%r, %r, %r)" % ( + self.__class__.__name__, + self.address, + self.connection_id, + self.__reason, + ) class ConnectionCheckOutStartedEvent(object): @@ -930,6 +977,7 @@ class ConnectionCheckOutStartedEvent(object): .. versionadded:: 3.9 """ + __slots__ = ("__address",) def __init__(self, address): @@ -943,7 +991,7 @@ def address(self): return self.__address def __repr__(self): - return '%s(%r)' % (self.__class__.__name__, self.__address) + return "%s(%r)" % (self.__class__.__name__, self.__address) class ConnectionCheckOutFailedEvent(object): @@ -956,6 +1004,7 @@ class ConnectionCheckOutFailedEvent(object): .. versionadded:: 3.9 """ + __slots__ = ("__address", "__reason") def __init__(self, address, reason): @@ -979,8 +1028,7 @@ def reason(self): return self.__reason def __repr__(self): - return '%s(%r, %r)' % ( - self.__class__.__name__, self.__address, self.__reason) + return "%s(%r, %r)" % (self.__class__.__name__, self.__address, self.__reason) class ConnectionCheckedOutEvent(_ConnectionEvent): @@ -993,6 +1041,7 @@ class ConnectionCheckedOutEvent(_ConnectionEvent): .. versionadded:: 3.9 """ + __slots__ = () @@ -1006,6 +1055,7 @@ class ConnectionCheckedInEvent(_ConnectionEvent): .. versionadded:: 3.9 """ + __slots__ = () @@ -1030,7 +1080,10 @@ def topology_id(self): def __repr__(self): return "<%s %s topology_id: %s>" % ( - self.__class__.__name__, self.server_address, self.topology_id) + self.__class__.__name__, + self.server_address, + self.topology_id, + ) class ServerDescriptionChangedEvent(_ServerEvent): @@ -1039,7 +1092,7 @@ class ServerDescriptionChangedEvent(_ServerEvent): .. versionadded:: 3.3 """ - __slots__ = ('__previous_description', '__new_description') + __slots__ = ("__previous_description", "__new_description") def __init__(self, previous_description, new_description, *args): super(ServerDescriptionChangedEvent, self).__init__(*args) @@ -1060,8 +1113,11 @@ def new_description(self): def __repr__(self): return "<%s %s changed from: %s, to: %s>" % ( - self.__class__.__name__, self.server_address, - self.previous_description, self.new_description) + self.__class__.__name__, + self.server_address, + self.previous_description, + self.new_description, + ) class ServerOpeningEvent(_ServerEvent): @@ -1085,7 +1141,7 @@ class ServerClosedEvent(_ServerEvent): class TopologyEvent(object): """Base class for topology description events.""" - __slots__ = ('__topology_id') + __slots__ = "__topology_id" def __init__(self, topology_id): self.__topology_id = topology_id @@ -1096,8 +1152,7 @@ def topology_id(self): return self.__topology_id def __repr__(self): - return "<%s topology_id: %s>" % ( - self.__class__.__name__, self.topology_id) + return "<%s topology_id: %s>" % (self.__class__.__name__, self.topology_id) class TopologyDescriptionChangedEvent(TopologyEvent): @@ -1106,9 +1161,9 @@ class TopologyDescriptionChangedEvent(TopologyEvent): .. versionadded:: 3.3 """ - __slots__ = ('__previous_description', '__new_description') + __slots__ = ("__previous_description", "__new_description") - def __init__(self, previous_description, new_description, *args): + def __init__(self, previous_description, new_description, *args): super(TopologyDescriptionChangedEvent, self).__init__(*args) self.__previous_description = previous_description self.__new_description = new_description @@ -1127,8 +1182,11 @@ def new_description(self): def __repr__(self): return "<%s topology_id: %s changed from: %s, to: %s>" % ( - self.__class__.__name__, self.topology_id, - self.previous_description, self.new_description) + self.__class__.__name__, + self.topology_id, + self.previous_description, + self.new_description, + ) class TopologyOpenedEvent(TopologyEvent): @@ -1152,7 +1210,7 @@ class TopologyClosedEvent(TopologyEvent): class _ServerHeartbeatEvent(object): """Base class for server heartbeat events.""" - __slots__ = ('__connection_id') + __slots__ = "__connection_id" def __init__(self, connection_id): self.__connection_id = connection_id @@ -1182,7 +1240,7 @@ class ServerHeartbeatSucceededEvent(_ServerHeartbeatEvent): .. versionadded:: 3.3 """ - __slots__ = ('__duration', '__reply', '__awaited') + __slots__ = ("__duration", "__reply", "__awaited") def __init__(self, duration, reply, connection_id, awaited=False): super(ServerHeartbeatSucceededEvent, self).__init__(connection_id) @@ -1212,8 +1270,12 @@ def awaited(self): def __repr__(self): return "<%s %s duration: %s, awaited: %s, reply: %s>" % ( - self.__class__.__name__, self.connection_id, - self.duration, self.awaited, self.reply) + self.__class__.__name__, + self.connection_id, + self.duration, + self.awaited, + self.reply, + ) class ServerHeartbeatFailedEvent(_ServerHeartbeatEvent): @@ -1223,7 +1285,7 @@ class ServerHeartbeatFailedEvent(_ServerHeartbeatEvent): .. versionadded:: 3.3 """ - __slots__ = ('__duration', '__reply', '__awaited') + __slots__ = ("__duration", "__reply", "__awaited") def __init__(self, duration, reply, connection_id, awaited=False): super(ServerHeartbeatFailedEvent, self).__init__(connection_id) @@ -1253,8 +1315,12 @@ def awaited(self): def __repr__(self): return "<%s %s duration: %s, awaited: %s, reply: %r>" % ( - self.__class__.__name__, self.connection_id, - self.duration, self.awaited, self.reply) + self.__class__.__name__, + self.connection_id, + self.duration, + self.awaited, + self.reply, + ) class _EventListeners(object): @@ -1265,6 +1331,7 @@ class _EventListeners(object): :Parameters: - `listeners`: A list of event listeners. """ + def __init__(self, listeners): self.__command_listeners = _LISTENERS.command_listeners[:] self.__server_listeners = _LISTENERS.server_listeners[:] @@ -1286,8 +1353,7 @@ def __init__(self, listeners): self.__cmap_listeners.append(lst) self.__enabled_for_commands = bool(self.__command_listeners) self.__enabled_for_server = bool(self.__server_listeners) - self.__enabled_for_server_heartbeat = bool( - self.__server_heartbeat_listeners) + self.__enabled_for_server_heartbeat = bool(self.__server_heartbeat_listeners) self.__enabled_for_topology = bool(self.__topology_listeners) self.__enabled_for_cmap = bool(self.__cmap_listeners) @@ -1318,15 +1384,17 @@ def enabled_for_cmap(self): def event_listeners(self): """List of registered event listeners.""" - return (self.__command_listeners + - self.__server_heartbeat_listeners + - self.__server_listeners + - self.__topology_listeners + - self.__cmap_listeners) - - def publish_command_start(self, command, database_name, - request_id, connection_id, op_id=None, - service_id=None): + return ( + self.__command_listeners + + self.__server_heartbeat_listeners + + self.__server_listeners + + self.__topology_listeners + + self.__cmap_listeners + ) + + def publish_command_start( + self, command, database_name, request_id, connection_id, op_id=None, service_id=None + ): """Publish a CommandStartedEvent to all command listeners. :Parameters: @@ -1342,18 +1410,25 @@ def publish_command_start(self, command, database_name, if op_id is None: op_id = request_id event = CommandStartedEvent( - command, database_name, request_id, connection_id, op_id, - service_id=service_id) + command, database_name, request_id, connection_id, op_id, service_id=service_id + ) for subscriber in self.__command_listeners: try: subscriber.started(event) except Exception: _handle_exception() - def publish_command_success(self, duration, reply, command_name, - request_id, connection_id, op_id=None, - service_id=None, - speculative_hello=False): + def publish_command_success( + self, + duration, + reply, + command_name, + request_id, + connection_id, + op_id=None, + service_id=None, + speculative_hello=False, + ): """Publish a CommandSucceededEvent to all command listeners. :Parameters: @@ -1374,17 +1449,24 @@ def publish_command_success(self, duration, reply, command_name, # speculativeAuthenticate. reply = {} event = CommandSucceededEvent( - duration, reply, command_name, request_id, connection_id, op_id, - service_id) + duration, reply, command_name, request_id, connection_id, op_id, service_id + ) for subscriber in self.__command_listeners: try: subscriber.succeeded(event) except Exception: _handle_exception() - def publish_command_failure(self, duration, failure, command_name, - request_id, connection_id, op_id=None, - service_id=None): + def publish_command_failure( + self, + duration, + failure, + command_name, + request_id, + connection_id, + op_id=None, + service_id=None, + ): """Publish a CommandFailedEvent to all command listeners. :Parameters: @@ -1401,8 +1483,8 @@ def publish_command_failure(self, duration, failure, command_name, if op_id is None: op_id = request_id event = CommandFailedEvent( - duration, failure, command_name, request_id, connection_id, op_id, - service_id=service_id) + duration, failure, command_name, request_id, connection_id, op_id, service_id=service_id + ) for subscriber in self.__command_listeners: try: subscriber.failed(event) @@ -1423,8 +1505,7 @@ def publish_server_heartbeat_started(self, connection_id): except Exception: _handle_exception() - def publish_server_heartbeat_succeeded(self, connection_id, duration, - reply, awaited): + def publish_server_heartbeat_succeeded(self, connection_id, duration, reply, awaited): """Publish a ServerHeartbeatSucceededEvent to all server heartbeat listeners. @@ -1434,17 +1515,15 @@ def publish_server_heartbeat_succeeded(self, connection_id, duration, resolution for the platform. - `reply`: The command reply. - `awaited`: True if the response was awaited. - """ - event = ServerHeartbeatSucceededEvent(duration, reply, connection_id, - awaited) + """ + event = ServerHeartbeatSucceededEvent(duration, reply, connection_id, awaited) for subscriber in self.__server_heartbeat_listeners: try: subscriber.succeeded(event) except Exception: _handle_exception() - def publish_server_heartbeat_failed(self, connection_id, duration, reply, - awaited): + def publish_server_heartbeat_failed(self, connection_id, duration, reply, awaited): """Publish a ServerHeartbeatFailedEvent to all server heartbeat listeners. @@ -1454,9 +1533,8 @@ def publish_server_heartbeat_failed(self, connection_id, duration, reply, resolution for the platform. - `reply`: The command reply. - `awaited`: True if the response was awaited. - """ - event = ServerHeartbeatFailedEvent(duration, reply, connection_id, - awaited) + """ + event = ServerHeartbeatFailedEvent(duration, reply, connection_id, awaited) for subscriber in self.__server_heartbeat_listeners: try: subscriber.failed(event) @@ -1493,9 +1571,9 @@ def publish_server_closed(self, server_address, topology_id): except Exception: _handle_exception() - def publish_server_description_changed(self, previous_description, - new_description, server_address, - topology_id): + def publish_server_description_changed( + self, previous_description, new_description, server_address, topology_id + ): """Publish a ServerDescriptionChangedEvent to all server listeners. :Parameters: @@ -1505,9 +1583,9 @@ def publish_server_description_changed(self, previous_description, - `topology_id`: A unique identifier for the topology this server is a part of. """ - event = ServerDescriptionChangedEvent(previous_description, - new_description, server_address, - topology_id) + event = ServerDescriptionChangedEvent( + previous_description, new_description, server_address, topology_id + ) for subscriber in self.__server_listeners: try: subscriber.description_changed(event) @@ -1542,8 +1620,9 @@ def publish_topology_closed(self, topology_id): except Exception: _handle_exception() - def publish_topology_description_changed(self, previous_description, - new_description, topology_id): + def publish_topology_description_changed( + self, previous_description, new_description, topology_id + ): """Publish a TopologyDescriptionChangedEvent to all topology listeners. :Parameters: @@ -1552,8 +1631,7 @@ def publish_topology_description_changed(self, previous_description, - `topology_id`: A unique identifier for the topology this server is a part of. """ - event = TopologyDescriptionChangedEvent(previous_description, - new_description, topology_id) + event = TopologyDescriptionChangedEvent(previous_description, new_description, topology_id) for subscriber in self.__topology_listeners: try: subscriber.description_changed(event) @@ -1561,8 +1639,7 @@ def publish_topology_description_changed(self, previous_description, _handle_exception() def publish_pool_created(self, address, options): - """Publish a :class:`PoolCreatedEvent` to all pool listeners. - """ + """Publish a :class:`PoolCreatedEvent` to all pool listeners.""" event = PoolCreatedEvent(address, options) for subscriber in self.__cmap_listeners: try: @@ -1571,8 +1648,7 @@ def publish_pool_created(self, address, options): _handle_exception() def publish_pool_ready(self, address): - """Publish a :class:`PoolReadyEvent` to all pool listeners. - """ + """Publish a :class:`PoolReadyEvent` to all pool listeners.""" event = PoolReadyEvent(address) for subscriber in self.__cmap_listeners: try: @@ -1581,8 +1657,7 @@ def publish_pool_ready(self, address): _handle_exception() def publish_pool_cleared(self, address, service_id): - """Publish a :class:`PoolClearedEvent` to all pool listeners. - """ + """Publish a :class:`PoolClearedEvent` to all pool listeners.""" event = PoolClearedEvent(address, service_id) for subscriber in self.__cmap_listeners: try: @@ -1591,8 +1666,7 @@ def publish_pool_cleared(self, address, service_id): _handle_exception() def publish_pool_closed(self, address): - """Publish a :class:`PoolClosedEvent` to all pool listeners. - """ + """Publish a :class:`PoolClosedEvent` to all pool listeners.""" event = PoolClosedEvent(address) for subscriber in self.__cmap_listeners: try: @@ -1612,8 +1686,7 @@ def publish_connection_created(self, address, connection_id): _handle_exception() def publish_connection_ready(self, address, connection_id): - """Publish a :class:`ConnectionReadyEvent` to all connection listeners. - """ + """Publish a :class:`ConnectionReadyEvent` to all connection listeners.""" event = ConnectionReadyEvent(address, connection_id) for subscriber in self.__cmap_listeners: try: diff --git a/pymongo/network.py b/pymongo/network.py index 7ec6540dd4..5141827e0e 100644 --- a/pymongo/network.py +++ b/pymongo/network.py @@ -20,36 +20,48 @@ import struct import time - from bson import _decode_all_selective - from pymongo import helpers, message from pymongo.common import MAX_MESSAGE_SIZE -from pymongo.compression_support import decompress, _NO_COMPRESSION -from pymongo.errors import (NotPrimaryError, - OperationFailure, - ProtocolError, - _OperationCancelled) +from pymongo.compression_support import _NO_COMPRESSION, decompress +from pymongo.errors import ( + NotPrimaryError, + OperationFailure, + ProtocolError, + _OperationCancelled, +) from pymongo.message import _UNPACK_REPLY, _OpMsg from pymongo.monitoring import _is_speculative_authenticate from pymongo.socket_checker import _errno_from_exception - _UNPACK_HEADER = struct.Struct(" max_bson_size): + if unacknowledged and max_bson_size is not None and max_doc_size > max_bson_size: message._raise_document_too_large(name, size, max_bson_size) else: request_id, msg, size = message._query( - flags, ns, 0, -1, spec, None, codec_options, check_keys, - compression_ctx) + flags, ns, 0, -1, spec, None, codec_options, check_keys, compression_ctx + ) - if (max_bson_size is not None - and size > max_bson_size + message._COMMAND_OVERHEAD): - message._raise_document_too_large( - name, size, max_bson_size + message._COMMAND_OVERHEAD) + if max_bson_size is not None and size > max_bson_size + message._COMMAND_OVERHEAD: + message._raise_document_too_large(name, size, max_bson_size + message._COMMAND_OVERHEAD) if publish: encoding_duration = datetime.datetime.now() - start - listeners.publish_command_start(orig, dbname, request_id, address, - service_id=sock_info.service_id) + listeners.publish_command_start( + orig, dbname, request_id, address, service_id=sock_info.service_id + ) start = datetime.datetime.now() try: @@ -149,15 +164,19 @@ def command(sock_info, dbname, spec, secondary_ok, is_mongos, reply = receive_message(sock_info, request_id) sock_info.more_to_come = reply.more_to_come unpacked_docs = reply.unpack_response( - codec_options=codec_options, user_fields=user_fields) + codec_options=codec_options, user_fields=user_fields + ) response_doc = unpacked_docs[0] if client: client._process_response(response_doc, session) if check: helpers._check_command_response( - response_doc, sock_info.max_wire_version, allowable_errors, - parse_write_concern_error=parse_write_concern_error) + response_doc, + sock_info.max_wire_version, + allowable_errors, + parse_write_concern_error=parse_write_concern_error, + ) except Exception as exc: if publish: duration = (datetime.datetime.now() - start) + encoding_duration @@ -166,25 +185,31 @@ def command(sock_info, dbname, spec, secondary_ok, is_mongos, else: failure = message._convert_exception(exc) listeners.publish_command_failure( - duration, failure, name, request_id, address, - service_id=sock_info.service_id) + duration, failure, name, request_id, address, service_id=sock_info.service_id + ) raise if publish: duration = (datetime.datetime.now() - start) + encoding_duration listeners.publish_command_success( - duration, response_doc, name, request_id, address, + duration, + response_doc, + name, + request_id, + address, service_id=sock_info.service_id, - speculative_hello=speculative_hello) + speculative_hello=speculative_hello, + ) if client and client._encrypter and reply: decrypted = client._encrypter.decrypt(reply.raw_command_response()) - response_doc = _decode_all_selective(decrypted, codec_options, - user_fields)[0] + response_doc = _decode_all_selective(decrypted, codec_options, user_fields)[0] return response_doc + _UNPACK_COMPRESSION_HEADER = struct.Struct(" max_message_size: - raise ProtocolError("Message length (%r) is larger than server max " - "message size (%r)" % (length, max_message_size)) + raise ProtocolError( + "Message length (%r) is larger than server max " + "message size (%r)" % (length, max_message_size) + ) if op_code == 2012: op_code, _, compressor_id = _UNPACK_COMPRESSION_HEADER( - _receive_data_on_socket(sock_info, 9, deadline)) - data = decompress( - _receive_data_on_socket(sock_info, length - 25, deadline), - compressor_id) + _receive_data_on_socket(sock_info, 9, deadline) + ) + data = decompress(_receive_data_on_socket(sock_info, length - 25, deadline), compressor_id) else: data = _receive_data_on_socket(sock_info, length - 16, deadline) try: unpack_reply = _UNPACK_REPLY[op_code] except KeyError: - raise ProtocolError("Got opcode %r but expected " - "%r" % (op_code, _UNPACK_REPLY.keys())) + raise ProtocolError("Got opcode %r but expected " "%r" % (op_code, _UNPACK_REPLY.keys())) return unpack_reply(data) @@ -234,7 +260,7 @@ def wait_for_read(sock_info, deadline): sock = sock_info.sock while True: # SSLSocket can have buffered data which won't be caught by select. - if hasattr(sock, 'pending') and sock.pending() > 0: + if hasattr(sock, "pending") and sock.pending() > 0: readable = True else: # Wait up to 500ms for the socket to become readable and then @@ -243,15 +269,15 @@ def wait_for_read(sock_info, deadline): timeout = max(min(deadline - time.monotonic(), _POLL_TIMEOUT), 0.001) else: timeout = _POLL_TIMEOUT - readable = sock_info.socket_checker.select( - sock, read=True, timeout=timeout) + readable = sock_info.socket_checker.select(sock, read=True, timeout=timeout) if context.cancelled: - raise _OperationCancelled('hello cancelled') + raise _OperationCancelled("hello cancelled") if readable: return if deadline and time.monotonic() > deadline: raise socket.timeout("timed out") + def _receive_data_on_socket(sock_info, length, deadline): buf = bytearray(length) mv = memoryview(buf) diff --git a/pymongo/ocsp_cache.py b/pymongo/ocsp_cache.py index c2c24c4ab0..0b380bc168 100644 --- a/pymongo/ocsp_cache.py +++ b/pymongo/ocsp_cache.py @@ -21,9 +21,11 @@ class _OCSPCache(object): """A cache for OCSP responses.""" - CACHE_KEY_TYPE = namedtuple('OcspResponseCacheKey', - ['hash_algorithm', 'issuer_name_hash', - 'issuer_key_hash', 'serial_number']) + + CACHE_KEY_TYPE = namedtuple( + "OcspResponseCacheKey", + ["hash_algorithm", "issuer_name_hash", "issuer_key_hash", "serial_number"], + ) def __init__(self): self._data = {} @@ -35,7 +37,8 @@ def _get_cache_key(self, ocsp_request): hash_algorithm=ocsp_request.hash_algorithm.name.lower(), issuer_name_hash=ocsp_request.issuer_name_hash, issuer_key_hash=ocsp_request.issuer_key_hash, - serial_number=ocsp_request.serial_number) + serial_number=ocsp_request.serial_number, + ) def __setitem__(self, key, value): """Add/update a cache entry. @@ -56,15 +59,13 @@ def __setitem__(self, key, value): return # Do nothing if the response is invalid. - if not (value.this_update <= _datetime.utcnow() - < value.next_update): + if not (value.this_update <= _datetime.utcnow() < value.next_update): return # Cache new response OR update cached response if new response # has longer validity. cached_value = self._data.get(cache_key, None) - if (cached_value is None or - cached_value.next_update < value.next_update): + if cached_value is None or cached_value.next_update < value.next_update: self._data[cache_key] = value def __getitem__(self, item): @@ -79,8 +80,7 @@ def __getitem__(self, item): value = self._data[cache_key] # Return cached response if it is still valid. - if (value.this_update <= _datetime.utcnow() < - value.next_update): + if value.this_update <= _datetime.utcnow() < value.next_update: return value self._data.pop(cache_key, None) diff --git a/pymongo/ocsp_support.py b/pymongo/ocsp_support.py index 1a983e0af8..369055ea8d 100644 --- a/pymongo/ocsp_support.py +++ b/pymongo/ocsp_support.py @@ -16,41 +16,35 @@ import logging as _logging import re as _re - from datetime import datetime as _datetime from cryptography.exceptions import InvalidSignature as _InvalidSignature from cryptography.hazmat.backends import default_backend as _default_backend -from cryptography.hazmat.primitives.asymmetric.dsa import ( - DSAPublicKey as _DSAPublicKey) +from cryptography.hazmat.primitives.asymmetric.dsa import DSAPublicKey as _DSAPublicKey +from cryptography.hazmat.primitives.asymmetric.ec import ECDSA as _ECDSA from cryptography.hazmat.primitives.asymmetric.ec import ( - ECDSA as _ECDSA, - EllipticCurvePublicKey as _EllipticCurvePublicKey) -from cryptography.hazmat.primitives.asymmetric.padding import ( - PKCS1v15 as _PKCS1v15) -from cryptography.hazmat.primitives.asymmetric.rsa import ( - RSAPublicKey as _RSAPublicKey) -from cryptography.hazmat.primitives.hashes import ( - Hash as _Hash, - SHA1 as _SHA1) -from cryptography.hazmat.primitives.serialization import ( - Encoding as _Encoding, - PublicFormat as _PublicFormat) -from cryptography.x509 import ( - AuthorityInformationAccess as _AuthorityInformationAccess, - ExtendedKeyUsage as _ExtendedKeyUsage, - ExtensionNotFound as _ExtensionNotFound, - load_pem_x509_certificate as _load_pem_x509_certificate, - TLSFeature as _TLSFeature, - TLSFeatureType as _TLSFeatureType) + EllipticCurvePublicKey as _EllipticCurvePublicKey, +) +from cryptography.hazmat.primitives.asymmetric.padding import PKCS1v15 as _PKCS1v15 +from cryptography.hazmat.primitives.asymmetric.rsa import RSAPublicKey as _RSAPublicKey +from cryptography.hazmat.primitives.hashes import SHA1 as _SHA1 +from cryptography.hazmat.primitives.hashes import Hash as _Hash +from cryptography.hazmat.primitives.serialization import Encoding as _Encoding +from cryptography.hazmat.primitives.serialization import PublicFormat as _PublicFormat +from cryptography.x509 import AuthorityInformationAccess as _AuthorityInformationAccess +from cryptography.x509 import ExtendedKeyUsage as _ExtendedKeyUsage +from cryptography.x509 import ExtensionNotFound as _ExtensionNotFound +from cryptography.x509 import TLSFeature as _TLSFeature +from cryptography.x509 import TLSFeatureType as _TLSFeatureType +from cryptography.x509 import load_pem_x509_certificate as _load_pem_x509_certificate +from cryptography.x509.ocsp import OCSPCertStatus as _OCSPCertStatus +from cryptography.x509.ocsp import OCSPRequestBuilder as _OCSPRequestBuilder +from cryptography.x509.ocsp import OCSPResponseStatus as _OCSPResponseStatus +from cryptography.x509.ocsp import load_der_ocsp_response as _load_der_ocsp_response from cryptography.x509.oid import ( AuthorityInformationAccessOID as _AuthorityInformationAccessOID, - ExtendedKeyUsageOID as _ExtendedKeyUsageOID) -from cryptography.x509.ocsp import ( - load_der_ocsp_response as _load_der_ocsp_response, - OCSPCertStatus as _OCSPCertStatus, - OCSPRequestBuilder as _OCSPRequestBuilder, - OCSPResponseStatus as _OCSPResponseStatus) +) +from cryptography.x509.oid import ExtendedKeyUsageOID as _ExtendedKeyUsageOID from requests import post as _post from requests.exceptions import RequestException as _RequestException @@ -62,21 +56,20 @@ _LOGGER = _logging.getLogger(__name__) _CERT_REGEX = _re.compile( - b'-----BEGIN CERTIFICATE[^\r\n]+.+?-----END CERTIFICATE[^\r\n]+', - _re.DOTALL) + b"-----BEGIN CERTIFICATE[^\r\n]+.+?-----END CERTIFICATE[^\r\n]+", _re.DOTALL +) def _load_trusted_ca_certs(cafile): """Parse the tlsCAFile into a list of certificates.""" - with open(cafile, 'rb') as f: + with open(cafile, "rb") as f: data = f.read() # Load all the certs in the file. trusted_ca_certs = [] backend = _default_backend() for cert_data in _re.findall(_CERT_REGEX, data): - trusted_ca_certs.append( - _load_pem_x509_certificate(cert_data, backend)) + trusted_ca_certs.append(_load_pem_x509_certificate(cert_data, backend)) return trusted_ca_certs @@ -128,14 +121,11 @@ def _public_key_hash(cert): # (excluding the tag and length fields)" # https://stackoverflow.com/a/46309453/600498 if isinstance(public_key, _RSAPublicKey): - pbytes = public_key.public_bytes( - _Encoding.DER, _PublicFormat.PKCS1) + pbytes = public_key.public_bytes(_Encoding.DER, _PublicFormat.PKCS1) elif isinstance(public_key, _EllipticCurvePublicKey): - pbytes = public_key.public_bytes( - _Encoding.X962, _PublicFormat.UncompressedPoint) + pbytes = public_key.public_bytes(_Encoding.X962, _PublicFormat.UncompressedPoint) else: - pbytes = public_key.public_bytes( - _Encoding.DER, _PublicFormat.SubjectPublicKeyInfo) + pbytes = public_key.public_bytes(_Encoding.DER, _PublicFormat.SubjectPublicKeyInfo) digest = _Hash(_SHA1(), backend=_default_backend()) digest.update(pbytes) return digest.finalize() @@ -143,16 +133,18 @@ def _public_key_hash(cert): def _get_certs_by_key_hash(certificates, issuer, responder_key_hash): return [ - cert for cert in certificates - if _public_key_hash(cert) == responder_key_hash and - cert.issuer == issuer.subject] + cert + for cert in certificates + if _public_key_hash(cert) == responder_key_hash and cert.issuer == issuer.subject + ] def _get_certs_by_name(certificates, issuer, responder_name): return [ - cert for cert in certificates - if cert.subject == responder_name and - cert.issuer == issuer.subject] + cert + for cert in certificates + if cert.subject == responder_name and cert.issuer == issuer.subject + ] def _verify_response_signature(issuer, response): @@ -190,10 +182,11 @@ def _verify_response_signature(issuer, response): _LOGGER.debug("Delegate not authorized for OCSP signing") return 0 if not _verify_signature( - issuer.public_key(), - responder_cert.signature, - responder_cert.signature_hash_algorithm, - responder_cert.tbs_certificate_bytes): + issuer.public_key(), + responder_cert.signature, + responder_cert.signature_hash_algorithm, + responder_cert.tbs_certificate_bytes, + ): _LOGGER.debug("Delegate signature verification failed") return 0 # RFC6960, Section 3.2, Number 2 @@ -201,7 +194,8 @@ def _verify_response_signature(issuer, response): responder_cert.public_key(), response.signature, response.signature_hash_algorithm, - response.tbs_response_bytes) + response.tbs_response_bytes, + ) if not ret: _LOGGER.debug("Response signature verification failed") return ret @@ -245,8 +239,9 @@ def _get_ocsp_response(cert, issuer, uri, ocsp_response_cache): response = _post( uri, data=ocsp_request.public_bytes(_Encoding.DER), - headers={'Content-Type': 'application/ocsp-request'}, - timeout=5) + headers={"Content-Type": "application/ocsp-request"}, + timeout=5, + ) except _RequestException as exc: _LOGGER.debug("HTTP request failed: %s", exc) return None @@ -254,8 +249,7 @@ def _get_ocsp_response(cert, issuer, uri, ocsp_response_cache): _LOGGER.debug("HTTP request returned %d", response.status_code) return None ocsp_response = _load_der_ocsp_response(response.content) - _LOGGER.debug( - "OCSP response status: %r", ocsp_response.response_status) + _LOGGER.debug("OCSP response status: %r", ocsp_response.response_status) if ocsp_response.response_status != _OCSPResponseStatus.SUCCESSFUL: return None # RFC6960, Section 3.2, Number 1. Only relevant if we need to @@ -299,7 +293,7 @@ def _ocsp_callback(conn, ocsp_bytes, user_data): ocsp_response_cache = user_data.ocsp_response_cache # No stapled OCSP response - if ocsp_bytes == b'': + if ocsp_bytes == b"": _LOGGER.debug("Peer did not staple an OCSP response") if must_staple: _LOGGER.debug("Must-staple cert with no stapled response, hard fail.") @@ -314,9 +308,11 @@ def _ocsp_callback(conn, ocsp_bytes, user_data): _LOGGER.debug("No authority access information, soft fail") # No stapled OCSP response, no responder URI, soft fail. return 1 - uris = [desc.access_location.value - for desc in ext.value - if desc.access_method == _AuthorityInformationAccessOID.OCSP] + uris = [ + desc.access_location.value + for desc in ext.value + if desc.access_method == _AuthorityInformationAccessOID.OCSP + ] if not uris: _LOGGER.debug("No OCSP URI, soft fail") # No responder URI, soft fail. @@ -329,8 +325,7 @@ def _ocsp_callback(conn, ocsp_bytes, user_data): # successful, valid responses with a certificate status of REVOKED. for uri in uris: _LOGGER.debug("Trying %s", uri) - response = _get_ocsp_response( - cert, issuer, uri, ocsp_response_cache) + response = _get_ocsp_response(cert, issuer, uri, ocsp_response_cache) if response is None: # The endpoint didn't respond in time, or the response was # unsuccessful or didn't match the request, or the response @@ -350,8 +345,7 @@ def _ocsp_callback(conn, ocsp_bytes, user_data): _LOGGER.debug("No issuer cert?") return 0 response = _load_der_ocsp_response(ocsp_bytes) - _LOGGER.debug( - "OCSP response status: %r", response.response_status) + _LOGGER.debug("OCSP response status: %r", response.response_status) # This happens in _request_ocsp when there is no stapled response so # we know if we can compare serial numbers for the request and response. if response.response_status != _OCSPResponseStatus.SUCCESSFUL: diff --git a/pymongo/operations.py b/pymongo/operations.py index b5d670e0ff..02ebdf2a79 100644 --- a/pymongo/operations.py +++ b/pymongo/operations.py @@ -15,8 +15,8 @@ """Operation class definitions.""" from pymongo import helpers -from pymongo.common import validate_boolean, validate_is_mapping, validate_list from pymongo.collation import validate_collation_or_none +from pymongo.common import validate_boolean, validate_is_mapping, validate_list from pymongo.helpers import _gen_index_name, _index_document, _index_list @@ -90,16 +90,14 @@ def __init__(self, filter, collation=None, hint=None): def _add_to_bulk(self, bulkobj): """Add this operation to the _Bulk instance `bulkobj`.""" - bulkobj.add_delete(self._filter, 1, collation=self._collation, - hint=self._hint) + bulkobj.add_delete(self._filter, 1, collation=self._collation, hint=self._hint) def __repr__(self): return "DeleteOne(%r, %r)" % (self._filter, self._collation) def __eq__(self, other): if type(other) == type(self): - return ((other._filter, other._collation) == - (self._filter, self._collation)) + return (other._filter, other._collation) == (self._filter, self._collation) return NotImplemented def __ne__(self, other): @@ -144,16 +142,14 @@ def __init__(self, filter, collation=None, hint=None): def _add_to_bulk(self, bulkobj): """Add this operation to the _Bulk instance `bulkobj`.""" - bulkobj.add_delete(self._filter, 0, collation=self._collation, - hint=self._hint) + bulkobj.add_delete(self._filter, 0, collation=self._collation, hint=self._hint) def __repr__(self): return "DeleteMany(%r, %r)" % (self._filter, self._collation) def __eq__(self, other): if type(other) == type(self): - return ((other._filter, other._collation) == - (self._filter, self._collation)) + return (other._filter, other._collation) == (self._filter, self._collation) return NotImplemented def __ne__(self, other): @@ -165,8 +161,7 @@ class ReplaceOne(object): __slots__ = ("_filter", "_doc", "_upsert", "_collation", "_hint") - def __init__(self, filter, replacement, upsert=False, collation=None, - hint=None): + def __init__(self, filter, replacement, upsert=False, collation=None, hint=None): """Create a ReplaceOne instance. For use with :meth:`~pymongo.collection.Collection.bulk_write`. @@ -207,15 +202,19 @@ def __init__(self, filter, replacement, upsert=False, collation=None, def _add_to_bulk(self, bulkobj): """Add this operation to the _Bulk instance `bulkobj`.""" - bulkobj.add_replace(self._filter, self._doc, self._upsert, - collation=self._collation, hint=self._hint) + bulkobj.add_replace( + self._filter, self._doc, self._upsert, collation=self._collation, hint=self._hint + ) def __eq__(self, other): if type(other) == type(self): - return ( - (other._filter, other._doc, other._upsert, other._collation, - other._hint) == (self._filter, self._doc, self._upsert, - self._collation, other._hint)) + return (other._filter, other._doc, other._upsert, other._collation, other._hint) == ( + self._filter, + self._doc, + self._upsert, + self._collation, + other._hint, + ) return NotImplemented def __ne__(self, other): @@ -223,15 +222,19 @@ def __ne__(self, other): def __repr__(self): return "%s(%r, %r, %r, %r, %r)" % ( - self.__class__.__name__, self._filter, self._doc, self._upsert, - self._collation, self._hint) + self.__class__.__name__, + self._filter, + self._doc, + self._upsert, + self._collation, + self._hint, + ) class _UpdateOp(object): """Private base class for update operations.""" - __slots__ = ("_filter", "_doc", "_upsert", "_collation", "_array_filters", - "_hint") + __slots__ = ("_filter", "_doc", "_upsert", "_collation", "_array_filters", "_hint") def __init__(self, filter, doc, upsert, collation, array_filters, hint): if filter is not None: @@ -244,7 +247,6 @@ def __init__(self, filter, doc, upsert, collation, array_filters, hint): if not isinstance(hint, str): hint = helpers._index_document(hint) - self._filter = filter self._doc = doc self._upsert = upsert @@ -255,10 +257,20 @@ def __init__(self, filter, doc, upsert, collation, array_filters, hint): def __eq__(self, other): if type(other) == type(self): return ( - (other._filter, other._doc, other._upsert, other._collation, - other._array_filters, other._hint) == - (self._filter, self._doc, self._upsert, self._collation, - self._array_filters, self._hint)) + other._filter, + other._doc, + other._upsert, + other._collation, + other._array_filters, + other._hint, + ) == ( + self._filter, + self._doc, + self._upsert, + self._collation, + self._array_filters, + self._hint, + ) return NotImplemented def __ne__(self, other): @@ -266,8 +278,14 @@ def __ne__(self, other): def __repr__(self): return "%s(%r, %r, %r, %r, %r, %r)" % ( - self.__class__.__name__, self._filter, self._doc, self._upsert, - self._collation, self._array_filters, self._hint) + self.__class__.__name__, + self._filter, + self._doc, + self._upsert, + self._collation, + self._array_filters, + self._hint, + ) class UpdateOne(_UpdateOp): @@ -275,8 +293,7 @@ class UpdateOne(_UpdateOp): __slots__ = () - def __init__(self, filter, update, upsert=False, collation=None, - array_filters=None, hint=None): + def __init__(self, filter, update, upsert=False, collation=None, array_filters=None, hint=None): """Represents an update_one operation. For use with :meth:`~pymongo.collection.Collection.bulk_write`. @@ -307,15 +324,19 @@ def __init__(self, filter, update, upsert=False, collation=None, .. versionchanged:: 3.5 Added the `collation` option. """ - super(UpdateOne, self).__init__(filter, update, upsert, collation, - array_filters, hint) + super(UpdateOne, self).__init__(filter, update, upsert, collation, array_filters, hint) def _add_to_bulk(self, bulkobj): """Add this operation to the _Bulk instance `bulkobj`.""" - bulkobj.add_update(self._filter, self._doc, False, self._upsert, - collation=self._collation, - array_filters=self._array_filters, - hint=self._hint) + bulkobj.add_update( + self._filter, + self._doc, + False, + self._upsert, + collation=self._collation, + array_filters=self._array_filters, + hint=self._hint, + ) class UpdateMany(_UpdateOp): @@ -323,8 +344,7 @@ class UpdateMany(_UpdateOp): __slots__ = () - def __init__(self, filter, update, upsert=False, collation=None, - array_filters=None, hint=None): + def __init__(self, filter, update, upsert=False, collation=None, array_filters=None, hint=None): """Create an UpdateMany instance. For use with :meth:`~pymongo.collection.Collection.bulk_write`. @@ -355,15 +375,19 @@ def __init__(self, filter, update, upsert=False, collation=None, .. versionchanged:: 3.5 Added the `collation` option. """ - super(UpdateMany, self).__init__(filter, update, upsert, collation, - array_filters, hint) + super(UpdateMany, self).__init__(filter, update, upsert, collation, array_filters, hint) def _add_to_bulk(self, bulkobj): """Add this operation to the _Bulk instance `bulkobj`.""" - bulkobj.add_update(self._filter, self._doc, True, self._upsert, - collation=self._collation, - array_filters=self._array_filters, - hint=self._hint) + bulkobj.add_update( + self._filter, + self._doc, + True, + self._upsert, + collation=self._collation, + array_filters=self._array_filters, + hint=self._hint, + ) class IndexModel(object): @@ -436,10 +460,10 @@ def __init__(self, keys, **kwargs): if "name" not in kwargs: kwargs["name"] = _gen_index_name(keys) kwargs["key"] = _index_document(keys) - collation = validate_collation_or_none(kwargs.pop('collation', None)) + collation = validate_collation_or_none(kwargs.pop("collation", None)) self.__document = kwargs if collation is not None: - self.__document['collation'] = collation + self.__document["collation"] = collation @property def document(self): diff --git a/pymongo/periodic_executor.py b/pymongo/periodic_executor.py index e1690ee9b1..9fce713dea 100644 --- a/pymongo/periodic_executor.py +++ b/pymongo/periodic_executor.py @@ -21,7 +21,7 @@ class PeriodicExecutor(object): def __init__(self, interval, min_interval, target, name=None): - """"Run a target function periodically on a background thread. + """ "Run a target function periodically on a background thread. If the target's return value is false, the executor stops. @@ -49,8 +49,7 @@ def __init__(self, interval, min_interval, target, name=None): self._lock = threading.Lock() def __repr__(self): - return '<%s(name=%s) object at 0x%x>' % ( - self.__class__.__name__, self._name, id(self)) + return "<%s(name=%s) object at 0x%x>" % (self.__class__.__name__, self._name, id(self)) def open(self): """Start. Multiple calls have no effect. diff --git a/pymongo/pool.py b/pymongo/pool.py index 2ae2576250..a7913f184a 100644 --- a/pymongo/pool.py +++ b/pymongo/pool.py @@ -18,8 +18,8 @@ import ipaddress import os import platform -import ssl import socket +import ssl import sys import threading import time @@ -27,41 +27,43 @@ from bson import DEFAULT_CODEC_OPTIONS from bson.son import SON -from pymongo import auth, helpers, __version__ +from pymongo import __version__, auth, helpers from pymongo.client_session import _validate_session_write_concern -from pymongo.common import (MAX_BSON_SIZE, - MAX_CONNECTING, - MAX_IDLE_TIME_SEC, - MAX_MESSAGE_SIZE, - MAX_POOL_SIZE, - MAX_WIRE_VERSION, - MAX_WRITE_BATCH_SIZE, - MIN_POOL_SIZE, - ORDERED_TYPES, - WAIT_QUEUE_TIMEOUT) -from pymongo.errors import (AutoReconnect, - _CertificateError, - ConnectionFailure, - ConfigurationError, - InvalidOperation, - DocumentTooLarge, - NetworkTimeout, - NotPrimaryError, - OperationFailure, - PyMongoError) -from pymongo.hello import HelloCompat, Hello -from pymongo.monitoring import (ConnectionCheckOutFailedReason, - ConnectionClosedReason) -from pymongo.network import (command, - receive_message) +from pymongo.common import ( + MAX_BSON_SIZE, + MAX_CONNECTING, + MAX_IDLE_TIME_SEC, + MAX_MESSAGE_SIZE, + MAX_POOL_SIZE, + MAX_WIRE_VERSION, + MAX_WRITE_BATCH_SIZE, + MIN_POOL_SIZE, + ORDERED_TYPES, + WAIT_QUEUE_TIMEOUT, +) +from pymongo.errors import ( + AutoReconnect, + ConfigurationError, + ConnectionFailure, + DocumentTooLarge, + InvalidOperation, + NetworkTimeout, + NotPrimaryError, + OperationFailure, + PyMongoError, + _CertificateError, +) +from pymongo.hello import Hello, HelloCompat +from pymongo.monitoring import ConnectionCheckOutFailedReason, ConnectionClosedReason +from pymongo.network import command, receive_message from pymongo.read_preferences import ReadPreference from pymongo.server_api import _add_to_command from pymongo.server_type import SERVER_TYPE from pymongo.socket_checker import SocketChecker -from pymongo.ssl_support import ( - SSLError as _SSLError, - HAS_SNI as _HAVE_SNI, - IPADDR_SAFE as _IPADDR_SAFE) +from pymongo.ssl_support import HAS_SNI as _HAVE_SNI +from pymongo.ssl_support import IPADDR_SAFE as _IPADDR_SAFE +from pymongo.ssl_support import SSLError as _SSLError + # For SNI support. According to RFC6066, section 3, IPv4 and IPv6 literals are # not permitted for SNI hostname. @@ -72,12 +74,15 @@ def is_ip_address(address): except (ValueError, UnicodeError): return False + try: - from fcntl import fcntl, F_GETFD, F_SETFD, FD_CLOEXEC + from fcntl import F_GETFD, F_SETFD, FD_CLOEXEC, fcntl + def _set_non_inheritable_non_atomic(fd): """Set the close-on-exec flag on the given file descriptor.""" flags = fcntl(fd, F_GETFD) fcntl(fd, F_SETFD, flags | FD_CLOEXEC) + except ImportError: # Windows, various platforms we don't claim to support # (Jython, IronPython, ...), systems that don't provide @@ -86,11 +91,12 @@ def _set_non_inheritable_non_atomic(dummy): """Dummy function for platforms that don't provide fcntl.""" pass + _MAX_TCP_KEEPIDLE = 120 _MAX_TCP_KEEPINTVL = 10 _MAX_TCP_KEEPCNT = 9 -if sys.platform == 'win32': +if sys.platform == "win32": try: import _winreg as winreg except ImportError: @@ -108,8 +114,8 @@ def _query(key, name, default): try: with winreg.OpenKey( - winreg.HKEY_LOCAL_MACHINE, - r"SYSTEM\CurrentControlSet\Services\Tcpip\Parameters") as key: + winreg.HKEY_LOCAL_MACHINE, r"SYSTEM\CurrentControlSet\Services\Tcpip\Parameters" + ) as key: _WINDOWS_TCP_IDLE_MS = _query(key, "KeepAliveTime", 7200000) _WINDOWS_TCP_INTERVAL_MS = _query(key, "KeepAliveInterval", 1000) except OSError: @@ -120,13 +126,12 @@ def _query(key, name, default): def _set_keepalive_times(sock): idle_ms = min(_WINDOWS_TCP_IDLE_MS, _MAX_TCP_KEEPIDLE * 1000) - interval_ms = min(_WINDOWS_TCP_INTERVAL_MS, - _MAX_TCP_KEEPINTVL * 1000) - if (idle_ms < _WINDOWS_TCP_IDLE_MS or - interval_ms < _WINDOWS_TCP_INTERVAL_MS): - sock.ioctl(socket.SIO_KEEPALIVE_VALS, - (1, idle_ms, interval_ms)) + interval_ms = min(_WINDOWS_TCP_INTERVAL_MS, _MAX_TCP_KEEPINTVL * 1000) + if idle_ms < _WINDOWS_TCP_IDLE_MS or interval_ms < _WINDOWS_TCP_INTERVAL_MS: + sock.ioctl(socket.SIO_KEEPALIVE_VALS, (1, idle_ms, interval_ms)) + else: + def _set_tcp_option(sock, tcp_option, max_value): if hasattr(socket, tcp_option): sockopt = getattr(socket, tcp_option) @@ -141,88 +146,106 @@ def _set_tcp_option(sock, tcp_option, max_value): pass def _set_keepalive_times(sock): - _set_tcp_option(sock, 'TCP_KEEPIDLE', _MAX_TCP_KEEPIDLE) - _set_tcp_option(sock, 'TCP_KEEPINTVL', _MAX_TCP_KEEPINTVL) - _set_tcp_option(sock, 'TCP_KEEPCNT', _MAX_TCP_KEEPCNT) + _set_tcp_option(sock, "TCP_KEEPIDLE", _MAX_TCP_KEEPIDLE) + _set_tcp_option(sock, "TCP_KEEPINTVL", _MAX_TCP_KEEPINTVL) + _set_tcp_option(sock, "TCP_KEEPCNT", _MAX_TCP_KEEPCNT) -_METADATA = SON([ - ('driver', SON([('name', 'PyMongo'), ('version', __version__)])), -]) -if sys.platform.startswith('linux'): +_METADATA = SON( + [ + ("driver", SON([("name", "PyMongo"), ("version", __version__)])), + ] +) + +if sys.platform.startswith("linux"): # platform.linux_distribution was deprecated in Python 3.5 # and removed in Python 3.8. Starting in Python 3.5 it # raises DeprecationWarning # DeprecationWarning: dist() and linux_distribution() functions are deprecated in Python 3.5 _name = platform.system() - _METADATA['os'] = SON([ - ('type', _name), - ('name', _name), - ('architecture', platform.machine()), - # Kernel version (e.g. 4.4.0-17-generic). - ('version', platform.release()) - ]) -elif sys.platform == 'darwin': - _METADATA['os'] = SON([ - ('type', platform.system()), - ('name', platform.system()), - ('architecture', platform.machine()), - # (mac|i|tv)OS(X) version (e.g. 10.11.6) instead of darwin - # kernel version. - ('version', platform.mac_ver()[0]) - ]) -elif sys.platform == 'win32': - _METADATA['os'] = SON([ - ('type', platform.system()), - # "Windows XP", "Windows 7", "Windows 10", etc. - ('name', ' '.join((platform.system(), platform.release()))), - ('architecture', platform.machine()), - # Windows patch level (e.g. 5.1.2600-SP3) - ('version', '-'.join(platform.win32_ver()[1:3])) - ]) -elif sys.platform.startswith('java'): + _METADATA["os"] = SON( + [ + ("type", _name), + ("name", _name), + ("architecture", platform.machine()), + # Kernel version (e.g. 4.4.0-17-generic). + ("version", platform.release()), + ] + ) +elif sys.platform == "darwin": + _METADATA["os"] = SON( + [ + ("type", platform.system()), + ("name", platform.system()), + ("architecture", platform.machine()), + # (mac|i|tv)OS(X) version (e.g. 10.11.6) instead of darwin + # kernel version. + ("version", platform.mac_ver()[0]), + ] + ) +elif sys.platform == "win32": + _METADATA["os"] = SON( + [ + ("type", platform.system()), + # "Windows XP", "Windows 7", "Windows 10", etc. + ("name", " ".join((platform.system(), platform.release()))), + ("architecture", platform.machine()), + # Windows patch level (e.g. 5.1.2600-SP3) + ("version", "-".join(platform.win32_ver()[1:3])), + ] + ) +elif sys.platform.startswith("java"): _name, _ver, _arch = platform.java_ver()[-1] - _METADATA['os'] = SON([ - # Linux, Windows 7, Mac OS X, etc. - ('type', _name), - ('name', _name), - # x86, x86_64, AMD64, etc. - ('architecture', _arch), - # Linux kernel version, OSX version, etc. - ('version', _ver) - ]) + _METADATA["os"] = SON( + [ + # Linux, Windows 7, Mac OS X, etc. + ("type", _name), + ("name", _name), + # x86, x86_64, AMD64, etc. + ("architecture", _arch), + # Linux kernel version, OSX version, etc. + ("version", _ver), + ] + ) else: # Get potential alias (e.g. SunOS 5.11 becomes Solaris 2.11) - _aliased = platform.system_alias( - platform.system(), platform.release(), platform.version()) - _METADATA['os'] = SON([ - ('type', platform.system()), - ('name', ' '.join([part for part in _aliased[:2] if part])), - ('architecture', platform.machine()), - ('version', _aliased[2]) - ]) - -if platform.python_implementation().startswith('PyPy'): - _METADATA['platform'] = ' '.join( - (platform.python_implementation(), - '.'.join(map(str, sys.pypy_version_info)), - '(Python %s)' % '.'.join(map(str, sys.version_info)))) -elif sys.platform.startswith('java'): - _METADATA['platform'] = ' '.join( - (platform.python_implementation(), - '.'.join(map(str, sys.version_info)), - '(%s)' % ' '.join((platform.system(), platform.release())))) + _aliased = platform.system_alias(platform.system(), platform.release(), platform.version()) + _METADATA["os"] = SON( + [ + ("type", platform.system()), + ("name", " ".join([part for part in _aliased[:2] if part])), + ("architecture", platform.machine()), + ("version", _aliased[2]), + ] + ) + +if platform.python_implementation().startswith("PyPy"): + _METADATA["platform"] = " ".join( + ( + platform.python_implementation(), + ".".join(map(str, sys.pypy_version_info)), + "(Python %s)" % ".".join(map(str, sys.version_info)), + ) + ) +elif sys.platform.startswith("java"): + _METADATA["platform"] = " ".join( + ( + platform.python_implementation(), + ".".join(map(str, sys.version_info)), + "(%s)" % " ".join((platform.system(), platform.release())), + ) + ) else: - _METADATA['platform'] = ' '.join( - (platform.python_implementation(), - '.'.join(map(str, sys.version_info)))) + _METADATA["platform"] = " ".join( + (platform.python_implementation(), ".".join(map(str, sys.version_info))) + ) # If the first getaddrinfo call of this interpreter's life is on a thread, # while the main thread holds the import lock, getaddrinfo deadlocks trying # to import the IDNA codec. Import it here, where presumably we're on the # main thread, to avoid the deadlock. See PYTHON-607. -'foo'.encode('idna') +"foo".encode("idna") # Remove after PYTHON-2712 _MOCK_SERVICE_ID = False @@ -233,14 +256,14 @@ def _raise_connection_failure(address, error, msg_prefix=None): host, port = address # If connecting to a Unix socket, port will be None. if port is not None: - msg = '%s:%d: %s' % (host, port, error) + msg = "%s:%d: %s" % (host, port, error) else: - msg = '%s: %s' % (host, error) + msg = "%s: %s" % (host, error) if msg_prefix: msg = msg_prefix + msg if isinstance(error, socket.timeout): raise NetworkTimeout(msg) from error - elif isinstance(error, _SSLError) and 'timed out' in str(error): + elif isinstance(error, _SSLError) and "timed out" in str(error): # Eventlet does not distinguish TLS network timeouts from other # SSLErrors (https://github.com/eventlet/eventlet/issues/692). # Luckily, we can work around this limitation because the phrase @@ -268,24 +291,45 @@ class PoolOptions(object): """ - __slots__ = ('__max_pool_size', '__min_pool_size', - '__max_idle_time_seconds', - '__connect_timeout', '__socket_timeout', - '__wait_queue_timeout', - '__ssl_context', '__tls_allow_invalid_hostnames', - '__event_listeners', '__appname', '__driver', '__metadata', - '__compression_settings', '__max_connecting', - '__pause_enabled', '__server_api', '__load_balanced') - - def __init__(self, max_pool_size=MAX_POOL_SIZE, - min_pool_size=MIN_POOL_SIZE, - max_idle_time_seconds=MAX_IDLE_TIME_SEC, connect_timeout=None, - socket_timeout=None, wait_queue_timeout=WAIT_QUEUE_TIMEOUT, - ssl_context=None, - tls_allow_invalid_hostnames=False, - event_listeners=None, appname=None, driver=None, - compression_settings=None, max_connecting=MAX_CONNECTING, - pause_enabled=True, server_api=None, load_balanced=None): + __slots__ = ( + "__max_pool_size", + "__min_pool_size", + "__max_idle_time_seconds", + "__connect_timeout", + "__socket_timeout", + "__wait_queue_timeout", + "__ssl_context", + "__tls_allow_invalid_hostnames", + "__event_listeners", + "__appname", + "__driver", + "__metadata", + "__compression_settings", + "__max_connecting", + "__pause_enabled", + "__server_api", + "__load_balanced", + ) + + def __init__( + self, + max_pool_size=MAX_POOL_SIZE, + min_pool_size=MIN_POOL_SIZE, + max_idle_time_seconds=MAX_IDLE_TIME_SEC, + connect_timeout=None, + socket_timeout=None, + wait_queue_timeout=WAIT_QUEUE_TIMEOUT, + ssl_context=None, + tls_allow_invalid_hostnames=False, + event_listeners=None, + appname=None, + driver=None, + compression_settings=None, + max_connecting=MAX_CONNECTING, + pause_enabled=True, + server_api=None, + load_balanced=None, + ): self.__max_pool_size = max_pool_size self.__min_pool_size = min_pool_size self.__max_idle_time_seconds = max_idle_time_seconds @@ -304,7 +348,7 @@ def __init__(self, max_pool_size=MAX_POOL_SIZE, self.__load_balanced = load_balanced self.__metadata = copy.deepcopy(_METADATA) if appname: - self.__metadata['application'] = {'name': appname} + self.__metadata["application"] = {"name": appname} # Combine the "driver" MongoClient option with PyMongo's info, like: # { @@ -316,14 +360,17 @@ def __init__(self, max_pool_size=MAX_POOL_SIZE, # } if driver: if driver.name: - self.__metadata['driver']['name'] = "%s|%s" % ( - _METADATA['driver']['name'], driver.name) + self.__metadata["driver"]["name"] = "%s|%s" % ( + _METADATA["driver"]["name"], + driver.name, + ) if driver.version: - self.__metadata['driver']['version'] = "%s|%s" % ( - _METADATA['driver']['version'], driver.version) + self.__metadata["driver"]["version"] = "%s|%s" % ( + _METADATA["driver"]["version"], + driver.version, + ) if driver.platform: - self.__metadata['platform'] = "%s|%s" % ( - _METADATA['platform'], driver.platform) + self.__metadata["platform"] = "%s|%s" % (_METADATA["platform"], driver.platform) @property def non_default_options(self): @@ -333,15 +380,15 @@ def non_default_options(self): """ opts = {} if self.__max_pool_size != MAX_POOL_SIZE: - opts['maxPoolSize'] = self.__max_pool_size + opts["maxPoolSize"] = self.__max_pool_size if self.__min_pool_size != MIN_POOL_SIZE: - opts['minPoolSize'] = self.__min_pool_size + opts["minPoolSize"] = self.__min_pool_size if self.__max_idle_time_seconds != MAX_IDLE_TIME_SEC: - opts['maxIdleTimeMS'] = self.__max_idle_time_seconds * 1000 + opts["maxIdleTimeMS"] = self.__max_idle_time_seconds * 1000 if self.__wait_queue_timeout != WAIT_QUEUE_TIMEOUT: - opts['waitQueueTimeoutMS'] = self.__wait_queue_timeout * 1000 + opts["waitQueueTimeoutMS"] = self.__wait_queue_timeout * 1000 if self.__max_connecting != MAX_CONNECTING: - opts['maxConnecting'] = self.__max_connecting + opts["maxConnecting"] = self.__max_connecting return opts @property @@ -387,14 +434,12 @@ def max_idle_time_seconds(self): @property def connect_timeout(self): - """How long a connection can take to be opened before timing out. - """ + """How long a connection can take to be opened before timing out.""" return self.__connect_timeout @property def socket_timeout(self): - """How long a send or receive on a socket can take before timing out. - """ + """How long a send or receive on a socket can take before timing out.""" return self.__socket_timeout @property @@ -406,32 +451,27 @@ def wait_queue_timeout(self): @property def _ssl_context(self): - """An SSLContext instance or None. - """ + """An SSLContext instance or None.""" return self.__ssl_context @property def tls_allow_invalid_hostnames(self): - """If True skip ssl.match_hostname. - """ + """If True skip ssl.match_hostname.""" return self.__tls_allow_invalid_hostnames @property def _event_listeners(self): - """An instance of pymongo.monitoring._EventListeners. - """ + """An instance of pymongo.monitoring._EventListeners.""" return self.__event_listeners @property def appname(self): - """The application name, for sending with hello in server handshake. - """ + """The application name, for sending with hello in server handshake.""" return self.__appname @property def driver(self): - """Driver name and version, for sending with hello in handshake. - """ + """Driver name and version, for sending with hello in handshake.""" return self.__driver @property @@ -440,36 +480,31 @@ def _compression_settings(self): @property def metadata(self): - """A dict of metadata about the application, driver, os, and platform. - """ + """A dict of metadata about the application, driver, os, and platform.""" return self.__metadata.copy() @property def server_api(self): - """A pymongo.server_api.ServerApi or None. - """ + """A pymongo.server_api.ServerApi or None.""" return self.__server_api @property def load_balanced(self): - """True if this Pool is configured in load balanced mode. - """ + """True if this Pool is configured in load balanced mode.""" return self.__load_balanced def _negotiate_creds(all_credentials): - """Return one credential that needs mechanism negotiation, if any. - """ + """Return one credential that needs mechanism negotiation, if any.""" if all_credentials: for creds in all_credentials.values(): - if creds.mechanism == 'DEFAULT' and creds.username: + if creds.mechanism == "DEFAULT" and creds.username: return creds return None def _speculative_context(all_credentials): - """Return the _AuthContext to use for speculative auth, if any. - """ + """Return the _AuthContext to use for speculative auth, if any.""" if all_credentials and len(all_credentials) == 1: creds = next(iter(all_credentials.values())) return auth._AuthContext.from_credentials(creds) @@ -499,6 +534,7 @@ class SocketInfo(object): - `address`: the server's (host, port) - `id`: the id of this socket in it's pool """ + def __init__(self, sock, pool, address, id): self.pool_ref = weakref.ref(pool) self.sock = sock @@ -565,65 +601,60 @@ def hello_cmd(self): if self.opts.server_api or self.hello_ok: return SON([(HelloCompat.CMD, 1)]) else: - return SON([(HelloCompat.LEGACY_CMD, 1), ('helloOk', True)]) + return SON([(HelloCompat.LEGACY_CMD, 1), ("helloOk", True)]) def hello(self, all_credentials=None): return self._hello(None, None, None, all_credentials) - def _hello(self, cluster_time, topology_version, - heartbeat_frequency, all_credentials): + def _hello(self, cluster_time, topology_version, heartbeat_frequency, all_credentials): cmd = self.hello_cmd() performing_handshake = not self.performed_handshake awaitable = False if performing_handshake: self.performed_handshake = True - cmd['client'] = self.opts.metadata + cmd["client"] = self.opts.metadata if self.compression_settings: - cmd['compression'] = self.compression_settings.compressors + cmd["compression"] = self.compression_settings.compressors if self.opts.load_balanced: - cmd['loadBalanced'] = True + cmd["loadBalanced"] = True elif topology_version is not None: - cmd['topologyVersion'] = topology_version - cmd['maxAwaitTimeMS'] = int(heartbeat_frequency*1000) + cmd["topologyVersion"] = topology_version + cmd["maxAwaitTimeMS"] = int(heartbeat_frequency * 1000) awaitable = True # If connect_timeout is None there is no timeout. if self.opts.connect_timeout: - self.sock.settimeout( - self.opts.connect_timeout + heartbeat_frequency) + self.sock.settimeout(self.opts.connect_timeout + heartbeat_frequency) if not performing_handshake and cluster_time is not None: - cmd['$clusterTime'] = cluster_time + cmd["$clusterTime"] = cluster_time # XXX: Simplify in PyMongo 4.0 when all_credentials is always a single # unchangeable value per MongoClient. creds = _negotiate_creds(all_credentials) if creds: - cmd['saslSupportedMechs'] = creds.source + '.' + creds.username + cmd["saslSupportedMechs"] = creds.source + "." + creds.username auth_ctx = _speculative_context(all_credentials) if auth_ctx: - cmd['speculativeAuthenticate'] = auth_ctx.speculate_command() + cmd["speculativeAuthenticate"] = auth_ctx.speculate_command() - doc = self.command('admin', cmd, publish_events=False, - exhaust_allowed=awaitable) + doc = self.command("admin", cmd, publish_events=False, exhaust_allowed=awaitable) # PYTHON-2712 will remove this topologyVersion fallback logic. if self.opts.load_balanced and _MOCK_SERVICE_ID: - process_id = doc.get('topologyVersion', {}).get('processId') - doc.setdefault('serviceId', process_id) + process_id = doc.get("topologyVersion", {}).get("processId") + doc.setdefault("serviceId", process_id) if not self.opts.load_balanced: - doc.pop('serviceId', None) + doc.pop("serviceId", None) hello = Hello(doc, awaitable=awaitable) self.is_writable = hello.is_writable self.max_wire_version = hello.max_wire_version self.max_bson_size = hello.max_bson_size self.max_message_size = hello.max_message_size self.max_write_batch_size = hello.max_write_batch_size - self.supports_sessions = ( - hello.logical_session_timeout_minutes is not None) + self.supports_sessions = hello.logical_session_timeout_minutes is not None self.hello_ok = hello.hello_ok self.is_mongos = hello.server_type == SERVER_TYPE.Mongos if performing_handshake and self.compression_settings: - ctx = self.compression_settings.get_compression_context( - hello.compressors) + ctx = self.compression_settings.get_compression_context(hello.compressors) self.compression_context = ctx self.op_msg_enabled = True @@ -636,8 +667,9 @@ def _hello(self, cluster_time, topology_version, if self.opts.load_balanced: if not hello.service_id: raise ConfigurationError( - 'Driver attempted to initialize in load balancing mode,' - ' but the server does not support this mode') + "Driver attempted to initialize in load balancing mode," + " but the server does not support this mode" + ) self.service_id = hello.service_id self.generation = self.pool_gen.get(self.service_id) return hello @@ -650,23 +682,30 @@ def _next_reply(self): helpers._check_command_response(response_doc, self.max_wire_version) # Remove after PYTHON-2712. if not self.opts.load_balanced: - response_doc.pop('serviceId', None) + response_doc.pop("serviceId", None) return response_doc - def command(self, dbname, spec, secondary_ok=False, - read_preference=ReadPreference.PRIMARY, - codec_options=DEFAULT_CODEC_OPTIONS, check=True, - allowable_errors=None, check_keys=False, - read_concern=None, - write_concern=None, - parse_write_concern_error=False, - collation=None, - session=None, - client=None, - retryable_write=False, - publish_events=True, - user_fields=None, - exhaust_allowed=False): + def command( + self, + dbname, + spec, + secondary_ok=False, + read_preference=ReadPreference.PRIMARY, + codec_options=DEFAULT_CODEC_OPTIONS, + check=True, + allowable_errors=None, + check_keys=False, + read_concern=None, + write_concern=None, + parse_write_concern_error=False, + collation=None, + session=None, + client=None, + retryable_write=False, + publish_events=True, + user_fields=None, + exhaust_allowed=False, + ): """Execute a command or raise an error. :Parameters: @@ -698,36 +737,45 @@ def command(self, dbname, spec, secondary_ok=False, if not isinstance(spec, ORDERED_TYPES): spec = SON(spec) - if not (write_concern is None or write_concern.acknowledged or - collation is None): - raise ConfigurationError( - 'Collation is unsupported for unacknowledged writes.') - if (write_concern and - not write_concern.is_server_default): - spec['writeConcern'] = write_concern.document + if not (write_concern is None or write_concern.acknowledged or collation is None): + raise ConfigurationError("Collation is unsupported for unacknowledged writes.") + if write_concern and not write_concern.is_server_default: + spec["writeConcern"] = write_concern.document self.add_server_api(spec) if session: - session._apply_to(spec, retryable_write, read_preference, - self) + session._apply_to(spec, retryable_write, read_preference, self) self.send_cluster_time(spec, session, client) listeners = self.listeners if publish_events else None unacknowledged = write_concern and not write_concern.acknowledged if self.op_msg_enabled: self._raise_if_not_writable(unacknowledged) try: - return command(self, dbname, spec, secondary_ok, - self.is_mongos, read_preference, codec_options, - session, client, check, allowable_errors, - self.address, check_keys, listeners, - self.max_bson_size, read_concern, - parse_write_concern_error=parse_write_concern_error, - collation=collation, - compression_ctx=self.compression_context, - use_op_msg=self.op_msg_enabled, - unacknowledged=unacknowledged, - user_fields=user_fields, - exhaust_allowed=exhaust_allowed) + return command( + self, + dbname, + spec, + secondary_ok, + self.is_mongos, + read_preference, + codec_options, + session, + client, + check, + allowable_errors, + self.address, + check_keys, + listeners, + self.max_bson_size, + read_concern, + parse_write_concern_error=parse_write_concern_error, + collation=collation, + compression_ctx=self.compression_context, + use_op_msg=self.op_msg_enabled, + unacknowledged=unacknowledged, + user_fields=user_fields, + exhaust_allowed=exhaust_allowed, + ) except (OperationFailure, NotPrimaryError): raise # Catch socket.error, KeyboardInterrupt, etc. and close ourselves. @@ -739,12 +787,11 @@ def send_message(self, message, max_doc_size): If a network exception is raised, the socket is closed. """ - if (self.max_bson_size is not None - and max_doc_size > self.max_bson_size): + if self.max_bson_size is not None and max_doc_size > self.max_bson_size: raise DocumentTooLarge( "BSON document too large (%d bytes) - the connected server " - "supports BSON document sizes up to %d bytes." % - (max_doc_size, self.max_bson_size)) + "supports BSON document sizes up to %d bytes." % (max_doc_size, self.max_bson_size) + ) try: self.sock.sendall(message) @@ -767,8 +814,7 @@ def _raise_if_not_writable(self, unacknowledged): """ if unacknowledged and not self.is_writable: # Write won't succeed, bail as if we'd received a not primary error. - raise NotPrimaryError("not primary", { - "ok": 0, "errmsg": "not primary", "code": 10107}) + raise NotPrimaryError("not primary", {"ok": 0, "errmsg": "not primary", "code": 10107}) def unack_write(self, msg, max_doc_size): """Send unack OP_MSG. @@ -842,8 +888,8 @@ def validate_session(self, client, session): if session: if session._client is not client: raise InvalidOperation( - 'Can only use session with the MongoClient that' - ' started it') + "Can only use session with the MongoClient that" " started it" + ) def close_socket(self, reason): """Close this connection with a reason.""" @@ -851,8 +897,7 @@ def close_socket(self, reason): return self._close_socket() if reason and self.enabled_for_cmap: - self.listeners.publish_connection_closed( - self.address, self.id, reason) + self.listeners.publish_connection_closed(self.address, self.id, reason) def _close_socket(self): """Close this connection.""" @@ -932,7 +977,7 @@ def __repr__(self): return "SocketInfo(%s)%s at %s" % ( repr(self.sock), self.closed and " CLOSED" or "", - id(self) + id(self), ) @@ -946,10 +991,9 @@ def _create_connection(address, options): host, port = address # Check if dealing with a unix domain socket - if host.endswith('.sock'): + if host.endswith(".sock"): if not hasattr(socket, "AF_UNIX"): - raise ConnectionFailure("UNIX-sockets are not supported " - "on this system") + raise ConnectionFailure("UNIX-sockets are not supported " "on this system") sock = socket.socket(socket.AF_UNIX) # SOCK_CLOEXEC not supported for Unix sockets. _set_non_inheritable_non_atomic(sock.fileno()) @@ -964,7 +1008,7 @@ def _create_connection(address, options): # is 'localhost' (::1 is fine). Avoids slow connect issues # like PYTHON-356. family = socket.AF_INET - if socket.has_ipv6 and host != 'localhost': + if socket.has_ipv6 and host != "localhost": family = socket.AF_UNSPEC err = None @@ -974,8 +1018,7 @@ def _create_connection(address, options): # number of platforms (newer Linux and *BSD). Starting with CPython 3.4 # all file descriptors are created non-inheritable. See PEP 446. try: - sock = socket.socket( - af, socktype | getattr(socket, 'SOCK_CLOEXEC', 0), proto) + sock = socket.socket(af, socktype | getattr(socket, "SOCK_CLOEXEC", 0), proto) except socket.error: # Can SOCK_CLOEXEC be defined even if the kernel doesn't support # it? @@ -1000,7 +1043,7 @@ def _create_connection(address, options): # host with an OS/kernel or Python interpreter that doesn't # support IPv6. The test case is Jython2.5.1 which doesn't # support IPv6 at all. - raise socket.error('getaddrinfo failed') + raise socket.error("getaddrinfo failed") def _configured_socket(address, options): @@ -1038,9 +1081,11 @@ def _configured_socket(address, options): # failures alike. Permanent handshake failures, like protocol # mismatch, will be turned into ServerSelectionTimeoutErrors later. _raise_connection_failure(address, exc, "SSL handshake failed: ") - if (ssl_context.verify_mode and not - getattr(ssl_context, "check_hostname", False) and - not options.tls_allow_invalid_hostnames): + if ( + ssl_context.verify_mode + and not getattr(ssl_context, "check_hostname", False) + and not options.tls_allow_invalid_hostnames + ): try: ssl.match_hostname(sock.getpeercert(), hostname=host) except _CertificateError: @@ -1055,6 +1100,7 @@ class _PoolClosedError(PyMongoError): """Internal error raised when a thread tries to get a connection from a closed pool. """ + pass @@ -1133,9 +1179,10 @@ def __init__(self, address, options, handshake=True): self.handshake = handshake # Don't publish events in Monitor pools. self.enabled_for_cmap = ( - self.handshake and - self.opts._event_listeners is not None and - self.opts._event_listeners.enabled_for_cmap) + self.handshake + and self.opts._event_listeners is not None + and self.opts._event_listeners.enabled_for_cmap + ) # The first portion of the wait queue. # Enforces: maxPoolSize @@ -1144,7 +1191,7 @@ def __init__(self, address, options, handshake=True): self.requests = 0 self.max_pool_size = self.opts.max_pool_size if not self.max_pool_size: - self.max_pool_size = float('inf') + self.max_pool_size = float("inf") # The second portion of the wait queue. # Enforces: maxConnecting # Also used for: clearing the wait queue @@ -1153,7 +1200,8 @@ def __init__(self, address, options, handshake=True): self._pending = 0 if self.enabled_for_cmap: self.opts._event_listeners.publish_pool_created( - self.address, self.opts.non_default_options) + self.address, self.opts.non_default_options + ) # Similar to active_sockets but includes threads in the wait queue. self.operation_count = 0 # Retain references to pinned connections to prevent the CPython GC @@ -1180,8 +1228,7 @@ def _reset(self, close, pause=True, service_id=None): with self.size_cond: if self.closed: return - if (self.opts.pause_enabled and pause and - not self.opts.load_balanced): + if self.opts.pause_enabled and pause and not self.opts.load_balanced: old_state, self.state = self.state, PoolState.PAUSED self.gen.inc(service_id) newpid = os.getpid() @@ -1219,8 +1266,7 @@ def _reset(self, close, pause=True, service_id=None): listeners.publish_pool_closed(self.address) else: if old_state != PoolState.PAUSED and self.enabled_for_cmap: - listeners.publish_pool_cleared(self.address, - service_id=service_id) + listeners.publish_pool_cleared(self.address, service_id=service_id) for sock_info in sockets: sock_info.close_socket(ConnectionClosedReason.STALE) @@ -1258,16 +1304,17 @@ def remove_stale_sockets(self, reference_generation, all_credentials): if self.opts.max_idle_time_seconds is not None: with self.lock: - while (self.sockets and - self.sockets[-1].idle_time_seconds() > self.opts.max_idle_time_seconds): + while ( + self.sockets + and self.sockets[-1].idle_time_seconds() > self.opts.max_idle_time_seconds + ): sock_info = self.sockets.pop() sock_info.close_socket(ConnectionClosedReason.IDLE) while True: with self.size_cond: # There are enough sockets in the pool. - if (len(self.sockets) + self.active_sockets >= - self.opts.min_pool_size): + if len(self.sockets) + self.active_sockets >= self.opts.min_pool_size: return if self.requests >= self.opts.min_pool_size: return @@ -1321,7 +1368,8 @@ def connect(self, all_credentials=None): except BaseException as error: if self.enabled_for_cmap: listeners.publish_connection_closed( - self.address, conn_id, ConnectionClosedReason.ERROR) + self.address, conn_id, ConnectionClosedReason.ERROR + ) if isinstance(error, (IOError, OSError, _SSLError)): _raise_connection_failure(self.address, error) @@ -1370,8 +1418,7 @@ def get_socket(self, all_credentials, handler=None): sock_info = self._get_socket(all_credentials) if self.enabled_for_cmap: - listeners.publish_connection_checked_out( - self.address, sock_info.id) + listeners.publish_connection_checked_out(self.address, sock_info.id) try: yield sock_info except: @@ -1403,9 +1450,9 @@ def _raise_if_not_ready(self, emit_event): if self.state != PoolState.READY: if self.enabled_for_cmap and emit_event: self.opts._event_listeners.publish_connection_check_out_failed( - self.address, ConnectionCheckOutFailedReason.CONN_ERROR) - _raise_connection_failure( - self.address, AutoReconnect('connection pool paused')) + self.address, ConnectionCheckOutFailedReason.CONN_ERROR + ) + _raise_connection_failure(self.address, AutoReconnect("connection pool paused")) def _get_socket(self, all_credentials): """Get or create a SocketInfo. Can raise ConnectionFailure.""" @@ -1418,10 +1465,11 @@ def _get_socket(self, all_credentials): if self.closed: if self.enabled_for_cmap: self.opts._event_listeners.publish_connection_check_out_failed( - self.address, ConnectionCheckOutFailedReason.POOL_CLOSED) + self.address, ConnectionCheckOutFailedReason.POOL_CLOSED + ) raise _PoolClosedError( - 'Attempted to check out a connection from closed connection ' - 'pool') + "Attempted to check out a connection from closed connection " "pool" + ) with self.lock: self.operation_count += 1 @@ -1458,13 +1506,11 @@ def _get_socket(self, all_credentials): # to be checked back into the pool. with self._max_connecting_cond: self._raise_if_not_ready(emit_event=False) - while not (self.sockets or - self._pending < self._max_connecting): + while not (self.sockets or self._pending < self._max_connecting): if not _cond_wait(self._max_connecting_cond, deadline): # Timed out, notify the next thread to ensure a # timeout doesn't consume the condition. - if (self.sockets or - self._pending < self._max_connecting): + if self.sockets or self._pending < self._max_connecting: self._max_connecting_cond.notify() emitted_event = True self._raise_wait_queue_timeout() @@ -1498,7 +1544,8 @@ def _get_socket(self, all_credentials): if self.enabled_for_cmap and not emitted_event: self.opts._event_listeners.publish_connection_check_out_failed( - self.address, ConnectionCheckOutFailedReason.CONN_ERROR) + self.address, ConnectionCheckOutFailedReason.CONN_ERROR + ) raise sock_info.active = True @@ -1528,14 +1575,13 @@ def return_socket(self, sock_info): # CMAP requires the closed event be emitted after the check in. if self.enabled_for_cmap: listeners.publish_connection_closed( - self.address, sock_info.id, - ConnectionClosedReason.ERROR) + self.address, sock_info.id, ConnectionClosedReason.ERROR + ) else: with self.lock: # Hold the lock to ensure this section does not race with # Pool.reset(). - if self.stale_generation(sock_info.generation, - sock_info.service_id): + if self.stale_generation(sock_info.generation, sock_info.service_id): sock_info.close_socket(ConnectionClosedReason.STALE) else: sock_info.update_last_checkin_time() @@ -1570,14 +1616,16 @@ def _perished(self, sock_info): """ idle_time_seconds = sock_info.idle_time_seconds() # If socket is idle, open a new one. - if (self.opts.max_idle_time_seconds is not None and - idle_time_seconds > self.opts.max_idle_time_seconds): + if ( + self.opts.max_idle_time_seconds is not None + and idle_time_seconds > self.opts.max_idle_time_seconds + ): sock_info.close_socket(ConnectionClosedReason.IDLE) return True - if (self._check_interval_seconds is not None and ( - 0 == self._check_interval_seconds or - idle_time_seconds > self._check_interval_seconds)): + if self._check_interval_seconds is not None and ( + 0 == self._check_interval_seconds or idle_time_seconds > self._check_interval_seconds + ): if sock_info.socket_closed(): sock_info.close_socket(ConnectionClosedReason.ERROR) return True @@ -1592,20 +1640,28 @@ def _raise_wait_queue_timeout(self): listeners = self.opts._event_listeners if self.enabled_for_cmap: listeners.publish_connection_check_out_failed( - self.address, ConnectionCheckOutFailedReason.TIMEOUT) + self.address, ConnectionCheckOutFailedReason.TIMEOUT + ) if self.opts.load_balanced: other_ops = self.active_sockets - self.ncursors - self.ntxns raise ConnectionFailure( - 'Timeout waiting for connection from the connection pool. ' - 'maxPoolSize: %s, connections in use by cursors: %s, ' - 'connections in use by transactions: %s, connections in use ' - 'by other operations: %s, wait_queue_timeout: %s' % ( - self.opts.max_pool_size, self.ncursors, self.ntxns, - other_ops, self.opts.wait_queue_timeout)) + "Timeout waiting for connection from the connection pool. " + "maxPoolSize: %s, connections in use by cursors: %s, " + "connections in use by transactions: %s, connections in use " + "by other operations: %s, wait_queue_timeout: %s" + % ( + self.opts.max_pool_size, + self.ncursors, + self.ntxns, + other_ops, + self.opts.wait_queue_timeout, + ) + ) raise ConnectionFailure( - 'Timed out while checking out a connection from connection pool. ' - 'maxPoolSize: %s, wait_queue_timeout: %s' % ( - self.opts.max_pool_size, self.opts.wait_queue_timeout)) + "Timed out while checking out a connection from connection pool. " + "maxPoolSize: %s, wait_queue_timeout: %s" + % (self.opts.max_pool_size, self.opts.wait_queue_timeout) + ) def __del__(self): # Avoid ResourceWarnings in Python 3 diff --git a/pymongo/pyopenssl_context.py b/pymongo/pyopenssl_context.py index f7c53a59e5..95c736637d 100644 --- a/pymongo/pyopenssl_context.py +++ b/pymongo/pyopenssl_context.py @@ -20,32 +20,27 @@ import ssl as _stdlibssl import sys as _sys import time as _time - from errno import EINTR as _EINTR - from ipaddress import ip_address as _ip_address from cryptography.x509 import load_der_x509_certificate as _load_der_x509_certificate -from OpenSSL import crypto as _crypto, SSL as _SSL -from service_identity.pyopenssl import ( - verify_hostname as _verify_hostname, - verify_ip_address as _verify_ip_address) -from service_identity import ( - CertificateError as _SICertificateError, - VerificationError as _SIVerificationError) - -from pymongo.errors import ( - _CertificateError, - ConfigurationError as _ConfigurationError) -from pymongo.ocsp_support import ( - _load_trusted_ca_certs, - _ocsp_callback) +from OpenSSL import SSL as _SSL +from OpenSSL import crypto as _crypto +from service_identity import CertificateError as _SICertificateError +from service_identity import VerificationError as _SIVerificationError +from service_identity.pyopenssl import verify_hostname as _verify_hostname +from service_identity.pyopenssl import verify_ip_address as _verify_ip_address + +from pymongo.errors import ConfigurationError as _ConfigurationError +from pymongo.errors import _CertificateError from pymongo.ocsp_cache import _OCSPCache -from pymongo.socket_checker import ( - _errno_from_exception, SocketChecker as _SocketChecker) +from pymongo.ocsp_support import _load_trusted_ca_certs, _ocsp_callback +from pymongo.socket_checker import SocketChecker as _SocketChecker +from pymongo.socket_checker import _errno_from_exception try: import certifi + _HAVE_CERTIFI = True except ImportError: _HAVE_CERTIFI = False @@ -70,11 +65,11 @@ _VERIFY_MAP = { _stdlibssl.CERT_NONE: _SSL.VERIFY_NONE, _stdlibssl.CERT_OPTIONAL: _SSL.VERIFY_PEER, - _stdlibssl.CERT_REQUIRED: _SSL.VERIFY_PEER | _SSL.VERIFY_FAIL_IF_NO_PEER_CERT + _stdlibssl.CERT_REQUIRED: _SSL.VERIFY_PEER | _SSL.VERIFY_FAIL_IF_NO_PEER_CERT, } -_REVERSE_VERIFY_MAP = dict( - (value, key) for key, value in _VERIFY_MAP.items()) +_REVERSE_VERIFY_MAP = dict((value, key) for key, value in _VERIFY_MAP.items()) + def _is_ip_address(address): try: @@ -83,22 +78,21 @@ def _is_ip_address(address): except (ValueError, UnicodeError): return False + # According to the docs for Connection.send it can raise # WantX509LookupError and should be retried. -_RETRY_ERRORS = ( - _SSL.WantReadError, _SSL.WantWriteError, _SSL.WantX509LookupError) +_RETRY_ERRORS = (_SSL.WantReadError, _SSL.WantWriteError, _SSL.WantX509LookupError) def _ragged_eof(exc): """Return True if the OpenSSL.SSL.SysCallError is a ragged EOF.""" - return exc.args == (-1, 'Unexpected EOF') + return exc.args == (-1, "Unexpected EOF") # https://github.com/pyca/pyopenssl/issues/168 # https://github.com/pyca/pyopenssl/issues/176 # https://docs.python.org/3/library/ssl.html#notes-on-non-blocking-sockets class _sslConn(_SSL.Connection): - def __init__(self, ctx, sock, suppress_ragged_eofs): self.socket_checker = _SocketChecker() self.suppress_ragged_eofs = suppress_ragged_eofs @@ -112,8 +106,7 @@ def _call(self, call, *args, **kwargs): try: return call(*args, **kwargs) except _RETRY_ERRORS: - self.socket_checker.select( - self, True, True, timeout) + self.socket_checker.select(self, True, True, timeout) if timeout and _time.monotonic() - start > timeout: raise _socket.timeout("timed out") continue @@ -146,8 +139,7 @@ def sendall(self, buf, flags=0): sent = 0 while total_sent < total_length: try: - sent = self._call( - super(_sslConn, self).send, view[total_sent:], flags) + sent = self._call(super(_sslConn, self).send, view[total_sent:], flags) # XXX: It's not clear if this can actually happen. PyOpenSSL # doesn't appear to have any interrupt handling, nor any interrupt # errors for OpenSSL connections. @@ -164,6 +156,7 @@ def sendall(self, buf, flags=0): class _CallbackData(object): """Data class which is passed to the OCSP callback.""" + def __init__(self): self.trusted_ca_certs = None self.check_ocsp_endpoint = None @@ -175,7 +168,7 @@ class SSLContext(object): context. """ - __slots__ = ('_protocol', '_ctx', '_callback_data', '_check_hostname') + __slots__ = ("_protocol", "_ctx", "_callback_data", "_check_hostname") def __init__(self, protocol): self._protocol = protocol @@ -187,8 +180,7 @@ def __init__(self, protocol): # side configuration and wrap_socket tries to support both client and # server side sockets. self._callback_data.check_ocsp_endpoint = True - self._ctx.set_ocsp_client_callback( - callback=_ocsp_callback, data=self._callback_data) + self._ctx.set_ocsp_client_callback(callback=_ocsp_callback, data=self._callback_data) @property def protocol(self): @@ -206,12 +198,14 @@ def __get_verify_mode(self): def __set_verify_mode(self, value): """Setter for verify_mode.""" + def _cb(connobj, x509obj, errnum, errdepth, retcode): # It seems we don't need to do anything here. Twisted doesn't, # and OpenSSL's SSL_CTX_set_verify let's you pass NULL # for the callback option. It's weird that PyOpenSSL requires # this. return retcode + self._ctx.set_verify(_VERIFY_MAP[value], _cb) verify_mode = property(__get_verify_mode, __set_verify_mode) @@ -234,8 +228,7 @@ def __set_check_ocsp_endpoint(self, value): raise TypeError("check_ocsp must be True or False") self._callback_data.check_ocsp_endpoint = value - check_ocsp_endpoint = property(__get_check_ocsp_endpoint, - __set_check_ocsp_endpoint) + check_ocsp_endpoint = property(__get_check_ocsp_endpoint, __set_check_ocsp_endpoint) def __get_options(self): # Calling set_options adds the option to the existing bitmask and @@ -263,11 +256,13 @@ def load_cert_chain(self, certfile, keyfile=None, password=None): # https://github.com/python/cpython/blob/v3.8.0/Modules/_ssl.c#L3930-L3971 # Password callback MUST be set first or it will be ignored. if password: + def _pwcb(max_length, prompt_twice, user_data): # XXX:We could check the password length against what OpenSSL # tells us is the max, but we can't raise an exception, so... # warn? - return password.encode('utf-8') + return password.encode("utf-8") + self._ctx.set_passwd_cb(_pwcb) self._ctx.use_certificate_chain_file(certfile) self._ctx.use_privatekey_file(keyfile or certfile) @@ -290,7 +285,8 @@ def _load_certifi(self): "tlsAllowInvalidCertificates is False but no system " "CA certificates could be loaded. Please install the " "certifi package, or provide a path to a CA file using " - "the tlsCAFile option") + "the tlsCAFile option" + ) def _load_wincerts(self, store): """Attempt to load CA certs from Windows trust store.""" @@ -300,8 +296,8 @@ def _load_wincerts(self, store): if encoding == "x509_asn": if trust is True or oid in trust: cert_store.add_cert( - _crypto.X509.from_cryptography( - _load_der_x509_certificate(cert))) + _crypto.X509.from_cryptography(_load_der_x509_certificate(cert)) + ) def load_default_certs(self): """A PyOpenSSL version of load_default_certs from CPython.""" @@ -310,7 +306,7 @@ def load_default_certs(self): # https://www.pyopenssl.org/en/stable/api/ssl.html#OpenSSL.SSL.Context.set_default_verify_paths if _sys.platform == "win32": try: - for storename in ('CA', 'ROOT'): + for storename in ("CA", "ROOT"): self._load_wincerts(storename) except PermissionError: # Fall back to certifi @@ -326,10 +322,15 @@ def set_default_verify_paths(self): # but not that same as CPython's. self._ctx.set_default_verify_paths() - def wrap_socket(self, sock, server_side=False, - do_handshake_on_connect=True, - suppress_ragged_eofs=True, - server_hostname=None, session=None): + def wrap_socket( + self, + sock, + server_side=False, + do_handshake_on_connect=True, + suppress_ragged_eofs=True, + server_hostname=None, + session=None, + ): """Wrap an existing Python socket sock and return a TLS socket object. """ @@ -343,7 +344,7 @@ def wrap_socket(self, sock, server_side=False, if server_hostname and not _is_ip_address(server_hostname): # XXX: Do this in a callback registered with # SSLContext.set_info_callback? See Twisted for an example. - ssl_conn.set_tlsext_host_name(server_hostname.encode('idna')) + ssl_conn.set_tlsext_host_name(server_hostname.encode("idna")) if self.verify_mode != _stdlibssl.CERT_NONE: # Request a stapled OCSP response. ssl_conn.request_ocsp() diff --git a/pymongo/read_concern.py b/pymongo/read_concern.py index 7e9cc4485c..1c1f010494 100644 --- a/pymongo/read_concern.py +++ b/pymongo/read_concern.py @@ -33,8 +33,7 @@ def __init__(self, level=None): if level is None or isinstance(level, str): self.__level = level else: - raise TypeError( - 'level must be a string or None.') + raise TypeError("level must be a string or None.") @property def level(self): @@ -45,7 +44,7 @@ def level(self): def ok_for_legacy(self): """Return ``True`` if this read concern is compatible with old wire protocol versions.""" - return self.level is None or self.level == 'local' + return self.level is None or self.level == "local" @property def document(self): @@ -57,7 +56,7 @@ def document(self): """ doc = {} if self.__level: - doc['level'] = self.level + doc["level"] = self.level return doc def __eq__(self, other): @@ -67,8 +66,8 @@ def __eq__(self, other): def __repr__(self): if self.level: - return 'ReadConcern(%s)' % self.level - return 'ReadConcern()' + return "ReadConcern(%s)" % self.level + return "ReadConcern()" DEFAULT_READ_CONCERN = ReadConcern() diff --git a/pymongo/read_preferences.py b/pymongo/read_preferences.py index c60240822d..1753df61ab 100644 --- a/pymongo/read_preferences.py +++ b/pymongo/read_preferences.py @@ -18,9 +18,10 @@ from pymongo import max_staleness_selectors from pymongo.errors import ConfigurationError -from pymongo.server_selectors import (member_with_tags_server_selector, - secondary_with_tags_server_selector) - +from pymongo.server_selectors import ( + member_with_tags_server_selector, + secondary_with_tags_server_selector, +) _PRIMARY = 0 _PRIMARY_PREFERRED = 1 @@ -30,41 +31,40 @@ _MONGOS_MODES = ( - 'primary', - 'primaryPreferred', - 'secondary', - 'secondaryPreferred', - 'nearest', + "primary", + "primaryPreferred", + "secondary", + "secondaryPreferred", + "nearest", ) def _validate_tag_sets(tag_sets): - """Validate tag sets for a MongoClient. - """ + """Validate tag sets for a MongoClient.""" if tag_sets is None: return tag_sets if not isinstance(tag_sets, list): - raise TypeError(( - "Tag sets %r invalid, must be a list") % (tag_sets,)) + raise TypeError(("Tag sets %r invalid, must be a list") % (tag_sets,)) if len(tag_sets) == 0: - raise ValueError(( - "Tag sets %r invalid, must be None or contain at least one set of" - " tags") % (tag_sets,)) + raise ValueError( + ("Tag sets %r invalid, must be None or contain at least one set of" " tags") + % (tag_sets,) + ) for tags in tag_sets: if not isinstance(tags, abc.Mapping): raise TypeError( "Tag set %r invalid, must be an instance of dict, " "bson.son.SON or other type that inherits from " - "collection.Mapping" % (tags,)) + "collection.Mapping" % (tags,) + ) return tag_sets def _invalid_max_staleness_msg(max_staleness): - return ("maxStalenessSeconds must be a positive integer, not %s" % - max_staleness) + return "maxStalenessSeconds must be a positive integer, not %s" % max_staleness # Some duplication with common.py to avoid import cycle. @@ -94,11 +94,9 @@ def _validate_hedge(hedge): class _ServerMode(object): - """Base class for all read preferences. - """ + """Base class for all read preferences.""" - __slots__ = ("__mongos_mode", "__mode", "__tag_sets", "__max_staleness", - "__hedge") + __slots__ = ("__mongos_mode", "__mode", "__tag_sets", "__max_staleness", "__hedge") def __init__(self, mode, tag_sets=None, max_staleness=-1, hedge=None): self.__mongos_mode = _MONGOS_MODES[mode] @@ -109,33 +107,29 @@ def __init__(self, mode, tag_sets=None, max_staleness=-1, hedge=None): @property def name(self): - """The name of this read preference. - """ + """The name of this read preference.""" return self.__class__.__name__ @property def mongos_mode(self): - """The mongos mode of this read preference. - """ + """The mongos mode of this read preference.""" return self.__mongos_mode @property def document(self): - """Read preference as a document. - """ - doc = {'mode': self.__mongos_mode} + """Read preference as a document.""" + doc = {"mode": self.__mongos_mode} if self.__tag_sets not in (None, [{}]): - doc['tags'] = self.__tag_sets + doc["tags"] = self.__tag_sets if self.__max_staleness != -1: - doc['maxStalenessSeconds'] = self.__max_staleness + doc["maxStalenessSeconds"] = self.__max_staleness if self.__hedge not in (None, {}): - doc['hedge'] = self.__hedge + doc["hedge"] = self.__hedge return doc @property def mode(self): - """The mode of this read preference instance. - """ + """The mode of this read preference instance.""" return self.__mode @property @@ -199,14 +193,20 @@ def min_wire_version(self): def __repr__(self): return "%s(tag_sets=%r, max_staleness=%r, hedge=%r)" % ( - self.name, self.__tag_sets, self.__max_staleness, self.__hedge) + self.name, + self.__tag_sets, + self.__max_staleness, + self.__hedge, + ) def __eq__(self, other): if isinstance(other, _ServerMode): - return (self.mode == other.mode and - self.tag_sets == other.tag_sets and - self.max_staleness == other.max_staleness and - self.hedge == other.hedge) + return ( + self.mode == other.mode + and self.tag_sets == other.tag_sets + and self.max_staleness == other.max_staleness + and self.hedge == other.hedge + ) return NotImplemented def __ne__(self, other): @@ -217,18 +217,20 @@ def __getstate__(self): Needed explicitly because __slots__() defined. """ - return {'mode': self.__mode, - 'tag_sets': self.__tag_sets, - 'max_staleness': self.__max_staleness, - 'hedge': self.__hedge} + return { + "mode": self.__mode, + "tag_sets": self.__tag_sets, + "max_staleness": self.__max_staleness, + "hedge": self.__hedge, + } def __setstate__(self, value): """Restore from pickling.""" - self.__mode = value['mode'] + self.__mode = value["mode"] self.__mongos_mode = _MONGOS_MODES[self.__mode] - self.__tag_sets = _validate_tag_sets(value['tag_sets']) - self.__max_staleness = _validate_max_staleness(value['max_staleness']) - self.__hedge = _validate_hedge(value['hedge']) + self.__tag_sets = _validate_tag_sets(value["tag_sets"]) + self.__max_staleness = _validate_max_staleness(value["max_staleness"]) + self.__hedge = _validate_hedge(value["hedge"]) class Primary(_ServerMode): @@ -290,8 +292,7 @@ class PrimaryPreferred(_ServerMode): __slots__ = () def __init__(self, tag_sets=None, max_staleness=-1, hedge=None): - super(PrimaryPreferred, self).__init__( - _PRIMARY_PREFERRED, tag_sets, max_staleness, hedge) + super(PrimaryPreferred, self).__init__(_PRIMARY_PREFERRED, tag_sets, max_staleness, hedge) def __call__(self, selection): """Apply this read preference to Selection.""" @@ -299,9 +300,8 @@ def __call__(self, selection): return selection.primary_selection else: return secondary_with_tags_server_selector( - self.tag_sets, - max_staleness_selectors.select( - self.max_staleness, selection)) + self.tag_sets, max_staleness_selectors.select(self.max_staleness, selection) + ) class Secondary(_ServerMode): @@ -330,15 +330,13 @@ class Secondary(_ServerMode): __slots__ = () def __init__(self, tag_sets=None, max_staleness=-1, hedge=None): - super(Secondary, self).__init__( - _SECONDARY, tag_sets, max_staleness, hedge) + super(Secondary, self).__init__(_SECONDARY, tag_sets, max_staleness, hedge) def __call__(self, selection): """Apply this read preference to Selection.""" return secondary_with_tags_server_selector( - self.tag_sets, - max_staleness_selectors.select( - self.max_staleness, selection)) + self.tag_sets, max_staleness_selectors.select(self.max_staleness, selection) + ) class SecondaryPreferred(_ServerMode): @@ -372,14 +370,14 @@ class SecondaryPreferred(_ServerMode): def __init__(self, tag_sets=None, max_staleness=-1, hedge=None): super(SecondaryPreferred, self).__init__( - _SECONDARY_PREFERRED, tag_sets, max_staleness, hedge) + _SECONDARY_PREFERRED, tag_sets, max_staleness, hedge + ) def __call__(self, selection): """Apply this read preference to Selection.""" secondaries = secondary_with_tags_server_selector( - self.tag_sets, - max_staleness_selectors.select( - self.max_staleness, selection)) + self.tag_sets, max_staleness_selectors.select(self.max_staleness, selection) + ) if secondaries: return secondaries @@ -413,39 +411,36 @@ class Nearest(_ServerMode): __slots__ = () def __init__(self, tag_sets=None, max_staleness=-1, hedge=None): - super(Nearest, self).__init__( - _NEAREST, tag_sets, max_staleness, hedge) + super(Nearest, self).__init__(_NEAREST, tag_sets, max_staleness, hedge) def __call__(self, selection): """Apply this read preference to Selection.""" return member_with_tags_server_selector( - self.tag_sets, - max_staleness_selectors.select( - self.max_staleness, selection)) + self.tag_sets, max_staleness_selectors.select(self.max_staleness, selection) + ) -_ALL_READ_PREFERENCES = (Primary, PrimaryPreferred, - Secondary, SecondaryPreferred, Nearest) +_ALL_READ_PREFERENCES = (Primary, PrimaryPreferred, Secondary, SecondaryPreferred, Nearest) def make_read_preference(mode, tag_sets, max_staleness=-1): if mode == _PRIMARY: if tag_sets not in (None, [{}]): - raise ConfigurationError("Read preference primary " - "cannot be combined with tags") + raise ConfigurationError("Read preference primary " "cannot be combined with tags") if max_staleness != -1: - raise ConfigurationError("Read preference primary cannot be " - "combined with maxStalenessSeconds") + raise ConfigurationError( + "Read preference primary cannot be " "combined with maxStalenessSeconds" + ) return Primary() return _ALL_READ_PREFERENCES[mode](tag_sets, max_staleness) _MODES = ( - 'PRIMARY', - 'PRIMARY_PREFERRED', - 'SECONDARY', - 'SECONDARY_PREFERRED', - 'NEAREST', + "PRIMARY", + "PRIMARY_PREFERRED", + "SECONDARY", + "SECONDARY_PREFERRED", + "NEAREST", ) @@ -499,6 +494,7 @@ class ReadPreference(object): - ``NEAREST``: Read from any shard member. """ + PRIMARY = Primary() PRIMARY_PREFERRED = PrimaryPreferred() SECONDARY = Secondary() @@ -507,13 +503,13 @@ class ReadPreference(object): def read_pref_mode_from_name(name): - """Get the read preference mode from mongos/uri name. - """ + """Get the read preference mode from mongos/uri name.""" return _MONGOS_MODES.index(name) class MovingAverage(object): """Tracks an exponentially-weighted moving average.""" + def __init__(self): self.average = None diff --git a/pymongo/response.py b/pymongo/response.py index 3094399da6..1369eac4e0 100644 --- a/pymongo/response.py +++ b/pymongo/response.py @@ -16,11 +16,9 @@ class Response(object): - __slots__ = ('_data', '_address', '_request_id', '_duration', - '_from_command', '_docs') + __slots__ = ("_data", "_address", "_request_id", "_duration", "_from_command", "_docs") - def __init__(self, data, address, request_id, duration, from_command, - docs): + def __init__(self, data, address, request_id, duration, from_command, docs): """Represent a response from the server. :Parameters: @@ -69,10 +67,11 @@ def docs(self): class PinnedResponse(Response): - __slots__ = ('_socket_info', '_more_to_come') + __slots__ = ("_socket_info", "_more_to_come") - def __init__(self, data, address, socket_info, request_id, duration, - from_command, docs, more_to_come): + def __init__( + self, data, address, socket_info, request_id, duration, from_command, docs, more_to_come + ): """Represent a response to an exhaust cursor's initial query. :Parameters: @@ -87,11 +86,9 @@ def __init__(self, data, address, socket_info, request_id, duration, - `more_to_come`: Bool indicating whether cursor is ready to be exhausted. """ - super(PinnedResponse, self).__init__(data, - address, - request_id, - duration, - from_command, docs) + super(PinnedResponse, self).__init__( + data, address, request_id, duration, from_command, docs + ) self._socket_info = socket_info self._more_to_come = more_to_come diff --git a/pymongo/results.py b/pymongo/results.py index a5025e9f48..c33cd75998 100644 --- a/pymongo/results.py +++ b/pymongo/results.py @@ -28,10 +28,12 @@ def __init__(self, acknowledged): def _raise_if_unacknowledged(self, property_name): """Raise an exception on property access if unacknowledged.""" if not self.__acknowledged: - raise InvalidOperation("A value for %s is not available when " - "the write is unacknowledged. Check the " - "acknowledged attribute to avoid this " - "error." % (property_name,)) + raise InvalidOperation( + "A value for %s is not available when " + "the write is unacknowledged. Check the " + "acknowledged attribute to avoid this " + "error." % (property_name,) + ) @property def acknowledged(self): @@ -54,8 +56,7 @@ def acknowledged(self): class InsertOneResult(_WriteResult): - """The return type for :meth:`~pymongo.collection.Collection.insert_one`. - """ + """The return type for :meth:`~pymongo.collection.Collection.insert_one`.""" __slots__ = ("__inserted_id", "__acknowledged") @@ -70,8 +71,7 @@ def inserted_id(self): class InsertManyResult(_WriteResult): - """The return type for :meth:`~pymongo.collection.Collection.insert_many`. - """ + """The return type for :meth:`~pymongo.collection.Collection.insert_many`.""" __slots__ = ("__inserted_ids", "__acknowledged") @@ -222,5 +222,6 @@ def upserted_ids(self): """A map of operation index to the _id of the upserted document.""" self._raise_if_unacknowledged("upserted_ids") if self.__bulk_api_result: - return dict((upsert["index"], upsert["_id"]) - for upsert in self.bulk_api_result["upserted"]) + return dict( + (upsert["index"], upsert["_id"]) for upsert in self.bulk_api_result["upserted"] + ) diff --git a/pymongo/saslprep.py b/pymongo/saslprep.py index 08a780c055..1619cdb5bb 100644 --- a/pymongo/saslprep.py +++ b/pymongo/saslprep.py @@ -19,16 +19,20 @@ import stringprep except ImportError: HAVE_STRINGPREP = False + def saslprep(data): """SASLprep dummy""" if isinstance(data, str): raise TypeError( "The stringprep module is not available. Usernames and " - "passwords must be instances of bytes.") + "passwords must be instances of bytes." + ) return data + else: HAVE_STRINGPREP = True import unicodedata + # RFC4013 section 2.3 prohibited output. _PROHIBITED = ( # A strict reading of RFC 4013 requires table c12 here, but @@ -42,7 +46,8 @@ def saslprep(data): stringprep.in_table_c6, stringprep.in_table_c7, stringprep.in_table_c8, - stringprep.in_table_c9) + stringprep.in_table_c9, + ) def saslprep(data, prohibit_unassigned_code_points=True): """An implementation of RFC4013 SASLprep. @@ -75,12 +80,12 @@ def saslprep(data, prohibit_unassigned_code_points=True): in_table_c12 = stringprep.in_table_c12 in_table_b1 = stringprep.in_table_b1 data = "".join( - ["\u0020" if in_table_c12(elt) else elt - for elt in data if not in_table_b1(elt)]) + ["\u0020" if in_table_c12(elt) else elt for elt in data if not in_table_b1(elt)] + ) # RFC3454 section 2, step 2 - Normalize # RFC4013 section 2.2 normalization - data = unicodedata.ucd_3_2_0.normalize('NFKC', data) + data = unicodedata.ucd_3_2_0.normalize("NFKC", data) in_table_d1 = stringprep.in_table_d1 if in_table_d1(data[0]): @@ -101,7 +106,6 @@ def saslprep(data, prohibit_unassigned_code_points=True): # RFC3454 section 2, step 3 and 4 - Prohibit and check bidi for char in data: if any(in_table(char) for in_table in prohibited): - raise ValueError( - "SASLprep: failed prohibited character check") + raise ValueError("SASLprep: failed prohibited character check") return data diff --git a/pymongo/server.py b/pymongo/server.py index 0a487e8c41..b4d9e8ceec 100644 --- a/pymongo/server.py +++ b/pymongo/server.py @@ -17,19 +17,19 @@ from datetime import datetime from bson import _decode_all_selective - from pymongo.errors import NotPrimaryError, OperationFailure from pymongo.helpers import _check_command_response from pymongo.message import _convert_exception, _OpMsg -from pymongo.response import Response, PinnedResponse +from pymongo.response import PinnedResponse, Response from pymongo.server_type import SERVER_TYPE -_CURSOR_DOC_FIELDS = {'cursor': {'firstBatch': 1, 'nextBatch': 1}} +_CURSOR_DOC_FIELDS = {"cursor": {"firstBatch": 1, "nextBatch": 1}} class Server(object): - def __init__(self, server_description, pool, monitor, topology_id=None, - listeners=None, events=None): + def __init__( + self, server_description, pool, monitor, topology_id=None, listeners=None, events=None + ): """Represent one MongoDB server.""" self._description = server_description self._pool = pool @@ -59,8 +59,12 @@ def close(self): Reconnect with open(). """ if self._publish: - self._events.put((self._listener.publish_server_closed, - (self._description.address, self._topology_id))) + self._events.put( + ( + self._listener.publish_server_closed, + (self._description.address, self._topology_id), + ) + ) self._monitor.close() self._pool.reset_without_pause() @@ -68,8 +72,7 @@ def request_check(self): """Check the server's state soon.""" self._monitor.request_check() - def run_operation(self, sock_info, operation, set_secondary_okay, listeners, - unpack_res): + def run_operation(self, sock_info, operation, set_secondary_okay, listeners, unpack_res): """Run a _Query or _GetMore operation and return a Response object. This method is used only to run _Query/_GetMore operations from @@ -89,20 +92,18 @@ def run_operation(self, sock_info, operation, set_secondary_okay, listeners, start = datetime.now() use_cmd = operation.use_command(sock_info) - more_to_come = (operation.sock_mgr - and operation.sock_mgr.more_to_come) + more_to_come = operation.sock_mgr and operation.sock_mgr.more_to_come if more_to_come: request_id = 0 else: - message = operation.get_message( - set_secondary_okay, sock_info, use_cmd) + message = operation.get_message(set_secondary_okay, sock_info, use_cmd) request_id, data, max_doc_size = self._split_message(message) if publish: cmd, dbn = operation.as_command(sock_info) listeners.publish_command_start( - cmd, dbn, request_id, sock_info.address, - service_id=sock_info.service_id) + cmd, dbn, request_id, sock_info.address, service_id=sock_info.service_id + ) start = datetime.now() try: @@ -119,10 +120,13 @@ def run_operation(self, sock_info, operation, set_secondary_okay, listeners, else: user_fields = None legacy_response = True - docs = unpack_res(reply, operation.cursor_id, - operation.codec_options, - legacy_response=legacy_response, - user_fields=user_fields) + docs = unpack_res( + reply, + operation.cursor_id, + operation.codec_options, + legacy_response=legacy_response, + user_fields=user_fields, + ) if use_cmd: first = docs[0] operation.client._process_response(first, operation.session) @@ -135,9 +139,13 @@ def run_operation(self, sock_info, operation, set_secondary_okay, listeners, else: failure = _convert_exception(exc) listeners.publish_command_failure( - duration, failure, operation.name, - request_id, sock_info.address, - service_id=sock_info.service_id) + duration, + failure, + operation.name, + request_id, + sock_info.address, + service_id=sock_info.service_id, + ) raise if publish: @@ -149,25 +157,26 @@ def run_operation(self, sock_info, operation, set_secondary_okay, listeners, elif operation.name == "explain": res = docs[0] if docs else {} else: - res = {"cursor": {"id": reply.cursor_id, - "ns": operation.namespace()}, - "ok": 1} + res = {"cursor": {"id": reply.cursor_id, "ns": operation.namespace()}, "ok": 1} if operation.name == "find": res["cursor"]["firstBatch"] = docs else: res["cursor"]["nextBatch"] = docs listeners.publish_command_success( - duration, res, operation.name, request_id, - sock_info.address, service_id=sock_info.service_id) + duration, + res, + operation.name, + request_id, + sock_info.address, + service_id=sock_info.service_id, + ) # Decrypt response. client = operation.client if client and client._encrypter: if use_cmd: - decrypted = client._encrypter.decrypt( - reply.raw_command_response()) - docs = _decode_all_selective( - decrypted, operation.codec_options, user_fields) + decrypted = client._encrypter.decrypt(reply.raw_command_response()) + docs = _decode_all_selective(decrypted, operation.codec_options, user_fields) if client._should_pin_cursor(operation.session) or operation.exhaust: sock_info.pin_cursor() @@ -188,7 +197,8 @@ def run_operation(self, sock_info, operation, set_secondary_okay, listeners, request_id=request_id, from_command=use_cmd, docs=docs, - more_to_come=more_to_come) + more_to_come=more_to_come, + ) else: response = Response( data=reply, @@ -196,7 +206,8 @@ def run_operation(self, sock_info, operation, set_secondary_okay, listeners, duration=duration, request_id=request_id, from_command=use_cmd, - docs=docs) + docs=docs, + ) return response @@ -230,4 +241,4 @@ def _split_message(self, message): return request_id, data, 0 def __repr__(self): - return '<%s %r>' % (self.__class__.__name__, self._description) + return "<%s %r>" % (self.__class__.__name__, self._description) diff --git a/pymongo/server_api.py b/pymongo/server_api.py index 4a1b925ca9..110406366a 100644 --- a/pymongo/server_api.py +++ b/pymongo/server_api.py @@ -97,6 +97,7 @@ class ServerApiVersion: class ServerApi(object): """MongoDB Versioned API.""" + def __init__(self, version, strict=None, deprecation_errors=None): """Options to configure MongoDB Versioned API. @@ -116,12 +117,13 @@ def __init__(self, version, strict=None, deprecation_errors=None): if strict is not None and not isinstance(strict, bool): raise TypeError( "Wrong type for ServerApi strict, value must be an instance " - "of bool, not %s" % (type(strict),)) - if (deprecation_errors is not None and - not isinstance(deprecation_errors, bool)): + "of bool, not %s" % (type(strict),) + ) + if deprecation_errors is not None and not isinstance(deprecation_errors, bool): raise TypeError( "Wrong type for ServerApi deprecation_errors, value must be " - "an instance of bool, not %s" % (type(deprecation_errors),)) + "an instance of bool, not %s" % (type(deprecation_errors),) + ) self._version = version self._strict = strict self._deprecation_errors = deprecation_errors @@ -161,8 +163,8 @@ def _add_to_command(cmd, server_api): """ if not server_api: return - cmd['apiVersion'] = server_api.version + cmd["apiVersion"] = server_api.version if server_api.strict is not None: - cmd['apiStrict'] = server_api.strict + cmd["apiStrict"] = server_api.strict if server_api.deprecation_errors is not None: - cmd['apiDeprecationErrors'] = server_api.deprecation_errors + cmd["apiDeprecationErrors"] = server_api.deprecation_errors diff --git a/pymongo/server_description.py b/pymongo/server_description.py index 2cbf6d63cd..abfd588d12 100644 --- a/pymongo/server_description.py +++ b/pymongo/server_description.py @@ -17,8 +17,8 @@ import time from bson import EPOCH_NAIVE -from pymongo.server_type import SERVER_TYPE from pymongo.hello import Hello +from pymongo.server_type import SERVER_TYPE class ServerDescription(object): @@ -32,20 +32,32 @@ class ServerDescription(object): """ __slots__ = ( - '_address', '_server_type', '_all_hosts', '_tags', '_replica_set_name', - '_primary', '_max_bson_size', '_max_message_size', - '_max_write_batch_size', '_min_wire_version', '_max_wire_version', - '_round_trip_time', '_me', '_is_writable', '_is_readable', - '_ls_timeout_minutes', '_error', '_set_version', '_election_id', - '_cluster_time', '_last_write_date', '_last_update_time', - '_topology_version') - - def __init__( - self, - address, - hello=None, - round_trip_time=None, - error=None): + "_address", + "_server_type", + "_all_hosts", + "_tags", + "_replica_set_name", + "_primary", + "_max_bson_size", + "_max_message_size", + "_max_write_batch_size", + "_min_wire_version", + "_max_wire_version", + "_round_trip_time", + "_me", + "_is_writable", + "_is_readable", + "_ls_timeout_minutes", + "_error", + "_set_version", + "_election_id", + "_cluster_time", + "_last_write_date", + "_last_update_time", + "_topology_version", + ) + + def __init__(self, address, hello=None, round_trip_time=None, error=None): self._address = address if not hello: hello = Hello({}) @@ -72,8 +84,8 @@ def __init__( self._error = error self._topology_version = hello.topology_version if error: - if hasattr(error, 'details') and isinstance(error.details, dict): - self._topology_version = error.details.get('topologyVersion') + if hasattr(error, "details") and isinstance(error.details, dict): + self._topology_version = error.details.get("topologyVersion") if hello.last_write_date: # Convert from datetime to seconds. @@ -204,10 +216,10 @@ def is_server_type_known(self): @property def retryable_writes_supported(self): """Checks if this server supports retryable writes.""" - return (( - self._ls_timeout_minutes is not None and - self._server_type in (SERVER_TYPE.Mongos, SERVER_TYPE.RSPrimary)) - or self._server_type == SERVER_TYPE.LoadBalancer) + return ( + self._ls_timeout_minutes is not None + and self._server_type in (SERVER_TYPE.Mongos, SERVER_TYPE.RSPrimary) + ) or self._server_type == SERVER_TYPE.LoadBalancer @property def retryable_reads_supported(self): @@ -225,20 +237,21 @@ def to_unknown(self, error=None): def __eq__(self, other): if isinstance(other, ServerDescription): - return ((self._address == other.address) and - (self._server_type == other.server_type) and - (self._min_wire_version == other.min_wire_version) and - (self._max_wire_version == other.max_wire_version) and - (self._me == other.me) and - (self._all_hosts == other.all_hosts) and - (self._tags == other.tags) and - (self._replica_set_name == other.replica_set_name) and - (self._set_version == other.set_version) and - (self._election_id == other.election_id) and - (self._primary == other.primary) and - (self._ls_timeout_minutes == - other.logical_session_timeout_minutes) and - (self._error == other.error)) + return ( + (self._address == other.address) + and (self._server_type == other.server_type) + and (self._min_wire_version == other.min_wire_version) + and (self._max_wire_version == other.max_wire_version) + and (self._me == other.me) + and (self._all_hosts == other.all_hosts) + and (self._tags == other.tags) + and (self._replica_set_name == other.replica_set_name) + and (self._set_version == other.set_version) + and (self._election_id == other.election_id) + and (self._primary == other.primary) + and (self._ls_timeout_minutes == other.logical_session_timeout_minutes) + and (self._error == other.error) + ) return NotImplemented @@ -246,12 +259,16 @@ def __ne__(self, other): return not self == other def __repr__(self): - errmsg = '' + errmsg = "" if self.error: - errmsg = ', error=%r' % (self.error,) + errmsg = ", error=%r" % (self.error,) return "<%s %s server_type: %s, rtt: %s%s>" % ( - self.__class__.__name__, self.address, self.server_type_name, - self.round_trip_time, errmsg) + self.__class__.__name__, + self.address, + self.server_type_name, + self.round_trip_time, + errmsg, + ) # For unittesting only. Use under no circumstances! _host_to_round_trip_time = {} diff --git a/pymongo/server_selectors.py b/pymongo/server_selectors.py index cc18450ad8..313566cb83 100644 --- a/pymongo/server_selectors.py +++ b/pymongo/server_selectors.py @@ -29,32 +29,28 @@ def from_topology_description(cls, topology_description): primary = sd break - return Selection(topology_description, - topology_description.known_servers, - topology_description.common_wire_version, - primary) - - def __init__(self, - topology_description, - server_descriptions, - common_wire_version, - primary): + return Selection( + topology_description, + topology_description.known_servers, + topology_description.common_wire_version, + primary, + ) + + def __init__(self, topology_description, server_descriptions, common_wire_version, primary): self.topology_description = topology_description self.server_descriptions = server_descriptions self.primary = primary self.common_wire_version = common_wire_version def with_server_descriptions(self, server_descriptions): - return Selection(self.topology_description, - server_descriptions, - self.common_wire_version, - self.primary) + return Selection( + self.topology_description, server_descriptions, self.common_wire_version, self.primary + ) def secondary_with_max_last_write_date(self): secondaries = secondary_server_selector(self) if secondaries.server_descriptions: - return max(secondaries.server_descriptions, - key=lambda sd: sd.last_write_date) + return max(secondaries.server_descriptions, key=lambda sd: sd.last_write_date) @property def primary_selection(self): @@ -82,30 +78,31 @@ def any_server_selector(selection): def readable_server_selector(selection): return selection.with_server_descriptions( - [s for s in selection.server_descriptions if s.is_readable]) + [s for s in selection.server_descriptions if s.is_readable] + ) def writable_server_selector(selection): return selection.with_server_descriptions( - [s for s in selection.server_descriptions if s.is_writable]) + [s for s in selection.server_descriptions if s.is_writable] + ) def secondary_server_selector(selection): return selection.with_server_descriptions( - [s for s in selection.server_descriptions - if s.server_type == SERVER_TYPE.RSSecondary]) + [s for s in selection.server_descriptions if s.server_type == SERVER_TYPE.RSSecondary] + ) def arbiter_server_selector(selection): return selection.with_server_descriptions( - [s for s in selection.server_descriptions - if s.server_type == SERVER_TYPE.RSArbiter]) + [s for s in selection.server_descriptions if s.server_type == SERVER_TYPE.RSArbiter] + ) def writable_preferred_server_selector(selection): """Like PrimaryPreferred but doesn't use tags or latency.""" - return (writable_server_selector(selection) or - secondary_server_selector(selection)) + return writable_server_selector(selection) or secondary_server_selector(selection) def apply_single_tag_set(tag_set, selection): @@ -116,6 +113,7 @@ def apply_single_tag_set(tag_set, selection): The empty tag set {} matches any server. """ + def tags_match(server_tags): for key, value in tag_set.items(): if key not in server_tags or server_tags[key] != value: @@ -124,7 +122,8 @@ def tags_match(server_tags): return True return selection.with_server_descriptions( - [s for s in selection.server_descriptions if tags_match(s.tags)]) + [s for s in selection.server_descriptions if tags_match(s.tags)] + ) def apply_tag_sets(tag_sets, selection): diff --git a/pymongo/server_type.py b/pymongo/server_type.py index 101f9dba4c..8be4667bcd 100644 --- a/pymongo/server_type.py +++ b/pymongo/server_type.py @@ -16,8 +16,17 @@ from collections import namedtuple - -SERVER_TYPE = namedtuple('ServerType', - ['Unknown', 'Mongos', 'RSPrimary', 'RSSecondary', - 'RSArbiter', 'RSOther', 'RSGhost', - 'Standalone', 'LoadBalancer'])(*range(9)) +SERVER_TYPE = namedtuple( + "ServerType", + [ + "Unknown", + "Mongos", + "RSPrimary", + "RSSecondary", + "RSArbiter", + "RSOther", + "RSGhost", + "Standalone", + "LoadBalancer", + ], +)(*range(9)) diff --git a/pymongo/settings.py b/pymongo/settings.py index d17b5e8b86..2bd2527cdf 100644 --- a/pymongo/settings.py +++ b/pymongo/settings.py @@ -27,32 +27,35 @@ class TopologySettings(object): - def __init__(self, - seeds=None, - replica_set_name=None, - pool_class=None, - pool_options=None, - monitor_class=None, - condition_class=None, - local_threshold_ms=LOCAL_THRESHOLD_MS, - server_selection_timeout=SERVER_SELECTION_TIMEOUT, - heartbeat_frequency=common.HEARTBEAT_FREQUENCY, - server_selector=None, - fqdn=None, - direct_connection=False, - load_balanced=None, - srv_service_name=common.SRV_SERVICE_NAME, - srv_max_hosts=0): + def __init__( + self, + seeds=None, + replica_set_name=None, + pool_class=None, + pool_options=None, + monitor_class=None, + condition_class=None, + local_threshold_ms=LOCAL_THRESHOLD_MS, + server_selection_timeout=SERVER_SELECTION_TIMEOUT, + heartbeat_frequency=common.HEARTBEAT_FREQUENCY, + server_selector=None, + fqdn=None, + direct_connection=False, + load_balanced=None, + srv_service_name=common.SRV_SERVICE_NAME, + srv_max_hosts=0, + ): """Represent MongoClient's configuration. Take a list of (host, port) pairs and optional replica set name. """ if heartbeat_frequency < common.MIN_HEARTBEAT_INTERVAL: raise ConfigurationError( - "heartbeatFrequencyMS cannot be less than %d" % ( - common.MIN_HEARTBEAT_INTERVAL * 1000,)) + "heartbeatFrequencyMS cannot be less than %d" + % (common.MIN_HEARTBEAT_INTERVAL * 1000,) + ) - self._seeds = seeds or [('localhost', 27017)] + self._seeds = seeds or [("localhost", 27017)] self._replica_set_name = replica_set_name self._pool_class = pool_class or pool.Pool self._pool_options = pool_options or PoolOptions() @@ -71,7 +74,7 @@ def __init__(self, self._topology_id = ObjectId() # Store the allocation traceback to catch unclosed clients in the # test suite. - self._stack = ''.join(traceback.format_stack()) + self._stack = "".join(traceback.format_stack()) @property def seeds(self): @@ -153,6 +156,4 @@ def get_topology_type(self): def get_server_descriptions(self): """Initial dict of (address, ServerDescription) for all seeds.""" - return dict([ - (address, ServerDescription(address)) - for address in self.seeds]) + return dict([(address, ServerDescription(address)) for address in self.seeds]) diff --git a/pymongo/socket_checker.py b/pymongo/socket_checker.py index 48f168be48..3d95b0e497 100644 --- a/pymongo/socket_checker.py +++ b/pymongo/socket_checker.py @@ -20,12 +20,12 @@ # PYTHON-2320: Jython does not fully support poll on SSL sockets, # https://bugs.jython.org/issue2900 -_HAVE_POLL = hasattr(select, "poll") and not sys.platform.startswith('java') +_HAVE_POLL = hasattr(select, "poll") and not sys.platform.startswith("java") _SelectError = getattr(select, "error", OSError) def _errno_from_exception(exc): - if hasattr(exc, 'errno'): + if hasattr(exc, "errno"): return exc.errno if exc.args: return exc.args[0] @@ -33,7 +33,6 @@ def _errno_from_exception(exc): class SocketChecker(object): - def __init__(self): if _HAVE_POLL: self._poller = select.poll() @@ -80,8 +79,7 @@ def select(self, sock, read=False, write=False, timeout=0): raise def socket_closed(self, sock): - """Return True if we know socket has been closed, False otherwise. - """ + """Return True if we know socket has been closed, False otherwise.""" try: return self.select(sock, read=True) except (RuntimeError, KeyError): diff --git a/pymongo/srv_resolver.py b/pymongo/srv_resolver.py index 69e075aec4..fe2dd49aa0 100644 --- a/pymongo/srv_resolver.py +++ b/pymongo/srv_resolver.py @@ -19,6 +19,7 @@ try: from dns import resolver + _HAVE_DNSPYTHON = True except ImportError: _HAVE_DNSPYTHON = False @@ -26,6 +27,7 @@ from pymongo.common import CONNECT_TIMEOUT from pymongo.errors import ConfigurationError + # dnspython can return bytes or str from various parts # of its API depending on version. We always want str. def maybe_decode(text): @@ -36,19 +38,21 @@ def maybe_decode(text): # PYTHON-2667 Lazily call dns.resolver methods for compatibility with eventlet. def _resolve(*args, **kwargs): - if hasattr(resolver, 'resolve'): + if hasattr(resolver, "resolve"): # dnspython >= 2 return resolver.resolve(*args, **kwargs) # dnspython 1.X return resolver.query(*args, **kwargs) + _INVALID_HOST_MSG = ( "Invalid URI host: %s is not a valid hostname for 'mongodb+srv://'. " - "Did you mean to use 'mongodb://'?") + "Did you mean to use 'mongodb://'?" +) + class _SrvResolver(object): - def __init__(self, fqdn, - connect_timeout, srv_service_name, srv_max_hosts=0): + def __init__(self, fqdn, connect_timeout, srv_service_name, srv_max_hosts=0): self.__fqdn = fqdn self.__srv = srv_service_name self.__connect_timeout = connect_timeout or CONNECT_TIMEOUT @@ -70,23 +74,21 @@ def __init__(self, fqdn, def get_options(self): try: - results = _resolve(self.__fqdn, 'TXT', - lifetime=self.__connect_timeout) + results = _resolve(self.__fqdn, "TXT", lifetime=self.__connect_timeout) except (resolver.NoAnswer, resolver.NXDOMAIN): # No TXT records return None except Exception as exc: raise ConfigurationError(str(exc)) if len(results) > 1: - raise ConfigurationError('Only one TXT record is supported') - return ( - b'&'.join([b''.join(res.strings) for res in results])).decode( - 'utf-8') + raise ConfigurationError("Only one TXT record is supported") + return (b"&".join([b"".join(res.strings) for res in results])).decode("utf-8") def _resolve_uri(self, encapsulate_errors): try: - results = _resolve('_' + self.__srv + '._tcp.' + self.__fqdn, - 'SRV', lifetime=self.__connect_timeout) + results = _resolve( + "_" + self.__srv + "._tcp." + self.__fqdn, "SRV", lifetime=self.__connect_timeout + ) except Exception as exc: if not encapsulate_errors: # Raise the original error. @@ -100,13 +102,13 @@ def _get_srv_response_and_hosts(self, encapsulate_errors): # Construct address tuples nodes = [ - (maybe_decode(res.target.to_text(omit_final_dot=True)), res.port) - for res in results] + (maybe_decode(res.target.to_text(omit_final_dot=True)), res.port) for res in results + ] # Validate hosts for node in nodes: try: - nlist = node[0].split(".")[1:][-self.__slen:] + nlist = node[0].split(".")[1:][-self.__slen :] except Exception: raise ConfigurationError("Invalid SRV host: %s" % (node[0],)) if self.__plist != nlist: diff --git a/pymongo/ssl_context.py b/pymongo/ssl_context.py index 2f35676f87..e546105141 100644 --- a/pymongo/ssl_context.py +++ b/pymongo/ssl_context.py @@ -32,6 +32,7 @@ SSLError = _ssl.SSLError from ssl import SSLContext + if hasattr(_ssl, "VERIFY_CRL_CHECK_LEAF"): from ssl import VERIFY_CRL_CHECK_LEAF # Python 3.7 uses OpenSSL's hostname matching implementation diff --git a/pymongo/ssl_support.py b/pymongo/ssl_support.py index 5826f95801..1dd73e5483 100644 --- a/pymongo/ssl_support.py +++ b/pymongo/ssl_support.py @@ -36,13 +36,20 @@ # import the ssl module even if we're only using it for this purpose. import ssl as _stdlibssl from ssl import CERT_NONE, CERT_REQUIRED + HAS_SNI = _ssl.HAS_SNI IPADDR_SAFE = _ssl.IS_PYOPENSSL or sys.version_info[:2] >= (3, 7) SSLError = _ssl.SSLError - def get_ssl_context(certfile, passphrase, ca_certs, crlfile, - allow_invalid_certificates, allow_invalid_hostnames, - disable_ocsp_endpoint_check): + def get_ssl_context( + certfile, + passphrase, + ca_certs, + crlfile, + allow_invalid_certificates, + allow_invalid_hostnames, + disable_ocsp_endpoint_check, + ): """Create and return an SSLContext object.""" verify_mode = CERT_NONE if allow_invalid_certificates else CERT_REQUIRED ctx = _ssl.SSLContext(_ssl.PROTOCOL_SSLv23) @@ -67,12 +74,10 @@ def get_ssl_context(certfile, passphrase, ca_certs, crlfile, try: ctx.load_cert_chain(certfile, None, passphrase) except _ssl.SSLError as exc: - raise ConfigurationError( - "Private key doesn't match certificate: %s" % (exc,)) + raise ConfigurationError("Private key doesn't match certificate: %s" % (exc,)) if crlfile is not None: if _ssl.IS_PYOPENSSL: - raise ConfigurationError( - "tlsCRLFile cannot be used with PyOpenSSL") + raise ConfigurationError("tlsCRLFile cannot be used with PyOpenSSL") # Match the server's behavior. ctx.verify_flags = getattr(_ssl, "VERIFY_CRL_CHECK_LEAF", 0) ctx.load_verify_locations(crlfile) @@ -82,9 +87,12 @@ def get_ssl_context(certfile, passphrase, ca_certs, crlfile, ctx.load_default_certs() ctx.verify_mode = verify_mode return ctx + else: + class SSLError(Exception): pass + HAS_SNI = False IPADDR_SAFE = False diff --git a/pymongo/topology.py b/pymongo/topology.py index 6f26cff617..8a66613afd 100644 --- a/pymongo/topology.py +++ b/pymongo/topology.py @@ -22,34 +22,39 @@ import warnings import weakref -from pymongo import (common, - helpers, - periodic_executor) +from pymongo import common, helpers, periodic_executor from pymongo.client_session import _ServerSessionPool -from pymongo.errors import (ConnectionFailure, - ConfigurationError, - NetworkTimeout, - NotPrimaryError, - OperationFailure, - PyMongoError, - ServerSelectionTimeoutError, - WriteError, - InvalidOperation) +from pymongo.errors import ( + ConfigurationError, + ConnectionFailure, + InvalidOperation, + NetworkTimeout, + NotPrimaryError, + OperationFailure, + PyMongoError, + ServerSelectionTimeoutError, + WriteError, +) from pymongo.hello import Hello from pymongo.monitor import SrvMonitor from pymongo.pool import PoolOptions from pymongo.server import Server from pymongo.server_description import ServerDescription -from pymongo.server_selectors import (any_server_selector, - arbiter_server_selector, - secondary_server_selector, - readable_server_selector, - writable_server_selector, - Selection) -from pymongo.topology_description import (updated_topology_description, - _updated_topology_description_srv_polling, - TopologyDescription, - SRV_POLLING_TOPOLOGIES, TOPOLOGY_TYPE) +from pymongo.server_selectors import ( + Selection, + any_server_selector, + arbiter_server_selector, + readable_server_selector, + secondary_server_selector, + writable_server_selector, +) +from pymongo.topology_description import ( + SRV_POLLING_TOPOLOGIES, + TOPOLOGY_TYPE, + TopologyDescription, + _updated_topology_description_srv_polling, + updated_topology_description, +) def process_events_queue(queue_ref): @@ -71,6 +76,7 @@ def process_events_queue(queue_ref): class Topology(object): """Monitor a topology of one or more servers.""" + def __init__(self, topology_settings): self._topology_id = topology_settings._topology_id self._listeners = topology_settings._pool_options._event_listeners @@ -86,8 +92,7 @@ def __init__(self, topology_settings): self._events = queue.Queue(maxsize=100) if self._publish_tp: - self._events.put((self._listeners.publish_topology_opened, - (self._topology_id,))) + self._events.put((self._listeners.publish_topology_opened, (self._topology_id,))) self._settings = topology_settings topology_description = TopologyDescription( topology_settings.get_topology_type(), @@ -95,20 +100,24 @@ def __init__(self, topology_settings): topology_settings.replica_set_name, None, None, - topology_settings) + topology_settings, + ) self._description = topology_description if self._publish_tp: - initial_td = TopologyDescription(TOPOLOGY_TYPE.Unknown, {}, None, - None, None, self._settings) - self._events.put(( - self._listeners.publish_topology_description_changed, - (initial_td, self._description, self._topology_id))) + initial_td = TopologyDescription( + TOPOLOGY_TYPE.Unknown, {}, None, None, None, self._settings + ) + self._events.put( + ( + self._listeners.publish_topology_description_changed, + (initial_td, self._description, self._topology_id), + ) + ) for seed in topology_settings.seeds: if self._publish_server: - self._events.put((self._listeners.publish_server_opened, - (seed, self._topology_id))) + self._events.put((self._listeners.publish_server_opened, (seed, self._topology_id))) # Store the seed list to help diagnose errors in _error_message(). self._seed_addresses = list(topology_description.server_descriptions()) @@ -122,6 +131,7 @@ def __init__(self, topology_settings): self._session_pool = _ServerSessionPool() if self._publish_server or self._publish_tp: + def target(): return process_events_queue(weak) @@ -129,7 +139,8 @@ def target(): interval=common.EVENTS_QUEUE_FREQUENCY, min_interval=common.MIN_HEARTBEAT_INTERVAL, target=target, - name="pymongo_events_thread") + name="pymongo_events_thread", + ) # We strongly reference the executor and it weakly references # the queue via this closure. When the topology is freed, stop @@ -139,8 +150,7 @@ def target(): executor.open() self._srv_monitor = None - if (self._settings.fqdn is not None and - not self._settings.load_balanced): + if self._settings.fqdn is not None and not self._settings.load_balanced: self._srv_monitor = SrvMonitor(self, self._settings) def open(self): @@ -163,7 +173,8 @@ def open(self): "MongoClient opened before fork. Create MongoClient only " "after forking. See PyMongo's documentation for details: " "https://pymongo.readthedocs.io/en/stable/faq.html#" - "is-pymongo-fork-safe") + "is-pymongo-fork-safe" + ) with self._lock: # Reset the session pool to avoid duplicate sessions in # the child process. @@ -172,10 +183,7 @@ def open(self): with self._lock: self._ensure_opened() - def select_servers(self, - selector, - server_selection_timeout=None, - address=None): + def select_servers(self, selector, server_selection_timeout=None, address=None): """Return a list of Servers matching selector, or time out. :Parameters: @@ -197,25 +205,25 @@ def select_servers(self, server_timeout = server_selection_timeout with self._lock: - server_descriptions = self._select_servers_loop( - selector, server_timeout, address) + server_descriptions = self._select_servers_loop(selector, server_timeout, address) - return [self.get_server_by_address(sd.address) - for sd in server_descriptions] + return [self.get_server_by_address(sd.address) for sd in server_descriptions] def _select_servers_loop(self, selector, timeout, address): """select_servers() guts. Hold the lock when calling this.""" now = time.monotonic() end_time = now + timeout server_descriptions = self._description.apply_selector( - selector, address, custom_selector=self._settings.server_selector) + selector, address, custom_selector=self._settings.server_selector + ) while not server_descriptions: # No suitable servers. if timeout == 0 or now > end_time: raise ServerSelectionTimeoutError( - "%s, Timeout: %ss, Topology Description: %r" % - (self._error_message(selector), timeout, self.description)) + "%s, Timeout: %ss, Topology Description: %r" + % (self._error_message(selector), timeout, self.description) + ) self._ensure_opened() self._request_check_all() @@ -228,19 +236,15 @@ def _select_servers_loop(self, selector, timeout, address): self._description.check_compatible() now = time.monotonic() server_descriptions = self._description.apply_selector( - selector, address, - custom_selector=self._settings.server_selector) + selector, address, custom_selector=self._settings.server_selector + ) self._description.check_compatible() return server_descriptions - def select_server(self, - selector, - server_selection_timeout=None, - address=None): + def select_server(self, selector, server_selection_timeout=None, address=None): """Like select_servers, but choose a random server if several match.""" - servers = self.select_servers( - selector, server_selection_timeout, address) + servers = self.select_servers(selector, server_selection_timeout, address) if len(servers) == 1: return servers[0] server1, server2 = random.sample(servers, 2) @@ -249,8 +253,7 @@ def select_server(self, else: return server2 - def select_server_by_address(self, address, - server_selection_timeout=None): + def select_server_by_address(self, address, server_selection_timeout=None): """Return a Server for "address", reconnecting if necessary. If the server's type is not known, request an immediate check of all @@ -268,9 +271,7 @@ def select_server_by_address(self, address, Raises exc:`ServerSelectionTimeoutError` after `server_selection_timeout` if no matching servers are found. """ - return self.select_server(any_server_selector, - server_selection_timeout, - address) + return self.select_server(any_server_selector, server_selection_timeout, address) def _process_change(self, server_description, reset_pool=False): """Process a new ServerDescription on an opened topology. @@ -283,39 +284,43 @@ def _process_change(self, server_description, reset_pool=False): # This is a stale hello response. Ignore it. return - new_td = updated_topology_description( - self._description, server_description) + new_td = updated_topology_description(self._description, server_description) # CMAP: Ensure the pool is "ready" when the server is selectable. - if (server_description.is_readable - or (server_description.is_server_type_known and - new_td.topology_type == TOPOLOGY_TYPE.Single)): + if server_description.is_readable or ( + server_description.is_server_type_known and new_td.topology_type == TOPOLOGY_TYPE.Single + ): server = self._servers.get(server_description.address) if server: server.pool.ready() - suppress_event = ((self._publish_server or self._publish_tp) - and sd_old == server_description) + suppress_event = (self._publish_server or self._publish_tp) and sd_old == server_description if self._publish_server and not suppress_event: - self._events.put(( - self._listeners.publish_server_description_changed, - (sd_old, server_description, - server_description.address, self._topology_id))) + self._events.put( + ( + self._listeners.publish_server_description_changed, + (sd_old, server_description, server_description.address, self._topology_id), + ) + ) self._description = new_td self._update_servers() self._receive_cluster_time_no_lock(server_description.cluster_time) if self._publish_tp and not suppress_event: - self._events.put(( - self._listeners.publish_topology_description_changed, - (td_old, self._description, self._topology_id))) + self._events.put( + ( + self._listeners.publish_topology_description_changed, + (td_old, self._description, self._topology_id), + ) + ) # Shutdown SRV polling for unsupported cluster types. # This is only applicable if the old topology was Unknown, and the # new one is something other than Unknown or Sharded. - if self._srv_monitor and (td_old.topology_type == TOPOLOGY_TYPE.Unknown - and self._description.topology_type not in - SRV_POLLING_TOPOLOGIES): + if self._srv_monitor and ( + td_old.topology_type == TOPOLOGY_TYPE.Unknown + and self._description.topology_type not in SRV_POLLING_TOPOLOGIES + ): self._srv_monitor.close() # Clear the pool from a failed heartbeat. @@ -339,8 +344,7 @@ def on_change(self, server_description, reset_pool=False): # once. Check if it's still in the description or if some state- # change removed it. E.g., we got a host list from the primary # that didn't include this server. - if (self._opened and - self._description.has_server(server_description.address)): + if self._opened and self._description.has_server(server_description.address): self._process_change(server_description, reset_pool) def _process_srv_update(self, seedlist): @@ -348,15 +352,17 @@ def _process_srv_update(self, seedlist): Hold the lock when calling this. """ td_old = self._description - self._description = _updated_topology_description_srv_polling( - self._description, seedlist) + self._description = _updated_topology_description_srv_polling(self._description, seedlist) self._update_servers() if self._publish_tp: - self._events.put(( - self._listeners.publish_topology_description_changed, - (td_old, self._description, self._topology_id))) + self._events.put( + ( + self._listeners.publish_topology_description_changed, + (td_old, self._description, self._topology_id), + ) + ) def on_srv_update(self, seedlist): """Process a new list of nodes obtained from scanning SRV records.""" @@ -393,8 +399,10 @@ def _get_replica_set_members(self, selector): # Implemented here in Topology instead of MongoClient, so it can lock. with self._lock: topology_type = self._description.topology_type - if topology_type not in (TOPOLOGY_TYPE.ReplicaSetWithPrimary, - TOPOLOGY_TYPE.ReplicaSetNoPrimary): + if topology_type not in ( + TOPOLOGY_TYPE.ReplicaSetWithPrimary, + TOPOLOGY_TYPE.ReplicaSetNoPrimary, + ): return set() return set([sd.address for sd in selector(self._new_selection())]) @@ -420,9 +428,10 @@ def _receive_cluster_time_no_lock(self, cluster_time): # value of the clusterTime embedded field." if cluster_time: # ">" uses bson.timestamp.Timestamp's comparison operator. - if (not self._max_cluster_time - or cluster_time['clusterTime'] > - self._max_cluster_time['clusterTime']): + if ( + not self._max_cluster_time + or cluster_time["clusterTime"] > self._max_cluster_time["clusterTime"] + ): self._max_cluster_time = cluster_time def receive_cluster_time(self, cluster_time): @@ -451,8 +460,7 @@ def update_pool(self, all_credentials): # Only update pools for data-bearing servers. for sd in self.data_bearing_servers(): server = self._servers[sd.address] - servers.append((server, - server.pool.gen.get_overall())) + servers.append((server, server.pool.gen.get_overall())) for server, generation in servers: try: @@ -465,7 +473,7 @@ def update_pool(self, all_credentials): def close(self): """Clear pools and terminate monitors. Topology does not reopen on demand. Any further operations will raise - :exc:`~.errors.InvalidOperation`. """ + :exc:`~.errors.InvalidOperation`.""" with self._lock: for server in self._servers.values(): server.close() @@ -485,8 +493,7 @@ def close(self): # Publish only after releasing the lock. if self._publish_tp: - self._events.put((self._listeners.publish_topology_closed, - (self._topology_id,))) + self._events.put((self._listeners.publish_topology_closed, (self._topology_id,))) if self._publish_server or self._publish_tp: self.__events_executor.close() @@ -507,19 +514,16 @@ def _check_session_support(self): if self._description.topology_type == TOPOLOGY_TYPE.Single: if not self._description.has_known_servers: self._select_servers_loop( - any_server_selector, - self._settings.server_selection_timeout, - None) + any_server_selector, self._settings.server_selection_timeout, None + ) elif not self._description.readable_servers: self._select_servers_loop( - readable_server_selector, - self._settings.server_selection_timeout, - None) + readable_server_selector, self._settings.server_selection_timeout, None + ) session_timeout = self._description.logical_session_timeout_minutes if session_timeout is None: - raise ConfigurationError( - "Sessions are not supported by this MongoDB deployment") + raise ConfigurationError("Sessions are not supported by this MongoDB deployment") return session_timeout def get_server_session(self): @@ -530,15 +534,15 @@ def get_server_session(self): session_timeout = self._check_session_support() else: # Sessions never time out in load balanced mode. - session_timeout = float('inf') + session_timeout = float("inf") return self._session_pool.get_server_session(session_timeout) def return_server_session(self, server_session, lock): if lock: with self._lock: self._session_pool.return_server_session( - server_session, - self._description.logical_session_timeout_minutes) + server_session, self._description.logical_session_timeout_minutes + ) else: # Called from a __del__ method, can't use a lock. self._session_pool.return_server_session_no_lock(server_session) @@ -567,16 +571,17 @@ def _ensure_opened(self): self.__events_executor.open() # Start the SRV polling thread. - if self._srv_monitor and (self.description.topology_type in - SRV_POLLING_TOPOLOGIES): + if self._srv_monitor and (self.description.topology_type in SRV_POLLING_TOPOLOGIES): self._srv_monitor.open() if self._settings.load_balanced: # Emit initial SDAM events for load balancer mode. - self._process_change(ServerDescription( - self._seed_addresses[0], - Hello({'ok': 1, 'serviceId': self._topology_id, - 'maxWireVersion': 13}))) + self._process_change( + ServerDescription( + self._seed_addresses[0], + Hello({"ok": 1, "serviceId": self._topology_id, "maxWireVersion": 13}), + ) + ) # Ensure that the monitors are open. for server in self._servers.values(): @@ -588,8 +593,7 @@ def _is_stale_error(self, address, err_ctx): # Another thread removed this server from the topology. return True - if server._pool.stale_generation( - err_ctx.sock_generation, err_ctx.service_id): + if server._pool.stale_generation(err_ctx.sock_generation, err_ctx.service_id): # This is an outdated error from a previous pool version. return True @@ -597,9 +601,9 @@ def _is_stale_error(self, address, err_ctx): cur_tv = server.description.topology_version error = err_ctx.error error_tv = None - if error and hasattr(error, 'details'): + if error and hasattr(error, "details"): if isinstance(error.details, dict): - error_tv = error.details.get('topologyVersion') + error_tv = error.details.get("topologyVersion") return _is_stale_error_topology_version(cur_tv, error_tv) @@ -611,8 +615,7 @@ def _handle_error(self, address, err_ctx): error = err_ctx.error exc_type = type(error) service_id = err_ctx.service_id - if (issubclass(exc_type, NetworkTimeout) and - err_ctx.completed_handshake): + if issubclass(exc_type, NetworkTimeout) and err_ctx.completed_handshake: # The socket has been closed. Don't reset the server. # Server Discovery And Monitoring Spec: "When an application # operation fails because of any network error besides a socket @@ -630,12 +633,12 @@ def _handle_error(self, address, err_ctx): # as Unknown and request an immediate check of the server. # Otherwise, we clear the connection pool, mark the server as # Unknown and request an immediate check of the server. - if hasattr(error, 'code'): + if hasattr(error, "code"): err_code = error.code else: # Default error code if one does not exist. default = 10107 if isinstance(error, NotPrimaryError) else None - err_code = error.details.get('code', default) + err_code = error.details.get("code", default) if err_code in helpers._NOT_PRIMARY_CODES: is_shutting_down = err_code in helpers._SHUTDOWN_CODES # Mark server Unknown, clear the pool, and request check. @@ -688,7 +691,8 @@ def _update_servers(self): server_description=sd, topology=self, pool=self._create_pool_for_monitor(address), - topology_settings=self._settings) + topology_settings=self._settings, + ) weak = None if self._publish_server: @@ -699,7 +703,8 @@ def _update_servers(self): monitor=monitor, topology_id=self._topology_id, listeners=self._listeners, - events=weak) + events=weak, + ) self._servers[address] = server server.open() @@ -710,8 +715,7 @@ def _update_servers(self): self._servers[address].description = sd # Update is_writable value of the pool, if it changed. if was_writable != sd.is_writable: - self._servers[address].pool.update_is_writable( - sd.is_writable) + self._servers[address].pool.update_is_writable(sd.is_writable) for address, server in list(self._servers.items()): if not self._description.has_server(address): @@ -739,8 +743,7 @@ def _create_pool_for_monitor(self, address): server_api=options.server_api, ) - return self._settings.pool_class(address, monitor_pool_options, - handshake=False) + return self._settings.pool_class(address, monitor_pool_options, handshake=False) def _error_message(self, selector): """Format an error message if server selection fails. @@ -749,22 +752,23 @@ def _error_message(self, selector): """ is_replica_set = self._description.topology_type in ( TOPOLOGY_TYPE.ReplicaSetWithPrimary, - TOPOLOGY_TYPE.ReplicaSetNoPrimary) + TOPOLOGY_TYPE.ReplicaSetNoPrimary, + ) if is_replica_set: - server_plural = 'replica set members' + server_plural = "replica set members" elif self._description.topology_type == TOPOLOGY_TYPE.Sharded: - server_plural = 'mongoses' + server_plural = "mongoses" else: - server_plural = 'servers' + server_plural = "servers" if self._description.known_servers: # We've connected, but no servers match the selector. if selector is writable_server_selector: if is_replica_set: - return 'No primary available for writes' + return "No primary available for writes" else: - return 'No %s available for writes' % server_plural + return "No %s available for writes" % server_plural else: return 'No %s match selector "%s"' % (server_plural, selector) else: @@ -774,9 +778,11 @@ def _error_message(self, selector): if is_replica_set: # We removed all servers because of the wrong setName? return 'No %s available for replica set name "%s"' % ( - server_plural, self._settings.replica_set_name) + server_plural, + self._settings.replica_set_name, + ) else: - return 'No %s available' % server_plural + return "No %s available" % server_plural # 1 or more servers, all Unknown. Are they unknown for one reason? error = servers[0].error @@ -784,32 +790,29 @@ def _error_message(self, selector): if same: if error is None: # We're still discovering. - return 'No %s found yet' % server_plural + return "No %s found yet" % server_plural - if (is_replica_set and not - set(addresses).intersection(self._seed_addresses)): + if is_replica_set and not set(addresses).intersection(self._seed_addresses): # We replaced our seeds with new hosts but can't reach any. return ( - 'Could not reach any servers in %s. Replica set is' - ' configured with internal hostnames or IPs?' % - addresses) + "Could not reach any servers in %s. Replica set is" + " configured with internal hostnames or IPs?" % addresses + ) return str(error) else: - return ','.join(str(server.error) for server in servers - if server.error) + return ",".join(str(server.error) for server in servers if server.error) def __repr__(self): - msg = '' + msg = "" if not self._opened: - msg = 'CLOSED ' - return '<%s %s%r>' % (self.__class__.__name__, msg, self._description) + msg = "CLOSED " + return "<%s %s%r>" % (self.__class__.__name__, msg, self._description) def eq_props(self): """The properties to use for MongoClient/Topology equality checks.""" ts = self._settings - return (tuple(sorted(ts.seeds)), ts.replica_set_name, ts.fqdn, - ts.srv_service_name) + return (tuple(sorted(ts.seeds)), ts.replica_set_name, ts.fqdn, ts.srv_service_name) def __eq__(self, other): if isinstance(other, self.__class__): @@ -822,8 +825,8 @@ def __hash__(self): class _ErrorContext(object): """An error with context for SDAM error handling.""" - def __init__(self, error, max_wire_version, sock_generation, - completed_handshake, service_id): + + def __init__(self, error, max_wire_version, sock_generation, completed_handshake, service_id): self.error = error self.max_wire_version = max_wire_version self.sock_generation = sock_generation @@ -835,9 +838,9 @@ def _is_stale_error_topology_version(current_tv, error_tv): """Return True if the error's topologyVersion is <= current.""" if current_tv is None or error_tv is None: return False - if current_tv['processId'] != error_tv['processId']: + if current_tv["processId"] != error_tv["processId"]: return False - return current_tv['counter'] >= error_tv['counter'] + return current_tv["counter"] >= error_tv["counter"] def _is_stale_server_description(current_sd, new_sd): @@ -845,6 +848,6 @@ def _is_stale_server_description(current_sd, new_sd): current_tv, new_tv = current_sd.topology_version, new_sd.topology_version if current_tv is None or new_tv is None: return False - if current_tv['processId'] != new_tv['processId']: + if current_tv["processId"] != new_tv["processId"]: return False - return current_tv['counter'] > new_tv['counter'] + return current_tv["counter"] > new_tv["counter"] diff --git a/pymongo/topology_description.py b/pymongo/topology_description.py index 4fe897dcef..e8fdeb5276 100644 --- a/pymongo/topology_description.py +++ b/pymongo/topology_description.py @@ -24,24 +24,33 @@ from pymongo.server_selectors import Selection from pymongo.server_type import SERVER_TYPE - # Enumeration for various kinds of MongoDB cluster topologies. -TOPOLOGY_TYPE = namedtuple('TopologyType', [ - 'Single', 'ReplicaSetNoPrimary', 'ReplicaSetWithPrimary', 'Sharded', - 'Unknown', 'LoadBalanced'])(*range(6)) +TOPOLOGY_TYPE = namedtuple( + "TopologyType", + [ + "Single", + "ReplicaSetNoPrimary", + "ReplicaSetWithPrimary", + "Sharded", + "Unknown", + "LoadBalanced", + ], +)(*range(6)) # Topologies compatible with SRV record polling. SRV_POLLING_TOPOLOGIES = (TOPOLOGY_TYPE.Unknown, TOPOLOGY_TYPE.Sharded) class TopologyDescription(object): - def __init__(self, - topology_type, - server_descriptions, - replica_set_name, - max_set_version, - max_election_id, - topology_settings): + def __init__( + self, + topology_type, + server_descriptions, + replica_set_name, + max_set_version, + max_election_id, + topology_settings, + ): """Representation of a deployment of MongoDB servers. :Parameters: @@ -77,12 +86,12 @@ def __init__(self, readable_servers = self.readable_servers if not readable_servers: self._ls_timeout_minutes = None - elif any(s.logical_session_timeout_minutes is None - for s in readable_servers): + elif any(s.logical_session_timeout_minutes is None for s in readable_servers): self._ls_timeout_minutes = None else: - self._ls_timeout_minutes = min(s.logical_session_timeout_minutes - for s in readable_servers) + self._ls_timeout_minutes = min( + s.logical_session_timeout_minutes for s in readable_servers + ) def _init_incompatible_err(self): """Internal compatibility check for non-load balanced topologies.""" @@ -95,28 +104,39 @@ def _init_incompatible_err(self): server_too_new = ( # Server too new. s.min_wire_version is not None - and s.min_wire_version > common.MAX_SUPPORTED_WIRE_VERSION) + and s.min_wire_version > common.MAX_SUPPORTED_WIRE_VERSION + ) server_too_old = ( # Server too old. s.max_wire_version is not None - and s.max_wire_version < common.MIN_SUPPORTED_WIRE_VERSION) + and s.max_wire_version < common.MIN_SUPPORTED_WIRE_VERSION + ) if server_too_new: self._incompatible_err = ( "Server at %s:%d requires wire version %d, but this " "version of PyMongo only supports up to %d." - % (s.address[0], s.address[1], - s.min_wire_version, common.MAX_SUPPORTED_WIRE_VERSION)) + % ( + s.address[0], + s.address[1], + s.min_wire_version, + common.MAX_SUPPORTED_WIRE_VERSION, + ) + ) elif server_too_old: self._incompatible_err = ( "Server at %s:%d reports wire version %d, but this " "version of PyMongo requires at least %d (MongoDB %s)." - % (s.address[0], s.address[1], - s.max_wire_version, - common.MIN_SUPPORTED_WIRE_VERSION, - common.MIN_SUPPORTED_SERVER_VERSION)) + % ( + s.address[0], + s.address[1], + s.max_wire_version, + common.MIN_SUPPORTED_WIRE_VERSION, + common.MIN_SUPPORTED_SERVER_VERSION, + ) + ) break @@ -145,8 +165,7 @@ def reset(self): topology_type = self._topology_type # The default ServerDescription's type is Unknown. - sds = dict((address, ServerDescription(address)) - for address in self._server_descriptions) + sds = dict((address, ServerDescription(address)) for address in self._server_descriptions) return TopologyDescription( topology_type, @@ -154,7 +173,8 @@ def reset(self): self._replica_set_name, self._max_set_version, self._max_election_id, - self._topology_settings) + self._topology_settings, + ) def server_descriptions(self): """Dict of (address, @@ -197,14 +217,12 @@ def logical_session_timeout_minutes(self): @property def known_servers(self): """List of Servers of types besides Unknown.""" - return [s for s in self._server_descriptions.values() - if s.is_server_type_known] + return [s for s in self._server_descriptions.values() if s.is_server_type_known] @property def has_known_servers(self): """Whether there are any Servers of types besides Unknown.""" - return any(s for s in self._server_descriptions.values() - if s.is_server_type_known) + return any(s for s in self._server_descriptions.values() if s.is_server_type_known) @property def readable_servers(self): @@ -232,11 +250,11 @@ def _apply_local_threshold(self, selection): if not selection: return [] # Round trip time in seconds. - fastest = min( - s.round_trip_time for s in selection.server_descriptions) + fastest = min(s.round_trip_time for s in selection.server_descriptions) threshold = self._topology_settings.local_threshold_ms / 1000.0 - return [s for s in selection.server_descriptions - if (s.round_trip_time - fastest) <= threshold] + return [ + s for s in selection.server_descriptions if (s.round_trip_time - fastest) <= threshold + ] def apply_selector(self, selector, address=None, custom_selector=None): """List of servers matching the provided selector(s). @@ -254,19 +272,17 @@ def apply_selector(self, selector, address=None, custom_selector=None): .. versionadded:: 3.4 """ - if getattr(selector, 'min_wire_version', 0): + if getattr(selector, "min_wire_version", 0): common_wv = self.common_wire_version if common_wv and common_wv < selector.min_wire_version: raise ConfigurationError( "%s requires min wire version %d, but topology's min" - " wire version is %d" % (selector, - selector.min_wire_version, - common_wv)) + " wire version is %d" % (selector, selector.min_wire_version, common_wv) + ) if self.topology_type == TOPOLOGY_TYPE.Unknown: return [] - elif self.topology_type in (TOPOLOGY_TYPE.Single, - TOPOLOGY_TYPE.LoadBalanced): + elif self.topology_type in (TOPOLOGY_TYPE.Single, TOPOLOGY_TYPE.LoadBalanced): # Ignore selectors for standalone and load balancer mode. return self.known_servers elif address: @@ -282,7 +298,8 @@ def apply_selector(self, selector, address=None, custom_selector=None): # Apply custom selector followed by localThresholdMS. if custom_selector is not None and selection: selection = selection.with_server_descriptions( - custom_selector(selection.server_descriptions)) + custom_selector(selection.server_descriptions) + ) return self._apply_local_threshold(selection) def has_readable_server(self, read_preference=ReadPreference.PRIMARY): @@ -314,11 +331,13 @@ def has_writable_server(self): def __repr__(self): # Sort the servers by address. - servers = sorted(self._server_descriptions.values(), - key=lambda sd: sd.address) + servers = sorted(self._server_descriptions.values(), key=lambda sd: sd.address) return "<%s id: %s, topology_type: %s, servers: %r>" % ( - self.__class__.__name__, self._topology_settings._topology_id, - self.topology_type_name, servers) + self.__class__.__name__, + self._topology_settings._topology_id, + self.topology_type_name, + servers, + ) # If topology type is Unknown and we receive a hello response, what should @@ -362,12 +381,12 @@ def updated_topology_description(topology_description, server_description): if topology_type == TOPOLOGY_TYPE.Single: # Set server type to Unknown if replica set name does not match. - if (set_name is not None and - set_name != server_description.replica_set_name): + if set_name is not None and set_name != server_description.replica_set_name: error = ConfigurationError( "client is configured to connect to a replica set named " - "'%s' but this node belongs to a set named '%s'" % ( - set_name, server_description.replica_set_name)) + "'%s' but this node belongs to a set named '%s'" + % (set_name, server_description.replica_set_name) + ) sds[address] = server_description.to_unknown(error=error) # Single type never changes. return TopologyDescription( @@ -376,7 +395,8 @@ def updated_topology_description(topology_description, server_description): set_name, max_set_version, max_election_id, - topology_description._topology_settings) + topology_description._topology_settings, + ) if topology_type == TOPOLOGY_TYPE.Unknown: if server_type in (SERVER_TYPE.Standalone, SERVER_TYPE.LoadBalancer): @@ -397,21 +417,14 @@ def updated_topology_description(topology_description, server_description): sds.pop(address) elif server_type == SERVER_TYPE.RSPrimary: - (topology_type, - set_name, - max_set_version, - max_election_id) = _update_rs_from_primary(sds, - set_name, - server_description, - max_set_version, - max_election_id) - - elif server_type in ( - SERVER_TYPE.RSSecondary, - SERVER_TYPE.RSArbiter, - SERVER_TYPE.RSOther): + (topology_type, set_name, max_set_version, max_election_id) = _update_rs_from_primary( + sds, set_name, server_description, max_set_version, max_election_id + ) + + elif server_type in (SERVER_TYPE.RSSecondary, SERVER_TYPE.RSArbiter, SERVER_TYPE.RSOther): topology_type, set_name = _update_rs_no_primary_from_member( - sds, set_name, server_description) + sds, set_name, server_description + ) elif topology_type == TOPOLOGY_TYPE.ReplicaSetWithPrimary: if server_type in (SERVER_TYPE.Standalone, SERVER_TYPE.Mongos): @@ -419,33 +432,26 @@ def updated_topology_description(topology_description, server_description): topology_type = _check_has_primary(sds) elif server_type == SERVER_TYPE.RSPrimary: - (topology_type, - set_name, - max_set_version, - max_election_id) = _update_rs_from_primary(sds, - set_name, - server_description, - max_set_version, - max_election_id) - - elif server_type in ( - SERVER_TYPE.RSSecondary, - SERVER_TYPE.RSArbiter, - SERVER_TYPE.RSOther): - topology_type = _update_rs_with_primary_from_member( - sds, set_name, server_description) + (topology_type, set_name, max_set_version, max_election_id) = _update_rs_from_primary( + sds, set_name, server_description, max_set_version, max_election_id + ) + + elif server_type in (SERVER_TYPE.RSSecondary, SERVER_TYPE.RSArbiter, SERVER_TYPE.RSOther): + topology_type = _update_rs_with_primary_from_member(sds, set_name, server_description) else: # Server type is Unknown or RSGhost: did we just lose the primary? topology_type = _check_has_primary(sds) # Return updated copy. - return TopologyDescription(topology_type, - sds, - set_name, - max_set_version, - max_election_id, - topology_description._topology_settings) + return TopologyDescription( + topology_type, + sds, + set_name, + max_set_version, + max_election_id, + topology_description._topology_settings, + ) def _updated_topology_description_srv_polling(topology_description, seedlist): @@ -463,7 +469,6 @@ def _updated_topology_description_srv_polling(topology_description, seedlist): if set(sds.keys()) == set(seedlist): return topology_description - # Remove SDs corresponding to servers no longer part of the SRV record. for address in list(sds.keys()): if address not in seedlist: @@ -486,15 +491,13 @@ def _updated_topology_description_srv_polling(topology_description, seedlist): topology_description.replica_set_name, topology_description.max_set_version, topology_description.max_election_id, - topology_description._topology_settings) + topology_description._topology_settings, + ) def _update_rs_from_primary( - sds, - replica_set_name, - server_description, - max_set_version, - max_election_id): + sds, replica_set_name, server_description, max_set_version, max_election_id +): """Update topology description from a primary's hello response. Pass in a dict of ServerDescriptions, current replica set name, the @@ -511,35 +514,33 @@ def _update_rs_from_primary( # We found a primary but it doesn't have the replica_set_name # provided by the user. sds.pop(server_description.address) - return (_check_has_primary(sds), - replica_set_name, - max_set_version, - max_election_id) + return (_check_has_primary(sds), replica_set_name, max_set_version, max_election_id) max_election_tuple = max_set_version, max_election_id if None not in server_description.election_tuple: - if (None not in max_election_tuple and - max_election_tuple > server_description.election_tuple): + if ( + None not in max_election_tuple + and max_election_tuple > server_description.election_tuple + ): # Stale primary, set to type Unknown. sds[server_description.address] = server_description.to_unknown() - return (_check_has_primary(sds), - replica_set_name, - max_set_version, - max_election_id) + return (_check_has_primary(sds), replica_set_name, max_set_version, max_election_id) max_election_id = server_description.election_id - if (server_description.set_version is not None and - (max_set_version is None or - server_description.set_version > max_set_version)): + if server_description.set_version is not None and ( + max_set_version is None or server_description.set_version > max_set_version + ): max_set_version = server_description.set_version # We've heard from the primary. Is it the same primary as before? for server in sds.values(): - if (server.server_type is SERVER_TYPE.RSPrimary - and server.address != server_description.address): + if ( + server.server_type is SERVER_TYPE.RSPrimary + and server.address != server_description.address + ): # Reset old primary's type to Unknown. sds[server.address] = server.to_unknown() @@ -558,16 +559,10 @@ def _update_rs_from_primary( # If the host list differs from the seed list, we may not have a primary # after all. - return (_check_has_primary(sds), - replica_set_name, - max_set_version, - max_election_id) + return (_check_has_primary(sds), replica_set_name, max_set_version, max_election_id) -def _update_rs_with_primary_from_member( - sds, - replica_set_name, - server_description): +def _update_rs_with_primary_from_member(sds, replica_set_name, server_description): """RS with known primary. Process a response from a non-primary. Pass in a dict of ServerDescriptions, current replica set name, and the @@ -579,18 +574,14 @@ def _update_rs_with_primary_from_member( if replica_set_name != server_description.replica_set_name: sds.pop(server_description.address) - elif (server_description.me and - server_description.address != server_description.me): + elif server_description.me and server_description.address != server_description.me: sds.pop(server_description.address) # Had this member been the primary? return _check_has_primary(sds) -def _update_rs_no_primary_from_member( - sds, - replica_set_name, - server_description): +def _update_rs_no_primary_from_member(sds, replica_set_name, server_description): """RS without known primary. Update from a non-primary's response. Pass in a dict of ServerDescriptions, current replica set name, and the @@ -612,8 +603,7 @@ def _update_rs_no_primary_from_member( if address not in sds: sds[address] = ServerDescription(address) - if (server_description.me and - server_description.address != server_description.me): + if server_description.me and server_description.address != server_description.me: sds.pop(server_description.address) return topology_type, replica_set_name diff --git a/pymongo/uri_parser.py b/pymongo/uri_parser.py index 8c43d51770..3b9d069781 100644 --- a/pymongo/uri_parser.py +++ b/pymongo/uri_parser.py @@ -15,23 +15,24 @@ """Tools to parse and validate a MongoDB URI.""" import re -import warnings import sys - +import warnings from urllib.parse import unquote_plus from pymongo.client_options import _parse_ssl_options from pymongo.common import ( + INTERNAL_URI_OPTION_NAME_MAP, SRV_SERVICE_NAME, - get_validated_options, INTERNAL_URI_OPTION_NAME_MAP, - URI_OPTIONS_DEPRECATION_MAP, _CaseInsensitiveDictionary) + URI_OPTIONS_DEPRECATION_MAP, + _CaseInsensitiveDictionary, + get_validated_options, +) from pymongo.errors import ConfigurationError, InvalidURI from pymongo.srv_resolver import _HAVE_DNSPYTHON, _SrvResolver - -SCHEME = 'mongodb://' +SCHEME = "mongodb://" SCHEME_LEN = len(SCHEME) -SRV_SCHEME = 'mongodb+srv://' +SRV_SCHEME = "mongodb+srv://" SRV_SCHEME_LEN = len(SRV_SCHEME) DEFAULT_PORT = 27017 @@ -44,14 +45,15 @@ def _unquoted_percent(s): and '%E2%85%A8' but cannot have unquoted percent like '%foo'. """ for i in range(len(s)): - if s[i] == '%': - sub = s[i:i+3] + if s[i] == "%": + sub = s[i : i + 3] # If unquoting yields the same string this means there was an # unquoted %. if unquote_plus(sub) == sub: return True return False + def parse_userinfo(userinfo): """Validates the format of user information in a MongoDB URI. Reserved characters that are gen-delimiters (":", "/", "?", "#", "[", @@ -63,10 +65,11 @@ def parse_userinfo(userinfo): :Paramaters: - `userinfo`: A string of the form : """ - if ('@' in userinfo or userinfo.count(':') > 1 or - _unquoted_percent(userinfo)): - raise InvalidURI("Username and password must be escaped according to " - "RFC 3986, use urllib.parse.quote_plus") + if "@" in userinfo or userinfo.count(":") > 1 or _unquoted_percent(userinfo): + raise InvalidURI( + "Username and password must be escaped according to " + "RFC 3986, use urllib.parse.quote_plus" + ) user, _, passwd = userinfo.partition(":") # No password is expected with GSSAPI authentication. @@ -88,14 +91,14 @@ def parse_ipv6_literal_host(entity, default_port): - `default_port`: The port number to use when one wasn't specified in entity. """ - if entity.find(']') == -1: - raise ValueError("an IPv6 address literal must be " - "enclosed in '[' and ']' according " - "to RFC 2732.") - i = entity.find(']:') + if entity.find("]") == -1: + raise ValueError( + "an IPv6 address literal must be " "enclosed in '[' and ']' according " "to RFC 2732." + ) + i = entity.find("]:") if i == -1: return entity[1:-1], default_port - return entity[1: i], entity[i + 2:] + return entity[1:i], entity[i + 2 :] def parse_host(entity, default_port=DEFAULT_PORT): @@ -112,21 +115,22 @@ def parse_host(entity, default_port=DEFAULT_PORT): """ host = entity port = default_port - if entity[0] == '[': + if entity[0] == "[": host, port = parse_ipv6_literal_host(entity, default_port) elif entity.endswith(".sock"): return entity, default_port - elif entity.find(':') != -1: - if entity.count(':') > 1: - raise ValueError("Reserved characters such as ':' must be " - "escaped according RFC 2396. An IPv6 " - "address literal must be enclosed in '[' " - "and ']' according to RFC 2732.") - host, port = host.split(':', 1) + elif entity.find(":") != -1: + if entity.count(":") > 1: + raise ValueError( + "Reserved characters such as ':' must be " + "escaped according RFC 2396. An IPv6 " + "address literal must be enclosed in '[' " + "and ']' according to RFC 2732." + ) + host, port = host.split(":", 1) if isinstance(port, str): if not port.isdigit() or int(port) > 65535 or int(port) <= 0: - raise ValueError("Port must be an integer between 0 and 65535: %s" - % (port,)) + raise ValueError("Port must be an integer between 0 and 65535: %s" % (port,)) port = int(port) # Normalize hostname to lowercase, since DNS is case-insensitive: @@ -140,7 +144,8 @@ def parse_host(entity, default_port=DEFAULT_PORT): _IMPLICIT_TLSINSECURE_OPTS = { "tlsallowinvalidcertificates", "tlsallowinvalidhostnames", - "tlsdisableocspendpointcheck"} + "tlsdisableocspendpointcheck", +} def _parse_options(opts, delim): @@ -150,12 +155,12 @@ def _parse_options(opts, delim): options = _CaseInsensitiveDictionary() for uriopt in opts.split(delim): key, value = uriopt.split("=") - if key.lower() == 'readpreferencetags': + if key.lower() == "readpreferencetags": options.setdefault(key, []).append(value) else: if key in options: warnings.warn("Duplicate URI option '%s'." % (key,)) - if key.lower() == 'authmechanismproperties': + if key.lower() == "authmechanismproperties": val = value else: val = unquote_plus(value) @@ -173,49 +178,47 @@ def _handle_security_options(options): MongoDB URI options. """ # Implicitly defined options must not be explicitly specified. - tlsinsecure = options.get('tlsinsecure') + tlsinsecure = options.get("tlsinsecure") if tlsinsecure is not None: for opt in _IMPLICIT_TLSINSECURE_OPTS: if opt in options: - err_msg = ("URI options %s and %s cannot be specified " - "simultaneously.") - raise InvalidURI(err_msg % ( - options.cased_key('tlsinsecure'), options.cased_key(opt))) + err_msg = "URI options %s and %s cannot be specified " "simultaneously." + raise InvalidURI( + err_msg % (options.cased_key("tlsinsecure"), options.cased_key(opt)) + ) # Handle co-occurence of OCSP & tlsAllowInvalidCertificates options. - tlsallowinvalidcerts = options.get('tlsallowinvalidcertificates') + tlsallowinvalidcerts = options.get("tlsallowinvalidcertificates") if tlsallowinvalidcerts is not None: - if 'tlsdisableocspendpointcheck' in options: - err_msg = ("URI options %s and %s cannot be specified " - "simultaneously.") - raise InvalidURI(err_msg % ( - 'tlsallowinvalidcertificates', options.cased_key( - 'tlsdisableocspendpointcheck'))) + if "tlsdisableocspendpointcheck" in options: + err_msg = "URI options %s and %s cannot be specified " "simultaneously." + raise InvalidURI( + err_msg + % ("tlsallowinvalidcertificates", options.cased_key("tlsdisableocspendpointcheck")) + ) if tlsallowinvalidcerts is True: - options['tlsdisableocspendpointcheck'] = True + options["tlsdisableocspendpointcheck"] = True # Handle co-occurence of CRL and OCSP-related options. - tlscrlfile = options.get('tlscrlfile') + tlscrlfile = options.get("tlscrlfile") if tlscrlfile is not None: - for opt in ('tlsinsecure', 'tlsallowinvalidcertificates', - 'tlsdisableocspendpointcheck'): + for opt in ("tlsinsecure", "tlsallowinvalidcertificates", "tlsdisableocspendpointcheck"): if options.get(opt) is True: - err_msg = ("URI option %s=True cannot be specified when " - "CRL checking is enabled.") + err_msg = "URI option %s=True cannot be specified when " "CRL checking is enabled." raise InvalidURI(err_msg % (opt,)) - if 'ssl' in options and 'tls' in options: + if "ssl" in options and "tls" in options: + def truth_value(val): - if val in ('true', 'false'): - return val == 'true' + if val in ("true", "false"): + return val == "true" if isinstance(val, bool): return val return val - if truth_value(options.get('ssl')) != truth_value(options.get('tls')): - err_msg = ("Can not specify conflicting values for URI options %s " - "and %s.") - raise InvalidURI(err_msg % ( - options.cased_key('ssl'), options.cased_key('tls'))) + + if truth_value(options.get("ssl")) != truth_value(options.get("tls")): + err_msg = "Can not specify conflicting values for URI options %s " "and %s." + raise InvalidURI(err_msg % (options.cased_key("ssl"), options.cased_key("tls"))) return options @@ -232,26 +235,30 @@ def _handle_option_deprecations(options): for optname in list(options): if optname in URI_OPTIONS_DEPRECATION_MAP: mode, message = URI_OPTIONS_DEPRECATION_MAP[optname] - if mode == 'renamed': + if mode == "renamed": newoptname = message if newoptname in options: - warn_msg = ("Deprecated option '%s' ignored in favor of " - "'%s'.") + warn_msg = "Deprecated option '%s' ignored in favor of " "'%s'." warnings.warn( - warn_msg % (options.cased_key(optname), - options.cased_key(newoptname)), - DeprecationWarning, stacklevel=2) + warn_msg % (options.cased_key(optname), options.cased_key(newoptname)), + DeprecationWarning, + stacklevel=2, + ) options.pop(optname) continue warn_msg = "Option '%s' is deprecated, use '%s' instead." warnings.warn( warn_msg % (options.cased_key(optname), newoptname), - DeprecationWarning, stacklevel=2) - elif mode == 'removed': + DeprecationWarning, + stacklevel=2, + ) + elif mode == "removed": warn_msg = "Option '%s' is deprecated. %s." warnings.warn( warn_msg % (options.cased_key(optname), message), - DeprecationWarning, stacklevel=2) + DeprecationWarning, + stacklevel=2, + ) return options @@ -265,7 +272,7 @@ def _normalize_options(options): MongoDB URI options. """ # Expand the tlsInsecure option. - tlsinsecure = options.get('tlsinsecure') + tlsinsecure = options.get("tlsinsecure") if tlsinsecure is not None: for opt in _IMPLICIT_TLSINSECURE_OPTS: # Implicit options are logically the same as tlsInsecure. @@ -333,9 +340,8 @@ def split_options(opts, validate=True, warn=False, normalize=True): if validate: options = validate_options(options, warn) - if options.get('authsource') == '': - raise InvalidURI( - "the authSource database cannot be an empty string") + if options.get("authsource") == "": + raise InvalidURI("the authSource database cannot be an empty string") return options @@ -354,13 +360,12 @@ def split_hosts(hosts, default_port=DEFAULT_PORT): for a host. """ nodes = [] - for entity in hosts.split(','): + for entity in hosts.split(","): if not entity: - raise ConfigurationError("Empty host " - "(or extra comma in host list).") + raise ConfigurationError("Empty host " "(or extra comma in host list).") port = default_port # Unix socket entities don't have ports - if entity.endswith('.sock'): + if entity.endswith(".sock"): port = None nodes.append(parse_host(entity, port)) return nodes @@ -368,34 +373,37 @@ def split_hosts(hosts, default_port=DEFAULT_PORT): # Prohibited characters in database name. DB names also can't have ".", but for # backward-compat we allow "db.collection" in URI. -_BAD_DB_CHARS = re.compile('[' + re.escape(r'/ "$') + ']') +_BAD_DB_CHARS = re.compile("[" + re.escape(r'/ "$') + "]") _ALLOWED_TXT_OPTS = frozenset( - ['authsource', 'authSource', 'replicaset', 'replicaSet', 'loadbalanced', - 'loadBalanced']) + ["authsource", "authSource", "replicaset", "replicaSet", "loadbalanced", "loadBalanced"] +) def _check_options(nodes, options): # Ensure directConnection was not True if there are multiple seeds. - if len(nodes) > 1 and options.get('directconnection'): - raise ConfigurationError( - 'Cannot specify multiple hosts with directConnection=true') + if len(nodes) > 1 and options.get("directconnection"): + raise ConfigurationError("Cannot specify multiple hosts with directConnection=true") - if options.get('loadbalanced'): + if options.get("loadbalanced"): if len(nodes) > 1: - raise ConfigurationError( - 'Cannot specify multiple hosts with loadBalanced=true') - if options.get('directconnection'): - raise ConfigurationError( - 'Cannot specify directConnection=true with loadBalanced=true') - if options.get('replicaset'): - raise ConfigurationError( - 'Cannot specify replicaSet with loadBalanced=true') - - -def parse_uri(uri, default_port=DEFAULT_PORT, validate=True, warn=False, - normalize=True, connect_timeout=None, srv_service_name=None, - srv_max_hosts=None): + raise ConfigurationError("Cannot specify multiple hosts with loadBalanced=true") + if options.get("directconnection"): + raise ConfigurationError("Cannot specify directConnection=true with loadBalanced=true") + if options.get("replicaset"): + raise ConfigurationError("Cannot specify replicaSet with loadBalanced=true") + + +def parse_uri( + uri, + default_port=DEFAULT_PORT, + validate=True, + warn=False, + normalize=True, + connect_timeout=None, + srv_service_name=None, + srv_max_hosts=None, +): """Parse and validate a MongoDB URI. Returns a dict of the form:: @@ -454,14 +462,16 @@ def parse_uri(uri, default_port=DEFAULT_PORT, validate=True, warn=False, python_path = sys.executable or "python" raise ConfigurationError( 'The "dnspython" module must be ' - 'installed to use mongodb+srv:// URIs. ' - 'To fix this error install pymongo with the srv extra:\n ' - '%s -m pip install "pymongo[srv]"' % (python_path)) + "installed to use mongodb+srv:// URIs. " + "To fix this error install pymongo with the srv extra:\n " + '%s -m pip install "pymongo[srv]"' % (python_path) + ) is_srv = True scheme_free = uri[SRV_SCHEME_LEN:] else: - raise InvalidURI("Invalid URI scheme: URI must " - "begin with '%s' or '%s'" % (SCHEME, SRV_SCHEME)) + raise InvalidURI( + "Invalid URI scheme: URI must " "begin with '%s' or '%s'" % (SCHEME, SRV_SCHEME) + ) if not scheme_free: raise InvalidURI("Must provide at least one hostname or IP.") @@ -472,21 +482,20 @@ def parse_uri(uri, default_port=DEFAULT_PORT, validate=True, warn=False, collection = None options = _CaseInsensitiveDictionary() - host_part, _, path_part = scheme_free.partition('/') + host_part, _, path_part = scheme_free.partition("/") if not host_part: host_part = path_part path_part = "" - if not path_part and '?' in host_part: - raise InvalidURI("A '/' is required between " - "the host list and any options.") + if not path_part and "?" in host_part: + raise InvalidURI("A '/' is required between " "the host list and any options.") if path_part: - dbase, _, opts = path_part.partition('?') + dbase, _, opts = path_part.partition("?") if dbase: dbase = unquote_plus(dbase) - if '.' in dbase: - dbase, collection = dbase.split('.', 1) + if "." in dbase: + dbase, collection = dbase.split(".", 1) if _BAD_DB_CHARS.search(dbase): raise InvalidURI('Bad database name "%s"' % dbase) else: @@ -496,77 +505,74 @@ def parse_uri(uri, default_port=DEFAULT_PORT, validate=True, warn=False, options.update(split_options(opts, validate, warn, normalize)) if srv_service_name is None: srv_service_name = options.get("srvServiceName", SRV_SERVICE_NAME) - if '@' in host_part: - userinfo, _, hosts = host_part.rpartition('@') + if "@" in host_part: + userinfo, _, hosts = host_part.rpartition("@") user, passwd = parse_userinfo(userinfo) else: hosts = host_part - if '/' in hosts: - raise InvalidURI("Any '/' in a unix domain socket must be" - " percent-encoded: %s" % host_part) + if "/" in hosts: + raise InvalidURI( + "Any '/' in a unix domain socket must be" " percent-encoded: %s" % host_part + ) hosts = unquote_plus(hosts) fqdn = None srv_max_hosts = srv_max_hosts or options.get("srvMaxHosts") if is_srv: - if options.get('directConnection'): + if options.get("directConnection"): raise ConfigurationError( - "Cannot specify directConnection=true with " - "%s URIs" % (SRV_SCHEME,)) + "Cannot specify directConnection=true with " "%s URIs" % (SRV_SCHEME,) + ) nodes = split_hosts(hosts, default_port=None) if len(nodes) != 1: - raise InvalidURI( - "%s URIs must include one, " - "and only one, hostname" % (SRV_SCHEME,)) + raise InvalidURI("%s URIs must include one, " "and only one, hostname" % (SRV_SCHEME,)) fqdn, port = nodes[0] if port is not None: - raise InvalidURI( - "%s URIs must not include a port number" % (SRV_SCHEME,)) + raise InvalidURI("%s URIs must not include a port number" % (SRV_SCHEME,)) # Use the connection timeout. connectTimeoutMS passed as a keyword # argument overrides the same option passed in the connection string. connect_timeout = connect_timeout or options.get("connectTimeoutMS") - dns_resolver = _SrvResolver(fqdn, connect_timeout, srv_service_name, - srv_max_hosts) + dns_resolver = _SrvResolver(fqdn, connect_timeout, srv_service_name, srv_max_hosts) nodes = dns_resolver.get_hosts() dns_options = dns_resolver.get_options() if dns_options: - parsed_dns_options = split_options( - dns_options, validate, warn, normalize) + parsed_dns_options = split_options(dns_options, validate, warn, normalize) if set(parsed_dns_options) - _ALLOWED_TXT_OPTS: raise ConfigurationError( - "Only authSource, replicaSet, and loadBalanced are " - "supported from DNS") + "Only authSource, replicaSet, and loadBalanced are " "supported from DNS" + ) for opt, val in parsed_dns_options.items(): if opt not in options: options[opt] = val if options.get("loadBalanced") and srv_max_hosts: - raise InvalidURI( - "You cannot specify loadBalanced with srvMaxHosts") + raise InvalidURI("You cannot specify loadBalanced with srvMaxHosts") if options.get("replicaSet") and srv_max_hosts: raise InvalidURI("You cannot specify replicaSet with srvMaxHosts") if "tls" not in options and "ssl" not in options: - options["tls"] = True if validate else 'true' + options["tls"] = True if validate else "true" elif not is_srv and options.get("srvServiceName") is not None: - raise ConfigurationError("The srvServiceName option is only allowed " - "with 'mongodb+srv://' URIs") + raise ConfigurationError( + "The srvServiceName option is only allowed " "with 'mongodb+srv://' URIs" + ) elif not is_srv and srv_max_hosts: - raise ConfigurationError("The srvMaxHosts option is only allowed " - "with 'mongodb+srv://' URIs") + raise ConfigurationError( + "The srvMaxHosts option is only allowed " "with 'mongodb+srv://' URIs" + ) else: nodes = split_hosts(hosts, default_port=default_port) _check_options(nodes, options) return { - 'nodelist': nodes, - 'username': user, - 'password': passwd, - 'database': dbase, - 'collection': collection, - 'options': options, - 'fqdn': fqdn + "nodelist": nodes, + "username": user, + "password": passwd, + "database": dbase, + "collection": collection, + "options": options, + "fqdn": fqdn, } @@ -575,37 +581,39 @@ def _parse_kms_tls_options(kms_tls_options): if not kms_tls_options: return {} if not isinstance(kms_tls_options, dict): - raise TypeError('kms_tls_options must be a dict') + raise TypeError("kms_tls_options must be a dict") contexts = {} for provider, opts in kms_tls_options.items(): if not isinstance(opts, dict): raise TypeError(f'kms_tls_options["{provider}"] must be a dict') - opts.setdefault('tls', True) + opts.setdefault("tls", True) opts = _CaseInsensitiveDictionary(opts) opts = _handle_security_options(opts) opts = _normalize_options(opts) opts = validate_options(opts) ssl_context, allow_invalid_hostnames = _parse_ssl_options(opts) if ssl_context is None: - raise ConfigurationError('TLS is required for KMS providers') + raise ConfigurationError("TLS is required for KMS providers") if allow_invalid_hostnames: - raise ConfigurationError('Insecure TLS options prohibited') - - for n in ['tlsInsecure', - 'tlsAllowInvalidCertificates', - 'tlsAllowInvalidHostnames', - 'tlsDisableOCSPEndpointCheck', - 'tlsDisableCertificateRevocationCheck']: + raise ConfigurationError("Insecure TLS options prohibited") + + for n in [ + "tlsInsecure", + "tlsAllowInvalidCertificates", + "tlsAllowInvalidHostnames", + "tlsDisableOCSPEndpointCheck", + "tlsDisableCertificateRevocationCheck", + ]: if n in opts: - raise ConfigurationError( - f'Insecure TLS options prohibited: {n}') + raise ConfigurationError(f"Insecure TLS options prohibited: {n}") contexts[provider] = ssl_context return contexts -if __name__ == '__main__': +if __name__ == "__main__": import pprint import sys + try: pprint.pprint(parse_uri(sys.argv[1])) except InvalidURI as exc: diff --git a/pymongo/write_concern.py b/pymongo/write_concern.py index ebc997c0db..4427c7e065 100644 --- a/pymongo/write_concern.py +++ b/pymongo/write_concern.py @@ -67,8 +67,7 @@ def __init__(self, w=None, wtimeout=None, j=None, fsync=None): if not isinstance(fsync, bool): raise TypeError("fsync must be True or False") if j and fsync: - raise ConfigurationError("Can't set both j " - "and fsync at the same time") + raise ConfigurationError("Can't set both j " "and fsync at the same time") self.__document["fsync"] = fsync if w == 0 and j is True: @@ -108,8 +107,7 @@ def acknowledged(self): return self.__acknowledged def __repr__(self): - return ("WriteConcern(%s)" % ( - ", ".join("%s=%s" % kvt for kvt in self.__document.items()),)) + return "WriteConcern(%s)" % (", ".join("%s=%s" % kvt for kvt in self.__document.items()),) def __eq__(self, other): if isinstance(other, WriteConcern): diff --git a/setup.py b/setup.py index 7251afcb4e..cb764ca99f 100755 --- a/setup.py +++ b/setup.py @@ -4,7 +4,6 @@ import sys import warnings - if sys.version_info[:2] < (3, 6): raise RuntimeError("Python version >= 3.6 required.") @@ -15,8 +14,8 @@ except ImportError: pass -from setuptools import setup, __version__ as _setuptools_version - +from setuptools import __version__ as _setuptools_version +from setuptools import setup if sys.version_info[:2] < (3, 10): from distutils.cmd import Command @@ -54,13 +53,14 @@ # generated by distutils for Apple provided pythons, allowing C extension # builds to complete without error. The inspiration comes from older # versions of distutils.sysconfig.get_config_vars. -if sys.platform == 'darwin' and 'clang' in platform.python_compiler().lower(): +if sys.platform == "darwin" and "clang" in platform.python_compiler().lower(): from distutils.sysconfig import get_config_vars + res = get_config_vars() - for key in ('CFLAGS', 'PY_CFLAGS'): + for key in ("CFLAGS", "PY_CFLAGS"): if key in res: flags = res[key] - flags = re.sub('-mno-fused-madd', '', flags) + flags = re.sub("-mno-fused-madd", "", flags) res[key] = flags @@ -69,11 +69,9 @@ class test(Command): user_options = [ ("test-module=", "m", "Discover tests in specified module"), - ("test-suite=", "s", - "Test suite to run (e.g. 'some_module.test_suite')"), + ("test-suite=", "s", "Test suite to run (e.g. 'some_module.test_suite')"), ("failfast", "f", "Stop running tests on first failure or error"), - ("xunit-output=", "x", - "Generate a results directory with XUnit XML format") + ("xunit-output=", "x", "Generate a results directory with XUnit XML format"), ] def initialize_options(self): @@ -84,44 +82,42 @@ def initialize_options(self): def finalize_options(self): if self.test_suite is None and self.test_module is None: - self.test_module = 'test' + self.test_module = "test" elif self.test_module is not None and self.test_suite is not None: - raise Exception( - "You may specify a module or suite, but not both" - ) + raise Exception("You may specify a module or suite, but not both") def run(self): # Installing required packages, running egg_info and build_ext are # part of normal operation for setuptools.command.test.test if self.distribution.install_requires: - self.distribution.fetch_build_eggs( - self.distribution.install_requires) + self.distribution.fetch_build_eggs(self.distribution.install_requires) if self.distribution.tests_require: self.distribution.fetch_build_eggs(self.distribution.tests_require) if self.xunit_output: self.distribution.fetch_build_eggs(["unittest-xml-reporting"]) - self.run_command('egg_info') - build_ext_cmd = self.reinitialize_command('build_ext') + self.run_command("egg_info") + build_ext_cmd = self.reinitialize_command("build_ext") build_ext_cmd.inplace = 1 - self.run_command('build_ext') + self.run_command("build_ext") # Construct a TextTestRunner directly from the unittest imported from # test, which creates a TestResult that supports the 'addSkip' method. # setuptools will by default create a TextTestRunner that uses the old # TestResult class. - from test import unittest, PymongoTestRunner, test_cases + from test import PymongoTestRunner, test_cases, unittest + if self.test_suite is None: all_tests = unittest.defaultTestLoader.discover(self.test_module) suite = unittest.TestSuite() - suite.addTests(sorted(test_cases(all_tests), - key=lambda x: x.__module__)) + suite.addTests(sorted(test_cases(all_tests), key=lambda x: x.__module__)) else: - suite = unittest.defaultTestLoader.loadTestsFromName( - self.test_suite) + suite = unittest.defaultTestLoader.loadTestsFromName(self.test_suite) if self.xunit_output: from test import PymongoXMLTestRunner - runner = PymongoXMLTestRunner(verbosity=2, failfast=self.failfast, - output=self.xunit_output) + + runner = PymongoXMLTestRunner( + verbosity=2, failfast=self.failfast, output=self.xunit_output + ) else: runner = PymongoTestRunner(verbosity=2, failfast=self.failfast) result = runner.run(suite) @@ -132,8 +128,7 @@ class doc(Command): description = "generate or test documentation" - user_options = [("test", "t", - "run doctests instead of generating documentation")] + user_options = [("test", "t", "run doctests instead of generating documentation")] boolean_options = ["test"] @@ -146,16 +141,13 @@ def finalize_options(self): def run(self): if not _HAVE_SPHINX: - raise RuntimeError( - "You must install Sphinx to build or test the documentation.") + raise RuntimeError("You must install Sphinx to build or test the documentation.") if self.test: - path = os.path.join( - os.path.abspath('.'), "doc", "_build", "doctest") + path = os.path.join(os.path.abspath("."), "doc", "_build", "doctest") mode = "doctest" else: - path = os.path.join( - os.path.abspath('.'), "doc", "_build", version) + path = os.path.join(os.path.abspath("."), "doc", "_build", version) mode = "html" try: @@ -168,7 +160,7 @@ def run(self): # sphinx.main calls sys.exit when sphinx.build_main exists. # Call build_main directly so we can check status and print # the full path to the built docs. - if hasattr(sphinx, 'build_main'): + if hasattr(sphinx, "build_main"): status = sphinx.build_main(sphinx_args) else: status = sphinx.main(sphinx_args) @@ -176,8 +168,9 @@ def run(self): if status: raise RuntimeError("documentation step '%s' failed" % (mode,)) - sys.stdout.write("\nDocumentation step '%s' performed, results here:\n" - " %s/\n" % (mode, path)) + sys.stdout.write( + "\nDocumentation step '%s' performed, results here:\n" " %s/\n" % (mode, path) + ) class custom_build_ext(build_ext): @@ -234,11 +227,14 @@ def run(self): build_ext.run(self) except Exception: e = sys.exc_info()[1] - sys.stdout.write('%s\n' % str(e)) - warnings.warn(self.warning_message % ("Extension modules", - "There was an issue with " - "your platform configuration" - " - see above.")) + sys.stdout.write("%s\n" % str(e)) + warnings.warn( + self.warning_message + % ( + "Extension modules", + "There was an issue with " "your platform configuration" " - see above.", + ) + ) def build_extension(self, ext): name = ext.name @@ -246,68 +242,75 @@ def build_extension(self, ext): build_ext.build_extension(self, ext) except Exception: e = sys.exc_info()[1] - sys.stdout.write('%s\n' % str(e)) - warnings.warn(self.warning_message % ("The %s extension " - "module" % (name,), - "The output above " - "this warning shows how " - "the compilation " - "failed.")) - -ext_modules = [Extension('bson._cbson', - include_dirs=['bson'], - sources=['bson/_cbsonmodule.c', - 'bson/time64.c', - 'bson/buffer.c', - 'bson/encoding_helpers.c']), - Extension('pymongo._cmessage', - include_dirs=['bson'], - sources=['pymongo/_cmessagemodule.c', - 'bson/buffer.c'])] + sys.stdout.write("%s\n" % str(e)) + warnings.warn( + self.warning_message + % ( + "The %s extension " "module" % (name,), + "The output above " "this warning shows how " "the compilation " "failed.", + ) + ) + + +ext_modules = [ + Extension( + "bson._cbson", + include_dirs=["bson"], + sources=[ + "bson/_cbsonmodule.c", + "bson/time64.c", + "bson/buffer.c", + "bson/encoding_helpers.c", + ], + ), + Extension( + "pymongo._cmessage", + include_dirs=["bson"], + sources=["pymongo/_cmessagemodule.c", "bson/buffer.c"], + ), +] # PyOpenSSL 17.0.0 introduced support for OCSP. 17.1.0 introduced # a related feature we need. 17.2.0 fixes a bug # in set_default_verify_paths we should really avoid. # service_identity 18.1.0 introduced support for IP addr matching. pyopenssl_reqs = ["pyopenssl>=17.2.0", "requests<3.0.0", "service_identity>=18.1.0"] -if sys.platform in ('win32', 'darwin'): +if sys.platform in ("win32", "darwin"): # Fallback to certifi on Windows if we can't load CA certs from the system # store and just use certifi on macOS. # https://www.pyopenssl.org/en/stable/api/ssl.html#OpenSSL.SSL.Context.set_default_verify_paths - pyopenssl_reqs.append('certifi') + pyopenssl_reqs.append("certifi") extras_require = { - 'encryption': ['pymongocrypt>=1.2.0,<2.0.0'], - 'ocsp': pyopenssl_reqs, - 'snappy': ['python-snappy'], - 'zstd': ['zstandard'], - 'aws': ['pymongo-auth-aws<2.0.0'], - 'srv': ["dnspython>=1.16.0,<3.0.0"], + "encryption": ["pymongocrypt>=1.2.0,<2.0.0"], + "ocsp": pyopenssl_reqs, + "snappy": ["python-snappy"], + "zstd": ["zstandard"], + "aws": ["pymongo-auth-aws<2.0.0"], + "srv": ["dnspython>=1.16.0,<3.0.0"], } # GSSAPI extras -if sys.platform == 'win32': - extras_require['gssapi'] = ["winkerberos>=0.5.0"] +if sys.platform == "win32": + extras_require["gssapi"] = ["winkerberos>=0.5.0"] else: - extras_require['gssapi'] = ["pykerberos"] + extras_require["gssapi"] = ["pykerberos"] -extra_opts = { - "packages": ["bson", "pymongo", "gridfs"] -} +extra_opts = {"packages": ["bson", "pymongo", "gridfs"]} if "--no_ext" in sys.argv: sys.argv.remove("--no_ext") -elif (sys.platform.startswith("java") or - sys.platform == "cli" or - "PyPy" in sys.version): - sys.stdout.write(""" +elif sys.platform.startswith("java") or sys.platform == "cli" or "PyPy" in sys.version: + sys.stdout.write( + """ *****************************************************\n The optional C extensions are currently not supported\n by this python implementation.\n *****************************************************\n -""") +""" + ) else: - extra_opts['ext_modules'] = ext_modules + extra_opts["ext_modules"] = ext_modules setup( name="pymongo", @@ -337,10 +340,9 @@ def build_extension(self, ext): "Programming Language :: Python :: 3.10", "Programming Language :: Python :: Implementation :: CPython", "Programming Language :: Python :: Implementation :: PyPy", - "Topic :: Database"], - cmdclass={"build_ext": custom_build_ext, - "doc": doc, - "test": test}, + "Topic :: Database", + ], + cmdclass={"build_ext": custom_build_ext, "doc": doc, "test": test}, extras_require=extras_require, **extra_opts ) diff --git a/test/__init__.py b/test/__init__.py index ab53b7fdc5..41fd958ce5 100644 --- a/test/__init__.py +++ b/test/__init__.py @@ -27,6 +27,7 @@ try: from xmlrunner import XMLTestRunner + HAVE_XML = True # ValueError is raised when version 3+ is installed on Jython 2.7. except (ImportError, ValueError): @@ -34,17 +35,18 @@ try: import ipaddress + HAVE_IPADDRESS = True except ImportError: HAVE_IPADDRESS = False from contextlib import contextmanager from functools import wraps +from test.version import Version from unittest import SkipTest import pymongo import pymongo.errors - from bson.son import SON from pymongo import common, message from pymongo.common import partition_node @@ -52,7 +54,6 @@ from pymongo.server_api import ServerApi from pymongo.ssl_support import HAVE_SSL, _ssl from pymongo.uri_parser import parse_uri -from test.version import Version if HAVE_SSL: import ssl @@ -61,36 +62,34 @@ # Enable the fault handler to dump the traceback of each running thread # after a segfault. import faulthandler + faulthandler.enable() except ImportError: pass # Enable debug output for uncollectable objects. PyPy does not have set_debug. -if hasattr(gc, 'set_debug'): +if hasattr(gc, "set_debug"): gc.set_debug( - gc.DEBUG_UNCOLLECTABLE | - getattr(gc, 'DEBUG_OBJECTS', 0) | - getattr(gc, 'DEBUG_INSTANCES', 0)) + gc.DEBUG_UNCOLLECTABLE | getattr(gc, "DEBUG_OBJECTS", 0) | getattr(gc, "DEBUG_INSTANCES", 0) + ) # The host and port of a single mongod or mongos, or the seed host # for a replica set. -host = os.environ.get("DB_IP", 'localhost') +host = os.environ.get("DB_IP", "localhost") port = int(os.environ.get("DB_PORT", 27017)) db_user = os.environ.get("DB_USER", "user") db_pwd = os.environ.get("DB_PASSWORD", "password") -CERT_PATH = os.path.join(os.path.dirname(os.path.realpath(__file__)), - 'certificates') -CLIENT_PEM = os.environ.get('CLIENT_PEM', - os.path.join(CERT_PATH, 'client.pem')) -CA_PEM = os.environ.get('CA_PEM', os.path.join(CERT_PATH, 'ca.pem')) +CERT_PATH = os.path.join(os.path.dirname(os.path.realpath(__file__)), "certificates") +CLIENT_PEM = os.environ.get("CLIENT_PEM", os.path.join(CERT_PATH, "client.pem")) +CA_PEM = os.environ.get("CA_PEM", os.path.join(CERT_PATH, "ca.pem")) TLS_OPTIONS = dict(tls=True) if CLIENT_PEM: - TLS_OPTIONS['tlsCertificateKeyFile'] = CLIENT_PEM + TLS_OPTIONS["tlsCertificateKeyFile"] = CLIENT_PEM if CA_PEM: - TLS_OPTIONS['tlsCAFile'] = CA_PEM + TLS_OPTIONS["tlsCAFile"] = CA_PEM COMPRESSORS = os.environ.get("COMPRESSORS") MONGODB_API_VERSION = os.environ.get("MONGODB_API_VERSION") @@ -101,20 +100,21 @@ if TEST_LOADBALANCER: # Remove after PYTHON-2712 from pymongo import pool + pool._MOCK_SERVICE_ID = True res = parse_uri(SINGLE_MONGOS_LB_URI) - host, port = res['nodelist'][0] - db_user = res['username'] or db_user - db_pwd = res['password'] or db_pwd + host, port = res["nodelist"][0] + db_user = res["username"] or db_user + db_pwd = res["password"] or db_pwd elif TEST_SERVERLESS: TEST_LOADBALANCER = True res = parse_uri(SINGLE_MONGOS_LB_URI) - host, port = res['nodelist'][0] - db_user = res['username'] or db_user - db_pwd = res['password'] or db_pwd - TLS_OPTIONS = {'tls': True} + host, port = res["nodelist"][0] + db_user = res["username"] or db_user + db_pwd = res["password"] or db_pwd + TLS_OPTIONS = {"tls": True} # Spec says serverless tests must be run with compression. - COMPRESSORS = COMPRESSORS or 'zlib' + COMPRESSORS = COMPRESSORS or "zlib" def is_server_resolvable(): @@ -123,7 +123,7 @@ def is_server_resolvable(): socket.setdefaulttimeout(1) try: try: - socket.gethostbyname('server') + socket.gethostbyname("server") return True except socket.error: return False @@ -132,22 +132,23 @@ def is_server_resolvable(): def _create_user(authdb, user, pwd=None, roles=None, **kwargs): - cmd = SON([('createUser', user)]) + cmd = SON([("createUser", user)]) # X509 doesn't use a password if pwd: - cmd['pwd'] = pwd - cmd['roles'] = roles or ['root'] + cmd["pwd"] = pwd + cmd["roles"] = roles or ["root"] cmd.update(**kwargs) return authdb.command(cmd) class client_knobs(object): def __init__( - self, - heartbeat_frequency=None, - min_heartbeat_interval=None, - kill_cursor_frequency=None, - events_queue_frequency=None): + self, + heartbeat_frequency=None, + min_heartbeat_interval=None, + kill_cursor_frequency=None, + events_queue_frequency=None, + ): self.heartbeat_frequency = heartbeat_frequency self.min_heartbeat_interval = min_heartbeat_interval self.kill_cursor_frequency = kill_cursor_frequency @@ -179,7 +180,7 @@ def enable(self): common.EVENTS_QUEUE_FREQUENCY = self.events_queue_frequency self._enabled = True # Store the allocation traceback to catch non-disabled client_knobs. - self._stack = ''.join(traceback.format_stack()) + self._stack = "".join(traceback.format_stack()) def __enter__(self): self.enable() @@ -200,6 +201,7 @@ def make_wrapper(f): def wrap(*args, **kwargs): with self: return f(*args, **kwargs) + return wrap return make_wrapper(func) @@ -207,20 +209,23 @@ def wrap(*args, **kwargs): def __del__(self): if self._enabled: msg = ( - 'ERROR: client_knobs still enabled! HEARTBEAT_FREQUENCY=%s, ' - 'MIN_HEARTBEAT_INTERVAL=%s, KILL_CURSOR_FREQUENCY=%s, ' - 'EVENTS_QUEUE_FREQUENCY=%s, stack:\n%s' % ( + "ERROR: client_knobs still enabled! HEARTBEAT_FREQUENCY=%s, " + "MIN_HEARTBEAT_INTERVAL=%s, KILL_CURSOR_FREQUENCY=%s, " + "EVENTS_QUEUE_FREQUENCY=%s, stack:\n%s" + % ( common.HEARTBEAT_FREQUENCY, common.MIN_HEARTBEAT_INTERVAL, common.KILL_CURSOR_FREQUENCY, common.EVENTS_QUEUE_FREQUENCY, - self._stack)) + self._stack, + ) + ) self.disable() raise Exception(msg) def _all_users(db): - return set(u['user'] for u in db.command('usersInfo').get('users', [])) + return set(u["user"] for u in db.command("usersInfo").get("users", [])) class ClientContext(object): @@ -267,10 +272,10 @@ def client_options(self): """Return the MongoClient options for creating a duplicate client.""" opts = client_context.default_client_options.copy() if client_context.auth_enabled: - opts['username'] = db_user - opts['password'] = db_pwd + opts["username"] = db_user + opts["password"] = db_pwd if self.replica_set_name: - opts['replicaSet'] = self.replica_set_name + opts["replicaSet"] = self.replica_set_name return opts @property @@ -281,29 +286,26 @@ def hello(self): def _connect(self, host, port, **kwargs): # Jython takes a long time to connect. - if sys.platform.startswith('java'): + if sys.platform.startswith("java"): timeout_ms = 10000 else: timeout_ms = 5000 kwargs.update(self.default_client_options) - client = pymongo.MongoClient( - host, port, serverSelectionTimeoutMS=timeout_ms, **kwargs) + client = pymongo.MongoClient(host, port, serverSelectionTimeoutMS=timeout_ms, **kwargs) try: try: client.admin.command(HelloCompat.LEGACY_CMD) # Can we connect? except pymongo.errors.OperationFailure as exc: # SERVER-32063 self.connection_attempts.append( - 'connected client %r, but legacy hello failed: %s' % ( - client, exc)) + "connected client %r, but legacy hello failed: %s" % (client, exc) + ) else: - self.connection_attempts.append( - 'successfully connected client %r' % (client,)) + self.connection_attempts.append("successfully connected client %r" % (client,)) # If connected, then return client with default timeout return pymongo.MongoClient(host, port, **kwargs) except pymongo.errors.ConnectionFailure as exc: - self.connection_attempts.append( - 'failed to connect client %r: %s' % (client, exc)) + self.connection_attempts.append("failed to connect client %r: %s" % (client, exc)) return None finally: client.close() @@ -314,12 +316,11 @@ def _init_client(self): if self.client is not None: # Return early when connected to dataLake as mongohoused does not # support the getCmdLineOpts command and is tested without TLS. - build_info = self.client.admin.command('buildInfo') - if 'dataLake' in build_info: + build_info = self.client.admin.command("buildInfo") + if "dataLake" in build_info: self.is_data_lake = True self.auth_enabled = True - self.client = self._connect( - host, port, username=db_user, password=db_pwd) + self.client = self._connect(host, port, username=db_user, password=db_pwd) self.connected = True return @@ -338,10 +339,10 @@ def _init_client(self): self.auth_enabled = True else: try: - self.cmd_line = self.client.admin.command('getCmdLineOpts') + self.cmd_line = self.client.admin.command("getCmdLineOpts") except pymongo.errors.OperationFailure as e: - msg = e.details.get('errmsg', '') - if e.code == 13 or 'unauthorized' in msg or 'login' in msg: + msg = e.details.get("errmsg", "") + if e.code == 13 or "unauthorized" in msg or "login" in msg: # Unauthorized. self.auth_enabled = True else: @@ -356,26 +357,30 @@ def _init_client(self): _create_user(self.client.admin, db_user, db_pwd) self.client = self._connect( - host, port, username=db_user, password=db_pwd, + host, + port, + username=db_user, + password=db_pwd, replicaSet=self.replica_set_name, - **self.default_client_options) + **self.default_client_options + ) # May not have this if OperationFailure was raised earlier. - self.cmd_line = self.client.admin.command('getCmdLineOpts') + self.cmd_line = self.client.admin.command("getCmdLineOpts") if self.serverless: self.server_status = {} else: - self.server_status = self.client.admin.command('serverStatus') + self.server_status = self.client.admin.command("serverStatus") if self.storage_engine == "mmapv1": # MMAPv1 does not support retryWrites=True. - self.default_client_options['retryWrites'] = False + self.default_client_options["retryWrites"] = False hello = self.hello - self.sessions_enabled = 'logicalSessionTimeoutMinutes' in hello + self.sessions_enabled = "logicalSessionTimeoutMinutes" in hello - if 'setName' in hello: - self.replica_set_name = str(hello['setName']) + if "setName" in hello: + self.replica_set_name = str(hello["setName"]) self.is_rs = True if self.auth_enabled: # It doesn't matter which member we use as the seed here. @@ -385,23 +390,19 @@ def _init_client(self): username=db_user, password=db_pwd, replicaSet=self.replica_set_name, - **self.default_client_options) + **self.default_client_options + ) else: self.client = pymongo.MongoClient( - host, - port, - replicaSet=self.replica_set_name, - **self.default_client_options) + host, port, replicaSet=self.replica_set_name, **self.default_client_options + ) # Get the authoritative hello result from the primary. self._hello = None hello = self.hello - nodes = [partition_node(node.lower()) - for node in hello.get('hosts', [])] - nodes.extend([partition_node(node.lower()) - for node in hello.get('passives', [])]) - nodes.extend([partition_node(node.lower()) - for node in hello.get('arbiters', [])]) + nodes = [partition_node(node.lower()) for node in hello.get("hosts", [])] + nodes.extend([partition_node(node.lower()) for node in hello.get("passives", [])]) + nodes.extend([partition_node(node.lower()) for node in hello.get("arbiters", [])]) self.nodes = set(nodes) else: self.nodes = set([(host, port)]) @@ -410,38 +411,36 @@ def _init_client(self): if self.serverless: self.server_parameters = { - 'requireApiVersion': False, - 'enableTestCommands': True, + "requireApiVersion": False, + "enableTestCommands": True, } self.test_commands_enabled = True self.has_ipv6 = False else: - self.server_parameters = self.client.admin.command( - 'getParameter', '*') - if 'enableTestCommands=1' in self.cmd_line['argv']: + self.server_parameters = self.client.admin.command("getParameter", "*") + if "enableTestCommands=1" in self.cmd_line["argv"]: self.test_commands_enabled = True - elif 'parsed' in self.cmd_line: - params = self.cmd_line['parsed'].get('setParameter', []) - if 'enableTestCommands=1' in params: + elif "parsed" in self.cmd_line: + params = self.cmd_line["parsed"].get("setParameter", []) + if "enableTestCommands=1" in params: self.test_commands_enabled = True else: - params = self.cmd_line['parsed'].get('setParameter', {}) - if params.get('enableTestCommands') == '1': + params = self.cmd_line["parsed"].get("setParameter", {}) + if params.get("enableTestCommands") == "1": self.test_commands_enabled = True self.has_ipv6 = self._server_started_with_ipv6() - self.is_mongos = (self.hello.get('msg') == 'isdbgrid') + self.is_mongos = self.hello.get("msg") == "isdbgrid" if self.is_mongos: address = self.client.address self.mongoses.append(address) if not self.serverless: # Check for another mongos on the next port. next_address = address[0], address[1] + 1 - mongos_client = self._connect( - *next_address, **self.default_client_options) + mongos_client = self._connect(*next_address, **self.default_client_options) if mongos_client: hello = mongos_client.admin.command(HelloCompat.LEGACY_CMD) - if hello.get('msg') == 'isdbgrid': + if hello.get("msg") == "isdbgrid": self.mongoses.append(next_address) def init(self): @@ -450,7 +449,7 @@ def init(self): self._init_client() def connection_attempt_info(self): - return '\n'.join(self.connection_attempts) + return "\n".join(self.connection_attempts) @property def host(self): @@ -487,17 +486,19 @@ def storage_engine(self): def _check_user_provided(self): """Return True if db_user/db_password is already an admin user.""" client = pymongo.MongoClient( - host, port, + host, + port, username=db_user, password=db_pwd, serverSelectionTimeoutMS=100, - **self.default_client_options) + **self.default_client_options + ) try: return db_user in _all_users(client.admin) except pymongo.errors.OperationFailure as e: - msg = e.details.get('errmsg', '') - if e.code == 18 or 'auth fails' in msg: + msg = e.details.get("errmsg", "") + if e.code == 18 or "auth fails" in msg: # Auth failed. return False else: @@ -505,31 +506,30 @@ def _check_user_provided(self): def _server_started_with_auth(self): # MongoDB >= 2.0 - if 'parsed' in self.cmd_line: - parsed = self.cmd_line['parsed'] + if "parsed" in self.cmd_line: + parsed = self.cmd_line["parsed"] # MongoDB >= 2.6 - if 'security' in parsed: - security = parsed['security'] + if "security" in parsed: + security = parsed["security"] # >= rc3 - if 'authorization' in security: - return security['authorization'] == 'enabled' + if "authorization" in security: + return security["authorization"] == "enabled" # < rc3 - return (security.get('auth', False) or - bool(security.get('keyFile'))) - return parsed.get('auth', False) or bool(parsed.get('keyFile')) + return security.get("auth", False) or bool(security.get("keyFile")) + return parsed.get("auth", False) or bool(parsed.get("keyFile")) # Legacy - argv = self.cmd_line['argv'] - return '--auth' in argv or '--keyFile' in argv + argv = self.cmd_line["argv"] + return "--auth" in argv or "--keyFile" in argv def _server_started_with_ipv6(self): if not socket.has_ipv6: return False - if 'parsed' in self.cmd_line: - if not self.cmd_line['parsed'].get('net', {}).get('ipv6'): + if "parsed" in self.cmd_line: + if not self.cmd_line["parsed"].get("net", {}).get("ipv6"): return False else: - if '--ipv6' not in self.cmd_line['argv']: + if "--ipv6" not in self.cmd_line["argv"]: return False # The server was started with --ipv6. Is there an IPv6 route to it? @@ -549,101 +549,107 @@ def wrap(*args, **kwargs): self.init() # Always raise SkipTest if we can't connect to MongoDB if not self.connected: - raise SkipTest( - "Cannot connect to MongoDB on %s" % (self.pair,)) + raise SkipTest("Cannot connect to MongoDB on %s" % (self.pair,)) if condition(): return f(*args, **kwargs) raise SkipTest(msg) + return wrap if func is None: + def decorate(f): return make_wrapper(f) + return decorate return make_wrapper(func) def create_user(self, dbname, user, pwd=None, roles=None, **kwargs): - kwargs['writeConcern'] = {'w': self.w} + kwargs["writeConcern"] = {"w": self.w} return _create_user(self.client[dbname], user, pwd, roles, **kwargs) def drop_user(self, dbname, user): - self.client[dbname].command( - 'dropUser', user, writeConcern={'w': self.w}) + self.client[dbname].command("dropUser", user, writeConcern={"w": self.w}) def require_connection(self, func): """Run a test only if we can connect to MongoDB.""" return self._require( lambda: True, # _require checks if we're connected "Cannot connect to MongoDB on %s" % (self.pair,), - func=func) + func=func, + ) def require_data_lake(self, func): """Run a test only if we are connected to Atlas Data Lake.""" return self._require( lambda: self.is_data_lake, "Not connected to Atlas Data Lake on %s" % (self.pair,), - func=func) + func=func, + ) def require_no_mmap(self, func): """Run a test only if the server is not using the MMAPv1 storage engine. Only works for standalone and replica sets; tests are - run regardless of storage engine on sharded clusters. """ + run regardless of storage engine on sharded clusters.""" + def is_not_mmap(): if self.is_mongos: return True - return self.storage_engine != 'mmapv1' + return self.storage_engine != "mmapv1" - return self._require( - is_not_mmap, "Storage engine must not be MMAPv1", func=func) + return self._require(is_not_mmap, "Storage engine must not be MMAPv1", func=func) def require_version_min(self, *ver): """Run a test only if the server version is at least ``version``.""" other_version = Version(*ver) - return self._require(lambda: self.version >= other_version, - "Server version must be at least %s" - % str(other_version)) + return self._require( + lambda: self.version >= other_version, + "Server version must be at least %s" % str(other_version), + ) def require_version_max(self, *ver): """Run a test only if the server version is at most ``version``.""" other_version = Version(*ver) - return self._require(lambda: self.version <= other_version, - "Server version must be at most %s" - % str(other_version)) + return self._require( + lambda: self.version <= other_version, + "Server version must be at most %s" % str(other_version), + ) def require_auth(self, func): """Run a test only if the server is running with auth enabled.""" - return self._require(lambda: self.auth_enabled, - "Authentication is not enabled on the server", - func=func) + return self._require( + lambda: self.auth_enabled, "Authentication is not enabled on the server", func=func + ) def require_no_auth(self, func): """Run a test only if the server is running without auth enabled.""" - return self._require(lambda: not self.auth_enabled, - "Authentication must not be enabled on the server", - func=func) + return self._require( + lambda: not self.auth_enabled, + "Authentication must not be enabled on the server", + func=func, + ) def require_replica_set(self, func): """Run a test only if the client is connected to a replica set.""" - return self._require(lambda: self.is_rs, - "Not connected to a replica set", - func=func) + return self._require(lambda: self.is_rs, "Not connected to a replica set", func=func) def require_secondaries_count(self, count): """Run a test only if the client is connected to a replica set that has `count` secondaries. """ + def sec_count(): return 0 if not self.client else len(self.client.secondaries) - return self._require(lambda: sec_count() >= count, - "Not enough secondaries available") + + return self._require(lambda: sec_count() >= count, "Not enough secondaries available") @property def supports_secondary_read_pref(self): if self.has_secondaries: return True if self.is_mongos: - shard = self.client.config.shards.find_one()['host'] - num_members = shard.count(',') + 1 + shard = self.client.config.shards.find_one()["host"] + num_members = shard.count(",") + 1 return num_members > 1 return False @@ -651,90 +657,94 @@ def require_secondary_read_pref(self): """Run a test only if the client is connected to a cluster that supports secondary read preference """ - return self._require(lambda: self.supports_secondary_read_pref, - "This cluster does not support secondary read " - "preference") + return self._require( + lambda: self.supports_secondary_read_pref, + "This cluster does not support secondary read " "preference", + ) def require_no_replica_set(self, func): """Run a test if the client is *not* connected to a replica set.""" return self._require( - lambda: not self.is_rs, - "Connected to a replica set, not a standalone mongod", - func=func) + lambda: not self.is_rs, "Connected to a replica set, not a standalone mongod", func=func + ) def require_ipv6(self, func): """Run a test only if the client can connect to a server via IPv6.""" - return self._require(lambda: self.has_ipv6, - "No IPv6", - func=func) + return self._require(lambda: self.has_ipv6, "No IPv6", func=func) def require_no_mongos(self, func): """Run a test only if the client is not connected to a mongos.""" - return self._require(lambda: not self.is_mongos, - "Must be connected to a mongod, not a mongos", - func=func) + return self._require( + lambda: not self.is_mongos, "Must be connected to a mongod, not a mongos", func=func + ) def require_mongos(self, func): """Run a test only if the client is connected to a mongos.""" - return self._require(lambda: self.is_mongos, - "Must be connected to a mongos", - func=func) + return self._require(lambda: self.is_mongos, "Must be connected to a mongos", func=func) def require_multiple_mongoses(self, func): """Run a test only if the client is connected to a sharded cluster that has 2 mongos nodes.""" - return self._require(lambda: len(self.mongoses) > 1, - "Must have multiple mongoses available", - func=func) + return self._require( + lambda: len(self.mongoses) > 1, "Must have multiple mongoses available", func=func + ) def require_standalone(self, func): """Run a test only if the client is connected to a standalone.""" - return self._require(lambda: not (self.is_mongos or self.is_rs), - "Must be connected to a standalone", - func=func) + return self._require( + lambda: not (self.is_mongos or self.is_rs), + "Must be connected to a standalone", + func=func, + ) def require_no_standalone(self, func): """Run a test only if the client is not connected to a standalone.""" - return self._require(lambda: self.is_mongos or self.is_rs, - "Must be connected to a replica set or mongos", - func=func) + return self._require( + lambda: self.is_mongos or self.is_rs, + "Must be connected to a replica set or mongos", + func=func, + ) def require_load_balancer(self, func): """Run a test only if the client is connected to a load balancer.""" - return self._require(lambda: self.load_balancer, - "Must be connected to a load balancer", - func=func) + return self._require( + lambda: self.load_balancer, "Must be connected to a load balancer", func=func + ) def require_no_load_balancer(self, func): - """Run a test only if the client is not connected to a load balancer. - """ - return self._require(lambda: not self.load_balancer, - "Must not be connected to a load balancer", - func=func) + """Run a test only if the client is not connected to a load balancer.""" + return self._require( + lambda: not self.load_balancer, "Must not be connected to a load balancer", func=func + ) def is_topology_type(self, topologies): - unknown = set(topologies) - {'single', 'replicaset', 'sharded', - 'sharded-replicaset', 'load-balanced'} + unknown = set(topologies) - { + "single", + "replicaset", + "sharded", + "sharded-replicaset", + "load-balanced", + } if unknown: - raise AssertionError('Unknown topologies: %r' % (unknown,)) + raise AssertionError("Unknown topologies: %r" % (unknown,)) if self.load_balancer: - if 'load-balanced' in topologies: + if "load-balanced" in topologies: return True return False - if 'single' in topologies and not (self.is_mongos or self.is_rs): + if "single" in topologies and not (self.is_mongos or self.is_rs): return True - if 'replicaset' in topologies and self.is_rs: + if "replicaset" in topologies and self.is_rs: return True - if 'sharded' in topologies and self.is_mongos: + if "sharded" in topologies and self.is_mongos: return True - if 'sharded-replicaset' in topologies and self.is_mongos: + if "sharded-replicaset" in topologies and self.is_mongos: shards = list(client_context.client.config.shards.find()) for shard in shards: # For a 3-member RS-backed sharded cluster, shard['host'] # will be 'replicaName/ip1:port1,ip2:port2,ip3:port3' # Otherwise it will be 'ip1:port1' - host_spec = shard['host'] - if not len(host_spec.split('/')) > 1: + host_spec = shard["host"] + if not len(host_spec.split("/")) > 1: return False return True return False @@ -743,76 +753,80 @@ def require_cluster_type(self, topologies=[]): """Run a test only if the client is connected to a cluster that conforms to one of the specified topologies. Acceptable topologies are 'single', 'replicaset', and 'sharded'.""" + def _is_valid_topology(): return self.is_topology_type(topologies) - return self._require( - _is_valid_topology, - "Cluster type not in %s" % (topologies)) + + return self._require(_is_valid_topology, "Cluster type not in %s" % (topologies)) def require_test_commands(self, func): """Run a test only if the server has test commands enabled.""" - return self._require(lambda: self.test_commands_enabled, - "Test commands must be enabled", - func=func) + return self._require( + lambda: self.test_commands_enabled, "Test commands must be enabled", func=func + ) def require_failCommand_fail_point(self, func): """Run a test only if the server supports the failCommand fail point.""" - return self._require(lambda: self.supports_failCommand_fail_point, - "failCommand fail point must be supported", - func=func) + return self._require( + lambda: self.supports_failCommand_fail_point, + "failCommand fail point must be supported", + func=func, + ) def require_failCommand_appName(self, func): """Run a test only if the server supports the failCommand appName.""" # SERVER-47195 - return self._require(lambda: (self.test_commands_enabled and - self.version >= (4, 4, -1)), - "failCommand appName must be supported", - func=func) + return self._require( + lambda: (self.test_commands_enabled and self.version >= (4, 4, -1)), + "failCommand appName must be supported", + func=func, + ) def require_failCommand_blockConnection(self, func): - """Run a test only if the server supports failCommand blockConnection. - """ + """Run a test only if the server supports failCommand blockConnection.""" return self._require( - lambda: (self.test_commands_enabled and ( - (not self.is_mongos and self.version >= (4, 2, 9)) or - (self.is_mongos and self.version >= (4, 4)))), + lambda: ( + self.test_commands_enabled + and ( + (not self.is_mongos and self.version >= (4, 2, 9)) + or (self.is_mongos and self.version >= (4, 4)) + ) + ), "failCommand blockConnection is not supported", - func=func) + func=func, + ) def require_tls(self, func): """Run a test only if the client can connect over TLS.""" - return self._require(lambda: self.tls, - "Must be able to connect via TLS", - func=func) + return self._require(lambda: self.tls, "Must be able to connect via TLS", func=func) def require_no_tls(self, func): """Run a test only if the client can connect over TLS.""" - return self._require(lambda: not self.tls, - "Must be able to connect without TLS", - func=func) + return self._require(lambda: not self.tls, "Must be able to connect without TLS", func=func) def require_tlsCertificateKeyFile(self, func): """Run a test only if the client can connect with tlsCertificateKeyFile.""" - return self._require(lambda: self.tlsCertificateKeyFile, - "Must be able to connect with tlsCertificateKeyFile", - func=func) + return self._require( + lambda: self.tlsCertificateKeyFile, + "Must be able to connect with tlsCertificateKeyFile", + func=func, + ) def require_server_resolvable(self, func): """Run a test only if the hostname 'server' is resolvable.""" - return self._require(lambda: self.server_is_resolvable, - "No hosts entry for 'server'. Cannot validate " - "hostname in the certificate", - func=func) + return self._require( + lambda: self.server_is_resolvable, + "No hosts entry for 'server'. Cannot validate " "hostname in the certificate", + func=func, + ) def require_sessions(self, func): """Run a test only if the deployment supports sessions.""" - return self._require(lambda: self.sessions_enabled, - "Sessions not supported", - func=func) + return self._require(lambda: self.sessions_enabled, "Sessions not supported", func=func) def supports_retryable_writes(self): - if self.storage_engine == 'mmapv1': + if self.storage_engine == "mmapv1": return False if not self.sessions_enabled: return False @@ -820,12 +834,14 @@ def supports_retryable_writes(self): def require_retryable_writes(self, func): """Run a test only if the deployment supports retryable writes.""" - return self._require(self.supports_retryable_writes, - "This server does not support retryable writes", - func=func) + return self._require( + self.supports_retryable_writes, + "This server does not support retryable writes", + func=func, + ) def supports_transactions(self): - if self.storage_engine == 'mmapv1': + if self.storage_engine == "mmapv1": return False if self.version.at_least(4, 1, 8): @@ -841,28 +857,28 @@ def require_transactions(self, func): *Might* because this does not test the storage engine or FCV. """ - return self._require(self.supports_transactions, - "Transactions are not supported", - func=func) + return self._require( + self.supports_transactions, "Transactions are not supported", func=func + ) def require_no_api_version(self, func): """Skip this test when testing with requireApiVersion.""" - return self._require(lambda: not MONGODB_API_VERSION, - "This test does not work with requireApiVersion", - func=func) + return self._require( + lambda: not MONGODB_API_VERSION, + "This test does not work with requireApiVersion", + func=func, + ) def mongos_seeds(self): - return ','.join('%s:%s' % address for address in self.mongoses) + return ",".join("%s:%s" % address for address in self.mongoses) @property def supports_failCommand_fail_point(self): """Does the server support the failCommand fail point?""" if self.is_mongos: - return (self.version.at_least(4, 1, 5) and - self.test_commands_enabled) + return self.version.at_least(4, 1, 5) and self.test_commands_enabled else: - return (self.version.at_least(4, 0) and - self.test_commands_enabled) + return self.version.at_least(4, 0) and self.test_commands_enabled @property def requires_hint_with_min_max_queries(self): @@ -872,11 +888,11 @@ def requires_hint_with_min_max_queries(self): @property def max_bson_size(self): - return self.hello['maxBsonObjectSize'] + return self.hello["maxBsonObjectSize"] @property def max_write_batch_size(self): - return self.hello['maxWriteBatchSize'] + return self.hello["maxWriteBatchSize"] # Reusable client context @@ -885,13 +901,13 @@ def max_write_batch_size(self): def sanitize_cmd(cmd): cp = cmd.copy() - cp.pop('$clusterTime', None) - cp.pop('$db', None) - cp.pop('$readPreference', None) - cp.pop('lsid', None) + cp.pop("$clusterTime", None) + cp.pop("$db", None) + cp.pop("$readPreference", None) + cp.pop("lsid", None) if MONGODB_API_VERSION: # Versioned api parameters - cp.pop('apiVersion', None) + cp.pop("apiVersion", None) # OP_MSG encoding may move the payload type one field to the # end of the command. Do the same here. name = next(iter(cp)) @@ -906,8 +922,8 @@ def sanitize_cmd(cmd): def sanitize_reply(reply): cp = reply.copy() - cp.pop('$clusterTime', None) - cp.pop('operationTime', None) + cp.pop("$clusterTime", None) + cp.pop("operationTime", None) return cp @@ -920,14 +936,15 @@ def assertEqualReply(self, expected, actual, msg=None): @contextmanager def fail_point(self, command_args): - cmd_on = SON([('configureFailPoint', 'failCommand')]) + cmd_on = SON([("configureFailPoint", "failCommand")]) cmd_on.update(command_args) client_context.client.admin.command(cmd_on) try: yield finally: client_context.client.admin.command( - 'configureFailPoint', cmd_on['configureFailPoint'], mode='off') + "configureFailPoint", cmd_on["configureFailPoint"], mode="off" + ) class IntegrationTest(PyMongoTestCase): @@ -936,16 +953,14 @@ class IntegrationTest(PyMongoTestCase): @classmethod @client_context.require_connection def setUpClass(cls): - if (client_context.load_balancer and - not getattr(cls, 'RUN_ON_LOAD_BALANCER', False)): - raise SkipTest('this test does not support load balancers') - if (client_context.serverless and - not getattr(cls, 'RUN_ON_SERVERLESS', False)): - raise SkipTest('this test does not support serverless') + if client_context.load_balancer and not getattr(cls, "RUN_ON_LOAD_BALANCER", False): + raise SkipTest("this test does not support load balancers") + if client_context.serverless and not getattr(cls, "RUN_ON_SERVERLESS", False): + raise SkipTest("this test does not support serverless") cls.client = client_context.client cls.db = cls.client.pymongo_test if client_context.auth_enabled: - cls.credentials = {'username': db_user, 'password': db_pwd} + cls.credentials = {"username": db_user, "password": db_pwd} else: cls.credentials = {} @@ -981,9 +996,7 @@ def setUpClass(cls): def setUp(self): super(MockClientTest, self).setUp() - self.client_knobs = client_knobs( - heartbeat_frequency=0.001, - min_heartbeat_interval=0.001) + self.client_knobs = client_knobs(heartbeat_frequency=0.001, min_heartbeat_interval=0.001) self.client_knobs.enable() @@ -1002,9 +1015,9 @@ def _get_executors(topology): executors = [] for server in topology._servers.values(): # Some MockMonitor do not have an _executor. - if hasattr(server._monitor, '_executor'): + if hasattr(server._monitor, "_executor"): executors.append(server._monitor._executor) - if hasattr(server._monitor, '_rtt_monitor'): + if hasattr(server._monitor, "_rtt_monitor"): executors.append(server._monitor._rtt_monitor._executor) executors.append(topology._Topology__events_executor) if topology._srv_monitor: @@ -1016,14 +1029,17 @@ def _get_executors(topology): def all_executors_stopped(topology): running = [e for e in _get_executors(topology) if not e._stopped] if running: - print(' Topology %s has THREADS RUNNING: %s, created at: %s' % ( - topology, running, topology._settings._stack)) + print( + " Topology %s has THREADS RUNNING: %s, created at: %s" + % (topology, running, topology._settings._stack) + ) return False return True def print_unclosed_clients(): from pymongo.topology import Topology + processed = set() # Call collect to manually cleanup any would-be gc'd clients to avoid # false positives. @@ -1043,11 +1059,11 @@ def print_unclosed_clients(): def teardown(): garbage = [] for g in gc.garbage: - garbage.append('GARBAGE: %r' % (g,)) - garbage.append(' gc.get_referents: %r' % (gc.get_referents(g),)) - garbage.append(' gc.get_referrers: %r' % (gc.get_referrers(g),)) + garbage.append("GARBAGE: %r" % (g,)) + garbage.append(" gc.get_referents: %r" % (gc.get_referents(g),)) + garbage.append(" gc.get_referrers: %r" % (gc.get_referrers(g),)) if garbage: - assert False, '\n'.join(garbage) + assert False, "\n".join(garbage) c = client_context.client if c: if not client_context.is_data_lake: @@ -1060,7 +1076,7 @@ def teardown(): c.close() # Jython does not support gc.get_objects. - if not sys.platform.startswith('java'): + if not sys.platform.startswith("java"): print_unclosed_clients() @@ -1073,6 +1089,7 @@ def run(self, test): if HAVE_XML: + class PymongoXMLTestRunner(XMLTestRunner): def run(self, test): setup() @@ -1103,17 +1120,21 @@ def clear_warning_registry(): class SystemCertsPatcher(object): def __init__(self, ca_certs): - if (ssl.OPENSSL_VERSION.lower().startswith('libressl') and - sys.platform == 'darwin' and not _ssl.IS_PYOPENSSL): + if ( + ssl.OPENSSL_VERSION.lower().startswith("libressl") + and sys.platform == "darwin" + and not _ssl.IS_PYOPENSSL + ): raise SkipTest( "LibreSSL on OSX doesn't support setting CA certificates " - "using SSL_CERT_FILE environment variable.") - self.original_certs = os.environ.get('SSL_CERT_FILE') + "using SSL_CERT_FILE environment variable." + ) + self.original_certs = os.environ.get("SSL_CERT_FILE") # Tell OpenSSL where CA certificates live. - os.environ['SSL_CERT_FILE'] = ca_certs + os.environ["SSL_CERT_FILE"] = ca_certs def disable(self): if self.original_certs is None: - os.environ.pop('SSL_CERT_FILE') + os.environ.pop("SSL_CERT_FILE") else: - os.environ['SSL_CERT_FILE'] = self.original_certs + os.environ["SSL_CERT_FILE"] = self.original_certs diff --git a/test/atlas/test_connection.py b/test/atlas/test_connection.py index 1ad84068ed..cad2b10683 100644 --- a/test/atlas/test_connection.py +++ b/test/atlas/test_connection.py @@ -17,7 +17,6 @@ import os import sys import unittest - from collections import defaultdict sys.path[0:0] = [""] @@ -27,6 +26,7 @@ try: import dns + HAS_DNS = True except ImportError: HAS_DNS = False @@ -57,59 +57,59 @@ def connect(uri): raise Exception("Must set env variable to test.") client = pymongo.MongoClient(uri) # No TLS error - client.admin.command('ping') + client.admin.command("ping") # No auth error client.test.test.count_documents({}) class TestAtlasConnect(unittest.TestCase): - @unittest.skipUnless(HAS_SNI, 'Free tier requires SNI support') + @unittest.skipUnless(HAS_SNI, "Free tier requires SNI support") def test_free_tier(self): - connect(URIS['ATLAS_FREE']) + connect(URIS["ATLAS_FREE"]) def test_replica_set(self): - connect(URIS['ATLAS_REPL']) + connect(URIS["ATLAS_REPL"]) def test_sharded_cluster(self): - connect(URIS['ATLAS_SHRD']) + connect(URIS["ATLAS_SHRD"]) def test_tls_11(self): - connect(URIS['ATLAS_TLS11']) + connect(URIS["ATLAS_TLS11"]) def test_tls_12(self): - connect(URIS['ATLAS_TLS12']) + connect(URIS["ATLAS_TLS12"]) def test_serverless(self): - connect(URIS['ATLAS_SERVERLESS']) + connect(URIS["ATLAS_SERVERLESS"]) def connect_srv(self, uri): connect(uri) - self.assertIn('mongodb+srv://', uri) + self.assertIn("mongodb+srv://", uri) - @unittest.skipUnless(HAS_SNI, 'Free tier requires SNI support') - @unittest.skipUnless(HAS_DNS or MUST_TEST_SRV, 'SRV requires dnspython') + @unittest.skipUnless(HAS_SNI, "Free tier requires SNI support") + @unittest.skipUnless(HAS_DNS or MUST_TEST_SRV, "SRV requires dnspython") def test_srv_free_tier(self): - self.connect_srv(URIS['ATLAS_SRV_FREE']) + self.connect_srv(URIS["ATLAS_SRV_FREE"]) - @unittest.skipUnless(HAS_DNS or MUST_TEST_SRV, 'SRV requires dnspython') + @unittest.skipUnless(HAS_DNS or MUST_TEST_SRV, "SRV requires dnspython") def test_srv_replica_set(self): - self.connect_srv(URIS['ATLAS_SRV_REPL']) + self.connect_srv(URIS["ATLAS_SRV_REPL"]) - @unittest.skipUnless(HAS_DNS or MUST_TEST_SRV, 'SRV requires dnspython') + @unittest.skipUnless(HAS_DNS or MUST_TEST_SRV, "SRV requires dnspython") def test_srv_sharded_cluster(self): - self.connect_srv(URIS['ATLAS_SRV_SHRD']) + self.connect_srv(URIS["ATLAS_SRV_SHRD"]) - @unittest.skipUnless(HAS_DNS or MUST_TEST_SRV, 'SRV requires dnspython') + @unittest.skipUnless(HAS_DNS or MUST_TEST_SRV, "SRV requires dnspython") def test_srv_tls_11(self): - self.connect_srv(URIS['ATLAS_SRV_TLS11']) + self.connect_srv(URIS["ATLAS_SRV_TLS11"]) - @unittest.skipUnless(HAS_DNS or MUST_TEST_SRV, 'SRV requires dnspython') + @unittest.skipUnless(HAS_DNS or MUST_TEST_SRV, "SRV requires dnspython") def test_srv_tls_12(self): - self.connect_srv(URIS['ATLAS_SRV_TLS12']) + self.connect_srv(URIS["ATLAS_SRV_TLS12"]) - @unittest.skipUnless(HAS_DNS or MUST_TEST_SRV, 'SRV requires dnspython') + @unittest.skipUnless(HAS_DNS or MUST_TEST_SRV, "SRV requires dnspython") def test_srv_serverless(self): - self.connect_srv(URIS['ATLAS_SRV_SERVERLESS']) + self.connect_srv(URIS["ATLAS_SRV_SERVERLESS"]) def test_uniqueness(self): """Ensure that we don't accidentally duplicate the test URIs.""" @@ -117,11 +117,12 @@ def test_uniqueness(self): for name, uri in URIS.items(): if uri: uri_to_names[uri].append(name) - duplicates = [names for names in uri_to_names.values() - if len(names) > 1] - self.assertFalse(duplicates, 'Error: the following env variables have ' - 'duplicate values: %s' % (duplicates,)) + duplicates = [names for names in uri_to_names.values() if len(names) > 1] + self.assertFalse( + duplicates, + "Error: the following env variables have " "duplicate values: %s" % (duplicates,), + ) -if __name__ == '__main__': +if __name__ == "__main__": unittest.main() diff --git a/test/auth_aws/test_auth_aws.py b/test/auth_aws/test_auth_aws.py index 0522201097..4ddaefeacf 100644 --- a/test/auth_aws/test_auth_aws.py +++ b/test/auth_aws/test_auth_aws.py @@ -26,24 +26,24 @@ class TestAuthAWS(unittest.TestCase): - @classmethod def setUpClass(cls): - cls.uri = os.environ['MONGODB_URI'] + cls.uri = os.environ["MONGODB_URI"] def test_should_fail_without_credentials(self): - if '@' not in self.uri: - self.skipTest('MONGODB_URI already has no credentials') + if "@" not in self.uri: + self.skipTest("MONGODB_URI already has no credentials") - hosts = ['%s:%s' % addr for addr in parse_uri(self.uri)['nodelist']] + hosts = ["%s:%s" % addr for addr in parse_uri(self.uri)["nodelist"]] self.assertTrue(hosts) with MongoClient(hosts) as client: with self.assertRaises(OperationFailure): client.aws.test.find_one() def test_should_fail_incorrect_credentials(self): - with MongoClient(self.uri, username='fake', password='fake', - authMechanism='MONGODB-AWS') as client: + with MongoClient( + self.uri, username="fake", password="fake", authMechanism="MONGODB-AWS" + ) as client: with self.assertRaises(OperationFailure): client.get_database().test.find_one() @@ -52,5 +52,5 @@ def test_connect_uri(self): client.get_database().test.find_one() -if __name__ == '__main__': +if __name__ == "__main__": unittest.main() diff --git a/test/crud_v2_format.py b/test/crud_v2_format.py index dbdea40d46..4118dfef9f 100644 --- a/test/crud_v2_format.py +++ b/test/crud_v2_format.py @@ -33,22 +33,22 @@ def allowable_errors(self, op): def get_scenario_db_name(self, scenario_def): """Crud spec says database_name is optional.""" - return scenario_def.get('database_name', self.TEST_DB) + return scenario_def.get("database_name", self.TEST_DB) def get_scenario_coll_name(self, scenario_def): """Crud spec says collection_name is optional.""" - return scenario_def.get('collection_name', self.TEST_COLLECTION) + return scenario_def.get("collection_name", self.TEST_COLLECTION) def get_object_name(self, op): """Crud spec says object is optional and defaults to 'collection'.""" - return op.get('object', 'collection') + return op.get("object", "collection") def get_outcome_coll_name(self, outcome, collection): """Crud spec says outcome has an optional 'collection.name'.""" - return outcome['collection'].get('name', collection.name) + return outcome["collection"].get("name", collection.name) def setup_scenario(self, scenario_def): """Allow specs to override a test's setup.""" # PYTHON-1935 Only create the collection if there is data to insert. - if scenario_def['data']: + if scenario_def["data"]: super(TestCrudV2, self).setup_scenario(scenario_def) diff --git a/test/mockupdb/operations.py b/test/mockupdb/operations.py index 47890f80ee..5d99c1e886 100644 --- a/test/mockupdb/operations.py +++ b/test/mockupdb/operations.py @@ -16,14 +16,13 @@ from mockupdb import * from mockupdb import OpMsgReply + from pymongo import ReadPreference -__all__ = ['operations', 'upgrades'] +__all__ = ["operations", "upgrades"] -Operation = namedtuple( - 'operation', - ['name', 'function', 'reply', 'op_type', 'not_master']) +Operation = namedtuple("operation", ["name", "function", "reply", "op_type", "not_master"]) """Client operations on MongoDB. Each has a human-readable name, a function that actually executes a test, and @@ -52,64 +51,71 @@ sharded cluster (PYTHON-868). """ -not_master_reply = OpMsgReply(ok=0, errmsg='not master') +not_master_reply = OpMsgReply(ok=0, errmsg="not master") operations = [ Operation( - 'find_one', + "find_one", lambda client: client.db.collection.find_one(), - reply={'cursor': {'id': 0, 'firstBatch': []}}, - op_type='may-use-secondary', - not_master=not_master_reply), + reply={"cursor": {"id": 0, "firstBatch": []}}, + op_type="may-use-secondary", + not_master=not_master_reply, + ), Operation( - 'count', + "count", lambda client: client.db.collection.count_documents({}), - reply={'n': 1}, - op_type='may-use-secondary', - not_master=not_master_reply), + reply={"n": 1}, + op_type="may-use-secondary", + not_master=not_master_reply, + ), Operation( - 'aggregate', + "aggregate", lambda client: client.db.collection.aggregate([]), - reply={'cursor': {'id': 0, 'firstBatch': []}}, - op_type='may-use-secondary', - not_master=not_master_reply), + reply={"cursor": {"id": 0, "firstBatch": []}}, + op_type="may-use-secondary", + not_master=not_master_reply, + ), Operation( - 'options', + "options", lambda client: client.db.collection.options(), - reply={'cursor': {'id': 0, 'firstBatch': []}}, - op_type='must-use-primary', - not_master=not_master_reply), + reply={"cursor": {"id": 0, "firstBatch": []}}, + op_type="must-use-primary", + not_master=not_master_reply, + ), Operation( - 'command', - lambda client: client.db.command('foo'), - reply={'ok': 1}, - op_type='must-use-primary', # Ignores client's read preference. - not_master=not_master_reply), + "command", + lambda client: client.db.command("foo"), + reply={"ok": 1}, + op_type="must-use-primary", # Ignores client's read preference. + not_master=not_master_reply, + ), Operation( - 'secondary command', - lambda client: - client.db.command('foo', read_preference=ReadPreference.SECONDARY), - reply={'ok': 1}, - op_type='always-use-secondary', - not_master=OpReply(ok=0, errmsg='node is recovering')), + "secondary command", + lambda client: client.db.command("foo", read_preference=ReadPreference.SECONDARY), + reply={"ok": 1}, + op_type="always-use-secondary", + not_master=OpReply(ok=0, errmsg="node is recovering"), + ), Operation( - 'listIndexes', + "listIndexes", lambda client: client.db.collection.index_information(), - reply={'cursor': {'id': 0, 'firstBatch': []}}, - op_type='must-use-primary', - not_master=not_master_reply), + reply={"cursor": {"id": 0, "firstBatch": []}}, + op_type="must-use-primary", + not_master=not_master_reply, + ), ] _ops_by_name = dict([(op.name, op) for op in operations]) -Upgrade = namedtuple('Upgrade', - ['name', 'function', 'old', 'new', 'wire_version']) +Upgrade = namedtuple("Upgrade", ["name", "function", "old", "new", "wire_version"]) upgrades = [ - Upgrade('estimated_document_count', - lambda client: client.db.collection.estimated_document_count(), - old=OpMsg('count', 'collection', namespace='db'), - new=OpMsg('aggregate', 'collection', namespace='db'), - wire_version=12), + Upgrade( + "estimated_document_count", + lambda client: client.db.collection.estimated_document_count(), + old=OpMsg("count", "collection", namespace="db"), + new=OpMsg("aggregate", "collection", namespace="db"), + wire_version=12, + ), ] diff --git a/test/mockupdb/test_auth_recovering_member.py b/test/mockupdb/test_auth_recovering_member.py index 6fb983b37f..33d33da24c 100755 --- a/test/mockupdb/test_auth_recovering_member.py +++ b/test/mockupdb/test_auth_recovering_member.py @@ -12,31 +12,35 @@ # See the License for the specific language governing permissions and # limitations under the License. +import unittest + from mockupdb import MockupDB + from pymongo import MongoClient from pymongo.errors import ServerSelectionTimeoutError -import unittest - class TestAuthRecoveringMember(unittest.TestCase): def test_auth_recovering_member(self): # Test that we don't attempt auth against a recovering RS member. server = MockupDB() - server.autoresponds('ismaster', { - 'minWireVersion': 2, - 'maxWireVersion': 6, - 'ismaster': False, - 'secondary': False, - 'setName': 'rs'}) + server.autoresponds( + "ismaster", + { + "minWireVersion": 2, + "maxWireVersion": 6, + "ismaster": False, + "secondary": False, + "setName": "rs", + }, + ) server.run() self.addCleanup(server.stop) - client = MongoClient(server.uri, - replicaSet='rs', - serverSelectionTimeoutMS=100, - socketTimeoutMS=100) + client = MongoClient( + server.uri, replicaSet="rs", serverSelectionTimeoutMS=100, socketTimeoutMS=100 + ) self.addCleanup(client.close) @@ -46,5 +50,6 @@ def test_auth_recovering_member(self): with self.assertRaises(ServerSelectionTimeoutError): client.db.command("ping") -if __name__ == '__main__': + +if __name__ == "__main__": unittest.main() diff --git a/test/mockupdb/test_cluster_time.py b/test/mockupdb/test_cluster_time.py index 858e32a0fa..e6d8c2126c 100644 --- a/test/mockupdb/test_cluster_time.py +++ b/test/mockupdb/test_cluster_time.py @@ -14,15 +14,13 @@ """Test $clusterTime handling.""" -from bson import Timestamp -from mockupdb import going, MockupDB -from pymongo import (MongoClient, - InsertOne, - UpdateOne, - DeleteMany) - import unittest +from mockupdb import MockupDB, going + +from bson import Timestamp +from pymongo import DeleteMany, InsertOne, MongoClient, UpdateOne + class TestClusterTime(unittest.TestCase): def cluster_time_conversation(self, callback, replies): @@ -31,10 +29,13 @@ def cluster_time_conversation(self, callback, replies): # First test all commands include $clusterTime with wire version 6. responder = server.autoresponds( - 'ismaster', - {'minWireVersion': 0, - 'maxWireVersion': 6, - '$clusterTime': {'clusterTime': cluster_time}}) + "ismaster", + { + "minWireVersion": 0, + "maxWireVersion": 6, + "$clusterTime": {"clusterTime": cluster_time}, + }, + ) server.run() self.addCleanup(server.stop) @@ -45,39 +46,35 @@ def cluster_time_conversation(self, callback, replies): with going(callback, client): for reply in replies: request = server.receives() - self.assertIn('$clusterTime', request) - self.assertEqual(request['$clusterTime']['clusterTime'], - cluster_time) - cluster_time = Timestamp(cluster_time.time, - cluster_time.inc + 1) - reply['$clusterTime'] = {'clusterTime': cluster_time} + self.assertIn("$clusterTime", request) + self.assertEqual(request["$clusterTime"]["clusterTime"], cluster_time) + cluster_time = Timestamp(cluster_time.time, cluster_time.inc + 1) + reply["$clusterTime"] = {"clusterTime": cluster_time} request.reply(reply) def test_command(self): def callback(client): - client.db.command('ping') - client.db.command('ping') + client.db.command("ping") + client.db.command("ping") - self.cluster_time_conversation(callback, [{'ok': 1}] * 2) + self.cluster_time_conversation(callback, [{"ok": 1}] * 2) def test_bulk(self): def callback(client): - client.db.collection.bulk_write([ - InsertOne({}), - InsertOne({}), - UpdateOne({}, {'$inc': {'x': 1}}), - DeleteMany({})]) + client.db.collection.bulk_write( + [InsertOne({}), InsertOne({}), UpdateOne({}, {"$inc": {"x": 1}}), DeleteMany({})] + ) self.cluster_time_conversation( callback, - [{'ok': 1, 'nInserted': 2}, - {'ok': 1, 'nModified': 1}, - {'ok': 1, 'nDeleted': 2}]) + [{"ok": 1, "nInserted": 2}, {"ok": 1, "nModified": 1}, {"ok": 1, "nDeleted": 2}], + ) batches = [ - {'cursor': {'id': 123, 'firstBatch': [{'a': 1}]}}, - {'cursor': {'id': 123, 'nextBatch': [{'a': 2}]}}, - {'cursor': {'id': 0, 'nextBatch': [{'a': 3}]}}] + {"cursor": {"id": 123, "firstBatch": [{"a": 1}]}}, + {"cursor": {"id": 123, "nextBatch": [{"a": 2}]}}, + {"cursor": {"id": 0, "nextBatch": [{"a": 3}]}}, + ] def test_cursor(self): def callback(client): @@ -95,13 +92,15 @@ def test_explain(self): def callback(client): client.db.collection.find().explain() - self.cluster_time_conversation(callback, [{'ok': 1}]) + self.cluster_time_conversation(callback, [{"ok": 1}]) def test_monitor(self): cluster_time = Timestamp(0, 0) - reply = {'minWireVersion': 0, - 'maxWireVersion': 6, - '$clusterTime': {'clusterTime': cluster_time}} + reply = { + "minWireVersion": 0, + "maxWireVersion": 6, + "$clusterTime": {"clusterTime": cluster_time}, + } server = MockupDB() server.run() @@ -110,55 +109,52 @@ def test_monitor(self): client = MongoClient(server.uri, heartbeatFrequencyMS=500) self.addCleanup(client.close) - request = server.receives('ismaster') + request = server.receives("ismaster") # No $clusterTime in first ismaster, only in subsequent ones - self.assertNotIn('$clusterTime', request) + self.assertNotIn("$clusterTime", request) request.ok(reply) # Next exchange: client returns first clusterTime, we send the second. - request = server.receives('ismaster') - self.assertIn('$clusterTime', request) - self.assertEqual(request['$clusterTime']['clusterTime'], - cluster_time) - cluster_time = Timestamp(cluster_time.time, - cluster_time.inc + 1) - reply['$clusterTime'] = {'clusterTime': cluster_time} + request = server.receives("ismaster") + self.assertIn("$clusterTime", request) + self.assertEqual(request["$clusterTime"]["clusterTime"], cluster_time) + cluster_time = Timestamp(cluster_time.time, cluster_time.inc + 1) + reply["$clusterTime"] = {"clusterTime": cluster_time} request.reply(reply) # Third exchange: client returns second clusterTime. - request = server.receives('ismaster') - self.assertEqual(request['$clusterTime']['clusterTime'], - cluster_time) + request = server.receives("ismaster") + self.assertEqual(request["$clusterTime"]["clusterTime"], cluster_time) # Return command error with a new clusterTime. - cluster_time = Timestamp(cluster_time.time, - cluster_time.inc + 1) - error = {'ok': 0, - 'code': 211, - 'errmsg': 'Cache Reader No keys found for HMAC ...', - '$clusterTime': {'clusterTime': cluster_time}} + cluster_time = Timestamp(cluster_time.time, cluster_time.inc + 1) + error = { + "ok": 0, + "code": 211, + "errmsg": "Cache Reader No keys found for HMAC ...", + "$clusterTime": {"clusterTime": cluster_time}, + } request.reply(error) # PyMongo 3.11+ closes the monitoring connection on command errors. # Fourth exchange: the Monitor closes the connection and runs the # handshake on a new connection. - request = server.receives('ismaster') + request = server.receives("ismaster") # No $clusterTime in first ismaster, only in subsequent ones - self.assertNotIn('$clusterTime', request) + self.assertNotIn("$clusterTime", request) # Reply without $clusterTime. - reply.pop('$clusterTime') + reply.pop("$clusterTime") request.reply(reply) # Fifth exchange: the Monitor attempt uses the clusterTime from # the previous isMaster error. - request = server.receives('ismaster') - self.assertEqual(request['$clusterTime']['clusterTime'], - cluster_time) + request = server.receives("ismaster") + self.assertEqual(request["$clusterTime"]["clusterTime"], cluster_time) request.reply(reply) client.close() -if __name__ == '__main__': +if __name__ == "__main__": unittest.main() diff --git a/test/mockupdb/test_cursor_namespace.py b/test/mockupdb/test_cursor_namespace.py index 10605601cf..22e9aa9f39 100644 --- a/test/mockupdb/test_cursor_namespace.py +++ b/test/mockupdb/test_cursor_namespace.py @@ -14,16 +14,17 @@ """Test list_indexes with more than one batch.""" -from mockupdb import going, MockupDB -from pymongo import MongoClient - import unittest +from mockupdb import MockupDB, going + +from pymongo import MongoClient + class TestCursorNamespace(unittest.TestCase): @classmethod def setUpClass(cls): - cls.server = MockupDB(auto_ismaster={'maxWireVersion': 6}) + cls.server = MockupDB(auto_ismaster={"maxWireVersion": 6}) cls.server.run() cls.client = MongoClient(cls.server.uri) @@ -34,44 +35,49 @@ def tearDownClass(cls): def _test_cursor_namespace(self, cursor_op, command): with going(cursor_op) as docs: - request = self.server.receives( - **{command: 'collection', 'namespace': 'test'}) + request = self.server.receives(**{command: "collection", "namespace": "test"}) # Respond with a different namespace. - request.reply({'cursor': { - 'firstBatch': [{'doc': 1}], - 'id': 123, - 'ns': 'different_db.different.coll'}}) + request.reply( + { + "cursor": { + "firstBatch": [{"doc": 1}], + "id": 123, + "ns": "different_db.different.coll", + } + } + ) # Client uses the namespace we returned. request = self.server.receives( - getMore=123, namespace='different_db', - collection='different.coll') + getMore=123, namespace="different_db", collection="different.coll" + ) - request.reply({'cursor': { - 'nextBatch': [{'doc': 2}], - 'id': 0}}) + request.reply({"cursor": {"nextBatch": [{"doc": 2}], "id": 0}}) - self.assertEqual([{'doc': 1}, {'doc': 2}], docs()) + self.assertEqual([{"doc": 1}, {"doc": 2}], docs()) def test_aggregate_cursor(self): def op(): return list(self.client.test.collection.aggregate([])) - self._test_cursor_namespace(op, 'aggregate') + + self._test_cursor_namespace(op, "aggregate") def test_find_cursor(self): def op(): return list(self.client.test.collection.find()) - self._test_cursor_namespace(op, 'find') + + self._test_cursor_namespace(op, "find") def test_list_indexes(self): def op(): return list(self.client.test.collection.list_indexes()) - self._test_cursor_namespace(op, 'listIndexes') + + self._test_cursor_namespace(op, "listIndexes") class TestKillCursorsNamespace(unittest.TestCase): @classmethod def setUpClass(cls): - cls.server = MockupDB(auto_ismaster={'maxWireVersion': 6}) + cls.server = MockupDB(auto_ismaster={"maxWireVersion": 6}) cls.server.run() cls.client = MongoClient(cls.server.uri) @@ -82,39 +88,47 @@ def tearDownClass(cls): def _test_killCursors_namespace(self, cursor_op, command): with going(cursor_op): - request = self.server.receives( - **{command: 'collection', 'namespace': 'test'}) + request = self.server.receives(**{command: "collection", "namespace": "test"}) # Respond with a different namespace. - request.reply({'cursor': { - 'firstBatch': [{'doc': 1}], - 'id': 123, - 'ns': 'different_db.different.coll'}}) + request.reply( + { + "cursor": { + "firstBatch": [{"doc": 1}], + "id": 123, + "ns": "different_db.different.coll", + } + } + ) # Client uses the namespace we returned for killCursors. - request = self.server.receives(**{ - 'killCursors': 'different.coll', - 'cursors': [123], - '$db': 'different_db'}) - request.reply({ - 'ok': 1, - 'cursorsKilled': [123], - 'cursorsNotFound': [], - 'cursorsAlive': [], - 'cursorsUnknown': []}) + request = self.server.receives( + **{"killCursors": "different.coll", "cursors": [123], "$db": "different_db"} + ) + request.reply( + { + "ok": 1, + "cursorsKilled": [123], + "cursorsNotFound": [], + "cursorsAlive": [], + "cursorsUnknown": [], + } + ) def test_aggregate_killCursor(self): def op(): cursor = self.client.test.collection.aggregate([], batchSize=1) next(cursor) cursor.close() - self._test_killCursors_namespace(op, 'aggregate') + + self._test_killCursors_namespace(op, "aggregate") def test_find_killCursor(self): def op(): cursor = self.client.test.collection.find(batch_size=1) next(cursor) cursor.close() - self._test_killCursors_namespace(op, 'find') + + self._test_killCursors_namespace(op, "find") -if __name__ == '__main__': +if __name__ == "__main__": unittest.main() diff --git a/test/mockupdb/test_getmore_sharded.py b/test/mockupdb/test_getmore_sharded.py index 2b3a1fd6ce..d0d7a56263 100644 --- a/test/mockupdb/test_getmore_sharded.py +++ b/test/mockupdb/test_getmore_sharded.py @@ -20,10 +20,10 @@ except ImportError: from Queue import Queue -from mockupdb import MockupDB, going - import unittest +from mockupdb import MockupDB, going + class TestGetmoreSharded(unittest.TestCase): def test_getmore_sharded(self): @@ -33,20 +33,22 @@ def test_getmore_sharded(self): q = Queue() for server in servers: server.subscribe(q.put) - server.autoresponds('ismaster', ismaster=True, msg='isdbgrid', - minWireVersion=2, maxWireVersion=6) + server.autoresponds( + "ismaster", ismaster=True, msg="isdbgrid", minWireVersion=2, maxWireVersion=6 + ) server.run() self.addCleanup(server.stop) - client = MongoClient('mongodb://%s:%d,%s:%d' % ( - servers[0].host, servers[0].port, - servers[1].host, servers[1].port)) + client = MongoClient( + "mongodb://%s:%d,%s:%d" + % (servers[0].host, servers[0].port, servers[1].host, servers[1].port) + ) self.addCleanup(client.close) collection = client.db.collection cursor = collection.find() with going(next, cursor): query = q.get(timeout=1) - query.replies({'cursor': {'id': 123, 'firstBatch': [{}]}}) + query.replies({"cursor": {"id": 123, "firstBatch": [{}]}}) # 10 batches, all getMores go to same server. for i in range(1, 10): @@ -54,9 +56,8 @@ def test_getmore_sharded(self): getmore = q.get(timeout=1) self.assertEqual(query.server, getmore.server) cursor_id = 123 if i < 9 else 0 - getmore.replies({'cursor': {'id': cursor_id, - 'nextBatch': [{}]}}) + getmore.replies({"cursor": {"id": cursor_id, "nextBatch": [{}]}}) -if __name__ == '__main__': +if __name__ == "__main__": unittest.main() diff --git a/test/mockupdb/test_handshake.py b/test/mockupdb/test_handshake.py index 621f01728f..5c6dcd3b22 100644 --- a/test/mockupdb/test_handshake.py +++ b/test/mockupdb/test_handshake.py @@ -13,23 +13,25 @@ # limitations under the License. -from mockupdb import MockupDB, OpReply, OpMsg, absent, Command, go -from pymongo import MongoClient, version as pymongo_version -from pymongo.errors import OperationFailure - import unittest +from mockupdb import Command, MockupDB, OpMsg, OpReply, absent, go + +from pymongo import MongoClient +from pymongo import version as pymongo_version +from pymongo.errors import OperationFailure + def _check_handshake_data(request): - assert 'client' in request - data = request['client'] + assert "client" in request + data = request["client"] - assert data['application'] == {'name': 'my app'} - assert data['driver'] == {'name': 'PyMongo', 'version': pymongo_version} + assert data["application"] == {"name": "my app"} + assert data["driver"] == {"name": "PyMongo", "version": pymongo_version} # Keep it simple, just check these fields exist. - assert 'os' in data - assert 'platform' in data + assert "os" in data + assert "platform" in data class TestHandshake(unittest.TestCase): @@ -40,63 +42,66 @@ def test_client_handshake_data(self): self.addCleanup(server.stop) hosts = [server.address_string for server in (primary, secondary)] - primary_response = OpReply('ismaster', True, - setName='rs', hosts=hosts, - minWireVersion=2, maxWireVersion=6) - error_response = OpReply( - 0, errmsg='Cache Reader No keys found for HMAC ...', code=211) - - secondary_response = OpReply('ismaster', False, - setName='rs', hosts=hosts, - secondary=True, - minWireVersion=2, maxWireVersion=6) - - client = MongoClient(primary.uri, - replicaSet='rs', - appname='my app', - heartbeatFrequencyMS=500) # Speed up the test. + primary_response = OpReply( + "ismaster", True, setName="rs", hosts=hosts, minWireVersion=2, maxWireVersion=6 + ) + error_response = OpReply(0, errmsg="Cache Reader No keys found for HMAC ...", code=211) + + secondary_response = OpReply( + "ismaster", + False, + setName="rs", + hosts=hosts, + secondary=True, + minWireVersion=2, + maxWireVersion=6, + ) + + client = MongoClient( + primary.uri, replicaSet="rs", appname="my app", heartbeatFrequencyMS=500 + ) # Speed up the test. self.addCleanup(client.close) # New monitoring sockets send data during handshake. - heartbeat = primary.receives('ismaster') + heartbeat = primary.receives("ismaster") _check_handshake_data(heartbeat) heartbeat.ok(primary_response) - heartbeat = secondary.receives('ismaster') + heartbeat = secondary.receives("ismaster") _check_handshake_data(heartbeat) heartbeat.ok(secondary_response) # Subsequent heartbeats have no client data. - primary.receives('ismaster', 1, client=absent).ok(error_response) - secondary.receives('ismaster', 1, client=absent).ok(error_response) + primary.receives("ismaster", 1, client=absent).ok(error_response) + secondary.receives("ismaster", 1, client=absent).ok(error_response) # The heartbeat retry (on a new connection) does have client data. - heartbeat = primary.receives('ismaster') + heartbeat = primary.receives("ismaster") _check_handshake_data(heartbeat) heartbeat.ok(primary_response) - heartbeat = secondary.receives('ismaster') + heartbeat = secondary.receives("ismaster") _check_handshake_data(heartbeat) heartbeat.ok(secondary_response) # Still no client data. - primary.receives('ismaster', 1, client=absent).ok(primary_response) - secondary.receives('ismaster', 1, client=absent).ok(secondary_response) + primary.receives("ismaster", 1, client=absent).ok(primary_response) + secondary.receives("ismaster", 1, client=absent).ok(secondary_response) # After a disconnect, next ismaster has client data again. - primary.receives('ismaster', 1, client=absent).hangup() - heartbeat = primary.receives('ismaster') + primary.receives("ismaster", 1, client=absent).hangup() + heartbeat = primary.receives("ismaster") _check_handshake_data(heartbeat) heartbeat.ok(primary_response) - secondary.autoresponds('ismaster', secondary_response) + secondary.autoresponds("ismaster", secondary_response) # Start a command, so the client opens an application socket. - future = go(client.db.command, 'whatever') + future = go(client.db.command, "whatever") for request in primary: - if request.matches(Command('ismaster')): + if request.matches(Command("ismaster")): if request.client_port == heartbeat.client_port: # This is the monitor again, keep going. request.ok(primary_response) @@ -106,7 +111,7 @@ def test_client_handshake_data(self): request.ok(primary_response) else: # Command succeeds. - request.assert_matches(OpMsg('whatever')) + request.assert_matches(OpMsg("whatever")) request.ok() assert future() return @@ -116,40 +121,42 @@ def test_client_handshake_saslSupportedMechs(self): server.run() self.addCleanup(server.stop) - primary_response = OpReply('ismaster', True, - minWireVersion=2, maxWireVersion=6) - client = MongoClient(server.uri, - username='username', - password='password') + primary_response = OpReply("ismaster", True, minWireVersion=2, maxWireVersion=6) + client = MongoClient(server.uri, username="username", password="password") self.addCleanup(client.close) # New monitoring sockets send data during handshake. - heartbeat = server.receives('ismaster') + heartbeat = server.receives("ismaster") heartbeat.ok(primary_response) - future = go(client.db.command, 'whatever') + future = go(client.db.command, "whatever") for request in server: - if request.matches('ismaster'): + if request.matches("ismaster"): if request.client_port == heartbeat.client_port: # This is the monitor again, keep going. request.ok(primary_response) else: # Handshaking a new application socket should send # saslSupportedMechs and speculativeAuthenticate. - self.assertEqual(request['saslSupportedMechs'], - 'admin.username') - self.assertIn( - 'saslStart', request['speculativeAuthenticate']) - auth = {'conversationId': 1, 'done': False, - 'payload': b'r=wPleNM8S5p8gMaffMDF7Py4ru9bnmmoqb0' - b'1WNPsil6o=pAvr6B1garhlwc6MKNQ93ZfFky' - b'tXdF9r,s=4dcxugMJq2P4hQaDbGXZR8uR3ei' - b'PHrSmh4uhkg==,i=15000'} - request.ok('ismaster', True, - saslSupportedMechs=['SCRAM-SHA-256'], - speculativeAuthenticate=auth, - minWireVersion=2, maxWireVersion=6) + self.assertEqual(request["saslSupportedMechs"], "admin.username") + self.assertIn("saslStart", request["speculativeAuthenticate"]) + auth = { + "conversationId": 1, + "done": False, + "payload": b"r=wPleNM8S5p8gMaffMDF7Py4ru9bnmmoqb0" + b"1WNPsil6o=pAvr6B1garhlwc6MKNQ93ZfFky" + b"tXdF9r,s=4dcxugMJq2P4hQaDbGXZR8uR3ei" + b"PHrSmh4uhkg==,i=15000", + } + request.ok( + "ismaster", + True, + saslSupportedMechs=["SCRAM-SHA-256"], + speculativeAuthenticate=auth, + minWireVersion=2, + maxWireVersion=6, + ) # Authentication should immediately fail with: # OperationFailure: Server returned an invalid nonce. with self.assertRaises(OperationFailure): @@ -157,5 +164,5 @@ def test_client_handshake_saslSupportedMechs(self): return -if __name__ == '__main__': +if __name__ == "__main__": unittest.main() diff --git a/test/mockupdb/test_initial_ismaster.py b/test/mockupdb/test_initial_ismaster.py index c67fcbf9e1..155ae6152e 100644 --- a/test/mockupdb/test_initial_ismaster.py +++ b/test/mockupdb/test_initial_ismaster.py @@ -13,11 +13,11 @@ # limitations under the License. import time +import unittest from mockupdb import MockupDB, wait_until -from pymongo import MongoClient -import unittest +from pymongo import MongoClient class TestInitialIsMaster(unittest.TestCase): @@ -32,15 +32,13 @@ def test_initial_ismaster(self): # A single ismaster is enough for the client to be connected. self.assertFalse(client.nodes) - server.receives('ismaster').ok(ismaster=True, - minWireVersion=2, maxWireVersion=6) - wait_until(lambda: client.nodes, - 'update nodes', timeout=1) + server.receives("ismaster").ok(ismaster=True, minWireVersion=2, maxWireVersion=6) + wait_until(lambda: client.nodes, "update nodes", timeout=1) # At least 10 seconds before next heartbeat. - server.receives('ismaster').ok(ismaster=True, - minWireVersion=2, maxWireVersion=6) + server.receives("ismaster").ok(ismaster=True, minWireVersion=2, maxWireVersion=6) self.assertGreaterEqual(time.time() - start, 10) -if __name__ == '__main__': + +if __name__ == "__main__": unittest.main() diff --git a/test/mockupdb/test_list_indexes.py b/test/mockupdb/test_list_indexes.py index b4787ff624..2bdbd7b910 100644 --- a/test/mockupdb/test_list_indexes.py +++ b/test/mockupdb/test_list_indexes.py @@ -14,42 +14,34 @@ """Test list_indexes with more than one batch.""" -from bson import SON +import unittest -from mockupdb import going, MockupDB, OpGetMore -from pymongo import MongoClient +from mockupdb import MockupDB, OpGetMore, going -import unittest +from bson import SON +from pymongo import MongoClient class TestListIndexes(unittest.TestCase): - def test_list_indexes_command(self): - server = MockupDB(auto_ismaster={'maxWireVersion': 6}) + server = MockupDB(auto_ismaster={"maxWireVersion": 6}) server.run() self.addCleanup(server.stop) client = MongoClient(server.uri) self.addCleanup(client.close) with going(client.test.collection.list_indexes) as cursor: - request = server.receives( - listIndexes='collection', namespace='test') - request.reply({'cursor': { - 'firstBatch': [{'name': 'index_0'}], - 'id': 123}}) + request = server.receives(listIndexes="collection", namespace="test") + request.reply({"cursor": {"firstBatch": [{"name": "index_0"}], "id": 123}}) with going(list, cursor()) as indexes: - request = server.receives(getMore=123, - namespace='test', - collection='collection') + request = server.receives(getMore=123, namespace="test", collection="collection") - request.reply({'cursor': { - 'nextBatch': [{'name': 'index_1'}], - 'id': 0}}) + request.reply({"cursor": {"nextBatch": [{"name": "index_1"}], "id": 0}}) - self.assertEqual([{'name': 'index_0'}, {'name': 'index_1'}], indexes()) + self.assertEqual([{"name": "index_0"}, {"name": "index_1"}], indexes()) for index_info in indexes(): self.assertIsInstance(index_info, SON) -if __name__ == '__main__': +if __name__ == "__main__": unittest.main() diff --git a/test/mockupdb/test_max_staleness.py b/test/mockupdb/test_max_staleness.py index 9bd65a1764..02efb6a718 100644 --- a/test/mockupdb/test_max_staleness.py +++ b/test/mockupdb/test_max_staleness.py @@ -12,33 +12,31 @@ # See the License for the specific language governing permissions and # limitations under the License. +import unittest + from mockupdb import MockupDB, going -from pymongo import MongoClient -import unittest +from pymongo import MongoClient class TestMaxStalenessMongos(unittest.TestCase): def test_mongos(self): mongos = MockupDB() - mongos.autoresponds('ismaster', maxWireVersion=6, - ismaster=True, msg='isdbgrid') + mongos.autoresponds("ismaster", maxWireVersion=6, ismaster=True, msg="isdbgrid") mongos.run() self.addCleanup(mongos.stop) # No maxStalenessSeconds. - uri = 'mongodb://localhost:%d/?readPreference=secondary' % mongos.port + uri = "mongodb://localhost:%d/?readPreference=secondary" % mongos.port client = MongoClient(uri) self.addCleanup(client.close) with going(client.db.coll.find_one) as future: request = mongos.receives() - self.assertNotIn( - 'maxStalenessSeconds', - request.doc['$readPreference']) + self.assertNotIn("maxStalenessSeconds", request.doc["$readPreference"]) self.assertTrue(request.slave_okay) - request.ok(cursor={'firstBatch': [], 'id': 0}) + request.ok(cursor={"firstBatch": [], "id": 0}) # find_one succeeds with no result. self.assertIsNone(future()) @@ -46,22 +44,22 @@ def test_mongos(self): # Set maxStalenessSeconds to 1. Client has no minimum with mongos, # we let mongos enforce the 90-second minimum and return an error: # SERVER-27146. - uri = 'mongodb://localhost:%d/?readPreference=secondary' \ - '&maxStalenessSeconds=1' % mongos.port + uri = ( + "mongodb://localhost:%d/?readPreference=secondary" + "&maxStalenessSeconds=1" % mongos.port + ) client = MongoClient(uri) self.addCleanup(client.close) with going(client.db.coll.find_one) as future: request = mongos.receives() - self.assertEqual( - 1, - request.doc['$readPreference']['maxStalenessSeconds']) + self.assertEqual(1, request.doc["$readPreference"]["maxStalenessSeconds"]) self.assertTrue(request.slave_okay) - request.ok(cursor={'firstBatch': [], 'id': 0}) + request.ok(cursor={"firstBatch": [], "id": 0}) self.assertIsNone(future()) -if __name__ == '__main__': +if __name__ == "__main__": unittest.main() diff --git a/test/mockupdb/test_mixed_version_sharded.py b/test/mockupdb/test_mixed_version_sharded.py index d13af3562b..8b83b3ba3b 100644 --- a/test/mockupdb/test_mixed_version_sharded.py +++ b/test/mockupdb/test_mixed_version_sharded.py @@ -21,12 +21,13 @@ except ImportError: from Queue import Queue -from mockupdb import MockupDB, go, OpMsg -from pymongo import MongoClient - import unittest + +from mockupdb import MockupDB, OpMsg, go from operations import upgrades +from pymongo import MongoClient + class TestMixedVersionSharded(unittest.TestCase): def setup_server(self, upgrade): @@ -36,25 +37,29 @@ def setup_server(self, upgrade): self.q = Queue() for server in self.mongos_old, self.mongos_new: server.subscribe(self.q.put) - server.autoresponds('getlasterror') + server.autoresponds("getlasterror") server.run() self.addCleanup(server.stop) # Max wire version is too old for the upgraded operation. - self.mongos_old.autoresponds('ismaster', ismaster=True, msg='isdbgrid', - maxWireVersion=upgrade.wire_version - 1) + self.mongos_old.autoresponds( + "ismaster", ismaster=True, msg="isdbgrid", maxWireVersion=upgrade.wire_version - 1 + ) # Up-to-date max wire version. - self.mongos_new.autoresponds('ismaster', ismaster=True, msg='isdbgrid', - maxWireVersion=upgrade.wire_version) + self.mongos_new.autoresponds( + "ismaster", ismaster=True, msg="isdbgrid", maxWireVersion=upgrade.wire_version + ) - self.mongoses_uri = 'mongodb://%s,%s' % (self.mongos_old.address_string, - self.mongos_new.address_string) + self.mongoses_uri = "mongodb://%s,%s" % ( + self.mongos_old.address_string, + self.mongos_new.address_string, + ) self.client = MongoClient(self.mongoses_uri) def tearDown(self): - if hasattr(self, 'client') and self.client: + if hasattr(self, "client") and self.client: self.client.close() @@ -67,23 +72,24 @@ def test(self): go(upgrade.function, self.client) request = self.q.get(timeout=1) servers_used.add(request.server) - request.assert_matches(upgrade.old - if request.server is self.mongos_old - else upgrade.new) + request.assert_matches( + upgrade.old if request.server is self.mongos_old else upgrade.new + ) if time.time() > start + 10: - self.fail('never used both mongoses') + self.fail("never used both mongoses") + return test def generate_mixed_version_sharded_tests(): for upgrade in upgrades: test = create_mixed_version_sharded_test(upgrade) - test_name = 'test_%s' % upgrade.name.replace(' ', '_') + test_name = "test_%s" % upgrade.name.replace(" ", "_") test.__name__ = test_name setattr(TestMixedVersionSharded, test_name, test) generate_mixed_version_sharded_tests() -if __name__ == '__main__': +if __name__ == "__main__": unittest.main() diff --git a/test/mockupdb/test_mongos_command_read_mode.py b/test/mockupdb/test_mongos_command_read_mode.py index ccd40c2cd7..cf50528612 100644 --- a/test/mockupdb/test_mongos_command_read_mode.py +++ b/test/mockupdb/test_mongos_command_read_mode.py @@ -13,23 +13,26 @@ # limitations under the License. import itertools +import unittest + +from mockupdb import MockupDB, OpMsg, go, going +from operations import operations from bson import SON -from mockupdb import MockupDB, going, OpMsg, go from pymongo import MongoClient, ReadPreference -from pymongo.read_preferences import (make_read_preference, - read_pref_mode_from_name, - _MONGOS_MODES) - -import unittest -from operations import operations +from pymongo.read_preferences import ( + _MONGOS_MODES, + make_read_preference, + read_pref_mode_from_name, +) class TestMongosCommandReadMode(unittest.TestCase): def test_aggregate(self): server = MockupDB() - server.autoresponds('ismaster', ismaster=True, msg='isdbgrid', - minWireVersion=2, maxWireVersion=6) + server.autoresponds( + "ismaster", ismaster=True, msg="isdbgrid", minWireVersion=2, maxWireVersion=6 + ) self.addCleanup(server.stop) server.run() @@ -37,20 +40,25 @@ def test_aggregate(self): self.addCleanup(client.close) collection = client.test.collection with going(collection.aggregate, []): - command = server.receives(aggregate='collection', pipeline=[]) - self.assertFalse(command.slave_ok, 'SlaveOkay set') + command = server.receives(aggregate="collection", pipeline=[]) + self.assertFalse(command.slave_ok, "SlaveOkay set") command.ok(result=[{}]) - secondary_collection = collection.with_options( - read_preference=ReadPreference.SECONDARY) + secondary_collection = collection.with_options(read_preference=ReadPreference.SECONDARY) with going(secondary_collection.aggregate, []): - command = server.receives(OpMsg({"aggregate": "collection", - "pipeline": [], - '$readPreference': {'mode': 'secondary'}})) + command = server.receives( + OpMsg( + { + "aggregate": "collection", + "pipeline": [], + "$readPreference": {"mode": "secondary"}, + } + ) + ) command.ok(result=[{}]) - self.assertTrue(command.slave_ok, 'SlaveOkay not set') + self.assertTrue(command.slave_ok, "SlaveOkay not set") def create_mongos_read_mode_test(mode, operation): @@ -58,11 +66,11 @@ def test(self): server = MockupDB() self.addCleanup(server.stop) server.run() - server.autoresponds('ismaster', ismaster=True, msg='isdbgrid', - minWireVersion=2, maxWireVersion=6) + server.autoresponds( + "ismaster", ismaster=True, msg="isdbgrid", minWireVersion=2, maxWireVersion=6 + ) - pref = make_read_preference(read_pref_mode_from_name(mode), - tag_sets=None) + pref = make_read_preference(read_pref_mode_from_name(mode), tag_sets=None) client = MongoClient(server.uri, read_preference=pref) self.addCleanup(client.close) @@ -71,23 +79,21 @@ def test(self): request = server.receive() request.reply(operation.reply) - if operation.op_type == 'always-use-secondary': - self.assertEqual(ReadPreference.SECONDARY.document, - request.doc.get('$readPreference')) - slave_ok = mode != 'primary' - elif operation.op_type == 'must-use-primary': + if operation.op_type == "always-use-secondary": + self.assertEqual(ReadPreference.SECONDARY.document, request.doc.get("$readPreference")) + slave_ok = mode != "primary" + elif operation.op_type == "must-use-primary": slave_ok = False - elif operation.op_type == 'may-use-secondary': - slave_ok = mode != 'primary' - self.assertEqual(pref.document, - request.doc.get('$readPreference')) + elif operation.op_type == "may-use-secondary": + slave_ok = mode != "primary" + self.assertEqual(pref.document, request.doc.get("$readPreference")) else: - self.fail('unrecognized op_type %r' % operation.op_type) + self.fail("unrecognized op_type %r" % operation.op_type) if slave_ok: - self.assertTrue(request.slave_ok, 'SlaveOkay not set') + self.assertTrue(request.slave_ok, "SlaveOkay not set") else: - self.assertFalse(request.slave_ok, 'SlaveOkay set') + self.assertFalse(request.slave_ok, "SlaveOkay set") return test @@ -97,12 +103,11 @@ def generate_mongos_read_mode_tests(): for entry in matrix: mode, operation = entry - if mode == 'primary' and operation.op_type == 'always-use-secondary': + if mode == "primary" and operation.op_type == "always-use-secondary": # Skip something like command('foo', read_preference=SECONDARY). continue test = create_mongos_read_mode_test(mode, operation) - test_name = 'test_%s_with_mode_%s' % ( - operation.name.replace(' ', '_'), mode) + test_name = "test_%s_with_mode_%s" % (operation.name.replace(" ", "_"), mode) test.__name__ = test_name setattr(TestMongosCommandReadMode, test_name, test) @@ -110,5 +115,5 @@ def generate_mongos_read_mode_tests(): generate_mongos_read_mode_tests() -if __name__ == '__main__': +if __name__ == "__main__": unittest.main() diff --git a/test/mockupdb/test_network_disconnect_primary.py b/test/mockupdb/test_network_disconnect_primary.py index 1df5febb78..3ae9ff8db8 100755 --- a/test/mockupdb/test_network_disconnect_primary.py +++ b/test/mockupdb/test_network_disconnect_primary.py @@ -17,12 +17,13 @@ except ImportError: from Queue import Queue -from mockupdb import MockupDB, wait_until, OpReply, going, Future +import unittest + +from mockupdb import Future, MockupDB, OpReply, going, wait_until + +from pymongo import MongoClient from pymongo.errors import ConnectionFailure from pymongo.topology_description import TOPOLOGY_TYPE -from pymongo import MongoClient - -import unittest class TestNetworkDisconnectPrimary(unittest.TestCase): @@ -36,52 +37,53 @@ def test_network_disconnect_primary(self): self.addCleanup(server.stop) hosts = [server.address_string for server in servers] - primary_response = OpReply(ismaster=True, setName='rs', hosts=hosts, - minWireVersion=2, maxWireVersion=6) - primary.autoresponds('ismaster', primary_response) + primary_response = OpReply( + ismaster=True, setName="rs", hosts=hosts, minWireVersion=2, maxWireVersion=6 + ) + primary.autoresponds("ismaster", primary_response) secondary.autoresponds( - 'ismaster', - ismaster=False, secondary=True, setName='rs', hosts=hosts, - minWireVersion=2, maxWireVersion=6) - - client = MongoClient(primary.uri, replicaSet='rs') + "ismaster", + ismaster=False, + secondary=True, + setName="rs", + hosts=hosts, + minWireVersion=2, + maxWireVersion=6, + ) + + client = MongoClient(primary.uri, replicaSet="rs") self.addCleanup(client.close) - wait_until(lambda: client.primary == primary.address, - 'discover primary') + wait_until(lambda: client.primary == primary.address, "discover primary") topology = client._topology - self.assertEqual(TOPOLOGY_TYPE.ReplicaSetWithPrimary, - topology.description.topology_type) + self.assertEqual(TOPOLOGY_TYPE.ReplicaSetWithPrimary, topology.description.topology_type) # Open a socket in the application pool (calls ismaster). - with going(client.db.command, 'buildinfo'): - primary.receives('buildinfo').ok() + with going(client.db.command, "buildinfo"): + primary.receives("buildinfo").ok() # The primary hangs replying to ismaster. ismaster_future = Future() - primary.autoresponds('ismaster', - lambda r: r.ok(ismaster_future.result())) + primary.autoresponds("ismaster", lambda r: r.ok(ismaster_future.result())) # Network error on application operation. with self.assertRaises(ConnectionFailure): - with going(client.db.command, 'buildinfo'): - primary.receives('buildinfo').hangup() + with going(client.db.command, "buildinfo"): + primary.receives("buildinfo").hangup() # Topology type is updated. - self.assertEqual(TOPOLOGY_TYPE.ReplicaSetNoPrimary, - topology.description.topology_type) + self.assertEqual(TOPOLOGY_TYPE.ReplicaSetNoPrimary, topology.description.topology_type) # Let ismasters through again. ismaster_future.set_result(primary_response) # Demand a primary. - with going(client.db.command, 'buildinfo'): - wait_until(lambda: client.primary == primary.address, - 'rediscover primary') - primary.receives('buildinfo').ok() + with going(client.db.command, "buildinfo"): + wait_until(lambda: client.primary == primary.address, "rediscover primary") + primary.receives("buildinfo").ok() + + self.assertEqual(TOPOLOGY_TYPE.ReplicaSetWithPrimary, topology.description.topology_type) - self.assertEqual(TOPOLOGY_TYPE.ReplicaSetWithPrimary, - topology.description.topology_type) -if __name__ == '__main__': +if __name__ == "__main__": unittest.main() diff --git a/test/mockupdb/test_op_msg.py b/test/mockupdb/test_op_msg.py index 35e70cebfc..d477d091b9 100755 --- a/test/mockupdb/test_op_msg.py +++ b/test/mockupdb/test_op_msg.py @@ -12,228 +12,252 @@ # See the License for the specific language governing permissions and # limitations under the License. +import unittest from collections import namedtuple -from mockupdb import MockupDB, going, OpMsg, OpMsgReply, OP_MSG_FLAGS +from mockupdb import OP_MSG_FLAGS, MockupDB, OpMsg, OpMsgReply, going + from pymongo import MongoClient, WriteConcern -from pymongo.operations import InsertOne, UpdateOne, DeleteOne from pymongo.cursor import CursorType +from pymongo.operations import DeleteOne, InsertOne, UpdateOne -import unittest - - -Operation = namedtuple( - 'Operation', - ['name', 'function', 'request', 'reply']) +Operation = namedtuple("Operation", ["name", "function", "request", "reply"]) operations = [ Operation( - 'find_one', + "find_one", lambda coll: coll.find_one({}), request=OpMsg({"find": "coll"}, flags=0), - reply={'ok': 1, 'cursor': {'firstBatch': [], 'id': 0}}), + reply={"ok": 1, "cursor": {"firstBatch": [], "id": 0}}, + ), Operation( - 'aggregate', + "aggregate", lambda coll: coll.aggregate([]), request=OpMsg({"aggregate": "coll"}, flags=0), - reply={'ok': 1, 'cursor': {'firstBatch': [], 'id': 0}}), + reply={"ok": 1, "cursor": {"firstBatch": [], "id": 0}}, + ), Operation( - 'insert_one', + "insert_one", lambda coll: coll.insert_one({}), request=OpMsg({"insert": "coll"}, flags=0), - reply={'ok': 1, 'n': 1}), + reply={"ok": 1, "n": 1}, + ), Operation( - 'insert_one-w0', - lambda coll: coll.with_options( - write_concern=WriteConcern(w=0)).insert_one({}), - request=OpMsg({"insert": "coll"}, flags=OP_MSG_FLAGS['moreToCome']), - reply=None), + "insert_one-w0", + lambda coll: coll.with_options(write_concern=WriteConcern(w=0)).insert_one({}), + request=OpMsg({"insert": "coll"}, flags=OP_MSG_FLAGS["moreToCome"]), + reply=None, + ), Operation( - 'insert_many', + "insert_many", lambda coll: coll.insert_many([{}, {}, {}]), request=OpMsg({"insert": "coll"}, flags=0), - reply={'ok': 1, 'n': 3}), + reply={"ok": 1, "n": 3}, + ), Operation( - 'insert_many-w0', - lambda coll: coll.with_options( - write_concern=WriteConcern(w=0)).insert_many([{}, {}, {}]), + "insert_many-w0", + lambda coll: coll.with_options(write_concern=WriteConcern(w=0)).insert_many([{}, {}, {}]), request=OpMsg({"insert": "coll"}, flags=0), - reply={'ok': 1, 'n': 3}), + reply={"ok": 1, "n": 3}, + ), Operation( - 'insert_many-w0-unordered', - lambda coll: coll.with_options( - write_concern=WriteConcern(w=0)).insert_many( - [{}, {}, {}], ordered=False), - request=OpMsg({"insert": "coll"}, flags=OP_MSG_FLAGS['moreToCome']), - reply=None), + "insert_many-w0-unordered", + lambda coll: coll.with_options(write_concern=WriteConcern(w=0)).insert_many( + [{}, {}, {}], ordered=False + ), + request=OpMsg({"insert": "coll"}, flags=OP_MSG_FLAGS["moreToCome"]), + reply=None, + ), Operation( - 'replace_one', + "replace_one", lambda coll: coll.replace_one({"_id": 1}, {"new": 1}), request=OpMsg({"update": "coll"}, flags=0), - reply={'ok': 1, 'n': 1, 'nModified': 1}), + reply={"ok": 1, "n": 1, "nModified": 1}, + ), Operation( - 'replace_one-w0', - lambda coll: coll.with_options( - write_concern=WriteConcern(w=0)).replace_one({"_id": 1}, - {"new": 1}), - request=OpMsg({"update": "coll"}, flags=OP_MSG_FLAGS['moreToCome']), - reply=None), + "replace_one-w0", + lambda coll: coll.with_options(write_concern=WriteConcern(w=0)).replace_one( + {"_id": 1}, {"new": 1} + ), + request=OpMsg({"update": "coll"}, flags=OP_MSG_FLAGS["moreToCome"]), + reply=None, + ), Operation( - 'update_one', + "update_one", lambda coll: coll.update_one({"_id": 1}, {"$set": {"new": 1}}), request=OpMsg({"update": "coll"}, flags=0), - reply={'ok': 1, 'n': 1, 'nModified': 1}), + reply={"ok": 1, "n": 1, "nModified": 1}, + ), Operation( - 'replace_one-w0', - lambda coll: coll.with_options( - write_concern=WriteConcern(w=0)).update_one({"_id": 1}, - {"$set": {"new": 1}}), - request=OpMsg({"update": "coll"}, flags=OP_MSG_FLAGS['moreToCome']), - reply=None), + "replace_one-w0", + lambda coll: coll.with_options(write_concern=WriteConcern(w=0)).update_one( + {"_id": 1}, {"$set": {"new": 1}} + ), + request=OpMsg({"update": "coll"}, flags=OP_MSG_FLAGS["moreToCome"]), + reply=None, + ), Operation( - 'update_many', + "update_many", lambda coll: coll.update_many({"_id": 1}, {"$set": {"new": 1}}), request=OpMsg({"update": "coll"}, flags=0), - reply={'ok': 1, 'n': 1, 'nModified': 1}), + reply={"ok": 1, "n": 1, "nModified": 1}, + ), Operation( - 'update_many-w0', - lambda coll: coll.with_options( - write_concern=WriteConcern(w=0)).update_many({"_id": 1}, - {"$set": {"new": 1}}), - request=OpMsg({"update": "coll"}, flags=OP_MSG_FLAGS['moreToCome']), - reply=None), + "update_many-w0", + lambda coll: coll.with_options(write_concern=WriteConcern(w=0)).update_many( + {"_id": 1}, {"$set": {"new": 1}} + ), + request=OpMsg({"update": "coll"}, flags=OP_MSG_FLAGS["moreToCome"]), + reply=None, + ), Operation( - 'delete_one', + "delete_one", lambda coll: coll.delete_one({"a": 1}), request=OpMsg({"delete": "coll"}, flags=0), - reply={'ok': 1, 'n': 1}), + reply={"ok": 1, "n": 1}, + ), Operation( - 'delete_one-w0', - lambda coll: coll.with_options( - write_concern=WriteConcern(w=0)).delete_one({"a": 1}), - request=OpMsg({"delete": "coll"}, flags=OP_MSG_FLAGS['moreToCome']), - reply=None), + "delete_one-w0", + lambda coll: coll.with_options(write_concern=WriteConcern(w=0)).delete_one({"a": 1}), + request=OpMsg({"delete": "coll"}, flags=OP_MSG_FLAGS["moreToCome"]), + reply=None, + ), Operation( - 'delete_many', + "delete_many", lambda coll: coll.delete_many({"a": 1}), request=OpMsg({"delete": "coll"}, flags=0), - reply={'ok': 1, 'n': 1}), + reply={"ok": 1, "n": 1}, + ), Operation( - 'delete_many-w0', - lambda coll: coll.with_options( - write_concern=WriteConcern(w=0)).delete_many({"a": 1}), - request=OpMsg({"delete": "coll"}, flags=OP_MSG_FLAGS['moreToCome']), - reply=None), + "delete_many-w0", + lambda coll: coll.with_options(write_concern=WriteConcern(w=0)).delete_many({"a": 1}), + request=OpMsg({"delete": "coll"}, flags=OP_MSG_FLAGS["moreToCome"]), + reply=None, + ), # Legacy methods Operation( - 'bulk_write_insert', + "bulk_write_insert", lambda coll: coll.bulk_write([InsertOne({}), InsertOne({})]), request=OpMsg({"insert": "coll"}, flags=0), - reply={'ok': 1, 'n': 2}), + reply={"ok": 1, "n": 2}, + ), Operation( - 'bulk_write_insert-w0', - lambda coll: coll.with_options( - write_concern=WriteConcern(w=0)).bulk_write([InsertOne({}), - InsertOne({})]), + "bulk_write_insert-w0", + lambda coll: coll.with_options(write_concern=WriteConcern(w=0)).bulk_write( + [InsertOne({}), InsertOne({})] + ), request=OpMsg({"insert": "coll"}, flags=0), - reply={'ok': 1, 'n': 2}), + reply={"ok": 1, "n": 2}, + ), Operation( - 'bulk_write_insert-w0-unordered', - lambda coll: coll.with_options( - write_concern=WriteConcern(w=0)).bulk_write( - [InsertOne({}), InsertOne({})], ordered=False), - request=OpMsg({"insert": "coll"}, flags=OP_MSG_FLAGS['moreToCome']), - reply=None), + "bulk_write_insert-w0-unordered", + lambda coll: coll.with_options(write_concern=WriteConcern(w=0)).bulk_write( + [InsertOne({}), InsertOne({})], ordered=False + ), + request=OpMsg({"insert": "coll"}, flags=OP_MSG_FLAGS["moreToCome"]), + reply=None, + ), Operation( - 'bulk_write_update', - lambda coll: coll.bulk_write([ - UpdateOne({"_id": 1}, {"$set": {"new": 1}}), - UpdateOne({"_id": 2}, {"$set": {"new": 1}})]), + "bulk_write_update", + lambda coll: coll.bulk_write( + [ + UpdateOne({"_id": 1}, {"$set": {"new": 1}}), + UpdateOne({"_id": 2}, {"$set": {"new": 1}}), + ] + ), request=OpMsg({"update": "coll"}, flags=0), - reply={'ok': 1, 'n': 2, 'nModified': 2}), + reply={"ok": 1, "n": 2, "nModified": 2}, + ), Operation( - 'bulk_write_update-w0', - lambda coll: coll.with_options( - write_concern=WriteConcern(w=0)).bulk_write([ + "bulk_write_update-w0", + lambda coll: coll.with_options(write_concern=WriteConcern(w=0)).bulk_write( + [ UpdateOne({"_id": 1}, {"$set": {"new": 1}}), - UpdateOne({"_id": 2}, {"$set": {"new": 1}})]), + UpdateOne({"_id": 2}, {"$set": {"new": 1}}), + ] + ), request=OpMsg({"update": "coll"}, flags=0), - reply={'ok': 1, 'n': 2, 'nModified': 2}), + reply={"ok": 1, "n": 2, "nModified": 2}, + ), Operation( - 'bulk_write_update-w0-unordered', - lambda coll: coll.with_options( - write_concern=WriteConcern(w=0)).bulk_write([ + "bulk_write_update-w0-unordered", + lambda coll: coll.with_options(write_concern=WriteConcern(w=0)).bulk_write( + [ UpdateOne({"_id": 1}, {"$set": {"new": 1}}), - UpdateOne({"_id": 2}, {"$set": {"new": 1}})], ordered=False), - request=OpMsg({"update": "coll"}, flags=OP_MSG_FLAGS['moreToCome']), - reply=None), + UpdateOne({"_id": 2}, {"$set": {"new": 1}}), + ], + ordered=False, + ), + request=OpMsg({"update": "coll"}, flags=OP_MSG_FLAGS["moreToCome"]), + reply=None, + ), Operation( - 'bulk_write_delete', - lambda coll: coll.bulk_write([ - DeleteOne({"_id": 1}), DeleteOne({"_id": 2})]), + "bulk_write_delete", + lambda coll: coll.bulk_write([DeleteOne({"_id": 1}), DeleteOne({"_id": 2})]), request=OpMsg({"delete": "coll"}, flags=0), - reply={'ok': 1, 'n': 2}), + reply={"ok": 1, "n": 2}, + ), Operation( - 'bulk_write_delete-w0', - lambda coll: coll.with_options( - write_concern=WriteConcern(w=0)).bulk_write([ - DeleteOne({"_id": 1}), DeleteOne({"_id": 2})]), + "bulk_write_delete-w0", + lambda coll: coll.with_options(write_concern=WriteConcern(w=0)).bulk_write( + [DeleteOne({"_id": 1}), DeleteOne({"_id": 2})] + ), request=OpMsg({"delete": "coll"}, flags=0), - reply={'ok': 1, 'n': 2}), + reply={"ok": 1, "n": 2}, + ), Operation( - 'bulk_write_delete-w0-unordered', - lambda coll: coll.with_options( - write_concern=WriteConcern(w=0)).bulk_write([ - DeleteOne({"_id": 1}), DeleteOne({"_id": 2})], ordered=False), - request=OpMsg({"delete": "coll"}, flags=OP_MSG_FLAGS['moreToCome']), - reply=None), + "bulk_write_delete-w0-unordered", + lambda coll: coll.with_options(write_concern=WriteConcern(w=0)).bulk_write( + [DeleteOne({"_id": 1}), DeleteOne({"_id": 2})], ordered=False + ), + request=OpMsg({"delete": "coll"}, flags=OP_MSG_FLAGS["moreToCome"]), + reply=None, + ), ] operations_312 = [ Operation( - 'find_raw_batches', + "find_raw_batches", lambda coll: list(coll.find_raw_batches({})), request=[ OpMsg({"find": "coll"}, flags=0), OpMsg({"getMore": 7}, flags=0), ], reply=[ - {'ok': 1, 'cursor': {'firstBatch': [{}], 'id': 7}}, - {'ok': 1, 'cursor': {'nextBatch': [{}], 'id': 0}}, - ]), + {"ok": 1, "cursor": {"firstBatch": [{}], "id": 7}}, + {"ok": 1, "cursor": {"nextBatch": [{}], "id": 0}}, + ], + ), Operation( - 'aggregate_raw_batches', + "aggregate_raw_batches", lambda coll: list(coll.aggregate_raw_batches([])), request=[ OpMsg({"aggregate": "coll"}, flags=0), OpMsg({"getMore": 7}, flags=0), ], reply=[ - {'ok': 1, 'cursor': {'firstBatch': [], 'id': 7}}, - {'ok': 1, 'cursor': {'nextBatch': [{}], 'id': 0}}, - ]), + {"ok": 1, "cursor": {"firstBatch": [], "id": 7}}, + {"ok": 1, "cursor": {"nextBatch": [{}], "id": 0}}, + ], + ), Operation( - 'find_exhaust_cursor', + "find_exhaust_cursor", lambda coll: list(coll.find({}, cursor_type=CursorType.EXHAUST)), request=[ OpMsg({"find": "coll"}, flags=0), OpMsg({"getMore": 7}, flags=1 << 16), ], reply=[ - OpMsgReply( - {'ok': 1, 'cursor': {'firstBatch': [{}], 'id': 7}}, flags=0), - OpMsgReply( - {'ok': 1, 'cursor': {'nextBatch': [{}], 'id': 7}}, flags=2), - OpMsgReply( - {'ok': 1, 'cursor': {'nextBatch': [{}], 'id': 7}}, flags=2), - OpMsgReply( - {'ok': 1, 'cursor': {'nextBatch': [{}], 'id': 0}}, flags=0), - ]), + OpMsgReply({"ok": 1, "cursor": {"firstBatch": [{}], "id": 7}}, flags=0), + OpMsgReply({"ok": 1, "cursor": {"nextBatch": [{}], "id": 7}}, flags=2), + OpMsgReply({"ok": 1, "cursor": {"nextBatch": [{}], "id": 7}}, flags=2), + OpMsgReply({"ok": 1, "cursor": {"nextBatch": [{}], "id": 0}}, flags=0), + ], + ), ] class TestOpMsg(unittest.TestCase): - @classmethod def setUpClass(cls): cls.server = MockupDB(auto_ismaster=True, max_wire_version=8) @@ -271,6 +295,7 @@ def _test_operation(self, op): def operation_test(op): def test(self): self._test_operation(op) + return test @@ -284,5 +309,5 @@ def create_tests(ops): create_tests(operations_312) -if __name__ == '__main__': +if __name__ == "__main__": unittest.main() diff --git a/test/mockupdb/test_op_msg_read_preference.py b/test/mockupdb/test_op_msg_read_preference.py index 6ecc229ea1..287d747d9e 100644 --- a/test/mockupdb/test_op_msg_read_preference.py +++ b/test/mockupdb/test_op_msg_read_preference.py @@ -14,16 +14,18 @@ import copy import itertools - -from mockupdb import MockupDB, going, CommandBase -from pymongo import MongoClient, ReadPreference -from pymongo.read_preferences import (make_read_preference, - read_pref_mode_from_name, - _MONGOS_MODES) - import unittest + +from mockupdb import CommandBase, MockupDB, going from operations import operations +from pymongo import MongoClient, ReadPreference +from pymongo.read_preferences import ( + _MONGOS_MODES, + make_read_preference, + read_pref_mode_from_name, +) + class OpMsgReadPrefBase(unittest.TestCase): single_mongod = False @@ -37,22 +39,20 @@ def add_test(cls, mode, test_name, test): setattr(cls, test_name, test) def setup_client(self, read_preference): - client = MongoClient(self.primary.uri, - read_preference=read_preference) + client = MongoClient(self.primary.uri, read_preference=read_preference) self.addCleanup(client.close) return client class TestOpMsgMongos(OpMsgReadPrefBase): - @classmethod def setUpClass(cls): super(TestOpMsgMongos, cls).setUpClass() auto_ismaster = { - 'ismaster': True, - 'msg': 'isdbgrid', # Mongos. - 'minWireVersion': 2, - 'maxWireVersion': 6, + "ismaster": True, + "msg": "isdbgrid", # Mongos. + "minWireVersion": 2, + "maxWireVersion": 6, } cls.primary = MockupDB(auto_ismaster=auto_ismaster) cls.primary.run() @@ -65,7 +65,6 @@ def tearDownClass(cls): class TestOpMsgReplicaSet(OpMsgReadPrefBase): - @classmethod def setUpClass(cls): super(TestOpMsgReplicaSet, cls).setUpClass() @@ -73,21 +72,20 @@ def setUpClass(cls): for server in cls.primary, cls.secondary: server.run() - hosts = [server.address_string - for server in (cls.primary, cls.secondary)] + hosts = [server.address_string for server in (cls.primary, cls.secondary)] primary_ismaster = { - 'ismaster': True, - 'setName': 'rs', - 'hosts': hosts, - 'minWireVersion': 2, - 'maxWireVersion': 6, + "ismaster": True, + "setName": "rs", + "hosts": hosts, + "minWireVersion": 2, + "maxWireVersion": 6, } - cls.primary.autoresponds(CommandBase('ismaster'), primary_ismaster) + cls.primary.autoresponds(CommandBase("ismaster"), primary_ismaster) secondary_ismaster = copy.copy(primary_ismaster) - secondary_ismaster['ismaster'] = False - secondary_ismaster['secondary'] = True - cls.secondary.autoresponds(CommandBase('ismaster'), secondary_ismaster) + secondary_ismaster["ismaster"] = False + secondary_ismaster["secondary"] = True + cls.secondary.autoresponds(CommandBase("ismaster"), secondary_ismaster) @classmethod def tearDownClass(cls): @@ -99,18 +97,15 @@ def tearDownClass(cls): def add_test(cls, mode, test_name, test): # Skip nearest tests since we don't know if we will select the primary # or secondary. - if mode != 'nearest': + if mode != "nearest": setattr(cls, test_name, test) def setup_client(self, read_preference): - client = MongoClient(self.primary.uri, - replicaSet='rs', - read_preference=read_preference) + client = MongoClient(self.primary.uri, replicaSet="rs", read_preference=read_preference) # Run a command on a secondary to discover the topology. This ensures # that secondaryPreferred commands will select the secondary. - client.admin.command('ismaster', - read_preference=ReadPreference.SECONDARY) + client.admin.command("ismaster", read_preference=ReadPreference.SECONDARY) self.addCleanup(client.close) return client @@ -122,9 +117,9 @@ class TestOpMsgSingle(OpMsgReadPrefBase): def setUpClass(cls): super(TestOpMsgSingle, cls).setUpClass() auto_ismaster = { - 'ismaster': True, - 'minWireVersion': 2, - 'maxWireVersion': 6, + "ismaster": True, + "minWireVersion": 2, + "maxWireVersion": 6, } cls.primary = MockupDB(auto_ismaster=auto_ismaster) cls.primary.run() @@ -138,25 +133,24 @@ def tearDownClass(cls): def create_op_msg_read_mode_test(mode, operation): def test(self): - pref = make_read_preference(read_pref_mode_from_name(mode), - tag_sets=None) + pref = make_read_preference(read_pref_mode_from_name(mode), tag_sets=None) client = self.setup_client(read_preference=pref) - if operation.op_type == 'always-use-secondary': + if operation.op_type == "always-use-secondary": expected_server = self.secondary expected_pref = ReadPreference.SECONDARY - elif operation.op_type == 'must-use-primary': + elif operation.op_type == "must-use-primary": expected_server = self.primary expected_pref = ReadPreference.PRIMARY - elif operation.op_type == 'may-use-secondary': - if mode in ('primary', 'primaryPreferred'): + elif operation.op_type == "may-use-secondary": + if mode in ("primary", "primaryPreferred"): expected_server = self.primary else: expected_server = self.secondary expected_pref = pref else: - self.fail('unrecognized op_type %r' % operation.op_type) + self.fail("unrecognized op_type %r" % operation.op_type) # For single mongod we send primaryPreferred instead of primary. if expected_pref == ReadPreference.PRIMARY and self.single_mongod: expected_pref = ReadPreference.PRIMARY_PREFERRED @@ -164,9 +158,8 @@ def test(self): request = expected_server.receive() request.reply(operation.reply) - self.assertEqual(expected_pref.document, - request.doc.get('$readPreference')) - self.assertNotIn('$query', request.doc) + self.assertEqual(expected_pref.document, request.doc.get("$readPreference")) + self.assertNotIn("$query", request.doc) return test @@ -177,8 +170,7 @@ def generate_op_msg_read_mode_tests(): for entry in matrix: mode, operation = entry test = create_op_msg_read_mode_test(mode, operation) - test_name = 'test_%s_with_mode_%s' % ( - operation.name.replace(' ', '_'), mode) + test_name = "test_%s_with_mode_%s" % (operation.name.replace(" ", "_"), mode) test.__name__ = test_name for cls in TestOpMsgMongos, TestOpMsgReplicaSet, TestOpMsgSingle: cls.add_test(mode, test_name, test) @@ -187,5 +179,5 @@ def generate_op_msg_read_mode_tests(): generate_op_msg_read_mode_tests() -if __name__ == '__main__': +if __name__ == "__main__": unittest.main() diff --git a/test/mockupdb/test_query_read_pref_sharded.py b/test/mockupdb/test_query_read_pref_sharded.py index 21813f7b8e..2d8d87478c 100644 --- a/test/mockupdb/test_query_read_pref_sharded.py +++ b/test/mockupdb/test_query_read_pref_sharded.py @@ -14,24 +14,28 @@ """Test PyMongo query and read preference with a sharded cluster.""" +import unittest + +from mockupdb import MockupDB, OpMsg, going + from bson import SON from pymongo import MongoClient -from pymongo.read_preferences import (Primary, - PrimaryPreferred, - Secondary, - SecondaryPreferred, - Nearest) -from mockupdb import MockupDB, going, OpMsg - -import unittest +from pymongo.read_preferences import ( + Nearest, + Primary, + PrimaryPreferred, + Secondary, + SecondaryPreferred, +) class TestQueryAndReadModeSharded(unittest.TestCase): def test_query_and_read_mode_sharded_op_msg(self): """Test OP_MSG sends non-primary $readPreference and never $query.""" server = MockupDB() - server.autoresponds('ismaster', ismaster=True, msg='isdbgrid', - minWireVersion=2, maxWireVersion=6) + server.autoresponds( + "ismaster", ismaster=True, msg="isdbgrid", minWireVersion=2, maxWireVersion=6 + ) server.run() self.addCleanup(server.stop) @@ -44,23 +48,33 @@ def test_query_and_read_mode_sharded_op_msg(self): PrimaryPreferred(), Secondary(), Nearest(), - SecondaryPreferred([{'tag': 'value'}]),) + SecondaryPreferred([{"tag": "value"}]), + ) - for query in ({'a': 1}, {'$query': {'a': 1}},): + for query in ( + {"a": 1}, + {"$query": {"a": 1}}, + ): for mode in read_prefs: - collection = client.db.get_collection('test', - read_preference=mode) + collection = client.db.get_collection("test", read_preference=mode) cursor = collection.find(query.copy()) with going(next, cursor): request = server.receives() # Command is not nested in $query. - request.assert_matches(OpMsg( - SON([('find', 'test'), - ('filter', {'a': 1}), - ('$readPreference', mode.document)]))) + request.assert_matches( + OpMsg( + SON( + [ + ("find", "test"), + ("filter", {"a": 1}), + ("$readPreference", mode.document), + ] + ) + ) + ) - request.replies({'cursor': {'id': 0, 'firstBatch': [{}]}}) + request.replies({"cursor": {"id": 0, "firstBatch": [{}]}}) -if __name__ == '__main__': +if __name__ == "__main__": unittest.main() diff --git a/test/mockupdb/test_reset_and_request_check.py b/test/mockupdb/test_reset_and_request_check.py index 86c2085e39..b1f752a452 100755 --- a/test/mockupdb/test_reset_and_request_check.py +++ b/test/mockupdb/test_reset_and_request_check.py @@ -12,17 +12,17 @@ # See the License for the specific language governing permissions and # limitations under the License. -import time import itertools +import time +import unittest from mockupdb import MockupDB, going, wait_until -from pymongo.server_type import SERVER_TYPE -from pymongo.errors import ConnectionFailure -from pymongo import MongoClient - -import unittest from operations import operations +from pymongo import MongoClient +from pymongo.errors import ConnectionFailure +from pymongo.server_type import SERVER_TYPE + class TestResetAndRequestCheck(unittest.TestCase): def __init__(self, *args, **kwargs): @@ -38,18 +38,18 @@ def responder(request): self.ismaster_time = time.time() return request.ok(ismaster=True, minWireVersion=2, maxWireVersion=6) - self.server.autoresponds('ismaster', responder) + self.server.autoresponds("ismaster", responder) self.server.run() self.addCleanup(self.server.stop) - kwargs = {'socketTimeoutMS': 100} + kwargs = {"socketTimeoutMS": 100} # Disable retryable reads when pymongo supports it. - kwargs['retryReads'] = False + kwargs["retryReads"] = False self.client = MongoClient(self.server.uri, **kwargs) - wait_until(lambda: self.client.nodes, 'connect to standalone') + wait_until(lambda: self.client.nodes, "connect to standalone") def tearDown(self): - if hasattr(self, 'client') and self.client: + if hasattr(self, "client") and self.client: self.client.close() def _test_disconnect(self, operation): @@ -71,11 +71,11 @@ def _test_disconnect(self, operation): after = time.time() # Demand a reconnect. - with going(self.client.db.command, 'buildinfo'): - self.server.receives('buildinfo').ok() + with going(self.client.db.command, "buildinfo"): + self.server.receives("buildinfo").ok() last = self.ismaster_time - self.assertGreaterEqual(last, after, 'called ismaster before needed') + self.assertGreaterEqual(last, after, "called ismaster before needed") def _test_timeout(self, operation): # Application operation times out. Test that client does *not* reset @@ -94,7 +94,7 @@ def _test_timeout(self, operation): self.assertEqual(SERVER_TYPE.Standalone, server.description.server_type) after = self.ismaster_time - self.assertEqual(after, before, 'unneeded ismaster call') + self.assertEqual(after, before, "unneeded ismaster call") def _test_not_master(self, operation): # Application operation gets a "not master" error. @@ -113,7 +113,7 @@ def _test_not_master(self, operation): self.assertEqual(SERVER_TYPE.Standalone, server.description.server_type) after = self.ismaster_time - self.assertGreater(after, before, 'ismaster not called') + self.assertGreater(after, before, "ismaster not called") def create_reset_test(operation, test_method): @@ -125,9 +125,9 @@ def test(self): def generate_reset_tests(): test_methods = [ - (TestResetAndRequestCheck._test_disconnect, 'test_disconnect'), - (TestResetAndRequestCheck._test_timeout, 'test_timeout'), - (TestResetAndRequestCheck._test_not_master, 'test_not_master'), + (TestResetAndRequestCheck._test_disconnect, "test_disconnect"), + (TestResetAndRequestCheck._test_timeout, "test_timeout"), + (TestResetAndRequestCheck._test_not_master, "test_not_master"), ] matrix = itertools.product(operations, test_methods) @@ -135,12 +135,12 @@ def generate_reset_tests(): for entry in matrix: operation, (test_method, name) = entry test = create_reset_test(operation, test_method) - test_name = '%s_%s' % (name, operation.name.replace(' ', '_')) + test_name = "%s_%s" % (name, operation.name.replace(" ", "_")) test.__name__ = test_name setattr(TestResetAndRequestCheck, test_name, test) generate_reset_tests() -if __name__ == '__main__': +if __name__ == "__main__": unittest.main() diff --git a/test/mockupdb/test_rsghost.py b/test/mockupdb/test_rsghost.py index 2f02503f54..354399728d 100644 --- a/test/mockupdb/test_rsghost.py +++ b/test/mockupdb/test_rsghost.py @@ -15,38 +15,45 @@ """Test connections to RSGhost nodes.""" import datetime +import unittest + +from mockupdb import MockupDB, going -from mockupdb import going, MockupDB from pymongo import MongoClient from pymongo.errors import ServerSelectionTimeoutError -import unittest - class TestRSGhost(unittest.TestCase): - def test_rsghost(self): rsother_response = { - 'ok': 1.0, 'ismaster': False, 'secondary': False, - 'info': 'Does not have a valid replica set config', - 'isreplicaset': True, 'maxBsonObjectSize': 16777216, - 'maxMessageSizeBytes': 48000000, 'maxWriteBatchSize': 100000, - 'localTime': datetime.datetime(2021, 11, 30, 0, 53, 4, 99000), - 'logicalSessionTimeoutMinutes': 30, 'connectionId': 3, - 'minWireVersion': 0, 'maxWireVersion': 15, 'readOnly': False} + "ok": 1.0, + "ismaster": False, + "secondary": False, + "info": "Does not have a valid replica set config", + "isreplicaset": True, + "maxBsonObjectSize": 16777216, + "maxMessageSizeBytes": 48000000, + "maxWriteBatchSize": 100000, + "localTime": datetime.datetime(2021, 11, 30, 0, 53, 4, 99000), + "logicalSessionTimeoutMinutes": 30, + "connectionId": 3, + "minWireVersion": 0, + "maxWireVersion": 15, + "readOnly": False, + } server = MockupDB(auto_ismaster=rsother_response) server.run() self.addCleanup(server.stop) # Default auto discovery yields a server selection timeout. with MongoClient(server.uri, serverSelectionTimeoutMS=250) as client: with self.assertRaises(ServerSelectionTimeoutError): - client.test.command('ping') + client.test.command("ping") # Direct connection succeeds. with MongoClient(server.uri, directConnection=True) as client: - with going(client.test.command, 'ping'): + with going(client.test.command, "ping"): request = server.receives(ping=1) request.reply() -if __name__ == '__main__': +if __name__ == "__main__": unittest.main() diff --git a/test/mockupdb/test_slave_okay_rs.py b/test/mockupdb/test_slave_okay_rs.py index 5ff6fced4e..5a162c08e3 100644 --- a/test/mockupdb/test_slave_okay_rs.py +++ b/test/mockupdb/test_slave_okay_rs.py @@ -17,12 +17,13 @@ Just make sure SlaveOkay is *not* set on primary reads. """ -from mockupdb import MockupDB, going -from pymongo import MongoClient - import unittest + +from mockupdb import MockupDB, going from operations import operations +from pymongo import MongoClient + class TestSlaveOkayRS(unittest.TestCase): def setup_server(self): @@ -31,24 +32,27 @@ def setup_server(self): server.run() self.addCleanup(server.stop) - hosts = [server.address_string - for server in (self.primary, self.secondary)] + hosts = [server.address_string for server in (self.primary, self.secondary)] self.primary.autoresponds( - 'ismaster', - ismaster=True, setName='rs', hosts=hosts, - minWireVersion=2, maxWireVersion=6) + "ismaster", ismaster=True, setName="rs", hosts=hosts, minWireVersion=2, maxWireVersion=6 + ) self.secondary.autoresponds( - 'ismaster', - ismaster=False, secondary=True, setName='rs', hosts=hosts, - minWireVersion=2, maxWireVersion=6) + "ismaster", + ismaster=False, + secondary=True, + setName="rs", + hosts=hosts, + minWireVersion=2, + maxWireVersion=6, + ) def create_slave_ok_rs_test(operation): def test(self): self.setup_server() - assert not operation.op_type == 'always-use-secondary' + assert not operation.op_type == "always-use-secondary" - client = MongoClient(self.primary.uri, replicaSet='rs') + client = MongoClient(self.primary.uri, replicaSet="rs") self.addCleanup(client.close) with going(operation.function, client): request = self.primary.receive() @@ -63,11 +67,11 @@ def generate_slave_ok_rs_tests(): for operation in operations: # Don't test secondary operations with MockupDB, the server enforces the # SlaveOkay bit so integration tests prove we set it. - if operation.op_type == 'always-use-secondary': + if operation.op_type == "always-use-secondary": continue test = create_slave_ok_rs_test(operation) - test_name = 'test_%s' % operation.name.replace(' ', '_') + test_name = "test_%s" % operation.name.replace(" ", "_") test.__name__ = test_name setattr(TestSlaveOkayRS, test_name, test) @@ -75,5 +79,5 @@ def generate_slave_ok_rs_tests(): generate_slave_ok_rs_tests() -if __name__ == '__main__': +if __name__ == "__main__": unittest.main() diff --git a/test/mockupdb/test_slave_okay_sharded.py b/test/mockupdb/test_slave_okay_sharded.py index 719de57553..44e570cada 100644 --- a/test/mockupdb/test_slave_okay_sharded.py +++ b/test/mockupdb/test_slave_okay_sharded.py @@ -20,20 +20,20 @@ """ import itertools -from pymongo.read_preferences import make_read_preference -from pymongo.read_preferences import read_pref_mode_from_name +from pymongo.read_preferences import make_read_preference, read_pref_mode_from_name try: from queue import Queue except ImportError: from Queue import Queue -from mockupdb import MockupDB, going -from pymongo import MongoClient - import unittest + +from mockupdb import MockupDB, going from operations import operations +from pymongo import MongoClient + class TestSlaveOkaySharded(unittest.TestCase): def setup_server(self): @@ -45,27 +45,29 @@ def setup_server(self): server.subscribe(self.q.put) server.run() self.addCleanup(server.stop) - server.autoresponds('ismaster', minWireVersion=2, maxWireVersion=6, - ismaster=True, msg='isdbgrid') + server.autoresponds( + "ismaster", minWireVersion=2, maxWireVersion=6, ismaster=True, msg="isdbgrid" + ) - self.mongoses_uri = 'mongodb://%s,%s' % (self.mongos1.address_string, - self.mongos2.address_string) + self.mongoses_uri = "mongodb://%s,%s" % ( + self.mongos1.address_string, + self.mongos2.address_string, + ) def create_slave_ok_sharded_test(mode, operation): def test(self): self.setup_server() - if operation.op_type == 'always-use-secondary': + if operation.op_type == "always-use-secondary": slave_ok = True - elif operation.op_type == 'may-use-secondary': - slave_ok = mode != 'primary' - elif operation.op_type == 'must-use-primary': + elif operation.op_type == "may-use-secondary": + slave_ok = mode != "primary" + elif operation.op_type == "must-use-primary": slave_ok = False else: - assert False, 'unrecognized op_type %r' % operation.op_type + assert False, "unrecognized op_type %r" % operation.op_type - pref = make_read_preference(read_pref_mode_from_name(mode), - tag_sets=None) + pref = make_read_preference(read_pref_mode_from_name(mode), tag_sets=None) client = MongoClient(self.mongoses_uri, read_preference=pref) self.addCleanup(client.close) @@ -74,22 +76,21 @@ def test(self): request.reply(operation.reply) if slave_ok: - self.assertTrue(request.slave_ok, 'SlaveOkay not set') + self.assertTrue(request.slave_ok, "SlaveOkay not set") else: - self.assertFalse(request.slave_ok, 'SlaveOkay set') + self.assertFalse(request.slave_ok, "SlaveOkay set") return test def generate_slave_ok_sharded_tests(): - modes = 'primary', 'secondary', 'nearest' + modes = "primary", "secondary", "nearest" matrix = itertools.product(modes, operations) for entry in matrix: mode, operation = entry test = create_slave_ok_sharded_test(mode, operation) - test_name = 'test_%s_with_mode_%s' % ( - operation.name.replace(' ', '_'), mode) + test_name = "test_%s_with_mode_%s" % (operation.name.replace(" ", "_"), mode) test.__name__ = test_name setattr(TestSlaveOkaySharded, test_name, test) @@ -97,5 +98,5 @@ def generate_slave_ok_sharded_tests(): generate_slave_ok_sharded_tests() -if __name__ == '__main__': +if __name__ == "__main__": unittest.main() diff --git a/test/mockupdb/test_slave_okay_single.py b/test/mockupdb/test_slave_okay_single.py index 83c0f925a4..98cd1f2706 100644 --- a/test/mockupdb/test_slave_okay_single.py +++ b/test/mockupdb/test_slave_okay_single.py @@ -20,16 +20,15 @@ """ import itertools +import unittest from mockupdb import MockupDB, going +from operations import operations + from pymongo import MongoClient -from pymongo.read_preferences import (make_read_preference, - read_pref_mode_from_name) +from pymongo.read_preferences import make_read_preference, read_pref_mode_from_name from pymongo.topology_description import TOPOLOGY_TYPE -import unittest -from operations import operations - def topology_type_name(client): topology_type = client._topology._description.topology_type @@ -46,20 +45,19 @@ def setUp(self): def create_slave_ok_single_test(mode, server_type, ismaster, operation): def test(self): ismaster_with_version = ismaster.copy() - ismaster_with_version['minWireVersion'] = 2 - ismaster_with_version['maxWireVersion'] = 6 - self.server.autoresponds('ismaster', **ismaster_with_version) - if operation.op_type == 'always-use-secondary': + ismaster_with_version["minWireVersion"] = 2 + ismaster_with_version["maxWireVersion"] = 6 + self.server.autoresponds("ismaster", **ismaster_with_version) + if operation.op_type == "always-use-secondary": slave_ok = True - elif operation.op_type == 'may-use-secondary': - slave_ok = mode != 'primary' or server_type != 'mongos' - elif operation.op_type == 'must-use-primary': - slave_ok = server_type != 'mongos' + elif operation.op_type == "may-use-secondary": + slave_ok = mode != "primary" or server_type != "mongos" + elif operation.op_type == "must-use-primary": + slave_ok = server_type != "mongos" else: - assert False, 'unrecognized op_type %r' % operation.op_type + assert False, "unrecognized op_type %r" % operation.op_type - pref = make_read_preference(read_pref_mode_from_name(mode), - tag_sets=None) + pref = make_read_preference(read_pref_mode_from_name(mode), tag_sets=None) client = MongoClient(self.server.uri, read_preference=pref) self.addCleanup(client.close) @@ -67,27 +65,30 @@ def test(self): request = self.server.receive() request.reply(operation.reply) - self.assertIn(topology_type_name(client), ['Sharded', 'Single']) + self.assertIn(topology_type_name(client), ["Sharded", "Single"]) return test def generate_slave_ok_single_tests(): - modes = 'primary', 'secondary', 'nearest' + modes = "primary", "secondary", "nearest" server_types = [ - ('standalone', {'ismaster': True}), - ('slave', {'ismaster': False}), - ('mongos', {'ismaster': True, 'msg': 'isdbgrid'})] + ("standalone", {"ismaster": True}), + ("slave", {"ismaster": False}), + ("mongos", {"ismaster": True, "msg": "isdbgrid"}), + ] matrix = itertools.product(modes, server_types, operations) for entry in matrix: mode, (server_type, ismaster), operation = entry - test = create_slave_ok_single_test(mode, server_type, ismaster, - operation) + test = create_slave_ok_single_test(mode, server_type, ismaster, operation) - test_name = 'test_%s_%s_with_mode_%s' % ( - operation.name.replace(' ', '_'), server_type, mode) + test_name = "test_%s_%s_with_mode_%s" % ( + operation.name.replace(" ", "_"), + server_type, + mode, + ) test.__name__ = test_name setattr(TestSlaveOkaySingle, test_name, test) @@ -96,5 +97,5 @@ def generate_slave_ok_single_tests(): generate_slave_ok_single_tests() -if __name__ == '__main__': +if __name__ == "__main__": unittest.main() diff --git a/test/mod_wsgi_test/test_client.py b/test/mod_wsgi_test/test_client.py index 61cf8df674..4ea3ff78ce 100644 --- a/test/mod_wsgi_test/test_client.py +++ b/test/mod_wsgi_test/test_client.py @@ -18,7 +18,6 @@ import sys import threading import time - from optparse import OptionParser try: @@ -36,31 +35,49 @@ def parse_args(): - parser = OptionParser("""usage: %prog [options] mode url + parser = OptionParser( + """usage: %prog [options] mode url - mode:\tparallel or serial""") + mode:\tparallel or serial""" + ) # Should be enough that any connection leak will exhaust available file # descriptors. parser.add_option( - "-n", "--nrequests", type="int", - dest="nrequests", default=50 * 1000, - help="Number of times to GET the URL, in total") + "-n", + "--nrequests", + type="int", + dest="nrequests", + default=50 * 1000, + help="Number of times to GET the URL, in total", + ) parser.add_option( - "-t", "--nthreads", type="int", - dest="nthreads", default=100, - help="Number of threads with mode 'parallel'") + "-t", + "--nthreads", + type="int", + dest="nthreads", + default=100, + help="Number of threads with mode 'parallel'", + ) parser.add_option( - "-q", "--quiet", - action="store_false", dest="verbose", default=True, - help="Don't print status messages to stdout") + "-q", + "--quiet", + action="store_false", + dest="verbose", + default=True, + help="Don't print status messages to stdout", + ) parser.add_option( - "-c", "--continue", - action="store_true", dest="continue_", default=False, - help="Continue after HTTP errors") + "-c", + "--continue", + action="store_true", + dest="continue_", + default=False, + help="Continue after HTTP errors", + ) try: options, (mode, url) = parser.parse_args() @@ -68,7 +85,7 @@ def parse_args(): parser.print_usage() sys.exit(1) - if mode not in ('parallel', 'serial'): + if mode not in ("parallel", "serial"): parser.print_usage() sys.exit(1) @@ -117,18 +134,22 @@ def run(self): def main(options, mode, url): start_time = time.time() errors = 0 - if mode == 'parallel': + if mode == "parallel": nrequests_per_thread = options.nrequests // options.nthreads if options.verbose: - print ( - 'Getting %s %s times total in %s threads, ' - '%s times per thread' % ( - url, nrequests_per_thread * options.nthreads, - options.nthreads, nrequests_per_thread)) + print( + "Getting %s %s times total in %s threads, " + "%s times per thread" + % ( + url, + nrequests_per_thread * options.nthreads, + options.nthreads, + nrequests_per_thread, + ) + ) threads = [ - URLGetterThread(options, url, nrequests_per_thread) - for _ in range(options.nthreads) + URLGetterThread(options, url, nrequests_per_thread) for _ in range(options.nthreads) ] for t in threads: @@ -140,14 +161,11 @@ def main(options, mode, url): errors = sum([t.errors for t in threads]) nthreads_with_errors = len([t for t in threads if t.errors]) if nthreads_with_errors: - print('%d threads had errors! %d errors in total' % ( - nthreads_with_errors, errors)) + print("%d threads had errors! %d errors in total" % (nthreads_with_errors, errors)) else: - assert mode == 'serial' + assert mode == "serial" if options.verbose: - print('Getting %s %s times in one thread' % ( - url, options.nrequests - )) + print("Getting %s %s times in one thread" % (url, options.nrequests)) for i in range(1, options.nrequests + 1): try: @@ -163,16 +181,16 @@ def main(options, mode, url): print(i) if errors: - print('%d errors!' % errors) + print("%d errors!" % errors) if options.verbose: - print('Completed in %.2f seconds' % (time.time() - start_time)) + print("Completed in %.2f seconds" % (time.time() - start_time)) if errors: # Failure sys.exit(1) -if __name__ == '__main__': +if __name__ == "__main__": options, mode, url = parse_args() main(options, mode, url) diff --git a/test/ocsp/test_ocsp.py b/test/ocsp/test_ocsp.py index 07197e73b6..cce846feac 100644 --- a/test/ocsp/test_ocsp.py +++ b/test/ocsp/test_ocsp.py @@ -22,19 +22,17 @@ sys.path[0:0] = [""] import pymongo - from pymongo.errors import ServerSelectionTimeoutError - CA_FILE = os.environ.get("CA_FILE") -OCSP_TLS_SHOULD_SUCCEED = (os.environ.get('OCSP_TLS_SHOULD_SUCCEED') == 'true') +OCSP_TLS_SHOULD_SUCCEED = os.environ.get("OCSP_TLS_SHOULD_SUCCEED") == "true" # Enable logs in this format: # 2020-06-08 23:49:35,982 DEBUG ocsp_support Peer did not staple an OCSP response -FORMAT = '%(asctime)s %(levelname)s %(module)s %(message)s' +FORMAT = "%(asctime)s %(levelname)s %(module)s %(message)s" logging.basicConfig(format=FORMAT, level=logging.DEBUG) -if sys.platform == 'win32': +if sys.platform == "win32": # The non-stapled OCSP endpoint check is slow on Windows. TIMEOUT_MS = 5000 else: @@ -42,15 +40,17 @@ def _connect(options): - uri = ("mongodb://localhost:27017/?serverSelectionTimeoutMS=%s" - "&tlsCAFile=%s&%s") % (TIMEOUT_MS, CA_FILE, options) + uri = ("mongodb://localhost:27017/?serverSelectionTimeoutMS=%s" "&tlsCAFile=%s&%s") % ( + TIMEOUT_MS, + CA_FILE, + options, + ) print(uri) client = pymongo.MongoClient(uri) - client.admin.command('ping') + client.admin.command("ping") class TestOCSP(unittest.TestCase): - def test_tls_insecure(self): # Should always succeed options = "tls=true&tlsInsecure=true" @@ -65,12 +65,11 @@ def test_tls(self): options = "tls=true" if not OCSP_TLS_SHOULD_SUCCEED: self.assertRaisesRegex( - ServerSelectionTimeoutError, - "invalid status response", - _connect, options) + ServerSelectionTimeoutError, "invalid status response", _connect, options + ) else: _connect(options) -if __name__ == '__main__': +if __name__ == "__main__": unittest.main() diff --git a/test/performance/perf_test.py b/test/performance/perf_test.py index d84e67aca4..dccebc6bdf 100644 --- a/test/performance/perf_test.py +++ b/test/performance/perf_test.py @@ -28,30 +28,30 @@ sys.path[0:0] = [""] +from test import client_context, host, port, unittest + from bson import decode, encode from bson.json_util import loads from gridfs import GridFSBucket from pymongo import MongoClient -from test import client_context, host, port, unittest NUM_ITERATIONS = 100 MAX_ITERATION_TIME = 300 NUM_DOCS = 10000 -TEST_PATH = os.environ.get('TEST_PATH', os.path.join( - os.path.dirname(os.path.realpath(__file__)), - os.path.join('data'))) +TEST_PATH = os.environ.get( + "TEST_PATH", os.path.join(os.path.dirname(os.path.realpath(__file__)), os.path.join("data")) +) -OUTPUT_FILE = os.environ.get('OUTPUT_FILE') +OUTPUT_FILE = os.environ.get("OUTPUT_FILE") result_data = [] + def tearDownModule(): - output = json.dumps({ - 'results': result_data - }, indent=4) + output = json.dumps({"results": result_data}, indent=4) if OUTPUT_FILE: - with open(OUTPUT_FILE, 'w') as opf: + with open(OUTPUT_FILE, "w") as opf: opf.write(output) else: print(output) @@ -68,7 +68,6 @@ def __exit__(self, *args): class PerformanceTest(object): - @classmethod def setUpClass(cls): client_context.init() @@ -80,16 +79,8 @@ def tearDown(self): name = self.__class__.__name__ median = self.percentile(50) result = self.data_size / median - print('Running %s. MEDIAN=%s' % (self.__class__.__name__, - self.percentile(50))) - result_data.append({ - 'name': name, - 'results': { - '1': { - 'ops_per_sec': result - } - } - }) + print("Running %s. MEDIAN=%s" % (self.__class__.__name__, self.percentile(50))) + result_data.append({"name": name, "results": {"1": {"ops_per_sec": result}}}) def before(self): pass @@ -98,12 +89,12 @@ def after(self): pass def percentile(self, percentile): - if hasattr(self, 'results'): + if hasattr(self, "results"): sorted_results = sorted(self.results) percentile_index = int(len(sorted_results) * percentile / 100) - 1 return sorted_results[percentile_index] else: - self.fail('Test execution failed') + self.fail("Test execution failed") def runTest(self): results = [] @@ -111,7 +102,7 @@ def runTest(self): self.max_iterations = NUM_ITERATIONS for i in range(NUM_ITERATIONS): if time.monotonic() - start > MAX_ITERATION_TIME: - warnings.warn('Test timed out, completed %s iterations.' % i) + warnings.warn("Test timed out, completed %s iterations." % i) break self.before() with Timer() as timer: @@ -126,9 +117,7 @@ def runTest(self): class BsonEncodingTest(PerformanceTest): def setUp(self): # Location of test data. - with open( - os.path.join(TEST_PATH, - os.path.join('extended_bson', self.dataset))) as data: + with open(os.path.join(TEST_PATH, os.path.join("extended_bson", self.dataset))) as data: self.document = loads(data.read()) def do_task(self): @@ -139,9 +128,7 @@ def do_task(self): class BsonDecodingTest(PerformanceTest): def setUp(self): # Location of test data. - with open( - os.path.join(TEST_PATH, - os.path.join('extended_bson', self.dataset))) as data: + with open(os.path.join(TEST_PATH, os.path.join("extended_bson", self.dataset))) as data: self.document = encode(json.loads(data.read())) def do_task(self): @@ -150,41 +137,42 @@ def do_task(self): class TestFlatEncoding(BsonEncodingTest, unittest.TestCase): - dataset = 'flat_bson.json' + dataset = "flat_bson.json" data_size = 75310000 class TestFlatDecoding(BsonDecodingTest, unittest.TestCase): - dataset = 'flat_bson.json' + dataset = "flat_bson.json" data_size = 75310000 class TestDeepEncoding(BsonEncodingTest, unittest.TestCase): - dataset = 'deep_bson.json' + dataset = "deep_bson.json" data_size = 19640000 class TestDeepDecoding(BsonDecodingTest, unittest.TestCase): - dataset = 'deep_bson.json' + dataset = "deep_bson.json" data_size = 19640000 class TestFullEncoding(BsonEncodingTest, unittest.TestCase): - dataset = 'full_bson.json' + dataset = "full_bson.json" data_size = 57340000 class TestFullDecoding(BsonDecodingTest, unittest.TestCase): - dataset = 'full_bson.json' + dataset = "full_bson.json" data_size = 57340000 # SINGLE-DOC BENCHMARKS class TestRunCommand(PerformanceTest, unittest.TestCase): data_size = 160000 + def setUp(self): self.client = client_context.client - self.client.drop_database('perftest') + self.client.drop_database("perftest") def do_task(self): command = self.client.perftest.command @@ -196,29 +184,29 @@ class TestDocument(PerformanceTest): def setUp(self): # Location of test data. with open( - os.path.join( - TEST_PATH, os.path.join( - 'single_and_multi_document', self.dataset)), 'r') as data: + os.path.join(TEST_PATH, os.path.join("single_and_multi_document", self.dataset)), "r" + ) as data: self.document = json.loads(data.read()) self.client = client_context.client - self.client.drop_database('perftest') + self.client.drop_database("perftest") def tearDown(self): super(TestDocument, self).tearDown() - self.client.drop_database('perftest') + self.client.drop_database("perftest") def before(self): - self.corpus = self.client.perftest.create_collection('corpus') + self.corpus = self.client.perftest.create_collection("corpus") def after(self): - self.client.perftest.drop_collection('corpus') + self.client.perftest.drop_collection("corpus") class TestFindOneByID(TestDocument, unittest.TestCase): data_size = 16220000 + def setUp(self): - self.dataset = 'tweet.json' + self.dataset = "tweet.json" super(TestFindOneByID, self).setUp() documents = [self.document.copy() for _ in range(NUM_DOCS)] @@ -229,7 +217,7 @@ def setUp(self): def do_task(self): find_one = self.corpus.find_one for _id in self.inserted_ids: - find_one({'_id': _id}) + find_one({"_id": _id}) def before(self): pass @@ -240,8 +228,9 @@ def after(self): class TestSmallDocInsertOne(TestDocument, unittest.TestCase): data_size = 2750000 + def setUp(self): - self.dataset = 'small_doc.json' + self.dataset = "small_doc.json" super(TestSmallDocInsertOne, self).setUp() self.documents = [self.document.copy() for _ in range(NUM_DOCS)] @@ -254,8 +243,9 @@ def do_task(self): class TestLargeDocInsertOne(TestDocument, unittest.TestCase): data_size = 27310890 + def setUp(self): - self.dataset = 'large_doc.json' + self.dataset = "large_doc.json" super(TestLargeDocInsertOne, self).setUp() self.documents = [self.document.copy() for _ in range(10)] @@ -269,14 +259,13 @@ def do_task(self): # MULTI-DOC BENCHMARKS class TestFindManyAndEmptyCursor(TestDocument, unittest.TestCase): data_size = 16220000 + def setUp(self): - self.dataset = 'tweet.json' + self.dataset = "tweet.json" super(TestFindManyAndEmptyCursor, self).setUp() for _ in range(10): - self.client.perftest.command( - 'insert', 'corpus', - documents=[self.document] * 1000) + self.client.perftest.command("insert", "corpus", documents=[self.document] * 1000) self.corpus = self.client.perftest.corpus def do_task(self): @@ -291,13 +280,14 @@ def after(self): class TestSmallDocBulkInsert(TestDocument, unittest.TestCase): data_size = 2750000 + def setUp(self): - self.dataset = 'small_doc.json' + self.dataset = "small_doc.json" super(TestSmallDocBulkInsert, self).setUp() self.documents = [self.document.copy() for _ in range(NUM_DOCS)] def before(self): - self.corpus = self.client.perftest.create_collection('corpus') + self.corpus = self.client.perftest.create_collection("corpus") def do_task(self): self.corpus.insert_many(self.documents, ordered=True) @@ -305,13 +295,14 @@ def do_task(self): class TestLargeDocBulkInsert(TestDocument, unittest.TestCase): data_size = 27310890 + def setUp(self): - self.dataset = 'large_doc.json' + self.dataset = "large_doc.json" super(TestLargeDocBulkInsert, self).setUp() self.documents = [self.document.copy() for _ in range(10)] def before(self): - self.corpus = self.client.perftest.create_collection('corpus') + self.corpus = self.client.perftest.create_collection("corpus") def do_task(self): self.corpus.insert_many(self.documents, ordered=True) @@ -319,47 +310,48 @@ def do_task(self): class TestGridFsUpload(PerformanceTest, unittest.TestCase): data_size = 52428800 + def setUp(self): self.client = client_context.client - self.client.drop_database('perftest') + self.client.drop_database("perftest") gridfs_path = os.path.join( - TEST_PATH, - os.path.join('single_and_multi_document', 'gridfs_large.bin')) - with open(gridfs_path, 'rb') as data: + TEST_PATH, os.path.join("single_and_multi_document", "gridfs_large.bin") + ) + with open(gridfs_path, "rb") as data: self.document = data.read() self.bucket = GridFSBucket(self.client.perftest) def tearDown(self): super(TestGridFsUpload, self).tearDown() - self.client.drop_database('perftest') + self.client.drop_database("perftest") def before(self): - self.bucket.upload_from_stream('init', b'x') + self.bucket.upload_from_stream("init", b"x") def do_task(self): - self.bucket.upload_from_stream('gridfstest', self.document) + self.bucket.upload_from_stream("gridfstest", self.document) class TestGridFsDownload(PerformanceTest, unittest.TestCase): data_size = 52428800 + def setUp(self): self.client = client_context.client - self.client.drop_database('perftest') + self.client.drop_database("perftest") gridfs_path = os.path.join( - TEST_PATH, - os.path.join('single_and_multi_document', 'gridfs_large.bin')) + TEST_PATH, os.path.join("single_and_multi_document", "gridfs_large.bin") + ) self.bucket = GridFSBucket(self.client.perftest) - with open(gridfs_path, 'rb') as gfile: - self.uploaded_id = self.bucket.upload_from_stream( - 'gridfstest', gfile) + with open(gridfs_path, "rb") as gfile: + self.uploaded_id = self.bucket.upload_from_stream("gridfstest", gfile) def tearDown(self): super(TestGridFsDownload, self).tearDown() - self.client.drop_database('perftest') + self.client.drop_database("perftest") def do_task(self): self.bucket.open_download_stream(self.uploaded_id).read() @@ -381,17 +373,17 @@ def mp_map(map_func, files): def insert_json_file(filename): - with open(filename, 'r') as data: + with open(filename, "r") as data: coll = proc_client.perftest.corpus coll.insert_many([json.loads(line) for line in data]) def insert_json_file_with_file_id(filename): documents = [] - with open(filename, 'r') as data: + with open(filename, "r") as data: for line in data: doc = json.loads(line) - doc['file'] = filename + doc["file"] = filename documents.append(doc) coll = proc_client.perftest.corpus coll.insert_many(documents) @@ -399,11 +391,11 @@ def insert_json_file_with_file_id(filename): def read_json_file(filename): coll = proc_client.perftest.corpus - temp = tempfile.TemporaryFile(mode='w') + temp = tempfile.TemporaryFile(mode="w") try: temp.writelines( - [json.dumps(doc) + '\n' for - doc in coll.find({'file': filename}, {'_id': False})]) + [json.dumps(doc) + "\n" for doc in coll.find({"file": filename}, {"_id": False})] + ) finally: temp.close() @@ -411,7 +403,7 @@ def read_json_file(filename): def insert_gridfs_file(filename): bucket = GridFSBucket(proc_client.perftest) - with open(filename, 'rb') as gfile: + with open(filename, "rb") as gfile: bucket.upload_from_stream(filename, gfile) @@ -427,41 +419,39 @@ def read_gridfs_file(filename): class TestJsonMultiImport(PerformanceTest, unittest.TestCase): data_size = 565000000 + def setUp(self): self.client = client_context.client - self.client.drop_database('perftest') + self.client.drop_database("perftest") def before(self): - self.client.perftest.command({'create': 'corpus'}) + self.client.perftest.command({"create": "corpus"}) self.corpus = self.client.perftest.corpus - ldjson_path = os.path.join( - TEST_PATH, os.path.join('parallel', 'ldjson_multi')) - self.files = [os.path.join( - ldjson_path, s) for s in os.listdir(ldjson_path)] + ldjson_path = os.path.join(TEST_PATH, os.path.join("parallel", "ldjson_multi")) + self.files = [os.path.join(ldjson_path, s) for s in os.listdir(ldjson_path)] def do_task(self): mp_map(insert_json_file, self.files) def after(self): - self.client.perftest.drop_collection('corpus') + self.client.perftest.drop_collection("corpus") def tearDown(self): super(TestJsonMultiImport, self).tearDown() - self.client.drop_database('perftest') + self.client.drop_database("perftest") class TestJsonMultiExport(PerformanceTest, unittest.TestCase): data_size = 565000000 + def setUp(self): self.client = client_context.client - self.client.drop_database('perftest') - self.client.perfest.corpus.create_index('file') + self.client.drop_database("perftest") + self.client.perfest.corpus.create_index("file") - ldjson_path = os.path.join( - TEST_PATH, os.path.join('parallel', 'ldjson_multi')) - self.files = [os.path.join( - ldjson_path, s) for s in os.listdir(ldjson_path)] + ldjson_path = os.path.join(TEST_PATH, os.path.join("parallel", "ldjson_multi")) + self.files = [os.path.join(ldjson_path, s) for s in os.listdir(ldjson_path)] mp_map(insert_json_file_with_file_id, self.files) @@ -470,48 +460,46 @@ def do_task(self): def tearDown(self): super(TestJsonMultiExport, self).tearDown() - self.client.drop_database('perftest') + self.client.drop_database("perftest") class TestGridFsMultiFileUpload(PerformanceTest, unittest.TestCase): data_size = 262144000 + def setUp(self): self.client = client_context.client - self.client.drop_database('perftest') + self.client.drop_database("perftest") def before(self): - self.client.perftest.drop_collection('fs.files') - self.client.perftest.drop_collection('fs.chunks') + self.client.perftest.drop_collection("fs.files") + self.client.perftest.drop_collection("fs.chunks") self.bucket = GridFSBucket(self.client.perftest) - gridfs_path = os.path.join( - TEST_PATH, os.path.join('parallel', 'gridfs_multi')) - self.files = [os.path.join( - gridfs_path, s) for s in os.listdir(gridfs_path)] + gridfs_path = os.path.join(TEST_PATH, os.path.join("parallel", "gridfs_multi")) + self.files = [os.path.join(gridfs_path, s) for s in os.listdir(gridfs_path)] def do_task(self): mp_map(insert_gridfs_file, self.files) def tearDown(self): super(TestGridFsMultiFileUpload, self).tearDown() - self.client.drop_database('perftest') + self.client.drop_database("perftest") class TestGridFsMultiFileDownload(PerformanceTest, unittest.TestCase): data_size = 262144000 + def setUp(self): self.client = client_context.client - self.client.drop_database('perftest') + self.client.drop_database("perftest") bucket = GridFSBucket(self.client.perftest) - gridfs_path = os.path.join( - TEST_PATH, os.path.join('parallel', 'gridfs_multi')) - self.files = [os.path.join( - gridfs_path, s) for s in os.listdir(gridfs_path)] + gridfs_path = os.path.join(TEST_PATH, os.path.join("parallel", "gridfs_multi")) + self.files = [os.path.join(gridfs_path, s) for s in os.listdir(gridfs_path)] for fname in self.files: - with open(fname, 'rb') as gfile: + with open(fname, "rb") as gfile: bucket.upload_from_stream(fname, gfile) def do_task(self): @@ -519,7 +507,7 @@ def do_task(self): def tearDown(self): super(TestGridFsMultiFileDownload, self).tearDown() - self.client.drop_database('perftest') + self.client.drop_database("perftest") if __name__ == "__main__": diff --git a/test/pymongo_mocks.py b/test/pymongo_mocks.py index 8b1ece8ad6..c0823a6d9a 100644 --- a/test/pymongo_mocks.py +++ b/test/pymongo_mocks.py @@ -15,19 +15,17 @@ """Tools for mocking parts of PyMongo to test other parts.""" import contextlib -from functools import partial import weakref +from functools import partial +from test import client_context -from pymongo import common -from pymongo import MongoClient +from pymongo import MongoClient, common from pymongo.errors import AutoReconnect, NetworkTimeout from pymongo.hello import Hello, HelloCompat from pymongo.monitor import Monitor from pymongo.pool import Pool from pymongo.server_description import ServerDescription -from test import client_context - class MockPool(Pool): def __init__(self, client, pair, *args, **kwargs): @@ -42,14 +40,13 @@ def __init__(self, client, pair, *args, **kwargs): @contextlib.contextmanager def get_socket(self, all_credentials, handler=None): client = self.client - host_and_port = '%s:%s' % (self.mock_host, self.mock_port) + host_and_port = "%s:%s" % (self.mock_host, self.mock_port) if host_and_port in client.mock_down_hosts: - raise AutoReconnect('mock error') + raise AutoReconnect("mock error") assert host_and_port in ( - client.mock_standalones - + client.mock_members - + client.mock_mongoses), "bad host: %s" % host_and_port + client.mock_standalones + client.mock_members + client.mock_mongoses + ), ("bad host: %s" % host_and_port) with Pool.get_socket(self, all_credentials, handler) as sock_info: sock_info.mock_host = self.mock_host @@ -79,34 +76,31 @@ def close(self): class MockMonitor(Monitor): - def __init__( - self, - client, - server_description, - topology, - pool, - topology_settings): + def __init__(self, client, server_description, topology, pool, topology_settings): # MockMonitor gets a 'client' arg, regular monitors don't. Weakref it # to avoid cycles. self.client = weakref.proxy(client) - Monitor.__init__( - self, - server_description, - topology, - pool, - topology_settings) + Monitor.__init__(self, server_description, topology, pool, topology_settings) def _check_once(self): client = self.client address = self._server_description.address - response, rtt = client.mock_hello('%s:%d' % address) + response, rtt = client.mock_hello("%s:%d" % address) return ServerDescription(address, Hello(response), rtt) class MockClient(MongoClient): def __init__( - self, standalones, members, mongoses, hello_hosts=None, - arbiters=None, down_hosts=None, *args, **kwargs): + self, + standalones, + members, + mongoses, + hello_hosts=None, + arbiters=None, + down_hosts=None, + *args, + **kwargs + ): """A MongoClient connected to the default server, with a mock topology. standalones, members, mongoses, arbiters, and down_hosts determine the @@ -144,8 +138,8 @@ def __init__( # Hostname -> round trip time self.mock_rtts = {} - kwargs['_pool_class'] = partial(MockPool, self) - kwargs['_monitor_class'] = partial(MockMonitor, self) + kwargs["_pool_class"] = partial(MockPool, self) + kwargs["_monitor_class"] = partial(MockMonitor, self) client_options = client_context.default_client_options.copy() client_options.update(kwargs) @@ -175,53 +169,57 @@ def mock_hello(self, host): max_wire_version = common.MAX_SUPPORTED_WIRE_VERSION max_write_batch_size = self.mock_max_write_batch_sizes.get( - host, common.MAX_WRITE_BATCH_SIZE) + host, common.MAX_WRITE_BATCH_SIZE + ) rtt = self.mock_rtts.get(host, 0) # host is like 'a:1'. if host in self.mock_down_hosts: - raise NetworkTimeout('mock timeout') + raise NetworkTimeout("mock timeout") elif host in self.mock_standalones: response = { - 'ok': 1, + "ok": 1, HelloCompat.LEGACY_CMD: True, - 'minWireVersion': min_wire_version, - 'maxWireVersion': max_wire_version, - 'maxWriteBatchSize': max_write_batch_size} + "minWireVersion": min_wire_version, + "maxWireVersion": max_wire_version, + "maxWriteBatchSize": max_write_batch_size, + } elif host in self.mock_members: - primary = (host == self.mock_primary) + primary = host == self.mock_primary # Simulate a replica set member. response = { - 'ok': 1, + "ok": 1, HelloCompat.LEGACY_CMD: primary, - 'secondary': not primary, - 'setName': 'rs', - 'hosts': self.mock_hello_hosts, - 'minWireVersion': min_wire_version, - 'maxWireVersion': max_wire_version, - 'maxWriteBatchSize': max_write_batch_size} + "secondary": not primary, + "setName": "rs", + "hosts": self.mock_hello_hosts, + "minWireVersion": min_wire_version, + "maxWireVersion": max_wire_version, + "maxWriteBatchSize": max_write_batch_size, + } if self.mock_primary: - response['primary'] = self.mock_primary + response["primary"] = self.mock_primary if host in self.mock_arbiters: - response['arbiterOnly'] = True - response['secondary'] = False + response["arbiterOnly"] = True + response["secondary"] = False elif host in self.mock_mongoses: response = { - 'ok': 1, + "ok": 1, HelloCompat.LEGACY_CMD: True, - 'minWireVersion': min_wire_version, - 'maxWireVersion': max_wire_version, - 'msg': 'isdbgrid', - 'maxWriteBatchSize': max_write_batch_size} + "minWireVersion": min_wire_version, + "maxWireVersion": max_wire_version, + "msg": "isdbgrid", + "maxWriteBatchSize": max_write_batch_size, + } else: # In test_internal_ips(), we try to connect to a host listed # in hello['hosts'] but not publicly accessible. - raise AutoReconnect('Unknown host: %s' % host) + raise AutoReconnect("Unknown host: %s" % host) return response, rtt diff --git a/test/qcheck.py b/test/qcheck.py index 57e0940b72..4cce7b5bc8 100644 --- a/test/qcheck.py +++ b/test/qcheck.py @@ -83,9 +83,7 @@ def gen_unichar(): def gen_unicode(gen_length): - return lambda: "".join([x for x in - gen_list(gen_unichar(), gen_length)() if - x not in ".$"]) + return lambda: "".join([x for x in gen_list(gen_unichar(), gen_length)() if x not in ".$"]) def gen_list(generator, gen_length): @@ -93,22 +91,24 @@ def gen_list(generator, gen_length): def gen_datetime(): - return lambda: datetime.datetime(random.randint(1970, 2037), - random.randint(1, 12), - random.randint(1, 28), - random.randint(0, 23), - random.randint(0, 59), - random.randint(0, 59), - random.randint(0, 999) * 1000) + return lambda: datetime.datetime( + random.randint(1970, 2037), + random.randint(1, 12), + random.randint(1, 28), + random.randint(0, 23), + random.randint(0, 59), + random.randint(0, 59), + random.randint(0, 999) * 1000, + ) def gen_dict(gen_key, gen_value, gen_length): - def a_dict(gen_key, gen_value, length): result = {} for _ in range(length): result[gen_key()] = gen_value() return result + return lambda: a_dict(gen_key, gen_value, gen_length()) @@ -128,6 +128,7 @@ def gen_flags(): flags = flags | re.VERBOSE return flags + return lambda: re.compile(pattern(), gen_flags()) @@ -142,15 +143,17 @@ def gen_dbref(): def gen_mongo_value(depth, ref): - choices = [gen_unicode(gen_range(0, 50)), - gen_printable_string(gen_range(0, 50)), - my_map(gen_string(gen_range(0, 1000)), bytes), - gen_int(), - gen_float(), - gen_boolean(), - gen_datetime(), - gen_objectid(), - lift(None)] + choices = [ + gen_unicode(gen_range(0, 50)), + gen_printable_string(gen_range(0, 50)), + my_map(gen_string(gen_range(0, 1000)), bytes), + gen_int(), + gen_float(), + gen_boolean(), + gen_datetime(), + gen_objectid(), + lift(None), + ] if ref: choices.append(gen_dbref()) if depth > 0: @@ -164,9 +167,10 @@ def gen_mongo_list(depth, ref): def gen_mongo_dict(depth, ref=True): - return my_map(gen_dict(gen_unicode(gen_range(0, 20)), - gen_mongo_value(depth - 1, ref), - gen_range(0, 10)), SON) + return my_map( + gen_dict(gen_unicode(gen_range(0, 20)), gen_mongo_value(depth - 1, ref), gen_range(0, 10)), + SON, + ) def simplify(case): # TODO this is a hack @@ -236,8 +240,10 @@ def check_unittest(test, predicate, generator): counter_examples = check(predicate, generator) if counter_examples: failures = len(counter_examples) - message = "\n".join([" -> %s" % f for f in - counter_examples[:examples]]) - message = ("found %d counter examples, displaying first %d:\n%s" % - (failures, min(failures, examples), message)) + message = "\n".join([" -> %s" % f for f in counter_examples[:examples]]) + message = "found %d counter examples, displaying first %d:\n%s" % ( + failures, + min(failures, examples), + message, + ) test.fail(message) diff --git a/test/test_auth.py b/test/test_auth.py index d0724dce72..1d8bc3d43c 100644 --- a/test/test_auth.py +++ b/test/test_auth.py @@ -17,41 +17,43 @@ import os import sys import threading - from urllib.parse import quote_plus sys.path[0:0] = [""] +from test import IntegrationTest, SkipTest, Version, client_context, unittest +from test.utils import ( + AllowListEventListener, + delay, + ignore_deprecations, + rs_or_single_client, + rs_or_single_client_noauth, + single_client, + single_client_noauth, +) + from pymongo import MongoClient, monitoring from pymongo.auth import HAVE_KERBEROS, _build_credentials_tuple from pymongo.errors import OperationFailure from pymongo.hello import HelloCompat from pymongo.read_preferences import ReadPreference from pymongo.saslprep import HAVE_STRINGPREP -from test import client_context, IntegrationTest, SkipTest, unittest, Version -from test.utils import (delay, - ignore_deprecations, - single_client, - rs_or_single_client, - rs_or_single_client_noauth, - single_client_noauth, - AllowListEventListener) # YOU MUST RUN KINIT BEFORE RUNNING GSSAPI TESTS ON UNIX. -GSSAPI_HOST = os.environ.get('GSSAPI_HOST') -GSSAPI_PORT = int(os.environ.get('GSSAPI_PORT', '27017')) -GSSAPI_PRINCIPAL = os.environ.get('GSSAPI_PRINCIPAL') -GSSAPI_SERVICE_NAME = os.environ.get('GSSAPI_SERVICE_NAME', 'mongodb') -GSSAPI_CANONICALIZE = os.environ.get('GSSAPI_CANONICALIZE', 'false') -GSSAPI_SERVICE_REALM = os.environ.get('GSSAPI_SERVICE_REALM') -GSSAPI_PASS = os.environ.get('GSSAPI_PASS') -GSSAPI_DB = os.environ.get('GSSAPI_DB', 'test') - -SASL_HOST = os.environ.get('SASL_HOST') -SASL_PORT = int(os.environ.get('SASL_PORT', '27017')) -SASL_USER = os.environ.get('SASL_USER') -SASL_PASS = os.environ.get('SASL_PASS') -SASL_DB = os.environ.get('SASL_DB', '$external') +GSSAPI_HOST = os.environ.get("GSSAPI_HOST") +GSSAPI_PORT = int(os.environ.get("GSSAPI_PORT", "27017")) +GSSAPI_PRINCIPAL = os.environ.get("GSSAPI_PRINCIPAL") +GSSAPI_SERVICE_NAME = os.environ.get("GSSAPI_SERVICE_NAME", "mongodb") +GSSAPI_CANONICALIZE = os.environ.get("GSSAPI_CANONICALIZE", "false") +GSSAPI_SERVICE_REALM = os.environ.get("GSSAPI_SERVICE_REALM") +GSSAPI_PASS = os.environ.get("GSSAPI_PASS") +GSSAPI_DB = os.environ.get("GSSAPI_DB", "test") + +SASL_HOST = os.environ.get("SASL_HOST") +SASL_PORT = int(os.environ.get("SASL_PORT", "27017")) +SASL_USER = os.environ.get("SASL_USER") +SASL_PASS = os.environ.get("SASL_PASS") +SASL_DB = os.environ.get("SASL_DB", "$external") class AutoAuthenticateThread(threading.Thread): @@ -70,45 +72,41 @@ def __init__(self, collection): self.success = False def run(self): - assert self.collection.find_one({'$where': delay(1)}) is not None + assert self.collection.find_one({"$where": delay(1)}) is not None self.success = True class TestGSSAPI(unittest.TestCase): - @classmethod def setUpClass(cls): if not HAVE_KERBEROS: - raise SkipTest('Kerberos module not available.') + raise SkipTest("Kerberos module not available.") if not GSSAPI_HOST or not GSSAPI_PRINCIPAL: - raise SkipTest( - 'Must set GSSAPI_HOST and GSSAPI_PRINCIPAL to test GSSAPI') + raise SkipTest("Must set GSSAPI_HOST and GSSAPI_PRINCIPAL to test GSSAPI") cls.service_realm_required = ( - GSSAPI_SERVICE_REALM is not None and - GSSAPI_SERVICE_REALM not in GSSAPI_PRINCIPAL) - mech_properties = 'SERVICE_NAME:%s' % (GSSAPI_SERVICE_NAME,) - mech_properties += ( - ',CANONICALIZE_HOST_NAME:%s' % (GSSAPI_CANONICALIZE,)) + GSSAPI_SERVICE_REALM is not None and GSSAPI_SERVICE_REALM not in GSSAPI_PRINCIPAL + ) + mech_properties = "SERVICE_NAME:%s" % (GSSAPI_SERVICE_NAME,) + mech_properties += ",CANONICALIZE_HOST_NAME:%s" % (GSSAPI_CANONICALIZE,) if GSSAPI_SERVICE_REALM is not None: - mech_properties += ',SERVICE_REALM:%s' % (GSSAPI_SERVICE_REALM,) + mech_properties += ",SERVICE_REALM:%s" % (GSSAPI_SERVICE_REALM,) cls.mech_properties = mech_properties def test_credentials_hashing(self): # GSSAPI credentials are properly hashed. - creds0 = _build_credentials_tuple( - 'GSSAPI', None, 'user', 'pass', {}, None) + creds0 = _build_credentials_tuple("GSSAPI", None, "user", "pass", {}, None) creds1 = _build_credentials_tuple( - 'GSSAPI', None, 'user', 'pass', - {'authmechanismproperties': {'SERVICE_NAME': 'A'}}, None) + "GSSAPI", None, "user", "pass", {"authmechanismproperties": {"SERVICE_NAME": "A"}}, None + ) creds2 = _build_credentials_tuple( - 'GSSAPI', None, 'user', 'pass', - {'authmechanismproperties': {'SERVICE_NAME': 'A'}}, None) + "GSSAPI", None, "user", "pass", {"authmechanismproperties": {"SERVICE_NAME": "A"}}, None + ) creds3 = _build_credentials_tuple( - 'GSSAPI', None, 'user', 'pass', - {'authmechanismproperties': {'SERVICE_NAME': 'B'}}, None) + "GSSAPI", None, "user", "pass", {"authmechanismproperties": {"SERVICE_NAME": "B"}}, None + ) self.assertEqual(1, len(set([creds1, creds2]))) self.assertEqual(3, len(set([creds0, creds1, creds2, creds3]))) @@ -116,24 +114,28 @@ def test_credentials_hashing(self): @ignore_deprecations def test_gssapi_simple(self): if GSSAPI_PASS is not None: - uri = ('mongodb://%s:%s@%s:%d/?authMechanism=' - 'GSSAPI' % (quote_plus(GSSAPI_PRINCIPAL), - GSSAPI_PASS, - GSSAPI_HOST, - GSSAPI_PORT)) + uri = "mongodb://%s:%s@%s:%d/?authMechanism=" "GSSAPI" % ( + quote_plus(GSSAPI_PRINCIPAL), + GSSAPI_PASS, + GSSAPI_HOST, + GSSAPI_PORT, + ) else: - uri = ('mongodb://%s@%s:%d/?authMechanism=' - 'GSSAPI' % (quote_plus(GSSAPI_PRINCIPAL), - GSSAPI_HOST, - GSSAPI_PORT)) + uri = "mongodb://%s@%s:%d/?authMechanism=" "GSSAPI" % ( + quote_plus(GSSAPI_PRINCIPAL), + GSSAPI_HOST, + GSSAPI_PORT, + ) if not self.service_realm_required: # Without authMechanismProperties. - client = MongoClient(GSSAPI_HOST, - GSSAPI_PORT, - username=GSSAPI_PRINCIPAL, - password=GSSAPI_PASS, - authMechanism='GSSAPI') + client = MongoClient( + GSSAPI_HOST, + GSSAPI_PORT, + username=GSSAPI_PRINCIPAL, + password=GSSAPI_PASS, + authMechanism="GSSAPI", + ) client[GSSAPI_DB].collection.find_one() @@ -142,60 +144,68 @@ def test_gssapi_simple(self): client[GSSAPI_DB].collection.find_one() # Authenticate with authMechanismProperties. - client = MongoClient(GSSAPI_HOST, - GSSAPI_PORT, - username=GSSAPI_PRINCIPAL, - password=GSSAPI_PASS, - authMechanism='GSSAPI', - authMechanismProperties=self.mech_properties) + client = MongoClient( + GSSAPI_HOST, + GSSAPI_PORT, + username=GSSAPI_PRINCIPAL, + password=GSSAPI_PASS, + authMechanism="GSSAPI", + authMechanismProperties=self.mech_properties, + ) client[GSSAPI_DB].collection.find_one() # Log in using URI, with authMechanismProperties. - mech_uri = uri + '&authMechanismProperties=%s' % (self.mech_properties,) + mech_uri = uri + "&authMechanismProperties=%s" % (self.mech_properties,) client = MongoClient(mech_uri) client[GSSAPI_DB].collection.find_one() - set_name = client.admin.command(HelloCompat.LEGACY_CMD).get('setName') + set_name = client.admin.command(HelloCompat.LEGACY_CMD).get("setName") if set_name: if not self.service_realm_required: # Without authMechanismProperties - client = MongoClient(GSSAPI_HOST, - GSSAPI_PORT, - username=GSSAPI_PRINCIPAL, - password=GSSAPI_PASS, - authMechanism='GSSAPI', - replicaSet=set_name) + client = MongoClient( + GSSAPI_HOST, + GSSAPI_PORT, + username=GSSAPI_PRINCIPAL, + password=GSSAPI_PASS, + authMechanism="GSSAPI", + replicaSet=set_name, + ) client[GSSAPI_DB].list_collection_names() - uri = uri + '&replicaSet=%s' % (str(set_name),) + uri = uri + "&replicaSet=%s" % (str(set_name),) client = MongoClient(uri) client[GSSAPI_DB].list_collection_names() # With authMechanismProperties - client = MongoClient(GSSAPI_HOST, - GSSAPI_PORT, - username=GSSAPI_PRINCIPAL, - password=GSSAPI_PASS, - authMechanism='GSSAPI', - authMechanismProperties=self.mech_properties, - replicaSet=set_name) + client = MongoClient( + GSSAPI_HOST, + GSSAPI_PORT, + username=GSSAPI_PRINCIPAL, + password=GSSAPI_PASS, + authMechanism="GSSAPI", + authMechanismProperties=self.mech_properties, + replicaSet=set_name, + ) client[GSSAPI_DB].list_collection_names() - mech_uri = mech_uri + '&replicaSet=%s' % (str(set_name),) + mech_uri = mech_uri + "&replicaSet=%s" % (str(set_name),) client = MongoClient(mech_uri) client[GSSAPI_DB].list_collection_names() @ignore_deprecations def test_gssapi_threaded(self): - client = MongoClient(GSSAPI_HOST, - GSSAPI_PORT, - username=GSSAPI_PRINCIPAL, - password=GSSAPI_PASS, - authMechanism='GSSAPI', - authMechanismProperties=self.mech_properties) + client = MongoClient( + GSSAPI_HOST, + GSSAPI_PORT, + username=GSSAPI_PRINCIPAL, + password=GSSAPI_PASS, + authMechanism="GSSAPI", + authMechanismProperties=self.mech_properties, + ) # Authentication succeeded? client.server_info() @@ -209,7 +219,7 @@ def test_gssapi_threaded(self): if not collection.count_documents({}): try: collection.drop() - collection.insert_one({'_id': 1}) + collection.insert_one({"_id": 1}) except OperationFailure: raise SkipTest("User must be able to write.") @@ -222,15 +232,17 @@ def test_gssapi_threaded(self): thread.join() self.assertTrue(thread.success) - set_name = client.admin.command(HelloCompat.LEGACY_CMD).get('setName') + set_name = client.admin.command(HelloCompat.LEGACY_CMD).get("setName") if set_name: - client = MongoClient(GSSAPI_HOST, - GSSAPI_PORT, - username=GSSAPI_PRINCIPAL, - password=GSSAPI_PASS, - authMechanism='GSSAPI', - authMechanismProperties=self.mech_properties, - replicaSet=set_name) + client = MongoClient( + GSSAPI_HOST, + GSSAPI_PORT, + username=GSSAPI_PRINCIPAL, + password=GSSAPI_PASS, + authMechanism="GSSAPI", + authMechanismProperties=self.mech_properties, + replicaSet=set_name, + ) # Succeeded? client.server_info() @@ -246,99 +258,107 @@ def test_gssapi_threaded(self): class TestSASLPlain(unittest.TestCase): - @classmethod def setUpClass(cls): if not SASL_HOST or not SASL_USER or not SASL_PASS: - raise SkipTest('Must set SASL_HOST, ' - 'SASL_USER, and SASL_PASS to test SASL') + raise SkipTest("Must set SASL_HOST, " "SASL_USER, and SASL_PASS to test SASL") def test_sasl_plain(self): - client = MongoClient(SASL_HOST, - SASL_PORT, - username=SASL_USER, - password=SASL_PASS, - authSource=SASL_DB, - authMechanism='PLAIN') + client = MongoClient( + SASL_HOST, + SASL_PORT, + username=SASL_USER, + password=SASL_PASS, + authSource=SASL_DB, + authMechanism="PLAIN", + ) client.ldap.test.find_one() - uri = ('mongodb://%s:%s@%s:%d/?authMechanism=PLAIN;' - 'authSource=%s' % (quote_plus(SASL_USER), - quote_plus(SASL_PASS), - SASL_HOST, SASL_PORT, SASL_DB)) + uri = "mongodb://%s:%s@%s:%d/?authMechanism=PLAIN;" "authSource=%s" % ( + quote_plus(SASL_USER), + quote_plus(SASL_PASS), + SASL_HOST, + SASL_PORT, + SASL_DB, + ) client = MongoClient(uri) client.ldap.test.find_one() - set_name = client.admin.command(HelloCompat.LEGACY_CMD).get('setName') + set_name = client.admin.command(HelloCompat.LEGACY_CMD).get("setName") if set_name: - client = MongoClient(SASL_HOST, - SASL_PORT, - replicaSet=set_name, - username=SASL_USER, - password=SASL_PASS, - authSource=SASL_DB, - authMechanism='PLAIN') + client = MongoClient( + SASL_HOST, + SASL_PORT, + replicaSet=set_name, + username=SASL_USER, + password=SASL_PASS, + authSource=SASL_DB, + authMechanism="PLAIN", + ) client.ldap.test.find_one() - uri = ('mongodb://%s:%s@%s:%d/?authMechanism=PLAIN;' - 'authSource=%s;replicaSet=%s' % (quote_plus(SASL_USER), - quote_plus(SASL_PASS), - SASL_HOST, SASL_PORT, - SASL_DB, str(set_name))) + uri = "mongodb://%s:%s@%s:%d/?authMechanism=PLAIN;" "authSource=%s;replicaSet=%s" % ( + quote_plus(SASL_USER), + quote_plus(SASL_PASS), + SASL_HOST, + SASL_PORT, + SASL_DB, + str(set_name), + ) client = MongoClient(uri) client.ldap.test.find_one() def test_sasl_plain_bad_credentials(self): def auth_string(user, password): - uri = ('mongodb://%s:%s@%s:%d/?authMechanism=PLAIN;' - 'authSource=%s' % (quote_plus(user), - quote_plus(password), - SASL_HOST, SASL_PORT, SASL_DB)) + uri = "mongodb://%s:%s@%s:%d/?authMechanism=PLAIN;" "authSource=%s" % ( + quote_plus(user), + quote_plus(password), + SASL_HOST, + SASL_PORT, + SASL_DB, + ) return uri - bad_user = MongoClient(auth_string('not-user', SASL_PASS)) - bad_pwd = MongoClient(auth_string(SASL_USER, 'not-pwd')) + bad_user = MongoClient(auth_string("not-user", SASL_PASS)) + bad_pwd = MongoClient(auth_string(SASL_USER, "not-pwd")) # OperationFailure raised upon connecting. - self.assertRaises(OperationFailure, bad_user.admin.command, 'ping') - self.assertRaises(OperationFailure, bad_pwd.admin.command, 'ping') + self.assertRaises(OperationFailure, bad_user.admin.command, "ping") + self.assertRaises(OperationFailure, bad_pwd.admin.command, "ping") class TestSCRAMSHA1(IntegrationTest): - @client_context.require_auth def setUp(self): super(TestSCRAMSHA1, self).setUp() - client_context.create_user( - 'pymongo_test', 'user', 'pass', roles=['userAdmin', 'readWrite']) + client_context.create_user("pymongo_test", "user", "pass", roles=["userAdmin", "readWrite"]) def tearDown(self): - client_context.drop_user('pymongo_test', 'user') + client_context.drop_user("pymongo_test", "user") super(TestSCRAMSHA1, self).tearDown() def test_scram_sha1(self): host, port = client_context.host, client_context.port client = rs_or_single_client_noauth( - 'mongodb://user:pass@%s:%d/pymongo_test?authMechanism=SCRAM-SHA-1' - % (host, port)) - client.pymongo_test.command('dbstats') + "mongodb://user:pass@%s:%d/pymongo_test?authMechanism=SCRAM-SHA-1" % (host, port) + ) + client.pymongo_test.command("dbstats") if client_context.is_rs: - uri = ('mongodb://user:pass' - '@%s:%d/pymongo_test?authMechanism=SCRAM-SHA-1' - '&replicaSet=%s' % (host, port, - client_context.replica_set_name)) + uri = ( + "mongodb://user:pass" + "@%s:%d/pymongo_test?authMechanism=SCRAM-SHA-1" + "&replicaSet=%s" % (host, port, client_context.replica_set_name) + ) client = single_client_noauth(uri) - client.pymongo_test.command('dbstats') - db = client.get_database( - 'pymongo_test', read_preference=ReadPreference.SECONDARY) - db.command('dbstats') + client.pymongo_test.command("dbstats") + db = client.get_database("pymongo_test", read_preference=ReadPreference.SECONDARY) + db.command("dbstats") # https://github.com/mongodb/specifications/blob/master/source/auth/auth.rst#scram-sha-256-and-mechanism-negotiation class TestSCRAM(IntegrationTest): - @client_context.require_auth @client_context.require_version_min(3, 7, 2) def setUp(self): @@ -356,114 +376,118 @@ def tearDown(self): def test_scram_skip_empty_exchange(self): listener = AllowListEventListener("saslStart", "saslContinue") client_context.create_user( - 'testscram', 'sha256', 'pwd', roles=['dbOwner'], - mechanisms=['SCRAM-SHA-256']) + "testscram", "sha256", "pwd", roles=["dbOwner"], mechanisms=["SCRAM-SHA-256"] + ) client = rs_or_single_client_noauth( - username='sha256', password='pwd', authSource='testscram', - event_listeners=[listener]) - client.testscram.command('dbstats') + username="sha256", password="pwd", authSource="testscram", event_listeners=[listener] + ) + client.testscram.command("dbstats") if client_context.version < (4, 4, -1): # Assert we sent the skipEmptyExchange option. - first_event = listener.results['started'][0] - self.assertEqual(first_event.command_name, 'saslStart') - self.assertEqual( - first_event.command['options'], {'skipEmptyExchange': True}) + first_event = listener.results["started"][0] + self.assertEqual(first_event.command_name, "saslStart") + self.assertEqual(first_event.command["options"], {"skipEmptyExchange": True}) # Assert the third exchange was skipped on servers that support it. # Note that the first exchange occurs on the connection handshake. started = listener.started_command_names() if client_context.version.at_least(4, 4, -1): - self.assertEqual(started, ['saslContinue']) + self.assertEqual(started, ["saslContinue"]) else: - self.assertEqual( - started, ['saslStart', 'saslContinue', 'saslContinue']) + self.assertEqual(started, ["saslStart", "saslContinue", "saslContinue"]) def test_scram(self): # Step 1: create users client_context.create_user( - 'testscram', 'sha1', 'pwd', roles=['dbOwner'], - mechanisms=['SCRAM-SHA-1']) + "testscram", "sha1", "pwd", roles=["dbOwner"], mechanisms=["SCRAM-SHA-1"] + ) client_context.create_user( - 'testscram', 'sha256', 'pwd', roles=['dbOwner'], - mechanisms=['SCRAM-SHA-256']) + "testscram", "sha256", "pwd", roles=["dbOwner"], mechanisms=["SCRAM-SHA-256"] + ) client_context.create_user( - 'testscram', 'both', 'pwd', roles=['dbOwner'], - mechanisms=['SCRAM-SHA-1', 'SCRAM-SHA-256']) + "testscram", + "both", + "pwd", + roles=["dbOwner"], + mechanisms=["SCRAM-SHA-1", "SCRAM-SHA-256"], + ) # Step 2: verify auth success cases - client = rs_or_single_client_noauth( - username='sha1', password='pwd', authSource='testscram') - client.testscram.command('dbstats') + client = rs_or_single_client_noauth(username="sha1", password="pwd", authSource="testscram") + client.testscram.command("dbstats") client = rs_or_single_client_noauth( - username='sha1', password='pwd', authSource='testscram', - authMechanism='SCRAM-SHA-1') - client.testscram.command('dbstats') + username="sha1", password="pwd", authSource="testscram", authMechanism="SCRAM-SHA-1" + ) + client.testscram.command("dbstats") client = rs_or_single_client_noauth( - username='sha256', password='pwd', authSource='testscram') - client.testscram.command('dbstats') + username="sha256", password="pwd", authSource="testscram" + ) + client.testscram.command("dbstats") client = rs_or_single_client_noauth( - username='sha256', password='pwd', authSource='testscram', - authMechanism='SCRAM-SHA-256') - client.testscram.command('dbstats') + username="sha256", password="pwd", authSource="testscram", authMechanism="SCRAM-SHA-256" + ) + client.testscram.command("dbstats") # Step 2: SCRAM-SHA-1 and SCRAM-SHA-256 client = rs_or_single_client_noauth( - username='both', password='pwd', authSource='testscram', - authMechanism='SCRAM-SHA-1') - client.testscram.command('dbstats') + username="both", password="pwd", authSource="testscram", authMechanism="SCRAM-SHA-1" + ) + client.testscram.command("dbstats") client = rs_or_single_client_noauth( - username='both', password='pwd', authSource='testscram', - authMechanism='SCRAM-SHA-256') - client.testscram.command('dbstats') + username="both", password="pwd", authSource="testscram", authMechanism="SCRAM-SHA-256" + ) + client.testscram.command("dbstats") self.listener.results.clear() client = rs_or_single_client_noauth( - username='both', password='pwd', authSource='testscram', - event_listeners=[self.listener]) - client.testscram.command('dbstats') + username="both", password="pwd", authSource="testscram", event_listeners=[self.listener] + ) + client.testscram.command("dbstats") if client_context.version.at_least(4, 4, -1): # Speculative authentication in 4.4+ sends saslStart with the # handshake. - self.assertEqual(self.listener.results['started'], []) + self.assertEqual(self.listener.results["started"], []) else: - started = self.listener.results['started'][0] - self.assertEqual(started.command.get('mechanism'), 'SCRAM-SHA-256') + started = self.listener.results["started"][0] + self.assertEqual(started.command.get("mechanism"), "SCRAM-SHA-256") # Step 3: verify auth failure conditions client = rs_or_single_client_noauth( - username='sha1', password='pwd', authSource='testscram', - authMechanism='SCRAM-SHA-256') + username="sha1", password="pwd", authSource="testscram", authMechanism="SCRAM-SHA-256" + ) with self.assertRaises(OperationFailure): - client.testscram.command('dbstats') + client.testscram.command("dbstats") client = rs_or_single_client_noauth( - username='sha256', password='pwd', authSource='testscram', - authMechanism='SCRAM-SHA-1') + username="sha256", password="pwd", authSource="testscram", authMechanism="SCRAM-SHA-1" + ) with self.assertRaises(OperationFailure): - client.testscram.command('dbstats') + client.testscram.command("dbstats") client = rs_or_single_client_noauth( - username='not-a-user', password='pwd', authSource='testscram') + username="not-a-user", password="pwd", authSource="testscram" + ) with self.assertRaises(OperationFailure): - client.testscram.command('dbstats') + client.testscram.command("dbstats") if client_context.is_rs: host, port = client_context.host, client_context.port - uri = ('mongodb://both:pwd@%s:%d/testscram' - '?replicaSet=%s' % (host, port, - client_context.replica_set_name)) + uri = "mongodb://both:pwd@%s:%d/testscram" "?replicaSet=%s" % ( + host, + port, + client_context.replica_set_name, + ) client = single_client_noauth(uri) - client.testscram.command('dbstats') - db = client.get_database( - 'testscram', read_preference=ReadPreference.SECONDARY) - db.command('dbstats') + client.testscram.command("dbstats") + db = client.get_database("testscram", read_preference=ReadPreference.SECONDARY) + db.command("dbstats") - @unittest.skipUnless(HAVE_STRINGPREP, 'Cannot test without stringprep') + @unittest.skipUnless(HAVE_STRINGPREP, "Cannot test without stringprep") def test_scram_saslprep(self): # Step 4: test SASLprep host, port = client_context.host, client_context.port @@ -472,59 +496,66 @@ def test_scram_saslprep(self): # becomes 'IX'. SASLprep is only supported when the standard # library provides stringprep. client_context.create_user( - 'testscram', '\u2168', '\u2163', roles=['dbOwner'], - mechanisms=['SCRAM-SHA-256']) + "testscram", "\u2168", "\u2163", roles=["dbOwner"], mechanisms=["SCRAM-SHA-256"] + ) client_context.create_user( - 'testscram', 'IX', 'IX', roles=['dbOwner'], - mechanisms=['SCRAM-SHA-256']) + "testscram", "IX", "IX", roles=["dbOwner"], mechanisms=["SCRAM-SHA-256"] + ) client = rs_or_single_client_noauth( - username='\u2168', password='\u2163', authSource='testscram') - client.testscram.command('dbstats') + username="\u2168", password="\u2163", authSource="testscram" + ) + client.testscram.command("dbstats") client = rs_or_single_client_noauth( - username='\u2168', password='\u2163', authSource='testscram', - authMechanism='SCRAM-SHA-256') - client.testscram.command('dbstats') + username="\u2168", + password="\u2163", + authSource="testscram", + authMechanism="SCRAM-SHA-256", + ) + client.testscram.command("dbstats") client = rs_or_single_client_noauth( - username='\u2168', password='IV', authSource='testscram') - client.testscram.command('dbstats') + username="\u2168", password="IV", authSource="testscram" + ) + client.testscram.command("dbstats") client = rs_or_single_client_noauth( - username='IX', password='I\u00ADX', authSource='testscram') - client.testscram.command('dbstats') + username="IX", password="I\u00ADX", authSource="testscram" + ) + client.testscram.command("dbstats") client = rs_or_single_client_noauth( - username='IX', password='I\u00ADX', authSource='testscram', - authMechanism='SCRAM-SHA-256') - client.testscram.command('dbstats') + username="IX", + password="I\u00ADX", + authSource="testscram", + authMechanism="SCRAM-SHA-256", + ) + client.testscram.command("dbstats") client = rs_or_single_client_noauth( - username='IX', password='IX', authSource='testscram', - authMechanism='SCRAM-SHA-256') - client.testscram.command('dbstats') + username="IX", password="IX", authSource="testscram", authMechanism="SCRAM-SHA-256" + ) + client.testscram.command("dbstats") client = rs_or_single_client_noauth( - 'mongodb://\u2168:\u2163@%s:%d/testscram' % (host, port)) - client.testscram.command('dbstats') - client = rs_or_single_client_noauth( - 'mongodb://\u2168:IV@%s:%d/testscram' % (host, port)) - client.testscram.command('dbstats') + "mongodb://\u2168:\u2163@%s:%d/testscram" % (host, port) + ) + client.testscram.command("dbstats") + client = rs_or_single_client_noauth("mongodb://\u2168:IV@%s:%d/testscram" % (host, port)) + client.testscram.command("dbstats") - client = rs_or_single_client_noauth( - 'mongodb://IX:I\u00ADX@%s:%d/testscram' % (host, port)) - client.testscram.command('dbstats') - client = rs_or_single_client_noauth( - 'mongodb://IX:IX@%s:%d/testscram' % (host, port)) - client.testscram.command('dbstats') + client = rs_or_single_client_noauth("mongodb://IX:I\u00ADX@%s:%d/testscram" % (host, port)) + client.testscram.command("dbstats") + client = rs_or_single_client_noauth("mongodb://IX:IX@%s:%d/testscram" % (host, port)) + client.testscram.command("dbstats") def test_cache(self): client = single_client() # Force authentication. - client.admin.command('ping') + client.admin.command("ping") all_credentials = client._MongoClient__all_credentials - credentials = all_credentials.get('admin') + credentials = all_credentials.get("admin") cache = credentials.cache self.assertIsNotNone(cache) data = cache.data @@ -553,7 +584,7 @@ def test_scram_threaded(self): coll = client_context.client.db.test coll.drop() - coll.insert_one({'_id': 1}) + coll.insert_one({"_id": 1}) # The first thread to call find() will authenticate coll = rs_or_single_client().db.test @@ -568,71 +599,68 @@ def test_scram_threaded(self): class TestAuthURIOptions(IntegrationTest): - @client_context.require_auth def setUp(self): super(TestAuthURIOptions, self).setUp() - client_context.create_user('admin', 'admin', 'pass') - client_context.create_user( - 'pymongo_test', 'user', 'pass', ['userAdmin', 'readWrite']) + client_context.create_user("admin", "admin", "pass") + client_context.create_user("pymongo_test", "user", "pass", ["userAdmin", "readWrite"]) def tearDown(self): - client_context.drop_user('pymongo_test', 'user') - client_context.drop_user('admin', 'admin') + client_context.drop_user("pymongo_test", "user") + client_context.drop_user("admin", "admin") super(TestAuthURIOptions, self).tearDown() def test_uri_options(self): # Test default to admin host, port = client_context.host, client_context.port - client = rs_or_single_client_noauth( - 'mongodb://admin:pass@%s:%d' % (host, port)) - self.assertTrue(client.admin.command('dbstats')) + client = rs_or_single_client_noauth("mongodb://admin:pass@%s:%d" % (host, port)) + self.assertTrue(client.admin.command("dbstats")) if client_context.is_rs: - uri = ('mongodb://admin:pass@%s:%d/?replicaSet=%s' % ( - host, port, client_context.replica_set_name)) + uri = "mongodb://admin:pass@%s:%d/?replicaSet=%s" % ( + host, + port, + client_context.replica_set_name, + ) client = single_client_noauth(uri) - self.assertTrue(client.admin.command('dbstats')) - db = client.get_database( - 'admin', read_preference=ReadPreference.SECONDARY) - self.assertTrue(db.command('dbstats')) + self.assertTrue(client.admin.command("dbstats")) + db = client.get_database("admin", read_preference=ReadPreference.SECONDARY) + self.assertTrue(db.command("dbstats")) # Test explicit database - uri = 'mongodb://user:pass@%s:%d/pymongo_test' % (host, port) + uri = "mongodb://user:pass@%s:%d/pymongo_test" % (host, port) client = rs_or_single_client_noauth(uri) - self.assertRaises(OperationFailure, client.admin.command, 'dbstats') - self.assertTrue(client.pymongo_test.command('dbstats')) + self.assertRaises(OperationFailure, client.admin.command, "dbstats") + self.assertTrue(client.pymongo_test.command("dbstats")) if client_context.is_rs: - uri = ('mongodb://user:pass@%s:%d/pymongo_test?replicaSet=%s' % ( - host, port, client_context.replica_set_name)) + uri = "mongodb://user:pass@%s:%d/pymongo_test?replicaSet=%s" % ( + host, + port, + client_context.replica_set_name, + ) client = single_client_noauth(uri) - self.assertRaises(OperationFailure, - client.admin.command, 'dbstats') - self.assertTrue(client.pymongo_test.command('dbstats')) - db = client.get_database( - 'pymongo_test', read_preference=ReadPreference.SECONDARY) - self.assertTrue(db.command('dbstats')) + self.assertRaises(OperationFailure, client.admin.command, "dbstats") + self.assertTrue(client.pymongo_test.command("dbstats")) + db = client.get_database("pymongo_test", read_preference=ReadPreference.SECONDARY) + self.assertTrue(db.command("dbstats")) # Test authSource - uri = ('mongodb://user:pass@%s:%d' - '/pymongo_test2?authSource=pymongo_test' % (host, port)) + uri = "mongodb://user:pass@%s:%d" "/pymongo_test2?authSource=pymongo_test" % (host, port) client = rs_or_single_client_noauth(uri) - self.assertRaises(OperationFailure, - client.pymongo_test2.command, 'dbstats') - self.assertTrue(client.pymongo_test.command('dbstats')) + self.assertRaises(OperationFailure, client.pymongo_test2.command, "dbstats") + self.assertTrue(client.pymongo_test.command("dbstats")) if client_context.is_rs: - uri = ('mongodb://user:pass@%s:%d/pymongo_test2?replicaSet=' - '%s;authSource=pymongo_test' % ( - host, port, client_context.replica_set_name)) + uri = ( + "mongodb://user:pass@%s:%d/pymongo_test2?replicaSet=" + "%s;authSource=pymongo_test" % (host, port, client_context.replica_set_name) + ) client = single_client_noauth(uri) - self.assertRaises(OperationFailure, - client.pymongo_test2.command, 'dbstats') - self.assertTrue(client.pymongo_test.command('dbstats')) - db = client.get_database( - 'pymongo_test', read_preference=ReadPreference.SECONDARY) - self.assertTrue(db.command('dbstats')) + self.assertRaises(OperationFailure, client.pymongo_test2.command, "dbstats") + self.assertTrue(client.pymongo_test.command("dbstats")) + db = client.get_database("pymongo_test", read_preference=ReadPreference.SECONDARY) + self.assertTrue(db.command("dbstats")) if __name__ == "__main__": diff --git a/test/test_auth_spec.py b/test/test_auth_spec.py index 8bf0dcb21c..449ae90bd2 100644 --- a/test/test_auth_spec.py +++ b/test/test_auth_spec.py @@ -21,12 +21,11 @@ sys.path[0:0] = [""] -from pymongo import MongoClient from test import unittest +from pymongo import MongoClient -_TEST_PATH = os.path.join( - os.path.dirname(os.path.realpath(__file__)), 'auth') +_TEST_PATH = os.path.join(os.path.dirname(os.path.realpath(__file__)), "auth") class TestAuthSpec(unittest.TestCase): @@ -34,11 +33,10 @@ class TestAuthSpec(unittest.TestCase): def create_test(test_case): - def run_test(self): - uri = test_case['uri'] - valid = test_case['valid'] - credential = test_case.get('credential') + uri = test_case["uri"] + valid = test_case["valid"] + credential = test_case.get("credential") if not valid: self.assertRaises(Exception, MongoClient, uri, connect=False) @@ -49,39 +47,34 @@ def run_test(self): self.assertIsNone(credentials) else: self.assertIsNotNone(credentials) - self.assertEqual(credentials.username, credential['username']) - self.assertEqual(credentials.password, credential['password']) - self.assertEqual(credentials.source, credential['source']) - if credential['mechanism'] is not None: - self.assertEqual( - credentials.mechanism, credential['mechanism']) + self.assertEqual(credentials.username, credential["username"]) + self.assertEqual(credentials.password, credential["password"]) + self.assertEqual(credentials.source, credential["source"]) + if credential["mechanism"] is not None: + self.assertEqual(credentials.mechanism, credential["mechanism"]) else: - self.assertEqual(credentials.mechanism, 'DEFAULT') - expected = credential['mechanism_properties'] + self.assertEqual(credentials.mechanism, "DEFAULT") + expected = credential["mechanism_properties"] if expected is not None: actual = credentials.mechanism_properties for key, val in expected.items(): - if 'SERVICE_NAME' in expected: - self.assertEqual( - actual.service_name, expected['SERVICE_NAME']) - elif 'CANONICALIZE_HOST_NAME' in expected: - self.assertEqual( - actual.canonicalize_host_name, - expected['CANONICALIZE_HOST_NAME']) - elif 'SERVICE_REALM' in expected: + if "SERVICE_NAME" in expected: + self.assertEqual(actual.service_name, expected["SERVICE_NAME"]) + elif "CANONICALIZE_HOST_NAME" in expected: self.assertEqual( - actual.service_realm, - expected['SERVICE_REALM']) - elif 'AWS_SESSION_TOKEN' in expected: + actual.canonicalize_host_name, expected["CANONICALIZE_HOST_NAME"] + ) + elif "SERVICE_REALM" in expected: + self.assertEqual(actual.service_realm, expected["SERVICE_REALM"]) + elif "AWS_SESSION_TOKEN" in expected: self.assertEqual( - actual.aws_session_token, - expected['AWS_SESSION_TOKEN']) + actual.aws_session_token, expected["AWS_SESSION_TOKEN"] + ) else: - self.fail('Unhandled property: %s' % (key,)) + self.fail("Unhandled property: %s" % (key,)) else: - if credential['mechanism'] == 'MONGODB-AWS': - self.assertIsNone( - credentials.mechanism_properties.aws_session_token) + if credential["mechanism"] == "MONGODB-AWS": + self.assertIsNone(credentials.mechanism_properties.aws_session_token) else: self.assertIsNone(credentials.mechanism_properties) @@ -89,19 +82,16 @@ def run_test(self): def create_tests(): - for filename in glob.glob(os.path.join(_TEST_PATH, '*.json')): + for filename in glob.glob(os.path.join(_TEST_PATH, "*.json")): test_suffix, _ = os.path.splitext(os.path.basename(filename)) with open(filename) as auth_tests: - test_cases = json.load(auth_tests)['tests'] + test_cases = json.load(auth_tests)["tests"] for test_case in test_cases: - if test_case.get('optional', False): + if test_case.get("optional", False): continue test_method = create_test(test_case) - name = str(test_case['description'].lower().replace(' ', '_')) - setattr( - TestAuthSpec, - 'test_%s_%s' % (test_suffix, name), - test_method) + name = str(test_case["description"].lower().replace(" ", "_")) + setattr(TestAuthSpec, "test_%s_%s" % (test_suffix, name), test_method) create_tests() diff --git a/test/test_binary.py b/test/test_binary.py index e6b681fc51..39777d1c58 100644 --- a/test/test_binary.py +++ b/test/test_binary.py @@ -25,58 +25,57 @@ sys.path[0:0] = [""] -import bson +from test import IntegrationTest, client_context, unittest +from test.utils import ignore_deprecations +import bson from bson import decode, encode from bson.binary import * from bson.codec_options import CodecOptions from bson.son import SON - from pymongo.common import validate_uuid_representation from pymongo.mongo_client import MongoClient from pymongo.write_concern import WriteConcern -from test import client_context, unittest, IntegrationTest -from test.utils import ignore_deprecations - class TestBinary(unittest.TestCase): - @classmethod def setUpClass(cls): # Generated by the Java driver from_java = ( - b'bAAAAAdfaWQAUCBQxkVm+XdxJ9tOBW5ld2d1aWQAEAAAAAMIQkfACFu' - b'Z/0RustLOU/G6Am5ld2d1aWRzdHJpbmcAJQAAAGZmOTk1YjA4LWMwND' - b'ctNDIwOC1iYWYxLTUzY2VkMmIyNmU0NAAAbAAAAAdfaWQAUCBQxkVm+' - b'XdxJ9tPBW5ld2d1aWQAEAAAAANgS/xhRXXv8kfIec+dYdyCAm5ld2d1' - b'aWRzdHJpbmcAJQAAAGYyZWY3NTQ1LTYxZmMtNGI2MC04MmRjLTYxOWR' - b'jZjc5Yzg0NwAAbAAAAAdfaWQAUCBQxkVm+XdxJ9tQBW5ld2d1aWQAEA' - b'AAAAPqREIbhZPUJOSdHCJIgaqNAm5ld2d1aWRzdHJpbmcAJQAAADI0Z' - b'DQ5Mzg1LTFiNDItNDRlYS04ZGFhLTgxNDgyMjFjOWRlNAAAbAAAAAdf' - b'aWQAUCBQxkVm+XdxJ9tRBW5ld2d1aWQAEAAAAANjQBn/aQuNfRyfNyx' - b'29COkAm5ld2d1aWRzdHJpbmcAJQAAADdkOGQwYjY5LWZmMTktNDA2My' - b'1hNDIzLWY0NzYyYzM3OWYxYwAAbAAAAAdfaWQAUCBQxkVm+XdxJ9tSB' - b'W5ld2d1aWQAEAAAAAMtSv/Et1cAQUFHUYevqxaLAm5ld2d1aWRzdHJp' - b'bmcAJQAAADQxMDA1N2I3LWM0ZmYtNGEyZC04YjE2LWFiYWY4NzUxNDc' - b'0MQAA') + b"bAAAAAdfaWQAUCBQxkVm+XdxJ9tOBW5ld2d1aWQAEAAAAAMIQkfACFu" + b"Z/0RustLOU/G6Am5ld2d1aWRzdHJpbmcAJQAAAGZmOTk1YjA4LWMwND" + b"ctNDIwOC1iYWYxLTUzY2VkMmIyNmU0NAAAbAAAAAdfaWQAUCBQxkVm+" + b"XdxJ9tPBW5ld2d1aWQAEAAAAANgS/xhRXXv8kfIec+dYdyCAm5ld2d1" + b"aWRzdHJpbmcAJQAAAGYyZWY3NTQ1LTYxZmMtNGI2MC04MmRjLTYxOWR" + b"jZjc5Yzg0NwAAbAAAAAdfaWQAUCBQxkVm+XdxJ9tQBW5ld2d1aWQAEA" + b"AAAAPqREIbhZPUJOSdHCJIgaqNAm5ld2d1aWRzdHJpbmcAJQAAADI0Z" + b"DQ5Mzg1LTFiNDItNDRlYS04ZGFhLTgxNDgyMjFjOWRlNAAAbAAAAAdf" + b"aWQAUCBQxkVm+XdxJ9tRBW5ld2d1aWQAEAAAAANjQBn/aQuNfRyfNyx" + b"29COkAm5ld2d1aWRzdHJpbmcAJQAAADdkOGQwYjY5LWZmMTktNDA2My" + b"1hNDIzLWY0NzYyYzM3OWYxYwAAbAAAAAdfaWQAUCBQxkVm+XdxJ9tSB" + b"W5ld2d1aWQAEAAAAAMtSv/Et1cAQUFHUYevqxaLAm5ld2d1aWRzdHJp" + b"bmcAJQAAADQxMDA1N2I3LWM0ZmYtNGEyZC04YjE2LWFiYWY4NzUxNDc" + b"0MQAA" + ) cls.java_data = base64.b64decode(from_java) # Generated by the .net driver from_csharp = ( - b'ZAAAABBfaWQAAAAAAAVuZXdndWlkABAAAAAD+MkoCd/Jy0iYJ7Vhl' - b'iF3BAJuZXdndWlkc3RyaW5nACUAAAAwOTI4YzlmOC1jOWRmLTQ4Y2' - b'ItOTgyNy1iNTYxOTYyMTc3MDQAAGQAAAAQX2lkAAEAAAAFbmV3Z3V' - b'pZAAQAAAAA9MD0oXQe6VOp7mK4jkttWUCbmV3Z3VpZHN0cmluZwAl' - b'AAAAODVkMjAzZDMtN2JkMC00ZWE1LWE3YjktOGFlMjM5MmRiNTY1A' - b'ABkAAAAEF9pZAACAAAABW5ld2d1aWQAEAAAAAPRmIO2auc/Tprq1Z' - b'oQ1oNYAm5ld2d1aWRzdHJpbmcAJQAAAGI2ODM5OGQxLWU3NmEtNGU' - b'zZi05YWVhLWQ1OWExMGQ2ODM1OAAAZAAAABBfaWQAAwAAAAVuZXdn' - b'dWlkABAAAAADISpriopuTEaXIa7arYOCFAJuZXdndWlkc3RyaW5nA' - b'CUAAAA4YTZiMmEyMS02ZThhLTQ2NGMtOTcyMS1hZWRhYWQ4MzgyMT' - b'QAAGQAAAAQX2lkAAQAAAAFbmV3Z3VpZAAQAAAAA98eg0CFpGlPihP' - b'MwOmYGOMCbmV3Z3VpZHN0cmluZwAlAAAANDA4MzFlZGYtYTQ4NS00' - b'ZjY5LThhMTMtY2NjMGU5OTgxOGUzAAA=') + b"ZAAAABBfaWQAAAAAAAVuZXdndWlkABAAAAAD+MkoCd/Jy0iYJ7Vhl" + b"iF3BAJuZXdndWlkc3RyaW5nACUAAAAwOTI4YzlmOC1jOWRmLTQ4Y2" + b"ItOTgyNy1iNTYxOTYyMTc3MDQAAGQAAAAQX2lkAAEAAAAFbmV3Z3V" + b"pZAAQAAAAA9MD0oXQe6VOp7mK4jkttWUCbmV3Z3VpZHN0cmluZwAl" + b"AAAAODVkMjAzZDMtN2JkMC00ZWE1LWE3YjktOGFlMjM5MmRiNTY1A" + b"ABkAAAAEF9pZAACAAAABW5ld2d1aWQAEAAAAAPRmIO2auc/Tprq1Z" + b"oQ1oNYAm5ld2d1aWRzdHJpbmcAJQAAAGI2ODM5OGQxLWU3NmEtNGU" + b"zZi05YWVhLWQ1OWExMGQ2ODM1OAAAZAAAABBfaWQAAwAAAAVuZXdn" + b"dWlkABAAAAADISpriopuTEaXIa7arYOCFAJuZXdndWlkc3RyaW5nA" + b"CUAAAA4YTZiMmEyMS02ZThhLTQ2NGMtOTcyMS1hZWRhYWQ4MzgyMT" + b"QAAGQAAAAQX2lkAAQAAAAFbmV3Z3VpZAAQAAAAA98eg0CFpGlPihP" + b"MwOmYGOMCbmV3Z3VpZHN0cmluZwAlAAAANDA4MzFlZGYtYTQ4NS00" + b"ZjY5LThhMTMtY2NjMGU5OTgxOGUzAAA=" + ) cls.csharp_data = base64.b64decode(from_csharp) def test_binary(self): @@ -122,20 +121,15 @@ def test_equality(self): def test_repr(self): one = Binary(b"hello world") - self.assertEqual(repr(one), - "Binary(%s, 0)" % (repr(b"hello world"),)) + self.assertEqual(repr(one), "Binary(%s, 0)" % (repr(b"hello world"),)) two = Binary(b"hello world", 2) - self.assertEqual(repr(two), - "Binary(%s, 2)" % (repr(b"hello world"),)) + self.assertEqual(repr(two), "Binary(%s, 2)" % (repr(b"hello world"),)) three = Binary(b"\x08\xFF") - self.assertEqual(repr(three), - "Binary(%s, 0)" % (repr(b"\x08\xFF"),)) + self.assertEqual(repr(three), "Binary(%s, 0)" % (repr(b"\x08\xFF"),)) four = Binary(b"\x08\xFF", 2) - self.assertEqual(repr(four), - "Binary(%s, 2)" % (repr(b"\x08\xFF"),)) + self.assertEqual(repr(four), "Binary(%s, 2)" % (repr(b"\x08\xFF"),)) five = Binary(b"test", 100) - self.assertEqual(repr(five), - "Binary(%s, 100)" % (repr(b"test"),)) + self.assertEqual(repr(five), "Binary(%s, 100)" % (repr(b"test"),)) def test_hash(self): one = Binary(b"hello world") @@ -150,9 +144,11 @@ def test_uuid_subtype_4(self): expected_bin = Binary(expected_uuid.bytes, 4) doc = {"uuid": expected_bin} encoded = encode(doc) - for uuid_rep in (UuidRepresentation.PYTHON_LEGACY, - UuidRepresentation.JAVA_LEGACY, - UuidRepresentation.CSHARP_LEGACY): + for uuid_rep in ( + UuidRepresentation.PYTHON_LEGACY, + UuidRepresentation.JAVA_LEGACY, + UuidRepresentation.CSHARP_LEGACY, + ): opts = CodecOptions(uuid_representation=uuid_rep) self.assertEqual(expected_bin, decode(encoded, opts)["uuid"]) opts = CodecOptions(uuid_representation=UuidRepresentation.STANDARD) @@ -163,39 +159,39 @@ def test_legacy_java_uuid(self): data = self.java_data docs = bson.decode_all(data, CodecOptions(SON, False, PYTHON_LEGACY)) for d in docs: - self.assertNotEqual(d['newguid'], uuid.UUID(d['newguidstring'])) + self.assertNotEqual(d["newguid"], uuid.UUID(d["newguidstring"])) docs = bson.decode_all(data, CodecOptions(SON, False, STANDARD)) for d in docs: - self.assertNotEqual(d['newguid'], uuid.UUID(d['newguidstring'])) + self.assertNotEqual(d["newguid"], uuid.UUID(d["newguidstring"])) docs = bson.decode_all(data, CodecOptions(SON, False, CSHARP_LEGACY)) for d in docs: - self.assertNotEqual(d['newguid'], uuid.UUID(d['newguidstring'])) + self.assertNotEqual(d["newguid"], uuid.UUID(d["newguidstring"])) docs = bson.decode_all(data, CodecOptions(SON, False, JAVA_LEGACY)) for d in docs: - self.assertEqual(d['newguid'], uuid.UUID(d['newguidstring'])) + self.assertEqual(d["newguid"], uuid.UUID(d["newguidstring"])) # Test encoding - encoded = b''.join([ - encode(doc, False, CodecOptions(uuid_representation=PYTHON_LEGACY)) - for doc in docs]) + encoded = b"".join( + [encode(doc, False, CodecOptions(uuid_representation=PYTHON_LEGACY)) for doc in docs] + ) self.assertNotEqual(data, encoded) - encoded = b''.join([ - encode(doc, False, CodecOptions(uuid_representation=STANDARD)) - for doc in docs]) + encoded = b"".join( + [encode(doc, False, CodecOptions(uuid_representation=STANDARD)) for doc in docs] + ) self.assertNotEqual(data, encoded) - encoded = b''.join([ - encode(doc, False, CodecOptions(uuid_representation=CSHARP_LEGACY)) - for doc in docs]) + encoded = b"".join( + [encode(doc, False, CodecOptions(uuid_representation=CSHARP_LEGACY)) for doc in docs] + ) self.assertNotEqual(data, encoded) - encoded = b''.join([ - encode(doc, False, CodecOptions(uuid_representation=JAVA_LEGACY)) - for doc in docs]) + encoded = b"".join( + [encode(doc, False, CodecOptions(uuid_representation=JAVA_LEGACY)) for doc in docs] + ) self.assertEqual(data, encoded) @client_context.require_connection @@ -203,21 +199,19 @@ def test_legacy_java_uuid_roundtrip(self): data = self.java_data docs = bson.decode_all(data, CodecOptions(SON, False, JAVA_LEGACY)) - client_context.client.pymongo_test.drop_collection('java_uuid') + client_context.client.pymongo_test.drop_collection("java_uuid") db = client_context.client.pymongo_test - coll = db.get_collection( - 'java_uuid', CodecOptions(uuid_representation=JAVA_LEGACY)) + coll = db.get_collection("java_uuid", CodecOptions(uuid_representation=JAVA_LEGACY)) coll.insert_many(docs) self.assertEqual(5, coll.count_documents({})) for d in coll.find(): - self.assertEqual(d['newguid'], uuid.UUID(d['newguidstring'])) + self.assertEqual(d["newguid"], uuid.UUID(d["newguidstring"])) - coll = db.get_collection( - 'java_uuid', CodecOptions(uuid_representation=PYTHON_LEGACY)) + coll = db.get_collection("java_uuid", CodecOptions(uuid_representation=PYTHON_LEGACY)) for d in coll.find(): - self.assertNotEqual(d['newguid'], d['newguidstring']) - client_context.client.pymongo_test.drop_collection('java_uuid') + self.assertNotEqual(d["newguid"], d["newguidstring"]) + client_context.client.pymongo_test.drop_collection("java_uuid") def test_legacy_csharp_uuid(self): data = self.csharp_data @@ -225,39 +219,39 @@ def test_legacy_csharp_uuid(self): # Test decoding docs = bson.decode_all(data, CodecOptions(SON, False, PYTHON_LEGACY)) for d in docs: - self.assertNotEqual(d['newguid'], uuid.UUID(d['newguidstring'])) + self.assertNotEqual(d["newguid"], uuid.UUID(d["newguidstring"])) docs = bson.decode_all(data, CodecOptions(SON, False, STANDARD)) for d in docs: - self.assertNotEqual(d['newguid'], uuid.UUID(d['newguidstring'])) + self.assertNotEqual(d["newguid"], uuid.UUID(d["newguidstring"])) docs = bson.decode_all(data, CodecOptions(SON, False, JAVA_LEGACY)) for d in docs: - self.assertNotEqual(d['newguid'], uuid.UUID(d['newguidstring'])) + self.assertNotEqual(d["newguid"], uuid.UUID(d["newguidstring"])) docs = bson.decode_all(data, CodecOptions(SON, False, CSHARP_LEGACY)) for d in docs: - self.assertEqual(d['newguid'], uuid.UUID(d['newguidstring'])) + self.assertEqual(d["newguid"], uuid.UUID(d["newguidstring"])) # Test encoding - encoded = b''.join([ - encode(doc, False, CodecOptions(uuid_representation=PYTHON_LEGACY)) - for doc in docs]) + encoded = b"".join( + [encode(doc, False, CodecOptions(uuid_representation=PYTHON_LEGACY)) for doc in docs] + ) self.assertNotEqual(data, encoded) - encoded = b''.join([ - encode(doc, False, CodecOptions(uuid_representation=STANDARD)) - for doc in docs]) + encoded = b"".join( + [encode(doc, False, CodecOptions(uuid_representation=STANDARD)) for doc in docs] + ) self.assertNotEqual(data, encoded) - encoded = b''.join([ - encode(doc, False, CodecOptions(uuid_representation=JAVA_LEGACY)) - for doc in docs]) + encoded = b"".join( + [encode(doc, False, CodecOptions(uuid_representation=JAVA_LEGACY)) for doc in docs] + ) self.assertNotEqual(data, encoded) - encoded = b''.join([ - encode(doc, False, CodecOptions(uuid_representation=CSHARP_LEGACY)) - for doc in docs]) + encoded = b"".join( + [encode(doc, False, CodecOptions(uuid_representation=CSHARP_LEGACY)) for doc in docs] + ) self.assertEqual(data, encoded) @client_context.require_connection @@ -265,29 +259,25 @@ def test_legacy_csharp_uuid_roundtrip(self): data = self.csharp_data docs = bson.decode_all(data, CodecOptions(SON, False, CSHARP_LEGACY)) - client_context.client.pymongo_test.drop_collection('csharp_uuid') + client_context.client.pymongo_test.drop_collection("csharp_uuid") db = client_context.client.pymongo_test - coll = db.get_collection( - 'csharp_uuid', CodecOptions(uuid_representation=CSHARP_LEGACY)) + coll = db.get_collection("csharp_uuid", CodecOptions(uuid_representation=CSHARP_LEGACY)) coll.insert_many(docs) self.assertEqual(5, coll.count_documents({})) for d in coll.find(): - self.assertEqual(d['newguid'], uuid.UUID(d['newguidstring'])) + self.assertEqual(d["newguid"], uuid.UUID(d["newguidstring"])) - coll = db.get_collection( - 'csharp_uuid', CodecOptions(uuid_representation=PYTHON_LEGACY)) + coll = db.get_collection("csharp_uuid", CodecOptions(uuid_representation=PYTHON_LEGACY)) for d in coll.find(): - self.assertNotEqual(d['newguid'], d['newguidstring']) - client_context.client.pymongo_test.drop_collection('csharp_uuid') + self.assertNotEqual(d["newguid"], d["newguidstring"]) + client_context.client.pymongo_test.drop_collection("csharp_uuid") def test_uri_to_uuid(self): uri = "mongodb://foo/?uuidrepresentation=csharpLegacy" client = MongoClient(uri, connect=False) - self.assertEqual( - client.pymongo_test.test.codec_options.uuid_representation, - CSHARP_LEGACY) + self.assertEqual(client.pymongo_test.test.codec_options.uuid_representation, CSHARP_LEGACY) @client_context.require_connection def test_uuid_queries(self): @@ -296,37 +286,39 @@ def test_uuid_queries(self): coll.drop() uu = uuid.uuid4() - coll.insert_one({'uuid': Binary(uu.bytes, 3)}) + coll.insert_one({"uuid": Binary(uu.bytes, 3)}) self.assertEqual(1, coll.count_documents({})) # Test regular UUID queries (using subtype 4). coll = db.get_collection( - "test", CodecOptions( - uuid_representation=UuidRepresentation.STANDARD)) - self.assertEqual(0, coll.count_documents({'uuid': uu})) - coll.insert_one({'uuid': uu}) + "test", CodecOptions(uuid_representation=UuidRepresentation.STANDARD) + ) + self.assertEqual(0, coll.count_documents({"uuid": uu})) + coll.insert_one({"uuid": uu}) self.assertEqual(2, coll.count_documents({})) - docs = list(coll.find({'uuid': uu})) + docs = list(coll.find({"uuid": uu})) self.assertEqual(1, len(docs)) - self.assertEqual(uu, docs[0]['uuid']) + self.assertEqual(uu, docs[0]["uuid"]) # Test both. uu_legacy = Binary.from_uuid(uu, UuidRepresentation.PYTHON_LEGACY) - predicate = {'uuid': {'$in': [uu, uu_legacy]}} + predicate = {"uuid": {"$in": [uu, uu_legacy]}} self.assertEqual(2, coll.count_documents(predicate)) docs = list(coll.find(predicate)) self.assertEqual(2, len(docs)) coll.drop() def test_pickle(self): - b1 = Binary(b'123', 2) + b1 = Binary(b"123", 2) # For testing backwards compatibility with pre-2.4 pymongo - p = (b"\x80\x03cbson.binary\nBinary\nq\x00C\x03123q\x01\x85q" - b"\x02\x81q\x03}q\x04X\x10\x00\x00\x00_Binary__subtypeq" - b"\x05K\x02sb.") + p = ( + b"\x80\x03cbson.binary\nBinary\nq\x00C\x03123q\x01\x85q" + b"\x02\x81q\x03}q\x04X\x10\x00\x00\x00_Binary__subtypeq" + b"\x05K\x02sb." + ) - if not sys.version.startswith('3.0'): + if not sys.version.startswith("3.0"): self.assertEqual(b1, pickle.loads(p)) for proto in range(pickle.HIGHEST_PROTOCOL + 1): @@ -342,15 +334,15 @@ def test_pickle(self): self.assertEqual(uul, pickle.loads(pickle.dumps(uul, proto))) def test_buffer_protocol(self): - b0 = Binary(b'123', 2) + b0 = Binary(b"123", 2) - self.assertEqual(b0, Binary(memoryview(b'123'), 2)) - self.assertEqual(b0, Binary(bytearray(b'123'), 2)) - with mmap.mmap(-1, len(b'123')) as mm: - mm.write(b'123') + self.assertEqual(b0, Binary(memoryview(b"123"), 2)) + self.assertEqual(b0, Binary(bytearray(b"123"), 2)) + with mmap.mmap(-1, len(b"123")) as mm: + mm.write(b"123") mm.seek(0) self.assertEqual(b0, Binary(mm, 2)) - self.assertEqual(b0, Binary(array.array('B', b'123'), 2)) + self.assertEqual(b0, Binary(array.array("B", b"123"), 2)) class TestUuidSpecExplicitCoding(unittest.TestCase): @@ -366,40 +358,37 @@ def _hex_to_bytes(hexstring): # Explicit encoding prose test #1 def test_encoding_1(self): obj = Binary.from_uuid(self.uuid) - expected_obj = Binary( - self._hex_to_bytes("00112233445566778899AABBCCDDEEFF"), 4) + expected_obj = Binary(self._hex_to_bytes("00112233445566778899AABBCCDDEEFF"), 4) self.assertEqual(obj, expected_obj) - def _test_encoding_w_uuid_rep( - self, uuid_rep, expected_hexstring, expected_subtype): + def _test_encoding_w_uuid_rep(self, uuid_rep, expected_hexstring, expected_subtype): obj = Binary.from_uuid(self.uuid, uuid_rep) - expected_obj = Binary( - self._hex_to_bytes(expected_hexstring), expected_subtype) + expected_obj = Binary(self._hex_to_bytes(expected_hexstring), expected_subtype) self.assertEqual(obj, expected_obj) # Explicit encoding prose test #2 def test_encoding_2(self): self._test_encoding_w_uuid_rep( - UuidRepresentation.STANDARD, - "00112233445566778899AABBCCDDEEFF", 4) + UuidRepresentation.STANDARD, "00112233445566778899AABBCCDDEEFF", 4 + ) # Explicit encoding prose test #3 def test_encoding_3(self): self._test_encoding_w_uuid_rep( - UuidRepresentation.JAVA_LEGACY, - "7766554433221100FFEEDDCCBBAA9988", 3) + UuidRepresentation.JAVA_LEGACY, "7766554433221100FFEEDDCCBBAA9988", 3 + ) # Explicit encoding prose test #4 def test_encoding_4(self): self._test_encoding_w_uuid_rep( - UuidRepresentation.CSHARP_LEGACY, - "33221100554477668899AABBCCDDEEFF", 3) + UuidRepresentation.CSHARP_LEGACY, "33221100554477668899AABBCCDDEEFF", 3 + ) # Explicit encoding prose test #5 def test_encoding_5(self): self._test_encoding_w_uuid_rep( - UuidRepresentation.PYTHON_LEGACY, - "00112233445566778899AABBCCDDEEFF", 3) + UuidRepresentation.PYTHON_LEGACY, "00112233445566778899AABBCCDDEEFF", 3 + ) # Explicit encoding prose test #6 def test_encoding_6(self): @@ -408,17 +397,18 @@ def test_encoding_6(self): # Explicit decoding prose test #1 def test_decoding_1(self): - obj = Binary( - self._hex_to_bytes("00112233445566778899AABBCCDDEEFF"), 4) + obj = Binary(self._hex_to_bytes("00112233445566778899AABBCCDDEEFF"), 4) # Case i: self.assertEqual(obj.as_uuid(), self.uuid) # Case ii: self.assertEqual(obj.as_uuid(UuidRepresentation.STANDARD), self.uuid) # Cases iii-vi: - for uuid_rep in (UuidRepresentation.JAVA_LEGACY, - UuidRepresentation.CSHARP_LEGACY, - UuidRepresentation.PYTHON_LEGACY): + for uuid_rep in ( + UuidRepresentation.JAVA_LEGACY, + UuidRepresentation.CSHARP_LEGACY, + UuidRepresentation.PYTHON_LEGACY, + ): with self.assertRaises(ValueError): obj.as_uuid(uuid_rep) @@ -429,31 +419,29 @@ def _test_decoding_legacy(self, hexstring, uuid_rep): with self.assertRaises(ValueError): obj.as_uuid() # Cases ii-iii: - for rep in (UuidRepresentation.STANDARD, - UuidRepresentation.UNSPECIFIED): + for rep in (UuidRepresentation.STANDARD, UuidRepresentation.UNSPECIFIED): with self.assertRaises(ValueError): obj.as_uuid(rep) # Case iv: - self.assertEqual(obj.as_uuid(uuid_rep), - self.uuid) + self.assertEqual(obj.as_uuid(uuid_rep), self.uuid) # Explicit decoding prose test #2 def test_decoding_2(self): self._test_decoding_legacy( - "7766554433221100FFEEDDCCBBAA9988", - UuidRepresentation.JAVA_LEGACY) + "7766554433221100FFEEDDCCBBAA9988", UuidRepresentation.JAVA_LEGACY + ) # Explicit decoding prose test #3 def test_decoding_3(self): self._test_decoding_legacy( - "33221100554477668899AABBCCDDEEFF", - UuidRepresentation.CSHARP_LEGACY) + "33221100554477668899AABBCCDDEEFF", UuidRepresentation.CSHARP_LEGACY + ) # Explicit decoding prose test #4 def test_decoding_4(self): self._test_decoding_legacy( - "00112233445566778899AABBCCDDEEFF", - UuidRepresentation.PYTHON_LEGACY) + "00112233445566778899AABBCCDDEEFF", UuidRepresentation.PYTHON_LEGACY + ) class TestUuidSpecImplicitCoding(IntegrationTest): @@ -468,95 +456,90 @@ def _hex_to_bytes(hexstring): def _get_coll_w_uuid_rep(self, uuid_rep): codec_options = self.client.codec_options.with_options( - uuid_representation=validate_uuid_representation(None, uuid_rep)) + uuid_representation=validate_uuid_representation(None, uuid_rep) + ) coll = self.db.get_collection( - 'pymongo_test', codec_options=codec_options, - write_concern=WriteConcern("majority")) + "pymongo_test", codec_options=codec_options, write_concern=WriteConcern("majority") + ) return coll def _test_encoding(self, uuid_rep, expected_hexstring, expected_subtype): coll = self._get_coll_w_uuid_rep(uuid_rep) coll.delete_many({}) - coll.insert_one({'_id': self.uuid}) + coll.insert_one({"_id": self.uuid}) self.assertTrue( - coll.find_one({"_id": Binary( - self._hex_to_bytes(expected_hexstring), expected_subtype)})) + coll.find_one({"_id": Binary(self._hex_to_bytes(expected_hexstring), expected_subtype)}) + ) # Implicit encoding prose test #1 def test_encoding_1(self): - self._test_encoding( - "javaLegacy", "7766554433221100FFEEDDCCBBAA9988", 3) + self._test_encoding("javaLegacy", "7766554433221100FFEEDDCCBBAA9988", 3) # Implicit encoding prose test #2 def test_encoding_2(self): - self._test_encoding( - "csharpLegacy", "33221100554477668899AABBCCDDEEFF", 3) + self._test_encoding("csharpLegacy", "33221100554477668899AABBCCDDEEFF", 3) # Implicit encoding prose test #3 def test_encoding_3(self): - self._test_encoding( - "pythonLegacy", "00112233445566778899AABBCCDDEEFF", 3) + self._test_encoding("pythonLegacy", "00112233445566778899AABBCCDDEEFF", 3) # Implicit encoding prose test #4 def test_encoding_4(self): - self._test_encoding( - "standard", "00112233445566778899AABBCCDDEEFF", 4) + self._test_encoding("standard", "00112233445566778899AABBCCDDEEFF", 4) # Implicit encoding prose test #5 def test_encoding_5(self): with self.assertRaises(ValueError): - self._test_encoding( - "unspecifed", "dummy", -1) - - def _test_decoding(self, client_uuid_representation_string, - legacy_field_uuid_representation, - expected_standard_field_value, - expected_legacy_field_value): + self._test_encoding("unspecifed", "dummy", -1) + + def _test_decoding( + self, + client_uuid_representation_string, + legacy_field_uuid_representation, + expected_standard_field_value, + expected_legacy_field_value, + ): coll = self._get_coll_w_uuid_rep(client_uuid_representation_string) coll.drop() standard_val = Binary.from_uuid(self.uuid, UuidRepresentation.STANDARD) legacy_val = Binary.from_uuid(self.uuid, legacy_field_uuid_representation) - coll.insert_one({'standard': standard_val, 'legacy': legacy_val}) + coll.insert_one({"standard": standard_val, "legacy": legacy_val}) doc = coll.find_one() - self.assertEqual(doc['standard'], expected_standard_field_value) - self.assertEqual(doc['legacy'], expected_legacy_field_value) + self.assertEqual(doc["standard"], expected_standard_field_value) + self.assertEqual(doc["legacy"], expected_legacy_field_value) # Implicit decoding prose test #1 def test_decoding_1(self): - standard_binary = Binary.from_uuid( - self.uuid, UuidRepresentation.STANDARD) + standard_binary = Binary.from_uuid(self.uuid, UuidRepresentation.STANDARD) self._test_decoding( - "javaLegacy", UuidRepresentation.JAVA_LEGACY, - standard_binary, self.uuid) + "javaLegacy", UuidRepresentation.JAVA_LEGACY, standard_binary, self.uuid + ) self._test_decoding( - "csharpLegacy", UuidRepresentation.CSHARP_LEGACY, - standard_binary, self.uuid) + "csharpLegacy", UuidRepresentation.CSHARP_LEGACY, standard_binary, self.uuid + ) self._test_decoding( - "pythonLegacy", UuidRepresentation.PYTHON_LEGACY, - standard_binary, self.uuid) + "pythonLegacy", UuidRepresentation.PYTHON_LEGACY, standard_binary, self.uuid + ) # Implicit decoding pose test #2 def test_decoding_2(self): - legacy_binary = Binary.from_uuid( - self.uuid, UuidRepresentation.PYTHON_LEGACY) - self._test_decoding( - "standard", UuidRepresentation.PYTHON_LEGACY, - self.uuid, legacy_binary) + legacy_binary = Binary.from_uuid(self.uuid, UuidRepresentation.PYTHON_LEGACY) + self._test_decoding("standard", UuidRepresentation.PYTHON_LEGACY, self.uuid, legacy_binary) # Implicit decoding pose test #3 def test_decoding_3(self): - expected_standard_value = Binary.from_uuid( - self.uuid, UuidRepresentation.STANDARD) - for legacy_uuid_rep in (UuidRepresentation.PYTHON_LEGACY, - UuidRepresentation.CSHARP_LEGACY, - UuidRepresentation.JAVA_LEGACY): - expected_legacy_value = Binary.from_uuid( - self.uuid, legacy_uuid_rep) + expected_standard_value = Binary.from_uuid(self.uuid, UuidRepresentation.STANDARD) + for legacy_uuid_rep in ( + UuidRepresentation.PYTHON_LEGACY, + UuidRepresentation.CSHARP_LEGACY, + UuidRepresentation.JAVA_LEGACY, + ): + expected_legacy_value = Binary.from_uuid(self.uuid, legacy_uuid_rep) self._test_decoding( - "unspecified", legacy_uuid_rep, - expected_standard_value, expected_legacy_value) + "unspecified", legacy_uuid_rep, expected_standard_value, expected_legacy_value + ) if __name__ == "__main__": diff --git a/test/test_bson.py b/test/test_bson.py index eb4f4e47c2..a552c17edd 100644 --- a/test/test_bson.py +++ b/test/test_bson.py @@ -21,44 +21,43 @@ import datetime import mmap import os +import pickle import re import sys import tempfile import uuid -import pickle - -from collections import abc, OrderedDict +from collections import OrderedDict, abc from io import BytesIO sys.path[0:0] = [""] +from test import qcheck, unittest +from test.utils import ExceptionCatchingThread + import bson -from bson import (BSON, - decode, - decode_all, - decode_file_iter, - decode_iter, - encode, - EPOCH_AWARE, - is_valid, - Regex) +from bson import ( + BSON, + EPOCH_AWARE, + Regex, + decode, + decode_all, + decode_file_iter, + decode_iter, + encode, + is_valid, +) from bson.binary import Binary, UuidRepresentation from bson.code import Code from bson.codec_options import CodecOptions +from bson.dbref import DBRef +from bson.errors import InvalidBSON, InvalidDocument from bson.int64 import Int64 +from bson.max_key import MaxKey +from bson.min_key import MinKey from bson.objectid import ObjectId -from bson.dbref import DBRef from bson.son import SON from bson.timestamp import Timestamp -from bson.errors import (InvalidBSON, - InvalidDocument) -from bson.max_key import MaxKey -from bson.min_key import MinKey -from bson.tz_util import (FixedOffset, - utc) - -from test import qcheck, unittest -from test.utils import ExceptionCatchingThread +from bson.tz_util import FixedOffset, utc class NotADict(abc.MutableMapping): @@ -95,7 +94,6 @@ def __repr__(self): class DSTAwareTimezone(datetime.tzinfo): - def __init__(self, offset, name, dst_start_month, dst_end_month): self.__offset = offset self.__dst_start_month = dst_start_month @@ -121,11 +119,10 @@ class TestBSON(unittest.TestCase): def assertInvalid(self, data): self.assertRaises(InvalidBSON, decode, data) - def check_encode_then_decode(self, doc_class=dict, decoder=decode, - encoder=encode): + def check_encode_then_decode(self, doc_class=dict, decoder=decode, encoder=encode): # Work around http://bugs.jython.org/issue1728 - if sys.platform.startswith('java'): + if sys.platform.startswith("java"): doc_class = SON def helper(doc): @@ -134,8 +131,7 @@ def helper(doc): helper({}) helper({"test": "hello"}) - self.assertTrue(isinstance(decoder(encoder( - {"hello": "world"}))["hello"], str)) + self.assertTrue(isinstance(decoder(encoder({"hello": "world"}))["hello"], str)) helper({"mike": -10120}) helper({"long": Int64(10)}) helper({"really big long": 2147483648}) @@ -148,9 +144,8 @@ def helper(doc): helper({"a binary": Binary(b"test", 128)}) helper({"a binary": Binary(b"test", 254)}) helper({"another binary": Binary(b"test", 2)}) - helper(SON([('test dst', datetime.datetime(1993, 4, 4, 2))])) - helper(SON([('test negative dst', - datetime.datetime(1, 1, 1, 1, 1, 1))])) + helper(SON([("test dst", datetime.datetime(1993, 4, 4, 2))])) + helper(SON([("test negative dst", datetime.datetime(1, 1, 1, 1, 1, 1))])) helper({"big float": float(10000000000)}) helper({"ref": DBRef("coll", 5)}) helper({"ref": DBRef("coll", 5, foo="bar", bar=4)}) @@ -160,14 +155,12 @@ def helper(doc): helper({"foo": MinKey()}) helper({"foo": MaxKey()}) helper({"$field": Code("function(){ return true; }")}) - helper({"$field": Code("return function(){ return x; }", scope={'x': False})}) + helper({"$field": Code("return function(){ return x; }", scope={"x": False})}) def encode_then_decode(doc): - return doc_class(doc) == decoder(encode(doc), CodecOptions( - document_class=doc_class)) + return doc_class(doc) == decoder(encode(doc), CodecOptions(document_class=doc_class)) - qcheck.check_unittest(self, encode_then_decode, - qcheck.gen_mongo_dict(3)) + qcheck.check_unittest(self, encode_then_decode, qcheck.gen_mongo_dict(3)) def test_encode_then_decode(self): self.check_encode_then_decode() @@ -177,18 +170,20 @@ def test_encode_then_decode_any_mapping(self): def test_encode_then_decode_legacy(self): self.check_encode_then_decode( - encoder=BSON.encode, - decoder=lambda *args: BSON(args[0]).decode(*args[1:])) + encoder=BSON.encode, decoder=lambda *args: BSON(args[0]).decode(*args[1:]) + ) def test_encode_then_decode_any_mapping_legacy(self): self.check_encode_then_decode( - doc_class=NotADict, encoder=BSON.encode, - decoder=lambda *args: BSON(args[0]).decode(*args[1:])) + doc_class=NotADict, + encoder=BSON.encode, + decoder=lambda *args: BSON(args[0]).decode(*args[1:]), + ) def test_encoding_defaultdict(self): - dct = collections.defaultdict(dict, [('foo', 'bar')]) + dct = collections.defaultdict(dict, [("foo", "bar")]) encode(dct) - self.assertEqual(dct, collections.defaultdict(dict, [('foo', 'bar')])) + self.assertEqual(dct, collections.defaultdict(dict, [("foo", "bar")])) def test_basic_validation(self): self.assertRaises(TypeError, is_valid, 100) @@ -209,117 +204,132 @@ def test_basic_validation(self): self.assertInvalid(b"\x07\x00\x00\x00\x02a\x00\x78\x56\x34\x12") self.assertInvalid(b"\x09\x00\x00\x00\x10a\x00\x05\x00") self.assertInvalid(b"\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00") - self.assertInvalid(b"\x13\x00\x00\x00\x02foo\x00" - b"\x04\x00\x00\x00bar\x00\x00") - self.assertInvalid(b"\x18\x00\x00\x00\x03foo\x00\x0f\x00\x00" - b"\x00\x10bar\x00\xff\xff\xff\x7f\x00\x00") - self.assertInvalid(b"\x15\x00\x00\x00\x03foo\x00\x0c" - b"\x00\x00\x00\x08bar\x00\x01\x00\x00") - self.assertInvalid(b"\x1c\x00\x00\x00\x03foo\x00" - b"\x12\x00\x00\x00\x02bar\x00" - b"\x05\x00\x00\x00baz\x00\x00\x00") - self.assertInvalid(b"\x10\x00\x00\x00\x02a\x00" - b"\x04\x00\x00\x00abc\xff\x00") - - def test_bad_string_lengths(self): - self.assertInvalid( - b"\x0c\x00\x00\x00\x02\x00" - b"\x00\x00\x00\x00\x00\x00") + self.assertInvalid(b"\x13\x00\x00\x00\x02foo\x00" b"\x04\x00\x00\x00bar\x00\x00") self.assertInvalid( - b"\x12\x00\x00\x00\x02\x00" - b"\xff\xff\xff\xfffoobar\x00\x00") + b"\x18\x00\x00\x00\x03foo\x00\x0f\x00\x00" b"\x00\x10bar\x00\xff\xff\xff\x7f\x00\x00" + ) self.assertInvalid( - b"\x0c\x00\x00\x00\x0e\x00" - b"\x00\x00\x00\x00\x00\x00") + b"\x15\x00\x00\x00\x03foo\x00\x0c" b"\x00\x00\x00\x08bar\x00\x01\x00\x00" + ) self.assertInvalid( - b"\x12\x00\x00\x00\x0e\x00" - b"\xff\xff\xff\xfffoobar\x00\x00") + b"\x1c\x00\x00\x00\x03foo\x00" + b"\x12\x00\x00\x00\x02bar\x00" + b"\x05\x00\x00\x00baz\x00\x00\x00" + ) + self.assertInvalid(b"\x10\x00\x00\x00\x02a\x00" b"\x04\x00\x00\x00abc\xff\x00") + + def test_bad_string_lengths(self): + self.assertInvalid(b"\x0c\x00\x00\x00\x02\x00" b"\x00\x00\x00\x00\x00\x00") + self.assertInvalid(b"\x12\x00\x00\x00\x02\x00" b"\xff\xff\xff\xfffoobar\x00\x00") + self.assertInvalid(b"\x0c\x00\x00\x00\x0e\x00" b"\x00\x00\x00\x00\x00\x00") + self.assertInvalid(b"\x12\x00\x00\x00\x0e\x00" b"\xff\xff\xff\xfffoobar\x00\x00") self.assertInvalid( - b"\x18\x00\x00\x00\x0c\x00" - b"\x00\x00\x00\x00\x00RY\xb5j" - b"\xfa[\xd8A\xd6X]\x99\x00") + b"\x18\x00\x00\x00\x0c\x00" b"\x00\x00\x00\x00\x00RY\xb5j" b"\xfa[\xd8A\xd6X]\x99\x00" + ) self.assertInvalid( b"\x1e\x00\x00\x00\x0c\x00" b"\xff\xff\xff\xfffoobar\x00" - b"RY\xb5j\xfa[\xd8A\xd6X]\x99\x00") - self.assertInvalid( - b"\x0c\x00\x00\x00\r\x00" - b"\x00\x00\x00\x00\x00\x00") - self.assertInvalid( - b"\x0c\x00\x00\x00\r\x00" - b"\xff\xff\xff\xff\x00\x00") + b"RY\xb5j\xfa[\xd8A\xd6X]\x99\x00" + ) + self.assertInvalid(b"\x0c\x00\x00\x00\r\x00" b"\x00\x00\x00\x00\x00\x00") + self.assertInvalid(b"\x0c\x00\x00\x00\r\x00" b"\xff\xff\xff\xff\x00\x00") self.assertInvalid( b"\x1c\x00\x00\x00\x0f\x00" b"\x15\x00\x00\x00\x00\x00" b"\x00\x00\x00\x0c\x00\x00" b"\x00\x02\x00\x01\x00\x00" - b"\x00\x00\x00\x00") + b"\x00\x00\x00\x00" + ) self.assertInvalid( b"\x1c\x00\x00\x00\x0f\x00" b"\x15\x00\x00\x00\xff\xff" b"\xff\xff\x00\x0c\x00\x00" b"\x00\x02\x00\x01\x00\x00" - b"\x00\x00\x00\x00") + b"\x00\x00\x00\x00" + ) self.assertInvalid( b"\x1c\x00\x00\x00\x0f\x00" b"\x15\x00\x00\x00\x01\x00" b"\x00\x00\x00\x0c\x00\x00" b"\x00\x02\x00\x00\x00\x00" - b"\x00\x00\x00\x00") + b"\x00\x00\x00\x00" + ) self.assertInvalid( b"\x1c\x00\x00\x00\x0f\x00" b"\x15\x00\x00\x00\x01\x00" b"\x00\x00\x00\x0c\x00\x00" b"\x00\x02\x00\xff\xff\xff" - b"\xff\x00\x00\x00") + b"\xff\x00\x00\x00" + ) def test_random_data_is_not_bson(self): - qcheck.check_unittest(self, qcheck.isnt(is_valid), - qcheck.gen_string(qcheck.gen_range(0, 40))) + qcheck.check_unittest( + self, qcheck.isnt(is_valid), qcheck.gen_string(qcheck.gen_range(0, 40)) + ) def test_basic_decode(self): - self.assertEqual({"test": "hello world"}, - decode(b"\x1B\x00\x00\x00\x0E\x74\x65\x73\x74\x00\x0C" - b"\x00\x00\x00\x68\x65\x6C\x6C\x6F\x20\x77\x6F" - b"\x72\x6C\x64\x00\x00")) - self.assertEqual([{"test": "hello world"}, {}], - decode_all(b"\x1B\x00\x00\x00\x0E\x74\x65\x73\x74" - b"\x00\x0C\x00\x00\x00\x68\x65\x6C\x6C" - b"\x6f\x20\x77\x6F\x72\x6C\x64\x00\x00" - b"\x05\x00\x00\x00\x00")) - self.assertEqual([{"test": "hello world"}, {}], - list(decode_iter( - b"\x1B\x00\x00\x00\x0E\x74\x65\x73\x74" - b"\x00\x0C\x00\x00\x00\x68\x65\x6C\x6C" - b"\x6f\x20\x77\x6F\x72\x6C\x64\x00\x00" - b"\x05\x00\x00\x00\x00"))) - self.assertEqual([{"test": "hello world"}, {}], - list(decode_file_iter(BytesIO( - b"\x1B\x00\x00\x00\x0E\x74\x65\x73\x74" - b"\x00\x0C\x00\x00\x00\x68\x65\x6C\x6C" - b"\x6f\x20\x77\x6F\x72\x6C\x64\x00\x00" - b"\x05\x00\x00\x00\x00")))) + self.assertEqual( + {"test": "hello world"}, + decode( + b"\x1B\x00\x00\x00\x0E\x74\x65\x73\x74\x00\x0C" + b"\x00\x00\x00\x68\x65\x6C\x6C\x6F\x20\x77\x6F" + b"\x72\x6C\x64\x00\x00" + ), + ) + self.assertEqual( + [{"test": "hello world"}, {}], + decode_all( + b"\x1B\x00\x00\x00\x0E\x74\x65\x73\x74" + b"\x00\x0C\x00\x00\x00\x68\x65\x6C\x6C" + b"\x6f\x20\x77\x6F\x72\x6C\x64\x00\x00" + b"\x05\x00\x00\x00\x00" + ), + ) + self.assertEqual( + [{"test": "hello world"}, {}], + list( + decode_iter( + b"\x1B\x00\x00\x00\x0E\x74\x65\x73\x74" + b"\x00\x0C\x00\x00\x00\x68\x65\x6C\x6C" + b"\x6f\x20\x77\x6F\x72\x6C\x64\x00\x00" + b"\x05\x00\x00\x00\x00" + ) + ), + ) + self.assertEqual( + [{"test": "hello world"}, {}], + list( + decode_file_iter( + BytesIO( + b"\x1B\x00\x00\x00\x0E\x74\x65\x73\x74" + b"\x00\x0C\x00\x00\x00\x68\x65\x6C\x6C" + b"\x6f\x20\x77\x6F\x72\x6C\x64\x00\x00" + b"\x05\x00\x00\x00\x00" + ) + ) + ), + ) def test_decode_all_buffer_protocol(self): - docs = [{'foo': 'bar'}, {}] + docs = [{"foo": "bar"}, {}] bs = b"".join(map(encode, docs)) self.assertEqual(docs, decode_all(bytearray(bs))) self.assertEqual(docs, decode_all(memoryview(bs))) - self.assertEqual(docs, decode_all(memoryview(b'1' + bs + b'1')[1:-1])) - self.assertEqual(docs, decode_all(array.array('B', bs))) + self.assertEqual(docs, decode_all(memoryview(b"1" + bs + b"1")[1:-1])) + self.assertEqual(docs, decode_all(array.array("B", bs))) with mmap.mmap(-1, len(bs)) as mm: mm.write(bs) mm.seek(0) self.assertEqual(docs, decode_all(mm)) def test_decode_buffer_protocol(self): - doc = {'foo': 'bar'} + doc = {"foo": "bar"} bs = encode(doc) self.assertEqual(doc, decode(bs)) self.assertEqual(doc, decode(bytearray(bs))) self.assertEqual(doc, decode(memoryview(bs))) - self.assertEqual(doc, decode(memoryview(b'1' + bs + b'1')[1:-1])) - self.assertEqual(doc, decode(array.array('B', bs))) + self.assertEqual(doc, decode(memoryview(b"1" + bs + b"1")[1:-1])) + self.assertEqual(doc, decode(array.array("B", bs))) with mmap.mmap(-1, len(bs)) as mm: mm.write(bs) mm.seek(0) @@ -329,8 +339,7 @@ def test_invalid_decodes(self): # Invalid object size (not enough bytes in document for even # an object size of first object. # NOTE: decode_all and decode_iter don't care, not sure if they should? - self.assertRaises(InvalidBSON, list, - decode_file_iter(BytesIO(b"\x1B"))) + self.assertRaises(InvalidBSON, list, decode_file_iter(BytesIO(b"\x1B"))) bad_bsons = [ # An object size that's too small to even include the object size, @@ -338,21 +347,27 @@ def test_invalid_decodes(self): b"\x01\x00\x00\x00\x00", # One object, but with object size listed smaller than it is in the # data. - (b"\x1A\x00\x00\x00\x0E\x74\x65\x73\x74" - b"\x00\x0C\x00\x00\x00\x68\x65\x6C\x6C" - b"\x6f\x20\x77\x6F\x72\x6C\x64\x00\x00" - b"\x05\x00\x00\x00\x00"), + ( + b"\x1A\x00\x00\x00\x0E\x74\x65\x73\x74" + b"\x00\x0C\x00\x00\x00\x68\x65\x6C\x6C" + b"\x6f\x20\x77\x6F\x72\x6C\x64\x00\x00" + b"\x05\x00\x00\x00\x00" + ), # One object, missing the EOO at the end. - (b"\x1B\x00\x00\x00\x0E\x74\x65\x73\x74" - b"\x00\x0C\x00\x00\x00\x68\x65\x6C\x6C" - b"\x6f\x20\x77\x6F\x72\x6C\x64\x00\x00" - b"\x05\x00\x00\x00"), + ( + b"\x1B\x00\x00\x00\x0E\x74\x65\x73\x74" + b"\x00\x0C\x00\x00\x00\x68\x65\x6C\x6C" + b"\x6f\x20\x77\x6F\x72\x6C\x64\x00\x00" + b"\x05\x00\x00\x00" + ), # One object, sized correctly, with a spot for an EOO, but the EOO # isn't 0x00. - (b"\x1B\x00\x00\x00\x0E\x74\x65\x73\x74" - b"\x00\x0C\x00\x00\x00\x68\x65\x6C\x6C" - b"\x6f\x20\x77\x6F\x72\x6C\x64\x00\x00" - b"\x05\x00\x00\x00\xFF"), + ( + b"\x1B\x00\x00\x00\x0E\x74\x65\x73\x74" + b"\x00\x0C\x00\x00\x00\x68\x65\x6C\x6C" + b"\x6f\x20\x77\x6F\x72\x6C\x64\x00\x00" + b"\x05\x00\x00\x00\xFF" + ), ] for i, data in enumerate(bad_bsons): msg = "bad_bson[{}]".format(i) @@ -371,14 +386,17 @@ def test_invalid_decodes(self): def test_invalid_field_name(self): # Decode a truncated field with self.assertRaises(InvalidBSON) as ctx: - decode(b'\x0b\x00\x00\x00\x02field\x00') + decode(b"\x0b\x00\x00\x00\x02field\x00") # Assert that the InvalidBSON error message is not empty. self.assertTrue(str(ctx.exception)) def test_data_timestamp(self): - self.assertEqual({"test": Timestamp(4, 20)}, - decode(b"\x13\x00\x00\x00\x11\x74\x65\x73\x74\x00\x14" - b"\x00\x00\x00\x04\x00\x00\x00\x00")) + self.assertEqual( + {"test": Timestamp(4, 20)}, + decode( + b"\x13\x00\x00\x00\x11\x74\x65\x73\x74\x00\x14" b"\x00\x00\x00\x04\x00\x00\x00\x00" + ), + ) def test_basic_encode(self): self.assertRaises(TypeError, encode, 100) @@ -388,83 +406,102 @@ def test_basic_encode(self): self.assertEqual(encode({}), BSON(b"\x05\x00\x00\x00\x00")) self.assertEqual(encode({}), b"\x05\x00\x00\x00\x00") - self.assertEqual(encode({"test": "hello world"}), - b"\x1B\x00\x00\x00\x02\x74\x65\x73\x74\x00\x0C\x00" - b"\x00\x00\x68\x65\x6C\x6C\x6F\x20\x77\x6F\x72\x6C" - b"\x64\x00\x00") - self.assertEqual(encode({"mike": 100}), - b"\x0F\x00\x00\x00\x10\x6D\x69\x6B\x65\x00\x64\x00" - b"\x00\x00\x00") - self.assertEqual(encode({"hello": 1.5}), - b"\x14\x00\x00\x00\x01\x68\x65\x6C\x6C\x6F\x00\x00" - b"\x00\x00\x00\x00\x00\xF8\x3F\x00") - self.assertEqual(encode({"true": True}), - b"\x0C\x00\x00\x00\x08\x74\x72\x75\x65\x00\x01\x00") - self.assertEqual(encode({"false": False}), - b"\x0D\x00\x00\x00\x08\x66\x61\x6C\x73\x65\x00\x00" - b"\x00") - self.assertEqual(encode({"empty": []}), - b"\x11\x00\x00\x00\x04\x65\x6D\x70\x74\x79\x00\x05" - b"\x00\x00\x00\x00\x00") - self.assertEqual(encode({"none": {}}), - b"\x10\x00\x00\x00\x03\x6E\x6F\x6E\x65\x00\x05\x00" - b"\x00\x00\x00\x00") - self.assertEqual(encode({"test": Binary(b"test", 0)}), - b"\x14\x00\x00\x00\x05\x74\x65\x73\x74\x00\x04\x00" - b"\x00\x00\x00\x74\x65\x73\x74\x00") - self.assertEqual(encode({"test": Binary(b"test", 2)}), - b"\x18\x00\x00\x00\x05\x74\x65\x73\x74\x00\x08\x00" - b"\x00\x00\x02\x04\x00\x00\x00\x74\x65\x73\x74\x00") - self.assertEqual(encode({"test": Binary(b"test", 128)}), - b"\x14\x00\x00\x00\x05\x74\x65\x73\x74\x00\x04\x00" - b"\x00\x00\x80\x74\x65\x73\x74\x00") - self.assertEqual(encode({"test": None}), - b"\x0B\x00\x00\x00\x0A\x74\x65\x73\x74\x00\x00") - self.assertEqual(encode({"date": datetime.datetime(2007, 1, 8, - 0, 30, 11)}), - b"\x13\x00\x00\x00\x09\x64\x61\x74\x65\x00\x38\xBE" - b"\x1C\xFF\x0F\x01\x00\x00\x00") - self.assertEqual(encode({"regex": re.compile(b"a*b", - re.IGNORECASE)}), - b"\x12\x00\x00\x00\x0B\x72\x65\x67\x65\x78\x00\x61" - b"\x2A\x62\x00\x69\x00\x00") - self.assertEqual(encode({"$where": Code("test")}), - b"\x16\x00\x00\x00\r$where\x00\x05\x00\x00\x00test" - b"\x00\x00") - self.assertEqual(encode({"$field": - Code("function(){ return true;}", scope=None)}), - b"+\x00\x00\x00\r$field\x00\x1a\x00\x00\x00" - b"function(){ return true;}\x00\x00") - self.assertEqual(encode({"$field": - Code("return function(){ return x; }", - scope={'x': False})}), - b"=\x00\x00\x00\x0f$field\x000\x00\x00\x00\x1f\x00" - b"\x00\x00return function(){ return x; }\x00\t\x00" - b"\x00\x00\x08x\x00\x00\x00\x00") + self.assertEqual( + encode({"test": "hello world"}), + b"\x1B\x00\x00\x00\x02\x74\x65\x73\x74\x00\x0C\x00" + b"\x00\x00\x68\x65\x6C\x6C\x6F\x20\x77\x6F\x72\x6C" + b"\x64\x00\x00", + ) + self.assertEqual( + encode({"mike": 100}), + b"\x0F\x00\x00\x00\x10\x6D\x69\x6B\x65\x00\x64\x00" b"\x00\x00\x00", + ) + self.assertEqual( + encode({"hello": 1.5}), + b"\x14\x00\x00\x00\x01\x68\x65\x6C\x6C\x6F\x00\x00" b"\x00\x00\x00\x00\x00\xF8\x3F\x00", + ) + self.assertEqual( + encode({"true": True}), b"\x0C\x00\x00\x00\x08\x74\x72\x75\x65\x00\x01\x00" + ) + self.assertEqual( + encode({"false": False}), b"\x0D\x00\x00\x00\x08\x66\x61\x6C\x73\x65\x00\x00" b"\x00" + ) + self.assertEqual( + encode({"empty": []}), + b"\x11\x00\x00\x00\x04\x65\x6D\x70\x74\x79\x00\x05" b"\x00\x00\x00\x00\x00", + ) + self.assertEqual( + encode({"none": {}}), + b"\x10\x00\x00\x00\x03\x6E\x6F\x6E\x65\x00\x05\x00" b"\x00\x00\x00\x00", + ) + self.assertEqual( + encode({"test": Binary(b"test", 0)}), + b"\x14\x00\x00\x00\x05\x74\x65\x73\x74\x00\x04\x00" b"\x00\x00\x00\x74\x65\x73\x74\x00", + ) + self.assertEqual( + encode({"test": Binary(b"test", 2)}), + b"\x18\x00\x00\x00\x05\x74\x65\x73\x74\x00\x08\x00" + b"\x00\x00\x02\x04\x00\x00\x00\x74\x65\x73\x74\x00", + ) + self.assertEqual( + encode({"test": Binary(b"test", 128)}), + b"\x14\x00\x00\x00\x05\x74\x65\x73\x74\x00\x04\x00" b"\x00\x00\x80\x74\x65\x73\x74\x00", + ) + self.assertEqual(encode({"test": None}), b"\x0B\x00\x00\x00\x0A\x74\x65\x73\x74\x00\x00") + self.assertEqual( + encode({"date": datetime.datetime(2007, 1, 8, 0, 30, 11)}), + b"\x13\x00\x00\x00\x09\x64\x61\x74\x65\x00\x38\xBE" b"\x1C\xFF\x0F\x01\x00\x00\x00", + ) + self.assertEqual( + encode({"regex": re.compile(b"a*b", re.IGNORECASE)}), + b"\x12\x00\x00\x00\x0B\x72\x65\x67\x65\x78\x00\x61" b"\x2A\x62\x00\x69\x00\x00", + ) + self.assertEqual( + encode({"$where": Code("test")}), + b"\x16\x00\x00\x00\r$where\x00\x05\x00\x00\x00test" b"\x00\x00", + ) + self.assertEqual( + encode({"$field": Code("function(){ return true;}", scope=None)}), + b"+\x00\x00\x00\r$field\x00\x1a\x00\x00\x00" b"function(){ return true;}\x00\x00", + ) + self.assertEqual( + encode({"$field": Code("return function(){ return x; }", scope={"x": False})}), + b"=\x00\x00\x00\x0f$field\x000\x00\x00\x00\x1f\x00" + b"\x00\x00return function(){ return x; }\x00\t\x00" + b"\x00\x00\x08x\x00\x00\x00\x00", + ) unicode_empty_scope = Code("function(){ return 'héllo';}", {}) - self.assertEqual(encode({'$field': unicode_empty_scope}), - b"8\x00\x00\x00\x0f$field\x00+\x00\x00\x00\x1e\x00" - b"\x00\x00function(){ return 'h\xc3\xa9llo';}\x00\x05" - b"\x00\x00\x00\x00\x00") + self.assertEqual( + encode({"$field": unicode_empty_scope}), + b"8\x00\x00\x00\x0f$field\x00+\x00\x00\x00\x1e\x00" + b"\x00\x00function(){ return 'h\xc3\xa9llo';}\x00\x05" + b"\x00\x00\x00\x00\x00", + ) a = ObjectId(b"\x00\x01\x02\x03\x04\x05\x06\x07\x08\x09\x0A\x0B") - self.assertEqual(encode({"oid": a}), - b"\x16\x00\x00\x00\x07\x6F\x69\x64\x00\x00\x01\x02" - b"\x03\x04\x05\x06\x07\x08\x09\x0A\x0B\x00") - self.assertEqual(encode({"ref": DBRef("coll", a)}), - b"\x2F\x00\x00\x00\x03ref\x00\x25\x00\x00\x00\x02" - b"$ref\x00\x05\x00\x00\x00coll\x00\x07$id\x00\x00" - b"\x01\x02\x03\x04\x05\x06\x07\x08\x09\x0A\x0B\x00" - b"\x00") + self.assertEqual( + encode({"oid": a}), + b"\x16\x00\x00\x00\x07\x6F\x69\x64\x00\x00\x01\x02" + b"\x03\x04\x05\x06\x07\x08\x09\x0A\x0B\x00", + ) + self.assertEqual( + encode({"ref": DBRef("coll", a)}), + b"\x2F\x00\x00\x00\x03ref\x00\x25\x00\x00\x00\x02" + b"$ref\x00\x05\x00\x00\x00coll\x00\x07$id\x00\x00" + b"\x01\x02\x03\x04\x05\x06\x07\x08\x09\x0A\x0B\x00" + b"\x00", + ) def test_unknown_type(self): # Repr value differs with major python version - part = "type %r for fieldname 'foo'" % (b'\x14',) + part = "type %r for fieldname 'foo'" % (b"\x14",) docs = [ - b'\x0e\x00\x00\x00\x14foo\x00\x01\x00\x00\x00\x00', - (b'\x16\x00\x00\x00\x04foo\x00\x0c\x00\x00\x00\x140' - b'\x00\x01\x00\x00\x00\x00\x00'), - (b' \x00\x00\x00\x04bar\x00\x16\x00\x00\x00\x030\x00\x0e\x00\x00' - b'\x00\x14foo\x00\x01\x00\x00\x00\x00\x00\x00')] + b"\x0e\x00\x00\x00\x14foo\x00\x01\x00\x00\x00\x00", + (b"\x16\x00\x00\x00\x04foo\x00\x0c\x00\x00\x00\x140" b"\x00\x01\x00\x00\x00\x00\x00"), + ( + b" \x00\x00\x00\x04bar\x00\x16\x00\x00\x00\x030\x00\x0e\x00\x00" + b"\x00\x14foo\x00\x01\x00\x00\x00\x00\x00\x00" + ), + ] for bs in docs: try: decode(bs) @@ -481,21 +518,19 @@ def test_dbpointer(self): # not support creation of the DBPointer type, but will decode # DBPointer to DBRef. - bs = (b"\x18\x00\x00\x00\x0c\x00\x01\x00\x00" - b"\x00\x00RY\xb5j\xfa[\xd8A\xd6X]\x99\x00") + bs = b"\x18\x00\x00\x00\x0c\x00\x01\x00\x00" b"\x00\x00RY\xb5j\xfa[\xd8A\xd6X]\x99\x00" - self.assertEqual({'': DBRef('', ObjectId('5259b56afa5bd841d6585d99'))}, - decode(bs)) + self.assertEqual({"": DBRef("", ObjectId("5259b56afa5bd841d6585d99"))}, decode(bs)) def test_bad_dbref(self): - ref_only = {'ref': {'$ref': 'collection'}} - id_only = {'ref': {'$id': ObjectId()}} + ref_only = {"ref": {"$ref": "collection"}} + id_only = {"ref": {"$id": ObjectId()}} self.assertEqual(ref_only, decode(encode(ref_only))) self.assertEqual(id_only, decode(encode(id_only))) def test_bytes_as_keys(self): - doc = {b"foo": 'bar'} + doc = {b"foo": "bar"} # Since `bytes` are stored as Binary you can't use them # as keys in python 3.x. Using binary data as a key makes # no sense in BSON anyway and little sense in python. @@ -528,13 +563,10 @@ def test_large_datetime_truncation(self): self.assertEqual(dt2.second, dt1.second) def test_aware_datetime(self): - aware = datetime.datetime(1993, 4, 4, 2, - tzinfo=FixedOffset(555, "SomeZone")) + aware = datetime.datetime(1993, 4, 4, 2, tzinfo=FixedOffset(555, "SomeZone")) as_utc = (aware - aware.utcoffset()).replace(tzinfo=utc) - self.assertEqual(datetime.datetime(1993, 4, 3, 16, 45, tzinfo=utc), - as_utc) - after = decode(encode({"date": aware}), CodecOptions(tz_aware=True))[ - "date"] + self.assertEqual(datetime.datetime(1993, 4, 3, 16, 45, tzinfo=utc), as_utc) + after = decode(encode({"date": aware}), CodecOptions(tz_aware=True))["date"] self.assertEqual(utc, after.tzinfo) self.assertEqual(as_utc, after) @@ -543,54 +575,47 @@ def test_local_datetime(self): tz = DSTAwareTimezone(60, "sixty-minutes", 4, 7) # It's not DST. - local = datetime.datetime(year=2025, month=12, hour=2, day=1, - tzinfo=tz) + local = datetime.datetime(year=2025, month=12, hour=2, day=1, tzinfo=tz) options = CodecOptions(tz_aware=True, tzinfo=tz) # Encode with this timezone, then decode to UTC. - encoded = encode({'date': local}, codec_options=options) - self.assertEqual(local.replace(hour=1, tzinfo=None), - decode(encoded)['date']) + encoded = encode({"date": local}, codec_options=options) + self.assertEqual(local.replace(hour=1, tzinfo=None), decode(encoded)["date"]) # It's DST. - local = datetime.datetime(year=2025, month=4, hour=1, day=1, - tzinfo=tz) - encoded = encode({'date': local}, codec_options=options) - self.assertEqual(local.replace(month=3, day=31, hour=23, tzinfo=None), - decode(encoded)['date']) + local = datetime.datetime(year=2025, month=4, hour=1, day=1, tzinfo=tz) + encoded = encode({"date": local}, codec_options=options) + self.assertEqual( + local.replace(month=3, day=31, hour=23, tzinfo=None), decode(encoded)["date"] + ) # Encode UTC, then decode in a different timezone. - encoded = encode({'date': local.replace(tzinfo=utc)}) - decoded = decode(encoded, options)['date'] + encoded = encode({"date": local.replace(tzinfo=utc)}) + decoded = decode(encoded, options)["date"] self.assertEqual(local.replace(hour=3), decoded) self.assertEqual(tz, decoded.tzinfo) # Test round-tripping. self.assertEqual( - local, decode(encode( - {'date': local}, codec_options=options), options)['date']) + local, decode(encode({"date": local}, codec_options=options), options)["date"] + ) # Test around the Unix Epoch. epochs = ( EPOCH_AWARE, - EPOCH_AWARE.astimezone(FixedOffset(120, 'one twenty')), - EPOCH_AWARE.astimezone(FixedOffset(-120, 'minus one twenty')) + EPOCH_AWARE.astimezone(FixedOffset(120, "one twenty")), + EPOCH_AWARE.astimezone(FixedOffset(-120, "minus one twenty")), ) utc_co = CodecOptions(tz_aware=True) for epoch in epochs: - doc = {'epoch': epoch} + doc = {"epoch": epoch} # We always retrieve datetimes in UTC unless told to do otherwise. - self.assertEqual( - EPOCH_AWARE, - decode(encode(doc), codec_options=utc_co)['epoch']) + self.assertEqual(EPOCH_AWARE, decode(encode(doc), codec_options=utc_co)["epoch"]) # Round-trip the epoch. local_co = CodecOptions(tz_aware=True, tzinfo=epoch.tzinfo) - self.assertEqual( - epoch, - decode(encode(doc), codec_options=local_co)['epoch']) + self.assertEqual(epoch, decode(encode(doc), codec_options=local_co)["epoch"]) def test_naive_decode(self): - aware = datetime.datetime(1993, 4, 4, 2, - tzinfo=FixedOffset(555, "SomeZone")) + aware = datetime.datetime(1993, 4, 4, 2, tzinfo=FixedOffset(555, "SomeZone")) naive_utc = (aware - aware.utcoffset()).replace(tzinfo=None) self.assertEqual(datetime.datetime(1993, 4, 3, 16, 45), naive_utc) after = decode(encode({"date": aware}))["date"] @@ -601,32 +626,30 @@ def test_dst(self): d = {"x": datetime.datetime(1993, 4, 4, 2)} self.assertEqual(d, decode(encode(d))) - @unittest.skip('Disabled due to http://bugs.python.org/issue25222') + @unittest.skip("Disabled due to http://bugs.python.org/issue25222") def test_bad_encode(self): - evil_list = {'a': []} - evil_list['a'].append(evil_list) + evil_list = {"a": []} + evil_list["a"].append(evil_list) evil_dict = {} - evil_dict['a'] = evil_dict + evil_dict["a"] = evil_dict for evil_data in [evil_dict, evil_list]: self.assertRaises(Exception, encode, evil_data) def test_overflow(self): self.assertTrue(encode({"x": 9223372036854775807})) - self.assertRaises(OverflowError, encode, - {"x": 9223372036854775808}) + self.assertRaises(OverflowError, encode, {"x": 9223372036854775808}) self.assertTrue(encode({"x": -9223372036854775808})) - self.assertRaises(OverflowError, encode, - {"x": -9223372036854775809}) + self.assertRaises(OverflowError, encode, {"x": -9223372036854775809}) def test_small_long_encode_decode(self): - encoded1 = encode({'x': 256}) - decoded1 = decode(encoded1)['x'] + encoded1 = encode({"x": 256}) + decoded1 = decode(encoded1)["x"] self.assertEqual(256, decoded1) self.assertEqual(type(256), type(decoded1)) - encoded2 = encode({'x': Int64(256)}) - decoded2 = decode(encoded2)['x'] + encoded2 = encode({"x": Int64(256)}) + decoded2 = decode(encoded2)["x"] expected = Int64(256) self.assertEqual(expected, decoded2) self.assertEqual(type(expected), type(decoded2)) @@ -634,18 +657,16 @@ def test_small_long_encode_decode(self): self.assertNotEqual(type(decoded1), type(decoded2)) def test_tuple(self): - self.assertEqual({"tuple": [1, 2]}, - decode(encode({"tuple": (1, 2)}))) + self.assertEqual({"tuple": [1, 2]}, decode(encode({"tuple": (1, 2)}))) def test_uuid(self): id = uuid.uuid4() # The default uuid_representation is UNSPECIFIED - with self.assertRaisesRegex(ValueError, 'cannot encode native uuid'): - bson.decode_all(encode({'uuid': id})) + with self.assertRaisesRegex(ValueError, "cannot encode native uuid"): + bson.decode_all(encode({"uuid": id})) opts = CodecOptions(uuid_representation=UuidRepresentation.STANDARD) - transformed_id = decode(encode({"id": id}, codec_options=opts), - codec_options=opts)["id"] + transformed_id = decode(encode({"id": id}, codec_options=opts), codec_options=opts)["id"] self.assertTrue(isinstance(transformed_id, uuid.UUID)) self.assertEqual(id, transformed_id) self.assertNotEqual(uuid.uuid4(), transformed_id) @@ -662,7 +683,7 @@ def test_uuid_legacy(self): # The C extension was segfaulting on unicode RegExs, so we have this test # that doesn't really test anything but the lack of a segfault. def test_unicode_regex(self): - regex = re.compile('revisi\xf3n') + regex = re.compile("revisi\xf3n") decode(encode({"regex": regex})) def test_non_string_keys(self): @@ -673,12 +694,12 @@ def test_utf8(self): self.assertEqual(w, decode(encode(w))) # b'a\xe9' == "aé".encode("iso-8859-1") - iso8859_bytes = b'a\xe9' + iso8859_bytes = b"a\xe9" y = {"hello": iso8859_bytes} # Stored as BSON binary subtype 0. out = decode(encode(y)) - self.assertTrue(isinstance(out['hello'], bytes)) - self.assertEqual(out['hello'], iso8859_bytes) + self.assertTrue(isinstance(out["hello"], bytes)) + self.assertEqual(out["hello"], iso8859_bytes) def test_null_character(self): doc = {"a": "\x00"} @@ -690,28 +711,27 @@ def test_null_character(self): self.assertRaises(InvalidDocument, encode, {b"\x00": "a"}) self.assertRaises(InvalidDocument, encode, {"\x00": "a"}) - self.assertRaises(InvalidDocument, encode, - {"a": re.compile(b"ab\x00c")}) - self.assertRaises(InvalidDocument, encode, - {"a": re.compile("ab\x00c")}) + self.assertRaises(InvalidDocument, encode, {"a": re.compile(b"ab\x00c")}) + self.assertRaises(InvalidDocument, encode, {"a": re.compile("ab\x00c")}) def test_move_id(self): - self.assertEqual(b"\x19\x00\x00\x00\x02_id\x00\x02\x00\x00\x00a\x00" - b"\x02a\x00\x02\x00\x00\x00a\x00\x00", - encode(SON([("a", "a"), ("_id", "a")]))) - - self.assertEqual(b"\x2c\x00\x00\x00" - b"\x02_id\x00\x02\x00\x00\x00b\x00" - b"\x03b\x00" - b"\x19\x00\x00\x00\x02a\x00\x02\x00\x00\x00a\x00" - b"\x02_id\x00\x02\x00\x00\x00a\x00\x00\x00", - encode(SON([("b", - SON([("a", "a"), ("_id", "a")])), - ("_id", "b")]))) + self.assertEqual( + b"\x19\x00\x00\x00\x02_id\x00\x02\x00\x00\x00a\x00" + b"\x02a\x00\x02\x00\x00\x00a\x00\x00", + encode(SON([("a", "a"), ("_id", "a")])), + ) + + self.assertEqual( + b"\x2c\x00\x00\x00" + b"\x02_id\x00\x02\x00\x00\x00b\x00" + b"\x03b\x00" + b"\x19\x00\x00\x00\x02a\x00\x02\x00\x00\x00a\x00" + b"\x02_id\x00\x02\x00\x00\x00a\x00\x00\x00", + encode(SON([("b", SON([("a", "a"), ("_id", "a")])), ("_id", "b")])), + ) def test_dates(self): - doc = {"early": datetime.datetime(1686, 5, 5), - "late": datetime.datetime(2086, 5, 5)} + doc = {"early": datetime.datetime(1686, 5, 5), "late": datetime.datetime(2086, 5, 5)} try: self.assertEqual(doc, decode(encode(doc))) except ValueError: @@ -724,15 +744,12 @@ def test_dates(self): def test_custom_class(self): self.assertIsInstance(decode(encode({})), dict) self.assertNotIsInstance(decode(encode({})), SON) - self.assertIsInstance( - decode(encode({}), CodecOptions(document_class=SON)), SON) + self.assertIsInstance(decode(encode({}), CodecOptions(document_class=SON)), SON) - self.assertEqual( - 1, decode(encode({"x": 1}), CodecOptions(document_class=SON))["x"]) + self.assertEqual(1, decode(encode({"x": 1}), CodecOptions(document_class=SON))["x"]) x = encode({"x": [{"y": 1}]}) - self.assertIsInstance( - decode(x, CodecOptions(document_class=SON))["x"][0], SON) + self.assertIsInstance(decode(x, CodecOptions(document_class=SON))["x"][0], SON) def test_subclasses(self): # make sure we can serialize subclasses of native Python types. @@ -745,9 +762,7 @@ class _myfloat(float): class _myunicode(str): pass - d = {'a': _myint(42), 'b': _myfloat(63.9), - 'c': _myunicode('hello world') - } + d = {"a": _myint(42), "b": _myfloat(63.9), "c": _myunicode("hello world")} d2 = decode(encode(d)) for key, value in d2.items(): orig_value = d[key] @@ -757,65 +772,60 @@ class _myunicode(str): def test_ordered_dict(self): d = OrderedDict([("one", 1), ("two", 2), ("three", 3), ("four", 4)]) - self.assertEqual( - d, decode(encode(d), CodecOptions(document_class=OrderedDict))) + self.assertEqual(d, decode(encode(d), CodecOptions(document_class=OrderedDict))) def test_bson_regex(self): # Invalid Python regex, though valid PCRE. - bson_re1 = Regex(r'[\w-\.]') - self.assertEqual(r'[\w-\.]', bson_re1.pattern) + bson_re1 = Regex(r"[\w-\.]") + self.assertEqual(r"[\w-\.]", bson_re1.pattern) self.assertEqual(0, bson_re1.flags) - doc1 = {'r': bson_re1} + doc1 = {"r": bson_re1} doc1_bson = ( - b'\x11\x00\x00\x00' # document length - b'\x0br\x00[\\w-\\.]\x00\x00' # r: regex - b'\x00') # document terminator + b"\x11\x00\x00\x00" b"\x0br\x00[\\w-\\.]\x00\x00" b"\x00" # document length # r: regex + ) # document terminator self.assertEqual(doc1_bson, encode(doc1)) self.assertEqual(doc1, decode(doc1_bson)) # Valid Python regex, with flags. - re2 = re.compile('.*', re.I | re.M | re.S | re.U | re.X) - bson_re2 = Regex('.*', re.I | re.M | re.S | re.U | re.X) + re2 = re.compile(".*", re.I | re.M | re.S | re.U | re.X) + bson_re2 = Regex(".*", re.I | re.M | re.S | re.U | re.X) - doc2_with_re = {'r': re2} - doc2_with_bson_re = {'r': bson_re2} + doc2_with_re = {"r": re2} + doc2_with_bson_re = {"r": bson_re2} doc2_bson = ( - b"\x11\x00\x00\x00" # document length - b"\x0br\x00.*\x00imsux\x00" # r: regex - b"\x00") # document terminator + b"\x11\x00\x00\x00" b"\x0br\x00.*\x00imsux\x00" b"\x00" # document length # r: regex + ) # document terminator self.assertEqual(doc2_bson, encode(doc2_with_re)) self.assertEqual(doc2_bson, encode(doc2_with_bson_re)) - self.assertEqual(re2.pattern, decode(doc2_bson)['r'].pattern) - self.assertEqual(re2.flags, decode(doc2_bson)['r'].flags) + self.assertEqual(re2.pattern, decode(doc2_bson)["r"].pattern) + self.assertEqual(re2.flags, decode(doc2_bson)["r"].flags) def test_regex_from_native(self): - self.assertEqual('.*', Regex.from_native(re.compile('.*')).pattern) - self.assertEqual(0, Regex.from_native(re.compile(b'')).flags) + self.assertEqual(".*", Regex.from_native(re.compile(".*")).pattern) + self.assertEqual(0, Regex.from_native(re.compile(b"")).flags) - regex = re.compile(b'', re.I | re.L | re.M | re.S | re.X) - self.assertEqual( - re.I | re.L | re.M | re.S | re.X, - Regex.from_native(regex).flags) + regex = re.compile(b"", re.I | re.L | re.M | re.S | re.X) + self.assertEqual(re.I | re.L | re.M | re.S | re.X, Regex.from_native(regex).flags) - unicode_regex = re.compile('', re.U) + unicode_regex = re.compile("", re.U) self.assertEqual(re.U, Regex.from_native(unicode_regex).flags) def test_regex_hash(self): - self.assertRaises(TypeError, hash, Regex('hello')) + self.assertRaises(TypeError, hash, Regex("hello")) def test_regex_comparison(self): - re1 = Regex('a') - re2 = Regex('b') + re1 = Regex("a") + re2 = Regex("b") self.assertNotEqual(re1, re2) - re1 = Regex('a', re.I) - re2 = Regex('a', re.M) + re1 = Regex("a", re.I) + re2 = Regex("a", re.M) self.assertNotEqual(re1, re2) - re1 = Regex('a', re.I) - re2 = Regex('a', re.I) + re1 = Regex("a", re.I) + re2 = Regex("a", re.I) self.assertEqual(re1, re2) def test_exception_wrapping(self): @@ -823,13 +833,12 @@ def test_exception_wrapping(self): # the final exception always matches InvalidBSON. # {'s': '\xff'}, will throw attempting to decode utf-8. - bad_doc = b'\x0f\x00\x00\x00\x02s\x00\x03\x00\x00\x00\xff\x00\x00\x00' + bad_doc = b"\x0f\x00\x00\x00\x02s\x00\x03\x00\x00\x00\xff\x00\x00\x00" with self.assertRaises(InvalidBSON) as context: decode_all(bad_doc) - self.assertIn("codec can't decode byte 0xff", - str(context.exception)) + self.assertIn("codec can't decode byte 0xff", str(context.exception)) def test_minkey_maxkey_comparison(self): # MinKey's <, <=, >, >=, !=, and ==. @@ -903,29 +912,25 @@ def test_timestamp_comparison(self): self.assertFalse(Timestamp(1, 0) > Timestamp(1, 0)) def test_timestamp_highorder_bits(self): - doc = {'a': Timestamp(0xFFFFFFFF, 0xFFFFFFFF)} - doc_bson = (b'\x10\x00\x00\x00' - b'\x11a\x00\xff\xff\xff\xff\xff\xff\xff\xff' - b'\x00') + doc = {"a": Timestamp(0xFFFFFFFF, 0xFFFFFFFF)} + doc_bson = b"\x10\x00\x00\x00" b"\x11a\x00\xff\xff\xff\xff\xff\xff\xff\xff" b"\x00" self.assertEqual(doc_bson, encode(doc)) self.assertEqual(doc, decode(doc_bson)) def test_bad_id_keys(self): - self.assertRaises(InvalidDocument, encode, - {"_id": {"$bad": 123}}, True) - self.assertRaises(InvalidDocument, encode, - {"_id": {'$oid': "52d0b971b3ba219fdeb4170e"}}, True) - encode({"_id": {'$oid': "52d0b971b3ba219fdeb4170e"}}) + self.assertRaises(InvalidDocument, encode, {"_id": {"$bad": 123}}, True) + self.assertRaises( + InvalidDocument, encode, {"_id": {"$oid": "52d0b971b3ba219fdeb4170e"}}, True + ) + encode({"_id": {"$oid": "52d0b971b3ba219fdeb4170e"}}) def test_bson_encode_thread_safe(self): - def target(i): for j in range(1000): - my_int = type('MyInt_%s_%s' % (i, j), (int,), {}) - bson.encode({'my_int': my_int()}) + my_int = type("MyInt_%s_%s" % (i, j), (int,), {}) + bson.encode({"my_int": my_int()}) - threads = [ExceptionCatchingThread(target=target, args=(i,)) - for i in range(3)] + threads = [ExceptionCatchingThread(target=target, args=(i,)) for i in range(3)] for t in threads: t.start() @@ -943,11 +948,11 @@ def __init__(self, val): def __repr__(self): return repr(self.val) - self.assertEqual('1', repr(Wrapper(1))) + self.assertEqual("1", repr(Wrapper(1))) with self.assertRaisesRegex( - InvalidDocument, - "cannot encode object: 1, of type: " + repr(Wrapper)): - encode({'t': Wrapper(1)}) + InvalidDocument, "cannot encode object: 1, of type: " + repr(Wrapper) + ): + encode({"t": Wrapper(1)}) class TestCodecOptions(unittest.TestCase): @@ -965,69 +970,67 @@ def test_uuid_representation(self): self.assertRaises(ValueError, CodecOptions, uuid_representation=2) def test_tzinfo(self): - self.assertRaises(TypeError, CodecOptions, tzinfo='pacific') - tz = FixedOffset(42, 'forty-two') + self.assertRaises(TypeError, CodecOptions, tzinfo="pacific") + tz = FixedOffset(42, "forty-two") self.assertRaises(ValueError, CodecOptions, tzinfo=tz) self.assertEqual(tz, CodecOptions(tz_aware=True, tzinfo=tz).tzinfo) def test_codec_options_repr(self): - r = ("CodecOptions(document_class=dict, tz_aware=False, " - "uuid_representation=UuidRepresentation.UNSPECIFIED, " - "unicode_decode_error_handler='strict', " - "tzinfo=None, type_registry=TypeRegistry(type_codecs=[], " - "fallback_encoder=None))") + r = ( + "CodecOptions(document_class=dict, tz_aware=False, " + "uuid_representation=UuidRepresentation.UNSPECIFIED, " + "unicode_decode_error_handler='strict', " + "tzinfo=None, type_registry=TypeRegistry(type_codecs=[], " + "fallback_encoder=None))" + ) self.assertEqual(r, repr(CodecOptions())) def test_decode_all_defaults(self): # Test decode_all()'s default document_class is dict and tz_aware is # False. - doc = {'sub_document': {}, - 'dt': datetime.datetime.utcnow()} + doc = {"sub_document": {}, "dt": datetime.datetime.utcnow()} decoded = bson.decode_all(bson.encode(doc))[0] - self.assertIsInstance(decoded['sub_document'], dict) - self.assertIsNone(decoded['dt'].tzinfo) + self.assertIsInstance(decoded["sub_document"], dict) + self.assertIsNone(decoded["dt"].tzinfo) # The default uuid_representation is UNSPECIFIED - with self.assertRaisesRegex(ValueError, 'cannot encode native uuid'): - bson.decode_all(bson.encode({'uuid': uuid.uuid4()})) + with self.assertRaisesRegex(ValueError, "cannot encode native uuid"): + bson.decode_all(bson.encode({"uuid": uuid.uuid4()})) def test_unicode_decode_error_handler(self): enc = encode({"keystr": "foobar"}) # Test handling of bad key value, bad string value, and both. - invalid_key = enc[:7] + b'\xe9' + enc[8:] - invalid_val = enc[:18] + b'\xe9' + enc[19:] - invalid_both = enc[:7] + b'\xe9' + enc[8:18] + b'\xe9' + enc[19:] + invalid_key = enc[:7] + b"\xe9" + enc[8:] + invalid_val = enc[:18] + b"\xe9" + enc[19:] + invalid_both = enc[:7] + b"\xe9" + enc[8:18] + b"\xe9" + enc[19:] # Ensure that strict mode raises an error. for invalid in [invalid_key, invalid_val, invalid_both]: - self.assertRaises(InvalidBSON, decode, invalid, CodecOptions( - unicode_decode_error_handler="strict")) + self.assertRaises( + InvalidBSON, decode, invalid, CodecOptions(unicode_decode_error_handler="strict") + ) self.assertRaises(InvalidBSON, decode, invalid, CodecOptions()) self.assertRaises(InvalidBSON, decode, invalid) # Test all other error handlers. - for handler in ['replace', 'backslashreplace', 'surrogateescape', - 'ignore']: - expected_key = b'ke\xe9str'.decode('utf-8', handler) - expected_val = b'fo\xe9bar'.decode('utf-8', handler) - doc = decode(invalid_key, - CodecOptions(unicode_decode_error_handler=handler)) + for handler in ["replace", "backslashreplace", "surrogateescape", "ignore"]: + expected_key = b"ke\xe9str".decode("utf-8", handler) + expected_val = b"fo\xe9bar".decode("utf-8", handler) + doc = decode(invalid_key, CodecOptions(unicode_decode_error_handler=handler)) self.assertEqual(doc, {expected_key: "foobar"}) - doc = decode(invalid_val, - CodecOptions(unicode_decode_error_handler=handler)) + doc = decode(invalid_val, CodecOptions(unicode_decode_error_handler=handler)) self.assertEqual(doc, {"keystr": expected_val}) - doc = decode(invalid_both, - CodecOptions(unicode_decode_error_handler=handler)) + doc = decode(invalid_both, CodecOptions(unicode_decode_error_handler=handler)) self.assertEqual(doc, {expected_key: expected_val}) # Test handling bad error mode. - dec = decode(enc, - CodecOptions(unicode_decode_error_handler="junk")) + dec = decode(enc, CodecOptions(unicode_decode_error_handler="junk")) self.assertEqual(dec, {"keystr": "foobar"}) - self.assertRaises(InvalidBSON, decode, invalid_both, CodecOptions( - unicode_decode_error_handler="junk")) + self.assertRaises( + InvalidBSON, decode, invalid_both, CodecOptions(unicode_decode_error_handler="junk") + ) def round_trip_pickle(self, obj, pickled_with_older): pickled_with_older_obj = pickle.loads(pickled_with_older) @@ -1039,61 +1042,75 @@ def round_trip_pickle(self, obj, pickled_with_older): def test_regex_pickling(self): reg = Regex(".?") - pickled_with_3 = (b'\x80\x04\x959\x00\x00\x00\x00\x00\x00\x00\x8c\n' - b'bson.regex\x94\x8c\x05Regex\x94\x93\x94)\x81\x94}' - b'\x94(\x8c\x07pattern\x94\x8c\x02.?\x94\x8c\x05flag' - b's\x94K\x00ub.') + pickled_with_3 = ( + b"\x80\x04\x959\x00\x00\x00\x00\x00\x00\x00\x8c\n" + b"bson.regex\x94\x8c\x05Regex\x94\x93\x94)\x81\x94}" + b"\x94(\x8c\x07pattern\x94\x8c\x02.?\x94\x8c\x05flag" + b"s\x94K\x00ub." + ) self.round_trip_pickle(reg, pickled_with_3) def test_timestamp_pickling(self): ts = Timestamp(0, 1) - pickled_with_3 = (b'\x80\x04\x95Q\x00\x00\x00\x00\x00\x00\x00\x8c' - b'\x0ebson.timestamp\x94\x8c\tTimestamp\x94\x93\x94)' - b'\x81\x94}\x94(' - b'\x8c\x10_Timestamp__time\x94K\x00\x8c' - b'\x0f_Timestamp__inc\x94K\x01ub.') + pickled_with_3 = ( + b"\x80\x04\x95Q\x00\x00\x00\x00\x00\x00\x00\x8c" + b"\x0ebson.timestamp\x94\x8c\tTimestamp\x94\x93\x94)" + b"\x81\x94}\x94(" + b"\x8c\x10_Timestamp__time\x94K\x00\x8c" + b"\x0f_Timestamp__inc\x94K\x01ub." + ) self.round_trip_pickle(ts, pickled_with_3) def test_dbref_pickling(self): dbr = DBRef("foo", 5) - pickled_with_3 = (b'\x80\x04\x95q\x00\x00\x00\x00\x00\x00\x00\x8c\n' - b'bson.dbref\x94\x8c\x05DBRef\x94\x93\x94)\x81\x94}' - b'\x94(\x8c\x12_DBRef__collection\x94\x8c\x03foo\x94' - b'\x8c\n_DBRef__id\x94K\x05\x8c\x10_DBRef__database' - b'\x94N\x8c\x0e_DBRef__kwargs\x94}\x94ub.') + pickled_with_3 = ( + b"\x80\x04\x95q\x00\x00\x00\x00\x00\x00\x00\x8c\n" + b"bson.dbref\x94\x8c\x05DBRef\x94\x93\x94)\x81\x94}" + b"\x94(\x8c\x12_DBRef__collection\x94\x8c\x03foo\x94" + b"\x8c\n_DBRef__id\x94K\x05\x8c\x10_DBRef__database" + b"\x94N\x8c\x0e_DBRef__kwargs\x94}\x94ub." + ) self.round_trip_pickle(dbr, pickled_with_3) - dbr = DBRef("foo", 5, database='db', kwargs1=None) - pickled_with_3 = (b'\x80\x04\x95\x81\x00\x00\x00\x00\x00\x00\x00\x8c' - b'\nbson.dbref\x94\x8c\x05DBRef\x94\x93\x94)\x81\x94}' - b'\x94(\x8c\x12_DBRef__collection\x94\x8c\x03foo\x94' - b'\x8c\n_DBRef__id\x94K\x05\x8c\x10_DBRef__database' - b'\x94\x8c\x02db\x94\x8c\x0e_DBRef__kwargs\x94}\x94' - b'\x8c\x07kwargs1\x94Nsub.') + dbr = DBRef("foo", 5, database="db", kwargs1=None) + pickled_with_3 = ( + b"\x80\x04\x95\x81\x00\x00\x00\x00\x00\x00\x00\x8c" + b"\nbson.dbref\x94\x8c\x05DBRef\x94\x93\x94)\x81\x94}" + b"\x94(\x8c\x12_DBRef__collection\x94\x8c\x03foo\x94" + b"\x8c\n_DBRef__id\x94K\x05\x8c\x10_DBRef__database" + b"\x94\x8c\x02db\x94\x8c\x0e_DBRef__kwargs\x94}\x94" + b"\x8c\x07kwargs1\x94Nsub." + ) self.round_trip_pickle(dbr, pickled_with_3) def test_minkey_pickling(self): mink = MinKey() - pickled_with_3 = (b'\x80\x04\x95\x1e\x00\x00\x00\x00\x00\x00\x00\x8c' - b'\x0cbson.min_key\x94\x8c\x06MinKey\x94\x93\x94)' - b'\x81\x94.') + pickled_with_3 = ( + b"\x80\x04\x95\x1e\x00\x00\x00\x00\x00\x00\x00\x8c" + b"\x0cbson.min_key\x94\x8c\x06MinKey\x94\x93\x94)" + b"\x81\x94." + ) self.round_trip_pickle(mink, pickled_with_3) def test_maxkey_pickling(self): maxk = MaxKey() - pickled_with_3 = (b'\x80\x04\x95\x1e\x00\x00\x00\x00\x00\x00\x00\x8c' - b'\x0cbson.max_key\x94\x8c\x06MaxKey\x94\x93\x94)' - b'\x81\x94.') + pickled_with_3 = ( + b"\x80\x04\x95\x1e\x00\x00\x00\x00\x00\x00\x00\x8c" + b"\x0cbson.max_key\x94\x8c\x06MaxKey\x94\x93\x94)" + b"\x81\x94." + ) self.round_trip_pickle(maxk, pickled_with_3) def test_int64_pickling(self): i64 = Int64(9) - pickled_with_3 = (b'\x80\x04\x95\x1e\x00\x00\x00\x00\x00\x00\x00\x8c\n' - b'bson.int64\x94\x8c\x05Int64\x94\x93\x94K\t\x85\x94' - b'\x81\x94.') + pickled_with_3 = ( + b"\x80\x04\x95\x1e\x00\x00\x00\x00\x00\x00\x00\x8c\n" + b"bson.int64\x94\x8c\x05Int64\x94\x93\x94K\t\x85\x94" + b"\x81\x94." + ) self.round_trip_pickle(i64, pickled_with_3) diff --git a/test/test_bson_corpus.py b/test/test_bson_corpus.py index cbb702e405..4a46276573 100644 --- a/test/test_bson_corpus.py +++ b/test/test_bson_corpus.py @@ -21,54 +21,52 @@ import json import os import sys - from decimal import DecimalException sys.path[0:0] = [""] +from test import unittest + from bson import decode, encode, json_util from bson.binary import STANDARD from bson.codec_options import CodecOptions -from bson.decimal128 import Decimal128 from bson.dbref import DBRef +from bson.decimal128 import Decimal128 from bson.errors import InvalidBSON, InvalidDocument, InvalidId from bson.json_util import JSONMode from bson.son import SON -from test import unittest - -_TEST_PATH = os.path.join( - os.path.dirname(os.path.realpath(__file__)), 'bson_corpus') +_TEST_PATH = os.path.join(os.path.dirname(os.path.realpath(__file__)), "bson_corpus") _TESTS_TO_SKIP = { # Python cannot decode dates after year 9999. - 'Y10K', + "Y10K", } _NON_PARSE_ERRORS = { # {"$date": } is our legacy format which we still need to parse. - 'Bad $date (number, not string or hash)', + "Bad $date (number, not string or hash)", # This variant of $numberLong may have been generated by an old version # of mongoexport. - 'Bad $numberLong (number, not string)', + "Bad $numberLong (number, not string)", # Python's UUID constructor is very permissive. - '$uuid invalid value--misplaced hyphens', + "$uuid invalid value--misplaced hyphens", # We parse Regex flags with extra characters, including nulls. - 'Null byte in $regularExpression options', + "Null byte in $regularExpression options", } _IMPLCIT_LOSSY_TESTS = { # JSON decodes top-level $ref+$id as a DBRef but BSON doesn't. - 'Document with key names similar to those of a DBRef' + "Document with key names similar to those of a DBRef" } _DEPRECATED_BSON_TYPES = { # Symbol - '0x0E': str, + "0x0E": str, # Undefined - '0x06': type(None), + "0x06": type(None), # DBPointer - '0x0C': DBRef + "0x0C": DBRef, } @@ -78,27 +76,23 @@ # We normally encode UUID as binary subtype 0x03, # but we'll need to encode to subtype 0x04 for one of the tests. codec_options_uuid_04 = codec_options._replace(uuid_representation=STANDARD) -json_options_uuid_04 = json_util.JSONOptions(json_mode=JSONMode.CANONICAL, - uuid_representation=STANDARD) +json_options_uuid_04 = json_util.JSONOptions( + json_mode=JSONMode.CANONICAL, uuid_representation=STANDARD +) json_options_iso8601 = json_util.JSONOptions( - datetime_representation=json_util.DatetimeRepresentation.ISO8601, - json_mode=JSONMode.LEGACY) -to_extjson = functools.partial(json_util.dumps, - json_options=json_util.CANONICAL_JSON_OPTIONS) -to_extjson_uuid_04 = functools.partial(json_util.dumps, - json_options=json_options_uuid_04) -to_extjson_iso8601 = functools.partial(json_util.dumps, - json_options=json_options_iso8601) -to_relaxed_extjson = functools.partial( - json_util.dumps, json_options=json_util.RELAXED_JSON_OPTIONS) -to_bson_uuid_04 = functools.partial(encode, - codec_options=codec_options_uuid_04) + datetime_representation=json_util.DatetimeRepresentation.ISO8601, json_mode=JSONMode.LEGACY +) +to_extjson = functools.partial(json_util.dumps, json_options=json_util.CANONICAL_JSON_OPTIONS) +to_extjson_uuid_04 = functools.partial(json_util.dumps, json_options=json_options_uuid_04) +to_extjson_iso8601 = functools.partial(json_util.dumps, json_options=json_options_iso8601) +to_relaxed_extjson = functools.partial(json_util.dumps, json_options=json_util.RELAXED_JSON_OPTIONS) +to_bson_uuid_04 = functools.partial(encode, codec_options=codec_options_uuid_04) to_bson = functools.partial(encode, codec_options=codec_options) decode_bson = functools.partial(decode, codec_options=codec_options_no_tzaware) decode_extjson = functools.partial( json_util.loads, - json_options=json_util.JSONOptions(json_mode=JSONMode.CANONICAL, - document_class=SON)) + json_options=json_util.JSONOptions(json_mode=JSONMode.CANONICAL, document_class=SON), +) loads = functools.partial(json.loads, object_pairs_hook=SON) @@ -113,65 +107,62 @@ def assertJsonEqual(self, first, second, msg=None): def create_test(case_spec): - bson_type = case_spec['bson_type'] + bson_type = case_spec["bson_type"] # Test key is absent when testing top-level documents. - test_key = case_spec.get('test_key') - deprecated = case_spec.get('deprecated') + test_key = case_spec.get("test_key") + deprecated = case_spec.get("deprecated") def run_test(self): - for valid_case in case_spec.get('valid', []): - description = valid_case['description'] + for valid_case in case_spec.get("valid", []): + description = valid_case["description"] if description in _TESTS_TO_SKIP: continue # Special case for testing encoding UUID as binary subtype 0x04. - if description.startswith('subtype 0x04'): + if description.startswith("subtype 0x04"): encode_extjson = to_extjson_uuid_04 encode_bson = to_bson_uuid_04 else: encode_extjson = to_extjson encode_bson = to_bson - cB = binascii.unhexlify(valid_case['canonical_bson'].encode('utf8')) - cEJ = valid_case['canonical_extjson'] - rEJ = valid_case.get('relaxed_extjson') - dEJ = valid_case.get('degenerate_extjson') + cB = binascii.unhexlify(valid_case["canonical_bson"].encode("utf8")) + cEJ = valid_case["canonical_extjson"] + rEJ = valid_case.get("relaxed_extjson") + dEJ = valid_case.get("degenerate_extjson") if description in _IMPLCIT_LOSSY_TESTS: - valid_case.setdefault('lossy', True) - lossy = valid_case.get('lossy') + valid_case.setdefault("lossy", True) + lossy = valid_case.get("lossy") # BSON double, use lowercase 'e+' to match Python's encoding - if bson_type == '0x01': - cEJ = cEJ.replace('E+', 'e+') + if bson_type == "0x01": + cEJ = cEJ.replace("E+", "e+") decoded_bson = decode_bson(cB) if not lossy: # Make sure we can parse the legacy (default) JSON format. legacy_json = json_util.dumps( - decoded_bson, json_options=json_util.LEGACY_JSON_OPTIONS) - self.assertEqual( - decode_extjson(legacy_json), decoded_bson, description) + decoded_bson, json_options=json_util.LEGACY_JSON_OPTIONS + ) + self.assertEqual(decode_extjson(legacy_json), decoded_bson, description) if deprecated: - if 'converted_bson' in valid_case: - converted_bson = binascii.unhexlify( - valid_case['converted_bson'].encode('utf8')) + if "converted_bson" in valid_case: + converted_bson = binascii.unhexlify(valid_case["converted_bson"].encode("utf8")) self.assertEqual(encode_bson(decoded_bson), converted_bson) self.assertJsonEqual( - encode_extjson(decode_bson(converted_bson)), - valid_case['converted_extjson']) + encode_extjson(decode_bson(converted_bson)), valid_case["converted_extjson"] + ) # Make sure we can decode the type. self.assertEqual(decoded_bson, decode_extjson(cEJ)) if test_key is not None: - self.assertIsInstance(decoded_bson[test_key], - _DEPRECATED_BSON_TYPES[bson_type]) + self.assertIsInstance(decoded_bson[test_key], _DEPRECATED_BSON_TYPES[bson_type]) continue # Jython can't handle NaN with a payload from # struct.(un)pack if endianness is specified in the format string. - if not (sys.platform.startswith("java") and - description == 'NaN with payload'): + if not (sys.platform.startswith("java") and description == "NaN with payload"): # Test round-tripping canonical bson. self.assertEqual(encode_bson(decoded_bson), cB, description) self.assertJsonEqual(encode_extjson(decoded_bson), cEJ) @@ -183,8 +174,8 @@ def run_test(self): self.assertEqual(encode_bson(decoded_json), cB) # Test round-tripping degenerate bson. - if 'degenerate_bson' in valid_case: - dB = binascii.unhexlify(valid_case['degenerate_bson'].encode('utf8')) + if "degenerate_bson" in valid_case: + dB = binascii.unhexlify(valid_case["degenerate_bson"].encode("utf8")) self.assertEqual(encode_bson(decode_bson(dB)), cB) # Test round-tripping degenerate extended json. @@ -200,53 +191,48 @@ def run_test(self): decoded_json = decode_extjson(rEJ) self.assertJsonEqual(to_relaxed_extjson(decoded_json), rEJ) - for decode_error_case in case_spec.get('decodeErrors', []): + for decode_error_case in case_spec.get("decodeErrors", []): with self.assertRaises(InvalidBSON): - decode_bson( - binascii.unhexlify(decode_error_case['bson'].encode('utf8'))) + decode_bson(binascii.unhexlify(decode_error_case["bson"].encode("utf8"))) - for parse_error_case in case_spec.get('parseErrors', []): - description = parse_error_case['description'] + for parse_error_case in case_spec.get("parseErrors", []): + description = parse_error_case["description"] if description in _NON_PARSE_ERRORS: - decode_extjson(parse_error_case['string']) + decode_extjson(parse_error_case["string"]) continue - if bson_type == '0x13': - self.assertRaises( - DecimalException, Decimal128, parse_error_case['string']) - elif bson_type == '0x00': + if bson_type == "0x13": + self.assertRaises(DecimalException, Decimal128, parse_error_case["string"]) + elif bson_type == "0x00": try: - doc = decode_extjson(parse_error_case['string']) + doc = decode_extjson(parse_error_case["string"]) # Null bytes are validated when encoding to BSON. - if 'Null' in description: + if "Null" in description: to_bson(doc) - raise AssertionError('exception not raised for test ' - 'case: ' + description) - except (ValueError, KeyError, TypeError, InvalidId, - InvalidDocument): + raise AssertionError("exception not raised for test " "case: " + description) + except (ValueError, KeyError, TypeError, InvalidId, InvalidDocument): pass - elif bson_type == '0x05': + elif bson_type == "0x05": try: - decode_extjson(parse_error_case['string']) - raise AssertionError('exception not raised for test ' - 'case: ' + description) + decode_extjson(parse_error_case["string"]) + raise AssertionError("exception not raised for test " "case: " + description) except (TypeError, ValueError): pass else: - raise AssertionError('cannot test parseErrors for type ' + - bson_type) + raise AssertionError("cannot test parseErrors for type " + bson_type) + return run_test def create_tests(): - for filename in glob.glob(os.path.join(_TEST_PATH, '*.json')): + for filename in glob.glob(os.path.join(_TEST_PATH, "*.json")): test_suffix, _ = os.path.splitext(os.path.basename(filename)) - with codecs.open(filename, encoding='utf-8') as bson_test_file: + with codecs.open(filename, encoding="utf-8") as bson_test_file: test_method = create_test(json.load(bson_test_file)) - setattr(TestBSONCorpus, 'test_' + test_suffix, test_method) + setattr(TestBSONCorpus, "test_" + test_suffix, test_method) create_tests() -if __name__ == '__main__': +if __name__ == "__main__": unittest.main() diff --git a/test/test_bulk.py b/test/test_bulk.py index 08740a437e..46be863c9d 100644 --- a/test/test_bulk.py +++ b/test/test_bulk.py @@ -16,31 +16,34 @@ import sys import uuid + from bson.binary import UuidRepresentation from bson.codec_options import CodecOptions sys.path[0:0] = [""] +from test import IntegrationTest, client_context, unittest +from test.utils import ( + remove_all_users, + rs_or_single_client_noauth, + single_client, + wait_until, +) + from bson import Binary from bson.objectid import ObjectId from pymongo.common import partition_node -from pymongo.errors import (BulkWriteError, - ConfigurationError, - InvalidOperation, - OperationFailure) +from pymongo.errors import ( + BulkWriteError, + ConfigurationError, + InvalidOperation, + OperationFailure, +) from pymongo.operations import * from pymongo.write_concern import WriteConcern -from test import (client_context, - unittest, - IntegrationTest) -from test.utils import (remove_all_users, - rs_or_single_client_noauth, - single_client, - wait_until) class BulkTestBase(IntegrationTest): - @classmethod def setUpClass(cls): super(BulkTestBase, cls).setUpClass() @@ -54,87 +57,91 @@ def setUp(self): def assertEqualResponse(self, expected, actual): """Compare response from bulk.execute() to expected response.""" for key, value in expected.items(): - if key == 'nModified': - self.assertEqual(value, actual['nModified']) - elif key == 'upserted': + if key == "nModified": + self.assertEqual(value, actual["nModified"]) + elif key == "upserted": expected_upserts = value - actual_upserts = actual['upserted'] + actual_upserts = actual["upserted"] self.assertEqual( - len(expected_upserts), len(actual_upserts), - 'Expected %d elements in "upserted", got %d' % ( - len(expected_upserts), len(actual_upserts))) + len(expected_upserts), + len(actual_upserts), + 'Expected %d elements in "upserted", got %d' + % (len(expected_upserts), len(actual_upserts)), + ) for e, a in zip(expected_upserts, actual_upserts): self.assertEqualUpsert(e, a) - elif key == 'writeErrors': + elif key == "writeErrors": expected_errors = value - actual_errors = actual['writeErrors'] + actual_errors = actual["writeErrors"] self.assertEqual( - len(expected_errors), len(actual_errors), - 'Expected %d elements in "writeErrors", got %d' % ( - len(expected_errors), len(actual_errors))) + len(expected_errors), + len(actual_errors), + 'Expected %d elements in "writeErrors", got %d' + % (len(expected_errors), len(actual_errors)), + ) for e, a in zip(expected_errors, actual_errors): self.assertEqualWriteError(e, a) else: self.assertEqual( - actual.get(key), value, - '%r value of %r does not match expected %r' % - (key, actual.get(key), value)) + actual.get(key), + value, + "%r value of %r does not match expected %r" % (key, actual.get(key), value), + ) def assertEqualUpsert(self, expected, actual): """Compare bulk.execute()['upserts'] to expected value. Like: {'index': 0, '_id': ObjectId()} """ - self.assertEqual(expected['index'], actual['index']) - if expected['_id'] == '...': + self.assertEqual(expected["index"], actual["index"]) + if expected["_id"] == "...": # Unspecified value. - self.assertTrue('_id' in actual) + self.assertTrue("_id" in actual) else: - self.assertEqual(expected['_id'], actual['_id']) + self.assertEqual(expected["_id"], actual["_id"]) def assertEqualWriteError(self, expected, actual): """Compare bulk.execute()['writeErrors'] to expected value. Like: {'index': 0, 'code': 123, 'errmsg': '...', 'op': { ... }} """ - self.assertEqual(expected['index'], actual['index']) - self.assertEqual(expected['code'], actual['code']) - if expected['errmsg'] == '...': + self.assertEqual(expected["index"], actual["index"]) + self.assertEqual(expected["code"], actual["code"]) + if expected["errmsg"] == "...": # Unspecified value. - self.assertTrue('errmsg' in actual) + self.assertTrue("errmsg" in actual) else: - self.assertEqual(expected['errmsg'], actual['errmsg']) + self.assertEqual(expected["errmsg"], actual["errmsg"]) - expected_op = expected['op'].copy() - actual_op = actual['op'].copy() - if expected_op.get('_id') == '...': + expected_op = expected["op"].copy() + actual_op = actual["op"].copy() + if expected_op.get("_id") == "...": # Unspecified _id. - self.assertTrue('_id' in actual_op) - actual_op.pop('_id') - expected_op.pop('_id') + self.assertTrue("_id" in actual_op) + actual_op.pop("_id") + expected_op.pop("_id") self.assertEqual(expected_op, actual_op) class TestBulk(BulkTestBase): - def test_empty(self): self.assertRaises(InvalidOperation, self.coll.bulk_write, []) def test_insert(self): expected = { - 'nMatched': 0, - 'nModified': 0, - 'nUpserted': 0, - 'nInserted': 1, - 'nRemoved': 0, - 'upserted': [], - 'writeErrors': [], - 'writeConcernErrors': [] + "nMatched": 0, + "nModified": 0, + "nUpserted": 0, + "nInserted": 1, + "nRemoved": 0, + "upserted": [], + "writeErrors": [], + "writeConcernErrors": [], } result = self.coll.bulk_write([InsertOne({})]) @@ -145,14 +152,14 @@ def test_insert(self): def _test_update_many(self, update): expected = { - 'nMatched': 2, - 'nModified': 2, - 'nUpserted': 0, - 'nInserted': 0, - 'nRemoved': 0, - 'upserted': [], - 'writeErrors': [], - 'writeConcernErrors': [] + "nMatched": 2, + "nModified": 2, + "nUpserted": 0, + "nInserted": 0, + "nRemoved": 0, + "upserted": [], + "writeErrors": [], + "writeConcernErrors": [], } self.coll.insert_many([{}, {}]) @@ -162,11 +169,11 @@ def _test_update_many(self, update): self.assertTrue(result.modified_count in (2, None)) def test_update_many(self): - self._test_update_many({'$set': {'foo': 'bar'}}) + self._test_update_many({"$set": {"foo": "bar"}}) @client_context.require_version_min(4, 1, 11) def test_update_many_pipeline(self): - self._test_update_many([{'$set': {'foo': 'bar'}}]) + self._test_update_many([{"$set": {"foo": "bar"}}]) def test_array_filters_validation(self): self.assertRaises(TypeError, UpdateMany, {}, {}, array_filters={}) @@ -174,23 +181,21 @@ def test_array_filters_validation(self): def test_array_filters_unacknowledged(self): coll = self.coll_w0 - update_one = UpdateOne( - {}, {'$set': {'y.$[i].b': 5}}, array_filters=[{'i.b': 1}]) - update_many = UpdateMany( - {}, {'$set': {'y.$[i].b': 5}}, array_filters=[{'i.b': 1}]) + update_one = UpdateOne({}, {"$set": {"y.$[i].b": 5}}, array_filters=[{"i.b": 1}]) + update_many = UpdateMany({}, {"$set": {"y.$[i].b": 5}}, array_filters=[{"i.b": 1}]) self.assertRaises(ConfigurationError, coll.bulk_write, [update_one]) self.assertRaises(ConfigurationError, coll.bulk_write, [update_many]) def _test_update_one(self, update): expected = { - 'nMatched': 1, - 'nModified': 1, - 'nUpserted': 0, - 'nInserted': 0, - 'nRemoved': 0, - 'upserted': [], - 'writeErrors': [], - 'writeConcernErrors': [] + "nMatched": 1, + "nModified": 1, + "nUpserted": 0, + "nInserted": 0, + "nRemoved": 0, + "upserted": [], + "writeErrors": [], + "writeConcernErrors": [], } self.coll.insert_many([{}, {}]) @@ -201,28 +206,28 @@ def _test_update_one(self, update): self.assertTrue(result.modified_count in (1, None)) def test_update_one(self): - self._test_update_one({'$set': {'foo': 'bar'}}) + self._test_update_one({"$set": {"foo": "bar"}}) @client_context.require_version_min(4, 1, 11) def test_update_one_pipeline(self): - self._test_update_one([{'$set': {'foo': 'bar'}}]) + self._test_update_one([{"$set": {"foo": "bar"}}]) def test_replace_one(self): expected = { - 'nMatched': 1, - 'nModified': 1, - 'nUpserted': 0, - 'nInserted': 0, - 'nRemoved': 0, - 'upserted': [], - 'writeErrors': [], - 'writeConcernErrors': [] + "nMatched": 1, + "nModified": 1, + "nUpserted": 0, + "nInserted": 0, + "nRemoved": 0, + "upserted": [], + "writeErrors": [], + "writeConcernErrors": [], } self.coll.insert_many([{}, {}]) - result = self.coll.bulk_write([ReplaceOne({}, {'foo': 'bar'})]) + result = self.coll.bulk_write([ReplaceOne({}, {"foo": "bar"})]) self.assertEqualResponse(expected, result.bulk_api_result) self.assertEqual(1, result.matched_count) self.assertTrue(result.modified_count in (1, None)) @@ -230,14 +235,14 @@ def test_replace_one(self): def test_remove(self): # Test removing all documents, ordered. expected = { - 'nMatched': 0, - 'nModified': 0, - 'nUpserted': 0, - 'nInserted': 0, - 'nRemoved': 2, - 'upserted': [], - 'writeErrors': [], - 'writeConcernErrors': [] + "nMatched": 0, + "nModified": 0, + "nUpserted": 0, + "nInserted": 0, + "nRemoved": 2, + "upserted": [], + "writeErrors": [], + "writeConcernErrors": [], } self.coll.insert_many([{}, {}]) @@ -249,14 +254,14 @@ def test_remove_one(self): # Test removing one document, empty selector. self.coll.insert_many([{}, {}]) expected = { - 'nMatched': 0, - 'nModified': 0, - 'nUpserted': 0, - 'nInserted': 0, - 'nRemoved': 1, - 'upserted': [], - 'writeErrors': [], - 'writeConcernErrors': [] + "nMatched": 0, + "nModified": 0, + "nUpserted": 0, + "nInserted": 0, + "nRemoved": 1, + "upserted": [], + "writeErrors": [], + "writeConcernErrors": [], } result = self.coll.bulk_write([DeleteOne({})]) @@ -267,23 +272,21 @@ def test_remove_one(self): def test_upsert(self): expected = { - 'nMatched': 0, - 'nModified': 0, - 'nUpserted': 1, - 'nInserted': 0, - 'nRemoved': 0, - 'upserted': [{'index': 0, '_id': '...'}] + "nMatched": 0, + "nModified": 0, + "nUpserted": 1, + "nInserted": 0, + "nRemoved": 0, + "upserted": [{"index": 0, "_id": "..."}], } - result = self.coll.bulk_write([ReplaceOne({}, - {'foo': 'bar'}, - upsert=True)]) + result = self.coll.bulk_write([ReplaceOne({}, {"foo": "bar"}, upsert=True)]) self.assertEqualResponse(expected, result.bulk_api_result) self.assertEqual(1, result.upserted_count) self.assertEqual(1, len(result.upserted_ids)) self.assertTrue(isinstance(result.upserted_ids.get(0), ObjectId)) - self.assertEqual(self.coll.count_documents({'foo': 'bar'}), 1) + self.assertEqual(self.coll.count_documents({"foo": "bar"}), 1) def test_numerous_inserts(self): # Ensure we don't exceed server's maxWriteBatchSize size limit. @@ -306,23 +309,23 @@ def test_bulk_max_message_size(self): # Generate a list of documents such that the first batched OP_MSG is # as close as possible to the 48MB limit. docs = [ - {'_id': 1, 'l': 's' * _16_MB}, - {'_id': 2, 'l': 's' * _16_MB}, - {'_id': 3, 'l': 's' * (_16_MB - 10000)}, + {"_id": 1, "l": "s" * _16_MB}, + {"_id": 2, "l": "s" * _16_MB}, + {"_id": 3, "l": "s" * (_16_MB - 10000)}, ] # Fill in the remaining ~10000 bytes with small documents. for i in range(4, 10000): - docs.append({'_id': i}) + docs.append({"_id": i}) result = self.coll.insert_many(docs) self.assertEqual(len(docs), len(result.inserted_ids)) def test_generator_insert(self): def gen(): - yield {'a': 1, 'b': 1} - yield {'a': 1, 'b': 2} - yield {'a': 2, 'b': 3} - yield {'a': 3, 'b': 5} - yield {'a': 5, 'b': 8} + yield {"a": 1, "b": 1} + yield {"a": 1, "b": 2} + yield {"a": 2, "b": 3} + yield {"a": 3, "b": 5} + yield {"a": 5, "b": 8} result = self.coll.insert_many(gen()) self.assertEqual(5, len(result.inserted_ids)) @@ -348,134 +351,166 @@ def test_bulk_write_invalid_arguments(self): self.coll.bulk_write([{}]) def test_upsert_large(self): - big = 'a' * (client_context.max_bson_size - 37) - result = self.coll.bulk_write([ - UpdateOne({'x': 1}, {'$set': {'s': big}}, upsert=True)]) + big = "a" * (client_context.max_bson_size - 37) + result = self.coll.bulk_write([UpdateOne({"x": 1}, {"$set": {"s": big}}, upsert=True)]) self.assertEqualResponse( - {'nMatched': 0, - 'nModified': 0, - 'nUpserted': 1, - 'nInserted': 0, - 'nRemoved': 0, - 'upserted': [{'index': 0, '_id': '...'}]}, - result.bulk_api_result) - - self.assertEqual(1, self.coll.count_documents({'x': 1})) + { + "nMatched": 0, + "nModified": 0, + "nUpserted": 1, + "nInserted": 0, + "nRemoved": 0, + "upserted": [{"index": 0, "_id": "..."}], + }, + result.bulk_api_result, + ) + + self.assertEqual(1, self.coll.count_documents({"x": 1})) def test_client_generated_upsert_id(self): - result = self.coll.bulk_write([ - UpdateOne({'_id': 0}, {'$set': {'a': 0}}, upsert=True), - ReplaceOne({'a': 1}, {'_id': 1}, upsert=True), - # This is just here to make the counts right in all cases. - ReplaceOne({'_id': 2}, {'_id': 2}, upsert=True), - ]) + result = self.coll.bulk_write( + [ + UpdateOne({"_id": 0}, {"$set": {"a": 0}}, upsert=True), + ReplaceOne({"a": 1}, {"_id": 1}, upsert=True), + # This is just here to make the counts right in all cases. + ReplaceOne({"_id": 2}, {"_id": 2}, upsert=True), + ] + ) self.assertEqualResponse( - {'nMatched': 0, - 'nModified': 0, - 'nUpserted': 3, - 'nInserted': 0, - 'nRemoved': 0, - 'upserted': [{'index': 0, '_id': 0}, - {'index': 1, '_id': 1}, - {'index': 2, '_id': 2}]}, - result.bulk_api_result) + { + "nMatched": 0, + "nModified": 0, + "nUpserted": 3, + "nInserted": 0, + "nRemoved": 0, + "upserted": [ + {"index": 0, "_id": 0}, + {"index": 1, "_id": 1}, + {"index": 2, "_id": 2}, + ], + }, + result.bulk_api_result, + ) def test_upsert_uuid_standard(self): options = CodecOptions(uuid_representation=UuidRepresentation.STANDARD) coll = self.coll.with_options(codec_options=options) uuids = [uuid.uuid4() for _ in range(3)] - result = coll.bulk_write([ - UpdateOne({'_id': uuids[0]}, {'$set': {'a': 0}}, upsert=True), - ReplaceOne({'a': 1}, {'_id': uuids[1]}, upsert=True), - # This is just here to make the counts right in all cases. - ReplaceOne({'_id': uuids[2]}, {'_id': uuids[2]}, upsert=True), - ]) + result = coll.bulk_write( + [ + UpdateOne({"_id": uuids[0]}, {"$set": {"a": 0}}, upsert=True), + ReplaceOne({"a": 1}, {"_id": uuids[1]}, upsert=True), + # This is just here to make the counts right in all cases. + ReplaceOne({"_id": uuids[2]}, {"_id": uuids[2]}, upsert=True), + ] + ) self.assertEqualResponse( - {'nMatched': 0, - 'nModified': 0, - 'nUpserted': 3, - 'nInserted': 0, - 'nRemoved': 0, - 'upserted': [{'index': 0, '_id': uuids[0]}, - {'index': 1, '_id': uuids[1]}, - {'index': 2, '_id': uuids[2]}]}, - result.bulk_api_result) + { + "nMatched": 0, + "nModified": 0, + "nUpserted": 3, + "nInserted": 0, + "nRemoved": 0, + "upserted": [ + {"index": 0, "_id": uuids[0]}, + {"index": 1, "_id": uuids[1]}, + {"index": 2, "_id": uuids[2]}, + ], + }, + result.bulk_api_result, + ) def test_upsert_uuid_unspecified(self): options = CodecOptions(uuid_representation=UuidRepresentation.UNSPECIFIED) coll = self.coll.with_options(codec_options=options) uuids = [Binary.from_uuid(uuid.uuid4()) for _ in range(3)] - result = coll.bulk_write([ - UpdateOne({'_id': uuids[0]}, {'$set': {'a': 0}}, upsert=True), - ReplaceOne({'a': 1}, {'_id': uuids[1]}, upsert=True), - # This is just here to make the counts right in all cases. - ReplaceOne({'_id': uuids[2]}, {'_id': uuids[2]}, upsert=True), - ]) + result = coll.bulk_write( + [ + UpdateOne({"_id": uuids[0]}, {"$set": {"a": 0}}, upsert=True), + ReplaceOne({"a": 1}, {"_id": uuids[1]}, upsert=True), + # This is just here to make the counts right in all cases. + ReplaceOne({"_id": uuids[2]}, {"_id": uuids[2]}, upsert=True), + ] + ) self.assertEqualResponse( - {'nMatched': 0, - 'nModified': 0, - 'nUpserted': 3, - 'nInserted': 0, - 'nRemoved': 0, - 'upserted': [{'index': 0, '_id': uuids[0]}, - {'index': 1, '_id': uuids[1]}, - {'index': 2, '_id': uuids[2]}]}, - result.bulk_api_result) + { + "nMatched": 0, + "nModified": 0, + "nUpserted": 3, + "nInserted": 0, + "nRemoved": 0, + "upserted": [ + {"index": 0, "_id": uuids[0]}, + {"index": 1, "_id": uuids[1]}, + {"index": 2, "_id": uuids[2]}, + ], + }, + result.bulk_api_result, + ) def test_upsert_uuid_standard_subdocuments(self): options = CodecOptions(uuid_representation=UuidRepresentation.STANDARD) coll = self.coll.with_options(codec_options=options) - ids = [ - {'f': Binary(bytes(i)), 'f2': uuid.uuid4()} - for i in range(3) - ] + ids = [{"f": Binary(bytes(i)), "f2": uuid.uuid4()} for i in range(3)] - result = coll.bulk_write([ - UpdateOne({'_id': ids[0]}, {'$set': {'a': 0}}, upsert=True), - ReplaceOne({'a': 1}, {'_id': ids[1]}, upsert=True), - # This is just here to make the counts right in all cases. - ReplaceOne({'_id': ids[2]}, {'_id': ids[2]}, upsert=True), - ]) + result = coll.bulk_write( + [ + UpdateOne({"_id": ids[0]}, {"$set": {"a": 0}}, upsert=True), + ReplaceOne({"a": 1}, {"_id": ids[1]}, upsert=True), + # This is just here to make the counts right in all cases. + ReplaceOne({"_id": ids[2]}, {"_id": ids[2]}, upsert=True), + ] + ) # The `Binary` values are returned as `bytes` objects. for _id in ids: - _id['f'] = bytes(_id['f']) + _id["f"] = bytes(_id["f"]) self.assertEqualResponse( - {'nMatched': 0, - 'nModified': 0, - 'nUpserted': 3, - 'nInserted': 0, - 'nRemoved': 0, - 'upserted': [{'index': 0, '_id': ids[0]}, - {'index': 1, '_id': ids[1]}, - {'index': 2, '_id': ids[2]}]}, - result.bulk_api_result) + { + "nMatched": 0, + "nModified": 0, + "nUpserted": 3, + "nInserted": 0, + "nRemoved": 0, + "upserted": [ + {"index": 0, "_id": ids[0]}, + {"index": 1, "_id": ids[1]}, + {"index": 2, "_id": ids[2]}, + ], + }, + result.bulk_api_result, + ) def test_single_ordered_batch(self): - result = self.coll.bulk_write([ - InsertOne({'a': 1}), - UpdateOne({'a': 1}, {'$set': {'b': 1}}), - UpdateOne({'a': 2}, {'$set': {'b': 2}}, upsert=True), - InsertOne({'a': 3}), - DeleteOne({'a': 3}), - ]) + result = self.coll.bulk_write( + [ + InsertOne({"a": 1}), + UpdateOne({"a": 1}, {"$set": {"b": 1}}), + UpdateOne({"a": 2}, {"$set": {"b": 2}}, upsert=True), + InsertOne({"a": 3}), + DeleteOne({"a": 3}), + ] + ) self.assertEqualResponse( - {'nMatched': 1, - 'nModified': 1, - 'nUpserted': 1, - 'nInserted': 2, - 'nRemoved': 1, - 'upserted': [{'index': 2, '_id': '...'}]}, - result.bulk_api_result) + { + "nMatched": 1, + "nModified": 1, + "nUpserted": 1, + "nInserted": 2, + "nRemoved": 1, + "upserted": [{"index": 2, "_id": "..."}], + }, + result.bulk_api_result, + ) def test_single_error_ordered_batch(self): - self.coll.create_index('a', unique=True) - self.addCleanup(self.coll.drop_index, [('a', 1)]) + self.coll.create_index("a", unique=True) + self.addCleanup(self.coll.drop_index, [("a", 1)]) requests = [ - InsertOne({'b': 1, 'a': 1}), - UpdateOne({'b': 2}, {'$set': {'a': 1}}, upsert=True), - InsertOne({'b': 3, 'a': 2}), + InsertOne({"b": 1, "a": 1}), + UpdateOne({"b": 2}, {"$set": {"a": 1}}, upsert=True), + InsertOne({"b": 3, "a": 2}), ] try: self.coll.bulk_write(requests) @@ -486,33 +521,41 @@ def test_single_error_ordered_batch(self): self.fail("Error not raised") self.assertEqualResponse( - {'nMatched': 0, - 'nModified': 0, - 'nUpserted': 0, - 'nInserted': 1, - 'nRemoved': 0, - 'upserted': [], - 'writeConcernErrors': [], - 'writeErrors': [ - {'index': 1, - 'code': 11000, - 'errmsg': '...', - 'op': {'q': {'b': 2}, - 'u': {'$set': {'a': 1}}, - 'multi': False, - 'upsert': True}}]}, - result) + { + "nMatched": 0, + "nModified": 0, + "nUpserted": 0, + "nInserted": 1, + "nRemoved": 0, + "upserted": [], + "writeConcernErrors": [], + "writeErrors": [ + { + "index": 1, + "code": 11000, + "errmsg": "...", + "op": { + "q": {"b": 2}, + "u": {"$set": {"a": 1}}, + "multi": False, + "upsert": True, + }, + } + ], + }, + result, + ) def test_multiple_error_ordered_batch(self): - self.coll.create_index('a', unique=True) - self.addCleanup(self.coll.drop_index, [('a', 1)]) + self.coll.create_index("a", unique=True) + self.addCleanup(self.coll.drop_index, [("a", 1)]) requests = [ - InsertOne({'b': 1, 'a': 1}), - UpdateOne({'b': 2}, {'$set': {'a': 1}}, upsert=True), - UpdateOne({'b': 3}, {'$set': {'a': 2}}, upsert=True), - UpdateOne({'b': 2}, {'$set': {'a': 1}}, upsert=True), - InsertOne({'b': 4, 'a': 3}), - InsertOne({'b': 5, 'a': 1}), + InsertOne({"b": 1, "a": 1}), + UpdateOne({"b": 2}, {"$set": {"a": 1}}, upsert=True), + UpdateOne({"b": 3}, {"$set": {"a": 2}}, upsert=True), + UpdateOne({"b": 2}, {"$set": {"a": 1}}, upsert=True), + InsertOne({"b": 4, "a": 3}), + InsertOne({"b": 5, "a": 1}), ] try: @@ -524,50 +567,61 @@ def test_multiple_error_ordered_batch(self): self.fail("Error not raised") self.assertEqualResponse( - {'nMatched': 0, - 'nModified': 0, - 'nUpserted': 0, - 'nInserted': 1, - 'nRemoved': 0, - 'upserted': [], - 'writeConcernErrors': [], - 'writeErrors': [ - {'index': 1, - 'code': 11000, - 'errmsg': '...', - 'op': {'q': {'b': 2}, - 'u': {'$set': {'a': 1}}, - 'multi': False, - 'upsert': True}}]}, - result) + { + "nMatched": 0, + "nModified": 0, + "nUpserted": 0, + "nInserted": 1, + "nRemoved": 0, + "upserted": [], + "writeConcernErrors": [], + "writeErrors": [ + { + "index": 1, + "code": 11000, + "errmsg": "...", + "op": { + "q": {"b": 2}, + "u": {"$set": {"a": 1}}, + "multi": False, + "upsert": True, + }, + } + ], + }, + result, + ) def test_single_unordered_batch(self): requests = [ - InsertOne({'a': 1}), - UpdateOne({'a': 1}, {'$set': {'b': 1}}), - UpdateOne({'a': 2}, {'$set': {'b': 2}}, upsert=True), - InsertOne({'a': 3}), - DeleteOne({'a': 3}), + InsertOne({"a": 1}), + UpdateOne({"a": 1}, {"$set": {"b": 1}}), + UpdateOne({"a": 2}, {"$set": {"b": 2}}, upsert=True), + InsertOne({"a": 3}), + DeleteOne({"a": 3}), ] result = self.coll.bulk_write(requests, ordered=False) self.assertEqualResponse( - {'nMatched': 1, - 'nModified': 1, - 'nUpserted': 1, - 'nInserted': 2, - 'nRemoved': 1, - 'upserted': [{'index': 2, '_id': '...'}], - 'writeErrors': [], - 'writeConcernErrors': []}, - result.bulk_api_result) + { + "nMatched": 1, + "nModified": 1, + "nUpserted": 1, + "nInserted": 2, + "nRemoved": 1, + "upserted": [{"index": 2, "_id": "..."}], + "writeErrors": [], + "writeConcernErrors": [], + }, + result.bulk_api_result, + ) def test_single_error_unordered_batch(self): - self.coll.create_index('a', unique=True) - self.addCleanup(self.coll.drop_index, [('a', 1)]) + self.coll.create_index("a", unique=True) + self.addCleanup(self.coll.drop_index, [("a", 1)]) requests = [ - InsertOne({'b': 1, 'a': 1}), - UpdateOne({'b': 2}, {'$set': {'a': 1}}, upsert=True), - InsertOne({'b': 3, 'a': 2}), + InsertOne({"b": 1, "a": 1}), + UpdateOne({"b": 2}, {"$set": {"a": 1}}, upsert=True), + InsertOne({"b": 3, "a": 2}), ] try: @@ -579,33 +633,41 @@ def test_single_error_unordered_batch(self): self.fail("Error not raised") self.assertEqualResponse( - {'nMatched': 0, - 'nModified': 0, - 'nUpserted': 0, - 'nInserted': 2, - 'nRemoved': 0, - 'upserted': [], - 'writeConcernErrors': [], - 'writeErrors': [ - {'index': 1, - 'code': 11000, - 'errmsg': '...', - 'op': {'q': {'b': 2}, - 'u': {'$set': {'a': 1}}, - 'multi': False, - 'upsert': True}}]}, - result) + { + "nMatched": 0, + "nModified": 0, + "nUpserted": 0, + "nInserted": 2, + "nRemoved": 0, + "upserted": [], + "writeConcernErrors": [], + "writeErrors": [ + { + "index": 1, + "code": 11000, + "errmsg": "...", + "op": { + "q": {"b": 2}, + "u": {"$set": {"a": 1}}, + "multi": False, + "upsert": True, + }, + } + ], + }, + result, + ) def test_multiple_error_unordered_batch(self): - self.coll.create_index('a', unique=True) - self.addCleanup(self.coll.drop_index, [('a', 1)]) + self.coll.create_index("a", unique=True) + self.addCleanup(self.coll.drop_index, [("a", 1)]) requests = [ - InsertOne({'b': 1, 'a': 1}), - UpdateOne({'b': 2}, {'$set': {'a': 3}}, upsert=True), - UpdateOne({'b': 3}, {'$set': {'a': 4}}, upsert=True), - UpdateOne({'b': 4}, {'$set': {'a': 3}}, upsert=True), - InsertOne({'b': 5, 'a': 2}), - InsertOne({'b': 6, 'a': 1}), + InsertOne({"b": 1, "a": 1}), + UpdateOne({"b": 2}, {"$set": {"a": 3}}, upsert=True), + UpdateOne({"b": 3}, {"$set": {"a": 4}}, upsert=True), + UpdateOne({"b": 4}, {"$set": {"a": 3}}, upsert=True), + InsertOne({"b": 5, "a": 2}), + InsertOne({"b": 6, "a": 1}), ] try: @@ -618,35 +680,43 @@ def test_multiple_error_unordered_batch(self): # Assume the update at index 1 runs before the update at index 3, # although the spec does not require it. Same for inserts. self.assertEqualResponse( - {'nMatched': 0, - 'nModified': 0, - 'nUpserted': 2, - 'nInserted': 2, - 'nRemoved': 0, - 'upserted': [ - {'index': 1, '_id': '...'}, - {'index': 2, '_id': '...'}], - 'writeConcernErrors': [], - 'writeErrors': [ - {'index': 3, - 'code': 11000, - 'errmsg': '...', - 'op': {'q': {'b': 4}, - 'u': {'$set': {'a': 3}}, - 'multi': False, - 'upsert': True}}, - {'index': 5, - 'code': 11000, - 'errmsg': '...', - 'op': {'_id': '...', 'b': 6, 'a': 1}}]}, - result) + { + "nMatched": 0, + "nModified": 0, + "nUpserted": 2, + "nInserted": 2, + "nRemoved": 0, + "upserted": [{"index": 1, "_id": "..."}, {"index": 2, "_id": "..."}], + "writeConcernErrors": [], + "writeErrors": [ + { + "index": 3, + "code": 11000, + "errmsg": "...", + "op": { + "q": {"b": 4}, + "u": {"$set": {"a": 3}}, + "multi": False, + "upsert": True, + }, + }, + { + "index": 5, + "code": 11000, + "errmsg": "...", + "op": {"_id": "...", "b": 6, "a": 1}, + }, + ], + }, + result, + ) def test_large_inserts_ordered(self): - big = 'x' * client_context.max_bson_size + big = "x" * client_context.max_bson_size requests = [ - InsertOne({'b': 1, 'a': 1}), - InsertOne({'big': big}), - InsertOne({'b': 2, 'a': 2}), + InsertOne({"b": 1, "a": 1}), + InsertOne({"big": big}), + InsertOne({"b": 2, "a": 2}), ] try: @@ -657,29 +727,31 @@ def test_large_inserts_ordered(self): else: self.fail("Error not raised") - self.assertEqual(1, result['nInserted']) + self.assertEqual(1, result["nInserted"]) self.coll.delete_many({}) - big = 'x' * (1024 * 1024 * 4) - result = self.coll.bulk_write([ - InsertOne({'a': 1, 'big': big}), - InsertOne({'a': 2, 'big': big}), - InsertOne({'a': 3, 'big': big}), - InsertOne({'a': 4, 'big': big}), - InsertOne({'a': 5, 'big': big}), - InsertOne({'a': 6, 'big': big}), - ]) + big = "x" * (1024 * 1024 * 4) + result = self.coll.bulk_write( + [ + InsertOne({"a": 1, "big": big}), + InsertOne({"a": 2, "big": big}), + InsertOne({"a": 3, "big": big}), + InsertOne({"a": 4, "big": big}), + InsertOne({"a": 5, "big": big}), + InsertOne({"a": 6, "big": big}), + ] + ) self.assertEqual(6, result.inserted_count) self.assertEqual(6, self.coll.count_documents({})) def test_large_inserts_unordered(self): - big = 'x' * client_context.max_bson_size + big = "x" * client_context.max_bson_size requests = [ - InsertOne({'b': 1, 'a': 1}), - InsertOne({'big': big}), - InsertOne({'b': 2, 'a': 2}), + InsertOne({"b": 1, "a": 1}), + InsertOne({"big": big}), + InsertOne({"b": 2, "a": 2}), ] try: @@ -690,26 +762,28 @@ def test_large_inserts_unordered(self): else: self.fail("Error not raised") - self.assertEqual(2, result['nInserted']) + self.assertEqual(2, result["nInserted"]) self.coll.delete_many({}) - big = 'x' * (1024 * 1024 * 4) - result = self.coll.bulk_write([ - InsertOne({'a': 1, 'big': big}), - InsertOne({'a': 2, 'big': big}), - InsertOne({'a': 3, 'big': big}), - InsertOne({'a': 4, 'big': big}), - InsertOne({'a': 5, 'big': big}), - InsertOne({'a': 6, 'big': big}), - ], ordered=False) + big = "x" * (1024 * 1024 * 4) + result = self.coll.bulk_write( + [ + InsertOne({"a": 1, "big": big}), + InsertOne({"a": 2, "big": big}), + InsertOne({"a": 3, "big": big}), + InsertOne({"a": 4, "big": big}), + InsertOne({"a": 5, "big": big}), + InsertOne({"a": 6, "big": big}), + ], + ordered=False, + ) self.assertEqual(6, result.inserted_count) self.assertEqual(6, self.coll.count_documents({})) class BulkAuthorizationTestBase(BulkTestBase): - @classmethod @client_context.require_auth @client_context.require_no_api_version @@ -718,129 +792,123 @@ def setUpClass(cls): def setUp(self): super(BulkAuthorizationTestBase, self).setUp() - client_context.create_user( - self.db.name, 'readonly', 'pw', ['read']) + client_context.create_user(self.db.name, "readonly", "pw", ["read"]) self.db.command( - 'createRole', 'noremove', - privileges=[{ - 'actions': ['insert', 'update', 'find'], - 'resource': {'db': 'pymongo_test', 'collection': 'test'} - }], - roles=[]) - - client_context.create_user(self.db.name, 'noremove', 'pw', ['noremove']) + "createRole", + "noremove", + privileges=[ + { + "actions": ["insert", "update", "find"], + "resource": {"db": "pymongo_test", "collection": "test"}, + } + ], + roles=[], + ) + + client_context.create_user(self.db.name, "noremove", "pw", ["noremove"]) def tearDown(self): - self.db.command('dropRole', 'noremove') + self.db.command("dropRole", "noremove") remove_all_users(self.db) class TestBulkUnacknowledged(BulkTestBase): - def tearDown(self): self.coll.delete_many({}) def test_no_results_ordered_success(self): requests = [ - InsertOne({'a': 1}), - UpdateOne({'a': 3}, {'$set': {'b': 1}}, upsert=True), - InsertOne({'a': 2}), - DeleteOne({'a': 1}), + InsertOne({"a": 1}), + UpdateOne({"a": 3}, {"$set": {"b": 1}}, upsert=True), + InsertOne({"a": 2}), + DeleteOne({"a": 1}), ] result = self.coll_w0.bulk_write(requests) self.assertFalse(result.acknowledged) - wait_until(lambda: 2 == self.coll.count_documents({}), - 'insert 2 documents') - wait_until(lambda: self.coll.find_one({'_id': 1}) is None, - 'removed {"_id": 1}') + wait_until(lambda: 2 == self.coll.count_documents({}), "insert 2 documents") + wait_until(lambda: self.coll.find_one({"_id": 1}) is None, 'removed {"_id": 1}') def test_no_results_ordered_failure(self): requests = [ - InsertOne({'_id': 1}), - UpdateOne({'_id': 3}, {'$set': {'b': 1}}, upsert=True), - InsertOne({'_id': 2}), + InsertOne({"_id": 1}), + UpdateOne({"_id": 3}, {"$set": {"b": 1}}, upsert=True), + InsertOne({"_id": 2}), # Fails with duplicate key error. - InsertOne({'_id': 1}), + InsertOne({"_id": 1}), # Should not be executed since the batch is ordered. - DeleteOne({'_id': 1}), + DeleteOne({"_id": 1}), ] result = self.coll_w0.bulk_write(requests) self.assertFalse(result.acknowledged) - wait_until(lambda: 3 == self.coll.count_documents({}), - 'insert 3 documents') - self.assertEqual({'_id': 1}, self.coll.find_one({'_id': 1})) + wait_until(lambda: 3 == self.coll.count_documents({}), "insert 3 documents") + self.assertEqual({"_id": 1}, self.coll.find_one({"_id": 1})) def test_no_results_unordered_success(self): requests = [ - InsertOne({'a': 1}), - UpdateOne({'a': 3}, {'$set': {'b': 1}}, upsert=True), - InsertOne({'a': 2}), - DeleteOne({'a': 1}), + InsertOne({"a": 1}), + UpdateOne({"a": 3}, {"$set": {"b": 1}}, upsert=True), + InsertOne({"a": 2}), + DeleteOne({"a": 1}), ] result = self.coll_w0.bulk_write(requests, ordered=False) self.assertFalse(result.acknowledged) - wait_until(lambda: 2 == self.coll.count_documents({}), - 'insert 2 documents') - wait_until(lambda: self.coll.find_one({'_id': 1}) is None, - 'removed {"_id": 1}') + wait_until(lambda: 2 == self.coll.count_documents({}), "insert 2 documents") + wait_until(lambda: self.coll.find_one({"_id": 1}) is None, 'removed {"_id": 1}') def test_no_results_unordered_failure(self): requests = [ - InsertOne({'_id': 1}), - UpdateOne({'_id': 3}, {'$set': {'b': 1}}, upsert=True), - InsertOne({'_id': 2}), + InsertOne({"_id": 1}), + UpdateOne({"_id": 3}, {"$set": {"b": 1}}, upsert=True), + InsertOne({"_id": 2}), # Fails with duplicate key error. - InsertOne({'_id': 1}), + InsertOne({"_id": 1}), # Should be executed since the batch is unordered. - DeleteOne({'_id': 1}), + DeleteOne({"_id": 1}), ] result = self.coll_w0.bulk_write(requests, ordered=False) self.assertFalse(result.acknowledged) - wait_until(lambda: 2 == self.coll.count_documents({}), - 'insert 2 documents') - wait_until(lambda: self.coll.find_one({'_id': 1}) is None, - 'removed {"_id": 1}') + wait_until(lambda: 2 == self.coll.count_documents({}), "insert 2 documents") + wait_until(lambda: self.coll.find_one({"_id": 1}) is None, 'removed {"_id": 1}') class TestBulkAuthorization(BulkAuthorizationTestBase): - def test_readonly(self): # We test that an authorization failure aborts the batch and is raised # as OperationFailure. - cli = rs_or_single_client_noauth(username='readonly', password='pw', - authSource='pymongo_test') + cli = rs_or_single_client_noauth( + username="readonly", password="pw", authSource="pymongo_test" + ) coll = cli.pymongo_test.test coll.find_one() - self.assertRaises(OperationFailure, coll.bulk_write, - [InsertOne({'x': 1})]) + self.assertRaises(OperationFailure, coll.bulk_write, [InsertOne({"x": 1})]) def test_no_remove(self): # We test that an authorization failure aborts the batch and is raised # as OperationFailure. - cli = rs_or_single_client_noauth(username='noremove', password='pw', - authSource='pymongo_test') + cli = rs_or_single_client_noauth( + username="noremove", password="pw", authSource="pymongo_test" + ) coll = cli.pymongo_test.test coll.find_one() requests = [ - InsertOne({'x': 1}), - ReplaceOne({'x': 2}, {'x': 2}, upsert=True), - DeleteMany({}), # Prohibited. - InsertOne({'x': 3}), # Never attempted. + InsertOne({"x": 1}), + ReplaceOne({"x": 2}, {"x": 2}, upsert=True), + DeleteMany({}), # Prohibited. + InsertOne({"x": 3}), # Never attempted. ] self.assertRaises(OperationFailure, coll.bulk_write, requests) - self.assertEqual(set([1, 2]), set(self.coll.distinct('x'))) + self.assertEqual(set([1, 2]), set(self.coll.distinct("x"))) class TestBulkWriteConcern(BulkTestBase): - @classmethod def setUpClass(cls): super(TestBulkWriteConcern, cls).setUpClass() cls.w = client_context.w cls.secondary = None if cls.w > 1: - for member in client_context.hello['hosts']: - if member != client_context.hello['primary']: + for member in client_context.hello["hosts"]: + if member != client_context.hello["primary"]: cls.secondary = single_client(*partition_node(member)) break @@ -855,32 +923,23 @@ def cause_wtimeout(self, requests, ordered): # Use the rsSyncApplyStop failpoint to pause replication on a # secondary which will cause a wtimeout error. - self.secondary.admin.command('configureFailPoint', - 'rsSyncApplyStop', - mode='alwaysOn') + self.secondary.admin.command("configureFailPoint", "rsSyncApplyStop", mode="alwaysOn") try: - coll = self.coll.with_options( - write_concern=WriteConcern(w=self.w, wtimeout=1)) + coll = self.coll.with_options(write_concern=WriteConcern(w=self.w, wtimeout=1)) return coll.bulk_write(requests, ordered=ordered) finally: - self.secondary.admin.command('configureFailPoint', - 'rsSyncApplyStop', - mode='off') + self.secondary.admin.command("configureFailPoint", "rsSyncApplyStop", mode="off") @client_context.require_replica_set @client_context.require_secondaries_count(1) def test_write_concern_failure_ordered(self): # Ensure we don't raise on wnote. coll_ww = self.coll.with_options(write_concern=WriteConcern(w=self.w)) - result = coll_ww.bulk_write([ - DeleteOne({"something": "that does no exist"})]) + result = coll_ww.bulk_write([DeleteOne({"something": "that does no exist"})]) self.assertTrue(result.acknowledged) - requests = [ - InsertOne({'a': 1}), - InsertOne({'a': 2}) - ] + requests = [InsertOne({"a": 1}), InsertOne({"a": 2})] # Replication wtimeout is a 'soft' error. # It shouldn't stop batch processing. try: @@ -892,34 +951,37 @@ def test_write_concern_failure_ordered(self): self.fail("Error not raised") self.assertEqualResponse( - {'nMatched': 0, - 'nModified': 0, - 'nUpserted': 0, - 'nInserted': 2, - 'nRemoved': 0, - 'upserted': [], - 'writeErrors': []}, - result) + { + "nMatched": 0, + "nModified": 0, + "nUpserted": 0, + "nInserted": 2, + "nRemoved": 0, + "upserted": [], + "writeErrors": [], + }, + result, + ) # When talking to legacy servers there will be a # write concern error for each operation. - self.assertTrue(len(result['writeConcernErrors']) > 0) + self.assertTrue(len(result["writeConcernErrors"]) > 0) - failed = result['writeConcernErrors'][0] - self.assertEqual(64, failed['code']) - self.assertTrue(isinstance(failed['errmsg'], str)) + failed = result["writeConcernErrors"][0] + self.assertEqual(64, failed["code"]) + self.assertTrue(isinstance(failed["errmsg"], str)) self.coll.delete_many({}) - self.coll.create_index('a', unique=True) - self.addCleanup(self.coll.drop_index, [('a', 1)]) + self.coll.create_index("a", unique=True) + self.addCleanup(self.coll.drop_index, [("a", 1)]) # Fail due to write concern support as well # as duplicate key error on ordered batch. requests = [ - InsertOne({'a': 1}), - ReplaceOne({'a': 3}, {'b': 1}, upsert=True), - InsertOne({'a': 1}), - InsertOne({'a': 2}), + InsertOne({"a": 1}), + ReplaceOne({"a": 3}, {"b": 1}, upsert=True), + InsertOne({"a": 1}), + InsertOne({"a": 2}), ] try: self.cause_wtimeout(requests, ordered=True) @@ -930,36 +992,36 @@ def test_write_concern_failure_ordered(self): self.fail("Error not raised") self.assertEqualResponse( - {'nMatched': 0, - 'nModified': 0, - 'nUpserted': 1, - 'nInserted': 1, - 'nRemoved': 0, - 'upserted': [{'index': 1, '_id': '...'}], - 'writeErrors': [ - {'index': 2, - 'code': 11000, - 'errmsg': '...', - 'op': {'_id': '...', 'a': 1}}]}, - result) - - self.assertTrue(len(result['writeConcernErrors']) > 1) - failed = result['writeErrors'][0] - self.assertTrue("duplicate" in failed['errmsg']) + { + "nMatched": 0, + "nModified": 0, + "nUpserted": 1, + "nInserted": 1, + "nRemoved": 0, + "upserted": [{"index": 1, "_id": "..."}], + "writeErrors": [ + {"index": 2, "code": 11000, "errmsg": "...", "op": {"_id": "...", "a": 1}} + ], + }, + result, + ) + + self.assertTrue(len(result["writeConcernErrors"]) > 1) + failed = result["writeErrors"][0] + self.assertTrue("duplicate" in failed["errmsg"]) @client_context.require_replica_set @client_context.require_secondaries_count(1) def test_write_concern_failure_unordered(self): # Ensure we don't raise on wnote. coll_ww = self.coll.with_options(write_concern=WriteConcern(w=self.w)) - result = coll_ww.bulk_write([ - DeleteOne({"something": "that does no exist"})], ordered=False) + result = coll_ww.bulk_write([DeleteOne({"something": "that does no exist"})], ordered=False) self.assertTrue(result.acknowledged) requests = [ - InsertOne({'a': 1}), - UpdateOne({'a': 3}, {'$set': {'a': 3, 'b': 1}}, upsert=True), - InsertOne({'a': 2}), + InsertOne({"a": 1}), + UpdateOne({"a": 3}, {"$set": {"a": 3, "b": 1}}, upsert=True), + InsertOne({"a": 2}), ] # Replication wtimeout is a 'soft' error. # It shouldn't stop batch processing. @@ -971,24 +1033,24 @@ def test_write_concern_failure_unordered(self): else: self.fail("Error not raised") - self.assertEqual(2, result['nInserted']) - self.assertEqual(1, result['nUpserted']) - self.assertEqual(0, len(result['writeErrors'])) + self.assertEqual(2, result["nInserted"]) + self.assertEqual(1, result["nUpserted"]) + self.assertEqual(0, len(result["writeErrors"])) # When talking to legacy servers there will be a # write concern error for each operation. - self.assertTrue(len(result['writeConcernErrors']) > 1) + self.assertTrue(len(result["writeConcernErrors"]) > 1) self.coll.delete_many({}) - self.coll.create_index('a', unique=True) - self.addCleanup(self.coll.drop_index, [('a', 1)]) + self.coll.create_index("a", unique=True) + self.addCleanup(self.coll.drop_index, [("a", 1)]) # Fail due to write concern support as well # as duplicate key error on unordered batch. requests = [ - InsertOne({'a': 1}), - UpdateOne({'a': 3}, {'$set': {'a': 3, 'b': 1}}, upsert=True), - InsertOne({'a': 1}), - InsertOne({'a': 2}), + InsertOne({"a": 1}), + UpdateOne({"a": 3}, {"$set": {"a": 3, "b": 1}}, upsert=True), + InsertOne({"a": 1}), + InsertOne({"a": 2}), ] try: self.cause_wtimeout(requests, ordered=False) @@ -998,27 +1060,27 @@ def test_write_concern_failure_unordered(self): else: self.fail("Error not raised") - self.assertEqual(2, result['nInserted']) - self.assertEqual(1, result['nUpserted']) - self.assertEqual(1, len(result['writeErrors'])) + self.assertEqual(2, result["nInserted"]) + self.assertEqual(1, result["nUpserted"]) + self.assertEqual(1, len(result["writeErrors"])) # When talking to legacy servers there will be a # write concern error for each operation. - self.assertTrue(len(result['writeConcernErrors']) > 1) + self.assertTrue(len(result["writeConcernErrors"]) > 1) - failed = result['writeErrors'][0] - self.assertEqual(2, failed['index']) - self.assertEqual(11000, failed['code']) - self.assertTrue(isinstance(failed['errmsg'], str)) - self.assertEqual(1, failed['op']['a']) + failed = result["writeErrors"][0] + self.assertEqual(2, failed["index"]) + self.assertEqual(11000, failed["code"]) + self.assertTrue(isinstance(failed["errmsg"], str)) + self.assertEqual(1, failed["op"]["a"]) - failed = result['writeConcernErrors'][0] - self.assertEqual(64, failed['code']) - self.assertTrue(isinstance(failed['errmsg'], str)) + failed = result["writeConcernErrors"][0] + self.assertEqual(64, failed["code"]) + self.assertTrue(isinstance(failed["errmsg"], str)) - upserts = result['upserted'] + upserts = result["upserted"] self.assertEqual(1, len(upserts)) - self.assertEqual(1, upserts[0]['index']) - self.assertTrue(upserts[0].get('_id')) + self.assertEqual(1, upserts[0]["index"]) + self.assertTrue(upserts[0].get("_id")) if __name__ == "__main__": diff --git a/test/test_change_stream.py b/test/test_change_stream.py index a49f6972b2..e4791ae636 100644 --- a/test/test_change_stream.py +++ b/test/test_change_stream.py @@ -14,39 +14,41 @@ """Test the change_stream module.""" -import random import os +import random import re -import sys import string +import sys import threading import time import uuid - from itertools import product -sys.path[0:0] = [''] +sys.path[0:0] = [""] -from bson import ObjectId, SON, Timestamp, encode, json_util -from bson.binary import (ALL_UUID_REPRESENTATIONS, - Binary, - STANDARD, - PYTHON_LEGACY) +from test import IntegrationTest, client_context, unittest +from test.unified_format import generate_test_classes +from test.utils import ( + AllowListEventListener, + EventListener, + rs_or_single_client, + wait_until, +) + +from bson import SON, ObjectId, Timestamp, encode, json_util +from bson.binary import ALL_UUID_REPRESENTATIONS, PYTHON_LEGACY, STANDARD, Binary from bson.raw_bson import DEFAULT_RAW_BSON_OPTIONS, RawBSONDocument - from pymongo import MongoClient from pymongo.command_cursor import CommandCursor -from pymongo.errors import (InvalidOperation, OperationFailure, - ServerSelectionTimeoutError) +from pymongo.errors import ( + InvalidOperation, + OperationFailure, + ServerSelectionTimeoutError, +) from pymongo.message import _CursorAddress from pymongo.read_concern import ReadConcern from pymongo.write_concern import WriteConcern -from test import client_context, unittest, IntegrationTest -from test.unified_format import generate_test_classes -from test.utils import ( - EventListener, AllowListEventListener, rs_or_single_client, wait_until) - class TestChangeStreamBase(IntegrationTest): RUN_ON_LOAD_BALANCER = True @@ -69,7 +71,7 @@ def client_with_listener(self, *commands): def watched_collection(self, *args, **kwargs): """Return a collection that is watched by self.change_stream().""" # Construct a unique collection for each test. - collname = '.'.join(self.id().rsplit('.', 2)[1:]) + collname = ".".join(self.id().rsplit(".", 2)[1:]) return self.db.get_collection(collname, *args, **kwargs) def generate_invalidate_event(self, change_stream): @@ -80,27 +82,25 @@ def generate_unique_collnames(self, numcolls): """Generate numcolls collection names unique to a test.""" collnames = [] for idx in range(1, numcolls + 1): - collnames.append(self.id() + '_' + str(idx)) + collnames.append(self.id() + "_" + str(idx)) return collnames def get_resume_token(self, invalidate=False): """Get a resume token to use for starting a change stream.""" # Ensure targeted collection exists before starting. - coll = self.watched_collection(write_concern=WriteConcern('majority')) + coll = self.watched_collection(write_concern=WriteConcern("majority")) coll.insert_one({}) if invalidate: - with self.change_stream( - [{'$match': {'operationType': 'invalidate'}}]) as cs: + with self.change_stream([{"$match": {"operationType": "invalidate"}}]) as cs: if isinstance(cs._target, MongoClient): - self.skipTest( - "cluster-level change streams cannot be invalidated") + self.skipTest("cluster-level change streams cannot be invalidated") self.generate_invalidate_event(cs) - return cs.next()['_id'] + return cs.next()["_id"] else: with self.change_stream() as cs: - coll.insert_one({'data': 1}) - return cs.next()['_id'] + coll.insert_one({"data": 1}) + return cs.next()["_id"] def get_start_at_operation_time(self): """Get an operationTime. Advances the operation clock beyond the most @@ -123,18 +123,18 @@ def kill_change_stream_cursor(self, change_stream): class APITestsMixin(object): def test_watch(self): with self.change_stream( - [{'$project': {'foo': 0}}], full_document='updateLookup', - max_await_time_ms=1000, batch_size=100) as change_stream: - self.assertEqual([{'$project': {'foo': 0}}], - change_stream._pipeline) - self.assertEqual('updateLookup', change_stream._full_document) + [{"$project": {"foo": 0}}], + full_document="updateLookup", + max_await_time_ms=1000, + batch_size=100, + ) as change_stream: + self.assertEqual([{"$project": {"foo": 0}}], change_stream._pipeline) + self.assertEqual("updateLookup", change_stream._full_document) self.assertEqual(1000, change_stream._max_await_time_ms) self.assertEqual(100, change_stream._batch_size) self.assertIsInstance(change_stream._cursor, CommandCursor) - self.assertEqual( - 1000, change_stream._cursor._CommandCursor__max_await_time_ms) - self.watched_collection( - write_concern=WriteConcern("majority")).insert_one({}) + self.assertEqual(1000, change_stream._cursor._CommandCursor__max_await_time_ms) + self.watched_collection(write_concern=WriteConcern("majority")).insert_one({}) _ = change_stream.next() resume_token = change_stream.resume_token with self.assertRaises(TypeError): @@ -147,36 +147,32 @@ def test_watch(self): def test_try_next(self): # ChangeStreams only read majority committed data so use w:majority. - coll = self.watched_collection().with_options( - write_concern=WriteConcern("majority")) + coll = self.watched_collection().with_options(write_concern=WriteConcern("majority")) coll.drop() coll.insert_one({}) self.addCleanup(coll.drop) with self.change_stream(max_await_time_ms=250) as stream: - self.assertIsNone(stream.try_next()) # No changes initially. - coll.insert_one({}) # Generate a change. + self.assertIsNone(stream.try_next()) # No changes initially. + coll.insert_one({}) # Generate a change. # On sharded clusters, even majority-committed changes only show # up once an event that sorts after it shows up on the other # shard. So, we wait on try_next to eventually return changes. - wait_until(lambda: stream.try_next() is not None, - "get change from try_next") + wait_until(lambda: stream.try_next() is not None, "get change from try_next") def test_try_next_runs_one_getmore(self): listener = EventListener() client = rs_or_single_client(event_listeners=[listener]) # Connect to the cluster. - client.admin.command('ping') + client.admin.command("ping") listener.results.clear() # ChangeStreams only read majority committed data so use w:majority. - coll = self.watched_collection().with_options( - write_concern=WriteConcern("majority")) + coll = self.watched_collection().with_options(write_concern=WriteConcern("majority")) coll.drop() # Create the watched collection before starting the change stream to # skip any "create" events. - coll.insert_one({'_id': 1}) + coll.insert_one({"_id": 1}) self.addCleanup(coll.drop) - with self.change_stream_with_client( - client, max_await_time_ms=250) as stream: + with self.change_stream_with_client(client, max_await_time_ms=250) as stream: self.assertEqual(listener.started_command_names(), ["aggregate"]) listener.results.clear() @@ -190,9 +186,8 @@ def test_try_next_runs_one_getmore(self): listener.results.clear() # Get at least one change before resuming. - coll.insert_one({'_id': 2}) - wait_until(lambda: stream.try_next() is not None, - "get change from try_next") + coll.insert_one({"_id": 2}) + wait_until(lambda: stream.try_next() is not None, "get change from try_next") listener.results.clear() # Cause the next request to initiate the resume process. @@ -204,43 +199,38 @@ def test_try_next_runs_one_getmore(self): # - resume with aggregate command # - no results, return immediately without another getMore self.assertIsNone(stream.try_next()) - self.assertEqual( - listener.started_command_names(), ["getMore", "aggregate"]) + self.assertEqual(listener.started_command_names(), ["getMore", "aggregate"]) listener.results.clear() # Stream still works after a resume. - coll.insert_one({'_id': 3}) - wait_until(lambda: stream.try_next() is not None, - "get change from try_next") - self.assertEqual(set(listener.started_command_names()), - set(["getMore"])) + coll.insert_one({"_id": 3}) + wait_until(lambda: stream.try_next() is not None, "get change from try_next") + self.assertEqual(set(listener.started_command_names()), set(["getMore"])) self.assertIsNone(stream.try_next()) def test_batch_size_is_honored(self): listener = EventListener() client = rs_or_single_client(event_listeners=[listener]) # Connect to the cluster. - client.admin.command('ping') + client.admin.command("ping") listener.results.clear() # ChangeStreams only read majority committed data so use w:majority. - coll = self.watched_collection().with_options( - write_concern=WriteConcern("majority")) + coll = self.watched_collection().with_options(write_concern=WriteConcern("majority")) coll.drop() # Create the watched collection before starting the change stream to # skip any "create" events. - coll.insert_one({'_id': 1}) + coll.insert_one({"_id": 1}) self.addCleanup(coll.drop) # Expected batchSize. - expected = {'batchSize': 23} - with self.change_stream_with_client( - client, max_await_time_ms=250, batch_size=23) as stream: + expected = {"batchSize": 23} + with self.change_stream_with_client(client, max_await_time_ms=250, batch_size=23) as stream: # Confirm that batchSize is honored for initial batch. - cmd = listener.results['started'][0].command - self.assertEqual(cmd['cursor'], expected) + cmd = listener.results["started"][0].command + self.assertEqual(cmd["cursor"], expected) listener.results.clear() # Confirm that batchSize is honored by getMores. self.assertIsNone(stream.try_next()) - cmd = listener.results['started'][0].command + cmd = listener.results["started"][0].command key = next(iter(expected)) self.assertEqual(expected[key], cmd[key]) @@ -249,8 +239,7 @@ def test_batch_size_is_honored(self): def test_start_at_operation_time(self): optime = self.get_start_at_operation_time() - coll = self.watched_collection( - write_concern=WriteConcern("majority")) + coll = self.watched_collection(write_concern=WriteConcern("majority")) ndocs = 3 coll.insert_many([{"data": i} for i in range(ndocs)]) @@ -261,17 +250,16 @@ def test_start_at_operation_time(self): def _test_full_pipeline(self, expected_cs_stage): client, listener = self.client_with_listener("aggregate") results = listener.results - with self.change_stream_with_client( - client, [{'$project': {'foo': 0}}]) as _: + with self.change_stream_with_client(client, [{"$project": {"foo": 0}}]) as _: pass - self.assertEqual(1, len(results['started'])) - command = results['started'][0] - self.assertEqual('aggregate', command.command_name) - self.assertEqual([ - {'$changeStream': expected_cs_stage}, - {'$project': {'foo': 0}}], - command.command['pipeline']) + self.assertEqual(1, len(results["started"])) + command = results["started"][0] + self.assertEqual("aggregate", command.command_name) + self.assertEqual( + [{"$changeStream": expected_cs_stage}, {"$project": {"foo": 0}}], + command.command["pipeline"], + ) def test_full_pipeline(self): """$changeStream must be the first stage in a change stream pipeline @@ -282,21 +270,19 @@ def test_full_pipeline(self): def test_iteration(self): with self.change_stream(batch_size=2) as change_stream: num_inserted = 10 - self.watched_collection().insert_many( - [{} for _ in range(num_inserted)]) + self.watched_collection().insert_many([{} for _ in range(num_inserted)]) inserts_received = 0 for change in change_stream: - self.assertEqual(change['operationType'], 'insert') + self.assertEqual(change["operationType"], "insert") inserts_received += 1 if inserts_received == num_inserted: break self._test_invalidate_stops_iteration(change_stream) def _test_next_blocks(self, change_stream): - inserted_doc = {'_id': ObjectId()} + inserted_doc = {"_id": ObjectId()} changes = [] - t = threading.Thread( - target=lambda: changes.append(change_stream.next())) + t = threading.Thread(target=lambda: changes.append(change_stream.next())) t.start() # Sleep for a bit to prove that the call to next() blocks. time.sleep(1) @@ -308,8 +294,8 @@ def _test_next_blocks(self, change_stream): t.join(30) self.assertFalse(t.is_alive()) self.assertEqual(1, len(changes)) - self.assertEqual(changes[0]['operationType'], 'insert') - self.assertEqual(changes[0]['fullDocument'], inserted_doc) + self.assertEqual(changes[0]["operationType"], "insert") + self.assertEqual(changes[0]["fullDocument"], inserted_doc) def test_next_blocks(self): """Test that next blocks until a change is readable""" @@ -320,16 +306,19 @@ def test_next_blocks(self): def test_aggregate_cursor_blocks(self): """Test that an aggregate cursor blocks until a change is readable.""" with self.watched_collection().aggregate( - [{'$changeStream': {}}], maxAwaitTimeMS=250) as change_stream: + [{"$changeStream": {}}], maxAwaitTimeMS=250 + ) as change_stream: self._test_next_blocks(change_stream) def test_concurrent_close(self): """Ensure a ChangeStream can be closed from another thread.""" # Use a short await time to speed up the test. with self.change_stream(max_await_time_ms=250) as change_stream: + def iterate_cursor(): for _ in change_stream: pass + t = threading.Thread(target=iterate_cursor) t.start() self.watched_collection().insert_one({}) @@ -339,57 +328,55 @@ def iterate_cursor(): self.assertFalse(t.is_alive()) def test_unknown_full_document(self): - """Must rely on the server to raise an error on unknown fullDocument. - """ + """Must rely on the server to raise an error on unknown fullDocument.""" try: - with self.change_stream(full_document='notValidatedByPyMongo'): + with self.change_stream(full_document="notValidatedByPyMongo"): pass except OperationFailure: pass def test_change_operations(self): """Test each operation type.""" - expected_ns = {'db': self.watched_collection().database.name, - 'coll': self.watched_collection().name} + expected_ns = { + "db": self.watched_collection().database.name, + "coll": self.watched_collection().name, + } with self.change_stream() as change_stream: # Insert. - inserted_doc = {'_id': ObjectId(), 'foo': 'bar'} + inserted_doc = {"_id": ObjectId(), "foo": "bar"} self.watched_collection().insert_one(inserted_doc) change = change_stream.next() - self.assertTrue(change['_id']) - self.assertEqual(change['operationType'], 'insert') - self.assertEqual(change['ns'], expected_ns) - self.assertEqual(change['fullDocument'], inserted_doc) + self.assertTrue(change["_id"]) + self.assertEqual(change["operationType"], "insert") + self.assertEqual(change["ns"], expected_ns) + self.assertEqual(change["fullDocument"], inserted_doc) # Update. - update_spec = {'$set': {'new': 1}, '$unset': {'foo': 1}} + update_spec = {"$set": {"new": 1}, "$unset": {"foo": 1}} self.watched_collection().update_one(inserted_doc, update_spec) change = change_stream.next() - self.assertTrue(change['_id']) - self.assertEqual(change['operationType'], 'update') - self.assertEqual(change['ns'], expected_ns) - self.assertNotIn('fullDocument', change) - - expected_update_description = { - 'updatedFields': {'new': 1}, - 'removedFields': ['foo']} + self.assertTrue(change["_id"]) + self.assertEqual(change["operationType"], "update") + self.assertEqual(change["ns"], expected_ns) + self.assertNotIn("fullDocument", change) + + expected_update_description = {"updatedFields": {"new": 1}, "removedFields": ["foo"]} if client_context.version.at_least(4, 5, 0): - expected_update_description['truncatedArrays'] = [] - self.assertEqual(expected_update_description, - change['updateDescription']) + expected_update_description["truncatedArrays"] = [] + self.assertEqual(expected_update_description, change["updateDescription"]) # Replace. - self.watched_collection().replace_one({'new': 1}, {'foo': 'bar'}) + self.watched_collection().replace_one({"new": 1}, {"foo": "bar"}) change = change_stream.next() - self.assertTrue(change['_id']) - self.assertEqual(change['operationType'], 'replace') - self.assertEqual(change['ns'], expected_ns) - self.assertEqual(change['fullDocument'], inserted_doc) + self.assertTrue(change["_id"]) + self.assertEqual(change["operationType"], "replace") + self.assertEqual(change["ns"], expected_ns) + self.assertEqual(change["fullDocument"], inserted_doc) # Delete. - self.watched_collection().delete_one({'foo': 'bar'}) + self.watched_collection().delete_one({"foo": "bar"}) change = change_stream.next() - self.assertTrue(change['_id']) - self.assertEqual(change['operationType'], 'delete') - self.assertEqual(change['ns'], expected_ns) - self.assertNotIn('fullDocument', change) + self.assertTrue(change["_id"]) + self.assertEqual(change["operationType"], "delete") + self.assertEqual(change["ns"], expected_ns) + self.assertNotIn("fullDocument", change) # Invalidate. self._test_get_invalidate_event(change_stream) @@ -403,44 +390,42 @@ def test_start_after(self): # start_after can resume after invalidate. with self.change_stream(start_after=resume_token) as change_stream: - self.watched_collection().insert_one({'_id': 2}) + self.watched_collection().insert_one({"_id": 2}) change = change_stream.next() - self.assertEqual(change['operationType'], 'insert') - self.assertEqual(change['fullDocument'], {'_id': 2}) + self.assertEqual(change["operationType"], "insert") + self.assertEqual(change["fullDocument"], {"_id": 2}) @client_context.require_version_min(4, 1, 1) def test_start_after_resume_process_with_changes(self): resume_token = self.get_resume_token(invalidate=True) - with self.change_stream(start_after=resume_token, - max_await_time_ms=250) as change_stream: - self.watched_collection().insert_one({'_id': 2}) + with self.change_stream(start_after=resume_token, max_await_time_ms=250) as change_stream: + self.watched_collection().insert_one({"_id": 2}) change = change_stream.next() - self.assertEqual(change['operationType'], 'insert') - self.assertEqual(change['fullDocument'], {'_id': 2}) + self.assertEqual(change["operationType"], "insert") + self.assertEqual(change["fullDocument"], {"_id": 2}) self.assertIsNone(change_stream.try_next()) self.kill_change_stream_cursor(change_stream) - self.watched_collection().insert_one({'_id': 3}) + self.watched_collection().insert_one({"_id": 3}) change = change_stream.next() - self.assertEqual(change['operationType'], 'insert') - self.assertEqual(change['fullDocument'], {'_id': 3}) + self.assertEqual(change["operationType"], "insert") + self.assertEqual(change["fullDocument"], {"_id": 3}) @client_context.require_no_mongos # Remove after SERVER-41196 @client_context.require_version_min(4, 1, 1) def test_start_after_resume_process_without_changes(self): resume_token = self.get_resume_token(invalidate=True) - with self.change_stream(start_after=resume_token, - max_await_time_ms=250) as change_stream: + with self.change_stream(start_after=resume_token, max_await_time_ms=250) as change_stream: self.assertIsNone(change_stream.try_next()) self.kill_change_stream_cursor(change_stream) - self.watched_collection().insert_one({'_id': 2}) + self.watched_collection().insert_one({"_id": 2}) change = change_stream.next() - self.assertEqual(change['operationType'], 'insert') - self.assertEqual(change['fullDocument'], {'_id': 2}) + self.assertEqual(change["operationType"], "insert") + self.assertEqual(change["fullDocument"], {"_id": 2}) class ProseSpecTestsMixin(object): @@ -451,45 +436,41 @@ def _client_with_listener(self, *commands): return client, listener def _populate_and_exhaust_change_stream(self, change_stream, batch_size=3): - self.watched_collection().insert_many( - [{"data": k} for k in range(batch_size)]) + self.watched_collection().insert_many([{"data": k} for k in range(batch_size)]) for _ in range(batch_size): change = next(change_stream) return change - def _get_expected_resume_token_legacy(self, stream, - listener, previous_change=None): + def _get_expected_resume_token_legacy(self, stream, listener, previous_change=None): """Predicts what the resume token should currently be for server versions that don't support postBatchResumeToken. Assumes the stream has never returned any changes if previous_change is None.""" if previous_change is None: - agg_cmd = listener.results['started'][0] + agg_cmd = listener.results["started"][0] stage = agg_cmd.command["pipeline"][0]["$changeStream"] return stage.get("resumeAfter") or stage.get("startAfter") - return previous_change['_id'] + return previous_change["_id"] - def _get_expected_resume_token(self, stream, listener, - previous_change=None): + def _get_expected_resume_token(self, stream, listener, previous_change=None): """Predicts what the resume token should currently be for server versions that support postBatchResumeToken. Assumes the stream has never returned any changes if previous_change is None. Assumes listener is a AllowListEventListener that listens for aggregate and getMore commands.""" if previous_change is None or stream._cursor._has_next(): - token = self._get_expected_resume_token_legacy( - stream, listener, previous_change) + token = self._get_expected_resume_token_legacy(stream, listener, previous_change) if token is not None: return token - response = listener.results['succeeded'][-1].reply - return response['cursor']['postBatchResumeToken'] + response = listener.results["succeeded"][-1].reply + return response["cursor"]["postBatchResumeToken"] def _test_raises_error_on_missing_id(self, expected_exception): """ChangeStream will raise an exception if the server response is missing the resume token. """ - with self.change_stream([{'$project': {'_id': 0}}]) as change_stream: + with self.change_stream([{"$project": {"_id": 0}}]) as change_stream: self.watched_collection().insert_one({}) with self.assertRaises(expected_exception): next(change_stream) @@ -500,17 +481,17 @@ def _test_raises_error_on_missing_id(self, expected_exception): def _test_update_resume_token(self, expected_rt_getter): """ChangeStream must continuously track the last seen resumeToken.""" client, listener = self._client_with_listener("aggregate", "getMore") - coll = self.watched_collection(write_concern=WriteConcern('majority')) + coll = self.watched_collection(write_concern=WriteConcern("majority")) with self.change_stream_with_client(client) as change_stream: self.assertEqual( - change_stream.resume_token, - expected_rt_getter(change_stream, listener)) + change_stream.resume_token, expected_rt_getter(change_stream, listener) + ) for _ in range(3): coll.insert_one({}) change = next(change_stream) self.assertEqual( - change_stream.resume_token, - expected_rt_getter(change_stream, listener, change)) + change_stream.resume_token, expected_rt_getter(change_stream, listener, change) + ) # Prose test no. 1 @client_context.require_version_min(4, 0, 7) @@ -538,17 +519,16 @@ def test_raises_error_on_missing_id_418minus(self): # Prose test no. 3 def test_resume_on_error(self): with self.change_stream() as change_stream: - self.insert_one_and_check(change_stream, {'_id': 1}) + self.insert_one_and_check(change_stream, {"_id": 1}) # Cause a cursor not found error on the next getMore. self.kill_change_stream_cursor(change_stream) - self.insert_one_and_check(change_stream, {'_id': 2}) + self.insert_one_and_check(change_stream, {"_id": 2}) # Prose test no. 4 @client_context.require_failCommand_fail_point def test_no_resume_attempt_if_aggregate_command_fails(self): # Set non-retryable error on aggregate command. - fail_point = {'mode': {'times': 1}, - 'data': {'errorCode': 2, 'failCommands': ['aggregate']}} + fail_point = {"mode": {"times": 1}, "data": {"errorCode": 2, "failCommands": ["aggregate"]}} client, listener = self._client_with_listener("aggregate", "getMore") with self.fail_point(fail_point): try: @@ -557,9 +537,8 @@ def test_no_resume_attempt_if_aggregate_command_fails(self): pass # Driver should have attempted aggregate command only once. - self.assertEqual(len(listener.results['started']), 1) - self.assertEqual(listener.results['started'][0].command_name, - 'aggregate') + self.assertEqual(len(listener.results["started"]), 1) + self.assertEqual(listener.results["started"][0].command_name, "aggregate") # Prose test no. 5 - REMOVED # Prose test no. 6 - SKIPPED @@ -581,14 +560,15 @@ def test_initial_empty_batch(self): # Prose test no. 8 def test_kill_cursors(self): def raise_error(): - raise ServerSelectionTimeoutError('mock error') + raise ServerSelectionTimeoutError("mock error") + with self.change_stream() as change_stream: - self.insert_one_and_check(change_stream, {'_id': 1}) + self.insert_one_and_check(change_stream, {"_id": 1}) # Cause a cursor not found error on the next getMore. cursor = change_stream._cursor self.kill_change_stream_cursor(change_stream) cursor.close = raise_error - self.insert_one_and_check(change_stream, {'_id': 2}) + self.insert_one_and_check(change_stream, {"_id": 2}) # Prose test no. 9 @client_context.require_version_min(4, 0, 0) @@ -599,21 +579,21 @@ def test_start_at_operation_time_caching(self): with self.change_stream_with_client(client) as cs: self.kill_change_stream_cursor(cs) cs.try_next() - cmd = listener.results['started'][-1].command - self.assertIsNotNone(cmd["pipeline"][0]["$changeStream"].get( - "startAtOperationTime")) + cmd = listener.results["started"][-1].command + self.assertIsNotNone(cmd["pipeline"][0]["$changeStream"].get("startAtOperationTime")) # Case 2: change stream started with startAtOperationTime listener.results.clear() optime = self.get_start_at_operation_time() - with self.change_stream_with_client( - client, start_at_operation_time=optime) as cs: + with self.change_stream_with_client(client, start_at_operation_time=optime) as cs: self.kill_change_stream_cursor(cs) cs.try_next() - cmd = listener.results['started'][-1].command - self.assertEqual(cmd["pipeline"][0]["$changeStream"].get( - "startAtOperationTime"), optime, str([k.command for k in - listener.results['started']])) + cmd = listener.results["started"][-1].command + self.assertEqual( + cmd["pipeline"][0]["$changeStream"].get("startAtOperationTime"), + optime, + str([k.command for k in listener.results["started"]]), + ) # Prose test no. 10 - SKIPPED # This test is identical to prose test no. 3. @@ -626,9 +606,8 @@ def test_resumetoken_empty_batch(self): self.assertIsNone(change_stream.try_next()) resume_token = change_stream.resume_token - response = listener.results['succeeded'][0].reply - self.assertEqual(resume_token, - response["cursor"]["postBatchResumeToken"]) + response = listener.results["succeeded"][0].reply + self.assertEqual(resume_token, response["cursor"]["postBatchResumeToken"]) # Prose test no. 11 @client_context.require_version_min(4, 0, 7) @@ -638,9 +617,8 @@ def test_resumetoken_exhausted_batch(self): self._populate_and_exhaust_change_stream(change_stream) resume_token = change_stream.resume_token - response = listener.results['succeeded'][-1].reply - self.assertEqual(resume_token, - response["cursor"]["postBatchResumeToken"]) + response = listener.results["succeeded"][-1].reply + self.assertEqual(resume_token, response["cursor"]["postBatchResumeToken"]) # Prose test no. 12 @client_context.require_version_max(4, 0, 7) @@ -665,7 +643,7 @@ def test_resumetoken_exhausted_batch_legacy(self): with self.change_stream() as change_stream: change = self._populate_and_exhaust_change_stream(change_stream) self.assertEqual(change_stream.resume_token, change["_id"]) - resume_point = change['_id'] + resume_point = change["_id"] # Resume token is _id of last change even if resumeAfter is specified. with self.change_stream(resume_after=resume_point) as change_stream: @@ -677,9 +655,9 @@ def test_resumetoken_partially_iterated_batch(self): # When batch has been iterated up to but not including the last element. # Resume token should be _id of previous change document. with self.change_stream() as change_stream: - self.watched_collection( - write_concern=WriteConcern('majority')).insert_many( - [{"data": k} for k in range(3)]) + self.watched_collection(write_concern=WriteConcern("majority")).insert_many( + [{"data": k} for k in range(3)] + ) for _ in range(2): change = next(change_stream) resume_token = change_stream.resume_token @@ -692,13 +670,12 @@ def _test_resumetoken_uniterated_nonempty_batch(self, resume_option): resume_point = self.get_resume_token() # Insert some documents so that firstBatch isn't empty. - self.watched_collection( - write_concern=WriteConcern("majority")).insert_many( - [{'a': 1}, {'b': 2}, {'c': 3}]) + self.watched_collection(write_concern=WriteConcern("majority")).insert_many( + [{"a": 1}, {"b": 2}, {"c": 3}] + ) # Resume token should be same as the resume option. - with self.change_stream( - **{resume_option: resume_point}) as change_stream: + with self.change_stream(**{resume_option: resume_point}) as change_stream: self.assertTrue(change_stream._cursor._has_next()) resume_token = change_stream.resume_token self.assertEqual(resume_token, resume_point) @@ -721,18 +698,15 @@ def test_startafter_resume_uses_startafter_after_empty_getMore(self): resume_point = self.get_resume_token() client, listener = self._client_with_listener("aggregate") - with self.change_stream_with_client( - client, start_after=resume_point) as change_stream: + with self.change_stream_with_client(client, start_after=resume_point) as change_stream: self.assertFalse(change_stream._cursor._has_next()) # No changes - change_stream.try_next() # No changes + change_stream.try_next() # No changes self.kill_change_stream_cursor(change_stream) - change_stream.try_next() # Resume attempt + change_stream.try_next() # Resume attempt - response = listener.results['started'][-1] - self.assertIsNone( - response.command["pipeline"][0]["$changeStream"].get("resumeAfter")) - self.assertIsNotNone( - response.command["pipeline"][0]["$changeStream"].get("startAfter")) + response = listener.results["started"][-1] + self.assertIsNone(response.command["pipeline"][0]["$changeStream"].get("resumeAfter")) + self.assertIsNotNone(response.command["pipeline"][0]["$changeStream"].get("startAfter")) # Prose test no. 18 @client_context.require_version_min(4, 1, 1) @@ -741,19 +715,16 @@ def test_startafter_resume_uses_resumeafter_after_nonempty_getMore(self): resume_point = self.get_resume_token() client, listener = self._client_with_listener("aggregate") - with self.change_stream_with_client( - client, start_after=resume_point) as change_stream: + with self.change_stream_with_client(client, start_after=resume_point) as change_stream: self.assertFalse(change_stream._cursor._has_next()) # No changes self.watched_collection().insert_one({}) - next(change_stream) # Changes + next(change_stream) # Changes self.kill_change_stream_cursor(change_stream) - change_stream.try_next() # Resume attempt + change_stream.try_next() # Resume attempt - response = listener.results['started'][-1] - self.assertIsNotNone( - response.command["pipeline"][0]["$changeStream"].get("resumeAfter")) - self.assertIsNone( - response.command["pipeline"][0]["$changeStream"].get("startAfter")) + response = listener.results["started"][-1] + self.assertIsNotNone(response.command["pipeline"][0]["$changeStream"].get("resumeAfter")) + self.assertIsNone(response.command["pipeline"][0]["$changeStream"].get("startAfter")) class TestClusterChangeStream(TestChangeStreamBase, APITestsMixin): @@ -789,10 +760,9 @@ def _insert_and_check(self, change_stream, db, collname, doc): coll = db[collname] coll.insert_one(doc) change = next(change_stream) - self.assertEqual(change['operationType'], 'insert') - self.assertEqual(change['ns'], {'db': db.name, - 'coll': collname}) - self.assertEqual(change['fullDocument'], doc) + self.assertEqual(change["operationType"], "insert") + self.assertEqual(change["ns"], {"db": db.name, "coll": collname}) + self.assertEqual(change["fullDocument"], doc) def insert_one_and_check(self, change_stream, doc): db = random.choice(self.dbs) @@ -803,22 +773,20 @@ def test_simple(self): collnames = self.generate_unique_collnames(3) with self.change_stream() as change_stream: for db, collname in product(self.dbs, collnames): - self._insert_and_check( - change_stream, db, collname, {'_id': collname} - ) + self._insert_and_check(change_stream, db, collname, {"_id": collname}) def test_aggregate_cursor_blocks(self): """Test that an aggregate cursor blocks until a change is readable.""" with self.client.admin.aggregate( - [{'$changeStream': {'allChangesForCluster': True}}], - maxAwaitTimeMS=250) as change_stream: + [{"$changeStream": {"allChangesForCluster": True}}], maxAwaitTimeMS=250 + ) as change_stream: self._test_next_blocks(change_stream) def test_full_pipeline(self): """$changeStream must be the first stage in a change stream pipeline sent to the server. """ - self._test_full_pipeline({'allChangesForCluster': True}) + self._test_full_pipeline({"allChangesForCluster": True}) class TestDatabaseChangeStream(TestChangeStreamBase, APITestsMixin): @@ -844,22 +812,22 @@ def _test_get_invalidate_event(self, change_stream): change = change_stream.next() # 4.1+ returns "drop" events for each collection in dropped database # and a "dropDatabase" event for the database itself. - if change['operationType'] == 'drop': - self.assertTrue(change['_id']) + if change["operationType"] == "drop": + self.assertTrue(change["_id"]) for _ in range(len(dropped_colls)): - ns = change['ns'] - self.assertEqual(ns['db'], change_stream._target.name) - self.assertIn(ns['coll'], dropped_colls) + ns = change["ns"] + self.assertEqual(ns["db"], change_stream._target.name) + self.assertIn(ns["coll"], dropped_colls) change = change_stream.next() - self.assertEqual(change['operationType'], 'dropDatabase') - self.assertTrue(change['_id']) - self.assertEqual(change['ns'], {'db': change_stream._target.name}) + self.assertEqual(change["operationType"], "dropDatabase") + self.assertTrue(change["_id"]) + self.assertEqual(change["ns"], {"db": change_stream._target.name}) # Get next change. change = change_stream.next() - self.assertTrue(change['_id']) - self.assertEqual(change['operationType'], 'invalidate') - self.assertNotIn('ns', change) - self.assertNotIn('fullDocument', change) + self.assertTrue(change["_id"]) + self.assertEqual(change["operationType"], "invalidate") + self.assertNotIn("ns", change) + self.assertNotIn("fullDocument", change) # The ChangeStream should be dead. with self.assertRaises(StopIteration): change_stream.next() @@ -869,10 +837,9 @@ def _test_invalidate_stops_iteration(self, change_stream): change_stream._client.drop_database(self.db.name) # Check drop and dropDatabase events. for change in change_stream: - self.assertIn(change['operationType'], ( - 'drop', 'dropDatabase', 'invalidate')) + self.assertIn(change["operationType"], ("drop", "dropDatabase", "invalidate")) # Last change must be invalidate. - self.assertEqual(change['operationType'], 'invalidate') + self.assertEqual(change["operationType"], "invalidate") # Change stream must not allow further iteration. with self.assertRaises(StopIteration): change_stream.next() @@ -883,10 +850,9 @@ def _insert_and_check(self, change_stream, collname, doc): coll = self.db[collname] coll.insert_one(doc) change = next(change_stream) - self.assertEqual(change['operationType'], 'insert') - self.assertEqual(change['ns'], {'db': self.db.name, - 'coll': collname}) - self.assertEqual(change['fullDocument'], doc) + self.assertEqual(change["operationType"], "insert") + self.assertEqual(change["ns"], {"db": self.db.name, "coll": collname}) + self.assertEqual(change["fullDocument"], doc) def insert_one_and_check(self, change_stream, doc): self._insert_and_check(change_stream, self.id(), doc) @@ -896,26 +862,21 @@ def test_simple(self): with self.change_stream() as change_stream: for collname in collnames: self._insert_and_check( - change_stream, collname, - {'_id': Binary.from_uuid(uuid.uuid4())}) + change_stream, collname, {"_id": Binary.from_uuid(uuid.uuid4())} + ) def test_isolation(self): # Ensure inserts to other dbs don't show up in our ChangeStream. other_db = self.client.pymongo_test_temp - self.assertNotEqual( - other_db, self.db, msg="Isolation must be tested on separate DBs") + self.assertNotEqual(other_db, self.db, msg="Isolation must be tested on separate DBs") collname = self.id() with self.change_stream() as change_stream: - other_db[collname].insert_one( - {'_id': Binary.from_uuid(uuid.uuid4())}) - self._insert_and_check( - change_stream, collname, - {'_id': Binary.from_uuid(uuid.uuid4())}) + other_db[collname].insert_one({"_id": Binary.from_uuid(uuid.uuid4())}) + self._insert_and_check(change_stream, collname, {"_id": Binary.from_uuid(uuid.uuid4())}) self.client.drop_database(other_db) -class TestCollectionChangeStream(TestChangeStreamBase, APITestsMixin, - ProseSpecTestsMixin): +class TestCollectionChangeStream(TestChangeStreamBase, APITestsMixin, ProseSpecTestsMixin): @classmethod @client_context.require_version_min(3, 5, 11) @client_context.require_no_mmap @@ -929,8 +890,11 @@ def setUp(self): self.watched_collection().insert_one({}) def change_stream_with_client(self, client, *args, **kwargs): - return client[self.db.name].get_collection( - self.watched_collection().name).watch(*args, **kwargs) + return ( + client[self.db.name] + .get_collection(self.watched_collection().name) + .watch(*args, **kwargs) + ) def generate_invalidate_event(self, change_stream): # Dropping the collection invalidates the change stream. @@ -940,9 +904,9 @@ def _test_invalidate_stops_iteration(self, change_stream): self.generate_invalidate_event(change_stream) # Check drop and dropDatabase events. for change in change_stream: - self.assertIn(change['operationType'], ('drop', 'invalidate')) + self.assertIn(change["operationType"], ("drop", "invalidate")) # Last change must be invalidate. - self.assertEqual(change['operationType'], 'invalidate') + self.assertEqual(change["operationType"], "invalidate") # Change stream must not allow further iteration. with self.assertRaises(StopIteration): change_stream.next() @@ -954,17 +918,18 @@ def _test_get_invalidate_event(self, change_stream): change_stream._target.drop() change = change_stream.next() # 4.1+ returns a "drop" change document. - if change['operationType'] == 'drop': - self.assertTrue(change['_id']) - self.assertEqual(change['ns'], { - 'db': change_stream._target.database.name, - 'coll': change_stream._target.name}) + if change["operationType"] == "drop": + self.assertTrue(change["_id"]) + self.assertEqual( + change["ns"], + {"db": change_stream._target.database.name, "coll": change_stream._target.name}, + ) # Last change should be invalidate. change = change_stream.next() - self.assertTrue(change['_id']) - self.assertEqual(change['operationType'], 'invalidate') - self.assertNotIn('ns', change) - self.assertNotIn('fullDocument', change) + self.assertTrue(change["_id"]) + self.assertEqual(change["operationType"], "invalidate") + self.assertNotIn("ns", change) + self.assertNotIn("fullDocument", change) # The ChangeStream should be dead. with self.assertRaises(StopIteration): change_stream.next() @@ -972,38 +937,36 @@ def _test_get_invalidate_event(self, change_stream): def insert_one_and_check(self, change_stream, doc): self.watched_collection().insert_one(doc) change = next(change_stream) - self.assertEqual(change['operationType'], 'insert') + self.assertEqual(change["operationType"], "insert") self.assertEqual( - change['ns'], {'db': self.watched_collection().database.name, - 'coll': self.watched_collection().name}) - self.assertEqual(change['fullDocument'], doc) + change["ns"], + {"db": self.watched_collection().database.name, "coll": self.watched_collection().name}, + ) + self.assertEqual(change["fullDocument"], doc) def test_raw(self): """Test with RawBSONDocument.""" - raw_coll = self.watched_collection( - codec_options=DEFAULT_RAW_BSON_OPTIONS) + raw_coll = self.watched_collection(codec_options=DEFAULT_RAW_BSON_OPTIONS) with raw_coll.watch() as change_stream: - raw_doc = RawBSONDocument(encode({'_id': 1})) + raw_doc = RawBSONDocument(encode({"_id": 1})) self.watched_collection().insert_one(raw_doc) change = next(change_stream) self.assertIsInstance(change, RawBSONDocument) - self.assertEqual(change['operationType'], 'insert') - self.assertEqual( - change['ns']['db'], self.watched_collection().database.name) - self.assertEqual( - change['ns']['coll'], self.watched_collection().name) - self.assertEqual(change['fullDocument'], raw_doc) + self.assertEqual(change["operationType"], "insert") + self.assertEqual(change["ns"]["db"], self.watched_collection().database.name) + self.assertEqual(change["ns"]["coll"], self.watched_collection().name) + self.assertEqual(change["fullDocument"], raw_doc) def test_uuid_representations(self): """Test with uuid document _ids and different uuid_representation.""" for uuid_representation in ALL_UUID_REPRESENTATIONS: for id_subtype in (STANDARD, PYTHON_LEGACY): options = self.watched_collection().codec_options.with_options( - uuid_representation=uuid_representation) + uuid_representation=uuid_representation + ) coll = self.watched_collection(codec_options=options) with coll.watch() as change_stream: - coll.insert_one( - {'_id': Binary(uuid.uuid4().bytes, id_subtype)}) + coll.insert_one({"_id": Binary(uuid.uuid4().bytes, id_subtype)}) _ = change_stream.next() resume_token = change_stream.resume_token @@ -1012,12 +975,12 @@ def test_uuid_representations(self): def test_document_id_order(self): """Test with document _ids that need their order preserved.""" - random_keys = random.sample(string.ascii_letters, - len(string.ascii_letters)) - random_doc = {'_id': SON([(key, key) for key in random_keys])} + random_keys = random.sample(string.ascii_letters, len(string.ascii_letters)) + random_doc = {"_id": SON([(key, key) for key in random_keys])} for document_class in (dict, SON, RawBSONDocument): options = self.watched_collection().codec_options.with_options( - document_class=document_class) + document_class=document_class + ) coll = self.watched_collection(codec_options=options) with coll.watch() as change_stream: coll.insert_one(random_doc) @@ -1033,12 +996,12 @@ def test_document_id_order(self): def test_read_concern(self): """Test readConcern is not validated by the driver.""" # Read concern 'local' is not allowed for $changeStream. - coll = self.watched_collection(read_concern=ReadConcern('local')) + coll = self.watched_collection(read_concern=ReadConcern("local")) with self.assertRaises(OperationFailure): coll.watch() # Does not error. - coll = self.watched_collection(read_concern=ReadConcern('majority')) + coll = self.watched_collection(read_concern=ReadConcern("majority")) with coll.watch(): pass @@ -1063,10 +1026,13 @@ def setUp(self): self.listener.results.clear() def setUpCluster(self, scenario_dict): - assets = [(scenario_dict["database_name"], - scenario_dict["collection_name"]), - (scenario_dict.get("database2_name", "db2"), - scenario_dict.get("collection2_name", "coll2"))] + assets = [ + (scenario_dict["database_name"], scenario_dict["collection_name"]), + ( + scenario_dict.get("database2_name", "db2"), + scenario_dict.get("collection2_name", "coll2"), + ), + ] for db, coll in assets: self.client.drop_database(db) self.client[db].create_collection(coll) @@ -1078,12 +1044,15 @@ def setFailPoint(self, scenario_dict): elif not client_context.test_commands_enabled: self.skipTest("Test commands must be enabled") - fail_cmd = SON([('configureFailPoint', 'failCommand')]) + fail_cmd = SON([("configureFailPoint", "failCommand")]) fail_cmd.update(fail_point) client_context.client.admin.command(fail_cmd) self.addCleanup( client_context.client.admin.command, - 'configureFailPoint', fail_cmd['configureFailPoint'], mode='off') + "configureFailPoint", + fail_cmd["configureFailPoint"], + mode="off", + ) def assert_list_contents_are_subset(self, superlist, sublist): """Check that each element in sublist is a subset of the corresponding @@ -1103,7 +1072,7 @@ def assert_dict_is_subset(self, superdict, subdict): exempt_fields = ["documentKey", "_id", "getMore"] for key, value in subdict.items(): if key not in superdict: - self.fail('Key %s not found in %s' % (key, superdict)) + self.fail("Key %s not found in %s" % (key, superdict)) if isinstance(value, dict): self.assert_dict_is_subset(superdict[key], value) continue @@ -1129,14 +1098,13 @@ def tearDown(self): self.listener.results.clear() -_TEST_PATH = os.path.join( - os.path.dirname(os.path.realpath(__file__)), 'change_streams') +_TEST_PATH = os.path.join(os.path.dirname(os.path.realpath(__file__)), "change_streams") def camel_to_snake(camel): # Regex to convert CamelCase to snake_case. - snake = re.sub('(.)([A-Z][a-z]+)', r'\1_\2', camel) - return re.sub('([a-z0-9])([A-Z])', r'\1_\2', snake).lower() + snake = re.sub("(.)([A-Z][a-z]+)", r"\1_\2", camel) + return re.sub("([a-z0-9])([A-Z])", r"\1_\2", snake).lower() def get_change_stream(client, scenario_def, test): @@ -1167,12 +1135,11 @@ def run_operation(client, operation): # Apply specified operations opname = camel_to_snake(operation["name"]) arguments = operation.get("arguments", {}) - if opname == 'rename': + if opname == "rename": # Special case for rename operation. - arguments = {'new_name': arguments["to"]} - cmd = getattr(client.get_database( - operation["database"]).get_collection( - operation["collection"]), opname + arguments = {"new_name": arguments["to"]} + cmd = getattr( + client.get_database(operation["database"]).get_collection(operation["collection"]), opname ) return cmd(**arguments) @@ -1184,15 +1151,12 @@ def run_scenario(self): self.setFailPoint(test) is_error = test["result"].get("error", False) try: - with get_change_stream( - self.client, scenario_def, test - ) as change_stream: + with get_change_stream(self.client, scenario_def, test) as change_stream: for operation in test["operations"]: # Run specified operations run_operation(self.client, operation) num_expected_changes = len(test["result"].get("success", [])) - changes = [ - change_stream.next() for _ in range(num_expected_changes)] + changes = [change_stream.next() for _ in range(num_expected_changes)] # Run a next() to induce an error if one is expected and # there are no changes. if is_error and not changes: @@ -1226,7 +1190,7 @@ def run_scenario(self): def create_tests(): - for dirpath, _, filenames in os.walk(os.path.join(_TEST_PATH, 'legacy')): + for dirpath, _, filenames in os.walk(os.path.join(_TEST_PATH, "legacy")): dirname = os.path.split(dirpath)[-1] for filename in filenames: @@ -1235,31 +1199,25 @@ def create_tests(): test_type = os.path.splitext(filename)[0] - for test in scenario_def['tests']: + for test in scenario_def["tests"]: new_test = create_test(scenario_def, test) new_test = client_context.require_no_mmap(new_test) - if 'minServerVersion' in test: - min_ver = tuple( - int(elt) for - elt in test['minServerVersion'].split('.')) - new_test = client_context.require_version_min(*min_ver)( - new_test) - if 'maxServerVersion' in test: - max_ver = tuple( - int(elt) for - elt in test['maxServerVersion'].split('.')) - new_test = client_context.require_version_max(*max_ver)( - new_test) - - topologies = test['topology'] - new_test = client_context.require_cluster_type(topologies)( - new_test) - - test_name = 'test_%s_%s_%s' % ( + if "minServerVersion" in test: + min_ver = tuple(int(elt) for elt in test["minServerVersion"].split(".")) + new_test = client_context.require_version_min(*min_ver)(new_test) + if "maxServerVersion" in test: + max_ver = tuple(int(elt) for elt in test["maxServerVersion"].split(".")) + new_test = client_context.require_version_max(*max_ver)(new_test) + + topologies = test["topology"] + new_test = client_context.require_cluster_type(topologies)(new_test) + + test_name = "test_%s_%s_%s" % ( dirname, test_type.replace("-", "_"), - str(test['description'].replace(" ", "_"))) + str(test["description"].replace(" ", "_")), + ) new_test.__name__ = test_name setattr(TestAllLegacyScenarios, new_test.__name__, new_test) @@ -1268,10 +1226,13 @@ def create_tests(): create_tests() -globals().update(generate_test_classes( - os.path.join(_TEST_PATH, 'unified'), - module=__name__,)) +globals().update( + generate_test_classes( + os.path.join(_TEST_PATH, "unified"), + module=__name__, + ) +) -if __name__ == '__main__': +if __name__ == "__main__": unittest.main() diff --git a/test/test_client.py b/test/test_client.py index 8c89a45481..335c7dacfa 100644 --- a/test/test_client.py +++ b/test/test_client.py @@ -14,6 +14,7 @@ """Test the mongo_client module.""" +import _thread as thread import contextlib import copy import datetime @@ -23,78 +24,85 @@ import socket import struct import sys -import time -import _thread as thread import threading +import time import warnings sys.path[0:0] = [""] +from test import ( + HAVE_IPADDRESS, + IntegrationTest, + MockClientTest, + SkipTest, + client_context, + client_knobs, + db_pwd, + db_user, + unittest, +) +from test.pymongo_mocks import MockClient +from test.utils import ( + NTHREADS, + CMAPListener, + FunctionCallRecorder, + assertRaisesExactly, + connected, + delay, + get_pool, + gevent_monkey_patched, + is_greenthread_patched, + lazy_client_trial, + one, + remove_all_users, + rs_client, + rs_or_single_client, + rs_or_single_client_noauth, + single_client, + wait_until, +) + +import pymongo from bson import encode from bson.codec_options import CodecOptions, TypeEncoder, TypeRegistry from bson.son import SON from bson.tz_util import utc -import pymongo from pymongo import event_loggers, message, monitoring from pymongo.client_options import ClientOptions from pymongo.command_cursor import CommandCursor -from pymongo.common import CONNECT_TIMEOUT, _UUID_REPRESENTATIONS +from pymongo.common import _UUID_REPRESENTATIONS, CONNECT_TIMEOUT from pymongo.compression_support import _HAVE_SNAPPY, _HAVE_ZSTD from pymongo.cursor import Cursor, CursorType from pymongo.database import Database from pymongo.driver_info import DriverInfo -from pymongo.errors import (AutoReconnect, - ConfigurationError, - ConnectionFailure, - InvalidName, - InvalidURI, - NetworkTimeout, - OperationFailure, - ServerSelectionTimeoutError, - WriteConcernError, - InvalidOperation) +from pymongo.errors import ( + AutoReconnect, + ConfigurationError, + ConnectionFailure, + InvalidName, + InvalidOperation, + InvalidURI, + NetworkTimeout, + OperationFailure, + ServerSelectionTimeoutError, + WriteConcernError, +) from pymongo.hello import HelloCompat from pymongo.mongo_client import MongoClient -from pymongo.monitoring import (ServerHeartbeatListener, - ServerHeartbeatStartedEvent) -from pymongo.pool import SocketInfo, _METADATA, PoolOptions +from pymongo.monitoring import ServerHeartbeatListener, ServerHeartbeatStartedEvent +from pymongo.pool import _METADATA, PoolOptions, SocketInfo from pymongo.read_preferences import ReadPreference from pymongo.server_description import ServerDescription -from pymongo.server_selectors import (readable_server_selector, - writable_server_selector) +from pymongo.server_selectors import readable_server_selector, writable_server_selector from pymongo.server_type import SERVER_TYPE from pymongo.settings import TOPOLOGY_TYPE from pymongo.srv_resolver import _HAVE_DNSPYTHON from pymongo.topology import _ErrorContext -from pymongo.topology_description import TopologyDescription, _updated_topology_description_srv_polling +from pymongo.topology_description import ( + TopologyDescription, + _updated_topology_description_srv_polling, +) from pymongo.write_concern import WriteConcern -from test import (client_context, - client_knobs, - SkipTest, - unittest, - IntegrationTest, - db_pwd, - db_user, - MockClientTest, - HAVE_IPADDRESS) -from test.pymongo_mocks import MockClient -from test.utils import (assertRaisesExactly, - connected, - CMAPListener, - delay, - FunctionCallRecorder, - get_pool, - gevent_monkey_patched, - is_greenthread_patched, - lazy_client_trial, - NTHREADS, - one, - remove_all_users, - rs_client, - rs_or_single_client, - rs_or_single_client_noauth, - single_client, - wait_until) class ClientUnitTest(unittest.TestCase): @@ -103,25 +111,26 @@ class ClientUnitTest(unittest.TestCase): @classmethod @client_context.require_connection def setUpClass(cls): - cls.client = rs_or_single_client(connect=False, - serverSelectionTimeoutMS=100) + cls.client = rs_or_single_client(connect=False, serverSelectionTimeoutMS=100) @classmethod def tearDownClass(cls): cls.client.close() def test_keyword_arg_defaults(self): - client = MongoClient(socketTimeoutMS=None, - connectTimeoutMS=20000, - waitQueueTimeoutMS=None, - replicaSet=None, - read_preference=ReadPreference.PRIMARY, - ssl=False, - tlsCertificateKeyFile=None, - tlsAllowInvalidCertificates=True, - tlsCAFile=None, - connect=False, - serverSelectionTimeoutMS=12000) + client = MongoClient( + socketTimeoutMS=None, + connectTimeoutMS=20000, + waitQueueTimeoutMS=None, + replicaSet=None, + read_preference=ReadPreference.PRIMARY, + ssl=False, + tlsCertificateKeyFile=None, + tlsAllowInvalidCertificates=True, + tlsCAFile=None, + connect=False, + serverSelectionTimeoutMS=12000, + ) options = client._MongoClient__options pool_opts = options.pool_options @@ -135,19 +144,17 @@ def test_keyword_arg_defaults(self): self.assertAlmostEqual(12, client.options.server_selection_timeout) def test_connect_timeout(self): - client = MongoClient(connect=False, connectTimeoutMS=None, - socketTimeoutMS=None) + client = MongoClient(connect=False, connectTimeoutMS=None, socketTimeoutMS=None) pool_opts = client._MongoClient__options.pool_options self.assertEqual(None, pool_opts.socket_timeout) self.assertEqual(None, pool_opts.connect_timeout) - client = MongoClient(connect=False, connectTimeoutMS=0, - socketTimeoutMS=0) + client = MongoClient(connect=False, connectTimeoutMS=0, socketTimeoutMS=0) pool_opts = client._MongoClient__options.pool_options self.assertEqual(None, pool_opts.socket_timeout) self.assertEqual(None, pool_opts.connect_timeout) client = MongoClient( - 'mongodb://localhost/?connectTimeoutMS=0&socketTimeoutMS=0', - connect=False) + "mongodb://localhost/?connectTimeoutMS=0&socketTimeoutMS=0", connect=False + ) pool_opts = client._MongoClient__options.pool_options self.assertEqual(None, pool_opts.socket_timeout) self.assertEqual(None, pool_opts.connect_timeout) @@ -165,18 +172,9 @@ def test_max_pool_size_zero(self): MongoClient(maxPoolSize=0) def test_uri_detection(self): - self.assertRaises( - ConfigurationError, - MongoClient, - "/foo") - self.assertRaises( - ConfigurationError, - MongoClient, - "://") - self.assertRaises( - ConfigurationError, - MongoClient, - "foo/") + self.assertRaises(ConfigurationError, MongoClient, "/foo") + self.assertRaises(ConfigurationError, MongoClient, "://") + self.assertRaises(ConfigurationError, MongoClient, "foo/") def test_get_db(self): def make_db(base, name): @@ -196,15 +194,14 @@ def make_db(base, name): def test_get_database(self): codec_options = CodecOptions(tz_aware=True) write_concern = WriteConcern(w=2, j=True) - db = self.client.get_database( - 'foo', codec_options, ReadPreference.SECONDARY, write_concern) - self.assertEqual('foo', db.name) + db = self.client.get_database("foo", codec_options, ReadPreference.SECONDARY, write_concern) + self.assertEqual("foo", db.name) self.assertEqual(codec_options, db.codec_options) self.assertEqual(ReadPreference.SECONDARY, db.read_preference) self.assertEqual(write_concern, db.write_concern) def test_getattr(self): - self.assertTrue(isinstance(self.client['_does_not_exist'], Database)) + self.assertTrue(isinstance(self.client["_does_not_exist"], Database)) with self.assertRaises(AttributeError) as context: self.client._does_not_exist @@ -212,8 +209,7 @@ def test_getattr(self): # Message should be: # "AttributeError: MongoClient has no attribute '_does_not_exist'. To # access the _does_not_exist database, use client['_does_not_exist']". - self.assertIn("has no attribute '_does_not_exist'", - str(context.exception)) + self.assertIn("has no attribute '_does_not_exist'", str(context.exception)) def test_iteration(self): def iterate(): @@ -222,108 +218,111 @@ def iterate(): self.assertRaises(TypeError, iterate) def test_get_default_database(self): - c = rs_or_single_client("mongodb://%s:%d/foo" % (client_context.host, - client_context.port), - connect=False) - self.assertEqual(Database(c, 'foo'), c.get_default_database()) + c = rs_or_single_client( + "mongodb://%s:%d/foo" % (client_context.host, client_context.port), connect=False + ) + self.assertEqual(Database(c, "foo"), c.get_default_database()) # Test that default doesn't override the URI value. - self.assertEqual(Database(c, 'foo'), c.get_default_database('bar')) + self.assertEqual(Database(c, "foo"), c.get_default_database("bar")) codec_options = CodecOptions(tz_aware=True) write_concern = WriteConcern(w=2, j=True) - db = c.get_default_database( - None, codec_options, ReadPreference.SECONDARY, write_concern) - self.assertEqual('foo', db.name) + db = c.get_default_database(None, codec_options, ReadPreference.SECONDARY, write_concern) + self.assertEqual("foo", db.name) self.assertEqual(codec_options, db.codec_options) self.assertEqual(ReadPreference.SECONDARY, db.read_preference) self.assertEqual(write_concern, db.write_concern) - c = rs_or_single_client("mongodb://%s:%d/" % (client_context.host, - client_context.port), - connect=False) - self.assertEqual(Database(c, 'foo'), c.get_default_database('foo')) + c = rs_or_single_client( + "mongodb://%s:%d/" % (client_context.host, client_context.port), connect=False + ) + self.assertEqual(Database(c, "foo"), c.get_default_database("foo")) def test_get_default_database_error(self): # URI with no database. - c = rs_or_single_client("mongodb://%s:%d/" % (client_context.host, - client_context.port), - connect=False) + c = rs_or_single_client( + "mongodb://%s:%d/" % (client_context.host, client_context.port), connect=False + ) self.assertRaises(ConfigurationError, c.get_default_database) def test_get_default_database_with_authsource(self): # Ensure we distinguish database name from authSource. - uri = "mongodb://%s:%d/foo?authSource=src" % ( - client_context.host, client_context.port) + uri = "mongodb://%s:%d/foo?authSource=src" % (client_context.host, client_context.port) c = rs_or_single_client(uri, connect=False) - self.assertEqual(Database(c, 'foo'), c.get_default_database()) + self.assertEqual(Database(c, "foo"), c.get_default_database()) def test_get_database_default(self): - c = rs_or_single_client("mongodb://%s:%d/foo" % (client_context.host, - client_context.port), - connect=False) - self.assertEqual(Database(c, 'foo'), c.get_database()) + c = rs_or_single_client( + "mongodb://%s:%d/foo" % (client_context.host, client_context.port), connect=False + ) + self.assertEqual(Database(c, "foo"), c.get_database()) def test_get_database_default_error(self): # URI with no database. - c = rs_or_single_client("mongodb://%s:%d/" % (client_context.host, - client_context.port), - connect=False) + c = rs_or_single_client( + "mongodb://%s:%d/" % (client_context.host, client_context.port), connect=False + ) self.assertRaises(ConfigurationError, c.get_database) def test_get_database_default_with_authsource(self): # Ensure we distinguish database name from authSource. - uri = "mongodb://%s:%d/foo?authSource=src" % ( - client_context.host, client_context.port) + uri = "mongodb://%s:%d/foo?authSource=src" % (client_context.host, client_context.port) c = rs_or_single_client(uri, connect=False) - self.assertEqual(Database(c, 'foo'), c.get_database()) + self.assertEqual(Database(c, "foo"), c.get_database()) def test_primary_read_pref_with_tags(self): # No tags allowed with "primary". with self.assertRaises(ConfigurationError): - MongoClient('mongodb://host/?readpreferencetags=dc:east') + MongoClient("mongodb://host/?readpreferencetags=dc:east") with self.assertRaises(ConfigurationError): - MongoClient('mongodb://host/?' - 'readpreference=primary&readpreferencetags=dc:east') + MongoClient("mongodb://host/?" "readpreference=primary&readpreferencetags=dc:east") def test_read_preference(self): c = rs_or_single_client( - "mongodb://host", connect=False, - readpreference=ReadPreference.NEAREST.mongos_mode) + "mongodb://host", connect=False, readpreference=ReadPreference.NEAREST.mongos_mode + ) self.assertEqual(c.read_preference, ReadPreference.NEAREST) def test_metadata(self): metadata = copy.deepcopy(_METADATA) - metadata['application'] = {'name': 'foobar'} - client = MongoClient( - "mongodb://foo:27017/?appname=foobar&connect=false") + metadata["application"] = {"name": "foobar"} + client = MongoClient("mongodb://foo:27017/?appname=foobar&connect=false") options = client._MongoClient__options self.assertEqual(options.pool_options.metadata, metadata) - client = MongoClient('foo', 27017, appname='foobar', connect=False) + client = MongoClient("foo", 27017, appname="foobar", connect=False) options = client._MongoClient__options self.assertEqual(options.pool_options.metadata, metadata) # No error - MongoClient(appname='x' * 128) - self.assertRaises(ValueError, MongoClient, appname='x' * 129) + MongoClient(appname="x" * 128) + self.assertRaises(ValueError, MongoClient, appname="x" * 129) # Bad "driver" options. - self.assertRaises(TypeError, DriverInfo, 'Foo', 1, 'a') - self.assertRaises(TypeError, DriverInfo, version="1", platform='a') + self.assertRaises(TypeError, DriverInfo, "Foo", 1, "a") + self.assertRaises(TypeError, DriverInfo, version="1", platform="a") self.assertRaises(TypeError, DriverInfo) self.assertRaises(TypeError, MongoClient, driver=1) - self.assertRaises(TypeError, MongoClient, driver='abc') - self.assertRaises(TypeError, MongoClient, driver=('Foo', '1', 'a')) + self.assertRaises(TypeError, MongoClient, driver="abc") + self.assertRaises(TypeError, MongoClient, driver=("Foo", "1", "a")) # Test appending to driver info. - metadata['driver']['name'] = 'PyMongo|FooDriver' - metadata['driver']['version'] = '%s|1.2.3' % ( - _METADATA['driver']['version'],) - client = MongoClient('foo', 27017, appname='foobar', - driver=DriverInfo('FooDriver', '1.2.3', None), connect=False) + metadata["driver"]["name"] = "PyMongo|FooDriver" + metadata["driver"]["version"] = "%s|1.2.3" % (_METADATA["driver"]["version"],) + client = MongoClient( + "foo", + 27017, + appname="foobar", + driver=DriverInfo("FooDriver", "1.2.3", None), + connect=False, + ) options = client._MongoClient__options self.assertEqual(options.pool_options.metadata, metadata) - metadata['platform'] = '%s|FooPlatform' % ( - _METADATA['platform'],) - client = MongoClient('foo', 27017, appname='foobar', - driver=DriverInfo('FooDriver', '1.2.3', 'FooPlatform'), connect=False) + metadata["platform"] = "%s|FooPlatform" % (_METADATA["platform"],) + client = MongoClient( + "foo", + 27017, + appname="foobar", + driver=DriverInfo("FooDriver", "1.2.3", "FooPlatform"), + connect=False, + ) options = client._MongoClient__options self.assertEqual(options.pool_options.metadata, metadata) @@ -331,12 +330,14 @@ def test_kwargs_codec_options(self): class MyFloatType(object): def __init__(self, x): self.__x = x + @property def x(self): return self.__x class MyFloatAsIntEncoder(TypeEncoder): python_type = MyFloatType + def transform_python(self, value): return int(value) @@ -344,8 +345,8 @@ def transform_python(self, value): document_class = SON type_registry = TypeRegistry([MyFloatAsIntEncoder()]) tz_aware = True - uuid_representation_label = 'javaLegacy' - unicode_decode_error_handler = 'ignore' + uuid_representation_label = "javaLegacy" + unicode_decode_error_handler = "ignore" tzinfo = utc c = MongoClient( document_class=document_class, @@ -354,63 +355,62 @@ def transform_python(self, value): uuidrepresentation=uuid_representation_label, unicode_decode_error_handler=unicode_decode_error_handler, tzinfo=tzinfo, - connect=False + connect=False, ) self.assertEqual(c.codec_options.document_class, document_class) self.assertEqual(c.codec_options.type_registry, type_registry) self.assertEqual(c.codec_options.tz_aware, tz_aware) self.assertEqual( - c.codec_options.uuid_representation, - _UUID_REPRESENTATIONS[uuid_representation_label]) - self.assertEqual( - c.codec_options.unicode_decode_error_handler, - unicode_decode_error_handler) + c.codec_options.uuid_representation, _UUID_REPRESENTATIONS[uuid_representation_label] + ) + self.assertEqual(c.codec_options.unicode_decode_error_handler, unicode_decode_error_handler) self.assertEqual(c.codec_options.tzinfo, tzinfo) def test_uri_codec_options(self): # Ensure codec options are passed in correctly - uuid_representation_label = 'javaLegacy' - unicode_decode_error_handler = 'ignore' - uri = ("mongodb://%s:%d/foo?tz_aware=true&uuidrepresentation=" - "%s&unicode_decode_error_handler=%s" % ( - client_context.host, - client_context.port, - uuid_representation_label, - unicode_decode_error_handler)) + uuid_representation_label = "javaLegacy" + unicode_decode_error_handler = "ignore" + uri = ( + "mongodb://%s:%d/foo?tz_aware=true&uuidrepresentation=" + "%s&unicode_decode_error_handler=%s" + % ( + client_context.host, + client_context.port, + uuid_representation_label, + unicode_decode_error_handler, + ) + ) c = MongoClient(uri, connect=False) self.assertEqual(c.codec_options.tz_aware, True) self.assertEqual( - c.codec_options.uuid_representation, - _UUID_REPRESENTATIONS[uuid_representation_label]) - self.assertEqual( - c.codec_options.unicode_decode_error_handler, - unicode_decode_error_handler) + c.codec_options.uuid_representation, _UUID_REPRESENTATIONS[uuid_representation_label] + ) + self.assertEqual(c.codec_options.unicode_decode_error_handler, unicode_decode_error_handler) def test_uri_option_precedence(self): # Ensure kwarg options override connection string options. - uri = ("mongodb://localhost/?ssl=true&replicaSet=name" - "&readPreference=primary") - c = MongoClient(uri, ssl=False, replicaSet="newname", - readPreference="secondaryPreferred") + uri = "mongodb://localhost/?ssl=true&replicaSet=name" "&readPreference=primary" + c = MongoClient(uri, ssl=False, replicaSet="newname", readPreference="secondaryPreferred") clopts = c._MongoClient__options opts = clopts._options - self.assertEqual(opts['tls'], False) + self.assertEqual(opts["tls"], False) self.assertEqual(clopts.replica_set_name, "newname") - self.assertEqual( - clopts.read_preference, ReadPreference.SECONDARY_PREFERRED) + self.assertEqual(clopts.read_preference, ReadPreference.SECONDARY_PREFERRED) - @unittest.skipUnless( - _HAVE_DNSPYTHON, "DNS-related tests need dnspython to be installed") + @unittest.skipUnless(_HAVE_DNSPYTHON, "DNS-related tests need dnspython to be installed") def test_connection_timeout_ms_propagates_to_DNS_resolver(self): # Patch the resolver. from pymongo.srv_resolver import _resolve + patched_resolver = FunctionCallRecorder(_resolve) pymongo.srv_resolver._resolve = patched_resolver + def reset_resolver(): pymongo.srv_resolver._resolve = _resolve + self.addCleanup(reset_resolver) # Setup. @@ -424,7 +424,7 @@ def test_scenario(args, kwargs, expected_value): patched_resolver.reset() MongoClient(*args, **kwargs) for _, kw in patched_resolver.call_list(): - self.assertAlmostEqual(kw['lifetime'], expected_value) + self.assertAlmostEqual(kw["lifetime"], expected_value) # No timeout specified. test_scenario((base_uri,), {}, CONNECT_TIMEOUT) @@ -433,7 +433,7 @@ def test_scenario(args, kwargs, expected_value): test_scenario((uri_with_timeout,), {}, expected_uri_value) # Timeout only specified in keyword arguments. - kwarg = {'connectTimeoutMS': connectTimeoutMS} + kwarg = {"connectTimeoutMS": connectTimeoutMS} test_scenario((base_uri,), kwarg, expected_kw_value) # Timeout specified in both kwargs and connection string. @@ -442,23 +442,27 @@ def test_scenario(args, kwargs, expected_value): def test_uri_security_options(self): # Ensure that we don't silently override security-related options. with self.assertRaises(InvalidURI): - MongoClient('mongodb://localhost/?ssl=true', tls=False, - connect=False) + MongoClient("mongodb://localhost/?ssl=true", tls=False, connect=False) # Matching SSL and TLS options should not cause errors. - c = MongoClient('mongodb://localhost/?ssl=false', tls=False, - connect=False) - self.assertEqual(c._MongoClient__options._options['tls'], False) + c = MongoClient("mongodb://localhost/?ssl=false", tls=False, connect=False) + self.assertEqual(c._MongoClient__options._options["tls"], False) # Conflicting tlsInsecure options should raise an error. with self.assertRaises(InvalidURI): - MongoClient('mongodb://localhost/?tlsInsecure=true', - connect=False, tlsAllowInvalidHostnames=True) + MongoClient( + "mongodb://localhost/?tlsInsecure=true", + connect=False, + tlsAllowInvalidHostnames=True, + ) # Conflicting legacy tlsInsecure options should also raise an error. with self.assertRaises(InvalidURI): - MongoClient('mongodb://localhost/?tlsInsecure=true', - connect=False, tlsAllowInvalidCertificates=False) + MongoClient( + "mongodb://localhost/?tlsInsecure=true", + connect=False, + tlsAllowInvalidCertificates=False, + ) # Conflicting kwargs should raise InvalidURI with self.assertRaises(InvalidURI): @@ -467,11 +471,13 @@ def test_uri_security_options(self): def test_event_listeners(self): c = MongoClient(event_listeners=[], connect=False) self.assertEqual(c.options.event_listeners, []) - listeners = [event_loggers.CommandLogger(), - event_loggers.HeartbeatLogger(), - event_loggers.ServerLogger(), - event_loggers.TopologyLogger(), - event_loggers.ConnectionPoolLogger()] + listeners = [ + event_loggers.CommandLogger(), + event_loggers.HeartbeatLogger(), + event_loggers.ServerLogger(), + event_loggers.TopologyLogger(), + event_loggers.ConnectionPoolLogger(), + ] c = MongoClient(event_listeners=listeners, connect=False) self.assertEqual(c.options.event_listeners, listeners) @@ -488,16 +494,19 @@ def test_client_options(self): class TestClient(IntegrationTest): def test_multiple_uris(self): with self.assertRaises(ConfigurationError): - MongoClient(host=['mongodb+srv://cluster-a.abc12.mongodb.net', - 'mongodb+srv://cluster-b.abc12.mongodb.net', - 'mongodb+srv://cluster-c.abc12.mongodb.net']) + MongoClient( + host=[ + "mongodb+srv://cluster-a.abc12.mongodb.net", + "mongodb+srv://cluster-b.abc12.mongodb.net", + "mongodb+srv://cluster-c.abc12.mongodb.net", + ] + ) def test_max_idle_time_reaper_default(self): with client_knobs(kill_cursor_frequency=0.1): # Assert reaper doesn't remove sockets when maxIdleTimeMS not set client = rs_or_single_client() - server = client._get_topology().select_server( - readable_server_selector) + server = client._get_topology().select_server(readable_server_selector) with server._pool.get_socket({}) as sock_info: pass self.assertEqual(1, len(server._pool.sockets)) @@ -507,89 +516,78 @@ def test_max_idle_time_reaper_default(self): def test_max_idle_time_reaper_removes_stale_minPoolSize(self): with client_knobs(kill_cursor_frequency=0.1): # Assert reaper removes idle socket and replaces it with a new one - client = rs_or_single_client(maxIdleTimeMS=500, - minPoolSize=1) - server = client._get_topology().select_server( - readable_server_selector) + client = rs_or_single_client(maxIdleTimeMS=500, minPoolSize=1) + server = client._get_topology().select_server(readable_server_selector) with server._pool.get_socket({}) as sock_info: pass # When the reaper runs at the same time as the get_socket, two # sockets could be created and checked into the pool. self.assertGreaterEqual(len(server._pool.sockets), 1) - wait_until(lambda: sock_info not in server._pool.sockets, - "remove stale socket") - wait_until(lambda: 1 <= len(server._pool.sockets), - "replace stale socket") + wait_until(lambda: sock_info not in server._pool.sockets, "remove stale socket") + wait_until(lambda: 1 <= len(server._pool.sockets), "replace stale socket") client.close() def test_max_idle_time_reaper_does_not_exceed_maxPoolSize(self): with client_knobs(kill_cursor_frequency=0.1): # Assert reaper respects maxPoolSize when adding new sockets. - client = rs_or_single_client(maxIdleTimeMS=500, - minPoolSize=1, - maxPoolSize=1) - server = client._get_topology().select_server( - readable_server_selector) + client = rs_or_single_client(maxIdleTimeMS=500, minPoolSize=1, maxPoolSize=1) + server = client._get_topology().select_server(readable_server_selector) with server._pool.get_socket({}) as sock_info: pass # When the reaper runs at the same time as the get_socket, # maxPoolSize=1 should prevent two sockets from being created. self.assertEqual(1, len(server._pool.sockets)) - wait_until(lambda: sock_info not in server._pool.sockets, - "remove stale socket") - wait_until(lambda: 1 == len(server._pool.sockets), - "replace stale socket") + wait_until(lambda: sock_info not in server._pool.sockets, "remove stale socket") + wait_until(lambda: 1 == len(server._pool.sockets), "replace stale socket") client.close() def test_max_idle_time_reaper_removes_stale(self): with client_knobs(kill_cursor_frequency=0.1): # Assert reaper has removed idle socket and NOT replaced it client = rs_or_single_client(maxIdleTimeMS=500) - server = client._get_topology().select_server( - readable_server_selector) + server = client._get_topology().select_server(readable_server_selector) with server._pool.get_socket({}) as sock_info_one: pass # Assert that the pool does not close sockets prematurely. - time.sleep(.300) + time.sleep(0.300) with server._pool.get_socket({}) as sock_info_two: pass self.assertIs(sock_info_one, sock_info_two) wait_until( lambda: 0 == len(server._pool.sockets), - "stale socket reaped and new one NOT added to the pool") + "stale socket reaped and new one NOT added to the pool", + ) client.close() def test_min_pool_size(self): - with client_knobs(kill_cursor_frequency=.1): + with client_knobs(kill_cursor_frequency=0.1): client = rs_or_single_client() - server = client._get_topology().select_server( - readable_server_selector) + server = client._get_topology().select_server(readable_server_selector) self.assertEqual(0, len(server._pool.sockets)) # Assert that pool started up at minPoolSize client = rs_or_single_client(minPoolSize=10) - server = client._get_topology().select_server( - readable_server_selector) - wait_until(lambda: 10 == len(server._pool.sockets), - "pool initialized with 10 sockets") + server = client._get_topology().select_server(readable_server_selector) + wait_until(lambda: 10 == len(server._pool.sockets), "pool initialized with 10 sockets") # Assert that if a socket is closed, a new one takes its place with server._pool.get_socket({}) as sock_info: sock_info.close_socket(None) - wait_until(lambda: 10 == len(server._pool.sockets), - "a closed socket gets replaced from the pool") + wait_until( + lambda: 10 == len(server._pool.sockets), + "a closed socket gets replaced from the pool", + ) self.assertFalse(sock_info in server._pool.sockets) def test_max_idle_time_checkout(self): # Use high frequency to test _get_socket_no_auth. with client_knobs(kill_cursor_frequency=99999999): client = rs_or_single_client(maxIdleTimeMS=500) - server = client._get_topology().select_server( - readable_server_selector) + server = client._get_topology().select_server(readable_server_selector) with server._pool.get_socket({}) as sock_info: pass self.assertEqual(1, len(server._pool.sockets)) - time.sleep(1) # Sleep so that the socket becomes stale. + time.sleep(1) # Sleep so that the socket becomes stale. with server._pool.get_socket({}) as new_sock_info: self.assertNotEqual(sock_info, new_sock_info) @@ -599,8 +597,7 @@ def test_max_idle_time_checkout(self): # Test that sockets are reused if maxIdleTimeMS is not set. client = rs_or_single_client() - server = client._get_topology().select_server( - readable_server_selector) + server = client._get_topology().select_server(readable_server_selector) with server._pool.get_socket({}) as sock_info: pass self.assertEqual(1, len(server._pool.sockets)) @@ -616,15 +613,14 @@ def test_constants(self): host, port = client_context.host, client_context.port kwargs = client_context.default_client_options.copy() if client_context.auth_enabled: - kwargs['username'] = db_user - kwargs['password'] = db_pwd + kwargs["username"] = db_user + kwargs["password"] = db_pwd # Set bad defaults. MongoClient.HOST = "somedomainthatdoesntexist.org" MongoClient.PORT = 123456789 with self.assertRaises(AutoReconnect): - connected(MongoClient(serverSelectionTimeoutMS=10, - **kwargs)) + connected(MongoClient(serverSelectionTimeoutMS=10, **kwargs)) # Override the defaults. No error. connected(MongoClient(host, port, **kwargs)) @@ -657,7 +653,7 @@ def test_init_disconnected(self): self.assertIsInstance(c.topology_description, TopologyDescription) self.assertEqual(c.topology_description, c._topology._description) self.assertIsNone(c.address) # PYTHON-2981 - c.admin.command('ping') # connect + c.admin.command("ping") # connect if client_context.is_rs: # The primary's host and port are from the replica set config. self.assertIsNotNone(c.address) @@ -665,66 +661,68 @@ def test_init_disconnected(self): self.assertEqual(c.address, (host, port)) bad_host = "somedomainthatdoesntexist.org" - c = MongoClient(bad_host, port, connectTimeoutMS=1, - serverSelectionTimeoutMS=10) + c = MongoClient(bad_host, port, connectTimeoutMS=1, serverSelectionTimeoutMS=10) self.assertRaises(ConnectionFailure, c.pymongo_test.test.find_one) def test_init_disconnected_with_auth(self): uri = "mongodb://user:pass@somedomainthatdoesntexist" - c = MongoClient(uri, connectTimeoutMS=1, - serverSelectionTimeoutMS=10) + c = MongoClient(uri, connectTimeoutMS=1, serverSelectionTimeoutMS=10) self.assertRaises(ConnectionFailure, c.pymongo_test.test.find_one) def test_equality(self): - seed = '%s:%s' % list(self.client._topology_settings.seeds)[0] + seed = "%s:%s" % list(self.client._topology_settings.seeds)[0] c = rs_or_single_client(seed, connect=False) self.addCleanup(c.close) self.assertEqual(client_context.client, c) # Explicitly test inequality self.assertFalse(client_context.client != c) - c = rs_or_single_client('invalid.com', connect=False) + c = rs_or_single_client("invalid.com", connect=False) self.addCleanup(c.close) self.assertNotEqual(client_context.client, c) self.assertTrue(client_context.client != c) # Seeds differ: - self.assertNotEqual(MongoClient('a', connect=False), - MongoClient('b', connect=False)) + self.assertNotEqual(MongoClient("a", connect=False), MongoClient("b", connect=False)) # Same seeds but out of order still compares equal: - self.assertEqual(MongoClient(['a', 'b', 'c'], connect=False), - MongoClient(['c', 'a', 'b'], connect=False)) + self.assertEqual( + MongoClient(["a", "b", "c"], connect=False), MongoClient(["c", "a", "b"], connect=False) + ) def test_hashable(self): - seed = '%s:%s' % list(self.client._topology_settings.seeds)[0] + seed = "%s:%s" % list(self.client._topology_settings.seeds)[0] c = rs_or_single_client(seed, connect=False) self.addCleanup(c.close) self.assertIn(c, {client_context.client}) - c = rs_or_single_client('invalid.com', connect=False) + c = rs_or_single_client("invalid.com", connect=False) self.addCleanup(c.close) self.assertNotIn(c, {client_context.client}) def test_host_w_port(self): with self.assertRaises(ValueError): - connected(MongoClient("%s:1234567" % (client_context.host,), - connectTimeoutMS=1, - serverSelectionTimeoutMS=10)) + connected( + MongoClient( + "%s:1234567" % (client_context.host,), + connectTimeoutMS=1, + serverSelectionTimeoutMS=10, + ) + ) def test_repr(self): # Used to test 'eval' below. import bson client = MongoClient( - 'mongodb://localhost:27017,localhost:27018/?replicaSet=replset' - '&connectTimeoutMS=12345&w=1&wtimeoutms=100', - connect=False, document_class=SON) + "mongodb://localhost:27017,localhost:27018/?replicaSet=replset" + "&connectTimeoutMS=12345&w=1&wtimeoutms=100", + connect=False, + document_class=SON, + ) the_repr = repr(client) - self.assertIn('MongoClient(host=', the_repr) + self.assertIn("MongoClient(host=", the_repr) self.assertIn( - "document_class=bson.son.SON, " - "tz_aware=False, " - "connect=False, ", - the_repr) + "document_class=bson.son.SON, " "tz_aware=False, " "connect=False, ", the_repr + ) self.assertIn("connecttimeoutms=12345", the_repr) self.assertIn("replicaset='replset'", the_repr) self.assertIn("w=1", the_repr) @@ -732,20 +730,18 @@ def test_repr(self): self.assertEqual(eval(the_repr), client) - client = MongoClient("localhost:27017,localhost:27018", - replicaSet='replset', - connectTimeoutMS=12345, - socketTimeoutMS=None, - w=1, - wtimeoutms=100, - connect=False) + client = MongoClient( + "localhost:27017,localhost:27018", + replicaSet="replset", + connectTimeoutMS=12345, + socketTimeoutMS=None, + w=1, + wtimeoutms=100, + connect=False, + ) the_repr = repr(client) - self.assertIn('MongoClient(host=', the_repr) - self.assertIn( - "document_class=dict, " - "tz_aware=False, " - "connect=False, ", - the_repr) + self.assertIn("MongoClient(host=", the_repr) + self.assertIn("document_class=dict, " "tz_aware=False, " "connect=False, ", the_repr) self.assertIn("connecttimeoutms=12345", the_repr) self.assertIn("replicaset='replset'", the_repr) self.assertIn("sockettimeoutms=None", the_repr) @@ -755,11 +751,10 @@ def test_repr(self): self.assertEqual(eval(the_repr), client) def test_getters(self): - wait_until(lambda: client_context.nodes == self.client.nodes, - "find all nodes") + wait_until(lambda: client_context.nodes == self.client.nodes, "find all nodes") def test_list_databases(self): - cmd_docs = self.client.admin.command('listDatabases')['databases'] + cmd_docs = self.client.admin.command("listDatabases")["databases"] cursor = self.client.list_databases() self.assertIsInstance(cursor, CommandCursor) helper_docs = list(cursor) @@ -806,7 +801,7 @@ def test_drop_database(self): if client_context.is_rs: wc_client = rs_or_single_client(w=len(client_context.nodes) + 1) with self.assertRaises(WriteConcernError): - wc_client.drop_database('pymongo_test2') + wc_client.drop_database("pymongo_test2") self.client.drop_database(self.client.pymongo_test2) dbs = self.client.list_database_names() @@ -820,7 +815,7 @@ def test_close(self): self.assertRaises(InvalidOperation, coll.count_documents, {}) def test_close_kills_cursors(self): - if sys.platform.startswith('java'): + if sys.platform.startswith("java"): # We can't figure out how to make this test reliable with Jython. raise SkipTest("Can't test with Jython") test_client = rs_or_single_client() @@ -865,7 +860,7 @@ def test_close_stops_kill_cursors_thread(self): self.assertTrue(client._kill_cursors_executor._stopped) # Reusing the closed client should raise an InvalidOperation error. - self.assertRaises(InvalidOperation, client.admin.command, 'ping') + self.assertRaises(InvalidOperation, client.admin.command, "ping") # Thread is still stopped. self.assertTrue(client._kill_cursors_executor._stopped) @@ -879,7 +874,7 @@ def test_uri_connect_option(self): self.assertFalse(kc_thread and kc_thread.is_alive()) # Using the client should open topology and start the thread. - client.admin.command('ping') + client.admin.command("ping") self.assertTrue(client._topology._opened) kc_thread = client._kill_cursors_executor._thread self.assertTrue(kc_thread and kc_thread.is_alive()) @@ -918,16 +913,13 @@ def test_auth_from_uri(self): self.addCleanup(client_context.drop_user, "admin", "admin") self.addCleanup(remove_all_users, self.client.pymongo_test) - client_context.create_user( - "pymongo_test", "user", "pass", roles=['userAdmin', 'readWrite']) + client_context.create_user("pymongo_test", "user", "pass", roles=["userAdmin", "readWrite"]) with self.assertRaises(OperationFailure): - connected(rs_or_single_client_noauth( - "mongodb://a:b@%s:%d" % (host, port))) + connected(rs_or_single_client_noauth("mongodb://a:b@%s:%d" % (host, port))) # No error. - connected(rs_or_single_client_noauth( - "mongodb://admin:pass@%s:%d" % (host, port))) + connected(rs_or_single_client_noauth("mongodb://admin:pass@%s:%d" % (host, port))) # Wrong database. uri = "mongodb://admin:pass@%s:%d/pymongo_test" % (host, port) @@ -935,21 +927,21 @@ def test_auth_from_uri(self): connected(rs_or_single_client_noauth(uri)) # No error. - connected(rs_or_single_client_noauth( - "mongodb://user:pass@%s:%d/pymongo_test" % (host, port))) + connected( + rs_or_single_client_noauth("mongodb://user:pass@%s:%d/pymongo_test" % (host, port)) + ) # Auth with lazy connection. rs_or_single_client_noauth( - "mongodb://user:pass@%s:%d/pymongo_test" % (host, port), - connect=False).pymongo_test.test.find_one() + "mongodb://user:pass@%s:%d/pymongo_test" % (host, port), connect=False + ).pymongo_test.test.find_one() # Wrong password. bad_client = rs_or_single_client_noauth( - "mongodb://user:wrong@%s:%d/pymongo_test" % (host, port), - connect=False) + "mongodb://user:wrong@%s:%d/pymongo_test" % (host, port), connect=False + ) - self.assertRaises(OperationFailure, - bad_client.pymongo_test.test.find_one) + self.assertRaises(OperationFailure, bad_client.pymongo_test.test.find_one) @client_context.require_auth def test_username_and_password(self): @@ -968,26 +960,23 @@ def test_username_and_password(self): c.server_info() with self.assertRaises(OperationFailure): - rs_or_single_client_noauth( - username="ad min", password="foo").server_info() + rs_or_single_client_noauth(username="ad min", password="foo").server_info() @client_context.require_auth def test_lazy_auth_raises_operation_failure(self): lazy_client = rs_or_single_client_noauth( - "mongodb://user:wrong@%s/pymongo_test" % (client_context.host,), - connect=False) + "mongodb://user:wrong@%s/pymongo_test" % (client_context.host,), connect=False + ) - assertRaisesExactly( - OperationFailure, lazy_client.test.collection.find_one) + assertRaisesExactly(OperationFailure, lazy_client.test.collection.find_one) @client_context.require_no_tls def test_unix_socket(self): if not hasattr(socket, "AF_UNIX"): raise SkipTest("UNIX-sockets are not supported on this system") - mongodb_socket = '/tmp/mongodb-%d.sock' % (client_context.port,) - encoded_socket = ( - '%2Ftmp%2F' + 'mongodb-%d.sock' % (client_context.port,)) + mongodb_socket = "/tmp/mongodb-%d.sock" % (client_context.port,) + encoded_socket = "%2Ftmp%2F" + "mongodb-%d.sock" % (client_context.port,) if not os.access(mongodb_socket, os.R_OK): raise SkipTest("Socket file is not accessible") @@ -1003,8 +992,9 @@ def test_unix_socket(self): # Confirm it fails with a missing socket. self.assertRaises( ConnectionFailure, - connected, MongoClient("mongodb://%2Ftmp%2Fnon-existent.sock", - serverSelectionTimeoutMS=100)) + connected, + MongoClient("mongodb://%2Ftmp%2Fnon-existent.sock", serverSelectionTimeoutMS=100), + ) def test_document_class(self): c = self.client @@ -1026,7 +1016,8 @@ def test_timeouts(self): connectTimeoutMS=10500, socketTimeoutMS=10500, maxIdleTimeMS=10500, - serverSelectionTimeoutMS=10500) + serverSelectionTimeoutMS=10500, + ) self.assertEqual(10.5, get_pool(client).opts.connect_timeout) self.assertEqual(10.5, get_pool(client).opts.socket_timeout) self.assertEqual(10.5, get_pool(client).opts.max_idle_time_seconds) @@ -1043,14 +1034,11 @@ def test_socket_timeout_ms_validation(self): c = connected(rs_or_single_client(socketTimeoutMS=0)) self.assertEqual(None, get_pool(c).opts.socket_timeout) - self.assertRaises(ValueError, - rs_or_single_client, socketTimeoutMS=-1) + self.assertRaises(ValueError, rs_or_single_client, socketTimeoutMS=-1) - self.assertRaises(ValueError, - rs_or_single_client, socketTimeoutMS=1e10) + self.assertRaises(ValueError, rs_or_single_client, socketTimeoutMS=1e10) - self.assertRaises(ValueError, - rs_or_single_client, socketTimeoutMS='foo') + self.assertRaises(ValueError, rs_or_single_client, socketTimeoutMS="foo") def test_socket_timeout(self): no_timeout = self.client @@ -1066,6 +1054,7 @@ def test_socket_timeout(self): def get_x(db): doc = next(db.test.find().where(where_func)) return doc["x"] + self.assertEqual(1, get_x(no_timeout.pymongo_test)) self.assertRaises(NetworkTimeout, get_x, timeout.pymongo_test) @@ -1076,28 +1065,23 @@ def test_server_selection_timeout(self): client = MongoClient(serverSelectionTimeoutMS=0, connect=False) self.assertAlmostEqual(0, client.options.server_selection_timeout) - self.assertRaises(ValueError, MongoClient, - serverSelectionTimeoutMS="foo", connect=False) - self.assertRaises(ValueError, MongoClient, - serverSelectionTimeoutMS=-1, connect=False) - self.assertRaises(ConfigurationError, MongoClient, - serverSelectionTimeoutMS=None, connect=False) + self.assertRaises(ValueError, MongoClient, serverSelectionTimeoutMS="foo", connect=False) + self.assertRaises(ValueError, MongoClient, serverSelectionTimeoutMS=-1, connect=False) + self.assertRaises( + ConfigurationError, MongoClient, serverSelectionTimeoutMS=None, connect=False + ) - client = MongoClient( - 'mongodb://localhost/?serverSelectionTimeoutMS=100', connect=False) + client = MongoClient("mongodb://localhost/?serverSelectionTimeoutMS=100", connect=False) self.assertAlmostEqual(0.1, client.options.server_selection_timeout) - client = MongoClient( - 'mongodb://localhost/?serverSelectionTimeoutMS=0', connect=False) + client = MongoClient("mongodb://localhost/?serverSelectionTimeoutMS=0", connect=False) self.assertAlmostEqual(0, client.options.server_selection_timeout) # Test invalid timeout in URI ignored and set to default. - client = MongoClient( - 'mongodb://localhost/?serverSelectionTimeoutMS=-1', connect=False) + client = MongoClient("mongodb://localhost/?serverSelectionTimeoutMS=-1", connect=False) self.assertAlmostEqual(30, client.options.server_selection_timeout) - client = MongoClient( - 'mongodb://localhost/?serverSelectionTimeoutMS=', connect=False) + client = MongoClient("mongodb://localhost/?serverSelectionTimeoutMS=", connect=False) self.assertAlmostEqual(30, client.options.server_selection_timeout) def test_waitQueueTimeoutMS(self): @@ -1107,12 +1091,11 @@ def test_waitQueueTimeoutMS(self): def test_socketKeepAlive(self): pool = get_pool(self.client) with pool.get_socket({}) as sock_info: - keepalive = sock_info.sock.getsockopt(socket.SOL_SOCKET, - socket.SO_KEEPALIVE) + keepalive = sock_info.sock.getsockopt(socket.SOL_SOCKET, socket.SO_KEEPALIVE) self.assertTrue(keepalive) def test_tz_aware(self): - self.assertRaises(ValueError, MongoClient, tz_aware='foo') + self.assertRaises(ValueError, MongoClient, tz_aware="foo") aware = rs_or_single_client(tz_aware=True) naive = self.client @@ -1125,7 +1108,8 @@ def test_tz_aware(self): self.assertEqual(utc, aware.pymongo_test.test.find_one()["x"].tzinfo) self.assertEqual( aware.pymongo_test.test.find_one()["x"].replace(tzinfo=None), - naive.pymongo_test.test.find_one()["x"]) + naive.pymongo_test.test.find_one()["x"], + ) @client_context.require_ipv6 def test_ipv6(self): @@ -1140,7 +1124,7 @@ def test_ipv6(self): uri = "mongodb://%s[::1]:%d" % (auth_str, client_context.port) if client_context.is_rs: - uri += '/?replicaSet=' + client_context.replica_set_name + uri += "/?replicaSet=" + client_context.replica_set_name client = rs_or_single_client_noauth(uri) client.pymongo_test.test.insert_one({"dummy": "object"}) @@ -1170,7 +1154,7 @@ def test_contextlib(self): client.pymongo_test.test.find_one() def test_interrupt_signal(self): - if sys.platform.startswith('java'): + if sys.platform.startswith("java"): # We can't figure out how to raise an exception on a thread that's # blocked on a socket, whether that's the main thread or a worker, # without simply killing the whole thread in Jython. This suggests @@ -1187,8 +1171,8 @@ def test_interrupt_signal(self): where = delay(1.5) # Need exactly 1 document so find() will execute its $where clause once - db.drop_collection('foo') - db.foo.insert_one({'_id': 1}) + db.drop_collection("foo") + db.foo.insert_one({"_id": 1}) old_signal_handler = None try: @@ -1199,7 +1183,8 @@ def test_interrupt_signal(self): # sock.recv(): TypeError: 'int' object is not callable # We don't know what causes this, so we hack around it. - if sys.platform == 'win32': + if sys.platform == "win32": + def interrupter(): # Raises KeyboardInterrupt in the main thread time.sleep(0.25) @@ -1218,7 +1203,7 @@ def sigalarm(num, frame): raised = False try: # Will be interrupted by a KeyboardInterrupt. - next(db.foo.find({'$where': where})) + next(db.foo.find({"$where": where})) except KeyboardInterrupt: raised = True @@ -1229,10 +1214,7 @@ def sigalarm(num, frame): # Raises AssertionError due to PYTHON-294 -- Mongo's response to # the previous find() is still waiting to be read on the socket, # so the request id's don't match. - self.assertEqual( - {'_id': 1}, - next(db.foo.find()) - ) + self.assertEqual({"_id": 1}, next(db.foo.find())) finally: if old_signal_handler: signal.signal(signal.SIGALRM, old_signal_handler) @@ -1249,10 +1231,8 @@ def test_operation_failure(self): self.assertGreaterEqual(socket_count, 1) old_sock_info = next(iter(pool.sockets)) client.pymongo_test.test.drop() - client.pymongo_test.test.insert_one({'_id': 'foo'}) - self.assertRaises( - OperationFailure, - client.pymongo_test.test.insert_one, {'_id': 'foo'}) + client.pymongo_test.test.insert_one({"_id": "foo"}) + self.assertRaises(OperationFailure, client.pymongo_test.test.insert_one, {"_id": "foo"}) self.assertEqual(socket_count, len(pool.sockets)) new_sock_info = next(iter(pool.sockets)) @@ -1264,27 +1244,26 @@ def test_lazy_connect_w0(self): # Use a separate collection to avoid races where we're still # completing an operation on a collection while the next test begins. - client_context.client.drop_database('test_lazy_connect_w0') - self.addCleanup( - client_context.client.drop_database, 'test_lazy_connect_w0') + client_context.client.drop_database("test_lazy_connect_w0") + self.addCleanup(client_context.client.drop_database, "test_lazy_connect_w0") client = rs_or_single_client(connect=False, w=0) client.test_lazy_connect_w0.test.insert_one({}) wait_until( - lambda: client.test_lazy_connect_w0.test.count_documents({}) == 1, - "find one document") + lambda: client.test_lazy_connect_w0.test.count_documents({}) == 1, "find one document" + ) client = rs_or_single_client(connect=False, w=0) - client.test_lazy_connect_w0.test.update_one({}, {'$set': {'x': 1}}) + client.test_lazy_connect_w0.test.update_one({}, {"$set": {"x": 1}}) wait_until( - lambda: client.test_lazy_connect_w0.test.find_one().get('x') == 1, - "update one document") + lambda: client.test_lazy_connect_w0.test.find_one().get("x") == 1, "update one document" + ) client = rs_or_single_client(connect=False, w=0) client.test_lazy_connect_w0.test.delete_one({}) wait_until( - lambda: client.test_lazy_connect_w0.test.count_documents({}) == 0, - "delete one document") + lambda: client.test_lazy_connect_w0.test.count_documents({}) == 0, "delete one document" + ) @client_context.require_no_mongos def test_exhaust_network_error(self): @@ -1316,9 +1295,7 @@ def test_auth_network_error(self): # when authenticating a new socket with cached credentials. # Get a client with one socket so we detect if it's leaked. - c = connected(rs_or_single_client(maxPoolSize=1, - waitQueueTimeoutMS=1, - retryReads=False)) + c = connected(rs_or_single_client(maxPoolSize=1, waitQueueTimeoutMS=1, retryReads=False)) # Cause a network error on the actual socket. pool = get_pool(c) @@ -1334,8 +1311,7 @@ def test_auth_network_error(self): @client_context.require_no_replica_set def test_connect_to_standalone_using_replica_set_name(self): - client = single_client(replicaSet='anything', - serverSelectionTimeoutMS=100) + client = single_client(replicaSet="anything", serverSelectionTimeoutMS=100) with self.assertRaises(AutoReconnect): client.test.test.find_one() @@ -1346,16 +1322,24 @@ def test_stale_getmore(self): # the topology before the getMore message is sent. Test that # MongoClient._run_operation_with_response handles the error. with self.assertRaises(AutoReconnect): - client = rs_client(connect=False, - serverSelectionTimeoutMS=100) + client = rs_client(connect=False, serverSelectionTimeoutMS=100) client._run_operation( - operation=message._GetMore('pymongo_test', 'collection', - 101, 1234, client.codec_options, - ReadPreference.PRIMARY, - None, client, None, None, False), - unpack_res=Cursor( - client.pymongo_test.collection)._unpack_response, - address=('not-a-member', 27017)) + operation=message._GetMore( + "pymongo_test", + "collection", + 101, + 1234, + client.codec_options, + ReadPreference.PRIMARY, + None, + client, + None, + None, + False, + ), + unpack_res=Cursor(client.pymongo_test.collection)._unpack_response, + address=("not-a-member", 27017), + ) def test_heartbeat_frequency_ms(self): class HeartbeatStartedListener(ServerHeartbeatListener): @@ -1382,15 +1366,17 @@ def init(self, *args): ServerHeartbeatStartedEvent.__init__ = init listener = HeartbeatStartedListener() uri = "mongodb://%s:%d/?heartbeatFrequencyMS=500" % ( - client_context.host, client_context.port) + client_context.host, + client_context.port, + ) client = single_client(uri, event_listeners=[listener]) - wait_until(lambda: len(listener.results) >= 2, - "record two ServerHeartbeatStartedEvents") + wait_until( + lambda: len(listener.results) >= 2, "record two ServerHeartbeatStartedEvents" + ) # Default heartbeatFrequencyMS is 10 sec. Check the interval was # closer to 0.5 sec with heartbeatFrequencyMS configured. - self.assertAlmostEqual( - heartbeat_times[1] - heartbeat_times[0], 0.5, delta=2) + self.assertAlmostEqual(heartbeat_times[1] - heartbeat_times[0], 0.5, delta=2) client.close() finally: @@ -1401,7 +1387,7 @@ def test_small_heartbeat_frequency_ms(self): with self.assertRaises(ConfigurationError) as context: MongoClient(uri) - self.assertIn('heartbeatFrequencyMS', str(context.exception)) + self.assertIn("heartbeatFrequencyMS", str(context.exception)) def test_compression(self): def compression_settings(client): @@ -1411,16 +1397,16 @@ def compression_settings(client): uri = "mongodb://localhost:27017/?compressors=zlib" client = MongoClient(uri, connect=False) opts = compression_settings(client) - self.assertEqual(opts.compressors, ['zlib']) + self.assertEqual(opts.compressors, ["zlib"]) uri = "mongodb://localhost:27017/?compressors=zlib&zlibCompressionLevel=4" client = MongoClient(uri, connect=False) opts = compression_settings(client) - self.assertEqual(opts.compressors, ['zlib']) + self.assertEqual(opts.compressors, ["zlib"]) self.assertEqual(opts.zlib_compression_level, 4) uri = "mongodb://localhost:27017/?compressors=zlib&zlibCompressionLevel=-1" client = MongoClient(uri, connect=False) opts = compression_settings(client) - self.assertEqual(opts.compressors, ['zlib']) + self.assertEqual(opts.compressors, ["zlib"]) self.assertEqual(opts.zlib_compression_level, -1) uri = "mongodb://localhost:27017" client = MongoClient(uri, connect=False) @@ -1435,7 +1421,7 @@ def compression_settings(client): uri = "mongodb://localhost:27017/?compressors=foobar,zlib" client = MongoClient(uri, connect=False) opts = compression_settings(client) - self.assertEqual(opts.compressors, ['zlib']) + self.assertEqual(opts.compressors, ["zlib"]) self.assertEqual(opts.zlib_compression_level, -1) # According to the connection string spec, unsupported values @@ -1443,12 +1429,12 @@ def compression_settings(client): uri = "mongodb://localhost:27017/?compressors=zlib&zlibCompressionLevel=10" client = MongoClient(uri, connect=False) opts = compression_settings(client) - self.assertEqual(opts.compressors, ['zlib']) + self.assertEqual(opts.compressors, ["zlib"]) self.assertEqual(opts.zlib_compression_level, -1) uri = "mongodb://localhost:27017/?compressors=zlib&zlibCompressionLevel=-2" client = MongoClient(uri, connect=False) opts = compression_settings(client) - self.assertEqual(opts.compressors, ['zlib']) + self.assertEqual(opts.compressors, ["zlib"]) self.assertEqual(opts.zlib_compression_level, -1) if not _HAVE_SNAPPY: @@ -1460,11 +1446,11 @@ def compression_settings(client): uri = "mongodb://localhost:27017/?compressors=snappy" client = MongoClient(uri, connect=False) opts = compression_settings(client) - self.assertEqual(opts.compressors, ['snappy']) + self.assertEqual(opts.compressors, ["snappy"]) uri = "mongodb://localhost:27017/?compressors=snappy,zlib" client = MongoClient(uri, connect=False) opts = compression_settings(client) - self.assertEqual(opts.compressors, ['snappy', 'zlib']) + self.assertEqual(opts.compressors, ["snappy", "zlib"]) if not _HAVE_ZSTD: uri = "mongodb://localhost:27017/?compressors=zstd" @@ -1475,11 +1461,11 @@ def compression_settings(client): uri = "mongodb://localhost:27017/?compressors=zstd" client = MongoClient(uri, connect=False) opts = compression_settings(client) - self.assertEqual(opts.compressors, ['zstd']) + self.assertEqual(opts.compressors, ["zstd"]) uri = "mongodb://localhost:27017/?compressors=zstd,zlib" client = MongoClient(uri, connect=False) opts = compression_settings(client) - self.assertEqual(opts.compressors, ['zstd', 'zlib']) + self.assertEqual(opts.compressors, ["zstd", "zlib"]) options = client_context.default_client_options if "compressors" in options and "zlib" in options["compressors"]: @@ -1491,7 +1477,7 @@ def compression_settings(client): def test_reset_during_update_pool(self): client = rs_or_single_client(minPoolSize=10) self.addCleanup(client.close) - client.admin.command('ping') + client.admin.command("ping") pool = get_pool(client) generation = pool.gen.get_overall() @@ -1507,9 +1493,8 @@ def stop(self): def run(self): while self.running: - exc = AutoReconnect('mock pool error') - ctx = _ErrorContext( - exc, 0, pool.gen.get_overall(), False, None) + exc = AutoReconnect("mock pool error") + ctx = _ErrorContext(exc, 0, pool.gen.get_overall(), False, None) client._topology.handle_error(pool.address, ctx) time.sleep(0.001) @@ -1521,24 +1506,23 @@ def run(self): try: while True: for _ in range(10): - client._topology.update_pool( - client._MongoClient__all_credentials) + client._topology.update_pool(client._MongoClient__all_credentials) if generation != pool.gen.get_overall(): break finally: t.stop() t.join() - client.admin.command('ping') + client.admin.command("ping") def test_background_connections_do_not_hold_locks(self): min_pool_size = 10 client = rs_or_single_client( - serverSelectionTimeoutMS=3000, minPoolSize=min_pool_size, - connect=False) + serverSelectionTimeoutMS=3000, minPoolSize=min_pool_size, connect=False + ) self.addCleanup(client.close) # Create a single connection in the pool. - client.admin.command('ping') + client.admin.command("ping") # Cause new connections stall for a few seconds. pool = get_pool(client) @@ -1550,15 +1534,15 @@ def stall_connect(*args, **kwargs): pool.connect = stall_connect # Un-patch Pool.connect to break the cyclic reference. - self.addCleanup(delattr, pool, 'connect') + self.addCleanup(delattr, pool, "connect") # Wait for the background thread to start creating connections - wait_until(lambda: len(pool.sockets) > 1, 'start creating connections') + wait_until(lambda: len(pool.sockets) > 1, "start creating connections") # Assert that application operations do not block. for _ in range(10): start = time.monotonic() - client.admin.command('ping') + client.admin.command("ping") total = time.monotonic() - start # Each ping command should not take more than 2 seconds self.assertLess(total, 2) @@ -1567,28 +1551,27 @@ def stall_connect(*args, **kwargs): def test_direct_connection(self): # direct_connection=True should result in Single topology. client = rs_or_single_client(directConnection=True) - client.admin.command('ping') + client.admin.command("ping") self.assertEqual(len(client.nodes), 1) - self.assertEqual(client._topology_settings.get_topology_type(), - TOPOLOGY_TYPE.Single) + self.assertEqual(client._topology_settings.get_topology_type(), TOPOLOGY_TYPE.Single) client.close() # direct_connection=False should result in RS topology. client = rs_or_single_client(directConnection=False) - client.admin.command('ping') + client.admin.command("ping") self.assertGreaterEqual(len(client.nodes), 1) - self.assertIn(client._topology_settings.get_topology_type(), - [TOPOLOGY_TYPE.ReplicaSetNoPrimary, - TOPOLOGY_TYPE.ReplicaSetWithPrimary]) + self.assertIn( + client._topology_settings.get_topology_type(), + [TOPOLOGY_TYPE.ReplicaSetNoPrimary, TOPOLOGY_TYPE.ReplicaSetWithPrimary], + ) client.close() # directConnection=True, should error with multiple hosts as a list. with self.assertRaises(ConfigurationError): - MongoClient(['host1', 'host2'], directConnection=True) + MongoClient(["host1", "host2"], directConnection=True) - @unittest.skipIf(sys.platform.startswith('java'), - 'Jython does not support gc.get_objects') - @unittest.skipIf('PyPy' in sys.version, 'PYTHON-2927 fails often on PyPy') + @unittest.skipIf(sys.platform.startswith("java"), "Jython does not support gc.get_objects") + @unittest.skipIf("PyPy" in sys.version, "PYTHON-2927 fails often on PyPy") def test_continuous_network_errors(self): def server_description_count(): i = 0 @@ -1599,12 +1582,12 @@ def server_description_count(): except ReferenceError: pass return i + gc.collect() with client_knobs(min_heartbeat_interval=0.003): client = MongoClient( - 'invalid:27017', - heartbeatFrequencyMS=3, - serverSelectionTimeoutMS=100) + "invalid:27017", heartbeatFrequencyMS=3, serverSelectionTimeoutMS=100 + ) initial_count = server_description_count() self.addCleanup(client.close) with self.assertRaises(ServerSelectionTimeoutError): @@ -1620,15 +1603,15 @@ def server_description_count(): def test_network_error_message(self): client = single_client(retryReads=False) self.addCleanup(client.close) - client.admin.command('ping') # connect - with self.fail_point({'mode': {'times': 1}, - 'data': {'closeConnection': True, - 'failCommands': ['find']}}): - expected = '%s:%s: ' % client.address + client.admin.command("ping") # connect + with self.fail_point( + {"mode": {"times": 1}, "data": {"closeConnection": True, "failCommands": ["find"]}} + ): + expected = "%s:%s: " % client.address with self.assertRaisesRegex(AutoReconnect, expected): client.pymongo_test.test.find_one({}) - @unittest.skipIf('PyPy' in sys.version, 'PYTHON-2938 could fail on PyPy') + @unittest.skipIf("PyPy" in sys.version, "PYTHON-2938 could fail on PyPy") def test_process_periodic_tasks(self): client = rs_or_single_client() coll = client.db.collection @@ -1640,49 +1623,45 @@ def test_process_periodic_tasks(self): client.close() # Add cursor to kill cursors queue del cursor - wait_until(lambda: client._MongoClient__kill_cursors_queue, - "waited for cursor to be added to queue") + wait_until( + lambda: client._MongoClient__kill_cursors_queue, + "waited for cursor to be added to queue", + ) client._process_periodic_tasks() # This must not raise or print any exceptions with self.assertRaises(InvalidOperation): coll.insert_many([{} for _ in range(5)]) - @unittest.skipUnless( - _HAVE_DNSPYTHON, "DNS-related tests need dnspython to be installed") + @unittest.skipUnless(_HAVE_DNSPYTHON, "DNS-related tests need dnspython to be installed") def test_service_name_from_kwargs(self): client = MongoClient( - 'mongodb+srv://user:password@test22.test.build.10gen.cc', - srvServiceName='customname', connect=False) - self.assertEqual(client._topology_settings.srv_service_name, - 'customname') + "mongodb+srv://user:password@test22.test.build.10gen.cc", + srvServiceName="customname", + connect=False, + ) + self.assertEqual(client._topology_settings.srv_service_name, "customname") client = MongoClient( - 'mongodb+srv://user:password@test22.test.build.10gen.cc' - '/?srvServiceName=shouldbeoverriden', - srvServiceName='customname', connect=False) - self.assertEqual(client._topology_settings.srv_service_name, - 'customname') + "mongodb+srv://user:password@test22.test.build.10gen.cc" + "/?srvServiceName=shouldbeoverriden", + srvServiceName="customname", + connect=False, + ) + self.assertEqual(client._topology_settings.srv_service_name, "customname") client = MongoClient( - 'mongodb+srv://user:password@test22.test.build.10gen.cc' - '/?srvServiceName=customname', - connect=False) - self.assertEqual(client._topology_settings.srv_service_name, - 'customname') - - @unittest.skipUnless( - _HAVE_DNSPYTHON, "DNS-related tests need dnspython to be installed") + "mongodb+srv://user:password@test22.test.build.10gen.cc" "/?srvServiceName=customname", + connect=False, + ) + self.assertEqual(client._topology_settings.srv_service_name, "customname") + + @unittest.skipUnless(_HAVE_DNSPYTHON, "DNS-related tests need dnspython to be installed") def test_srv_max_hosts_kwarg(self): + client = MongoClient("mongodb+srv://test1.test.build.10gen.cc/") + self.assertGreater(len(client.topology_description.server_descriptions()), 1) + client = MongoClient("mongodb+srv://test1.test.build.10gen.cc/", srvmaxhosts=1) + self.assertEqual(len(client.topology_description.server_descriptions()), 1) client = MongoClient( - 'mongodb+srv://test1.test.build.10gen.cc/') - self.assertGreater( - len(client.topology_description.server_descriptions()), 1) - client = MongoClient( - 'mongodb+srv://test1.test.build.10gen.cc/', srvmaxhosts=1) - self.assertEqual( - len(client.topology_description.server_descriptions()), 1) - client = MongoClient( - 'mongodb+srv://test1.test.build.10gen.cc/?srvMaxHosts=1', - srvmaxhosts=2) - self.assertEqual( - len(client.topology_description.server_descriptions()), 2) + "mongodb+srv://test1.test.build.10gen.cc/?srvMaxHosts=1", srvmaxhosts=2 + ) + self.assertEqual(len(client.topology_description.server_descriptions()), 2) class TestExhaustCursor(IntegrationTest): @@ -1705,8 +1684,8 @@ def test_exhaust_query_server_error(self): # This will cause OperationFailure in all mongo versions since # the value for $orderby must be a document. cursor = collection.find( - SON([('$query', {}), ('$orderby', True)]), - cursor_type=CursorType.EXHAUST) + SON([("$query", {}), ("$orderby", True)]), cursor_type=CursorType.EXHAUST + ) self.assertRaises(OperationFailure, cursor.next) self.assertFalse(sock_info.closed) @@ -1740,8 +1719,8 @@ def receive_message(request_id): SocketInfo.receive_message(sock_info, request_id) # responseFlags bit 1 is QueryFailure. - msg = struct.pack('= count, - 'find %s %s event(s)' % (count, event), timeout=timeout) + event = OBJECT_TYPES[op["event"]] + count = op["count"] + timeout = op.get("timeout", 10000) / 1000.0 + wait_until( + lambda: self.listener.event_count(event) >= count, + "find %s %s event(s)" % (count, event), + timeout=timeout, + ) def check_out(self, op): """Run the 'checkOut' operation.""" - label = op['label'] + label = op["label"] with self.pool.get_socket({}) as sock_info: # Call 'pin_cursor' so we can hold the socket. sock_info.pin_cursor() @@ -130,7 +130,7 @@ def check_out(self, op): def check_in(self, op): """Run the 'checkIn' operation.""" - label = op['connection'] + label = op["connection"] sock_info = self.labels[label] self.pool.return_socket(sock_info) @@ -148,8 +148,8 @@ def close(self, op): def run_operation(self, op): """Run a single operation in a test.""" - op_name = camel_to_snake(op['name']) - thread = op['thread'] + op_name = camel_to_snake(op["name"]) + thread = op["thread"] meth = getattr(self, op_name) if thread: self.targets[thread].schedule(lambda: meth(op)) @@ -164,9 +164,9 @@ def run_operations(self, ops): def check_object(self, actual, expected): """Assert that the actual object matches the expected object.""" - self.assertEqual(type(actual), OBJECT_TYPES[expected['type']]) + self.assertEqual(type(actual), OBJECT_TYPES[expected["type"]]) for attr, expected_val in expected.items(): - if attr == 'type': + if attr == "type": continue c2s = camel_to_snake(attr) actual_val = getattr(actual, c2s) @@ -182,62 +182,60 @@ def check_event(self, actual, expected): def actual_events(self, ignore): """Return all the non-ignored events.""" ignore = tuple(OBJECT_TYPES[name] for name in ignore) - return [event for event in self.listener.events - if not isinstance(event, ignore)] + return [event for event in self.listener.events if not isinstance(event, ignore)] def check_events(self, events, ignore): """Check the events of a test.""" actual_events = self.actual_events(ignore) for actual, expected in zip(actual_events, events): - self.logs.append('Checking event actual: %r vs expected: %r' % ( - actual, expected)) + self.logs.append("Checking event actual: %r vs expected: %r" % (actual, expected)) self.check_event(actual, expected) if len(events) > len(actual_events): - self.fail('missing events: %r' % (events[len(actual_events):],)) + self.fail("missing events: %r" % (events[len(actual_events) :],)) def check_error(self, actual, expected): - message = expected.pop('message') + message = expected.pop("message") self.check_object(actual, expected) self.assertIn(message, str(actual)) def _set_fail_point(self, client, command_args): - cmd = SON([('configureFailPoint', 'failCommand')]) + cmd = SON([("configureFailPoint", "failCommand")]) cmd.update(command_args) client.admin.command(cmd) def set_fail_point(self, command_args): if not client_context.supports_failCommand_fail_point: - self.skipTest('failCommand fail point must be supported') + self.skipTest("failCommand fail point must be supported") self._set_fail_point(self.client, command_args) def run_scenario(self, scenario_def, test): """Run a CMAP spec test.""" self.logs = [] - self.assertEqual(scenario_def['version'], 1) - self.assertIn(scenario_def['style'], ['unit', 'integration']) + self.assertEqual(scenario_def["version"], 1) + self.assertIn(scenario_def["style"], ["unit", "integration"]) self.listener = CMAPListener() self._ops = [] # Configure the fail point before creating the client. - if 'failPoint' in test: - fp = test['failPoint'] + if "failPoint" in test: + fp = test["failPoint"] self.set_fail_point(fp) - self.addCleanup(self.set_fail_point, { - 'configureFailPoint': fp['configureFailPoint'], 'mode': 'off'}) - - opts = test['poolOptions'].copy() - opts['event_listeners'] = [self.listener] - opts['_monitor_class'] = DummyMonitor - opts['connect'] = False + self.addCleanup( + self.set_fail_point, {"configureFailPoint": fp["configureFailPoint"], "mode": "off"} + ) + + opts = test["poolOptions"].copy() + opts["event_listeners"] = [self.listener] + opts["_monitor_class"] = DummyMonitor + opts["connect"] = False # Support backgroundThreadIntervalMS, default to 50ms. - interval = opts.pop('backgroundThreadIntervalMS', 50) + interval = opts.pop("backgroundThreadIntervalMS", 50) if interval < 0: kill_cursor_frequency = 99999999 else: - kill_cursor_frequency = interval/1000.0 - with client_knobs(kill_cursor_frequency=kill_cursor_frequency, - min_heartbeat_interval=.05): + kill_cursor_frequency = interval / 1000.0 + with client_knobs(kill_cursor_frequency=kill_cursor_frequency, min_heartbeat_interval=0.05): client = single_client(**opts) # Update the SD to a known type because the DummyMonitor will not. # Note we cannot simply call topology.on_change because that would @@ -245,10 +243,10 @@ def run_scenario(self, scenario_def, test): # PoolReadyEvents. Instead, update the initial state before # opening the Topology. td = client_context.client._topology.description - sd = td.server_descriptions()[(client_context.host, - client_context.port)] + sd = td.server_descriptions()[(client_context.host, client_context.port)] client._topology._description = updated_topology_description( - client._topology._description, sd) + client._topology._description, sd + ) # When backgroundThreadIntervalMS is negative we do not start the # background thread to ensure it never runs. if interval < 0: @@ -274,37 +272,37 @@ def cleanup(): self.addCleanup(cleanup) try: - if test['error']: + if test["error"]: with self.assertRaises(PyMongoError) as ctx: - self.run_operations(test['operations']) - self.check_error(ctx.exception, test['error']) + self.run_operations(test["operations"]) + self.check_error(ctx.exception, test["error"]) else: - self.run_operations(test['operations']) + self.run_operations(test["operations"]) - self.check_events(test['events'], test['ignore']) + self.check_events(test["events"], test["ignore"]) except Exception: # Print the events after a test failure. - print('\nFailed test: %r' % (test['description'],)) - print('Operations:') + print("\nFailed test: %r" % (test["description"],)) + print("Operations:") for op in self._ops: print(op) - print('Threads:') + print("Threads:") print(self.targets) - print('Connections:') + print("Connections:") print(self.labels) - print('Events:') + print("Events:") for event in self.listener.events: print(event) - print('Log:') + print("Log:") for log in self.logs: print(log) raise POOL_OPTIONS = { - 'maxPoolSize': 50, - 'minPoolSize': 1, - 'maxIdleTimeMS': 10000, - 'waitQueueTimeoutMS': 10000 + "maxPoolSize": 50, + "minPoolSize": 1, + "maxIdleTimeMS": 10000, + "waitQueueTimeoutMS": 10000, } # @@ -319,11 +317,10 @@ def test_1_client_connection_pool_options(self): def test_2_all_client_pools_have_same_options(self): client = rs_or_single_client(**self.POOL_OPTIONS) self.addCleanup(client.close) - client.admin.command('ping') + client.admin.command("ping") # Discover at least one secondary. if client_context.has_secondaries: - client.admin.command( - 'ping', read_preference=ReadPreference.SECONDARY) + client.admin.command("ping", read_preference=ReadPreference.SECONDARY) pools = get_pools(client) pool_opts = pools[0].opts @@ -332,9 +329,8 @@ def test_2_all_client_pools_have_same_options(self): self.assertEqual(pool.opts, pool_opts) def test_3_uri_connection_pool_options(self): - opts = '&'.join(['%s=%s' % (k, v) - for k, v in self.POOL_OPTIONS.items()]) - uri = 'mongodb://%s/?%s' % (client_context.pair, opts) + opts = "&".join(["%s=%s" % (k, v) for k, v in self.POOL_OPTIONS.items()]) + uri = "mongodb://%s/?%s" % (client_context.pair, opts) client = rs_or_single_client(uri) self.addCleanup(client.close) pool_opts = get_pool(client).opts @@ -347,18 +343,16 @@ def test_4_subscribe_to_events(self): self.assertEqual(listener.event_count(PoolCreatedEvent), 1) # Creates a new connection. - client.admin.command('ping') - self.assertEqual( - listener.event_count(ConnectionCheckOutStartedEvent), 1) + client.admin.command("ping") + self.assertEqual(listener.event_count(ConnectionCheckOutStartedEvent), 1) self.assertEqual(listener.event_count(ConnectionCreatedEvent), 1) self.assertEqual(listener.event_count(ConnectionReadyEvent), 1) self.assertEqual(listener.event_count(ConnectionCheckedOutEvent), 1) self.assertEqual(listener.event_count(ConnectionCheckedInEvent), 1) # Uses the existing connection. - client.admin.command('ping') - self.assertEqual( - listener.event_count(ConnectionCheckOutStartedEvent), 2) + client.admin.command("ping") + self.assertEqual(listener.event_count(ConnectionCheckOutStartedEvent), 2) self.assertEqual(listener.event_count(ConnectionCheckedOutEvent), 2) self.assertEqual(listener.event_count(ConnectionCheckedInEvent), 2) @@ -373,49 +367,44 @@ def test_5_check_out_fails_connection_error(self): pool = get_pool(client) def mock_connect(*args, **kwargs): - raise ConnectionFailure('connect failed') + raise ConnectionFailure("connect failed") + pool.connect = mock_connect # Un-patch Pool.connect to break the cyclic reference. - self.addCleanup(delattr, pool, 'connect') + self.addCleanup(delattr, pool, "connect") # Attempt to create a new connection. - with self.assertRaisesRegex(ConnectionFailure, 'connect failed'): - client.admin.command('ping') + with self.assertRaisesRegex(ConnectionFailure, "connect failed"): + client.admin.command("ping") self.assertIsInstance(listener.events[0], PoolCreatedEvent) self.assertIsInstance(listener.events[1], PoolReadyEvent) - self.assertIsInstance(listener.events[2], - ConnectionCheckOutStartedEvent) - self.assertIsInstance(listener.events[3], - ConnectionCheckOutFailedEvent) + self.assertIsInstance(listener.events[2], ConnectionCheckOutStartedEvent) + self.assertIsInstance(listener.events[3], ConnectionCheckOutFailedEvent) self.assertIsInstance(listener.events[4], PoolClearedEvent) failed_event = listener.events[3] - self.assertEqual( - failed_event.reason, ConnectionCheckOutFailedReason.CONN_ERROR) + self.assertEqual(failed_event.reason, ConnectionCheckOutFailedReason.CONN_ERROR) def test_5_check_out_fails_auth_error(self): listener = CMAPListener() client = single_client_noauth( - username="notauser", password="fail", - event_listeners=[listener]) + username="notauser", password="fail", event_listeners=[listener] + ) self.addCleanup(client.close) # Attempt to create a new connection. - with self.assertRaisesRegex(OperationFailure, 'failed'): - client.admin.command('ping') + with self.assertRaisesRegex(OperationFailure, "failed"): + client.admin.command("ping") self.assertIsInstance(listener.events[0], PoolCreatedEvent) self.assertIsInstance(listener.events[1], PoolReadyEvent) - self.assertIsInstance(listener.events[2], - ConnectionCheckOutStartedEvent) + self.assertIsInstance(listener.events[2], ConnectionCheckOutStartedEvent) self.assertIsInstance(listener.events[3], ConnectionCreatedEvent) # Error happens here. self.assertIsInstance(listener.events[4], ConnectionClosedEvent) - self.assertIsInstance(listener.events[5], - ConnectionCheckOutFailedEvent) - self.assertEqual(listener.events[5].reason, - ConnectionCheckOutFailedReason.CONN_ERROR) + self.assertIsInstance(listener.events[5], ConnectionCheckOutFailedEvent) + self.assertEqual(listener.events[5].reason, ConnectionCheckOutFailedReason.CONN_ERROR) # # Extra non-spec tests @@ -426,13 +415,13 @@ def assertRepr(self, obj): self.assertEqual(repr(new_obj), repr(obj)) def test_events_repr(self): - host = ('localhost', 27017) + host = ("localhost", 27017) self.assertRepr(ConnectionCheckedInEvent(host, 1)) self.assertRepr(ConnectionCheckedOutEvent(host, 1)) - self.assertRepr(ConnectionCheckOutFailedEvent( - host, ConnectionCheckOutFailedReason.POOL_CLOSED)) - self.assertRepr(ConnectionClosedEvent( - host, 1, ConnectionClosedReason.POOL_CLOSED)) + self.assertRepr( + ConnectionCheckOutFailedEvent(host, ConnectionCheckOutFailedReason.POOL_CLOSED) + ) + self.assertRepr(ConnectionClosedEvent(host, 1, ConnectionClosedReason.POOL_CLOSED)) self.assertRepr(ConnectionCreatedEvent(host, 1)) self.assertRepr(ConnectionReadyEvent(host, 1)) self.assertRepr(ConnectionCheckOutStartedEvent(host)) @@ -446,7 +435,7 @@ def test_close_leaves_pool_unpaused(self): # test_threads.TestThreads.test_client_disconnect listener = CMAPListener() client = single_client(event_listeners=[listener]) - client.admin.command('ping') + client.admin.command("ping") pool = get_pool(client) client.close() self.assertEqual(1, listener.event_count(PoolClearedEvent)) @@ -464,7 +453,6 @@ def run_scenario(self): class CMAPTestCreator(TestCreator): - def tests(self, scenario_def): """Extract the tests from a spec file. diff --git a/test/test_code.py b/test/test_code.py index c5e190f363..97586573da 100644 --- a/test/test_code.py +++ b/test/test_code.py @@ -17,11 +17,13 @@ """Tests for the Code wrapper.""" import sys + sys.path[0:0] = [""] -from bson.code import Code from test import unittest +from bson.code import Code + class TestCode(unittest.TestCase): def test_types(self): @@ -36,6 +38,7 @@ def test_read_only(self): def set_c(): c.scope = 5 + self.assertRaises(AttributeError, set_c) def test_code(self): @@ -46,15 +49,15 @@ def test_code(self): self.assertTrue(isinstance(a_code, Code)) self.assertFalse(isinstance(a_string, Code)) self.assertIsNone(a_code.scope) - with_scope = Code('hello world', {'my_var': 5}) - self.assertEqual({'my_var': 5}, with_scope.scope) - empty_scope = Code('hello world', {}) + with_scope = Code("hello world", {"my_var": 5}) + self.assertEqual({"my_var": 5}, with_scope.scope) + empty_scope = Code("hello world", {}) self.assertEqual({}, empty_scope.scope) - another_scope = Code(with_scope, {'new_var': 42}) + another_scope = Code(with_scope, {"new_var": 42}) self.assertEqual(str(with_scope), str(another_scope)) - self.assertEqual({'new_var': 42, 'my_var': 5}, another_scope.scope) + self.assertEqual({"new_var": 42, "my_var": 5}, another_scope.scope) # No error. - Code('héllø world¡') + Code("héllø world¡") def test_repr(self): c = Code("hello world", {}) @@ -97,8 +100,7 @@ def test_scope_preserved(self): def test_scope_kwargs(self): self.assertEqual({"a": 1}, Code("", a=1).scope) self.assertEqual({"a": 1}, Code("", {"a": 2}, a=1).scope) - self.assertEqual({"a": 1, "b": 2, "c": 3}, - Code("", {"b": 2}, a=1, c=3).scope) + self.assertEqual({"a": 1, "b": 2, "c": 3}, Code("", {"b": 2}, a=1, c=3).scope) if __name__ == "__main__": diff --git a/test/test_collation.py b/test/test_collation.py index f0139b4a22..7631289cbc 100644 --- a/test/test_collation.py +++ b/test/test_collation.py @@ -16,40 +16,47 @@ import functools import warnings +from test import IntegrationTest, client_context, unittest +from test.utils import EventListener, rs_or_single_client from pymongo.collation import ( Collation, - CollationCaseFirst, CollationStrength, CollationAlternate, - CollationMaxVariable) + CollationAlternate, + CollationCaseFirst, + CollationMaxVariable, + CollationStrength, +) from pymongo.errors import ConfigurationError -from pymongo.operations import (DeleteMany, DeleteOne, IndexModel, ReplaceOne, - UpdateMany, UpdateOne) +from pymongo.operations import ( + DeleteMany, + DeleteOne, + IndexModel, + ReplaceOne, + UpdateMany, + UpdateOne, +) from pymongo.write_concern import WriteConcern -from test import client_context, IntegrationTest, unittest -from test.utils import EventListener, rs_or_single_client class TestCollationObject(unittest.TestCase): - def test_constructor(self): self.assertRaises(TypeError, Collation, locale=42) # Fill in a locale to test the other options. - _Collation = functools.partial(Collation, 'en_US') + _Collation = functools.partial(Collation, "en_US") # No error. _Collation(caseFirst=CollationCaseFirst.UPPER) - self.assertRaises(TypeError, _Collation, caseLevel='true') - self.assertRaises(ValueError, _Collation, strength='six') - self.assertRaises(TypeError, _Collation, - numericOrdering='true') + self.assertRaises(TypeError, _Collation, caseLevel="true") + self.assertRaises(ValueError, _Collation, strength="six") + self.assertRaises(TypeError, _Collation, numericOrdering="true") self.assertRaises(TypeError, _Collation, alternate=5) self.assertRaises(TypeError, _Collation, maxVariable=2) - self.assertRaises(TypeError, _Collation, normalization='false') - self.assertRaises(TypeError, _Collation, backwards='true') + self.assertRaises(TypeError, _Collation, normalization="false") + self.assertRaises(TypeError, _Collation, backwards="true") # No errors. - Collation('en_US', future_option='bar', another_option=42) + Collation("en_US", future_option="bar", another_option=42) collation = Collation( - 'en_US', + "en_US", caseLevel=True, caseFirst=CollationCaseFirst.UPPER, strength=CollationStrength.QUATERNARY, @@ -57,24 +64,27 @@ def test_constructor(self): alternate=CollationAlternate.SHIFTED, maxVariable=CollationMaxVariable.SPACE, normalization=True, - backwards=True) - - self.assertEqual({ - 'locale': 'en_US', - 'caseLevel': True, - 'caseFirst': 'upper', - 'strength': 4, - 'numericOrdering': True, - 'alternate': 'shifted', - 'maxVariable': 'space', - 'normalization': True, - 'backwards': True - }, collation.document) - - self.assertEqual({ - 'locale': 'en_US', - 'backwards': True - }, Collation('en_US', backwards=True).document) + backwards=True, + ) + + self.assertEqual( + { + "locale": "en_US", + "caseLevel": True, + "caseFirst": "upper", + "strength": 4, + "numericOrdering": True, + "alternate": "shifted", + "maxVariable": "space", + "normalization": True, + "backwards": True, + }, + collation.document, + ) + + self.assertEqual( + {"locale": "en_US", "backwards": True}, Collation("en_US", backwards=True).document + ) class TestCollation(IntegrationTest): @@ -85,7 +95,7 @@ def setUpClass(cls): cls.listener = EventListener() cls.client = rs_or_single_client(event_listeners=[cls.listener]) cls.db = cls.client.pymongo_test - cls.collation = Collation('en_US') + cls.collation = Collation("en_US") cls.warn_context = warnings.catch_warnings() cls.warn_context.__enter__() warnings.simplefilter("ignore", DeprecationWarning) @@ -102,38 +112,33 @@ def tearDown(self): super(TestCollation, self).tearDown() def last_command_started(self): - return self.listener.results['started'][-1].command + return self.listener.results["started"][-1].command def assertCollationInLastCommand(self): - self.assertEqual( - self.collation.document, - self.last_command_started()['collation']) + self.assertEqual(self.collation.document, self.last_command_started()["collation"]) def test_create_collection(self): self.db.test.drop() - self.db.create_collection('test', collation=self.collation) + self.db.create_collection("test", collation=self.collation) self.assertCollationInLastCommand() # Test passing collation as a dict as well. self.db.test.drop() self.listener.results.clear() - self.db.create_collection('test', collation=self.collation.document) + self.db.create_collection("test", collation=self.collation.document) self.assertCollationInLastCommand() def test_index_model(self): - model = IndexModel([('a', 1), ('b', -1)], collation=self.collation) - self.assertEqual(self.collation.document, model.document['collation']) + model = IndexModel([("a", 1), ("b", -1)], collation=self.collation) + self.assertEqual(self.collation.document, model.document["collation"]) def test_create_index(self): - self.db.test.create_index('foo', collation=self.collation) - ci_cmd = self.listener.results['started'][0].command - self.assertEqual( - self.collation.document, - ci_cmd['indexes'][0]['collation']) + self.db.test.create_index("foo", collation=self.collation) + ci_cmd = self.listener.results["started"][0].command + self.assertEqual(self.collation.document, ci_cmd["indexes"][0]["collation"]) def test_aggregate(self): - self.db.test.aggregate([{'$group': {'_id': 42}}], - collation=self.collation) + self.db.test.aggregate([{"$group": {"_id": 42}}], collation=self.collation) self.assertCollationInLastCommand() def test_count_documents(self): @@ -141,15 +146,15 @@ def test_count_documents(self): self.assertCollationInLastCommand() def test_distinct(self): - self.db.test.distinct('foo', collation=self.collation) + self.db.test.distinct("foo", collation=self.collation) self.assertCollationInLastCommand() self.listener.results.clear() - self.db.test.find(collation=self.collation).distinct('foo') + self.db.test.find(collation=self.collation).distinct("foo") self.assertCollationInLastCommand() def test_find_command(self): - self.db.test.insert_one({'is this thing on?': True}) + self.db.test.insert_one({"is this thing on?": True}) self.listener.results.clear() next(self.db.test.find(collation=self.collation)) self.assertCollationInLastCommand() @@ -159,127 +164,118 @@ def test_explain_command(self): self.db.test.find(collation=self.collation).explain() # The collation should be part of the explained command. self.assertEqual( - self.collation.document, - self.last_command_started()['explain']['collation']) + self.collation.document, self.last_command_started()["explain"]["collation"] + ) def test_delete(self): - self.db.test.delete_one({'foo': 42}, collation=self.collation) - command = self.listener.results['started'][0].command - self.assertEqual( - self.collation.document, - command['deletes'][0]['collation']) + self.db.test.delete_one({"foo": 42}, collation=self.collation) + command = self.listener.results["started"][0].command + self.assertEqual(self.collation.document, command["deletes"][0]["collation"]) self.listener.results.clear() - self.db.test.delete_many({'foo': 42}, collation=self.collation) - command = self.listener.results['started'][0].command - self.assertEqual( - self.collation.document, - command['deletes'][0]['collation']) + self.db.test.delete_many({"foo": 42}, collation=self.collation) + command = self.listener.results["started"][0].command + self.assertEqual(self.collation.document, command["deletes"][0]["collation"]) def test_update(self): - self.db.test.replace_one({'foo': 42}, {'foo': 43}, - collation=self.collation) - command = self.listener.results['started'][0].command - self.assertEqual( - self.collation.document, - command['updates'][0]['collation']) + self.db.test.replace_one({"foo": 42}, {"foo": 43}, collation=self.collation) + command = self.listener.results["started"][0].command + self.assertEqual(self.collation.document, command["updates"][0]["collation"]) self.listener.results.clear() - self.db.test.update_one({'foo': 42}, {'$set': {'foo': 43}}, - collation=self.collation) - command = self.listener.results['started'][0].command - self.assertEqual( - self.collation.document, - command['updates'][0]['collation']) + self.db.test.update_one({"foo": 42}, {"$set": {"foo": 43}}, collation=self.collation) + command = self.listener.results["started"][0].command + self.assertEqual(self.collation.document, command["updates"][0]["collation"]) self.listener.results.clear() - self.db.test.update_many({'foo': 42}, {'$set': {'foo': 43}}, - collation=self.collation) - command = self.listener.results['started'][0].command - self.assertEqual( - self.collation.document, - command['updates'][0]['collation']) + self.db.test.update_many({"foo": 42}, {"$set": {"foo": 43}}, collation=self.collation) + command = self.listener.results["started"][0].command + self.assertEqual(self.collation.document, command["updates"][0]["collation"]) def test_find_and(self): - self.db.test.find_one_and_delete({'foo': 42}, collation=self.collation) + self.db.test.find_one_and_delete({"foo": 42}, collation=self.collation) self.assertCollationInLastCommand() self.listener.results.clear() - self.db.test.find_one_and_update({'foo': 42}, {'$set': {'foo': 43}}, - collation=self.collation) + self.db.test.find_one_and_update( + {"foo": 42}, {"$set": {"foo": 43}}, collation=self.collation + ) self.assertCollationInLastCommand() self.listener.results.clear() - self.db.test.find_one_and_replace({'foo': 42}, {'foo': 43}, - collation=self.collation) + self.db.test.find_one_and_replace({"foo": 42}, {"foo": 43}, collation=self.collation) self.assertCollationInLastCommand() def test_bulk_write(self): - self.db.test.collection.bulk_write([ - DeleteOne({'noCollation': 42}), - DeleteMany({'noCollation': 42}), - DeleteOne({'foo': 42}, collation=self.collation), - DeleteMany({'foo': 42}, collation=self.collation), - ReplaceOne({'noCollation': 24}, {'bar': 42}), - UpdateOne({'noCollation': 84}, {'$set': {'bar': 10}}, upsert=True), - UpdateMany({'noCollation': 45}, {'$set': {'bar': 42}}), - ReplaceOne({'foo': 24}, {'foo': 42}, collation=self.collation), - UpdateOne({'foo': 84}, {'$set': {'foo': 10}}, upsert=True, - collation=self.collation), - UpdateMany({'foo': 45}, {'$set': {'foo': 42}}, - collation=self.collation) - ]) - - delete_cmd = self.listener.results['started'][0].command - update_cmd = self.listener.results['started'][1].command + self.db.test.collection.bulk_write( + [ + DeleteOne({"noCollation": 42}), + DeleteMany({"noCollation": 42}), + DeleteOne({"foo": 42}, collation=self.collation), + DeleteMany({"foo": 42}, collation=self.collation), + ReplaceOne({"noCollation": 24}, {"bar": 42}), + UpdateOne({"noCollation": 84}, {"$set": {"bar": 10}}, upsert=True), + UpdateMany({"noCollation": 45}, {"$set": {"bar": 42}}), + ReplaceOne({"foo": 24}, {"foo": 42}, collation=self.collation), + UpdateOne( + {"foo": 84}, {"$set": {"foo": 10}}, upsert=True, collation=self.collation + ), + UpdateMany({"foo": 45}, {"$set": {"foo": 42}}, collation=self.collation), + ] + ) + + delete_cmd = self.listener.results["started"][0].command + update_cmd = self.listener.results["started"][1].command def check_ops(ops): for op in ops: - if 'noCollation' in op['q']: - self.assertNotIn('collation', op) + if "noCollation" in op["q"]: + self.assertNotIn("collation", op) else: - self.assertEqual(self.collation.document, - op['collation']) + self.assertEqual(self.collation.document, op["collation"]) - check_ops(delete_cmd['deletes']) - check_ops(update_cmd['updates']) + check_ops(delete_cmd["deletes"]) + check_ops(update_cmd["updates"]) def test_indexes_same_keys_different_collations(self): self.db.test.drop() - usa_collation = Collation('en_US') - ja_collation = Collation('ja') - self.db.test.create_indexes([ - IndexModel('fieldname', collation=usa_collation), - IndexModel('fieldname', name='japanese_version', - collation=ja_collation), - IndexModel('fieldname', name='simple') - ]) + usa_collation = Collation("en_US") + ja_collation = Collation("ja") + self.db.test.create_indexes( + [ + IndexModel("fieldname", collation=usa_collation), + IndexModel("fieldname", name="japanese_version", collation=ja_collation), + IndexModel("fieldname", name="simple"), + ] + ) indexes = self.db.test.index_information() - self.assertEqual(usa_collation.document['locale'], - indexes['fieldname_1']['collation']['locale']) - self.assertEqual(ja_collation.document['locale'], - indexes['japanese_version']['collation']['locale']) - self.assertNotIn('collation', indexes['simple']) - self.db.test.drop_index('fieldname_1') + self.assertEqual( + usa_collation.document["locale"], indexes["fieldname_1"]["collation"]["locale"] + ) + self.assertEqual( + ja_collation.document["locale"], indexes["japanese_version"]["collation"]["locale"] + ) + self.assertNotIn("collation", indexes["simple"]) + self.db.test.drop_index("fieldname_1") indexes = self.db.test.index_information() - self.assertIn('japanese_version', indexes) - self.assertIn('simple', indexes) - self.assertNotIn('fieldname', indexes) + self.assertIn("japanese_version", indexes) + self.assertIn("simple", indexes) + self.assertNotIn("fieldname", indexes) def test_unacknowledged_write(self): unacknowledged = WriteConcern(w=0) - collection = self.db.get_collection( - 'test', write_concern=unacknowledged) + collection = self.db.get_collection("test", write_concern=unacknowledged) with self.assertRaises(ConfigurationError): collection.update_one( - {'hello': 'world'}, {'$set': {'hello': 'moon'}}, - collation=self.collation) - update_one = UpdateOne({'hello': 'world'}, {'$set': {'hello': 'moon'}}, - collation=self.collation) + {"hello": "world"}, {"$set": {"hello": "moon"}}, collation=self.collation + ) + update_one = UpdateOne( + {"hello": "world"}, {"$set": {"hello": "moon"}}, collation=self.collation + ) with self.assertRaises(ConfigurationError): collection.bulk_write([update_one]) def test_cursor_collation(self): - self.db.test.insert_one({'hello': 'world'}) + self.db.test.insert_one({"hello": "world"}) next(self.db.test.find().collation(self.collation)) self.assertCollationInLastCommand() diff --git a/test/test_collection.py b/test/test_collection.py index 79a2a907a6..f7d069e2e1 100644 --- a/test/test_collection.py +++ b/test/test_collection.py @@ -19,53 +19,61 @@ import contextlib import re import sys - from codecs import utf_8_decode from collections import defaultdict sys.path[0:0] = [""] +from test import client_context, unittest +from test.test_client import IntegrationTest +from test.utils import ( + IMPOSSIBLE_WRITE_CONCERN, + EventListener, + get_pool, + is_mongos, + rs_or_single_client, + single_client, + wait_until, +) + from bson import encode -from bson.raw_bson import RawBSONDocument -from bson.regex import Regex from bson.codec_options import CodecOptions from bson.objectid import ObjectId +from bson.raw_bson import RawBSONDocument +from bson.regex import Regex from bson.son import SON from pymongo import ASCENDING, DESCENDING, GEO2D, GEOSPHERE, HASHED, TEXT from pymongo.bulk import BulkWriteError from pymongo.collection import Collection, ReturnDocument from pymongo.command_cursor import CommandCursor from pymongo.cursor import CursorType -from pymongo.errors import (ConfigurationError, - DocumentTooLarge, - DuplicateKeyError, - ExecutionTimeout, - InvalidDocument, - InvalidName, - InvalidOperation, - OperationFailure, - WriteConcernError) +from pymongo.errors import ( + ConfigurationError, + DocumentTooLarge, + DuplicateKeyError, + ExecutionTimeout, + InvalidDocument, + InvalidName, + InvalidOperation, + OperationFailure, + WriteConcernError, +) from pymongo.message import _COMMAND_OVERHEAD, _gen_find_command from pymongo.mongo_client import MongoClient from pymongo.operations import * from pymongo.read_concern import DEFAULT_READ_CONCERN from pymongo.read_preferences import ReadPreference -from pymongo.results import (InsertOneResult, - InsertManyResult, - UpdateResult, - DeleteResult) +from pymongo.results import ( + DeleteResult, + InsertManyResult, + InsertOneResult, + UpdateResult, +) from pymongo.write_concern import WriteConcern -from test import client_context, unittest -from test.test_client import IntegrationTest -from test.utils import (get_pool, is_mongos, - rs_or_single_client, single_client, - wait_until, EventListener, - IMPOSSIBLE_WRITE_CONCERN) class TestCollectionNoConnect(unittest.TestCase): - """Test Collection features on a client that does not connect. - """ + """Test Collection features on a client that does not connect.""" @classmethod def setUpClass(cls): @@ -91,7 +99,7 @@ def make_col(base, name): def test_getattr(self): coll = self.db.test - self.assertTrue(isinstance(coll['_does_not_exist'], Collection)) + self.assertTrue(isinstance(coll["_does_not_exist"], Collection)) with self.assertRaises(AttributeError) as context: coll._does_not_exist @@ -100,8 +108,7 @@ def test_getattr(self): # "AttributeError: Collection has no attribute '_does_not_exist'. To # access the test._does_not_exist collection, use # database['test._does_not_exist']." - self.assertIn("has no attribute '_does_not_exist'", - str(context.exception)) + self.assertIn("has no attribute '_does_not_exist'", str(context.exception)) coll2 = coll.with_options(write_concern=WriteConcern(w=0)) self.assertEqual(coll2.write_concern, WriteConcern(w=0)) @@ -116,7 +123,6 @@ def test_iteration(self): class TestCollection(IntegrationTest): - @classmethod def setUpClass(cls): super(TestCollection, cls).setUpClass() @@ -138,8 +144,8 @@ def write_concern_collection(self): with self.assertRaises(WriteConcernError): # Unsatisfiable write concern. yield Collection( - self.db, 'test', - write_concern=WriteConcern(w=len(client_context.nodes) + 1)) + self.db, "test", write_concern=WriteConcern(w=len(client_context.nodes) + 1) + ) else: yield self.db.test @@ -158,33 +164,33 @@ def test_create(self): db = client_context.client.pymongo_test db.create_test_no_wc.drop() wait_until( - lambda: 'create_test_no_wc' not in db.list_collection_names(), - 'drop create_test_no_wc collection') - Collection(db, name='create_test_no_wc', create=True) + lambda: "create_test_no_wc" not in db.list_collection_names(), + "drop create_test_no_wc collection", + ) + Collection(db, name="create_test_no_wc", create=True) wait_until( - lambda: 'create_test_no_wc' in db.list_collection_names(), - 'create create_test_no_wc collection') + lambda: "create_test_no_wc" in db.list_collection_names(), + "create create_test_no_wc collection", + ) # SERVER-33317 - if (not client_context.is_mongos or not - client_context.version.at_least(3, 7, 0)): + if not client_context.is_mongos or not client_context.version.at_least(3, 7, 0): with self.assertRaises(OperationFailure): Collection( - db, name='create-test-wc', - write_concern=IMPOSSIBLE_WRITE_CONCERN, - create=True) + db, name="create-test-wc", write_concern=IMPOSSIBLE_WRITE_CONCERN, create=True + ) def test_drop_nonexistent_collection(self): - self.db.drop_collection('test') - self.assertFalse('test' in self.db.list_collection_names()) + self.db.drop_collection("test") + self.assertFalse("test" in self.db.list_collection_names()) # No exception - self.db.drop_collection('test') + self.db.drop_collection("test") def test_create_indexes(self): db = self.db - self.assertRaises(TypeError, db.test.create_indexes, 'foo') - self.assertRaises(TypeError, db.test.create_indexes, ['foo']) + self.assertRaises(TypeError, db.test.create_indexes, "foo") + self.assertRaises(TypeError, db.test.create_indexes, ["foo"]) self.assertRaises(TypeError, IndexModel, 5) self.assertRaises(ValueError, IndexModel, []) @@ -193,8 +199,7 @@ def test_create_indexes(self): self.assertEqual(len(db.test.index_information()), 1) db.test.create_indexes([IndexModel("hello")]) - db.test.create_indexes([IndexModel([("hello", DESCENDING), - ("world", ASCENDING)])]) + db.test.create_indexes([IndexModel([("hello", DESCENDING), ("world", ASCENDING)])]) # Tuple instead of list. db.test.create_indexes([IndexModel((("world", ASCENDING),))]) @@ -202,9 +207,9 @@ def test_create_indexes(self): self.assertEqual(len(db.test.index_information()), 4) db.test.drop_indexes() - names = db.test.create_indexes([IndexModel([("hello", DESCENDING), - ("world", ASCENDING)], - name="hello_world")]) + names = db.test.create_indexes( + [IndexModel([("hello", DESCENDING), ("world", ASCENDING)], name="hello_world")] + ) self.assertEqual(names, ["hello_world"]) db.test.drop_indexes() @@ -214,37 +219,35 @@ def test_create_indexes(self): db.test.drop_indexes() self.assertEqual(len(db.test.index_information()), 1) - names = db.test.create_indexes([IndexModel([("hello", DESCENDING), - ("world", ASCENDING)]), - IndexModel("hello")]) + names = db.test.create_indexes( + [IndexModel([("hello", DESCENDING), ("world", ASCENDING)]), IndexModel("hello")] + ) info = db.test.index_information() for name in names: self.assertTrue(name in info) db.test.drop() - db.test.insert_one({'a': 1}) - db.test.insert_one({'a': 1}) - self.assertRaises( - DuplicateKeyError, - db.test.create_indexes, - [IndexModel('a', unique=True)]) + db.test.insert_one({"a": 1}) + db.test.insert_one({"a": 1}) + self.assertRaises(DuplicateKeyError, db.test.create_indexes, [IndexModel("a", unique=True)]) with self.write_concern_collection() as coll: - coll.create_indexes([IndexModel('hello')]) + coll.create_indexes([IndexModel("hello")]) @client_context.require_version_max(4, 3, -1) def test_create_indexes_commitQuorum_requires_44(self): db = self.db with self.assertRaisesRegex( - ConfigurationError, - 'Must be connected to MongoDB 4\.4\+ to use the commitQuorum ' - 'option for createIndexes'): - db.coll.create_indexes([IndexModel('a')], commitQuorum="majority") + ConfigurationError, + "Must be connected to MongoDB 4\.4\+ to use the commitQuorum " + "option for createIndexes", + ): + db.coll.create_indexes([IndexModel("a")], commitQuorum="majority") @client_context.require_no_standalone @client_context.require_version_min(4, 4, -1) def test_create_indexes_commitQuorum(self): - self.db.coll.create_indexes([IndexModel('a')], commitQuorum="majority") + self.db.coll.create_indexes([IndexModel("a")], commitQuorum="majority") def test_create_index(self): db = self.db @@ -266,8 +269,7 @@ def test_create_index(self): self.assertEqual(len(db.test.index_information()), 4) db.test.drop_indexes() - ix = db.test.create_index([("hello", DESCENDING), - ("world", ASCENDING)], name="hello_world") + ix = db.test.create_index([("hello", DESCENDING), ("world", ASCENDING)], name="hello_world") self.assertEqual(ix, "hello_world") db.test.drop_indexes() @@ -281,13 +283,12 @@ def test_create_index(self): self.assertTrue("hello_-1_world_1" in db.test.index_information()) db.test.drop() - db.test.insert_one({'a': 1}) - db.test.insert_one({'a': 1}) - self.assertRaises( - DuplicateKeyError, db.test.create_index, 'a', unique=True) + db.test.insert_one({"a": 1}) + db.test.insert_one({"a": 1}) + self.assertRaises(DuplicateKeyError, db.test.create_index, "a", unique=True) with self.write_concern_collection() as coll: - coll.create_index([('hello', DESCENDING)]) + coll.create_index([("hello", DESCENDING)]) def test_drop_index(self): db = self.db @@ -316,31 +317,22 @@ def test_drop_index(self): self.assertTrue("hello_1" in db.test.index_information()) with self.write_concern_collection() as coll: - coll.drop_index('hello_1') + coll.drop_index("hello_1") @client_context.require_no_mongos @client_context.require_test_commands def test_index_management_max_time_ms(self): coll = self.db.test - self.client.admin.command("configureFailPoint", - "maxTimeAlwaysTimeOut", - mode="alwaysOn") + self.client.admin.command("configureFailPoint", "maxTimeAlwaysTimeOut", mode="alwaysOn") try: + self.assertRaises(ExecutionTimeout, coll.create_index, "foo", maxTimeMS=1) self.assertRaises( - ExecutionTimeout, coll.create_index, "foo", maxTimeMS=1) - self.assertRaises( - ExecutionTimeout, - coll.create_indexes, - [IndexModel("foo")], - maxTimeMS=1) - self.assertRaises( - ExecutionTimeout, coll.drop_index, "foo", maxTimeMS=1) - self.assertRaises( - ExecutionTimeout, coll.drop_indexes, maxTimeMS=1) + ExecutionTimeout, coll.create_indexes, [IndexModel("foo")], maxTimeMS=1 + ) + self.assertRaises(ExecutionTimeout, coll.drop_index, "foo", maxTimeMS=1) + self.assertRaises(ExecutionTimeout, coll.drop_indexes, maxTimeMS=1) finally: - self.client.admin.command("configureFailPoint", - "maxTimeAlwaysTimeOut", - mode="off") + self.client.admin.command("configureFailPoint", "maxTimeAlwaysTimeOut", mode="off") def test_list_indexes(self): db = self.db @@ -357,16 +349,15 @@ def map_indexes(indexes): db.test.create_index("hello") indexes = list(db.test.list_indexes()) self.assertEqual(len(indexes), 2) - self.assertEqual(map_indexes(indexes)["hello_1"]["key"], - SON([("hello", ASCENDING)])) + self.assertEqual(map_indexes(indexes)["hello_1"]["key"], SON([("hello", ASCENDING)])) - db.test.create_index([("hello", DESCENDING), ("world", ASCENDING)], - unique=True) + db.test.create_index([("hello", DESCENDING), ("world", ASCENDING)], unique=True) indexes = list(db.test.list_indexes()) self.assertEqual(len(indexes), 3) index_map = map_indexes(indexes) - self.assertEqual(index_map["hello_-1_world_1"]["key"], - SON([("hello", DESCENDING), ("world", ASCENDING)])) + self.assertEqual( + index_map["hello_-1_world_1"]["key"], SON([("hello", DESCENDING), ("world", ASCENDING)]) + ) self.assertEqual(True, index_map["hello_-1_world_1"]["unique"]) # List indexes on a collection that does not exist. @@ -386,26 +377,23 @@ def test_index_info(self): db.test.create_index("hello") self.assertEqual(len(db.test.index_information()), 2) - self.assertEqual(db.test.index_information()["hello_1"]["key"], - [("hello", ASCENDING)]) + self.assertEqual(db.test.index_information()["hello_1"]["key"], [("hello", ASCENDING)]) - db.test.create_index([("hello", DESCENDING), ("world", ASCENDING)], - unique=True) - self.assertEqual(db.test.index_information()["hello_1"]["key"], - [("hello", ASCENDING)]) + db.test.create_index([("hello", DESCENDING), ("world", ASCENDING)], unique=True) + self.assertEqual(db.test.index_information()["hello_1"]["key"], [("hello", ASCENDING)]) self.assertEqual(len(db.test.index_information()), 3) - self.assertEqual([("hello", DESCENDING), ("world", ASCENDING)], - db.test.index_information()["hello_-1_world_1"]["key"] - ) self.assertEqual( - True, db.test.index_information()["hello_-1_world_1"]["unique"]) + [("hello", DESCENDING), ("world", ASCENDING)], + db.test.index_information()["hello_-1_world_1"]["key"], + ) + self.assertEqual(True, db.test.index_information()["hello_-1_world_1"]["unique"]) def test_index_geo2d(self): db = self.db db.test.drop_indexes() - self.assertEqual('loc_2d', db.test.create_index([("loc", GEO2D)])) - index_info = db.test.index_information()['loc_2d'] - self.assertEqual([('loc', '2d')], index_info['key']) + self.assertEqual("loc_2d", db.test.create_index([("loc", GEO2D)])) + index_info = db.test.index_information()["loc_2d"] + self.assertEqual([("loc", "2d")], index_info["key"]) # geoSearch was deprecated in 4.4 and removed in 5.0 @client_context.require_version_max(4, 5) @@ -413,35 +401,29 @@ def test_index_geo2d(self): def test_index_haystack(self): db = self.db db.test.drop() - _id = db.test.insert_one({ - "pos": {"long": 34.2, "lat": 33.3}, - "type": "restaurant" - }).inserted_id - db.test.insert_one({ - "pos": {"long": 34.2, "lat": 37.3}, "type": "restaurant" - }) - db.test.insert_one({ - "pos": {"long": 59.1, "lat": 87.2}, "type": "office" - }) - db.test.create_index( - [("pos", "geoHaystack"), ("type", ASCENDING)], - bucketSize=1 - ) - - results = db.command(SON([ - ("geoSearch", "test"), - ("near", [33, 33]), - ("maxDistance", 6), - ("search", {"type": "restaurant"}), - ("limit", 30), - ]))['results'] + _id = db.test.insert_one( + {"pos": {"long": 34.2, "lat": 33.3}, "type": "restaurant"} + ).inserted_id + db.test.insert_one({"pos": {"long": 34.2, "lat": 37.3}, "type": "restaurant"}) + db.test.insert_one({"pos": {"long": 59.1, "lat": 87.2}, "type": "office"}) + db.test.create_index([("pos", "geoHaystack"), ("type", ASCENDING)], bucketSize=1) + + results = db.command( + SON( + [ + ("geoSearch", "test"), + ("near", [33, 33]), + ("maxDistance", 6), + ("search", {"type": "restaurant"}), + ("limit", 30), + ] + ) + )["results"] self.assertEqual(2, len(results)) - self.assertEqual({ - "_id": _id, - "pos": {"long": 34.2, "lat": 33.3}, - "type": "restaurant" - }, results[0]) + self.assertEqual( + {"_id": _id, "pos": {"long": 34.2, "lat": 33.3}, "type": "restaurant"}, results[0] + ) @client_context.require_no_mongos def test_index_text(self): @@ -451,38 +433,33 @@ def test_index_text(self): index_info = db.test.index_information()["t_text"] self.assertTrue("weights" in index_info) - db.test.insert_many([ - {'t': 'spam eggs and spam'}, - {'t': 'spam'}, - {'t': 'egg sausage and bacon'}]) + db.test.insert_many( + [{"t": "spam eggs and spam"}, {"t": "spam"}, {"t": "egg sausage and bacon"}] + ) # MongoDB 2.6 text search. Create 'score' field in projection. - cursor = db.test.find( - {'$text': {'$search': 'spam'}}, - {'score': {'$meta': 'textScore'}}) + cursor = db.test.find({"$text": {"$search": "spam"}}, {"score": {"$meta": "textScore"}}) # Sort by 'score' field. - cursor.sort([('score', {'$meta': 'textScore'})]) + cursor.sort([("score", {"$meta": "textScore"})]) results = list(cursor) - self.assertTrue(results[0]['score'] >= results[1]['score']) + self.assertTrue(results[0]["score"] >= results[1]["score"]) db.test.drop_indexes() def test_index_2dsphere(self): db = self.db db.test.drop_indexes() - self.assertEqual("geo_2dsphere", - db.test.create_index([("geo", GEOSPHERE)])) + self.assertEqual("geo_2dsphere", db.test.create_index([("geo", GEOSPHERE)])) for dummy, info in db.test.index_information().items(): - field, idx_type = info['key'][0] - if field == 'geo' and idx_type == '2dsphere': + field, idx_type = info["key"][0] + if field == "geo" and idx_type == "2dsphere": break else: self.fail("2dsphere index not found.") - poly = {"type": "Polygon", - "coordinates": [[[40, 5], [40, 6], [41, 6], [41, 5], [40, 5]]]} + poly = {"type": "Polygon", "coordinates": [[[40, 5], [40, 6], [41, 6], [41, 5], [40, 5]]]} query = {"geo": {"$within": {"$geometry": poly}}} # This query will error without a 2dsphere index. @@ -492,12 +469,11 @@ def test_index_2dsphere(self): def test_index_hashed(self): db = self.db db.test.drop_indexes() - self.assertEqual("a_hashed", - db.test.create_index([("a", HASHED)])) + self.assertEqual("a_hashed", db.test.create_index([("a", HASHED)])) for dummy, info in db.test.index_information().items(): - field, idx_type = info['key'][0] - if field == 'a' and idx_type == 'hashed': + field, idx_type = info["key"][0] + if field == "a" and idx_type == "hashed": break else: self.fail("hashed index not found.") @@ -507,25 +483,25 @@ def test_index_hashed(self): def test_index_sparse(self): db = self.db db.test.drop_indexes() - db.test.create_index([('key', ASCENDING)], sparse=True) - self.assertTrue(db.test.index_information()['key_1']['sparse']) + db.test.create_index([("key", ASCENDING)], sparse=True) + self.assertTrue(db.test.index_information()["key_1"]["sparse"]) def test_index_background(self): db = self.db db.test.drop_indexes() - db.test.create_index([('keya', ASCENDING)]) - db.test.create_index([('keyb', ASCENDING)], background=False) - db.test.create_index([('keyc', ASCENDING)], background=True) - self.assertFalse('background' in db.test.index_information()['keya_1']) - self.assertFalse(db.test.index_information()['keyb_1']['background']) - self.assertTrue(db.test.index_information()['keyc_1']['background']) + db.test.create_index([("keya", ASCENDING)]) + db.test.create_index([("keyb", ASCENDING)], background=False) + db.test.create_index([("keyc", ASCENDING)], background=True) + self.assertFalse("background" in db.test.index_information()["keya_1"]) + self.assertFalse(db.test.index_information()["keyb_1"]["background"]) + self.assertTrue(db.test.index_information()["keyc_1"]["background"]) def _drop_dups_setup(self, db): - db.drop_collection('test') - db.test.insert_one({'i': 1}) - db.test.insert_one({'i': 2}) - db.test.insert_one({'i': 2}) # duplicate - db.test.insert_one({'i': 3}) + db.drop_collection("test") + db.test.insert_one({"i": 1}) + db.test.insert_one({"i": 2}) + db.test.insert_one({"i": 2}) # duplicate + db.test.insert_one({"i": 3}) def test_index_dont_drop_dups(self): # Try *not* dropping duplicates @@ -534,11 +510,8 @@ def test_index_dont_drop_dups(self): # There's a duplicate def test_create(): - db.test.create_index( - [('i', ASCENDING)], - unique=True, - dropDups=False - ) + db.test.create_index([("i", ASCENDING)], unique=True, dropDups=False) + self.assertRaises(DuplicateKeyError, test_create) # Duplicate wasn't dropped @@ -549,12 +522,12 @@ def test_create(): # Get the plan dynamically because the explain format will change. def get_plan_stage(self, root, stage): - if root.get('stage') == stage: + if root.get("stage") == stage: return root elif "inputStage" in root: - return self.get_plan_stage(root['inputStage'], stage) + return self.get_plan_stage(root["inputStage"], stage) elif "inputStages" in root: - for i in root['inputStages']: + for i in root["inputStages"]: stage = self.get_plan_stage(i, stage) if stage: return stage @@ -562,8 +535,8 @@ def get_plan_stage(self, root, stage): # queryPlan (and slotBasedPlan) are new in 5.0. return self.get_plan_stage(root["queryPlan"], stage) elif "shards" in root: - for i in root['shards']: - stage = self.get_plan_stage(i['winningPlan'], stage) + for i in root["shards"]: + stage = self.get_plan_stage(i["winningPlan"], stage) if stage: return stage return {} @@ -573,52 +546,52 @@ def test_index_filter(self): db.drop_collection("test") # Test bad filter spec on create. - self.assertRaises(OperationFailure, db.test.create_index, "x", - partialFilterExpression=5) - self.assertRaises(OperationFailure, db.test.create_index, "x", - partialFilterExpression={"x": {"$asdasd": 3}}) - self.assertRaises(OperationFailure, db.test.create_index, "x", - partialFilterExpression={"$and": 5}) - - self.assertEqual("x_1", db.test.create_index( - [('x', ASCENDING)], partialFilterExpression={"a": {"$lte": 1.5}})) + self.assertRaises(OperationFailure, db.test.create_index, "x", partialFilterExpression=5) + self.assertRaises( + OperationFailure, + db.test.create_index, + "x", + partialFilterExpression={"x": {"$asdasd": 3}}, + ) + self.assertRaises( + OperationFailure, db.test.create_index, "x", partialFilterExpression={"$and": 5} + ) + + self.assertEqual( + "x_1", + db.test.create_index([("x", ASCENDING)], partialFilterExpression={"a": {"$lte": 1.5}}), + ) db.test.insert_one({"x": 5, "a": 2}) db.test.insert_one({"x": 6, "a": 1}) # Operations that use the partial index. explain = db.test.find({"x": 6, "a": 1}).explain() - stage = self.get_plan_stage(explain['queryPlanner']['winningPlan'], - 'IXSCAN') - self.assertEqual("x_1", stage.get('indexName')) - self.assertTrue(stage.get('isPartial')) + stage = self.get_plan_stage(explain["queryPlanner"]["winningPlan"], "IXSCAN") + self.assertEqual("x_1", stage.get("indexName")) + self.assertTrue(stage.get("isPartial")) explain = db.test.find({"x": {"$gt": 1}, "a": 1}).explain() - stage = self.get_plan_stage(explain['queryPlanner']['winningPlan'], - 'IXSCAN') - self.assertEqual("x_1", stage.get('indexName')) - self.assertTrue(stage.get('isPartial')) + stage = self.get_plan_stage(explain["queryPlanner"]["winningPlan"], "IXSCAN") + self.assertEqual("x_1", stage.get("indexName")) + self.assertTrue(stage.get("isPartial")) explain = db.test.find({"x": 6, "a": {"$lte": 1}}).explain() - stage = self.get_plan_stage(explain['queryPlanner']['winningPlan'], - 'IXSCAN') - self.assertEqual("x_1", stage.get('indexName')) - self.assertTrue(stage.get('isPartial')) + stage = self.get_plan_stage(explain["queryPlanner"]["winningPlan"], "IXSCAN") + self.assertEqual("x_1", stage.get("indexName")) + self.assertTrue(stage.get("isPartial")) # Operations that do not use the partial index. explain = db.test.find({"x": 6, "a": {"$lte": 1.6}}).explain() - stage = self.get_plan_stage(explain['queryPlanner']['winningPlan'], - 'COLLSCAN') + stage = self.get_plan_stage(explain["queryPlanner"]["winningPlan"], "COLLSCAN") self.assertNotEqual({}, stage) explain = db.test.find({"x": 6}).explain() - stage = self.get_plan_stage(explain['queryPlanner']['winningPlan'], - 'COLLSCAN') + stage = self.get_plan_stage(explain["queryPlanner"]["winningPlan"], "COLLSCAN") self.assertNotEqual({}, stage) # Test drop_indexes. db.test.drop_index("x_1") explain = db.test.find({"x": 6, "a": 1}).explain() - stage = self.get_plan_stage(explain['queryPlanner']['winningPlan'], - 'COLLSCAN') + stage = self.get_plan_stage(explain["queryPlanner"]["winningPlan"], "COLLSCAN") self.assertNotEqual({}, stage) def test_field_selection(self): @@ -681,8 +654,8 @@ def test_options(self): db.create_collection("test", capped=True, size=4096) result = db.test.options() # mongos 2.2.x adds an $auth field when auth is enabled. - result.pop('$auth', None) - self.assertEqual(result, {"capped": True, 'size': 4096}) + result.pop("$auth", None) + self.assertEqual(result, {"capped": True, "size": 4096}) db.drop_collection("test") def test_insert_one(self): @@ -707,19 +680,16 @@ def test_insert_one(self): self.assertIsNotNone(db.test.find_one({"_id": document["_id"]})) self.assertEqual(2, db.test.count_documents({})) - db = db.client.get_database(db.name, - write_concern=WriteConcern(w=0)) + db = db.client.get_database(db.name, write_concern=WriteConcern(w=0)) result = db.test.insert_one(document) self.assertTrue(isinstance(result, InsertOneResult)) self.assertTrue(isinstance(result.inserted_id, ObjectId)) self.assertEqual(document["_id"], result.inserted_id) self.assertFalse(result.acknowledged) # The insert failed duplicate key... - wait_until(lambda: 2 == db.test.count_documents({}), - 'forcing duplicate key error') + wait_until(lambda: 2 == db.test.count_documents({}), "forcing duplicate key error") - document = RawBSONDocument( - encode({'_id': ObjectId(), 'foo': 'bar'})) + document = RawBSONDocument(encode({"_id": ObjectId(), "foo": "bar"})) result = db.test.insert_one(document) self.assertTrue(isinstance(result, InsertOneResult)) self.assertEqual(result.inserted_id, None) @@ -737,7 +707,7 @@ def test_insert_many(self): _id = doc["_id"] self.assertTrue(isinstance(_id, ObjectId)) self.assertTrue(_id in result.inserted_ids) - self.assertEqual(1, db.test.count_documents({'_id': _id})) + self.assertEqual(1, db.test.count_documents({"_id": _id})) self.assertTrue(result.acknowledged) docs = [{"_id": i} for i in range(5)] @@ -752,15 +722,13 @@ def test_insert_many(self): self.assertEqual(1, db.test.count_documents({"_id": _id})) self.assertTrue(result.acknowledged) - docs = [RawBSONDocument(encode({"_id": i + 5})) - for i in range(5)] + docs = [RawBSONDocument(encode({"_id": i + 5})) for i in range(5)] result = db.test.insert_many(docs) self.assertTrue(isinstance(result, InsertManyResult)) self.assertTrue(isinstance(result.inserted_ids, list)) self.assertEqual([], result.inserted_ids) - db = db.client.get_database(db.name, - write_concern=WriteConcern(w=0)) + db = db.client.get_database(db.name, write_concern=WriteConcern(w=0)) docs = [{} for _ in range(5)] result = db.test.insert_many(docs) self.assertTrue(isinstance(result, InsertManyResult)) @@ -772,11 +740,11 @@ def test_insert_many_generator(self): coll.delete_many({}) def gen(): - yield {'a': 1, 'b': 1} - yield {'a': 1, 'b': 2} - yield {'a': 2, 'b': 3} - yield {'a': 3, 'b': 5} - yield {'a': 5, 'b': 8} + yield {"a": 1, "b": 1} + yield {"a": 1, "b": 2} + yield {"a": 2, "b": 3} + yield {"a": 3, "b": 5} + yield {"a": 5, "b": 8} result = coll.insert_many(gen()) self.assertEqual(5, len(result.inserted_ids)) @@ -784,21 +752,17 @@ def gen(): def test_insert_many_invalid(self): db = self.db - with self.assertRaisesRegex( - TypeError, "documents must be a non-empty list"): + with self.assertRaisesRegex(TypeError, "documents must be a non-empty list"): db.test.insert_many({}) - with self.assertRaisesRegex( - TypeError, "documents must be a non-empty list"): + with self.assertRaisesRegex(TypeError, "documents must be a non-empty list"): db.test.insert_many([]) - with self.assertRaisesRegex( - TypeError, "documents must be a non-empty list"): + with self.assertRaisesRegex(TypeError, "documents must be a non-empty list"): db.test.insert_many(1) - with self.assertRaisesRegex( - TypeError, "documents must be a non-empty list"): - db.test.insert_many(RawBSONDocument(encode({'_id': 2}))) + with self.assertRaisesRegex(TypeError, "documents must be a non-empty list"): + db.test.insert_many(RawBSONDocument(encode({"_id": 2}))) def test_delete_one(self): self.db.test.drop() @@ -819,13 +783,12 @@ def test_delete_one(self): self.assertTrue(result.acknowledged) self.assertEqual(1, self.db.test.count_documents({})) - db = self.db.client.get_database(self.db.name, - write_concern=WriteConcern(w=0)) + db = self.db.client.get_database(self.db.name, write_concern=WriteConcern(w=0)) result = db.test.delete_one({"z": 1}) self.assertTrue(isinstance(result, DeleteResult)) self.assertRaises(InvalidOperation, lambda: result.deleted_count) self.assertFalse(result.acknowledged) - wait_until(lambda: 0 == db.test.count_documents({}), 'delete 1 documents') + wait_until(lambda: 0 == db.test.count_documents({}), "delete 1 documents") def test_delete_many(self): self.db.test.drop() @@ -841,25 +804,20 @@ def test_delete_many(self): self.assertTrue(result.acknowledged) self.assertEqual(0, self.db.test.count_documents({"x": 1})) - db = self.db.client.get_database(self.db.name, - write_concern=WriteConcern(w=0)) + db = self.db.client.get_database(self.db.name, write_concern=WriteConcern(w=0)) result = db.test.delete_many({"y": 1}) self.assertTrue(isinstance(result, DeleteResult)) self.assertRaises(InvalidOperation, lambda: result.deleted_count) self.assertFalse(result.acknowledged) - wait_until( - lambda: 0 == db.test.count_documents({}), 'delete 2 documents') + wait_until(lambda: 0 == db.test.count_documents({}), "delete 2 documents") def test_command_document_too_large(self): - large = '*' * (client_context.max_bson_size + _COMMAND_OVERHEAD) + large = "*" * (client_context.max_bson_size + _COMMAND_OVERHEAD) coll = self.db.test - self.assertRaises( - DocumentTooLarge, coll.insert_one, {'data': large}) + self.assertRaises(DocumentTooLarge, coll.insert_one, {"data": large}) # update_one and update_many are the same - self.assertRaises( - DocumentTooLarge, coll.replace_one, {}, {'data': large}) - self.assertRaises( - DocumentTooLarge, coll.delete_one, {'data': large}) + self.assertRaises(DocumentTooLarge, coll.replace_one, {}, {"data": large}) + self.assertRaises(DocumentTooLarge, coll.delete_one, {"data": large}) def test_write_large_document(self): max_size = client_context.max_bson_size @@ -868,42 +826,38 @@ def test_write_large_document(self): half_str = "x" * half_size self.assertEqual(max_size, 16777216) - self.assertRaises(OperationFailure, self.db.test.insert_one, - {"foo": max_str}) - self.assertRaises(OperationFailure, self.db.test.replace_one, - {}, {"foo": max_str}, upsert=True) - self.assertRaises(OperationFailure, self.db.test.insert_many, - [{"x": 1}, {"foo": max_str}]) + self.assertRaises(OperationFailure, self.db.test.insert_one, {"foo": max_str}) + self.assertRaises( + OperationFailure, self.db.test.replace_one, {}, {"foo": max_str}, upsert=True + ) + self.assertRaises(OperationFailure, self.db.test.insert_many, [{"x": 1}, {"foo": max_str}]) self.db.test.insert_many([{"foo": half_str}, {"foo": half_str}]) self.db.test.insert_one({"bar": "x"}) # Use w=0 here to test legacy doc size checking in all server versions unack_coll = self.db.test.with_options(write_concern=WriteConcern(w=0)) - self.assertRaises(DocumentTooLarge, unack_coll.replace_one, - {"bar": "x"}, {"bar": "x" * (max_size - 14)}) + self.assertRaises( + DocumentTooLarge, unack_coll.replace_one, {"bar": "x"}, {"bar": "x" * (max_size - 14)} + ) self.db.test.replace_one({"bar": "x"}, {"bar": "x" * (max_size - 32)}) def test_insert_bypass_document_validation(self): db = self.db db.test.drop() db.create_collection("test", validator={"a": {"$exists": True}}) - db_w0 = self.db.client.get_database( - self.db.name, write_concern=WriteConcern(w=0)) + db_w0 = self.db.client.get_database(self.db.name, write_concern=WriteConcern(w=0)) # Test insert_one - self.assertRaises(OperationFailure, db.test.insert_one, - {"_id": 1, "x": 100}) - result = db.test.insert_one({"_id": 1, "x": 100}, - bypass_document_validation=True) + self.assertRaises(OperationFailure, db.test.insert_one, {"_id": 1, "x": 100}) + result = db.test.insert_one({"_id": 1, "x": 100}, bypass_document_validation=True) self.assertTrue(isinstance(result, InsertOneResult)) self.assertEqual(1, result.inserted_id) - result = db.test.insert_one({"_id":2, "a":0}) + result = db.test.insert_one({"_id": 2, "a": 0}) self.assertTrue(isinstance(result, InsertOneResult)) self.assertEqual(2, result.inserted_id) db_w0.test.insert_one({"y": 1}, bypass_document_validation=True) - wait_until(lambda: db_w0.test.find_one({"y": 1}), - "find w:0 inserted document") + wait_until(lambda: db_w0.test.find_one({"y": 1}), "find w:0 inserted document") # Test insert_many docs = [{"_id": i, "x": 100 - i} for i in range(3, 100)] @@ -928,25 +882,25 @@ def test_insert_bypass_document_validation(self): self.assertEqual(1, db.test.count_documents({"a": doc["a"]})) self.assertTrue(result.acknowledged) - self.assertRaises(OperationFailure, db_w0.test.insert_many, - [{"x": 1}, {"x": 2}], - bypass_document_validation=True) + self.assertRaises( + OperationFailure, + db_w0.test.insert_many, + [{"x": 1}, {"x": 2}], + bypass_document_validation=True, + ) def test_replace_bypass_document_validation(self): db = self.db db.test.drop() db.create_collection("test", validator={"a": {"$exists": True}}) - db_w0 = self.db.client.get_database( - self.db.name, write_concern=WriteConcern(w=0)) + db_w0 = self.db.client.get_database(self.db.name, write_concern=WriteConcern(w=0)) # Test replace_one db.test.insert_one({"a": 101}) - self.assertRaises(OperationFailure, db.test.replace_one, - {"a": 101}, {"y": 1}) + self.assertRaises(OperationFailure, db.test.replace_one, {"a": 101}, {"y": 1}) self.assertEqual(0, db.test.count_documents({"y": 1})) self.assertEqual(1, db.test.count_documents({"a": 101})) - db.test.replace_one({"a": 101}, {"y": 1}, - bypass_document_validation=True) + db.test.replace_one({"a": 101}, {"y": 1}, bypass_document_validation=True) self.assertEqual(0, db.test.count_documents({"a": 101})) self.assertEqual(1, db.test.count_documents({"y": 1})) db.test.replace_one({"y": 1}, {"a": 102}) @@ -955,123 +909,107 @@ def test_replace_bypass_document_validation(self): self.assertEqual(1, db.test.count_documents({"a": 102})) db.test.insert_one({"y": 1}, bypass_document_validation=True) - self.assertRaises(OperationFailure, db.test.replace_one, - {"y": 1}, {"x": 101}) + self.assertRaises(OperationFailure, db.test.replace_one, {"y": 1}, {"x": 101}) self.assertEqual(0, db.test.count_documents({"x": 101})) self.assertEqual(1, db.test.count_documents({"y": 1})) - db.test.replace_one({"y": 1}, {"x": 101}, - bypass_document_validation=True) + db.test.replace_one({"y": 1}, {"x": 101}, bypass_document_validation=True) self.assertEqual(0, db.test.count_documents({"y": 1})) self.assertEqual(1, db.test.count_documents({"x": 101})) - db.test.replace_one({"x": 101}, {"a": 103}, - bypass_document_validation=False) + db.test.replace_one({"x": 101}, {"a": 103}, bypass_document_validation=False) self.assertEqual(0, db.test.count_documents({"x": 101})) self.assertEqual(1, db.test.count_documents({"a": 103})) db.test.insert_one({"y": 1}, bypass_document_validation=True) - db_w0.test.replace_one({"y": 1}, {"x": 1}, - bypass_document_validation=True) - wait_until(lambda: db_w0.test.find_one({"x": 1}), - "find w:0 replaced document") + db_w0.test.replace_one({"y": 1}, {"x": 1}, bypass_document_validation=True) + wait_until(lambda: db_w0.test.find_one({"x": 1}), "find w:0 replaced document") def test_update_bypass_document_validation(self): db = self.db db.test.drop() db.test.insert_one({"z": 5}) - db.command(SON([("collMod", "test"), - ("validator", {"z": {"$gte": 0}})])) - db_w0 = self.db.client.get_database( - self.db.name, write_concern=WriteConcern(w=0)) + db.command(SON([("collMod", "test"), ("validator", {"z": {"$gte": 0}})])) + db_w0 = self.db.client.get_database(self.db.name, write_concern=WriteConcern(w=0)) # Test update_one - self.assertRaises(OperationFailure, db.test.update_one, - {"z": 5}, {"$inc": {"z": -10}}) + self.assertRaises(OperationFailure, db.test.update_one, {"z": 5}, {"$inc": {"z": -10}}) self.assertEqual(0, db.test.count_documents({"z": -5})) self.assertEqual(1, db.test.count_documents({"z": 5})) - db.test.update_one({"z": 5}, {"$inc": {"z": -10}}, - bypass_document_validation=True) + db.test.update_one({"z": 5}, {"$inc": {"z": -10}}, bypass_document_validation=True) self.assertEqual(0, db.test.count_documents({"z": 5})) self.assertEqual(1, db.test.count_documents({"z": -5})) - db.test.update_one({"z": -5}, {"$inc": {"z": 6}}, - bypass_document_validation=False) + db.test.update_one({"z": -5}, {"$inc": {"z": 6}}, bypass_document_validation=False) self.assertEqual(1, db.test.count_documents({"z": 1})) self.assertEqual(0, db.test.count_documents({"z": -5})) - db.test.insert_one({"z": -10}, - bypass_document_validation=True) - self.assertRaises(OperationFailure, db.test.update_one, - {"z": -10}, {"$inc": {"z": 1}}) + db.test.insert_one({"z": -10}, bypass_document_validation=True) + self.assertRaises(OperationFailure, db.test.update_one, {"z": -10}, {"$inc": {"z": 1}}) self.assertEqual(0, db.test.count_documents({"z": -9})) self.assertEqual(1, db.test.count_documents({"z": -10})) - db.test.update_one({"z": -10}, {"$inc": {"z": 1}}, - bypass_document_validation=True) + db.test.update_one({"z": -10}, {"$inc": {"z": 1}}, bypass_document_validation=True) self.assertEqual(1, db.test.count_documents({"z": -9})) self.assertEqual(0, db.test.count_documents({"z": -10})) - db.test.update_one({"z": -9}, {"$inc": {"z": 9}}, - bypass_document_validation=False) + db.test.update_one({"z": -9}, {"$inc": {"z": 9}}, bypass_document_validation=False) self.assertEqual(0, db.test.count_documents({"z": -9})) self.assertEqual(1, db.test.count_documents({"z": 0})) db.test.insert_one({"y": 1, "x": 0}, bypass_document_validation=True) - db_w0.test.update_one({"y": 1}, {"$inc": {"x": 1}}, - bypass_document_validation=True) - wait_until(lambda: db_w0.test.find_one({"y": 1, "x": 1}), - "find w:0 updated document") + db_w0.test.update_one({"y": 1}, {"$inc": {"x": 1}}, bypass_document_validation=True) + wait_until(lambda: db_w0.test.find_one({"y": 1, "x": 1}), "find w:0 updated document") # Test update_many db.test.insert_many([{"z": i} for i in range(3, 101)]) - db.test.insert_one({"y": 0}, - bypass_document_validation=True) - self.assertRaises(OperationFailure, db.test.update_many, {}, - {"$inc": {"z": -100}}) + db.test.insert_one({"y": 0}, bypass_document_validation=True) + self.assertRaises(OperationFailure, db.test.update_many, {}, {"$inc": {"z": -100}}) self.assertEqual(100, db.test.count_documents({"z": {"$gte": 0}})) self.assertEqual(0, db.test.count_documents({"z": {"$lt": 0}})) self.assertEqual(0, db.test.count_documents({"y": 0, "z": -100})) - db.test.update_many({"z": {"$gte": 0}}, {"$inc": {"z": -100}}, - bypass_document_validation=True) + db.test.update_many( + {"z": {"$gte": 0}}, {"$inc": {"z": -100}}, bypass_document_validation=True + ) self.assertEqual(0, db.test.count_documents({"z": {"$gt": 0}})) self.assertEqual(100, db.test.count_documents({"z": {"$lte": 0}})) - db.test.update_many({"z": {"$gt": -50}}, {"$inc": {"z": 100}}, - bypass_document_validation=False) + db.test.update_many( + {"z": {"$gt": -50}}, {"$inc": {"z": 100}}, bypass_document_validation=False + ) self.assertEqual(50, db.test.count_documents({"z": {"$gt": 0}})) self.assertEqual(50, db.test.count_documents({"z": {"$lt": 0}})) - db.test.insert_many([{"z": -i} for i in range(50)], - bypass_document_validation=True) - self.assertRaises(OperationFailure, db.test.update_many, - {}, {"$inc": {"z": 1}}) + db.test.insert_many([{"z": -i} for i in range(50)], bypass_document_validation=True) + self.assertRaises(OperationFailure, db.test.update_many, {}, {"$inc": {"z": 1}}) self.assertEqual(100, db.test.count_documents({"z": {"$lte": 0}})) self.assertEqual(50, db.test.count_documents({"z": {"$gt": 1}})) - db.test.update_many({"z": {"$gte": 0}}, {"$inc": {"z": -100}}, - bypass_document_validation=True) + db.test.update_many( + {"z": {"$gte": 0}}, {"$inc": {"z": -100}}, bypass_document_validation=True + ) self.assertEqual(0, db.test.count_documents({"z": {"$gt": 0}})) self.assertEqual(150, db.test.count_documents({"z": {"$lte": 0}})) - db.test.update_many({"z": {"$lte": 0}}, {"$inc": {"z": 100}}, - bypass_document_validation=False) + db.test.update_many( + {"z": {"$lte": 0}}, {"$inc": {"z": 100}}, bypass_document_validation=False + ) self.assertEqual(150, db.test.count_documents({"z": {"$gte": 0}})) self.assertEqual(0, db.test.count_documents({"z": {"$lt": 0}})) db.test.insert_one({"m": 1, "x": 0}, bypass_document_validation=True) db.test.insert_one({"m": 1, "x": 0}, bypass_document_validation=True) - db_w0.test.update_many({"m": 1}, {"$inc": {"x": 1}}, - bypass_document_validation=True) + db_w0.test.update_many({"m": 1}, {"$inc": {"x": 1}}, bypass_document_validation=True) wait_until( - lambda: db_w0.test.count_documents({"m": 1, "x": 1}) == 2, - "find w:0 updated documents") + lambda: db_w0.test.count_documents({"m": 1, "x": 1}) == 2, "find w:0 updated documents" + ) def test_bypass_document_validation_bulk_write(self): db = self.db db.test.drop() db.create_collection("test", validator={"a": {"$gte": 0}}) - db_w0 = self.db.client.get_database( - self.db.name, write_concern=WriteConcern(w=0)) - - ops = [InsertOne({"a": -10}), - InsertOne({"a": -11}), - InsertOne({"a": -12}), - UpdateOne({"a": {"$lte": -10}}, {"$inc": {"a": 1}}), - UpdateMany({"a": {"$lte": -10}}, {"$inc": {"a": 1}}), - ReplaceOne({"a": {"$lte": -10}}, {"a": -1})] + db_w0 = self.db.client.get_database(self.db.name, write_concern=WriteConcern(w=0)) + + ops = [ + InsertOne({"a": -10}), + InsertOne({"a": -11}), + InsertOne({"a": -12}), + UpdateOne({"a": {"$lte": -10}}, {"$inc": {"a": 1}}), + UpdateMany({"a": {"$lte": -10}}, {"$inc": {"a": 1}}), + ReplaceOne({"a": {"$lte": -10}}, {"a": -1}), + ] db.test.bulk_write(ops, bypass_document_validation=True) self.assertEqual(3, db.test.count_documents({})) @@ -1083,22 +1021,22 @@ def test_bypass_document_validation_bulk_write(self): for op in ops: self.assertRaises(BulkWriteError, db.test.bulk_write, [op]) - self.assertRaises(OperationFailure, db_w0.test.bulk_write, ops, - bypass_document_validation=True) + self.assertRaises( + OperationFailure, db_w0.test.bulk_write, ops, bypass_document_validation=True + ) def test_find_by_default_dct(self): db = self.db - db.test.insert_one({'foo': 'bar'}) - dct = defaultdict(dict, [('foo', 'bar')]) + db.test.insert_one({"foo": "bar"}) + dct = defaultdict(dict, [("foo", "bar")]) self.assertIsNotNone(db.test.find_one(dct)) - self.assertEqual(dct, defaultdict(dict, [('foo', 'bar')])) + self.assertEqual(dct, defaultdict(dict, [("foo", "bar")])) def test_find_w_fields(self): db = self.db db.test.delete_many({}) - db.test.insert_one({"x": 1, "mike": "awesome", - "extra thing": "abcdefghijklmnopqrstuvwxyz"}) + db.test.insert_one({"x": 1, "mike": "awesome", "extra thing": "abcdefghijklmnopqrstuvwxyz"}) self.assertEqual(1, db.test.count_documents({})) doc = next(db.test.find({})) self.assertTrue("x" in doc) @@ -1126,9 +1064,7 @@ def test_fields_specifier_as_dict(self): db.test.insert_one({"x": [1, 2, 3], "mike": "awesome"}) self.assertEqual([1, 2, 3], db.test.find_one()["x"]) - self.assertEqual([2, 3], - db.test.find_one( - projection={"x": {"$slice": -2}})["x"]) + self.assertEqual([2, 3], db.test.find_one(projection={"x": {"$slice": -2}})["x"]) self.assertTrue("x" not in db.test.find_one(projection={"x": 0})) self.assertTrue("mike" in db.test.find_one(projection={"x": 0})) @@ -1142,14 +1078,10 @@ def test_find_w_regex(self): db.test.insert_one({"x": "hello_test"}) self.assertEqual(len(list(db.test.find())), 4) - self.assertEqual(len(list(db.test.find({"x": - re.compile("^hello.*")}))), 4) - self.assertEqual(len(list(db.test.find({"x": - re.compile("ello")}))), 4) - self.assertEqual(len(list(db.test.find({"x": - re.compile("^hello$")}))), 0) - self.assertEqual(len(list(db.test.find({"x": - re.compile("^hello_mi.*$")}))), 2) + self.assertEqual(len(list(db.test.find({"x": re.compile("^hello.*")}))), 4) + self.assertEqual(len(list(db.test.find({"x": re.compile("ello")}))), 4) + self.assertEqual(len(list(db.test.find({"x": re.compile("^hello$")}))), 0) + self.assertEqual(len(list(db.test.find({"x": re.compile("^hello_mi.*$")}))), 2) def test_id_can_be_anything(self): db = self.db @@ -1213,83 +1145,74 @@ def test_write_error_text_handling(self): db.test.create_index("text", unique=True) # Test workaround for SERVER-24007 - data = (b'a\xe2\x98\x83\xe2\x98\x83\xe2\x98\x83\xe2\x98\x83' - b'\xe2\x98\x83\xe2\x98\x83\xe2\x98\x83\xe2\x98\x83' - b'\xe2\x98\x83\xe2\x98\x83\xe2\x98\x83\xe2\x98\x83' - b'\xe2\x98\x83\xe2\x98\x83\xe2\x98\x83\xe2\x98\x83' - b'\xe2\x98\x83\xe2\x98\x83\xe2\x98\x83\xe2\x98\x83' - b'\xe2\x98\x83\xe2\x98\x83\xe2\x98\x83\xe2\x98\x83' - b'\xe2\x98\x83\xe2\x98\x83\xe2\x98\x83\xe2\x98\x83' - b'\xe2\x98\x83\xe2\x98\x83\xe2\x98\x83\xe2\x98\x83' - b'\xe2\x98\x83\xe2\x98\x83\xe2\x98\x83\xe2\x98\x83' - b'\xe2\x98\x83\xe2\x98\x83\xe2\x98\x83\xe2\x98\x83' - b'\xe2\x98\x83\xe2\x98\x83\xe2\x98\x83\xe2\x98\x83' - b'\xe2\x98\x83\xe2\x98\x83\xe2\x98\x83\xe2\x98\x83' - b'\xe2\x98\x83\xe2\x98\x83\xe2\x98\x83\xe2\x98\x83' - b'\xe2\x98\x83\xe2\x98\x83\xe2\x98\x83\xe2\x98\x83' - b'\xe2\x98\x83\xe2\x98\x83\xe2\x98\x83\xe2\x98\x83' - b'\xe2\x98\x83\xe2\x98\x83\xe2\x98\x83\xe2\x98\x83' - b'\xe2\x98\x83\xe2\x98\x83\xe2\x98\x83\xe2\x98\x83' - b'\xe2\x98\x83\xe2\x98\x83\xe2\x98\x83\xe2\x98\x83' - b'\xe2\x98\x83\xe2\x98\x83\xe2\x98\x83\xe2\x98\x83' - b'\xe2\x98\x83\xe2\x98\x83\xe2\x98\x83\xe2\x98\x83' - b'\xe2\x98\x83\xe2\x98\x83\xe2\x98\x83\xe2\x98\x83' - b'\xe2\x98\x83\xe2\x98\x83\xe2\x98\x83\xe2\x98\x83' - b'\xe2\x98\x83\xe2\x98\x83\xe2\x98\x83\xe2\x98\x83' - b'\xe2\x98\x83\xe2\x98\x83\xe2\x98\x83\xe2\x98\x83' - b'\xe2\x98\x83\xe2\x98\x83\xe2\x98\x83\xe2\x98\x83' - b'\xe2\x98\x83\xe2\x98\x83\xe2\x98\x83\xe2\x98\x83' - b'\xe2\x98\x83\xe2\x98\x83\xe2\x98\x83\xe2\x98\x83' - b'\xe2\x98\x83\xe2\x98\x83\xe2\x98\x83\xe2\x98\x83') + data = ( + b"a\xe2\x98\x83\xe2\x98\x83\xe2\x98\x83\xe2\x98\x83" + b"\xe2\x98\x83\xe2\x98\x83\xe2\x98\x83\xe2\x98\x83" + b"\xe2\x98\x83\xe2\x98\x83\xe2\x98\x83\xe2\x98\x83" + b"\xe2\x98\x83\xe2\x98\x83\xe2\x98\x83\xe2\x98\x83" + b"\xe2\x98\x83\xe2\x98\x83\xe2\x98\x83\xe2\x98\x83" + b"\xe2\x98\x83\xe2\x98\x83\xe2\x98\x83\xe2\x98\x83" + b"\xe2\x98\x83\xe2\x98\x83\xe2\x98\x83\xe2\x98\x83" + b"\xe2\x98\x83\xe2\x98\x83\xe2\x98\x83\xe2\x98\x83" + b"\xe2\x98\x83\xe2\x98\x83\xe2\x98\x83\xe2\x98\x83" + b"\xe2\x98\x83\xe2\x98\x83\xe2\x98\x83\xe2\x98\x83" + b"\xe2\x98\x83\xe2\x98\x83\xe2\x98\x83\xe2\x98\x83" + b"\xe2\x98\x83\xe2\x98\x83\xe2\x98\x83\xe2\x98\x83" + b"\xe2\x98\x83\xe2\x98\x83\xe2\x98\x83\xe2\x98\x83" + b"\xe2\x98\x83\xe2\x98\x83\xe2\x98\x83\xe2\x98\x83" + b"\xe2\x98\x83\xe2\x98\x83\xe2\x98\x83\xe2\x98\x83" + b"\xe2\x98\x83\xe2\x98\x83\xe2\x98\x83\xe2\x98\x83" + b"\xe2\x98\x83\xe2\x98\x83\xe2\x98\x83\xe2\x98\x83" + b"\xe2\x98\x83\xe2\x98\x83\xe2\x98\x83\xe2\x98\x83" + b"\xe2\x98\x83\xe2\x98\x83\xe2\x98\x83\xe2\x98\x83" + b"\xe2\x98\x83\xe2\x98\x83\xe2\x98\x83\xe2\x98\x83" + b"\xe2\x98\x83\xe2\x98\x83\xe2\x98\x83\xe2\x98\x83" + b"\xe2\x98\x83\xe2\x98\x83\xe2\x98\x83\xe2\x98\x83" + b"\xe2\x98\x83\xe2\x98\x83\xe2\x98\x83\xe2\x98\x83" + b"\xe2\x98\x83\xe2\x98\x83\xe2\x98\x83\xe2\x98\x83" + b"\xe2\x98\x83\xe2\x98\x83\xe2\x98\x83\xe2\x98\x83" + b"\xe2\x98\x83\xe2\x98\x83\xe2\x98\x83\xe2\x98\x83" + b"\xe2\x98\x83\xe2\x98\x83\xe2\x98\x83\xe2\x98\x83" + b"\xe2\x98\x83\xe2\x98\x83\xe2\x98\x83\xe2\x98\x83" + ) text = utf_8_decode(data, None, True) db.test.insert_one({"text": text}) # Should raise DuplicateKeyError, not InvalidBSON - self.assertRaises(DuplicateKeyError, - db.test.insert_one, - {"text": text}) + self.assertRaises(DuplicateKeyError, db.test.insert_one, {"text": text}) - self.assertRaises(DuplicateKeyError, - db.test.replace_one, - {"_id": ObjectId()}, - {"text": text}, - upsert=True) + self.assertRaises( + DuplicateKeyError, db.test.replace_one, {"_id": ObjectId()}, {"text": text}, upsert=True + ) # Should raise BulkWriteError, not InvalidBSON - self.assertRaises(BulkWriteError, - db.test.insert_many, - [{"text": text}]) + self.assertRaises(BulkWriteError, db.test.insert_many, [{"text": text}]) def test_write_error_unicode(self): coll = self.db.test self.addCleanup(coll.drop) - coll.create_index('a', unique=True) - coll.insert_one({'a': 'unicode \U0001f40d'}) - with self.assertRaisesRegex( - DuplicateKeyError, - 'E11000 duplicate key error') as ctx: - coll.insert_one({'a': 'unicode \U0001f40d'}) + coll.create_index("a", unique=True) + coll.insert_one({"a": "unicode \U0001f40d"}) + with self.assertRaisesRegex(DuplicateKeyError, "E11000 duplicate key error") as ctx: + coll.insert_one({"a": "unicode \U0001f40d"}) # Once more for good measure. - self.assertIn('E11000 duplicate key error', - str(ctx.exception)) + self.assertIn("E11000 duplicate key error", str(ctx.exception)) def test_wtimeout(self): # Ensure setting wtimeout doesn't disable write concern altogether. # See SERVER-12596. collection = self.db.test collection.drop() - collection.insert_one({'_id': 1}) + collection.insert_one({"_id": 1}) - coll = collection.with_options( - write_concern=WriteConcern(w=1, wtimeout=1000)) - self.assertRaises(DuplicateKeyError, coll.insert_one, {'_id': 1}) + coll = collection.with_options(write_concern=WriteConcern(w=1, wtimeout=1000)) + self.assertRaises(DuplicateKeyError, coll.insert_one, {"_id": 1}) - coll = collection.with_options( - write_concern=WriteConcern(wtimeout=1000)) - self.assertRaises(DuplicateKeyError, coll.insert_one, {'_id': 1}) + coll = collection.with_options(write_concern=WriteConcern(wtimeout=1000)) + self.assertRaises(DuplicateKeyError, coll.insert_one, {"_id": 1}) def test_error_code(self): try: @@ -1315,16 +1238,13 @@ def test_index_on_subfield(self): db.test.insert_one({"hello": {"a": 4, "b": 5}}) db.test.insert_one({"hello": {"a": 7, "b": 2}}) - self.assertRaises(DuplicateKeyError, - db.test.insert_one, - {"hello": {"a": 4, "b": 10}}) + self.assertRaises(DuplicateKeyError, db.test.insert_one, {"hello": {"a": 4, "b": 10}}) def test_replace_one(self): db = self.db db.drop_collection("test") - self.assertRaises(ValueError, - lambda: db.test.replace_one({}, {"$set": {"x": 1}})) + self.assertRaises(ValueError, lambda: db.test.replace_one({}, {"$set": {"x": 1}})) id1 = db.test.insert_one({"x": 1}).inserted_id result = db.test.replace_one({"x": 1}, {"y": 1}) @@ -1356,8 +1276,7 @@ def test_replace_one(self): self.assertTrue(result.acknowledged) self.assertEqual(1, db.test.count_documents({"y": 2})) - db = db.client.get_database(db.name, - write_concern=WriteConcern(w=0)) + db = db.client.get_database(db.name, write_concern=WriteConcern(w=0)) result = db.test.replace_one({"x": 0}, {"y": 0}) self.assertTrue(isinstance(result, UpdateResult)) self.assertRaises(InvalidOperation, lambda: result.matched_count) @@ -1369,8 +1288,7 @@ def test_update_one(self): db = self.db db.drop_collection("test") - self.assertRaises(ValueError, - lambda: db.test.update_one({}, {"x": 1})) + self.assertRaises(ValueError, lambda: db.test.update_one({}, {"x": 1})) id1 = db.test.insert_one({"x": 5}).inserted_id result = db.test.update_one({}, {"$inc": {"x": 1}}) @@ -1398,8 +1316,7 @@ def test_update_one(self): self.assertTrue(isinstance(result.upserted_id, ObjectId)) self.assertTrue(result.acknowledged) - db = db.client.get_database(db.name, - write_concern=WriteConcern(w=0)) + db = db.client.get_database(db.name, write_concern=WriteConcern(w=0)) result = db.test.update_one({"x": 0}, {"$inc": {"x": 1}}) self.assertTrue(isinstance(result, UpdateResult)) self.assertRaises(InvalidOperation, lambda: result.matched_count) @@ -1411,8 +1328,7 @@ def test_update_many(self): db = self.db db.drop_collection("test") - self.assertRaises(ValueError, - lambda: db.test.update_many({}, {"x": 1})) + self.assertRaises(ValueError, lambda: db.test.update_many({}, {"x": 1})) db.test.insert_one({"x": 4, "y": 3}) db.test.insert_one({"x": 5, "y": 5}) @@ -1441,8 +1357,7 @@ def test_update_many(self): self.assertTrue(isinstance(result.upserted_id, ObjectId)) self.assertTrue(result.acknowledged) - db = db.client.get_database(db.name, - write_concern=WriteConcern(w=0)) + db = db.client.get_database(db.name, write_concern=WriteConcern(w=0)) result = db.test.update_many({"x": 0}, {"$inc": {"x": 1}}) self.assertTrue(isinstance(result, UpdateResult)) self.assertRaises(InvalidOperation, lambda: result.matched_count) @@ -1455,28 +1370,28 @@ def test_update_check_keys(self): self.assertTrue(self.db.test.insert_one({"hello": "world"})) # Modify shouldn't check keys... - self.assertTrue(self.db.test.update_one({"hello": "world"}, - {"$set": {"foo.bar": "baz"}}, - upsert=True)) + self.assertTrue( + self.db.test.update_one({"hello": "world"}, {"$set": {"foo.bar": "baz"}}, upsert=True) + ) # I know this seems like testing the server but I'd like to be notified # by CI if the server's behavior changes here. doc = SON([("$set", {"foo.bar": "bim"}), ("hello", "world")]) - self.assertRaises(OperationFailure, self.db.test.update_one, - {"hello": "world"}, doc, upsert=True) + self.assertRaises( + OperationFailure, self.db.test.update_one, {"hello": "world"}, doc, upsert=True + ) # This is going to cause keys to be checked and raise InvalidDocument. # That's OK assuming the server's behavior in the previous assert # doesn't change. If the behavior changes checking the first key for # '$' in update won't be good enough anymore. doc = SON([("hello", "world"), ("$set", {"foo.bar": "bim"})]) - self.assertRaises(OperationFailure, self.db.test.replace_one, - {"hello": "world"}, doc, upsert=True) + self.assertRaises( + OperationFailure, self.db.test.replace_one, {"hello": "world"}, doc, upsert=True + ) # Replace with empty document - self.assertNotEqual(0, - self.db.test.replace_one( - {"hello": "world"}, {}).matched_count) + self.assertNotEqual(0, self.db.test.replace_one({"hello": "world"}, {}).matched_count) def test_acknowledged_delete(self): db = self.db @@ -1510,10 +1425,9 @@ def test_count_documents(self): self.assertEqual(db.test.count_documents({}), 0) db.test.insert_many([{}, {}]) self.assertEqual(db.test.count_documents({}), 2) - db.test.insert_many([{'foo': 'bar'}, {'foo': 'baz'}]) - self.assertEqual(db.test.count_documents({'foo': 'bar'}), 1) - self.assertEqual( - db.test.count_documents({'foo': re.compile(r'ba.*')}), 2) + db.test.insert_many([{"foo": "bar"}, {"foo": "baz"}]) + self.assertEqual(db.test.count_documents({"foo": "bar"}), 1) + self.assertEqual(db.test.count_documents({"foo": re.compile(r"ba.*")}), 2) def test_estimated_document_count(self): db = self.db @@ -1529,39 +1443,37 @@ def test_estimated_document_count(self): def test_aggregate(self): db = self.db db.drop_collection("test") - db.test.insert_one({'foo': [1, 2]}) + db.test.insert_one({"foo": [1, 2]}) self.assertRaises(TypeError, db.test.aggregate, "wow") pipeline = {"$project": {"_id": False, "foo": True}} result = db.test.aggregate([pipeline]) self.assertTrue(isinstance(result, CommandCursor)) - self.assertEqual([{'foo': [1, 2]}], list(result)) + self.assertEqual([{"foo": [1, 2]}], list(result)) # Test write concern. with self.write_concern_collection() as coll: - coll.aggregate([{'$out': 'output-collection'}]) + coll.aggregate([{"$out": "output-collection"}]) def test_aggregate_raw_bson(self): db = self.db db.drop_collection("test") - db.test.insert_one({'foo': [1, 2]}) + db.test.insert_one({"foo": [1, 2]}) self.assertRaises(TypeError, db.test.aggregate, "wow") pipeline = {"$project": {"_id": False, "foo": True}} - coll = db.get_collection( - 'test', - codec_options=CodecOptions(document_class=RawBSONDocument)) + coll = db.get_collection("test", codec_options=CodecOptions(document_class=RawBSONDocument)) result = coll.aggregate([pipeline]) self.assertTrue(isinstance(result, CommandCursor)) first_result = next(result) self.assertIsInstance(first_result, RawBSONDocument) - self.assertEqual([1, 2], list(first_result['foo'])) + self.assertEqual([1, 2], list(first_result["foo"])) def test_aggregation_cursor_validation(self): db = self.db - projection = {'$project': {'_id': '$_id'}} + projection = {"$project": {"_id": "$_id"}} cursor = db.test.aggregate([projection], cursor={}) self.assertTrue(isinstance(cursor, CommandCursor)) @@ -1572,20 +1484,17 @@ def test_aggregation_cursor(self): db = self.client.get_database( db.name, read_preference=ReadPreference.SECONDARY, - write_concern=WriteConcern(w=self.w)) + write_concern=WriteConcern(w=self.w), + ) for collection_size in (10, 1000): db.drop_collection("test") - db.test.insert_many([{'_id': i} for i in range(collection_size)]) + db.test.insert_many([{"_id": i} for i in range(collection_size)]) expected_sum = sum(range(collection_size)) # Use batchSize to ensure multiple getMore messages - cursor = db.test.aggregate( - [{'$project': {'_id': '$_id'}}], - batchSize=5) + cursor = db.test.aggregate([{"$project": {"_id": "$_id"}}], batchSize=5) - self.assertEqual( - expected_sum, - sum(doc['_id'] for doc in cursor)) + self.assertEqual(expected_sum, sum(doc["_id"] for doc in cursor)) # Test that batchSize is handled properly. cursor = db.test.aggregate([], batchSize=5) @@ -1603,7 +1512,7 @@ def test_aggregation_cursor_alive(self): self.db.test.delete_many({}) self.db.test.insert_many([{} for _ in range(3)]) self.addCleanup(self.db.test.delete_many, {}) - cursor = self.db.test.aggregate(pipeline=[], cursor={'batchSize': 2}) + cursor = self.db.test.aggregate(pipeline=[], cursor={"batchSize": 2}) n = 0 while True: cursor.next() @@ -1617,15 +1526,14 @@ def test_aggregation_cursor_alive(self): def test_large_limit(self): db = self.db db.drop_collection("test_large_limit") - db.test_large_limit.create_index([('x', 1)]) + db.test_large_limit.create_index([("x", 1)]) my_str = "mongomongo" * 1000 - db.test_large_limit.insert_many( - {"x": i, "y": my_str} for i in range(2000)) + db.test_large_limit.insert_many({"x": i, "y": my_str} for i in range(2000)) i = 0 y = 0 - for doc in db.test_large_limit.find(limit=1900).sort([('x', 1)]): + for doc in db.test_large_limit.find(limit=1900).sort([("x", 1)]): i += 1 y += doc["x"] @@ -1679,7 +1587,7 @@ def test_rename(self): db.foo.rename("test", dropTarget=True) with self.write_concern_collection() as coll: - coll.rename('foo') + coll.rename("foo") def test_find_one(self): db = self.db @@ -1691,8 +1599,7 @@ def test_find_one(self): self.assertEqual(db.test.find_one(_id), db.test.find_one()) self.assertEqual(db.test.find_one(None), db.test.find_one()) self.assertEqual(db.test.find_one({}), db.test.find_one()) - self.assertEqual(db.test.find_one({"hello": "world"}), - db.test.find_one()) + self.assertEqual(db.test.find_one({"hello": "world"}), db.test.find_one()) self.assertTrue("hello" in db.test.find_one(projection=["hello"])) self.assertTrue("hello" not in db.test.find_one(projection=["foo"])) @@ -1706,8 +1613,7 @@ def test_find_one(self): self.assertTrue("hello" in db.test.find_one(projection=frozenset(["hello"]))) self.assertTrue("hello" not in db.test.find_one(projection=frozenset(["foo"]))) - self.assertEqual(["_id"], list(db.test.find_one(projection={'_id': - True}))) + self.assertEqual(["_id"], list(db.test.find_one(projection={"_id": True}))) self.assertTrue("hello" in list(db.test.find_one(projection={}))) self.assertTrue("hello" in list(db.test.find_one(projection=[]))) @@ -1760,16 +1666,13 @@ def test_cursor_timeout(self): def test_exhaust(self): if is_mongos(self.db.client): - self.assertRaises(InvalidOperation, - self.db.test.find, - cursor_type=CursorType.EXHAUST) + self.assertRaises(InvalidOperation, self.db.test.find, cursor_type=CursorType.EXHAUST) return # Limit is incompatible with exhaust. - self.assertRaises(InvalidOperation, - self.db.test.find, - cursor_type=CursorType.EXHAUST, - limit=5) + self.assertRaises( + InvalidOperation, self.db.test.find, cursor_type=CursorType.EXHAUST, limit=5 + ) cur = self.db.test.find(cursor_type=CursorType.EXHAUST) self.assertRaises(InvalidOperation, cur.limit, 5) cur = self.db.test.find(limit=5) @@ -1780,7 +1683,7 @@ def test_exhaust(self): self.db.drop_collection("test") # Insert enough documents to require more than one batch - self.db.test.insert_many([{'i': i} for i in range(150)]) + self.db.test.insert_many([{"i": i} for i in range(150)]) client = rs_or_single_client(maxPoolSize=1) self.addCleanup(client.close) @@ -1802,8 +1705,7 @@ def test_exhaust(self): # If the Cursor instance is discarded before being completely iterated # and the socket has pending data (more_to_come=True) we have to close # and discard the socket. - cur = client[self.db.name].test.find(cursor_type=CursorType.EXHAUST, - batch_size=2) + cur = client[self.db.name].test.find(cursor_type=CursorType.EXHAUST, batch_size=2) if client_context.version.at_least(4, 2): # On 4.2+ we use OP_MSG which only sets more_to_come=True after the # first getMore. @@ -1812,12 +1714,12 @@ def test_exhaust(self): else: next(cur) self.assertEqual(0, len(pool.sockets)) - if sys.platform.startswith('java') or 'PyPy' in sys.version: + if sys.platform.startswith("java") or "PyPy" in sys.version: # Don't wait for GC or use gc.collect(), it's unreliable. cur.close() cur = None # Wait until the background thread returns the socket. - wait_until(lambda: pool.active_sockets == 0, 'return socket') + wait_until(lambda: pool.active_sockets == 0, "return socket") # The socket should be discarded. self.assertEqual(0, len(pool.sockets)) @@ -1832,11 +1734,11 @@ def test_distinct(self): self.assertEqual([1, 2, 3], distinct) - distinct = test.find({'a': {'$gt': 1}}).distinct("a") + distinct = test.find({"a": {"$gt": 1}}).distinct("a") distinct.sort() self.assertEqual([2, 3], distinct) - distinct = test.distinct('a', {'a': {'$gt': 1}}) + distinct = test.distinct("a", {"a": {"$gt": 1}}) distinct.sort() self.assertEqual([2, 3], distinct) @@ -1857,19 +1759,15 @@ def test_query_on_query_field(self): self.db.test.insert_one({"query": "foo"}) self.db.test.insert_one({"bar": "foo"}) - self.assertEqual(1, - self.db.test.count_documents({"query": {"$ne": None}})) - self.assertEqual(1, - len(list(self.db.test.find({"query": {"$ne": None}}))) - ) + self.assertEqual(1, self.db.test.count_documents({"query": {"$ne": None}})) + self.assertEqual(1, len(list(self.db.test.find({"query": {"$ne": None}})))) def test_min_query(self): self.db.drop_collection("test") self.db.test.insert_many([{"x": 1}, {"x": 2}]) self.db.test.create_index("x") - cursor = self.db.test.find({"$min": {"x": 2}, "$query": {}}, - hint="x_1") + cursor = self.db.test.find({"$min": {"x": 2}, "$query": {}}, hint="x_1") docs = list(cursor) self.assertEqual(1, len(docs)) @@ -1886,24 +1784,30 @@ def test_numerous_inserts(self): def test_insert_many_large_batch(self): # Tests legacy insert. db = self.client.test_insert_large_batch - self.addCleanup(self.client.drop_database, 'test_insert_large_batch') + self.addCleanup(self.client.drop_database, "test_insert_large_batch") max_bson_size = client_context.max_bson_size # Write commands are limited to 16MB + 16k per batch - big_string = 'x' * int(max_bson_size / 2) + big_string = "x" * int(max_bson_size / 2) # Batch insert that requires 2 batches. - successful_insert = [{'x': big_string}, {'x': big_string}, - {'x': big_string}, {'x': big_string}] + successful_insert = [ + {"x": big_string}, + {"x": big_string}, + {"x": big_string}, + {"x": big_string}, + ] db.collection_0.insert_many(successful_insert) self.assertEqual(4, db.collection_0.count_documents({})) db.collection_0.drop() # Test that inserts fail after first error. - insert_second_fails = [{'_id': 'id0', 'x': big_string}, - {'_id': 'id0', 'x': big_string}, - {'_id': 'id1', 'x': big_string}, - {'_id': 'id2', 'x': big_string}] + insert_second_fails = [ + {"_id": "id0", "x": big_string}, + {"_id": "id0", "x": big_string}, + {"_id": "id1", "x": big_string}, + {"_id": "id2", "x": big_string}, + ] with self.assertRaises(BulkWriteError): db.collection_1.insert_many(insert_second_fails) @@ -1913,25 +1817,27 @@ def test_insert_many_large_batch(self): db.collection_1.drop() # 2 batches, 2nd insert fails, unacknowledged, ordered. - unack_coll = db.collection_2.with_options( - write_concern=WriteConcern(w=0)) + unack_coll = db.collection_2.with_options(write_concern=WriteConcern(w=0)) unack_coll.insert_many(insert_second_fails) - wait_until(lambda: 1 == db.collection_2.count_documents({}), - 'insert 1 document', timeout=60) + wait_until( + lambda: 1 == db.collection_2.count_documents({}), "insert 1 document", timeout=60 + ) db.collection_2.drop() # 2 batches, ids of docs 0 and 1 are dupes, ids of docs 2 and 3 are # dupes. Acknowledged, unordered. - insert_two_failures = [{'_id': 'id0', 'x': big_string}, - {'_id': 'id0', 'x': big_string}, - {'_id': 'id1', 'x': big_string}, - {'_id': 'id1', 'x': big_string}] + insert_two_failures = [ + {"_id": "id0", "x": big_string}, + {"_id": "id0", "x": big_string}, + {"_id": "id1", "x": big_string}, + {"_id": "id1", "x": big_string}, + ] with self.assertRaises(OperationFailure) as context: db.collection_3.insert_many(insert_two_failures, ordered=False) - self.assertIn('id1', str(context.exception)) + self.assertIn("id1", str(context.exception)) # Only the first and third documents should be inserted. self.assertEqual(2, db.collection_3.count_documents({})) @@ -1939,13 +1845,13 @@ def test_insert_many_large_batch(self): db.collection_3.drop() # 2 batches, 2 errors, unacknowledged, unordered. - unack_coll = db.collection_4.with_options( - write_concern=WriteConcern(w=0)) + unack_coll = db.collection_4.with_options(write_concern=WriteConcern(w=0)) unack_coll.insert_many(insert_two_failures, ordered=False) # Only the first and third documents are inserted. - wait_until(lambda: 2 == db.collection_4.count_documents({}), - 'insert 2 documents', timeout=60) + wait_until( + lambda: 2 == db.collection_4.count_documents({}), "insert 2 documents", timeout=60 + ) db.collection_4.drop() @@ -1973,210 +1879,229 @@ class BadGetAttr(dict): def __getattr__(self, name): pass - bad = BadGetAttr([('foo', 'bar')]) - c.insert_one({'bad': bad}) - self.assertEqual('bar', c.find_one()['bad']['foo']) + bad = BadGetAttr([("foo", "bar")]) + c.insert_one({"bad": bad}) + self.assertEqual("bar", c.find_one()["bad"]["foo"]) def test_array_filters_validation(self): # array_filters must be a list. c = self.db.test with self.assertRaises(TypeError): - c.update_one({}, {'$set': {'a': 1}}, array_filters={}) + c.update_one({}, {"$set": {"a": 1}}, array_filters={}) with self.assertRaises(TypeError): - c.update_many({}, {'$set': {'a': 1}}, array_filters={}) + c.update_many({}, {"$set": {"a": 1}}, array_filters={}) with self.assertRaises(TypeError): - c.find_one_and_update({}, {'$set': {'a': 1}}, array_filters={}) + c.find_one_and_update({}, {"$set": {"a": 1}}, array_filters={}) def test_array_filters_unacknowledged(self): c_w0 = self.db.test.with_options(write_concern=WriteConcern(w=0)) with self.assertRaises(ConfigurationError): - c_w0.update_one({}, {'$set': {'y.$[i].b': 5}}, - array_filters=[{'i.b': 1}]) + c_w0.update_one({}, {"$set": {"y.$[i].b": 5}}, array_filters=[{"i.b": 1}]) with self.assertRaises(ConfigurationError): - c_w0.update_many({}, {'$set': {'y.$[i].b': 5}}, - array_filters=[{'i.b': 1}]) + c_w0.update_many({}, {"$set": {"y.$[i].b": 5}}, array_filters=[{"i.b": 1}]) with self.assertRaises(ConfigurationError): - c_w0.find_one_and_update({}, {'$set': {'y.$[i].b': 5}}, - array_filters=[{'i.b': 1}]) + c_w0.find_one_and_update({}, {"$set": {"y.$[i].b": 5}}, array_filters=[{"i.b": 1}]) def test_find_one_and(self): c = self.db.test c.drop() - c.insert_one({'_id': 1, 'i': 1}) - - self.assertEqual({'_id': 1, 'i': 1}, - c.find_one_and_update({'_id': 1}, {'$inc': {'i': 1}})) - self.assertEqual({'_id': 1, 'i': 3}, - c.find_one_and_update( - {'_id': 1}, {'$inc': {'i': 1}}, - return_document=ReturnDocument.AFTER)) - - self.assertEqual({'_id': 1, 'i': 3}, - c.find_one_and_delete({'_id': 1})) - self.assertEqual(None, c.find_one({'_id': 1})) - - self.assertEqual(None, - c.find_one_and_update({'_id': 1}, {'$inc': {'i': 1}})) - self.assertEqual({'_id': 1, 'i': 1}, - c.find_one_and_update( - {'_id': 1}, {'$inc': {'i': 1}}, - return_document=ReturnDocument.AFTER, - upsert=True)) - self.assertEqual({'_id': 1, 'i': 2}, - c.find_one_and_update( - {'_id': 1}, {'$inc': {'i': 1}}, - return_document=ReturnDocument.AFTER)) - - self.assertEqual({'_id': 1, 'i': 3}, - c.find_one_and_replace( - {'_id': 1}, {'i': 3, 'j': 1}, - projection=['i'], - return_document=ReturnDocument.AFTER)) - self.assertEqual({'i': 4}, - c.find_one_and_update( - {'_id': 1}, {'$inc': {'i': 1}}, - projection={'i': 1, '_id': 0}, - return_document=ReturnDocument.AFTER)) + c.insert_one({"_id": 1, "i": 1}) + + self.assertEqual({"_id": 1, "i": 1}, c.find_one_and_update({"_id": 1}, {"$inc": {"i": 1}})) + self.assertEqual( + {"_id": 1, "i": 3}, + c.find_one_and_update( + {"_id": 1}, {"$inc": {"i": 1}}, return_document=ReturnDocument.AFTER + ), + ) + + self.assertEqual({"_id": 1, "i": 3}, c.find_one_and_delete({"_id": 1})) + self.assertEqual(None, c.find_one({"_id": 1})) + + self.assertEqual(None, c.find_one_and_update({"_id": 1}, {"$inc": {"i": 1}})) + self.assertEqual( + {"_id": 1, "i": 1}, + c.find_one_and_update( + {"_id": 1}, {"$inc": {"i": 1}}, return_document=ReturnDocument.AFTER, upsert=True + ), + ) + self.assertEqual( + {"_id": 1, "i": 2}, + c.find_one_and_update( + {"_id": 1}, {"$inc": {"i": 1}}, return_document=ReturnDocument.AFTER + ), + ) + + self.assertEqual( + {"_id": 1, "i": 3}, + c.find_one_and_replace( + {"_id": 1}, {"i": 3, "j": 1}, projection=["i"], return_document=ReturnDocument.AFTER + ), + ) + self.assertEqual( + {"i": 4}, + c.find_one_and_update( + {"_id": 1}, + {"$inc": {"i": 1}}, + projection={"i": 1, "_id": 0}, + return_document=ReturnDocument.AFTER, + ), + ) c.drop() for j in range(5): - c.insert_one({'j': j, 'i': 0}) + c.insert_one({"j": j, "i": 0}) - sort = [('j', DESCENDING)] - self.assertEqual(4, c.find_one_and_update({}, - {'$inc': {'i': 1}}, - sort=sort)['j']) + sort = [("j", DESCENDING)] + self.assertEqual(4, c.find_one_and_update({}, {"$inc": {"i": 1}}, sort=sort)["j"]) def test_find_one_and_write_concern(self): listener = EventListener() db = single_client(event_listeners=[listener])[self.db.name] # non-default WriteConcern. - c_w0 = db.get_collection( - 'test', write_concern=WriteConcern(w=0)) + c_w0 = db.get_collection("test", write_concern=WriteConcern(w=0)) # default WriteConcern. - c_default = db.get_collection('test', write_concern=WriteConcern()) + c_default = db.get_collection("test", write_concern=WriteConcern()) results = listener.results # Authenticate the client and throw out auth commands from the listener. - db.command('ping') + db.command("ping") results.clear() - c_w0.find_one_and_update( - {'_id': 1}, {'$set': {'foo': 'bar'}}) - self.assertEqual( - {'w': 0}, results['started'][0].command['writeConcern']) + c_w0.find_one_and_update({"_id": 1}, {"$set": {"foo": "bar"}}) + self.assertEqual({"w": 0}, results["started"][0].command["writeConcern"]) results.clear() - c_w0.find_one_and_replace({'_id': 1}, {'foo': 'bar'}) - self.assertEqual( - {'w': 0}, results['started'][0].command['writeConcern']) + c_w0.find_one_and_replace({"_id": 1}, {"foo": "bar"}) + self.assertEqual({"w": 0}, results["started"][0].command["writeConcern"]) results.clear() - c_w0.find_one_and_delete({'_id': 1}) - self.assertEqual( - {'w': 0}, results['started'][0].command['writeConcern']) + c_w0.find_one_and_delete({"_id": 1}) + self.assertEqual({"w": 0}, results["started"][0].command["writeConcern"]) results.clear() # Test write concern errors. if client_context.is_rs: c_wc_error = db.get_collection( - 'test', - write_concern=WriteConcern( - w=len(client_context.nodes) + 1)) + "test", write_concern=WriteConcern(w=len(client_context.nodes) + 1) + ) self.assertRaises( WriteConcernError, c_wc_error.find_one_and_update, - {'_id': 1}, {'$set': {'foo': 'bar'}}) + {"_id": 1}, + {"$set": {"foo": "bar"}}, + ) self.assertRaises( WriteConcernError, c_wc_error.find_one_and_replace, - {'w': 0}, results['started'][0].command['writeConcern']) + {"w": 0}, + results["started"][0].command["writeConcern"], + ) self.assertRaises( WriteConcernError, c_wc_error.find_one_and_delete, - {'w': 0}, results['started'][0].command['writeConcern']) + {"w": 0}, + results["started"][0].command["writeConcern"], + ) results.clear() - c_default.find_one_and_update({'_id': 1}, {'$set': {'foo': 'bar'}}) - self.assertNotIn('writeConcern', results['started'][0].command) + c_default.find_one_and_update({"_id": 1}, {"$set": {"foo": "bar"}}) + self.assertNotIn("writeConcern", results["started"][0].command) results.clear() - c_default.find_one_and_replace({'_id': 1}, {'foo': 'bar'}) - self.assertNotIn('writeConcern', results['started'][0].command) + c_default.find_one_and_replace({"_id": 1}, {"foo": "bar"}) + self.assertNotIn("writeConcern", results["started"][0].command) results.clear() - c_default.find_one_and_delete({'_id': 1}) - self.assertNotIn('writeConcern', results['started'][0].command) + c_default.find_one_and_delete({"_id": 1}) + self.assertNotIn("writeConcern", results["started"][0].command) results.clear() def test_find_with_nested(self): c = self.db.test c.drop() - c.insert_many([{'i': i} for i in range(5)]) # [0, 1, 2, 3, 4] + c.insert_many([{"i": i} for i in range(5)]) # [0, 1, 2, 3, 4] self.assertEqual( [2], - [i['i'] for i in c.find({ - '$and': [ + [ + i["i"] + for i in c.find( { - # This clause gives us [1,2,4] - '$or': [ - {'i': {'$lte': 2}}, - {'i': {'$gt': 3}}, - ], - }, - { - # This clause gives us [2,3] - '$or': [ - {'i': 2}, - {'i': 3}, + "$and": [ + { + # This clause gives us [1,2,4] + "$or": [ + {"i": {"$lte": 2}}, + {"i": {"$gt": 3}}, + ], + }, + { + # This clause gives us [2,3] + "$or": [ + {"i": 2}, + {"i": 3}, + ] + }, ] - }, - ] - })] + } + ) + ], ) self.assertEqual( [0, 1, 2], - [i['i'] for i in c.find({ - '$or': [ - { - # This clause gives us [2] - '$and': [ - {'i': {'$gte': 2}}, - {'i': {'$lt': 3}}, - ], - }, + [ + i["i"] + for i in c.find( { - # This clause gives us [0,1] - '$and': [ - {'i': {'$gt': -100}}, - {'i': {'$lt': 2}}, + "$or": [ + { + # This clause gives us [2] + "$and": [ + {"i": {"$gte": 2}}, + {"i": {"$lt": 3}}, + ], + }, + { + # This clause gives us [0,1] + "$and": [ + {"i": {"$gt": -100}}, + {"i": {"$lt": 2}}, + ] + }, ] - }, - ] - })] + } + ) + ], ) def test_find_regex(self): c = self.db.test c.drop() - c.insert_one({'r': re.compile('.*')}) + c.insert_one({"r": re.compile(".*")}) - self.assertTrue(isinstance(c.find_one()['r'], Regex)) + self.assertTrue(isinstance(c.find_one()["r"], Regex)) for doc in c.find(): - self.assertTrue(isinstance(doc['r'], Regex)) + self.assertTrue(isinstance(doc["r"], Regex)) def test_find_command_generation(self): - cmd = _gen_find_command('coll', {'$query': {'foo': 1}, '$dumb': 2}, - None, 0, 0, 0, None, DEFAULT_READ_CONCERN, - None, None) + cmd = _gen_find_command( + "coll", + {"$query": {"foo": 1}, "$dumb": 2}, + None, + 0, + 0, + 0, + None, + DEFAULT_READ_CONCERN, + None, + None, + ) self.assertEqual( - cmd.to_dict(), - SON([('find', 'coll'), - ('$dumb', 2), - ('filter', {'foo': 1})]).to_dict()) + cmd.to_dict(), SON([("find", "coll"), ("$dumb", 2), ("filter", {"foo": 1})]).to_dict() + ) def test_bool(self): with self.assertRaises(NotImplementedError): - bool(Collection(self.db, 'test')) + bool(Collection(self.db, "test")) if __name__ == "__main__": diff --git a/test/test_collection_management.py b/test/test_collection_management.py index 342e612583..c5e29eda8a 100644 --- a/test/test_collection_management.py +++ b/test/test_collection_management.py @@ -20,12 +20,10 @@ sys.path[0:0] = [""] from test import unittest - from test.unified_format import generate_test_classes # Location of JSON test specifications. -TEST_PATH = os.path.join( - os.path.dirname(os.path.realpath(__file__)), 'collection_management') +TEST_PATH = os.path.join(os.path.dirname(os.path.realpath(__file__)), "collection_management") # Generate unified tests. globals().update(generate_test_classes(TEST_PATH, module=__name__)) diff --git a/test/test_command_monitoring_legacy.py b/test/test_command_monitoring_legacy.py index 7ff80d75e5..14db5abe2a 100644 --- a/test/test_command_monitoring_legacy.py +++ b/test/test_command_monitoring_legacy.py @@ -20,30 +20,31 @@ sys.path[0:0] = [""] -import pymongo +from test import client_context, unittest +from test.utils import ( + EventListener, + parse_read_preference, + rs_or_single_client, + wait_until, +) -from pymongo import MongoClient +import pymongo from bson import json_util +from pymongo import MongoClient from pymongo.errors import OperationFailure from pymongo.write_concern import WriteConcern -from test import unittest, client_context -from test.utils import (rs_or_single_client, wait_until, EventListener, - parse_read_preference) # Location of JSON test specifications. -_TEST_PATH = os.path.join( - os.path.dirname(os.path.realpath(__file__)), - 'command_monitoring') +_TEST_PATH = os.path.join(os.path.dirname(os.path.realpath(__file__)), "command_monitoring") def camel_to_snake(camel): # Regex to convert CamelCase to snake_case. - snake = re.sub('(.)([A-Z][a-z]+)', r'\1_\2', camel) - return re.sub('([a-z0-9])([A-Z])', r'\1_\2', snake).lower() + snake = re.sub("(.)([A-Z][a-z]+)", r"\1_\2", camel) + return re.sub("([a-z0-9])([A-Z])", r"\1_\2", snake).lower() class TestAllScenarios(unittest.TestCase): - @classmethod @client_context.require_connection def setUpClass(cls): @@ -59,9 +60,9 @@ def tearDown(self): def format_actual_results(results): - started = results['started'] - succeeded = results['succeeded'] - failed = results['failed'] + started = results["started"] + succeeded = results["succeeded"] + failed = results["failed"] msg = "\nStarted: %r" % (started[0].command if len(started) else None,) msg += "\nSucceeded: %r" % (succeeded[0].reply if len(succeeded) else None,) msg += "\nFailed: %r" % (failed[0].failure if len(failed) else None,) @@ -70,51 +71,51 @@ def format_actual_results(results): def create_test(scenario_def, test): def run_scenario(self): - dbname = scenario_def['database_name'] - collname = scenario_def['collection_name'] + dbname = scenario_def["database_name"] + collname = scenario_def["collection_name"] coll = self.client[dbname][collname] coll.drop() - coll.insert_many(scenario_def['data']) + coll.insert_many(scenario_def["data"]) self.listener.results.clear() - name = camel_to_snake(test['operation']['name']) - if 'read_preference' in test['operation']: - coll = coll.with_options(read_preference=parse_read_preference( - test['operation']['read_preference'])) - if 'collectionOptions' in test['operation']: - colloptions = test['operation']['collectionOptions'] - if 'writeConcern' in colloptions: - concern = colloptions['writeConcern'] - coll = coll.with_options( - write_concern=WriteConcern(**concern)) - - test_args = test['operation']['arguments'] - if 'options' in test_args: - options = test_args.pop('options') + name = camel_to_snake(test["operation"]["name"]) + if "read_preference" in test["operation"]: + coll = coll.with_options( + read_preference=parse_read_preference(test["operation"]["read_preference"]) + ) + if "collectionOptions" in test["operation"]: + colloptions = test["operation"]["collectionOptions"] + if "writeConcern" in colloptions: + concern = colloptions["writeConcern"] + coll = coll.with_options(write_concern=WriteConcern(**concern)) + + test_args = test["operation"]["arguments"] + if "options" in test_args: + options = test_args.pop("options") test_args.update(options) args = {} for arg in test_args: args[camel_to_snake(arg)] = test_args[arg] - if name == 'count': - self.skipTest('PyMongo does not support count') - elif name == 'bulk_write': + if name == "count": + self.skipTest("PyMongo does not support count") + elif name == "bulk_write": bulk_args = [] - for request in args['requests']: - opname = request['name'] + for request in args["requests"]: + opname = request["name"] klass = opname[0:1].upper() + opname[1:] - arg = getattr(pymongo, klass)(**request['arguments']) + arg = getattr(pymongo, klass)(**request["arguments"]) bulk_args.append(arg) try: - coll.bulk_write(bulk_args, args.get('ordered', True)) + coll.bulk_write(bulk_args, args.get("ordered", True)) except OperationFailure: pass - elif name == 'find': - if 'sort' in args: - args['sort'] = list(args['sort'].items()) - if 'hint' in args: - args['hint'] = list(args['hint'].items()) - for arg in 'skip', 'limit': + elif name == "find": + if "sort" in args: + args["sort"] = list(args["sort"].items()) + if "hint" in args: + args["hint"] = list(args["hint"].items()) + for arg in "skip", "limit": if arg in args: args[arg] = int(args[arg]) try: @@ -129,73 +130,73 @@ def run_scenario(self): pass res = self.listener.results - for expectation in test['expectations']: + for expectation in test["expectations"]: event_type = next(iter(expectation)) if event_type == "command_started_event": - event = res['started'][0] if len(res['started']) else None + event = res["started"][0] if len(res["started"]) else None if event is not None: # The tests substitute 42 for any number other than 0. - if (event.command_name == 'getMore' - and event.command['getMore']): - event.command['getMore'] = 42 - elif event.command_name == 'killCursors': - event.command['cursors'] = [42] - elif event.command_name == 'update': + if event.command_name == "getMore" and event.command["getMore"]: + event.command["getMore"] = 42 + elif event.command_name == "killCursors": + event.command["cursors"] = [42] + elif event.command_name == "update": # TODO: remove this once PYTHON-1744 is done. # Add upsert and multi fields back into # expectations. - updates = expectation[event_type]['command'][ - 'updates'] + updates = expectation[event_type]["command"]["updates"] for update in updates: - update.setdefault('upsert', False) - update.setdefault('multi', False) + update.setdefault("upsert", False) + update.setdefault("multi", False) elif event_type == "command_succeeded_event": - event = ( - res['succeeded'].pop(0) if len(res['succeeded']) else None) + event = res["succeeded"].pop(0) if len(res["succeeded"]) else None if event is not None: reply = event.reply # The tests substitute 42 for any number other than 0, # and "" for any error message. - if 'writeErrors' in reply: - for doc in reply['writeErrors']: + if "writeErrors" in reply: + for doc in reply["writeErrors"]: # Remove any new fields the server adds. The tests # only have index, code, and errmsg. - diff = set(doc) - set(['index', 'code', 'errmsg']) + diff = set(doc) - set(["index", "code", "errmsg"]) for field in diff: doc.pop(field) - doc['code'] = 42 - doc['errmsg'] = "" - elif 'cursor' in reply: - if reply['cursor']['id']: - reply['cursor']['id'] = 42 - elif event.command_name == 'killCursors': + doc["code"] = 42 + doc["errmsg"] = "" + elif "cursor" in reply: + if reply["cursor"]["id"]: + reply["cursor"]["id"] = 42 + elif event.command_name == "killCursors": # Make the tests continue to pass when the killCursors # command is actually in use. - if 'cursorsKilled' in reply: - reply.pop('cursorsKilled') - reply['cursorsUnknown'] = [42] + if "cursorsKilled" in reply: + reply.pop("cursorsKilled") + reply["cursorsUnknown"] = [42] # Found succeeded event. Pop related started event. - res['started'].pop(0) + res["started"].pop(0) elif event_type == "command_failed_event": - event = res['failed'].pop(0) if len(res['failed']) else None + event = res["failed"].pop(0) if len(res["failed"]) else None if event is not None: # Found failed event. Pop related started event. - res['started'].pop(0) + res["started"].pop(0) else: self.fail("Unknown event type") if event is None: - event_name = event_type.split('_')[1] + event_name = event_type.split("_")[1] self.fail( "Expected %s event for %s command. Actual " - "results:%s" % ( + "results:%s" + % ( event_name, - expectation[event_type]['command_name'], - format_actual_results(res))) + expectation[event_type]["command_name"], + format_actual_results(res), + ) + ) for attr, expected in expectation[event_type].items(): - if 'options' in expected: - options = expected.pop('options') + if "options" in expected: + options = expected.pop("options") expected.update(options) actual = getattr(event, attr) if isinstance(expected, dict): @@ -208,35 +209,33 @@ def run_scenario(self): def create_tests(): - for dirpath, _, filenames in os.walk(os.path.join(_TEST_PATH, 'legacy')): + for dirpath, _, filenames in os.walk(os.path.join(_TEST_PATH, "legacy")): dirname = os.path.split(dirpath)[-1] for filename in filenames: with open(os.path.join(dirpath, filename)) as scenario_stream: scenario_def = json_util.loads(scenario_stream.read()) - assert bool(scenario_def.get('tests')), "tests cannot be empty" + assert bool(scenario_def.get("tests")), "tests cannot be empty" # Construct test from scenario. - for test in scenario_def['tests']: + for test in scenario_def["tests"]: new_test = create_test(scenario_def, test) if "ignore_if_server_version_greater_than" in test: version = test["ignore_if_server_version_greater_than"] - ver = tuple(int(elt) for elt in version.split('.')) - new_test = client_context.require_version_max(*ver)( - new_test) + ver = tuple(int(elt) for elt in version.split(".")) + new_test = client_context.require_version_max(*ver)(new_test) if "ignore_if_server_version_less_than" in test: version = test["ignore_if_server_version_less_than"] - ver = tuple(int(elt) for elt in version.split('.')) - new_test = client_context.require_version_min(*ver)( - new_test) + ver = tuple(int(elt) for elt in version.split(".")) + new_test = client_context.require_version_min(*ver)(new_test) if "ignore_if_topology_type" in test: types = set(test["ignore_if_topology_type"]) if "sharded" in types: - new_test = client_context.require_no_mongos(None)( - new_test) + new_test = client_context.require_no_mongos(None)(new_test) - test_name = 'test_%s_%s_%s' % ( + test_name = "test_%s_%s_%s" % ( dirname, os.path.splitext(filename)[0], - str(test['description'].replace(" ", "_"))) + str(test["description"].replace(" ", "_")), + ) new_test.__name__ = test_name setattr(TestAllScenarios, new_test.__name__, new_test) diff --git a/test/test_command_monitoring_unified.py b/test/test_command_monitoring_unified.py index 9390c9fec6..46e1e4724c 100644 --- a/test/test_command_monitoring_unified.py +++ b/test/test_command_monitoring_unified.py @@ -22,16 +22,16 @@ from test import unittest from test.unified_format import generate_test_classes - # Location of JSON test specifications. -_TEST_PATH = os.path.join( - os.path.dirname(os.path.realpath(__file__)), - 'command_monitoring') +_TEST_PATH = os.path.join(os.path.dirname(os.path.realpath(__file__)), "command_monitoring") -globals().update(generate_test_classes( - os.path.join(_TEST_PATH, 'unified'), - module=__name__,)) +globals().update( + generate_test_classes( + os.path.join(_TEST_PATH, "unified"), + module=__name__, + ) +) if __name__ == "__main__": diff --git a/test/test_common.py b/test/test_common.py index dcd618c509..af42089806 100644 --- a/test/test_common.py +++ b/test/test_common.py @@ -19,13 +19,14 @@ sys.path[0:0] = [""] -from bson.binary import Binary, PYTHON_LEGACY, STANDARD, UuidRepresentation +from test import IntegrationTest, client_context, unittest +from test.utils import connected, rs_or_single_client, single_client + +from bson.binary import PYTHON_LEGACY, STANDARD, Binary, UuidRepresentation from bson.codec_options import CodecOptions from bson.objectid import ObjectId from pymongo.errors import OperationFailure from pymongo.write_concern import WriteConcern -from test import client_context, unittest, IntegrationTest -from test.utils import connected, rs_or_single_client, single_client @client_context.require_connection @@ -34,81 +35,79 @@ def setUpModule(): class TestCommon(IntegrationTest): - def test_uuid_representation(self): coll = self.db.uuid coll.drop() # Test property - self.assertEqual(UuidRepresentation.UNSPECIFIED, - coll.codec_options.uuid_representation) + self.assertEqual(UuidRepresentation.UNSPECIFIED, coll.codec_options.uuid_representation) # Test basic query uu = uuid.uuid4() # Insert as binary subtype 3 - coll = self.db.get_collection( - "uuid", CodecOptions(uuid_representation=PYTHON_LEGACY)) + coll = self.db.get_collection("uuid", CodecOptions(uuid_representation=PYTHON_LEGACY)) legacy_opts = coll.codec_options - coll.insert_one({'uu': uu}) - self.assertEqual(uu, coll.find_one({'uu': uu})['uu']) - coll = self.db.get_collection( - "uuid", CodecOptions(uuid_representation=STANDARD)) + coll.insert_one({"uu": uu}) + self.assertEqual(uu, coll.find_one({"uu": uu})["uu"]) + coll = self.db.get_collection("uuid", CodecOptions(uuid_representation=STANDARD)) self.assertEqual(STANDARD, coll.codec_options.uuid_representation) - self.assertEqual(None, coll.find_one({'uu': uu})) + self.assertEqual(None, coll.find_one({"uu": uu})) uul = Binary.from_uuid(uu, PYTHON_LEGACY) - self.assertEqual(uul, coll.find_one({'uu': uul})['uu']) + self.assertEqual(uul, coll.find_one({"uu": uul})["uu"]) # Test count_documents - self.assertEqual(0, coll.count_documents({'uu': uu})) - coll = self.db.get_collection( - "uuid", CodecOptions(uuid_representation=PYTHON_LEGACY)) - self.assertEqual(1, coll.count_documents({'uu': uu})) + self.assertEqual(0, coll.count_documents({"uu": uu})) + coll = self.db.get_collection("uuid", CodecOptions(uuid_representation=PYTHON_LEGACY)) + self.assertEqual(1, coll.count_documents({"uu": uu})) # Test delete - coll = self.db.get_collection( - "uuid", CodecOptions(uuid_representation=STANDARD)) - coll.delete_one({'uu': uu}) + coll = self.db.get_collection("uuid", CodecOptions(uuid_representation=STANDARD)) + coll.delete_one({"uu": uu}) self.assertEqual(1, coll.count_documents({})) - coll = self.db.get_collection( - "uuid", CodecOptions(uuid_representation=PYTHON_LEGACY)) - coll.delete_one({'uu': uu}) + coll = self.db.get_collection("uuid", CodecOptions(uuid_representation=PYTHON_LEGACY)) + coll.delete_one({"uu": uu}) self.assertEqual(0, coll.count_documents({})) # Test update_one - coll.insert_one({'_id': uu, 'i': 1}) - coll = self.db.get_collection( - "uuid", CodecOptions(uuid_representation=STANDARD)) - coll.update_one({'_id': uu}, {'$set': {'i': 2}}) - coll = self.db.get_collection( - "uuid", CodecOptions(uuid_representation=PYTHON_LEGACY)) - self.assertEqual(1, coll.find_one({'_id': uu})['i']) - coll.update_one({'_id': uu}, {'$set': {'i': 2}}) - self.assertEqual(2, coll.find_one({'_id': uu})['i']) + coll.insert_one({"_id": uu, "i": 1}) + coll = self.db.get_collection("uuid", CodecOptions(uuid_representation=STANDARD)) + coll.update_one({"_id": uu}, {"$set": {"i": 2}}) + coll = self.db.get_collection("uuid", CodecOptions(uuid_representation=PYTHON_LEGACY)) + self.assertEqual(1, coll.find_one({"_id": uu})["i"]) + coll.update_one({"_id": uu}, {"$set": {"i": 2}}) + self.assertEqual(2, coll.find_one({"_id": uu})["i"]) # Test Cursor.distinct - self.assertEqual([2], coll.find({'_id': uu}).distinct('i')) - coll = self.db.get_collection( - "uuid", CodecOptions(uuid_representation=STANDARD)) - self.assertEqual([], coll.find({'_id': uu}).distinct('i')) + self.assertEqual([2], coll.find({"_id": uu}).distinct("i")) + coll = self.db.get_collection("uuid", CodecOptions(uuid_representation=STANDARD)) + self.assertEqual([], coll.find({"_id": uu}).distinct("i")) # Test findAndModify - self.assertEqual(None, coll.find_one_and_update({'_id': uu}, - {'$set': {'i': 5}})) - coll = self.db.get_collection( - "uuid", CodecOptions(uuid_representation=PYTHON_LEGACY)) - self.assertEqual(2, coll.find_one_and_update({'_id': uu}, - {'$set': {'i': 5}})['i']) - self.assertEqual(5, coll.find_one({'_id': uu})['i']) + self.assertEqual(None, coll.find_one_and_update({"_id": uu}, {"$set": {"i": 5}})) + coll = self.db.get_collection("uuid", CodecOptions(uuid_representation=PYTHON_LEGACY)) + self.assertEqual(2, coll.find_one_and_update({"_id": uu}, {"$set": {"i": 5}})["i"]) + self.assertEqual(5, coll.find_one({"_id": uu})["i"]) # Test command - self.assertEqual(5, self.db.command( - 'findAndModify', 'uuid', - update={'$set': {'i': 6}}, - query={'_id': uu}, codec_options=legacy_opts)['value']['i']) - self.assertEqual(6, self.db.command( - 'findAndModify', 'uuid', - update={'$set': {'i': 7}}, - query={'_id': Binary.from_uuid(uu, PYTHON_LEGACY)})['value']['i']) + self.assertEqual( + 5, + self.db.command( + "findAndModify", + "uuid", + update={"$set": {"i": 6}}, + query={"_id": uu}, + codec_options=legacy_opts, + )["value"]["i"], + ) + self.assertEqual( + 6, + self.db.command( + "findAndModify", + "uuid", + update={"$set": {"i": 7}}, + query={"_id": Binary.from_uuid(uu, PYTHON_LEGACY)}, + )["value"]["i"], + ) def test_write_concern(self): c = rs_or_single_client(connect=False) @@ -119,7 +118,7 @@ def test_write_concern(self): self.assertEqual(wc, c.write_concern) # Can we override back to the server default? - db = c.get_database('pymongo_test', write_concern=WriteConcern()) + db = c.get_database("pymongo_test", write_concern=WriteConcern()) self.assertEqual(db.write_concern, WriteConcern()) db = c.pymongo_test @@ -128,7 +127,7 @@ def test_write_concern(self): self.assertEqual(wc, coll.write_concern) cwc = WriteConcern(j=True) - coll = db.get_collection('test', write_concern=cwc) + coll = db.get_collection("test", write_concern=cwc) self.assertEqual(cwc, coll.write_concern) self.assertEqual(wc, db.write_concern) @@ -149,21 +148,22 @@ def test_mongo_client(self): self.assertTrue(new_coll.insert_one(doc)) self.assertRaises(OperationFailure, coll.insert_one, doc) - m = rs_or_single_client("mongodb://%s/" % (pair,), - replicaSet=client_context.replica_set_name) + m = rs_or_single_client( + "mongodb://%s/" % (pair,), replicaSet=client_context.replica_set_name + ) coll = m.pymongo_test.write_concern_test self.assertRaises(OperationFailure, coll.insert_one, doc) - m = rs_or_single_client("mongodb://%s/?w=0" % (pair,), - replicaSet=client_context.replica_set_name) + m = rs_or_single_client( + "mongodb://%s/?w=0" % (pair,), replicaSet=client_context.replica_set_name + ) coll = m.pymongo_test.write_concern_test coll.insert_one(doc) # Equality tests direct = connected(single_client(w=0)) - direct2 = connected(single_client("mongodb://%s/?w=0" % (pair,), - **self.credentials)) + direct2 = connected(single_client("mongodb://%s/?w=0" % (pair,), **self.credentials)) self.assertEqual(direct, direct2) self.assertFalse(direct != direct2) diff --git a/test/test_connections_survive_primary_stepdown_spec.py b/test/test_connections_survive_primary_stepdown_spec.py index 894b14becd..ded872515e 100644 --- a/test/test_connections_survive_primary_stepdown_spec.py +++ b/test/test_connections_survive_primary_stepdown_spec.py @@ -18,19 +18,19 @@ sys.path[0:0] = [""] +from test import IntegrationTest, client_context, unittest +from test.utils import ( + CMAPListener, + ensure_all_connected, + repl_set_step_down, + rs_or_single_client, +) + from bson import SON from pymongo import monitoring from pymongo.errors import NotPrimaryError from pymongo.write_concern import WriteConcern -from test import (client_context, - unittest, - IntegrationTest) -from test.utils import (CMAPListener, - ensure_all_connected, - repl_set_step_down, - rs_or_single_client) - class TestConnectionsSurvivePrimaryStepDown(IntegrationTest): @classmethod @@ -38,9 +38,9 @@ class TestConnectionsSurvivePrimaryStepDown(IntegrationTest): def setUpClass(cls): super(TestConnectionsSurvivePrimaryStepDown, cls).setUpClass() cls.listener = CMAPListener() - cls.client = rs_or_single_client(event_listeners=[cls.listener], - retryWrites=False, - heartbeatFrequencyMS=500) + cls.client = rs_or_single_client( + event_listeners=[cls.listener], retryWrites=False, heartbeatFrequencyMS=500 + ) # Ensure connections to all servers in replica set. This is to test # that the is_writable flag is properly updated for sockets that @@ -48,10 +48,8 @@ def setUpClass(cls): ensure_all_connected(cls.client) cls.listener.reset() - cls.db = cls.client.get_database( - "step-down", write_concern=WriteConcern("majority")) - cls.coll = cls.db.get_collection( - "step-down", write_concern=WriteConcern("majority")) + cls.db = cls.client.get_database("step-down", write_concern=WriteConcern("majority")) + cls.coll = cls.db.get_collection("step-down", write_concern=WriteConcern("majority")) @classmethod def tearDownClass(cls): @@ -69,17 +67,15 @@ def set_fail_point(self, command_args): self.client.admin.command(cmd) def verify_pool_cleared(self): - self.assertEqual( - self.listener.event_count(monitoring.PoolClearedEvent), 1) + self.assertEqual(self.listener.event_count(monitoring.PoolClearedEvent), 1) def verify_pool_not_cleared(self): - self.assertEqual( - self.listener.event_count(monitoring.PoolClearedEvent), 0) + self.assertEqual(self.listener.event_count(monitoring.PoolClearedEvent), 0) @client_context.require_version_min(4, 2, -1) def test_get_more_iteration(self): # Insert 5 documents with WC majority. - self.coll.insert_many([{'data': k} for k in range(5)]) + self.coll.insert_many([{"data": k} for k in range(5)]) # Start a find operation and retrieve first batch of results. batch_size = 2 cursor = self.coll.find(batch_size=batch_size) @@ -104,14 +100,14 @@ def test_get_more_iteration(self): def run_scenario(self, error_code, retry, pool_status_checker): # Set fail point. - self.set_fail_point({"mode": {"times": 1}, - "data": {"failCommands": ["insert"], - "errorCode": error_code}}) + self.set_fail_point( + {"mode": {"times": 1}, "data": {"failCommands": ["insert"], "errorCode": error_code}} + ) self.addCleanup(self.set_fail_point, {"mode": "off"}) # Insert record and verify failure. with self.assertRaises(NotPrimaryError) as exc: self.coll.insert_one({"test": 1}) - self.assertEqual(exc.exception.details['code'], error_code) + self.assertEqual(exc.exception.details["code"], error_code) # Retry before CMAPListener assertion if retry_before=True. if retry: self.coll.insert_one({"test": 1}) diff --git a/test/test_create_entities.py b/test/test_create_entities.py index b82b730aef..ad0ac9347e 100644 --- a/test/test_create_entities.py +++ b/test/test_create_entities.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. import unittest - from test.unified_format import UnifiedSpecTestMixinV1 @@ -26,23 +25,18 @@ def test_store_events_as_entities(self): { "client": { "id": "client0", - "storeEventsAsEntities": [ - { - "id": "events1", - "events": [ - "PoolCreatedEvent", - ] - } - ] + "storeEventsAsEntities": [ + { + "id": "events1", + "events": [ + "PoolCreatedEvent", + ], + } + ], } }, ], - "tests": [ - { - "description": "foo", - "operations": [] - } - ] + "tests": [{"description": "foo", "operations": []}], } self.scenario_runner.TEST_SPEC = spec self.scenario_runner.setUp() @@ -63,27 +57,18 @@ def test_store_all_others_as_entities(self): { "client": { "id": "client0", - "uriOptions": { - "retryReads": True - }, - } - }, - { - "database": { - "id": "database0", - "client": "client0", - "databaseName": "dat" + "uriOptions": {"retryReads": True}, } }, + {"database": {"id": "database0", "client": "client0", "databaseName": "dat"}}, { "collection": { "id": "collection0", "database": "database0", - "collectionName": "dat" + "collectionName": "dat", } - } + }, ], - "tests": [ { "description": "test loops", @@ -99,33 +84,21 @@ def test_store_all_others_as_entities(self): "numIterations": 5, "operations": [ { - "name": "insertOne", - "object": "collection0", - "arguments": { - "document": { - "_id": 1, - "x": 44 - } - } - + "name": "insertOne", + "object": "collection0", + "arguments": {"document": {"_id": 1, "x": 44}}, }, { "name": "insertOne", "object": "collection0", - "arguments": { - "document": { - "_id": 1, - "x": 44 - } - } - - } - ] - } + "arguments": {"document": {"_id": 1, "x": 44}}, + }, + ], + }, } - ] + ], } - ] + ], } self.scenario_runner.TEST_SPEC = spec diff --git a/test/test_crud_unified.py b/test/test_crud_unified.py index a435c1caa1..cc9a521b3b 100644 --- a/test/test_crud_unified.py +++ b/test/test_crud_unified.py @@ -20,16 +20,13 @@ sys.path[0:0] = [""] from test import unittest - from test.unified_format import generate_test_classes # Location of JSON test specifications. -TEST_PATH = os.path.join( - os.path.dirname(os.path.realpath(__file__)), 'crud', 'unified') +TEST_PATH = os.path.join(os.path.dirname(os.path.realpath(__file__)), "crud", "unified") # Generate unified tests. -globals().update(generate_test_classes( - TEST_PATH, module=__name__, RUN_ON_SERVERLESS=True)) +globals().update(generate_test_classes(TEST_PATH, module=__name__, RUN_ON_SERVERLESS=True)) if __name__ == "__main__": unittest.main() diff --git a/test/test_crud_v1.py b/test/test_crud_v1.py index 5a63e030fe..8f3bf2d998 100644 --- a/test/test_crud_v1.py +++ b/test/test_crud_v1.py @@ -19,26 +19,32 @@ sys.path[0:0] = [""] -from pymongo import operations, WriteConcern +from test import IntegrationTest, client_context, unittest +from test.utils import ( + TestCreator, + camel_to_snake, + camel_to_snake_args, + camel_to_upper_camel, + drop_collections, +) + +from pymongo import WriteConcern, operations from pymongo.command_cursor import CommandCursor from pymongo.cursor import Cursor from pymongo.errors import PyMongoError +from pymongo.operations import ( + DeleteMany, + DeleteOne, + InsertOne, + ReplaceOne, + UpdateMany, + UpdateOne, +) from pymongo.read_concern import ReadConcern -from pymongo.results import _WriteResult, BulkWriteResult -from pymongo.operations import (InsertOne, - DeleteOne, - DeleteMany, - ReplaceOne, - UpdateOne, - UpdateMany) - -from test import client_context, unittest, IntegrationTest -from test.utils import (camel_to_snake, camel_to_upper_camel, - camel_to_snake_args, drop_collections, TestCreator) +from pymongo.results import BulkWriteResult, _WriteResult # Location of JSON test specifications. -_TEST_PATH = os.path.join( - os.path.dirname(os.path.realpath(__file__)), 'crud', 'v1') +_TEST_PATH = os.path.join(os.path.dirname(os.path.realpath(__file__)), "crud", "v1") class TestAllScenarios(IntegrationTest): @@ -51,8 +57,7 @@ def check_result(self, expected_result, result): prop = camel_to_snake(res) msg = "%s : %r != %r" % (prop, expected_result, result) # SPEC-869: Only BulkWriteResult has upserted_count. - if (prop == "upserted_count" - and not isinstance(result, BulkWriteResult)): + if prop == "upserted_count" and not isinstance(result, BulkWriteResult): if result.upserted_id is not None: upserted_count = 1 else: @@ -61,8 +66,7 @@ def check_result(self, expected_result, result): elif prop == "inserted_ids": # BulkWriteResult does not have inserted_ids. if isinstance(result, BulkWriteResult): - self.assertEqual(len(expected_result[res]), - result.inserted_count) + self.assertEqual(len(expected_result[res]), result.inserted_count) else: # InsertManyResult may be compared to [id1] from the # crud spec or {"0": id1} from the retryable write spec. @@ -78,8 +82,7 @@ def check_result(self, expected_result, result): expected_ids[int(str_index)] = ids[str_index] self.assertEqual(expected_ids, result.upserted_ids, msg) else: - self.assertEqual( - getattr(result, prop), expected_result[res], msg) + self.assertEqual(getattr(result, prop), expected_result[res], msg) else: self.assertEqual(result, expected_result) @@ -87,16 +90,16 @@ def check_result(self, expected_result, result): def run_operation(collection, test): # Convert command from CamelCase to pymongo.collection method. - operation = camel_to_snake(test['operation']['name']) + operation = camel_to_snake(test["operation"]["name"]) cmd = getattr(collection, operation) # Convert arguments to snake_case and handle special cases. - arguments = test['operation']['arguments'] + arguments = test["operation"]["arguments"] options = arguments.pop("options", {}) for option_name in options: arguments[camel_to_snake(option_name)] = options[option_name] - if operation == 'count': - raise unittest.SkipTest('PyMongo does not support count') + if operation == "count": + raise unittest.SkipTest("PyMongo does not support count") if operation == "bulk_write": # Parse each request into a bulk write model. requests = [] @@ -137,15 +140,15 @@ def create_test(scenario_def, test, name): def run_scenario(self): # Cleanup state and load data (if provided). drop_collections(self.db) - data = scenario_def.get('data') + data = scenario_def.get("data") if data: - self.db.test.with_options( - write_concern=WriteConcern(w="majority")).insert_many( - scenario_def['data']) + self.db.test.with_options(write_concern=WriteConcern(w="majority")).insert_many( + scenario_def["data"] + ) # Run operations and check results or errors. - expected_result = test.get('outcome', {}).get('result') - expected_error = test.get('outcome', {}).get('error') + expected_result = test.get("outcome", {}).get("result") + expected_error = test.get("outcome", {}).get("error") if expected_error is True: with self.assertRaises(PyMongoError): run_operation(self.db.test, test) @@ -155,16 +158,15 @@ def run_scenario(self): check_result(self, expected_result, result) # Assert final state is expected. - expected_c = test['outcome'].get('collection') + expected_c = test["outcome"].get("collection") if expected_c is not None: - expected_name = expected_c.get('name') + expected_name = expected_c.get("name") if expected_name is not None: db_coll = self.db[expected_name] else: db_coll = self.db.test - db_coll = db_coll.with_options( - read_concern=ReadConcern(level="local")) - self.assertEqual(list(db_coll.find()), expected_c['data']) + db_coll = db_coll.with_options(read_concern=ReadConcern(level="local")) + self.assertEqual(list(db_coll.find()), expected_c["data"]) return run_scenario @@ -175,53 +177,68 @@ def run_scenario(self): class TestWriteOpsComparison(unittest.TestCase): def test_InsertOneEquals(self): - self.assertEqual(InsertOne({'foo': 42}), InsertOne({'foo': 42})) + self.assertEqual(InsertOne({"foo": 42}), InsertOne({"foo": 42})) def test_InsertOneNotEquals(self): - self.assertNotEqual(InsertOne({'foo': 42}), InsertOne({'foo': 23})) + self.assertNotEqual(InsertOne({"foo": 42}), InsertOne({"foo": 23})) def test_DeleteOneEquals(self): - self.assertEqual(DeleteOne({'foo': 42}), DeleteOne({'foo': 42})) + self.assertEqual(DeleteOne({"foo": 42}), DeleteOne({"foo": 42})) def test_DeleteOneNotEquals(self): - self.assertNotEqual(DeleteOne({'foo': 42}), DeleteOne({'foo': 23})) + self.assertNotEqual(DeleteOne({"foo": 42}), DeleteOne({"foo": 23})) def test_DeleteManyEquals(self): - self.assertEqual(DeleteMany({'foo': 42}), DeleteMany({'foo': 42})) + self.assertEqual(DeleteMany({"foo": 42}), DeleteMany({"foo": 42})) def test_DeleteManyNotEquals(self): - self.assertNotEqual(DeleteMany({'foo': 42}), DeleteMany({'foo': 23})) + self.assertNotEqual(DeleteMany({"foo": 42}), DeleteMany({"foo": 23})) def test_DeleteOneNotEqualsDeleteMany(self): - self.assertNotEqual(DeleteOne({'foo': 42}), DeleteMany({'foo': 42})) + self.assertNotEqual(DeleteOne({"foo": 42}), DeleteMany({"foo": 42})) def test_ReplaceOneEquals(self): - self.assertEqual(ReplaceOne({'foo': 42}, {'bar': 42}, upsert=False), - ReplaceOne({'foo': 42}, {'bar': 42}, upsert=False)) + self.assertEqual( + ReplaceOne({"foo": 42}, {"bar": 42}, upsert=False), + ReplaceOne({"foo": 42}, {"bar": 42}, upsert=False), + ) def test_ReplaceOneNotEquals(self): - self.assertNotEqual(ReplaceOne({'foo': 42}, {'bar': 42}, upsert=False), - ReplaceOne({'foo': 42}, {'bar': 42}, upsert=True)) + self.assertNotEqual( + ReplaceOne({"foo": 42}, {"bar": 42}, upsert=False), + ReplaceOne({"foo": 42}, {"bar": 42}, upsert=True), + ) def test_UpdateOneEquals(self): - self.assertEqual(UpdateOne({'foo': 42}, {'$set': {'bar': 42}}), - UpdateOne({'foo': 42}, {'$set': {'bar': 42}})) + self.assertEqual( + UpdateOne({"foo": 42}, {"$set": {"bar": 42}}), + UpdateOne({"foo": 42}, {"$set": {"bar": 42}}), + ) def test_UpdateOneNotEquals(self): - self.assertNotEqual(UpdateOne({'foo': 42}, {'$set': {'bar': 42}}), - UpdateOne({'foo': 42}, {'$set': {'bar': 23}})) + self.assertNotEqual( + UpdateOne({"foo": 42}, {"$set": {"bar": 42}}), + UpdateOne({"foo": 42}, {"$set": {"bar": 23}}), + ) def test_UpdateManyEquals(self): - self.assertEqual(UpdateMany({'foo': 42}, {'$set': {'bar': 42}}), - UpdateMany({'foo': 42}, {'$set': {'bar': 42}})) + self.assertEqual( + UpdateMany({"foo": 42}, {"$set": {"bar": 42}}), + UpdateMany({"foo": 42}, {"$set": {"bar": 42}}), + ) def test_UpdateManyNotEquals(self): - self.assertNotEqual(UpdateMany({'foo': 42}, {'$set': {'bar': 42}}), - UpdateMany({'foo': 42}, {'$set': {'bar': 23}})) + self.assertNotEqual( + UpdateMany({"foo": 42}, {"$set": {"bar": 42}}), + UpdateMany({"foo": 42}, {"$set": {"bar": 23}}), + ) def test_UpdateOneNotEqualsUpdateMany(self): - self.assertNotEqual(UpdateOne({'foo': 42}, {'$set': {'bar': 42}}), - UpdateMany({'foo': 42}, {'$set': {'bar': 42}})) + self.assertNotEqual( + UpdateOne({"foo": 42}, {"$set": {"bar": 42}}), + UpdateMany({"foo": 42}, {"$set": {"bar": 42}}), + ) + if __name__ == "__main__": unittest.main() diff --git a/test/test_cursor.py b/test/test_cursor.py index d56f9fc27d..fc0accb711 100644 --- a/test/test_cursor.py +++ b/test/test_cursor.py @@ -19,41 +19,46 @@ import random import re import sys -import time import threading +import time sys.path[0:0] = [""] +from test import IntegrationTest, client_context, unittest +from test.utils import ( + AllowListEventListener, + EventListener, + OvertCommandListener, + ignore_deprecations, + rs_or_single_client, +) + from bson import decode_all from bson.code import Code from bson.son import SON -from pymongo import (ASCENDING, - DESCENDING) +from pymongo import ASCENDING, DESCENDING from pymongo.collation import Collation from pymongo.cursor import Cursor, CursorType -from pymongo.errors import (ConfigurationError, - ExecutionTimeout, - InvalidOperation, - OperationFailure) +from pymongo.errors import ( + ConfigurationError, + ExecutionTimeout, + InvalidOperation, + OperationFailure, +) from pymongo.read_concern import ReadConcern from pymongo.write_concern import WriteConcern -from test import (client_context, - unittest, - IntegrationTest) -from test.utils import (EventListener, - OvertCommandListener, - ignore_deprecations, - rs_or_single_client, - AllowListEventListener) class TestCursor(IntegrationTest): def test_deepcopy_cursor_littered_with_regexes(self): - cursor = self.db.test.find({ - "x": re.compile("^hmmm.*"), - "y": [re.compile("^hmm.*")], - "z": {"a": [re.compile("^hm.*")]}, - re.compile("^key.*"): {"a": [re.compile("^hm.*")]}}) + cursor = self.db.test.find( + { + "x": re.compile("^hmmm.*"), + "y": [re.compile("^hmm.*")], + "z": {"a": [re.compile("^hm.*")]}, + re.compile("^key.*"): {"a": [re.compile("^hm.*")]}, + } + ) cursor2 = copy.deepcopy(cursor) self.assertEqual(cursor._Cursor__spec, cursor2._Cursor__spec) @@ -64,19 +69,15 @@ def test_add_remove_option(self): cursor.add_option(2) cursor2 = self.db.test.find(cursor_type=CursorType.TAILABLE) self.assertEqual(2, cursor2._Cursor__query_flags) - self.assertEqual(cursor._Cursor__query_flags, - cursor2._Cursor__query_flags) + self.assertEqual(cursor._Cursor__query_flags, cursor2._Cursor__query_flags) cursor.add_option(32) cursor2 = self.db.test.find(cursor_type=CursorType.TAILABLE_AWAIT) self.assertEqual(34, cursor2._Cursor__query_flags) - self.assertEqual(cursor._Cursor__query_flags, - cursor2._Cursor__query_flags) + self.assertEqual(cursor._Cursor__query_flags, cursor2._Cursor__query_flags) cursor.add_option(128) - cursor2 = self.db.test.find( - cursor_type=CursorType.TAILABLE_AWAIT).add_option(128) + cursor2 = self.db.test.find(cursor_type=CursorType.TAILABLE_AWAIT).add_option(128) self.assertEqual(162, cursor2._Cursor__query_flags) - self.assertEqual(cursor._Cursor__query_flags, - cursor2._Cursor__query_flags) + self.assertEqual(cursor._Cursor__query_flags, cursor2._Cursor__query_flags) self.assertEqual(162, cursor._Cursor__query_flags) cursor.add_option(128) @@ -85,13 +86,11 @@ def test_add_remove_option(self): cursor.remove_option(128) cursor2 = self.db.test.find(cursor_type=CursorType.TAILABLE_AWAIT) self.assertEqual(34, cursor2._Cursor__query_flags) - self.assertEqual(cursor._Cursor__query_flags, - cursor2._Cursor__query_flags) + self.assertEqual(cursor._Cursor__query_flags, cursor2._Cursor__query_flags) cursor.remove_option(32) cursor2 = self.db.test.find(cursor_type=CursorType.TAILABLE) self.assertEqual(2, cursor2._Cursor__query_flags) - self.assertEqual(cursor._Cursor__query_flags, - cursor2._Cursor__query_flags) + self.assertEqual(cursor._Cursor__query_flags, cursor2._Cursor__query_flags) self.assertEqual(2, cursor._Cursor__query_flags) cursor.remove_option(32) @@ -101,8 +100,7 @@ def test_add_remove_option(self): cursor = self.db.test.find(no_cursor_timeout=True) self.assertEqual(16, cursor._Cursor__query_flags) cursor2 = self.db.test.find().add_option(16) - self.assertEqual(cursor._Cursor__query_flags, - cursor2._Cursor__query_flags) + self.assertEqual(cursor._Cursor__query_flags, cursor2._Cursor__query_flags) cursor.remove_option(16) self.assertEqual(0, cursor._Cursor__query_flags) @@ -110,8 +108,7 @@ def test_add_remove_option(self): cursor = self.db.test.find(cursor_type=CursorType.TAILABLE_AWAIT) self.assertEqual(34, cursor._Cursor__query_flags) cursor2 = self.db.test.find().add_option(34) - self.assertEqual(cursor._Cursor__query_flags, - cursor2._Cursor__query_flags) + self.assertEqual(cursor._Cursor__query_flags, cursor2._Cursor__query_flags) cursor.remove_option(32) self.assertEqual(2, cursor._Cursor__query_flags) @@ -119,8 +116,7 @@ def test_add_remove_option(self): cursor = self.db.test.find(allow_partial_results=True) self.assertEqual(128, cursor._Cursor__query_flags) cursor2 = self.db.test.find().add_option(128) - self.assertEqual(cursor._Cursor__query_flags, - cursor2._Cursor__query_flags) + self.assertEqual(cursor._Cursor__query_flags, cursor2._Cursor__query_flags) cursor.remove_option(128) self.assertEqual(0, cursor._Cursor__query_flags) @@ -133,8 +129,7 @@ def test_add_remove_option_exhaust(self): cursor = self.db.test.find(cursor_type=CursorType.EXHAUST) self.assertEqual(64, cursor._Cursor__query_flags) cursor2 = self.db.test.find().add_option(64) - self.assertEqual(cursor._Cursor__query_flags, - cursor2._Cursor__query_flags) + self.assertEqual(cursor._Cursor__query_flags, cursor2._Cursor__query_flags) self.assertTrue(cursor._Cursor__exhaust) cursor.remove_option(64) self.assertEqual(0, cursor._Cursor__query_flags) @@ -145,7 +140,7 @@ def test_allow_disk_use(self): db.pymongo_test.drop() coll = db.pymongo_test - self.assertRaises(TypeError, coll.find().allow_disk_use, 'baz') + self.assertRaises(TypeError, coll.find().allow_disk_use, "baz") cursor = coll.find().allow_disk_use(True) self.assertEqual(True, cursor._Cursor__allow_disk_use) @@ -156,7 +151,7 @@ def test_max_time_ms(self): db = self.db db.pymongo_test.drop() coll = db.pymongo_test - self.assertRaises(TypeError, coll.find().max_time_ms, 'foo') + self.assertRaises(TypeError, coll.find().max_time_ms, "foo") coll.insert_one({"amalia": 1}) coll.insert_one({"amalia": 2}) @@ -177,12 +172,9 @@ def test_max_time_ms(self): self.assertTrue(coll.find_one(max_time_ms=1000)) client = self.client - if (not client_context.is_mongos - and client_context.test_commands_enabled): + if not client_context.is_mongos and client_context.test_commands_enabled: # Cursor parses server timeout error in response to initial query. - client.admin.command("configureFailPoint", - "maxTimeAlwaysTimeOut", - mode="alwaysOn") + client.admin.command("configureFailPoint", "maxTimeAlwaysTimeOut", mode="alwaysOn") try: cursor = coll.find().max_time_ms(1) try: @@ -191,19 +183,16 @@ def test_max_time_ms(self): pass else: self.fail("ExecutionTimeout not raised") - self.assertRaises(ExecutionTimeout, - coll.find_one, max_time_ms=1) + self.assertRaises(ExecutionTimeout, coll.find_one, max_time_ms=1) finally: - client.admin.command("configureFailPoint", - "maxTimeAlwaysTimeOut", - mode="off") + client.admin.command("configureFailPoint", "maxTimeAlwaysTimeOut", mode="off") def test_max_await_time_ms(self): db = self.db db.pymongo_test.drop() coll = db.create_collection("pymongo_test", capped=True, size=4096) - self.assertRaises(TypeError, coll.find().max_await_time_ms, 'foo') + self.assertRaises(TypeError, coll.find().max_await_time_ms, "foo") coll.insert_one({"amalia": 1}) coll.insert_one({"amalia": 2}) @@ -221,95 +210,91 @@ def test_max_await_time_ms(self): self.assertEqual(None, cursor._Cursor__max_await_time_ms) # If cursor is tailable_await and timeout is set - cursor = coll.find( - cursor_type=CursorType.TAILABLE_AWAIT).max_await_time_ms(99) + cursor = coll.find(cursor_type=CursorType.TAILABLE_AWAIT).max_await_time_ms(99) self.assertEqual(99, cursor._Cursor__max_await_time_ms) - cursor = coll.find( - cursor_type=CursorType.TAILABLE_AWAIT).max_await_time_ms( - 10).max_await_time_ms(90) + cursor = ( + coll.find(cursor_type=CursorType.TAILABLE_AWAIT) + .max_await_time_ms(10) + .max_await_time_ms(90) + ) self.assertEqual(90, cursor._Cursor__max_await_time_ms) - listener = AllowListEventListener('find', 'getMore') - coll = rs_or_single_client( - event_listeners=[listener])[self.db.name].pymongo_test + listener = AllowListEventListener("find", "getMore") + coll = rs_or_single_client(event_listeners=[listener])[self.db.name].pymongo_test results = listener.results # Tailable_await defaults. list(coll.find(cursor_type=CursorType.TAILABLE_AWAIT)) # find - self.assertFalse('maxTimeMS' in results['started'][0].command) + self.assertFalse("maxTimeMS" in results["started"][0].command) # getMore - self.assertFalse('maxTimeMS' in results['started'][1].command) + self.assertFalse("maxTimeMS" in results["started"][1].command) results.clear() # Tailable_await with max_await_time_ms set. - list(coll.find( - cursor_type=CursorType.TAILABLE_AWAIT).max_await_time_ms(99)) + list(coll.find(cursor_type=CursorType.TAILABLE_AWAIT).max_await_time_ms(99)) # find - self.assertEqual('find', results['started'][0].command_name) - self.assertFalse('maxTimeMS' in results['started'][0].command) + self.assertEqual("find", results["started"][0].command_name) + self.assertFalse("maxTimeMS" in results["started"][0].command) # getMore - self.assertEqual('getMore', results['started'][1].command_name) - self.assertTrue('maxTimeMS' in results['started'][1].command) - self.assertEqual(99, results['started'][1].command['maxTimeMS']) + self.assertEqual("getMore", results["started"][1].command_name) + self.assertTrue("maxTimeMS" in results["started"][1].command) + self.assertEqual(99, results["started"][1].command["maxTimeMS"]) results.clear() # Tailable_await with max_time_ms - list(coll.find( - cursor_type=CursorType.TAILABLE_AWAIT).max_time_ms(99)) + list(coll.find(cursor_type=CursorType.TAILABLE_AWAIT).max_time_ms(99)) # find - self.assertEqual('find', results['started'][0].command_name) - self.assertTrue('maxTimeMS' in results['started'][0].command) - self.assertEqual(99, results['started'][0].command['maxTimeMS']) + self.assertEqual("find", results["started"][0].command_name) + self.assertTrue("maxTimeMS" in results["started"][0].command) + self.assertEqual(99, results["started"][0].command["maxTimeMS"]) # getMore - self.assertEqual('getMore', results['started'][1].command_name) - self.assertFalse('maxTimeMS' in results['started'][1].command) + self.assertEqual("getMore", results["started"][1].command_name) + self.assertFalse("maxTimeMS" in results["started"][1].command) results.clear() # Tailable_await with both max_time_ms and max_await_time_ms - list(coll.find( - cursor_type=CursorType.TAILABLE_AWAIT).max_time_ms( - 99).max_await_time_ms(99)) + list(coll.find(cursor_type=CursorType.TAILABLE_AWAIT).max_time_ms(99).max_await_time_ms(99)) # find - self.assertEqual('find', results['started'][0].command_name) - self.assertTrue('maxTimeMS' in results['started'][0].command) - self.assertEqual(99, results['started'][0].command['maxTimeMS']) + self.assertEqual("find", results["started"][0].command_name) + self.assertTrue("maxTimeMS" in results["started"][0].command) + self.assertEqual(99, results["started"][0].command["maxTimeMS"]) # getMore - self.assertEqual('getMore', results['started'][1].command_name) - self.assertTrue('maxTimeMS' in results['started'][1].command) - self.assertEqual(99, results['started'][1].command['maxTimeMS']) + self.assertEqual("getMore", results["started"][1].command_name) + self.assertTrue("maxTimeMS" in results["started"][1].command) + self.assertEqual(99, results["started"][1].command["maxTimeMS"]) results.clear() # Non tailable_await with max_await_time_ms list(coll.find(batch_size=1).max_await_time_ms(99)) # find - self.assertEqual('find', results['started'][0].command_name) - self.assertFalse('maxTimeMS' in results['started'][0].command) + self.assertEqual("find", results["started"][0].command_name) + self.assertFalse("maxTimeMS" in results["started"][0].command) # getMore - self.assertEqual('getMore', results['started'][1].command_name) - self.assertFalse('maxTimeMS' in results['started'][1].command) + self.assertEqual("getMore", results["started"][1].command_name) + self.assertFalse("maxTimeMS" in results["started"][1].command) results.clear() # Non tailable_await with max_time_ms list(coll.find(batch_size=1).max_time_ms(99)) # find - self.assertEqual('find', results['started'][0].command_name) - self.assertTrue('maxTimeMS' in results['started'][0].command) - self.assertEqual(99, results['started'][0].command['maxTimeMS']) + self.assertEqual("find", results["started"][0].command_name) + self.assertTrue("maxTimeMS" in results["started"][0].command) + self.assertEqual(99, results["started"][0].command["maxTimeMS"]) # getMore - self.assertEqual('getMore', results['started'][1].command_name) - self.assertFalse('maxTimeMS' in results['started'][1].command) + self.assertEqual("getMore", results["started"][1].command_name) + self.assertFalse("maxTimeMS" in results["started"][1].command) # Non tailable_await with both max_time_ms and max_await_time_ms list(coll.find(batch_size=1).max_time_ms(99).max_await_time_ms(88)) # find - self.assertEqual('find', results['started'][0].command_name) - self.assertTrue('maxTimeMS' in results['started'][0].command) - self.assertEqual(99, results['started'][0].command['maxTimeMS']) + self.assertEqual("find", results["started"][0].command_name) + self.assertTrue("maxTimeMS" in results["started"][0].command) + self.assertEqual(99, results["started"][0].command["maxTimeMS"]) # getMore - self.assertEqual('getMore', results['started'][1].command_name) - self.assertFalse('maxTimeMS' in results['started'][1].command) + self.assertEqual("getMore", results["started"][1].command_name) + self.assertFalse("maxTimeMS" in results["started"][1].command) @client_context.require_test_commands @client_context.require_no_mongos @@ -321,9 +306,7 @@ def test_max_time_ms_getmore(self): # Send initial query before turning on failpoint. next(cursor) - self.client.admin.command("configureFailPoint", - "maxTimeAlwaysTimeOut", - mode="alwaysOn") + self.client.admin.command("configureFailPoint", "maxTimeAlwaysTimeOut", mode="alwaysOn") try: try: # Iterate up to first getmore. @@ -333,9 +316,7 @@ def test_max_time_ms_getmore(self): else: self.fail("ExecutionTimeout not raised") finally: - self.client.admin.command("configureFailPoint", - "maxTimeAlwaysTimeOut", - mode="off") + self.client.admin.command("configureFailPoint", "maxTimeAlwaysTimeOut", mode="off") def test_explain(self): a = self.db.test.find() @@ -351,10 +332,9 @@ def test_explain_with_read_concern(self): listener = AllowListEventListener("explain") client = rs_or_single_client(event_listeners=[listener]) self.addCleanup(client.close) - coll = client.pymongo_test.test.with_options( - read_concern=ReadConcern(level="local")) + coll = client.pymongo_test.test.with_options(read_concern=ReadConcern(level="local")) self.assertTrue(coll.find().explain()) - started = listener.results['started'] + started = listener.results["started"] self.assertEqual(len(started), 1) self.assertNotIn("readConcern", started[0].command) @@ -365,23 +345,26 @@ def test_hint(self): db.test.insert_many([{"num": i, "foo": i} for i in range(100)]) - self.assertRaises(OperationFailure, - db.test.find({"num": 17, "foo": 17}) - .hint([("num", ASCENDING)]).explain) - self.assertRaises(OperationFailure, - db.test.find({"num": 17, "foo": 17}) - .hint([("foo", ASCENDING)]).explain) + self.assertRaises( + OperationFailure, + db.test.find({"num": 17, "foo": 17}).hint([("num", ASCENDING)]).explain, + ) + self.assertRaises( + OperationFailure, + db.test.find({"num": 17, "foo": 17}).hint([("foo", ASCENDING)]).explain, + ) spec = [("num", DESCENDING)] index = db.test.create_index(spec) first = next(db.test.find()) - self.assertEqual(0, first.get('num')) + self.assertEqual(0, first.get("num")) first = next(db.test.find().hint(spec)) - self.assertEqual(99, first.get('num')) - self.assertRaises(OperationFailure, - db.test.find({"num": 17, "foo": 17}) - .hint([("foo", ASCENDING)]).explain) + self.assertEqual(99, first.get("num")) + self.assertRaises( + OperationFailure, + db.test.find({"num": 17, "foo": 17}).hint([("foo", ASCENDING)]).explain, + ) a = db.test.find({"num": 17}) a.hint(spec) @@ -395,11 +378,11 @@ def test_hint_by_name(self): db.test.insert_many([{"i": i} for i in range(100)]) - db.test.create_index([('i', DESCENDING)], name='fooindex') + db.test.create_index([("i", DESCENDING)], name="fooindex") first = next(db.test.find()) - self.assertEqual(0, first.get('i')) - first = next(db.test.find().hint('fooindex')) - self.assertEqual(99, first.get('i')) + self.assertEqual(0, first.get("i")) + first = next(db.test.find().hint("fooindex")) + self.assertEqual(99, first.get("i")) def test_limit(self): db = self.db @@ -702,8 +685,7 @@ def test_sort(self): self.assertRaises(TypeError, db.test.find().sort, 5) self.assertRaises(ValueError, db.test.find().sort, []) self.assertRaises(TypeError, db.test.find().sort, [], ASCENDING) - self.assertRaises(TypeError, db.test.find().sort, - [("hello", DESCENDING)], DESCENDING) + self.assertRaises(TypeError, db.test.find().sort, [("hello", DESCENDING)], DESCENDING) db.test.drop() @@ -724,8 +706,7 @@ def test_sort(self): self.assertEqual(desc, expect) desc = [i["x"] for i in db.test.find().sort([("x", DESCENDING)])] self.assertEqual(desc, expect) - desc = [i["x"] for i in - db.test.find().sort("x", ASCENDING).sort("x", DESCENDING)] + desc = [i["x"] for i in db.test.find().sort("x", ASCENDING).sort("x", DESCENDING)] self.assertEqual(desc, expect) expected = [(1, 5), (2, 5), (0, 3), (7, 3), (9, 2), (2, 1), (3, 1)] @@ -736,9 +717,9 @@ def test_sort(self): for (a, b) in shuffled: db.test.insert_one({"a": a, "b": b}) - result = [(i["a"], i["b"]) for i in - db.test.find().sort([("b", DESCENDING), - ("a", ASCENDING)])] + result = [ + (i["a"], i["b"]) for i in db.test.find().sort([("b", DESCENDING), ("a", ASCENDING)]) + ] self.assertEqual(result, expected) a = db.test.find() @@ -758,42 +739,34 @@ def test_where(self): db.test.insert_many([{"x": i} for i in range(10)]) - self.assertEqual(3, len(list(db.test.find().where('this.x < 3')))) - self.assertEqual(3, - len(list(db.test.find().where(Code('this.x < 3'))))) + self.assertEqual(3, len(list(db.test.find().where("this.x < 3")))) + self.assertEqual(3, len(list(db.test.find().where(Code("this.x < 3"))))) - code_with_scope = Code('this.x < i', {"i": 3}) + code_with_scope = Code("this.x < i", {"i": 3}) if client_context.version.at_least(4, 3, 3): # MongoDB 4.4 removed support for Code with scope. with self.assertRaises(OperationFailure): list(db.test.find().where(code_with_scope)) - code_with_empty_scope = Code('this.x < 3', {}) + code_with_empty_scope = Code("this.x < 3", {}) with self.assertRaises(OperationFailure): list(db.test.find().where(code_with_empty_scope)) else: - self.assertEqual( - 3, len(list(db.test.find().where(code_with_scope)))) + self.assertEqual(3, len(list(db.test.find().where(code_with_scope)))) self.assertEqual(10, len(list(db.test.find()))) - self.assertEqual([0, 1, 2], - [a["x"] for a in - db.test.find().where('this.x < 3')]) - self.assertEqual([], - [a["x"] for a in - db.test.find({"x": 5}).where('this.x < 3')]) - self.assertEqual([5], - [a["x"] for a in - db.test.find({"x": 5}).where('this.x > 3')]) - - cursor = db.test.find().where('this.x < 3').where('this.x > 7') + self.assertEqual([0, 1, 2], [a["x"] for a in db.test.find().where("this.x < 3")]) + self.assertEqual([], [a["x"] for a in db.test.find({"x": 5}).where("this.x < 3")]) + self.assertEqual([5], [a["x"] for a in db.test.find({"x": 5}).where("this.x > 3")]) + + cursor = db.test.find().where("this.x < 3").where("this.x > 7") self.assertEqual([8, 9], [a["x"] for a in cursor]) a = db.test.find() - b = a.where('this.x > 3') + b = a.where("this.x > 3") for _ in a: break - self.assertRaises(InvalidOperation, a.where, 'this.x < 3') + self.assertRaises(InvalidOperation, a.where, "this.x < 3") def test_rewind(self): self.db.test.insert_many([{"x": i} for i in range(1, 4)]) @@ -866,26 +839,28 @@ def test_clone(self): self.assertNotEqual(cursor, cursor.clone()) # Just test attributes - cursor = self.db.test.find({"x": re.compile("^hello.*")}, - projection={'_id': False}, - skip=1, - no_cursor_timeout=True, - cursor_type=CursorType.TAILABLE_AWAIT, - sort=[("x", 1)], - allow_partial_results=True, - oplog_replay=True, - batch_size=123, - collation={'locale': 'en_US'}, - hint=[("_id", 1)], - max_scan=100, - max_time_ms=1000, - return_key=True, - show_record_id=True, - snapshot=True, - allow_disk_use=True).limit(2) - cursor.min([('a', 1)]).max([('b', 3)]) + cursor = self.db.test.find( + {"x": re.compile("^hello.*")}, + projection={"_id": False}, + skip=1, + no_cursor_timeout=True, + cursor_type=CursorType.TAILABLE_AWAIT, + sort=[("x", 1)], + allow_partial_results=True, + oplog_replay=True, + batch_size=123, + collation={"locale": "en_US"}, + hint=[("_id", 1)], + max_scan=100, + max_time_ms=1000, + return_key=True, + show_record_id=True, + snapshot=True, + allow_disk_use=True, + ).limit(2) + cursor.min([("a", 1)]).max([("b", 3)]) cursor.add_option(128) - cursor.comment('hi!') + cursor.comment("hi!") # Every attribute should be the same. cursor2 = cursor.clone() @@ -893,17 +868,17 @@ def test_clone(self): # Shallow copies can so can mutate cursor2 = copy.copy(cursor) - cursor2._Cursor__projection['cursor2'] = False - self.assertTrue('cursor2' in cursor._Cursor__projection) + cursor2._Cursor__projection["cursor2"] = False + self.assertTrue("cursor2" in cursor._Cursor__projection) # Deepcopies and shouldn't mutate cursor3 = copy.deepcopy(cursor) - cursor3._Cursor__projection['cursor3'] = False - self.assertFalse('cursor3' in cursor._Cursor__projection) + cursor3._Cursor__projection["cursor3"] = False + self.assertFalse("cursor3" in cursor._Cursor__projection) cursor4 = cursor.clone() - cursor4._Cursor__projection['cursor4'] = False - self.assertFalse('cursor4' in cursor._Cursor__projection) + cursor4._Cursor__projection["cursor4"] = False + self.assertFalse("cursor4" in cursor._Cursor__projection) # Test memo when deepcopying queries query = {"hello": "world"} @@ -912,14 +887,12 @@ def test_clone(self): cursor2 = copy.deepcopy(cursor) - self.assertNotEqual(id(cursor._Cursor__spec), - id(cursor2._Cursor__spec)) - self.assertEqual(id(cursor2._Cursor__spec['reflexive']), - id(cursor2._Cursor__spec)) + self.assertNotEqual(id(cursor._Cursor__spec), id(cursor2._Cursor__spec)) + self.assertEqual(id(cursor2._Cursor__spec["reflexive"]), id(cursor2._Cursor__spec)) self.assertEqual(len(cursor2._Cursor__spec), 2) # Ensure hints are cloned as the correct type - cursor = self.db.test.find().hint([('z', 1), ("a", 1)]) + cursor = self.db.test.find().hint([("z", 1), ("a", 1)]) cursor2 = copy.deepcopy(cursor) self.assertTrue(isinstance(cursor2._Cursor__hint, SON)) self.assertEqual(cursor._Cursor__hint, cursor2._Cursor__hint) @@ -947,46 +920,38 @@ def test_getitem_slice_index(self): self.assertRaises(IndexError, lambda: self.db.test.find()[1:2:2]) for a, b in zip(count(0), self.db.test.find()): - self.assertEqual(a, b['i']) + self.assertEqual(a, b["i"]) self.assertEqual(100, len(list(self.db.test.find()[0:]))) for a, b in zip(count(0), self.db.test.find()[0:]): - self.assertEqual(a, b['i']) + self.assertEqual(a, b["i"]) self.assertEqual(80, len(list(self.db.test.find()[20:]))) for a, b in zip(count(20), self.db.test.find()[20:]): - self.assertEqual(a, b['i']) + self.assertEqual(a, b["i"]) for a, b in zip(count(99), self.db.test.find()[99:]): - self.assertEqual(a, b['i']) + self.assertEqual(a, b["i"]) for i in self.db.test.find()[1000:]: self.fail() self.assertEqual(5, len(list(self.db.test.find()[20:25]))) - self.assertEqual(5, len(list( - self.db.test.find()[20:25]))) + self.assertEqual(5, len(list(self.db.test.find()[20:25]))) for a, b in zip(count(20), self.db.test.find()[20:25]): - self.assertEqual(a, b['i']) + self.assertEqual(a, b["i"]) self.assertEqual(80, len(list(self.db.test.find()[40:45][20:]))) for a, b in zip(count(20), self.db.test.find()[40:45][20:]): - self.assertEqual(a, b['i']) - - self.assertEqual(80, - len(list(self.db.test.find()[40:45].limit(0).skip(20)) - ) - ) - for a, b in zip(count(20), - self.db.test.find()[40:45].limit(0).skip(20)): - self.assertEqual(a, b['i']) - - self.assertEqual(80, - len(list(self.db.test.find().limit(10).skip(40)[20:])) - ) - for a, b in zip(count(20), - self.db.test.find().limit(10).skip(40)[20:]): - self.assertEqual(a, b['i']) + self.assertEqual(a, b["i"]) + + self.assertEqual(80, len(list(self.db.test.find()[40:45].limit(0).skip(20)))) + for a, b in zip(count(20), self.db.test.find()[40:45].limit(0).skip(20)): + self.assertEqual(a, b["i"]) + + self.assertEqual(80, len(list(self.db.test.find().limit(10).skip(40)[20:]))) + for a, b in zip(count(20), self.db.test.find().limit(10).skip(40)[20:]): + self.assertEqual(a, b["i"]) self.assertEqual(1, len(list(self.db.test.find()[:1]))) self.assertEqual(5, len(list(self.db.test.find()[:5]))) @@ -995,10 +960,7 @@ def test_getitem_slice_index(self): self.assertEqual(1, len(list(self.db.test.find()[99:1000]))) self.assertEqual(0, len(list(self.db.test.find()[10:10]))) self.assertEqual(0, len(list(self.db.test.find()[:0]))) - self.assertEqual(80, - len(list(self.db.test.find()[10:10].limit(0).skip(20)) - ) - ) + self.assertEqual(80, len(list(self.db.test.find()[10:10].limit(0).skip(20)))) self.assertRaises(IndexError, lambda: self.db.test.find()[10:8]) @@ -1006,17 +968,16 @@ def test_getitem_numeric_index(self): self.db.drop_collection("test") self.db.test.insert_many([{"i": i} for i in range(100)]) - self.assertEqual(0, self.db.test.find()[0]['i']) - self.assertEqual(50, self.db.test.find()[50]['i']) - self.assertEqual(50, self.db.test.find().skip(50)[0]['i']) - self.assertEqual(50, self.db.test.find().skip(49)[1]['i']) - self.assertEqual(50, self.db.test.find()[50]['i']) - self.assertEqual(99, self.db.test.find()[99]['i']) + self.assertEqual(0, self.db.test.find()[0]["i"]) + self.assertEqual(50, self.db.test.find()[50]["i"]) + self.assertEqual(50, self.db.test.find().skip(50)[0]["i"]) + self.assertEqual(50, self.db.test.find().skip(49)[1]["i"]) + self.assertEqual(50, self.db.test.find()[50]["i"]) + self.assertEqual(99, self.db.test.find()[99]["i"]) self.assertRaises(IndexError, lambda x: self.db.test.find()[x], -1) self.assertRaises(IndexError, lambda x: self.db.test.find()[x], 100) - self.assertRaises(IndexError, - lambda x: self.db.test.find().skip(50)[x], 50) + self.assertRaises(IndexError, lambda x: self.db.test.find().skip(50)[x], 50) def test_len(self): self.assertRaises(TypeError, len, self.db.test.find()) @@ -1032,7 +993,7 @@ def set_coll(): def test_get_more(self): db = self.db db.drop_collection("test") - db.test.insert_many([{'i': i} for i in range(10)]) + db.test.insert_many([{"i": i} for i in range(10)]) self.assertEqual(10, len(list(db.test.find().batch_size(5)))) def test_tailable(self): @@ -1075,8 +1036,10 @@ def test_tailable(self): self.assertEqual(3, db.test.count_documents({})) # __getitem__(index) - for cursor in (db.test.find(cursor_type=CursorType.TAILABLE), - db.test.find(cursor_type=CursorType.TAILABLE_AWAIT)): + for cursor in ( + db.test.find(cursor_type=CursorType.TAILABLE), + db.test.find(cursor_type=CursorType.TAILABLE_AWAIT), + ): self.assertEqual(4, cursor[0]["x"]) self.assertEqual(5, cursor[1]["x"]) self.assertEqual(6, cursor[2]["x"]) @@ -1106,6 +1069,7 @@ def iterate_cursor(): while cursor.alive: for doc in cursor: pass + t = threading.Thread(target=iterate_cursor) t.start() time.sleep(1) @@ -1114,12 +1078,10 @@ def iterate_cursor(): t.join(3) self.assertFalse(t.is_alive()) - def test_distinct(self): self.db.drop_collection("test") - self.db.test.insert_many( - [{"a": 1}, {"a": 2}, {"a": 2}, {"a": 2}, {"a": 3}]) + self.db.test.insert_many([{"a": 1}, {"a": 2}, {"a": 2}, {"a": 2}, {"a": 3}]) distinct = self.db.test.find({"a": {"$lt": 3}}).distinct("a") distinct.sort() @@ -1145,8 +1107,7 @@ def test_max_scan(self): self.assertEqual(100, len(list(self.db.test.find()))) self.assertEqual(50, len(list(self.db.test.find().max_scan(50)))) - self.assertEqual(50, len(list(self.db.test.find() - .max_scan(90).max_scan(50)))) + self.assertEqual(50, len(list(self.db.test.find().max_scan(90).max_scan(50)))) def test_with_statement(self): self.db.drop_collection("test") @@ -1165,28 +1126,32 @@ def test_with_statement(self): @client_context.require_no_mongos def test_comment(self): self.client.drop_database(self.db) - self.db.command('profile', 2) # Profile ALL commands. + self.db.command("profile", 2) # Profile ALL commands. try: - list(self.db.test.find().comment('foo')) + list(self.db.test.find().comment("foo")) count = self.db.system.profile.count_documents( - {'ns': 'pymongo_test.test', 'op': 'query', - 'command.comment': 'foo'}) + {"ns": "pymongo_test.test", "op": "query", "command.comment": "foo"} + ) self.assertEqual(count, 1) - self.db.test.find().comment('foo').distinct('type') + self.db.test.find().comment("foo").distinct("type") count = self.db.system.profile.count_documents( - {'ns': 'pymongo_test.test', 'op': 'command', - 'command.distinct': 'test', - 'command.comment': 'foo'}) + { + "ns": "pymongo_test.test", + "op": "command", + "command.distinct": "test", + "command.comment": "foo", + } + ) self.assertEqual(count, 1) finally: - self.db.command('profile', 0) # Turn off profiling. + self.db.command("profile", 0) # Turn off profiling. self.db.system.profile.drop() self.db.test.insert_many([{}, {}]) cursor = self.db.test.find() next(cursor) - self.assertRaises(InvalidOperation, cursor.comment, 'hello') + self.assertRaises(InvalidOperation, cursor.comment, "hello") def test_alive(self): self.db.test.delete_many({}) @@ -1230,8 +1195,7 @@ def assertCursorKilled(): self.assertEqual(1, len(results["started"])) self.assertEqual("killCursors", results["started"][0].command_name) self.assertEqual(1, len(results["succeeded"])) - self.assertEqual("killCursors", - results["succeeded"][0].command_name) + self.assertEqual("killCursors", results["succeeded"][0].command_name) assertCursorKilled() results.clear() @@ -1254,9 +1218,8 @@ def test_delete_not_initialized(self): cursor.__del__() # no error def test_getMore_does_not_send_readPreference(self): - listener = AllowListEventListener('find', 'getMore') - client = rs_or_single_client( - event_listeners=[listener]) + listener = AllowListEventListener("find", "getMore") + client = rs_or_single_client(event_listeners=[listener]) self.addCleanup(client.close) coll = client[self.db.name].test @@ -1265,21 +1228,21 @@ def test_getMore_does_not_send_readPreference(self): self.addCleanup(coll.drop) list(coll.find(batch_size=3)) - started = listener.results['started'] + started = listener.results["started"] self.assertEqual(2, len(started)) - self.assertEqual('find', started[0].command_name) - self.assertIn('$readPreference', started[0].command) - self.assertEqual('getMore', started[1].command_name) - self.assertNotIn('$readPreference', started[1].command) + self.assertEqual("find", started[0].command_name) + self.assertIn("$readPreference", started[0].command) + self.assertEqual("getMore", started[1].command_name) + self.assertNotIn("$readPreference", started[1].command) class TestRawBatchCursor(IntegrationTest): def test_find_raw(self): c = self.db.test c.drop() - docs = [{'_id': i, 'x': 3.0 * i} for i in range(10)] + docs = [{"_id": i, "x": 3.0 * i} for i in range(10)] c.insert_many(docs) - batches = list(c.find_raw_batches().sort('_id')) + batches = list(c.find_raw_batches().sort("_id")) self.assertEqual(1, len(batches)) self.assertEqual(docs, decode_all(batches[0])) @@ -1287,24 +1250,27 @@ def test_find_raw(self): def test_find_raw_transaction(self): c = self.db.test c.drop() - docs = [{'_id': i, 'x': 3.0 * i} for i in range(10)] + docs = [{"_id": i, "x": 3.0 * i} for i in range(10)] c.insert_many(docs) listener = OvertCommandListener() client = rs_or_single_client(event_listeners=[listener]) with client.start_session() as session: with session.start_transaction(): - batches = list(client[self.db.name].test.find_raw_batches( - session=session).sort('_id')) - cmd = listener.results['started'][0] - self.assertEqual(cmd.command_name, 'find') - self.assertIn('$clusterTime', cmd.command) - self.assertEqual(cmd.command['startTransaction'], True) - self.assertEqual(cmd.command['txnNumber'], 1) + batches = list( + client[self.db.name].test.find_raw_batches(session=session).sort("_id") + ) + cmd = listener.results["started"][0] + self.assertEqual(cmd.command_name, "find") + self.assertIn("$clusterTime", cmd.command) + self.assertEqual(cmd.command["startTransaction"], True) + self.assertEqual(cmd.command["txnNumber"], 1) # Ensure we update $clusterTime from the command response. - last_cmd = listener.results['succeeded'][-1] - self.assertEqual(last_cmd.reply['$clusterTime']['clusterTime'], - session.cluster_time['clusterTime']) + last_cmd = listener.results["succeeded"][-1] + self.assertEqual( + last_cmd.reply["$clusterTime"]["clusterTime"], + session.cluster_time["clusterTime"], + ) self.assertEqual(1, len(batches)) self.assertEqual(docs, decode_all(batches[0])) @@ -1314,47 +1280,42 @@ def test_find_raw_transaction(self): def test_find_raw_retryable_reads(self): c = self.db.test c.drop() - docs = [{'_id': i, 'x': 3.0 * i} for i in range(10)] + docs = [{"_id": i, "x": 3.0 * i} for i in range(10)] c.insert_many(docs) listener = OvertCommandListener() - client = rs_or_single_client(event_listeners=[listener], - retryReads=True) - with self.fail_point({ - 'mode': {'times': 1}, 'data': {'failCommands': ['find'], - 'closeConnection': True}}): - batches = list( - client[self.db.name].test.find_raw_batches().sort('_id')) + client = rs_or_single_client(event_listeners=[listener], retryReads=True) + with self.fail_point( + {"mode": {"times": 1}, "data": {"failCommands": ["find"], "closeConnection": True}} + ): + batches = list(client[self.db.name].test.find_raw_batches().sort("_id")) self.assertEqual(1, len(batches)) self.assertEqual(docs, decode_all(batches[0])) - self.assertEqual(len(listener.results['started']), 2) - for cmd in listener.results['started']: - self.assertEqual(cmd.command_name, 'find') + self.assertEqual(len(listener.results["started"]), 2) + for cmd in listener.results["started"]: + self.assertEqual(cmd.command_name, "find") @client_context.require_version_min(5, 0, 0) @client_context.require_no_standalone def test_find_raw_snapshot_reads(self): - c = self.db.get_collection( - "test", write_concern=WriteConcern(w="majority")) + c = self.db.get_collection("test", write_concern=WriteConcern(w="majority")) c.drop() - docs = [{'_id': i, 'x': 3.0 * i} for i in range(10)] + docs = [{"_id": i, "x": 3.0 * i} for i in range(10)] c.insert_many(docs) listener = OvertCommandListener() - client = rs_or_single_client(event_listeners=[listener], - retryReads=True) + client = rs_or_single_client(event_listeners=[listener], retryReads=True) db = client[self.db.name] with client.start_session(snapshot=True) as session: - db.test.distinct('x', {}, session=session) - batches = list(db.test.find_raw_batches( - session=session).sort('_id')) + db.test.distinct("x", {}, session=session) + batches = list(db.test.find_raw_batches(session=session).sort("_id")) self.assertEqual(1, len(batches)) self.assertEqual(docs, decode_all(batches[0])) - find_cmd = listener.results['started'][1].command - self.assertEqual(find_cmd['readConcern']['level'], 'snapshot') - self.assertIsNotNone(find_cmd['readConcern']['atClusterTime']) + find_cmd = listener.results["started"][1].command + self.assertEqual(find_cmd["readConcern"]["level"], "snapshot") + self.assertIsNotNone(find_cmd["readConcern"]["atClusterTime"]) def test_explain(self): c = self.db.test @@ -1379,13 +1340,13 @@ def test_clone(self): def test_exhaust(self): c = self.db.test c.drop() - c.insert_many({'_id': i} for i in range(200)) - result = b''.join(c.find_raw_batches(cursor_type=CursorType.EXHAUST)) - self.assertEqual([{'_id': i} for i in range(200)], decode_all(result)) + c.insert_many({"_id": i} for i in range(200)) + result = b"".join(c.find_raw_batches(cursor_type=CursorType.EXHAUST)) + self.assertEqual([{"_id": i} for i in range(200)], decode_all(result)) def test_server_error(self): with self.assertRaises(OperationFailure) as exc: - next(self.db.test.find_raw_batches({'x': {'$bad': 1}})) + next(self.db.test.find_raw_batches({"x": {"$bad": 1}})) # The server response was decoded, not left raw. self.assertIsInstance(exc.exception.details, dict) @@ -1395,12 +1356,11 @@ def test_get_item(self): self.db.test.find_raw_batches()[0] def test_collation(self): - next(self.db.test.find_raw_batches(collation=Collation('en_US'))) + next(self.db.test.find_raw_batches(collation=Collation("en_US"))) - @client_context.require_no_mmap # MMAPv1 does not support read concern + @client_context.require_no_mmap # MMAPv1 does not support read concern def test_read_concern(self): - self.db.get_collection( - "test", write_concern=WriteConcern(w="majority")).insert_one({}) + self.db.get_collection("test", write_concern=WriteConcern(w="majority")).insert_one({}) c = self.db.get_collection("test", read_concern=ReadConcern("majority")) next(c.find_raw_batches()) @@ -1409,7 +1369,7 @@ def test_monitoring(self): client = rs_or_single_client(event_listeners=[listener]) c = client.pymongo_test.test c.drop() - c.insert_many([{'_id': i} for i in range(10)]) + c.insert_many([{"_id": i} for i in range(10)]) listener.results.clear() cursor = c.find_raw_batches(batch_size=4) @@ -1417,19 +1377,18 @@ def test_monitoring(self): # First raw batch of 4 documents. next(cursor) - started = listener.results['started'][0] - succeeded = listener.results['succeeded'][0] - self.assertEqual(0, len(listener.results['failed'])) - self.assertEqual('find', started.command_name) - self.assertEqual('pymongo_test', started.database_name) - self.assertEqual('find', succeeded.command_name) + started = listener.results["started"][0] + succeeded = listener.results["succeeded"][0] + self.assertEqual(0, len(listener.results["failed"])) + self.assertEqual("find", started.command_name) + self.assertEqual("pymongo_test", started.database_name) + self.assertEqual("find", succeeded.command_name) csr = succeeded.reply["cursor"] self.assertEqual(csr["ns"], "pymongo_test.test") # The batch is a list of one raw bytes object. self.assertEqual(len(csr["firstBatch"]), 1) - self.assertEqual(decode_all(csr["firstBatch"][0]), - [{'_id': i} for i in range(0, 4)]) + self.assertEqual(decode_all(csr["firstBatch"][0]), [{"_id": i} for i in range(0, 4)]) listener.results.clear() @@ -1437,17 +1396,16 @@ def test_monitoring(self): next(cursor) try: results = listener.results - started = results['started'][0] - succeeded = results['succeeded'][0] - self.assertEqual(0, len(results['failed'])) - self.assertEqual('getMore', started.command_name) - self.assertEqual('pymongo_test', started.database_name) - self.assertEqual('getMore', succeeded.command_name) + started = results["started"][0] + succeeded = results["succeeded"][0] + self.assertEqual(0, len(results["failed"])) + self.assertEqual("getMore", started.command_name) + self.assertEqual("pymongo_test", started.database_name) + self.assertEqual("getMore", succeeded.command_name) csr = succeeded.reply["cursor"] self.assertEqual(csr["ns"], "pymongo_test.test") self.assertEqual(len(csr["nextBatch"]), 1) - self.assertEqual(decode_all(csr["nextBatch"][0]), - [{'_id': i} for i in range(4, 8)]) + self.assertEqual(decode_all(csr["nextBatch"][0]), [{"_id": i} for i in range(4, 8)]) finally: # Finish the cursor. tuple(cursor) @@ -1461,9 +1419,9 @@ def setUpClass(cls): def test_aggregate_raw(self): c = self.db.test c.drop() - docs = [{'_id': i, 'x': 3.0 * i} for i in range(10)] + docs = [{"_id": i, "x": 3.0 * i} for i in range(10)] c.insert_many(docs) - batches = list(c.aggregate_raw_batches([{'$sort': {'_id': 1}}])) + batches = list(c.aggregate_raw_batches([{"$sort": {"_id": 1}}])) self.assertEqual(1, len(batches)) self.assertEqual(docs, decode_all(batches[0])) @@ -1471,24 +1429,29 @@ def test_aggregate_raw(self): def test_aggregate_raw_transaction(self): c = self.db.test c.drop() - docs = [{'_id': i, 'x': 3.0 * i} for i in range(10)] + docs = [{"_id": i, "x": 3.0 * i} for i in range(10)] c.insert_many(docs) listener = OvertCommandListener() client = rs_or_single_client(event_listeners=[listener]) with client.start_session() as session: with session.start_transaction(): - batches = list(client[self.db.name].test.aggregate_raw_batches( - [{'$sort': {'_id': 1}}], session=session)) - cmd = listener.results['started'][0] - self.assertEqual(cmd.command_name, 'aggregate') - self.assertIn('$clusterTime', cmd.command) - self.assertEqual(cmd.command['startTransaction'], True) - self.assertEqual(cmd.command['txnNumber'], 1) + batches = list( + client[self.db.name].test.aggregate_raw_batches( + [{"$sort": {"_id": 1}}], session=session + ) + ) + cmd = listener.results["started"][0] + self.assertEqual(cmd.command_name, "aggregate") + self.assertIn("$clusterTime", cmd.command) + self.assertEqual(cmd.command["startTransaction"], True) + self.assertEqual(cmd.command["txnNumber"], 1) # Ensure we update $clusterTime from the command response. - last_cmd = listener.results['succeeded'][-1] - self.assertEqual(last_cmd.reply['$clusterTime']['clusterTime'], - session.cluster_time['clusterTime']) + last_cmd = listener.results["succeeded"][-1] + self.assertEqual( + last_cmd.reply["$clusterTime"]["clusterTime"], + session.cluster_time["clusterTime"], + ) self.assertEqual(1, len(batches)) self.assertEqual(docs, decode_all(batches[0])) @@ -1497,62 +1460,63 @@ def test_aggregate_raw_transaction(self): def test_aggregate_raw_retryable_reads(self): c = self.db.test c.drop() - docs = [{'_id': i, 'x': 3.0 * i} for i in range(10)] + docs = [{"_id": i, "x": 3.0 * i} for i in range(10)] c.insert_many(docs) listener = OvertCommandListener() - client = rs_or_single_client(event_listeners=[listener], - retryReads=True) - with self.fail_point({ - 'mode': {'times': 1}, 'data': {'failCommands': ['aggregate'], - 'closeConnection': True}}): - batches = list(client[self.db.name].test.aggregate_raw_batches( - [{'$sort': {'_id': 1}}])) + client = rs_or_single_client(event_listeners=[listener], retryReads=True) + with self.fail_point( + {"mode": {"times": 1}, "data": {"failCommands": ["aggregate"], "closeConnection": True}} + ): + batches = list(client[self.db.name].test.aggregate_raw_batches([{"$sort": {"_id": 1}}])) self.assertEqual(1, len(batches)) self.assertEqual(docs, decode_all(batches[0])) - self.assertEqual(len(listener.results['started']), 3) - cmds = listener.results['started'] - self.assertEqual(cmds[0].command_name, 'aggregate') - self.assertEqual(cmds[1].command_name, 'aggregate') + self.assertEqual(len(listener.results["started"]), 3) + cmds = listener.results["started"] + self.assertEqual(cmds[0].command_name, "aggregate") + self.assertEqual(cmds[1].command_name, "aggregate") @client_context.require_version_min(5, 0, -1) @client_context.require_no_standalone def test_aggregate_raw_snapshot_reads(self): - c = self.db.get_collection( - "test", write_concern=WriteConcern(w="majority")) + c = self.db.get_collection("test", write_concern=WriteConcern(w="majority")) c.drop() - docs = [{'_id': i, 'x': 3.0 * i} for i in range(10)] + docs = [{"_id": i, "x": 3.0 * i} for i in range(10)] c.insert_many(docs) listener = OvertCommandListener() - client = rs_or_single_client(event_listeners=[listener], - retryReads=True) + client = rs_or_single_client(event_listeners=[listener], retryReads=True) db = client[self.db.name] with client.start_session(snapshot=True) as session: - db.test.distinct('x', {}, session=session) - batches = list(db.test.aggregate_raw_batches( - [{'$sort': {'_id': 1}}], session=session)) + db.test.distinct("x", {}, session=session) + batches = list(db.test.aggregate_raw_batches([{"$sort": {"_id": 1}}], session=session)) self.assertEqual(1, len(batches)) self.assertEqual(docs, decode_all(batches[0])) - find_cmd = listener.results['started'][1].command - self.assertEqual(find_cmd['readConcern']['level'], 'snapshot') - self.assertIsNotNone(find_cmd['readConcern']['atClusterTime']) + find_cmd = listener.results["started"][1].command + self.assertEqual(find_cmd["readConcern"]["level"], "snapshot") + self.assertIsNotNone(find_cmd["readConcern"]["atClusterTime"]) def test_server_error(self): c = self.db.test c.drop() - docs = [{'_id': i, 'x': 3.0 * i} for i in range(10)] + docs = [{"_id": i, "x": 3.0 * i} for i in range(10)] c.insert_many(docs) - c.insert_one({'_id': 10, 'x': 'not a number'}) + c.insert_one({"_id": 10, "x": "not a number"}) with self.assertRaises(OperationFailure) as exc: - list(self.db.test.aggregate_raw_batches([{ - '$sort': {'_id': 1}, - }, { - '$project': {'x': {'$multiply': [2, '$x']}} - }], batchSize=4)) + list( + self.db.test.aggregate_raw_batches( + [ + { + "$sort": {"_id": 1}, + }, + {"$project": {"x": {"$multiply": [2, "$x"]}}}, + ], + batchSize=4, + ) + ) # The server response was decoded, not left raw. self.assertIsInstance(exc.exception.details, dict) @@ -1562,25 +1526,25 @@ def test_get_item(self): self.db.test.aggregate_raw_batches([])[0] def test_collation(self): - next(self.db.test.aggregate_raw_batches([], collation=Collation('en_US'))) + next(self.db.test.aggregate_raw_batches([], collation=Collation("en_US"))) def test_monitoring(self): listener = EventListener() client = rs_or_single_client(event_listeners=[listener]) c = client.pymongo_test.test c.drop() - c.insert_many([{'_id': i} for i in range(10)]) + c.insert_many([{"_id": i} for i in range(10)]) listener.results.clear() - cursor = c.aggregate_raw_batches([{'$sort': {'_id': 1}}], batchSize=4) + cursor = c.aggregate_raw_batches([{"$sort": {"_id": 1}}], batchSize=4) # Start cursor, no initial batch. - started = listener.results['started'][0] - succeeded = listener.results['succeeded'][0] - self.assertEqual(0, len(listener.results['failed'])) - self.assertEqual('aggregate', started.command_name) - self.assertEqual('pymongo_test', started.database_name) - self.assertEqual('aggregate', succeeded.command_name) + started = listener.results["started"][0] + succeeded = listener.results["succeeded"][0] + self.assertEqual(0, len(listener.results["failed"])) + self.assertEqual("aggregate", started.command_name) + self.assertEqual("pymongo_test", started.database_name) + self.assertEqual("aggregate", succeeded.command_name) csr = succeeded.reply["cursor"] self.assertEqual(csr["ns"], "pymongo_test.test") @@ -1592,18 +1556,17 @@ def test_monitoring(self): n = 0 for batch in cursor: results = listener.results - started = results['started'][0] - succeeded = results['succeeded'][0] - self.assertEqual(0, len(results['failed'])) - self.assertEqual('getMore', started.command_name) - self.assertEqual('pymongo_test', started.database_name) - self.assertEqual('getMore', succeeded.command_name) + started = results["started"][0] + succeeded = results["succeeded"][0] + self.assertEqual(0, len(results["failed"])) + self.assertEqual("getMore", started.command_name) + self.assertEqual("pymongo_test", started.database_name) + self.assertEqual("getMore", succeeded.command_name) csr = succeeded.reply["cursor"] self.assertEqual(csr["ns"], "pymongo_test.test") self.assertEqual(len(csr["nextBatch"]), 1) self.assertEqual(csr["nextBatch"][0], batch) - self.assertEqual(decode_all(batch), - [{'_id': i} for i in range(n, min(n + 4, 10))]) + self.assertEqual(decode_all(batch), [{"_id": i} for i in range(n, min(n + 4, 10))]) n += 4 listener.results.clear() diff --git a/test/test_custom_types.py b/test/test_custom_types.py index 5db208ab7e..550b322020 100644 --- a/test/test_custom_types.py +++ b/test/test_custom_types.py @@ -17,39 +17,43 @@ import datetime import sys import tempfile - from collections import OrderedDict from decimal import Decimal from random import random sys.path[0:0] = [""] -from bson import (Decimal128, - decode, - decode_all, - decode_file_iter, - decode_iter, - encode, - RE_TYPE, - _BUILT_IN_TYPES, - _dict_to_bson, - _bson_to_dict) -from bson.codec_options import (CodecOptions, TypeCodec, TypeDecoder, - TypeEncoder, TypeRegistry) +from test import client_context, unittest +from test.test_client import IntegrationTest +from test.utils import rs_client + +from bson import ( + _BUILT_IN_TYPES, + RE_TYPE, + Decimal128, + _bson_to_dict, + _dict_to_bson, + decode, + decode_all, + decode_file_iter, + decode_iter, + encode, +) +from bson.codec_options import ( + CodecOptions, + TypeCodec, + TypeDecoder, + TypeEncoder, + TypeRegistry, +) from bson.errors import InvalidDocument from bson.int64 import Int64 from bson.raw_bson import RawBSONDocument - from gridfs import GridIn, GridOut - from pymongo.collection import ReturnDocument from pymongo.errors import DuplicateKeyError from pymongo.message import _CursorAddress -from test import client_context, unittest -from test.test_client import IntegrationTest -from test.utils import rs_client - class DecimalEncoder(TypeEncoder): @property @@ -73,8 +77,7 @@ class DecimalCodec(DecimalDecoder, DecimalEncoder): pass -DECIMAL_CODECOPTS = CodecOptions( - type_registry=TypeRegistry([DecimalCodec()])) +DECIMAL_CODECOPTS = CodecOptions(type_registry=TypeRegistry([DecimalCodec()])) class UndecipherableInt64Type(object): @@ -90,39 +93,55 @@ def __eq__(self, other): class UndecipherableIntDecoder(TypeDecoder): bson_type = Int64 + def transform_bson(self, value): return UndecipherableInt64Type(value) class UndecipherableIntEncoder(TypeEncoder): python_type = UndecipherableInt64Type + def transform_python(self, value): return Int64(value.value) UNINT_DECODER_CODECOPTS = CodecOptions( - type_registry=TypeRegistry([UndecipherableIntDecoder(), ])) + type_registry=TypeRegistry( + [ + UndecipherableIntDecoder(), + ] + ) +) -UNINT_CODECOPTS = CodecOptions(type_registry=TypeRegistry( - [UndecipherableIntDecoder(), UndecipherableIntEncoder()])) +UNINT_CODECOPTS = CodecOptions( + type_registry=TypeRegistry([UndecipherableIntDecoder(), UndecipherableIntEncoder()]) +) class UppercaseTextDecoder(TypeDecoder): bson_type = str + def transform_bson(self, value): return value.upper() -UPPERSTR_DECODER_CODECOPTS = CodecOptions(type_registry=TypeRegistry( - [UppercaseTextDecoder(),])) +UPPERSTR_DECODER_CODECOPTS = CodecOptions( + type_registry=TypeRegistry( + [ + UppercaseTextDecoder(), + ] + ) +) def type_obfuscating_decoder_factory(rt_type): class ResumeTokenToNanDecoder(TypeDecoder): bson_type = rt_type + def transform_bson(self, value): return "NaN" + return ResumeTokenToNanDecoder @@ -133,40 +152,39 @@ def roundtrip(self, doc): self.assertEqual(doc, rt_document) def test_encode_decode_roundtrip(self): - self.roundtrip({'average': Decimal('56.47')}) - self.roundtrip({'average': {'b': Decimal('56.47')}}) - self.roundtrip({'average': [Decimal('56.47')]}) - self.roundtrip({'average': [[Decimal('56.47')]]}) - self.roundtrip({'average': [{'b': Decimal('56.47')}]}) + self.roundtrip({"average": Decimal("56.47")}) + self.roundtrip({"average": {"b": Decimal("56.47")}}) + self.roundtrip({"average": [Decimal("56.47")]}) + self.roundtrip({"average": [[Decimal("56.47")]]}) + self.roundtrip({"average": [{"b": Decimal("56.47")}]}) def test_decode_all(self): documents = [] for dec in range(3): - documents.append({'average': Decimal('56.4%s' % (dec,))}) + documents.append({"average": Decimal("56.4%s" % (dec,))}) bsonstream = bytes() for doc in documents: bsonstream += encode(doc, codec_options=self.codecopts) - self.assertEqual( - decode_all(bsonstream, self.codecopts), documents) + self.assertEqual(decode_all(bsonstream, self.codecopts), documents) def test__bson_to_dict(self): - document = {'average': Decimal('56.47')} + document = {"average": Decimal("56.47")} rawbytes = encode(document, codec_options=self.codecopts) decoded_document = _bson_to_dict(rawbytes, self.codecopts) self.assertEqual(document, decoded_document) def test__dict_to_bson(self): - document = {'average': Decimal('56.47')} + document = {"average": Decimal("56.47")} rawbytes = encode(document, codec_options=self.codecopts) encoded_document = _dict_to_bson(document, False, self.codecopts) self.assertEqual(encoded_document, rawbytes) def _generate_multidocument_bson_stream(self): inp_num = [str(random() * 100)[:4] for _ in range(10)] - docs = [{'n': Decimal128(dec)} for dec in inp_num] - edocs = [{'n': Decimal(dec)} for dec in inp_num] + docs = [{"n": Decimal128(dec)} for dec in inp_num] + edocs = [{"n": Decimal(dec)} for dec in inp_num] bsonstream = b"" for doc in docs: bsonstream += encode(doc) @@ -174,8 +192,7 @@ def _generate_multidocument_bson_stream(self): def test_decode_iter(self): expected, bson_data = self._generate_multidocument_bson_stream() - for expected_doc, decoded_doc in zip( - expected, decode_iter(bson_data, self.codecopts)): + for expected_doc, decoded_doc in zip(expected, decode_iter(bson_data, self.codecopts)): self.assertEqual(expected_doc, decoded_doc) def test_decode_file_iter(self): @@ -184,26 +201,24 @@ def test_decode_file_iter(self): fileobj.write(bson_data) fileobj.seek(0) - for expected_doc, decoded_doc in zip( - expected, decode_file_iter(fileobj, self.codecopts)): + for expected_doc, decoded_doc in zip(expected, decode_file_iter(fileobj, self.codecopts)): self.assertEqual(expected_doc, decoded_doc) fileobj.close() -class TestCustomPythonBSONTypeToBSONMonolithicCodec(CustomBSONTypeTests, - unittest.TestCase): +class TestCustomPythonBSONTypeToBSONMonolithicCodec(CustomBSONTypeTests, unittest.TestCase): @classmethod def setUpClass(cls): cls.codecopts = DECIMAL_CODECOPTS -class TestCustomPythonBSONTypeToBSONMultiplexedCodec(CustomBSONTypeTests, - unittest.TestCase): +class TestCustomPythonBSONTypeToBSONMultiplexedCodec(CustomBSONTypeTests, unittest.TestCase): @classmethod def setUpClass(cls): codec_options = CodecOptions( - type_registry=TypeRegistry((DecimalEncoder(), DecimalDecoder()))) + type_registry=TypeRegistry((DecimalEncoder(), DecimalDecoder())) + ) cls.codecopts = codec_options @@ -214,29 +229,29 @@ def _get_codec_options(self, fallback_encoder): def test_simple(self): codecopts = self._get_codec_options(lambda x: Decimal128(x)) - document = {'average': Decimal('56.47')} + document = {"average": Decimal("56.47")} bsonbytes = encode(document, codec_options=codecopts) - exp_document = {'average': Decimal128('56.47')} + exp_document = {"average": Decimal128("56.47")} exp_bsonbytes = encode(exp_document) self.assertEqual(bsonbytes, exp_bsonbytes) def test_erroring_fallback_encoder(self): - codecopts = self._get_codec_options(lambda _: 1/0) + codecopts = self._get_codec_options(lambda _: 1 / 0) # fallback converter should not be invoked when encoding known types. encode( - {'a': 1, 'b': Decimal128('1.01'), 'c': {'arr': ['abc', 3.678]}}, - codec_options=codecopts) + {"a": 1, "b": Decimal128("1.01"), "c": {"arr": ["abc", 3.678]}}, codec_options=codecopts + ) # expect an error when encoding a custom type. - document = {'average': Decimal('56.47')} + document = {"average": Decimal("56.47")} with self.assertRaises(ZeroDivisionError): encode(document, codec_options=codecopts) def test_noop_fallback_encoder(self): codecopts = self._get_codec_options(lambda x: x) - document = {'average': Decimal('56.47')} + document = {"average": Decimal("56.47")} with self.assertRaises(InvalidDocument): encode(document, codec_options=codecopts) @@ -246,8 +261,9 @@ def fallback_encoder(value): return Decimal128(value) except: raise TypeError("cannot encode type %s" % (type(value))) + codecopts = self._get_codec_options(fallback_encoder) - document = {'average': Decimal} + document = {"average": Decimal} with self.assertRaises(TypeError): encode(document, codec_options=codecopts) @@ -255,8 +271,9 @@ def fallback_encoder(value): class TestBSONTypeEnDeCodecs(unittest.TestCase): def test_instantiation(self): msg = "Can't instantiate abstract class" + def run_test(base, attrs, fail): - codec = type('testcodec', (base,), attrs) + codec = type("testcodec", (base,), attrs) if fail: with self.assertRaisesRegex(TypeError, msg): codec() @@ -266,24 +283,46 @@ def run_test(base, attrs, fail): class MyType(object): pass - run_test(TypeEncoder, {'python_type': MyType,}, fail=True) - run_test(TypeEncoder, {'transform_python': lambda s, x: x}, fail=True) - run_test(TypeEncoder, {'transform_python': lambda s, x: x, - 'python_type': MyType}, fail=False) - - run_test(TypeDecoder, {'bson_type': Decimal128, }, fail=True) - run_test(TypeDecoder, {'transform_bson': lambda s, x: x}, fail=True) - run_test(TypeDecoder, {'transform_bson': lambda s, x: x, - 'bson_type': Decimal128}, fail=False) - - run_test(TypeCodec, {'bson_type': Decimal128, - 'python_type': MyType}, fail=True) - run_test(TypeCodec, {'transform_bson': lambda s, x: x, - 'transform_python': lambda s, x: x}, fail=True) - run_test(TypeCodec, {'python_type': MyType, - 'transform_python': lambda s, x: x, - 'transform_bson': lambda s, x: x, - 'bson_type': Decimal128}, fail=False) + run_test( + TypeEncoder, + { + "python_type": MyType, + }, + fail=True, + ) + run_test(TypeEncoder, {"transform_python": lambda s, x: x}, fail=True) + run_test( + TypeEncoder, {"transform_python": lambda s, x: x, "python_type": MyType}, fail=False + ) + + run_test( + TypeDecoder, + { + "bson_type": Decimal128, + }, + fail=True, + ) + run_test(TypeDecoder, {"transform_bson": lambda s, x: x}, fail=True) + run_test( + TypeDecoder, {"transform_bson": lambda s, x: x, "bson_type": Decimal128}, fail=False + ) + + run_test(TypeCodec, {"bson_type": Decimal128, "python_type": MyType}, fail=True) + run_test( + TypeCodec, + {"transform_bson": lambda s, x: x, "transform_python": lambda s, x: x}, + fail=True, + ) + run_test( + TypeCodec, + { + "python_type": MyType, + "transform_python": lambda s, x: x, + "transform_bson": lambda s, x: x, + "bson_type": Decimal128, + }, + fail=False, + ) def test_type_checks(self): self.assertTrue(issubclass(TypeCodec, TypeEncoder)) @@ -316,6 +355,7 @@ def fallback_encoder_A2BSON(value): # transforms B into something encodable class B2BSON(TypeEncoder): python_type = TypeB + def transform_python(self, value): return value.value @@ -324,6 +364,7 @@ def transform_python(self, value): # BSON-encodable. class A2B(TypeEncoder): python_type = TypeA + def transform_python(self, value): return TypeB(value.value) @@ -332,6 +373,7 @@ def transform_python(self, value): # BSON-encodable. class B2A(TypeEncoder): python_type = TypeB + def transform_python(self, value): return TypeA(value.value) @@ -344,37 +386,37 @@ def transform_python(self, value): cls.A2B = A2B def test_encode_fallback_then_custom(self): - codecopts = CodecOptions(type_registry=TypeRegistry( - [self.B2BSON()], fallback_encoder=self.fallback_encoder_A2B)) - testdoc = {'x': self.TypeA(123)} - expected_bytes = encode({'x': 123}) + codecopts = CodecOptions( + type_registry=TypeRegistry([self.B2BSON()], fallback_encoder=self.fallback_encoder_A2B) + ) + testdoc = {"x": self.TypeA(123)} + expected_bytes = encode({"x": 123}) - self.assertEqual(encode(testdoc, codec_options=codecopts), - expected_bytes) + self.assertEqual(encode(testdoc, codec_options=codecopts), expected_bytes) def test_encode_custom_then_fallback(self): - codecopts = CodecOptions(type_registry=TypeRegistry( - [self.B2A()], fallback_encoder=self.fallback_encoder_A2BSON)) - testdoc = {'x': self.TypeB(123)} - expected_bytes = encode({'x': 123}) + codecopts = CodecOptions( + type_registry=TypeRegistry([self.B2A()], fallback_encoder=self.fallback_encoder_A2BSON) + ) + testdoc = {"x": self.TypeB(123)} + expected_bytes = encode({"x": 123}) - self.assertEqual(encode(testdoc, codec_options=codecopts), - expected_bytes) + self.assertEqual(encode(testdoc, codec_options=codecopts), expected_bytes) def test_chaining_encoders_fails(self): - codecopts = CodecOptions(type_registry=TypeRegistry( - [self.A2B(), self.B2BSON()])) + codecopts = CodecOptions(type_registry=TypeRegistry([self.A2B(), self.B2BSON()])) with self.assertRaises(InvalidDocument): - encode({'x': self.TypeA(123)}, codec_options=codecopts) + encode({"x": self.TypeA(123)}, codec_options=codecopts) def test_infinite_loop_exceeds_max_recursion_depth(self): - codecopts = CodecOptions(type_registry=TypeRegistry( - [self.B2A()], fallback_encoder=self.fallback_encoder_A2B)) + codecopts = CodecOptions( + type_registry=TypeRegistry([self.B2A()], fallback_encoder=self.fallback_encoder_A2B) + ) # Raises max recursion depth exceeded error with self.assertRaises(RuntimeError): - encode({'x': self.TypeA(100)}, codec_options=codecopts) + encode({"x": self.TypeA(100)}, codec_options=codecopts) class TestTypeRegistry(unittest.TestCase): @@ -429,29 +471,34 @@ def fallback_encoder(value): def test_simple(self): codec_instances = [codec() for codec in self.codecs] + def assert_proper_initialization(type_registry, codec_instances): - self.assertEqual(type_registry._encoder_map, { - self.types[0]: codec_instances[0].transform_python, - self.types[1]: codec_instances[1].transform_python}) - self.assertEqual(type_registry._decoder_map, { - int: codec_instances[0].transform_bson, - str: codec_instances[1].transform_bson}) self.assertEqual( - type_registry._fallback_encoder, self.fallback_encoder) + type_registry._encoder_map, + { + self.types[0]: codec_instances[0].transform_python, + self.types[1]: codec_instances[1].transform_python, + }, + ) + self.assertEqual( + type_registry._decoder_map, + {int: codec_instances[0].transform_bson, str: codec_instances[1].transform_bson}, + ) + self.assertEqual(type_registry._fallback_encoder, self.fallback_encoder) type_registry = TypeRegistry(codec_instances, self.fallback_encoder) assert_proper_initialization(type_registry, codec_instances) type_registry = TypeRegistry( - fallback_encoder=self.fallback_encoder, type_codecs=codec_instances) + fallback_encoder=self.fallback_encoder, type_codecs=codec_instances + ) assert_proper_initialization(type_registry, codec_instances) # Ensure codec list held by the type registry doesn't change if we # mutate the initial list. codec_instances_copy = list(codec_instances) codec_instances.pop(0) - self.assertListEqual( - type_registry._TypeRegistry__type_codecs, codec_instances_copy) + self.assertListEqual(type_registry._TypeRegistry__type_codecs, codec_instances_copy) def test_simple_separate_codecs(self): class MyIntEncoder(TypeEncoder): @@ -471,72 +518,83 @@ def transform_bson(self, value): self.assertEqual( type_registry._encoder_map, - {MyIntEncoder.python_type: codec_instances[1].transform_python}) + {MyIntEncoder.python_type: codec_instances[1].transform_python}, + ) self.assertEqual( - type_registry._decoder_map, - {MyIntDecoder.bson_type: codec_instances[0].transform_bson}) + type_registry._decoder_map, {MyIntDecoder.bson_type: codec_instances[0].transform_bson} + ) def test_initialize_fail(self): - err_msg = ("Expected an instance of TypeEncoder, TypeDecoder, " - "or TypeCodec, got .* instead") + err_msg = ( + "Expected an instance of TypeEncoder, TypeDecoder, " "or TypeCodec, got .* instead" + ) with self.assertRaisesRegex(TypeError, err_msg): TypeRegistry(self.codecs) with self.assertRaisesRegex(TypeError, err_msg): - TypeRegistry([type('AnyType', (object,), {})()]) + TypeRegistry([type("AnyType", (object,), {})()]) err_msg = "fallback_encoder %r is not a callable" % (True,) with self.assertRaisesRegex(TypeError, err_msg): TypeRegistry([], True) - err_msg = "fallback_encoder %r is not a callable" % ('hello',) + err_msg = "fallback_encoder %r is not a callable" % ("hello",) with self.assertRaisesRegex(TypeError, err_msg): - TypeRegistry(fallback_encoder='hello') + TypeRegistry(fallback_encoder="hello") def test_type_registry_repr(self): codec_instances = [codec() for codec in self.codecs] type_registry = TypeRegistry(codec_instances) - r = ("TypeRegistry(type_codecs=%r, fallback_encoder=%r)" % ( - codec_instances, None)) + r = "TypeRegistry(type_codecs=%r, fallback_encoder=%r)" % (codec_instances, None) self.assertEqual(r, repr(type_registry)) def test_type_registry_eq(self): codec_instances = [codec() for codec in self.codecs] - self.assertEqual( - TypeRegistry(codec_instances), TypeRegistry(codec_instances)) + self.assertEqual(TypeRegistry(codec_instances), TypeRegistry(codec_instances)) codec_instances_2 = [codec() for codec in self.codecs] - self.assertNotEqual( - TypeRegistry(codec_instances), TypeRegistry(codec_instances_2)) + self.assertNotEqual(TypeRegistry(codec_instances), TypeRegistry(codec_instances_2)) def test_builtin_types_override_fails(self): def run_test(base, attrs): - msg = (r"TypeEncoders cannot change how built-in types " - r"are encoded \(encoder .* transforms type .*\)") + msg = ( + r"TypeEncoders cannot change how built-in types " + r"are encoded \(encoder .* transforms type .*\)" + ) for pytype in _BUILT_IN_TYPES: - attrs.update({'python_type': pytype, - 'transform_python': lambda x: x}) - codec = type('testcodec', (base, ), attrs) + attrs.update({"python_type": pytype, "transform_python": lambda x: x}) + codec = type("testcodec", (base,), attrs) codec_instance = codec() with self.assertRaisesRegex(TypeError, msg): - TypeRegistry([codec_instance,]) + TypeRegistry( + [ + codec_instance, + ] + ) # Test only some subtypes as not all can be subclassed. - if pytype in [bool, type(None), RE_TYPE,]: + if pytype in [ + bool, + type(None), + RE_TYPE, + ]: continue class MyType(pytype): pass - attrs.update({'python_type': MyType, - 'transform_python': lambda x: x}) - codec = type('testcodec', (base, ), attrs) + + attrs.update({"python_type": MyType, "transform_python": lambda x: x}) + codec = type("testcodec", (base,), attrs) codec_instance = codec() with self.assertRaisesRegex(TypeError, msg): - TypeRegistry([codec_instance,]) + TypeRegistry( + [ + codec_instance, + ] + ) run_test(TypeEncoder, {}) - run_test(TypeCodec, {'bson_type': Decimal128, - 'transform_bson': lambda x: x}) + run_test(TypeCodec, {"bson_type": Decimal128, "transform_bson": lambda x: x}) class TestCollectionWCustomType(IntegrationTest): @@ -548,115 +606,127 @@ def tearDown(self): def test_command_errors_w_custom_type_decoder(self): db = self.db - test_doc = {'_id': 1, 'data': 'a'} - test = db.get_collection('test', - codec_options=UNINT_DECODER_CODECOPTS) + test_doc = {"_id": 1, "data": "a"} + test = db.get_collection("test", codec_options=UNINT_DECODER_CODECOPTS) result = test.insert_one(test_doc) - self.assertEqual(result.inserted_id, test_doc['_id']) + self.assertEqual(result.inserted_id, test_doc["_id"]) with self.assertRaises(DuplicateKeyError): test.insert_one(test_doc) def test_find_w_custom_type_decoder(self): db = self.db - input_docs = [ - {'x': Int64(k)} for k in [1, 2, 3]] + input_docs = [{"x": Int64(k)} for k in [1, 2, 3]] for doc in input_docs: db.test.insert_one(doc) - test = db.get_collection( - 'test', codec_options=UNINT_DECODER_CODECOPTS) + test = db.get_collection("test", codec_options=UNINT_DECODER_CODECOPTS) for doc in test.find({}, batch_size=1): - self.assertIsInstance(doc['x'], UndecipherableInt64Type) + self.assertIsInstance(doc["x"], UndecipherableInt64Type) def test_find_w_custom_type_decoder_and_document_class(self): def run_test(doc_cls): db = self.db - input_docs = [ - {'x': Int64(k)} for k in [1, 2, 3]] + input_docs = [{"x": Int64(k)} for k in [1, 2, 3]] for doc in input_docs: db.test.insert_one(doc) - test = db.get_collection('test', codec_options=CodecOptions( - type_registry=TypeRegistry([UndecipherableIntDecoder()]), - document_class=doc_cls)) + test = db.get_collection( + "test", + codec_options=CodecOptions( + type_registry=TypeRegistry([UndecipherableIntDecoder()]), document_class=doc_cls + ), + ) for doc in test.find({}, batch_size=1): self.assertIsInstance(doc, doc_cls) - self.assertIsInstance(doc['x'], UndecipherableInt64Type) + self.assertIsInstance(doc["x"], UndecipherableInt64Type) for doc_cls in [RawBSONDocument, OrderedDict]: run_test(doc_cls) def test_aggregate_w_custom_type_decoder(self): db = self.db - db.test.insert_many([ - {'status': 'in progress', 'qty': Int64(1)}, - {'status': 'complete', 'qty': Int64(10)}, - {'status': 'in progress', 'qty': Int64(1)}, - {'status': 'complete', 'qty': Int64(10)}, - {'status': 'in progress', 'qty': Int64(1)},]) - test = db.get_collection( - 'test', codec_options=UNINT_DECODER_CODECOPTS) + db.test.insert_many( + [ + {"status": "in progress", "qty": Int64(1)}, + {"status": "complete", "qty": Int64(10)}, + {"status": "in progress", "qty": Int64(1)}, + {"status": "complete", "qty": Int64(10)}, + {"status": "in progress", "qty": Int64(1)}, + ] + ) + test = db.get_collection("test", codec_options=UNINT_DECODER_CODECOPTS) pipeline = [ - {'$match': {'status': 'complete'}}, - {'$group': {'_id': "$status", 'total_qty': {"$sum": "$qty"}}},] + {"$match": {"status": "complete"}}, + {"$group": {"_id": "$status", "total_qty": {"$sum": "$qty"}}}, + ] result = test.aggregate(pipeline) res = list(result)[0] - self.assertEqual(res['_id'], 'complete') - self.assertIsInstance(res['total_qty'], UndecipherableInt64Type) - self.assertEqual(res['total_qty'].value, 20) + self.assertEqual(res["_id"], "complete") + self.assertIsInstance(res["total_qty"], UndecipherableInt64Type) + self.assertEqual(res["total_qty"].value, 20) def test_distinct_w_custom_type(self): self.db.drop_collection("test") - test = self.db.get_collection('test', codec_options=UNINT_CODECOPTS) + test = self.db.get_collection("test", codec_options=UNINT_CODECOPTS) values = [ UndecipherableInt64Type(1), UndecipherableInt64Type(2), UndecipherableInt64Type(3), - {"b": UndecipherableInt64Type(3)}] + {"b": UndecipherableInt64Type(3)}, + ] test.insert_many({"a": val} for val in values) self.assertEqual(values, test.distinct("a")) def test_find_one_and__w_custom_type_decoder(self): db = self.db - c = db.get_collection('test', codec_options=UNINT_DECODER_CODECOPTS) - c.insert_one({'_id': 1, 'x': Int64(1)}) - - doc = c.find_one_and_update({'_id': 1}, {'$inc': {'x': 1}}, - return_document=ReturnDocument.AFTER) - self.assertEqual(doc['_id'], 1) - self.assertIsInstance(doc['x'], UndecipherableInt64Type) - self.assertEqual(doc['x'].value, 2) - - doc = c.find_one_and_replace({'_id': 1}, {'x': Int64(3), 'y': True}, - return_document=ReturnDocument.AFTER) - self.assertEqual(doc['_id'], 1) - self.assertIsInstance(doc['x'], UndecipherableInt64Type) - self.assertEqual(doc['x'].value, 3) - self.assertEqual(doc['y'], True) - - doc = c.find_one_and_delete({'y': True}) - self.assertEqual(doc['_id'], 1) - self.assertIsInstance(doc['x'], UndecipherableInt64Type) - self.assertEqual(doc['x'].value, 3) + c = db.get_collection("test", codec_options=UNINT_DECODER_CODECOPTS) + c.insert_one({"_id": 1, "x": Int64(1)}) + + doc = c.find_one_and_update( + {"_id": 1}, {"$inc": {"x": 1}}, return_document=ReturnDocument.AFTER + ) + self.assertEqual(doc["_id"], 1) + self.assertIsInstance(doc["x"], UndecipherableInt64Type) + self.assertEqual(doc["x"].value, 2) + + doc = c.find_one_and_replace( + {"_id": 1}, {"x": Int64(3), "y": True}, return_document=ReturnDocument.AFTER + ) + self.assertEqual(doc["_id"], 1) + self.assertIsInstance(doc["x"], UndecipherableInt64Type) + self.assertEqual(doc["x"].value, 3) + self.assertEqual(doc["y"], True) + + doc = c.find_one_and_delete({"y": True}) + self.assertEqual(doc["_id"], 1) + self.assertIsInstance(doc["x"], UndecipherableInt64Type) + self.assertEqual(doc["x"].value, 3) self.assertIsNone(c.find_one()) class TestGridFileCustomType(IntegrationTest): def setUp(self): - self.db.drop_collection('fs.files') - self.db.drop_collection('fs.chunks') + self.db.drop_collection("fs.files") + self.db.drop_collection("fs.chunks") def test_grid_out_custom_opts(self): db = self.db.with_options(codec_options=UPPERSTR_DECODER_CODECOPTS) - one = GridIn(db.fs, _id=5, filename="my_file", - contentType="text/html", chunkSize=1000, aliases=["foo"], - metadata={"foo": 'red', "bar": 'blue'}, bar=3, - baz="hello") + one = GridIn( + db.fs, + _id=5, + filename="my_file", + contentType="text/html", + chunkSize=1000, + aliases=["foo"], + metadata={"foo": "red", "bar": "blue"}, + bar=3, + baz="hello", + ) one.write(b"hello world") one.close() @@ -670,12 +740,21 @@ def test_grid_out_custom_opts(self): self.assertEqual(1000, two.chunk_size) self.assertTrue(isinstance(two.upload_date, datetime.datetime)) self.assertEqual(["foo"], two.aliases) - self.assertEqual({"foo": 'red', "bar": 'blue'}, two.metadata) + self.assertEqual({"foo": "red", "bar": "blue"}, two.metadata) self.assertEqual(3, two.bar) self.assertEqual(None, two.md5) - for attr in ["_id", "name", "content_type", "length", "chunk_size", - "upload_date", "aliases", "metadata", "md5"]: + for attr in [ + "_id", + "name", + "content_type", + "length", + "chunk_size", + "upload_date", + "aliases", + "metadata", + "md5", + ]: self.assertRaises(AttributeError, setattr, two, attr, 5) @@ -683,11 +762,10 @@ class ChangeStreamsWCustomTypesTestMixin(object): def change_stream(self, *args, **kwargs): return self.watched_target.watch(*args, **kwargs) - def insert_and_check(self, change_stream, insert_doc, - expected_doc): + def insert_and_check(self, change_stream, insert_doc, expected_doc): self.input_target.insert_one(insert_doc) change = next(change_stream) - self.assertEqual(change['fullDocument'], expected_doc) + self.assertEqual(change["fullDocument"], expected_doc) def kill_change_stream_cursor(self, change_stream): # Cause a cursor not found error on the next getMore. @@ -697,18 +775,21 @@ def kill_change_stream_cursor(self, change_stream): client._close_cursor_now(cursor.cursor_id, address) def test_simple(self): - codecopts = CodecOptions(type_registry=TypeRegistry([ - UndecipherableIntEncoder(), UppercaseTextDecoder()])) + codecopts = CodecOptions( + type_registry=TypeRegistry([UndecipherableIntEncoder(), UppercaseTextDecoder()]) + ) self.create_targets(codec_options=codecopts) input_docs = [ - {'_id': UndecipherableInt64Type(1), 'data': 'hello'}, - {'_id': 2, 'data': 'world'}, - {'_id': UndecipherableInt64Type(3), 'data': '!'},] + {"_id": UndecipherableInt64Type(1), "data": "hello"}, + {"_id": 2, "data": "world"}, + {"_id": UndecipherableInt64Type(3), "data": "!"}, + ] expected_docs = [ - {'_id': 1, 'data': 'HELLO'}, - {'_id': 2, 'data': 'WORLD'}, - {'_id': 3, 'data': '!'},] + {"_id": 1, "data": "HELLO"}, + {"_id": 2, "data": "WORLD"}, + {"_id": 3, "data": "!"}, + ] change_stream = self.change_stream() @@ -719,22 +800,22 @@ def test_simple(self): self.insert_and_check(change_stream, input_docs[2], expected_docs[2]) def test_custom_type_in_pipeline(self): - codecopts = CodecOptions(type_registry=TypeRegistry([ - UndecipherableIntEncoder(), UppercaseTextDecoder()])) + codecopts = CodecOptions( + type_registry=TypeRegistry([UndecipherableIntEncoder(), UppercaseTextDecoder()]) + ) self.create_targets(codec_options=codecopts) input_docs = [ - {'_id': UndecipherableInt64Type(1), 'data': 'hello'}, - {'_id': 2, 'data': 'world'}, - {'_id': UndecipherableInt64Type(3), 'data': '!'}] - expected_docs = [ - {'_id': 2, 'data': 'WORLD'}, - {'_id': 3, 'data': '!'}] + {"_id": UndecipherableInt64Type(1), "data": "hello"}, + {"_id": 2, "data": "world"}, + {"_id": UndecipherableInt64Type(3), "data": "!"}, + ] + expected_docs = [{"_id": 2, "data": "WORLD"}, {"_id": 3, "data": "!"}] # UndecipherableInt64Type should be encoded with the TypeRegistry. change_stream = self.change_stream( - [{'$match': {'documentKey._id': { - '$gte': UndecipherableInt64Type(2)}}}]) + [{"$match": {"documentKey._id": {"$gte": UndecipherableInt64Type(2)}}}] + ) self.input_target.insert_one(input_docs[0]) self.insert_and_check(change_stream, input_docs[1], expected_docs[0]) @@ -747,17 +828,17 @@ def test_break_resume_token(self): change_stream = self.change_stream() self.input_target.insert_one({"data": "test"}) change = next(change_stream) - resume_token_decoder = type_obfuscating_decoder_factory( - type(change['_id']['_data'])) + resume_token_decoder = type_obfuscating_decoder_factory(type(change["_id"]["_data"])) # Custom-decoding the resumeToken type breaks resume tokens. - codecopts = CodecOptions(type_registry=TypeRegistry([ - resume_token_decoder(), UndecipherableIntEncoder()])) + codecopts = CodecOptions( + type_registry=TypeRegistry([resume_token_decoder(), UndecipherableIntEncoder()]) + ) # Re-create targets, change stream and proceed. self.create_targets(codec_options=codecopts) - docs = [{'_id': 1}, {'_id': 2}, {'_id': 3}] + docs = [{"_id": 1}, {"_id": 2}, {"_id": 3}] change_stream = self.change_stream() self.insert_and_check(change_stream, docs[0], docs[0]) @@ -768,27 +849,27 @@ def test_break_resume_token(self): def test_document_class(self): def run_test(doc_cls): - codecopts = CodecOptions(type_registry=TypeRegistry([ - UppercaseTextDecoder(), UndecipherableIntEncoder()]), - document_class=doc_cls) + codecopts = CodecOptions( + type_registry=TypeRegistry([UppercaseTextDecoder(), UndecipherableIntEncoder()]), + document_class=doc_cls, + ) self.create_targets(codec_options=codecopts) change_stream = self.change_stream() - doc = {'a': UndecipherableInt64Type(101), 'b': 'xyz'} + doc = {"a": UndecipherableInt64Type(101), "b": "xyz"} self.input_target.insert_one(doc) change = next(change_stream) self.assertIsInstance(change, doc_cls) - self.assertEqual(change['fullDocument']['a'], 101) - self.assertEqual(change['fullDocument']['b'], 'XYZ') + self.assertEqual(change["fullDocument"]["a"], 101) + self.assertEqual(change["fullDocument"]["b"], "XYZ") for doc_cls in [OrderedDict, RawBSONDocument]: run_test(doc_cls) -class TestCollectionChangeStreamsWCustomTypes( - IntegrationTest, ChangeStreamsWCustomTypesTestMixin): +class TestCollectionChangeStreamsWCustomTypes(IntegrationTest, ChangeStreamsWCustomTypesTestMixin): @classmethod @client_context.require_no_mmap @client_context.require_no_standalone @@ -800,16 +881,14 @@ def tearDown(self): self.input_target.drop() def create_targets(self, *args, **kwargs): - self.watched_target = self.db.get_collection( - 'test', *args, **kwargs) + self.watched_target = self.db.get_collection("test", *args, **kwargs) self.input_target = self.watched_target # Ensure the collection exists and is empty. self.input_target.insert_one({}) self.input_target.delete_many({}) -class TestDatabaseChangeStreamsWCustomTypes( - IntegrationTest, ChangeStreamsWCustomTypesTestMixin): +class TestDatabaseChangeStreamsWCustomTypes(IntegrationTest, ChangeStreamsWCustomTypesTestMixin): @classmethod @client_context.require_version_min(4, 0, 0) @client_context.require_no_mmap @@ -823,15 +902,13 @@ def tearDown(self): self.client.drop_database(self.watched_target) def create_targets(self, *args, **kwargs): - self.watched_target = self.client.get_database( - self.db.name, *args, **kwargs) + self.watched_target = self.client.get_database(self.db.name, *args, **kwargs) self.input_target = self.watched_target.test # Insert a record to ensure db, coll are created. - self.input_target.insert_one({'data': 'dummy'}) + self.input_target.insert_one({"data": "dummy"}) -class TestClusterChangeStreamsWCustomTypes( - IntegrationTest, ChangeStreamsWCustomTypesTestMixin): +class TestClusterChangeStreamsWCustomTypes(IntegrationTest, ChangeStreamsWCustomTypesTestMixin): @classmethod @client_context.require_version_min(4, 0, 0) @client_context.require_no_mmap @@ -845,15 +922,15 @@ def tearDown(self): self.client.drop_database(self.db) def create_targets(self, *args, **kwargs): - codec_options = kwargs.pop('codec_options', None) + codec_options = kwargs.pop("codec_options", None) if codec_options: - kwargs['type_registry'] = codec_options.type_registry - kwargs['document_class'] = codec_options.document_class + kwargs["type_registry"] = codec_options.type_registry + kwargs["document_class"] = codec_options.document_class self.watched_target = rs_client(*args, **kwargs) self.addCleanup(self.watched_target.close) self.input_target = self.watched_target[self.db.name].test # Insert a record to ensure db, coll are created. - self.input_target.insert_one({'data': 'dummy'}) + self.input_target.insert_one({"data": "dummy"}) if __name__ == "__main__": diff --git a/test/test_data_lake.py b/test/test_data_lake.py index 2954efe651..863b3a4f59 100644 --- a/test/test_data_lake.py +++ b/test/test_data_lake.py @@ -19,33 +19,37 @@ sys.path[0:0] = [""] -from pymongo.auth import MECHANISMS -from test import client_context, unittest, IntegrationTest +from test import IntegrationTest, client_context, unittest from test.crud_v2_format import TestCrudV2 from test.utils import ( - rs_client_noauth, rs_or_single_client, OvertCommandListener, TestCreator) + OvertCommandListener, + TestCreator, + rs_client_noauth, + rs_or_single_client, +) +from pymongo.auth import MECHANISMS # Location of JSON test specifications. -_TEST_PATH = os.path.join( - os.path.dirname(os.path.realpath(__file__)), "data_lake") +_TEST_PATH = os.path.join(os.path.dirname(os.path.realpath(__file__)), "data_lake") class TestDataLakeMustConnect(IntegrationTest): def test_connected_to_data_lake(self): - data_lake = os.environ.get('DATA_LAKE') + data_lake = os.environ.get("DATA_LAKE") if not data_lake: - self.skipTest('DATA_LAKE is not set') + self.skipTest("DATA_LAKE is not set") - self.assertTrue(client_context.is_data_lake, - 'client context.is_data_lake must be True when ' - 'DATA_LAKE is set') + self.assertTrue( + client_context.is_data_lake, + "client context.is_data_lake must be True when " "DATA_LAKE is set", + ) class TestDataLakeProse(IntegrationTest): # Default test database and collection names. - TEST_DB = 'test' - TEST_COLLECTION = 'driverdata' + TEST_DB = "test" + TEST_COLLECTION = "driverdata" @classmethod @client_context.require_data_lake @@ -56,8 +60,7 @@ def setUpClass(cls): def test_1(self): listener = OvertCommandListener() client = rs_or_single_client(event_listeners=[listener]) - cursor = client[self.TEST_DB][self.TEST_COLLECTION].find( - {}, batch_size=2) + cursor = client[self.TEST_DB][self.TEST_COLLECTION].find({}, batch_size=2) next(cursor) # find command assertions @@ -69,13 +72,12 @@ def test_1(self): # killCursors command assertions cursor.close() started = listener.results["started"][-1] - self.assertEqual(started.command_name, 'killCursors') + self.assertEqual(started.command_name, "killCursors") succeeded = listener.results["succeeded"][-1] - self.assertEqual(succeeded.command_name, 'killCursors') + self.assertEqual(succeeded.command_name, "killCursors") self.assertIn(cursor_id, started.command["cursors"]) - target_ns = ".".join([started.command['$db'], - started.command['killCursors']]) + target_ns = ".".join([started.command["$db"], started.command["killCursors"]]) self.assertEqual(cursor_ns, target_ns) self.assertIn(cursor_id, succeeded.reply["cursorsKilled"]) @@ -83,19 +85,19 @@ def test_1(self): # Test no auth def test_2(self): client = rs_client_noauth() - client.admin.command('ping') + client.admin.command("ping") # Test with auth def test_3(self): - for mechanism in ['SCRAM-SHA-1', 'SCRAM-SHA-256']: + for mechanism in ["SCRAM-SHA-1", "SCRAM-SHA-256"]: client = rs_or_single_client(authMechanism=mechanism) client[self.TEST_DB][self.TEST_COLLECTION].find_one() class DataLakeTestSpec(TestCrudV2): # Default test database and collection names. - TEST_DB = 'test' - TEST_COLLECTION = 'driverdata' + TEST_DB = "test" + TEST_COLLECTION = "driverdata" @classmethod @client_context.require_data_lake diff --git a/test/test_database.py b/test/test_database.py index 4adccc1b58..57f276b89e 100644 --- a/test/test_database.py +++ b/test/test_database.py @@ -20,43 +20,43 @@ sys.path[0:0] = [""] +from test import IntegrationTest, SkipTest, client_context, unittest +from test.test_custom_types import DECIMAL_CODECOPTS +from test.utils import ( + IMPOSSIBLE_WRITE_CONCERN, + DeprecationFilter, + OvertCommandListener, + ignore_deprecations, + rs_or_single_client, + server_started_with_auth, + wait_until, +) + from bson.codec_options import CodecOptions -from bson.int64 import Int64 -from bson.regex import Regex from bson.dbref import DBRef +from bson.int64 import Int64 from bson.objectid import ObjectId +from bson.regex import Regex from bson.son import SON -from pymongo import (auth, - helpers) +from pymongo import auth, helpers from pymongo.collection import Collection from pymongo.database import Database -from pymongo.errors import (CollectionInvalid, - ConfigurationError, - ExecutionTimeout, - InvalidName, - OperationFailure, - WriteConcernError) +from pymongo.errors import ( + CollectionInvalid, + ConfigurationError, + ExecutionTimeout, + InvalidName, + OperationFailure, + WriteConcernError, +) from pymongo.mongo_client import MongoClient from pymongo.read_concern import ReadConcern from pymongo.read_preferences import ReadPreference from pymongo.write_concern import WriteConcern -from test import (client_context, - SkipTest, - unittest, - IntegrationTest) -from test.utils import (ignore_deprecations, - rs_or_single_client, - server_started_with_auth, - wait_until, - DeprecationFilter, - IMPOSSIBLE_WRITE_CONCERN, - OvertCommandListener) -from test.test_custom_types import DECIMAL_CODECOPTS class TestDatabaseNoConnect(unittest.TestCase): - """Test Database features on a client that does not connect. - """ + """Test Database features on a client that does not connect.""" @classmethod def setUpClass(cls): @@ -67,18 +67,17 @@ def test_name(self): self.assertRaises(InvalidName, Database, self.client, "my db") self.assertRaises(InvalidName, Database, self.client, 'my"db') self.assertRaises(InvalidName, Database, self.client, "my\x00db") - self.assertRaises(InvalidName, Database, - self.client, "my\u0000db") + self.assertRaises(InvalidName, Database, self.client, "my\u0000db") self.assertEqual("name", Database(self.client, "name").name) def test_get_collection(self): codec_options = CodecOptions(tz_aware=True) write_concern = WriteConcern(w=2, j=True) - read_concern = ReadConcern('majority') + read_concern = ReadConcern("majority") coll = self.client.pymongo_test.get_collection( - 'foo', codec_options, ReadPreference.SECONDARY, write_concern, - read_concern) - self.assertEqual('foo', coll.name) + "foo", codec_options, ReadPreference.SECONDARY, write_concern, read_concern + ) + self.assertEqual("foo", coll.name) self.assertEqual(codec_options, coll.codec_options) self.assertEqual(ReadPreference.SECONDARY, coll.read_preference) self.assertEqual(write_concern, coll.write_concern) @@ -86,7 +85,7 @@ def test_get_collection(self): def test_getattr(self): db = self.client.pymongo_test - self.assertTrue(isinstance(db['_does_not_exist'], Collection)) + self.assertTrue(isinstance(db["_does_not_exist"], Collection)) with self.assertRaises(AttributeError) as context: db._does_not_exist @@ -94,24 +93,19 @@ def test_getattr(self): # Message should be: "AttributeError: Database has no attribute # '_does_not_exist'. To access the _does_not_exist collection, # use database['_does_not_exist']". - self.assertIn("has no attribute '_does_not_exist'", - str(context.exception)) + self.assertIn("has no attribute '_does_not_exist'", str(context.exception)) def test_iteration(self): self.assertRaises(TypeError, next, self.client.pymongo_test) class TestDatabase(IntegrationTest): - def test_equality(self): - self.assertNotEqual(Database(self.client, "test"), - Database(self.client, "mike")) - self.assertEqual(Database(self.client, "test"), - Database(self.client, "test")) + self.assertNotEqual(Database(self.client, "test"), Database(self.client, "mike")) + self.assertEqual(Database(self.client, "test"), Database(self.client, "test")) # Explicitly test inequality - self.assertFalse(Database(self.client, "test") != - Database(self.client, "test")) + self.assertFalse(Database(self.client, "test") != Database(self.client, "test")) def test_hashable(self): self.assertIn(self.client.test, {Database(self.client, "test")}) @@ -124,9 +118,10 @@ def test_get_coll(self): self.assertEqual(db.test.mike, db["test.mike"]) def test_repr(self): - self.assertEqual(repr(Database(self.client, "pymongo_test")), - "Database(%r, %s)" % (self.client, - repr("pymongo_test"))) + self.assertEqual( + repr(Database(self.client, "pymongo_test")), + "Database(%r, %s)" % (self.client, repr("pymongo_test")), + ) def test_create_collection(self): db = Database(self.client, "pymongo_test") @@ -163,7 +158,8 @@ def test_list_collection_names(self): db.systemcoll.test.insert_one({}) no_system_collections = db.list_collection_names( - filter={"name": {"$regex": r"^(?!system\.)"}}) + filter={"name": {"$regex": r"^(?!system\.)"}} + ) for coll in no_system_collections: self.assertTrue(not coll.startswith("system.")) self.assertIn("systemcoll.test", no_system_collections) @@ -190,15 +186,14 @@ def test_list_collection_names_filter(self): self.addCleanup(client.drop_database, db.name) # Should not send nameOnly. - for filter in ({'options.capped': True}, - {'options.capped': True, 'name': 'capped'}): + for filter in ({"options.capped": True}, {"options.capped": True, "name": "capped"}): results.clear() names = db.list_collection_names(filter=filter) self.assertEqual(names, ["capped"]) self.assertNotIn("nameOnly", results["started"][0].command) # Should send nameOnly (except on 2.6). - for filter in (None, {}, {'name': {'$in': ['capped', 'non_capped']}}): + for filter in (None, {}, {"name": {"$in": ["capped", "non_capped"]}}): results.clear() names = db.list_collection_names(filter=filter) self.assertIn("capped", names) @@ -236,8 +231,10 @@ def test_list_collections(self): coll_cnt = {} # Checking if is there any collection which don't exists. - if (len(set(colls) - set(["test","test.mike"])) == 0 or - len(set(colls) - set(["test","test.mike","system.indexes"])) == 0): + if ( + len(set(colls) - set(["test", "test.mike"])) == 0 + or len(set(colls) - set(["test", "test.mike", "system.indexes"])) == 0 + ): self.assertTrue(True) else: self.assertTrue(False) @@ -251,7 +248,7 @@ def test_list_collections(self): db.drop_collection("test") db.create_collection("test", capped=True, size=4096) - results = db.list_collections(filter={'options.capped': True}) + results = db.list_collections(filter={"options.capped": True}) colls = [result["name"] for result in results] # Checking only capped collections are present @@ -274,8 +271,10 @@ def test_list_collections(self): coll_cnt = {} # Checking if is there any collection which don't exists. - if (len(set(colls) - set(["test"])) == 0 or - len(set(colls) - set(["test","system.indexes"])) == 0): + if ( + len(set(colls) - set(["test"])) == 0 + or len(set(colls) - set(["test", "system.indexes"])) == 0 + ): self.assertTrue(True) else: self.assertTrue(False) @@ -284,13 +283,13 @@ def test_list_collections(self): def test_list_collection_names_single_socket(self): client = rs_or_single_client(maxPoolSize=1) - client.drop_database('test_collection_names_single_socket') + client.drop_database("test_collection_names_single_socket") db = client.test_collection_names_single_socket for i in range(200): db.create_collection(str(i)) db.list_collection_names() # Must not hang. - client.drop_database('test_collection_names_single_socket') + client.drop_database("test_collection_names_single_socket") def test_drop_collection(self): db = Database(self.client, "pymongo_test") @@ -322,10 +321,9 @@ def test_drop_collection(self): db.drop_collection(db.test.doesnotexist) if client_context.is_rs: - db_wc = Database(self.client, 'pymongo_test', - write_concern=IMPOSSIBLE_WRITE_CONCERN) + db_wc = Database(self.client, "pymongo_test", write_concern=IMPOSSIBLE_WRITE_CONCERN) with self.assertRaises(WriteConcernError): - db_wc.drop_collection('test') + db_wc.drop_collection("test") def test_validate_collection(self): db = self.client.pymongo_test @@ -335,10 +333,8 @@ def test_validate_collection(self): db.test.insert_one({"dummy": "object"}) - self.assertRaises(OperationFailure, db.validate_collection, - "test.doesnotexist") - self.assertRaises(OperationFailure, db.validate_collection, - db.test.doesnotexist) + self.assertRaises(OperationFailure, db.validate_collection, "test.doesnotexist") + self.assertRaises(OperationFailure, db.validate_collection, db.test.doesnotexist) self.assertTrue(db.validate_collection("test")) self.assertTrue(db.validate_collection(db.test)) @@ -354,10 +350,9 @@ def test_validate_collection_background(self): coll = db.test self.assertTrue(db.validate_collection(coll, background=False)) # The inMemory storage engine does not support background=True. - if client_context.storage_engine != 'inMemory': + if client_context.storage_engine != "inMemory": self.assertTrue(db.validate_collection(coll, background=True)) - self.assertTrue( - db.validate_collection(coll, scandata=True, background=True)) + self.assertTrue(db.validate_collection(coll, scandata=True, background=True)) # The server does not support background=True with full=True. # Assert that we actually send the background option by checking # that this combination fails. @@ -378,24 +373,25 @@ def test_command(self): def test_command_with_regex(self): db = self.client.pymongo_test db.test.drop() - db.test.insert_one({'r': re.compile('.*')}) - db.test.insert_one({'r': Regex('.*')}) + db.test.insert_one({"r": re.compile(".*")}) + db.test.insert_one({"r": Regex(".*")}) - result = db.command('aggregate', 'test', pipeline=[], cursor={}) - for doc in result['cursor']['firstBatch']: - self.assertTrue(isinstance(doc['r'], Regex)) + result = db.command("aggregate", "test", pipeline=[], cursor={}) + for doc in result["cursor"]["firstBatch"]: + self.assertTrue(isinstance(doc["r"], Regex)) def test_password_digest(self): self.assertRaises(TypeError, auth._password_digest, 5) self.assertRaises(TypeError, auth._password_digest, True) self.assertRaises(TypeError, auth._password_digest, None) - self.assertTrue(isinstance(auth._password_digest("mike", "password"), - str)) - self.assertEqual(auth._password_digest("mike", "password"), - "cd7e45b3b2767dc2fa9b6b548457ed00") - self.assertEqual(auth._password_digest("Gustave", "Dor\xe9"), - "81e0e2364499209f466e75926a162d73") + self.assertTrue(isinstance(auth._password_digest("mike", "password"), str)) + self.assertEqual( + auth._password_digest("mike", "password"), "cd7e45b3b2767dc2fa9b6b548457ed00" + ) + self.assertEqual( + auth._password_digest("Gustave", "Dor\xe9"), "81e0e2364499209f466e75926a162d73" + ) def test_id_ordering(self): # PyMongo attempts to have _id show up first @@ -406,11 +402,11 @@ def test_id_ordering(self): # with hash randomization enabled (e.g. tox). db = self.client.pymongo_test db.test.drop() - db.test.insert_one(SON([("hello", "world"), - ("_id", 5)])) + db.test.insert_one(SON([("hello", "world"), ("_id", 5)])) db = self.client.get_database( - "pymongo_test", codec_options=CodecOptions(document_class=SON)) + "pymongo_test", codec_options=CodecOptions(document_class=SON) + ) cursor = db.test.find() for x in cursor: for (k, v) in x.items(): @@ -429,10 +425,8 @@ def test_deref(self): obj = {"x": True} key = db.test.insert_one(obj).inserted_id self.assertEqual(obj, db.dereference(DBRef("test", key))) - self.assertEqual(obj, - db.dereference(DBRef("test", key, "pymongo_test"))) - self.assertRaises(ValueError, - db.dereference, DBRef("test", key, "foo")) + self.assertEqual(obj, db.dereference(DBRef("test", key, "pymongo_test"))) + self.assertRaises(ValueError, db.dereference, DBRef("test", key, "foo")) self.assertEqual(None, db.dereference(DBRef("test", 4))) obj = {"_id": 4} @@ -445,10 +439,11 @@ def test_deref_kwargs(self): db.test.insert_one({"_id": 4, "foo": "bar"}) db = self.client.get_database( - "pymongo_test", codec_options=CodecOptions(document_class=SON)) - self.assertEqual(SON([("foo", "bar")]), - db.dereference(DBRef("test", 4), - projection={"_id": False})) + "pymongo_test", codec_options=CodecOptions(document_class=SON) + ) + self.assertEqual( + SON([("foo", "bar")]), db.dereference(DBRef("test", 4), projection={"_id": False}) + ) # TODO some of these tests belong in the collection level testing. def test_insert_find_one(self): @@ -482,12 +477,12 @@ def test_long(self): db = self.client.pymongo_test db.test.drop() db.test.insert_one({"x": 9223372036854775807}) - retrieved = db.test.find_one()['x'] + retrieved = db.test.find_one()["x"] self.assertEqual(Int64(9223372036854775807), retrieved) self.assertIsInstance(retrieved, Int64) db.test.delete_many({}) db.test.insert_one({"x": Int64(1)}) - retrieved = db.test.find_one()['x'] + retrieved = db.test.find_one()["x"] self.assertEqual(Int64(1), retrieved) self.assertIsInstance(retrieved, Int64) @@ -529,11 +524,10 @@ def test_command_response_without_ok(self): # Sometimes (SERVER-10891) the server's response to a badly-formatted # command document will have no 'ok' field. We should raise # OperationFailure instead of KeyError. - self.assertRaises(OperationFailure, - helpers._check_command_response, {}, None) + self.assertRaises(OperationFailure, helpers._check_command_response, {}, None) try: - helpers._check_command_response({'$err': 'foo'}, None) + helpers._check_command_response({"$err": "foo"}, None) except OperationFailure as e: self.assertEqual(e.args[0], "foo, full error: {'$err': 'foo'}") else: @@ -541,64 +535,59 @@ def test_command_response_without_ok(self): def test_mongos_response(self): error_document = { - 'ok': 0, - 'errmsg': 'outer', - 'raw': {'shard0/host0,host1': {'ok': 0, 'errmsg': 'inner'}}} + "ok": 0, + "errmsg": "outer", + "raw": {"shard0/host0,host1": {"ok": 0, "errmsg": "inner"}}, + } with self.assertRaises(OperationFailure) as context: helpers._check_command_response(error_document, None) - self.assertIn('inner', str(context.exception)) + self.assertIn("inner", str(context.exception)) # If a shard has no primary and you run a command like dbstats, which # cannot be run on a secondary, mongos's response includes empty "raw" # errors. See SERVER-15428. - error_document = { - 'ok': 0, - 'errmsg': 'outer', - 'raw': {'shard0/host0,host1': {}}} + error_document = {"ok": 0, "errmsg": "outer", "raw": {"shard0/host0,host1": {}}} with self.assertRaises(OperationFailure) as context: helpers._check_command_response(error_document, None) - self.assertIn('outer', str(context.exception)) + self.assertIn("outer", str(context.exception)) # Raw error has ok: 0 but no errmsg. Not a known case, but test it. - error_document = { - 'ok': 0, - 'errmsg': 'outer', - 'raw': {'shard0/host0,host1': {'ok': 0}}} + error_document = {"ok": 0, "errmsg": "outer", "raw": {"shard0/host0,host1": {"ok": 0}}} with self.assertRaises(OperationFailure) as context: helpers._check_command_response(error_document, None) - self.assertIn('outer', str(context.exception)) + self.assertIn("outer", str(context.exception)) @client_context.require_test_commands @client_context.require_no_mongos def test_command_max_time_ms(self): - self.client.admin.command("configureFailPoint", - "maxTimeAlwaysTimeOut", - mode="alwaysOn") + self.client.admin.command("configureFailPoint", "maxTimeAlwaysTimeOut", mode="alwaysOn") try: db = self.client.pymongo_test - db.command('count', 'test') - self.assertRaises(ExecutionTimeout, db.command, - 'count', 'test', maxTimeMS=1) - pipeline = [{'$project': {'name': 1, 'count': 1}}] + db.command("count", "test") + self.assertRaises(ExecutionTimeout, db.command, "count", "test", maxTimeMS=1) + pipeline = [{"$project": {"name": 1, "count": 1}}] # Database command helper. - db.command('aggregate', 'test', pipeline=pipeline, cursor={}) - self.assertRaises(ExecutionTimeout, db.command, - 'aggregate', 'test', - pipeline=pipeline, cursor={}, maxTimeMS=1) + db.command("aggregate", "test", pipeline=pipeline, cursor={}) + self.assertRaises( + ExecutionTimeout, + db.command, + "aggregate", + "test", + pipeline=pipeline, + cursor={}, + maxTimeMS=1, + ) # Collection helper. db.test.aggregate(pipeline=pipeline) - self.assertRaises(ExecutionTimeout, - db.test.aggregate, pipeline, maxTimeMS=1) + self.assertRaises(ExecutionTimeout, db.test.aggregate, pipeline, maxTimeMS=1) finally: - self.client.admin.command("configureFailPoint", - "maxTimeAlwaysTimeOut", - mode="off") + self.client.admin.command("configureFailPoint", "maxTimeAlwaysTimeOut", mode="off") def test_with_options(self): codec_options = DECIMAL_CODECOPTS @@ -607,13 +596,22 @@ def test_with_options(self): read_concern = ReadConcern(level="majority") # List of all options to compare. - allopts = ['name', 'client', 'codec_options', - 'read_preference', 'write_concern', 'read_concern'] + allopts = [ + "name", + "client", + "codec_options", + "read_preference", + "write_concern", + "read_concern", + ] db1 = self.client.get_database( - 'with_options_test', codec_options=codec_options, - read_preference=read_preference, write_concern=write_concern, - read_concern=read_concern) + "with_options_test", + codec_options=codec_options, + read_preference=read_preference, + write_concern=write_concern, + read_concern=read_concern, + ) # Case 1: swap no options db2 = db1.with_options() @@ -621,22 +619,25 @@ def test_with_options(self): self.assertEqual(getattr(db1, opt), getattr(db2, opt)) # Case 2: swap all options - newopts = {'codec_options': CodecOptions(), - 'read_preference': ReadPreference.PRIMARY, - 'write_concern': WriteConcern(w=1), - 'read_concern': ReadConcern(level="local")} + newopts = { + "codec_options": CodecOptions(), + "read_preference": ReadPreference.PRIMARY, + "write_concern": WriteConcern(w=1), + "read_concern": ReadConcern(level="local"), + } db2 = db1.with_options(**newopts) for opt in newopts: - self.assertEqual( - getattr(db2, opt), newopts.get(opt, getattr(db1, opt))) + self.assertEqual(getattr(db2, opt), newopts.get(opt, getattr(db1, opt))) class TestDatabaseAggregation(IntegrationTest): def setUp(self): - self.pipeline = [{"$listLocalSessions": {}}, - {"$limit": 1}, - {"$addFields": {"dummy": "dummy field"}}, - {"$project": {"_id": 0, "dummy": 1}}] + self.pipeline = [ + {"$listLocalSessions": {}}, + {"$limit": 1}, + {"$addFields": {"dummy": "dummy field"}}, + {"$project": {"_id": 0, "dummy": 1}}, + ] self.result = {"dummy": "dummy field"} self.admin = self.client.admin @@ -655,8 +656,7 @@ def test_database_aggregation_fake_cursor(self): # SERVER-43287 disallows writing with $out to the admin db, use # $merge instead. db_name = "pymongo_test" - write_stage = { - "$merge": {"into": {"db": db_name, "coll": coll_name}}} + write_stage = {"$merge": {"into": {"db": db_name, "coll": coll_name}}} output_coll = self.client[db_name][coll_name] output_coll.drop() self.addCleanup(output_coll.drop) diff --git a/test/test_dbref.py b/test/test_dbref.py index 964947351e..9a00707524 100644 --- a/test/test_dbref.py +++ b/test/test_dbref.py @@ -16,14 +16,15 @@ import pickle import sys + sys.path[0:0] = [""] -from bson import encode, decode -from bson.dbref import DBRef -from bson.objectid import ObjectId +from copy import deepcopy from test import unittest -from copy import deepcopy +from bson import decode, encode +from bson.dbref import DBRef +from bson.objectid import ObjectId class TestDBRef(unittest.TestCase): @@ -56,53 +57,45 @@ def bar(): self.assertRaises(AttributeError, bar) def test_repr(self): - self.assertEqual(repr(DBRef("coll", - ObjectId("1234567890abcdef12345678"))), - "DBRef('coll', ObjectId('1234567890abcdef12345678'))") - self.assertEqual(repr(DBRef("coll", - ObjectId("1234567890abcdef12345678"))), - "DBRef(%s, ObjectId('1234567890abcdef12345678'))" - % (repr('coll'),) - ) - self.assertEqual(repr(DBRef("coll", 5, foo="bar")), - "DBRef('coll', 5, foo='bar')") - self.assertEqual(repr(DBRef("coll", - ObjectId("1234567890abcdef12345678"), "foo")), - "DBRef('coll', ObjectId('1234567890abcdef12345678'), " - "'foo')") + self.assertEqual( + repr(DBRef("coll", ObjectId("1234567890abcdef12345678"))), + "DBRef('coll', ObjectId('1234567890abcdef12345678'))", + ) + self.assertEqual( + repr(DBRef("coll", ObjectId("1234567890abcdef12345678"))), + "DBRef(%s, ObjectId('1234567890abcdef12345678'))" % (repr("coll"),), + ) + self.assertEqual(repr(DBRef("coll", 5, foo="bar")), "DBRef('coll', 5, foo='bar')") + self.assertEqual( + repr(DBRef("coll", ObjectId("1234567890abcdef12345678"), "foo")), + "DBRef('coll', ObjectId('1234567890abcdef12345678'), " "'foo')", + ) def test_equality(self): obj_id = ObjectId("1234567890abcdef12345678") - self.assertEqual(DBRef('foo', 5), DBRef('foo', 5)) + self.assertEqual(DBRef("foo", 5), DBRef("foo", 5)) self.assertEqual(DBRef("coll", obj_id), DBRef("coll", obj_id)) - self.assertNotEqual(DBRef("coll", obj_id), - DBRef("coll", obj_id, "foo")) + self.assertNotEqual(DBRef("coll", obj_id), DBRef("coll", obj_id, "foo")) self.assertNotEqual(DBRef("coll", obj_id), DBRef("col", obj_id)) - self.assertNotEqual(DBRef("coll", obj_id), - DBRef("coll", ObjectId(b"123456789011"))) + self.assertNotEqual(DBRef("coll", obj_id), DBRef("coll", ObjectId(b"123456789011"))) self.assertNotEqual(DBRef("coll", obj_id), 4) - self.assertNotEqual(DBRef("coll", obj_id, "foo"), - DBRef("coll", obj_id, "bar")) + self.assertNotEqual(DBRef("coll", obj_id, "foo"), DBRef("coll", obj_id, "bar")) # Explicitly test inequality - self.assertFalse(DBRef('foo', 5) != DBRef('foo', 5)) + self.assertFalse(DBRef("foo", 5) != DBRef("foo", 5)) self.assertFalse(DBRef("coll", obj_id) != DBRef("coll", obj_id)) - self.assertFalse(DBRef("coll", obj_id, "foo") != - DBRef("coll", obj_id, "foo")) + self.assertFalse(DBRef("coll", obj_id, "foo") != DBRef("coll", obj_id, "foo")) def test_kwargs(self): - self.assertEqual(DBRef("coll", 5, foo="bar"), - DBRef("coll", 5, foo="bar")) + self.assertEqual(DBRef("coll", 5, foo="bar"), DBRef("coll", 5, foo="bar")) self.assertNotEqual(DBRef("coll", 5, foo="bar"), DBRef("coll", 5)) - self.assertNotEqual(DBRef("coll", 5, foo="bar"), - DBRef("coll", 5, foo="baz")) + self.assertNotEqual(DBRef("coll", 5, foo="bar"), DBRef("coll", 5, foo="baz")) self.assertEqual("bar", DBRef("coll", 5, foo="bar").foo) - self.assertRaises(AttributeError, getattr, - DBRef("coll", 5, foo="bar"), "bar") + self.assertRaises(AttributeError, getattr, DBRef("coll", 5, foo="bar"), "bar") def test_deepcopy(self): - a = DBRef('coll', 'asdf', 'db', x=[1]) + a = DBRef("coll", "asdf", "db", x=[1]) b = deepcopy(a) self.assertEqual(a, b) @@ -115,19 +108,19 @@ def test_deepcopy(self): self.assertEqual(b.x, [2]) def test_pickling(self): - dbr = DBRef('coll', 5, foo='bar') + dbr = DBRef("coll", 5, foo="bar") for protocol in [0, 1, 2, -1]: pkl = pickle.dumps(dbr, protocol=protocol) dbr2 = pickle.loads(pkl) self.assertEqual(dbr, dbr2) def test_dbref_hash(self): - dbref_1a = DBRef('collection', 'id', 'database') - dbref_1b = DBRef('collection', 'id', 'database') + dbref_1a = DBRef("collection", "id", "database") + dbref_1b = DBRef("collection", "id", "database") self.assertEqual(hash(dbref_1a), hash(dbref_1b)) - dbref_2a = DBRef('collection', 'id', 'database', custom='custom') - dbref_2b = DBRef('collection', 'id', 'database', custom='custom') + dbref_2a = DBRef("collection", "id", "database", custom="custom") + dbref_2b = DBRef("collection", "id", "database", custom="custom") self.assertEqual(hash(dbref_2a), hash(dbref_2b)) self.assertNotEqual(hash(dbref_1a), hash(dbref_2a)) @@ -156,12 +149,12 @@ def test_decoding_1_2_3(self): {"foo": 1, "$ref": "coll0", "$id": 1, "$db": "db0", "bar": 1}, ]: with self.subTest(doc=doc): - decoded = decode(encode({'dbref': doc})) - dbref = decoded['dbref'] + decoded = decode(encode({"dbref": doc})) + dbref = decoded["dbref"] self.assertIsInstance(dbref, DBRef) - self.assertEqual(dbref.collection, doc['$ref']) - self.assertEqual(dbref.id, doc['$id']) - self.assertEqual(dbref.database, doc.get('$db')) + self.assertEqual(dbref.collection, doc["$ref"]) + self.assertEqual(dbref.id, doc["$id"]) + self.assertEqual(dbref.database, doc.get("$db")) for extra in set(doc.keys()) - {"$ref", "$id", "$db"}: self.assertEqual(getattr(dbref, extra), doc[extra]) @@ -178,8 +171,8 @@ def test_decoding_4_5(self): {"$ref": "coll0", "$id": 1, "$db": 1}, ]: with self.subTest(doc=doc): - decoded = decode(encode({'dbref': doc})) - dbref = decoded['dbref'] + decoded = decode(encode({"dbref": doc})) + dbref = decoded["dbref"] self.assertIsInstance(dbref, dict) def test_encoding_1_2(self): @@ -198,9 +191,9 @@ def test_encoding_1_2(self): ]: with self.subTest(doc=doc): # Decode the test input to a DBRef via a BSON roundtrip. - encoded_doc = encode({'dbref': doc}) + encoded_doc = encode({"dbref": doc}) decoded = decode(encoded_doc) - dbref = decoded['dbref'] + dbref = decoded["dbref"] self.assertIsInstance(dbref, DBRef) # Encode the DBRef. encoded_dbref = encode(decoded) @@ -221,9 +214,9 @@ def test_encoding_3(self): ]: with self.subTest(doc=doc): # Decode the test input to a DBRef via a BSON roundtrip. - encoded_doc = encode({'dbref': doc}) + encoded_doc = encode({"dbref": doc}) decoded = decode(encoded_doc) - dbref = decoded['dbref'] + dbref = decoded["dbref"] self.assertIsInstance(dbref, DBRef) # Encode the DBRef. encoded_dbref = encode(decoded) diff --git a/test/test_decimal128.py b/test/test_decimal128.py index 4ff25935dd..28f382554c 100644 --- a/test/test_decimal128.py +++ b/test/test_decimal128.py @@ -16,40 +16,38 @@ import pickle import sys - from decimal import Decimal sys.path[0:0] = [""] -from bson.decimal128 import Decimal128, create_decimal128_context from test import client_context, unittest -class TestDecimal128(unittest.TestCase): +from bson.decimal128 import Decimal128, create_decimal128_context + +class TestDecimal128(unittest.TestCase): @client_context.require_connection def test_round_trip(self): coll = client_context.client.pymongo_test.test coll.drop() - dec128 = Decimal128.from_bid( - b'\x00@cR\xbf\xc6\x01\x00\x00\x00\x00\x00\x00\x00\x1c0') - coll.insert_one({'dec128': dec128}) - doc = coll.find_one({'dec128': dec128}) + dec128 = Decimal128.from_bid(b"\x00@cR\xbf\xc6\x01\x00\x00\x00\x00\x00\x00\x00\x1c0") + coll.insert_one({"dec128": dec128}) + doc = coll.find_one({"dec128": dec128}) self.assertIsNotNone(doc) - self.assertEqual(doc['dec128'], dec128) + self.assertEqual(doc["dec128"], dec128) def test_pickle(self): - dec128 = Decimal128.from_bid( - b'\x00@cR\xbf\xc6\x01\x00\x00\x00\x00\x00\x00\x00\x1c0') + dec128 = Decimal128.from_bid(b"\x00@cR\xbf\xc6\x01\x00\x00\x00\x00\x00\x00\x00\x1c0") for protocol in range(pickle.HIGHEST_PROTOCOL + 1): pkl = pickle.dumps(dec128, protocol=protocol) self.assertEqual(dec128, pickle.loads(pkl)) def test_special(self): - dnan = Decimal('NaN') - dnnan = Decimal('-NaN') - dsnan = Decimal('sNaN') - dnsnan = Decimal('-sNaN') + dnan = Decimal("NaN") + dnnan = Decimal("-NaN") + dsnan = Decimal("sNaN") + dnsnan = Decimal("-sNaN") dnan128 = Decimal128(dnan) dnnan128 = Decimal128(dnnan) dsnan128 = Decimal128(dsnan) @@ -69,5 +67,5 @@ def test_decimal128_context(self): self.assertEqual("0E-6176", str(ctx.copy().create_decimal("1E-6177"))) -if __name__ == '__main__': +if __name__ == "__main__": unittest.main() diff --git a/test/test_discovery_and_monitoring.py b/test/test_discovery_and_monitoring.py index c26c0df309..46ba83287d 100644 --- a/test/test_discovery_and_monitoring.py +++ b/test/test_discovery_and_monitoring.py @@ -21,40 +21,41 @@ sys.path[0:0] = [""] -from bson import json_util, Timestamp -from pymongo import (common, - monitoring) -from pymongo.errors import (AutoReconnect, - ConfigurationError, - NetworkTimeout, - NotPrimaryError, - OperationFailure) -from pymongo.helpers import (_check_command_response, - _check_write_command_response) +from test import IntegrationTest, unittest +from test.pymongo_mocks import DummyMonitor +from test.utils import ( + CMAPListener, + HeartbeatEventListener, + TestCreator, + assertion_context, + client_context, + get_pool, + rs_or_single_client, + server_name_to_type, + single_client, + wait_until, +) +from test.utils_spec_runner import SpecRunner, SpecRunnerThread + +from bson import Timestamp, json_util +from pymongo import common, monitoring +from pymongo.errors import ( + AutoReconnect, + ConfigurationError, + NetworkTimeout, + NotPrimaryError, + OperationFailure, +) from pymongo.hello import Hello, HelloCompat -from pymongo.server_description import ServerDescription, SERVER_TYPE +from pymongo.helpers import _check_command_response, _check_write_command_response +from pymongo.server_description import SERVER_TYPE, ServerDescription from pymongo.settings import TopologySettings from pymongo.topology import Topology, _ErrorContext from pymongo.topology_description import TOPOLOGY_TYPE from pymongo.uri_parser import parse_uri -from test import unittest, IntegrationTest -from test.utils import (assertion_context, - CMAPListener, - client_context, - get_pool, - HeartbeatEventListener, - server_name_to_type, - rs_or_single_client, - single_client, - TestCreator, - wait_until) -from test.utils_spec_runner import SpecRunner, SpecRunnerThread -from test.pymongo_mocks import DummyMonitor - # Location of JSON test specifications. -_TEST_PATH = os.path.join( - os.path.dirname(os.path.realpath(__file__)), 'discovery_and_monitoring') +_TEST_PATH = os.path.join(os.path.dirname(os.path.realpath(__file__)), "discovery_and_monitoring") def create_mock_topology(uri, monitor_class=DummyMonitor): @@ -62,19 +63,20 @@ def create_mock_topology(uri, monitor_class=DummyMonitor): replica_set_name = None direct_connection = None load_balanced = None - if 'replicaset' in parsed_uri['options']: - replica_set_name = parsed_uri['options']['replicaset'] - if 'directConnection' in parsed_uri['options']: - direct_connection = parsed_uri['options']['directConnection'] - if 'loadBalanced' in parsed_uri['options']: - load_balanced = parsed_uri['options']['loadBalanced'] + if "replicaset" in parsed_uri["options"]: + replica_set_name = parsed_uri["options"]["replicaset"] + if "directConnection" in parsed_uri["options"]: + direct_connection = parsed_uri["options"]["directConnection"] + if "loadBalanced" in parsed_uri["options"]: + load_balanced = parsed_uri["options"]["loadBalanced"] topology_settings = TopologySettings( - parsed_uri['nodelist'], + parsed_uri["nodelist"], replica_set_name=replica_set_name, monitor_class=monitor_class, direct_connection=direct_connection, - load_balanced=load_balanced) + load_balanced=load_balanced, + ) c = Topology(topology_settings) c.open() @@ -82,43 +84,42 @@ def create_mock_topology(uri, monitor_class=DummyMonitor): def got_hello(topology, server_address, hello_response): - server_description = ServerDescription( - server_address, Hello(hello_response), 0) + server_description = ServerDescription(server_address, Hello(hello_response), 0) topology.on_change(server_description) def got_app_error(topology, app_error): - server_address = common.partition_node(app_error['address']) + server_address = common.partition_node(app_error["address"]) server = topology.get_server_by_address(server_address) - error_type = app_error['type'] - generation = app_error.get( - 'generation', server.pool.gen.get_overall()) - when = app_error['when'] - max_wire_version = app_error['maxWireVersion'] + error_type = app_error["type"] + generation = app_error.get("generation", server.pool.gen.get_overall()) + when = app_error["when"] + max_wire_version = app_error["maxWireVersion"] # XXX: We could get better test coverage by mocking the errors on the # Pool/SocketInfo. try: - if error_type == 'command': - _check_command_response(app_error['response'], max_wire_version) - _check_write_command_response(app_error['response']) - elif error_type == 'network': - raise AutoReconnect('mock non-timeout network error') - elif error_type == 'timeout': - raise NetworkTimeout('mock network timeout error') + if error_type == "command": + _check_command_response(app_error["response"], max_wire_version) + _check_write_command_response(app_error["response"]) + elif error_type == "network": + raise AutoReconnect("mock non-timeout network error") + elif error_type == "timeout": + raise NetworkTimeout("mock network timeout error") else: - raise AssertionError('unknown error type: %s' % (error_type,)) + raise AssertionError("unknown error type: %s" % (error_type,)) assert False except (AutoReconnect, NotPrimaryError, OperationFailure) as e: - if when == 'beforeHandshakeCompletes': + if when == "beforeHandshakeCompletes": completed_handshake = False - elif when == 'afterHandshakeCompletes': + elif when == "afterHandshakeCompletes": completed_handshake = True else: - assert False, 'Unknown when field %s' % (when,) + assert False, "Unknown when field %s" % (when,) topology.handle_error( - server_address, _ErrorContext(e, max_wire_version, generation, - completed_handshake, None)) + server_address, + _ErrorContext(e, max_wire_version, generation, completed_handshake, None), + ) def get_type(topology, hostname): @@ -139,14 +140,12 @@ def server_type_name(server_type): def check_outcome(self, topology, outcome): - expected_servers = outcome['servers'] + expected_servers = outcome["servers"] # Check weak equality before proceeding. - self.assertEqual( - len(topology.description.server_descriptions()), - len(expected_servers)) + self.assertEqual(len(topology.description.server_descriptions()), len(expected_servers)) - if outcome.get('compatible') is False: + if outcome.get("compatible") is False: with self.assertRaises(ConfigurationError): topology.description.check_compatible() else: @@ -160,59 +159,55 @@ def check_outcome(self, topology, outcome): self.assertTrue(topology.has_server(node)) actual_server = topology.get_server_by_address(node) actual_server_description = actual_server.description - expected_server_type = server_name_to_type(expected_server['type']) + expected_server_type = server_name_to_type(expected_server["type"]) self.assertEqual( server_type_name(expected_server_type), - server_type_name(actual_server_description.server_type)) + server_type_name(actual_server_description.server_type), + ) - self.assertEqual( - expected_server.get('setName'), - actual_server_description.replica_set_name) + self.assertEqual(expected_server.get("setName"), actual_server_description.replica_set_name) - self.assertEqual( - expected_server.get('setVersion'), - actual_server_description.set_version) + self.assertEqual(expected_server.get("setVersion"), actual_server_description.set_version) - self.assertEqual( - expected_server.get('electionId'), - actual_server_description.election_id) + self.assertEqual(expected_server.get("electionId"), actual_server_description.election_id) self.assertEqual( - expected_server.get('topologyVersion'), - actual_server_description.topology_version) + expected_server.get("topologyVersion"), actual_server_description.topology_version + ) - expected_pool = expected_server.get('pool') + expected_pool = expected_server.get("pool") if expected_pool: - self.assertEqual( - expected_pool.get('generation'), - actual_server.pool.gen.get_overall()) + self.assertEqual(expected_pool.get("generation"), actual_server.pool.gen.get_overall()) - self.assertEqual(outcome['setName'], topology.description.replica_set_name) - self.assertEqual(outcome.get('logicalSessionTimeoutMinutes'), - topology.description.logical_session_timeout_minutes) + self.assertEqual(outcome["setName"], topology.description.replica_set_name) + self.assertEqual( + outcome.get("logicalSessionTimeoutMinutes"), + topology.description.logical_session_timeout_minutes, + ) - expected_topology_type = getattr(TOPOLOGY_TYPE, outcome['topologyType']) - self.assertEqual(topology_type_name(expected_topology_type), - topology_type_name(topology.description.topology_type)) + expected_topology_type = getattr(TOPOLOGY_TYPE, outcome["topologyType"]) + self.assertEqual( + topology_type_name(expected_topology_type), + topology_type_name(topology.description.topology_type), + ) def create_test(scenario_def): def run_scenario(self): - c = create_mock_topology(scenario_def['uri']) + c = create_mock_topology(scenario_def["uri"]) - for i, phase in enumerate(scenario_def['phases']): + for i, phase in enumerate(scenario_def["phases"]): # Including the phase description makes failures easier to debug. - description = phase.get('description', str(i)) - with assertion_context('phase: %s' % (description,)): - for response in phase.get('responses', []): - got_hello( - c, common.partition_node(response[0]), response[1]) + description = phase.get("description", str(i)) + with assertion_context("phase: %s" % (description,)): + for response in phase.get("responses", []): + got_hello(c, common.partition_node(response[0]), response[1]) - for app_error in phase.get('applicationErrors', []): + for app_error in phase.get("applicationErrors", []): got_app_error(c, app_error) - check_outcome(self, c, phase['outcome']) + check_outcome(self, c, phase["outcome"]) return run_scenario @@ -227,8 +222,7 @@ def create_tests(): # Construct test from scenario. new_test = create_test(scenario_def) - test_name = 'test_%s_%s' % ( - dirname, os.path.splitext(filename)[0]) + test_name = "test_%s_%s" % (dirname, os.path.splitext(filename)[0]) new_test.__name__ = test_name setattr(TestAllScenarios, new_test.__name__, new_test) @@ -239,17 +233,16 @@ def create_tests(): class TestClusterTimeComparison(unittest.TestCase): def test_cluster_time_comparison(self): - t = create_mock_topology('mongodb://host') + t = create_mock_topology("mongodb://host") def send_cluster_time(time, inc, should_update): old = t.max_cluster_time() - new = {'clusterTime': Timestamp(time, inc)} - got_hello(t, - ('host', 27017), - {'ok': 1, - 'minWireVersion': 0, - 'maxWireVersion': 6, - '$clusterTime': new}) + new = {"clusterTime": Timestamp(time, inc)} + got_hello( + t, + ("host", 27017), + {"ok": 1, "minWireVersion": 0, "maxWireVersion": 6, "$clusterTime": new}, + ) actual = t.max_cluster_time() if should_update: @@ -265,7 +258,6 @@ def send_cluster_time(time, inc, should_update): class TestIgnoreStaleErrors(IntegrationTest): - def test_ignore_stale_connection_errors(self): N_THREADS = 5 barrier = threading.Barrier(N_THREADS, timeout=30) @@ -273,22 +265,22 @@ def test_ignore_stale_connection_errors(self): self.addCleanup(client.close) # Wait for initial discovery. - client.admin.command('ping') + client.admin.command("ping") pool = get_pool(client) starting_generation = pool.gen.get_overall() - wait_until(lambda: len(pool.sockets) == N_THREADS, 'created sockets') + wait_until(lambda: len(pool.sockets) == N_THREADS, "created sockets") def mock_command(*args, **kwargs): # Synchronize all threads to ensure they use the same generation. barrier.wait() - raise AutoReconnect('mock SocketInfo.command error') + raise AutoReconnect("mock SocketInfo.command error") for sock in pool.sockets: sock.command = mock_command def insert_command(i): try: - client.test.command('insert', 'test', documents=[{'i': i}]) + client.test.command("insert", "test", documents=[{"i": i}]) except AutoReconnect as exc: pass @@ -301,11 +293,10 @@ def insert_command(i): t.join() # Expect a single pool reset for the network error - self.assertEqual( - starting_generation+1, pool.gen.get_overall()) + self.assertEqual(starting_generation + 1, pool.gen.get_overall()) # Server should be selectable. - client.admin.command('ping') + client.admin.command("ping") class CMAPHeartbeatListener(HeartbeatEventListener, CMAPListener): @@ -317,51 +308,51 @@ class TestPoolManagement(IntegrationTest): def test_pool_unpause(self): # This test implements the prose test "Connection Pool Management" listener = CMAPHeartbeatListener() - client = single_client(appName="SDAMPoolManagementTest", - heartbeatFrequencyMS=500, - event_listeners=[listener]) + client = single_client( + appName="SDAMPoolManagementTest", heartbeatFrequencyMS=500, event_listeners=[listener] + ) self.addCleanup(client.close) # Assert that ConnectionPoolReadyEvent occurs after the first # ServerHeartbeatSucceededEvent. listener.wait_for_event(monitoring.PoolReadyEvent, 1) pool_ready = listener.events_by_type(monitoring.PoolReadyEvent)[0] - hb_succeeded = listener.events_by_type( - monitoring.ServerHeartbeatSucceededEvent)[0] - self.assertGreater( - listener.events.index(pool_ready), - listener.events.index(hb_succeeded)) + hb_succeeded = listener.events_by_type(monitoring.ServerHeartbeatSucceededEvent)[0] + self.assertGreater(listener.events.index(pool_ready), listener.events.index(hb_succeeded)) listener.reset() fail_hello = { - 'mode': {'times': 2}, - 'data': { - 'failCommands': [HelloCompat.LEGACY_CMD, 'hello'], - 'errorCode': 1234, - 'appName': 'SDAMPoolManagementTest', + "mode": {"times": 2}, + "data": { + "failCommands": [HelloCompat.LEGACY_CMD, "hello"], + "errorCode": 1234, + "appName": "SDAMPoolManagementTest", }, } with self.fail_point(fail_hello): listener.wait_for_event(monitoring.ServerHeartbeatFailedEvent, 1) listener.wait_for_event(monitoring.PoolClearedEvent, 1) - listener.wait_for_event( - monitoring.ServerHeartbeatSucceededEvent, 1) + listener.wait_for_event(monitoring.ServerHeartbeatSucceededEvent, 1) listener.wait_for_event(monitoring.PoolReadyEvent, 1) class TestIntegration(SpecRunner): # Location of JSON test specifications. TEST_PATH = os.path.join( - os.path.dirname(os.path.realpath(__file__)), - 'discovery_and_monitoring_integration') + os.path.dirname(os.path.realpath(__file__)), "discovery_and_monitoring_integration" + ) def _event_count(self, event): - if event == 'ServerMarkedUnknownEvent': + if event == "ServerMarkedUnknownEvent": + def marked_unknown(e): - return (isinstance(e, monitoring.ServerDescriptionChangedEvent) - and not e.new_description.is_server_type_known) + return ( + isinstance(e, monitoring.ServerDescriptionChangedEvent) + and not e.new_description.is_server_type_known + ) + return len(self.server_listener.matching(marked_unknown)) # Only support CMAP events for now. - self.assertTrue(event.startswith('Pool') or event.startswith('Conn')) + self.assertTrue(event.startswith("Pool") or event.startswith("Conn")) event_type = getattr(monitoring, event) return self.pool_listener.event_count(event_type) @@ -370,50 +361,48 @@ def assert_event_count(self, event, count): Assert the given event was published exactly `count` times. """ - self.assertEqual(self._event_count(event), count, - 'expected %s not %r' % (count, event)) + self.assertEqual(self._event_count(event), count, "expected %s not %r" % (count, event)) def wait_for_event(self, event, count): """Run the waitForEvent test operation. Wait for a number of events to be published, or fail. """ - wait_until(lambda: self._event_count(event) >= count, - 'find %s %s event(s)' % (count, event)) + wait_until( + lambda: self._event_count(event) >= count, "find %s %s event(s)" % (count, event) + ) def configure_fail_point(self, fail_point): - """Run the configureFailPoint test operation. - """ + """Run the configureFailPoint test operation.""" self.set_fail_point(fail_point) - self.addCleanup(self.set_fail_point, { - 'configureFailPoint': fail_point['configureFailPoint'], - 'mode': 'off'}) + self.addCleanup( + self.set_fail_point, + {"configureFailPoint": fail_point["configureFailPoint"], "mode": "off"}, + ) def run_admin_command(self, command, **kwargs): - """Run the runAdminCommand test operation. - """ + """Run the runAdminCommand test operation.""" self.client.admin.command(command, **kwargs) def record_primary(self): - """Run the recordPrimary test operation. - """ + """Run the recordPrimary test operation.""" self._previous_primary = self.scenario_client.primary def wait_for_primary_change(self, timeout_ms): - """Run the waitForPrimaryChange test operation. - """ + """Run the waitForPrimaryChange test operation.""" + def primary_changed(): primary = self.scenario_client.primary if primary is None: return False return primary != self._previous_primary - timeout = timeout_ms/1000.0 - wait_until(primary_changed, 'change primary', timeout=timeout) + + timeout = timeout_ms / 1000.0 + wait_until(primary_changed, "change primary", timeout=timeout) def wait(self, ms): - """Run the "wait" test operation. - """ - time.sleep(ms/1000.0) + """Run the "wait" test operation.""" + time.sleep(ms / 1000.0) def start_thread(self, name): """Run the 'startThread' thread operation.""" @@ -424,8 +413,7 @@ def start_thread(self, name): def run_on_thread(self, sessions, collection, name, operation): """Run the 'runOnThread' operation.""" thread = self.targets[name] - thread.schedule(lambda: self._run_op( - sessions, collection, operation, False)) + thread.schedule(lambda: self._run_op(sessions, collection, operation, False)) def wait_for_thread(self, name): """Run the 'waitForThread' operation.""" @@ -434,8 +422,7 @@ def wait_for_thread(self, name): thread.join(60) if thread.exc: raise thread.exc - self.assertFalse( - thread.is_alive(), 'Thread %s is still running' % (name,)) + self.assertFalse(thread.is_alive(), "Thread %s is still running" % (name,)) def create_spec_test(scenario_def, test, name): diff --git a/test/test_dns.py b/test/test_dns.py index 8404c2aa69..d47e115f41 100644 --- a/test/test_dns.py +++ b/test/test_dns.py @@ -21,18 +21,20 @@ sys.path[0:0] = [""] +from test import client_context, unittest +from test.utils import wait_until + from pymongo.common import validate_read_preference_tags -from pymongo.srv_resolver import _HAVE_DNSPYTHON from pymongo.errors import ConfigurationError from pymongo.mongo_client import MongoClient +from pymongo.srv_resolver import _HAVE_DNSPYTHON from pymongo.uri_parser import parse_uri, split_hosts -from test import client_context, unittest -from test.utils import wait_until class TestDNSRepl(unittest.TestCase): - TEST_PATH = os.path.join(os.path.dirname(os.path.realpath(__file__)), - 'srv_seedlist', 'replica-set') + TEST_PATH = os.path.join( + os.path.dirname(os.path.realpath(__file__)), "srv_seedlist", "replica-set" + ) load_balanced = False @client_context.require_replica_set @@ -41,8 +43,9 @@ def setUp(self): class TestDNSLoadBalanced(unittest.TestCase): - TEST_PATH = os.path.join(os.path.dirname(os.path.realpath(__file__)), - 'srv_seedlist', 'load-balanced') + TEST_PATH = os.path.join( + os.path.dirname(os.path.realpath(__file__)), "srv_seedlist", "load-balanced" + ) load_balanced = True @client_context.require_load_balancer @@ -51,8 +54,7 @@ def setUp(self): class TestDNSSharded(unittest.TestCase): - TEST_PATH = os.path.join(os.path.dirname(os.path.realpath(__file__)), - 'srv_seedlist', 'sharded') + TEST_PATH = os.path.join(os.path.dirname(os.path.realpath(__file__)), "srv_seedlist", "sharded") load_balanced = False @client_context.require_mongos @@ -61,77 +63,74 @@ def setUp(self): def create_test(test_case): - def run_test(self): if not _HAVE_DNSPYTHON: raise unittest.SkipTest("DNS tests require the dnspython module") - uri = test_case['uri'] - seeds = test_case.get('seeds') - num_seeds = test_case.get('numSeeds', len(seeds or [])) - hosts = test_case.get('hosts') + uri = test_case["uri"] + seeds = test_case.get("seeds") + num_seeds = test_case.get("numSeeds", len(seeds or [])) + hosts = test_case.get("hosts") num_hosts = test_case.get("numHosts", len(hosts or [])) - options = test_case.get('options', {}) - if 'ssl' in options: - options['tls'] = options.pop('ssl') - parsed_options = test_case.get('parsed_options') + options = test_case.get("options", {}) + if "ssl" in options: + options["tls"] = options.pop("ssl") + parsed_options = test_case.get("parsed_options") # See DRIVERS-1324, unless tls is explicitly set to False we need TLS. - needs_tls = not (options and (options.get('ssl') == False or - options.get('tls') == False)) + needs_tls = not (options and (options.get("ssl") == False or options.get("tls") == False)) if needs_tls and not client_context.tls: - self.skipTest('this test requires a TLS cluster') + self.skipTest("this test requires a TLS cluster") if not needs_tls and client_context.tls: - self.skipTest('this test requires a non-TLS cluster') + self.skipTest("this test requires a non-TLS cluster") if seeds: - seeds = split_hosts(','.join(seeds)) + seeds = split_hosts(",".join(seeds)) if hosts: - hosts = frozenset(split_hosts(','.join(hosts))) + hosts = frozenset(split_hosts(",".join(hosts))) if seeds or num_seeds: result = parse_uri(uri, validate=True) if seeds is not None: - self.assertEqual(sorted(result['nodelist']), sorted(seeds)) + self.assertEqual(sorted(result["nodelist"]), sorted(seeds)) if num_seeds is not None: - self.assertEqual(len(result['nodelist']), num_seeds) + self.assertEqual(len(result["nodelist"]), num_seeds) if options: - opts = result['options'] - if 'readpreferencetags' in opts: + opts = result["options"] + if "readpreferencetags" in opts: rpts = validate_read_preference_tags( - 'readPreferenceTags', opts.pop('readpreferencetags')) - opts['readPreferenceTags'] = rpts - self.assertEqual(result['options'], options) + "readPreferenceTags", opts.pop("readpreferencetags") + ) + opts["readPreferenceTags"] = rpts + self.assertEqual(result["options"], options) if parsed_options: for opt, expected in parsed_options.items(): - if opt == 'user': - self.assertEqual(result['username'], expected) - elif opt == 'password': - self.assertEqual(result['password'], expected) - elif opt == 'auth_database' or opt == 'db': - self.assertEqual(result['database'], expected) + if opt == "user": + self.assertEqual(result["username"], expected) + elif opt == "password": + self.assertEqual(result["password"], expected) + elif opt == "auth_database" or opt == "db": + self.assertEqual(result["database"], expected) hostname = next(iter(client_context.client.nodes))[0] # The replica set members must be configured as 'localhost'. - if hostname == 'localhost': + if hostname == "localhost": copts = client_context.default_client_options.copy() # Remove tls since SRV parsing should add it automatically. - copts.pop('tls', None) + copts.pop("tls", None) if client_context.tls: # Our test certs don't support the SRV hosts used in these # tests. - copts['tlsAllowInvalidHostnames'] = True + copts["tlsAllowInvalidHostnames"] = True client = MongoClient(uri, **copts) if num_seeds is not None: - self.assertEqual(len(client._topology_settings.seeds), - num_seeds) + self.assertEqual(len(client._topology_settings.seeds), num_seeds) if hosts is not None: - wait_until( - lambda: hosts == client.nodes, - 'match test hosts to client nodes') + wait_until(lambda: hosts == client.nodes, "match test hosts to client nodes") if num_hosts is not None: - wait_until(lambda: num_hosts == len(client.nodes), - "wait to connect to num_hosts") + wait_until( + lambda: num_hosts == len(client.nodes), "wait to connect to num_hosts" + ) # XXX: we should block until SRV poller runs at least once # and re-run these assertions. else: @@ -146,11 +145,11 @@ def run_test(self): def create_tests(cls): - for filename in glob.glob(os.path.join(cls.TEST_PATH, '*.json')): + for filename in glob.glob(os.path.join(cls.TEST_PATH, "*.json")): test_suffix, _ = os.path.splitext(os.path.basename(filename)) with open(filename) as dns_test_file: test_method = create_test(json.load(dns_test_file)) - setattr(cls, 'test_' + test_suffix, test_method) + setattr(cls, "test_" + test_suffix, test_method) create_tests(TestDNSRepl) @@ -159,26 +158,33 @@ def create_tests(cls): class TestParsingErrors(unittest.TestCase): - @unittest.skipUnless(_HAVE_DNSPYTHON, "DNS tests require the dnspython module") def test_invalid_host(self): self.assertRaisesRegex( ConfigurationError, "Invalid URI host: mongodb is not", - MongoClient, "mongodb+srv://mongodb") + MongoClient, + "mongodb+srv://mongodb", + ) self.assertRaisesRegex( ConfigurationError, "Invalid URI host: mongodb.com is not", - MongoClient, "mongodb+srv://mongodb.com") + MongoClient, + "mongodb+srv://mongodb.com", + ) self.assertRaisesRegex( ConfigurationError, "Invalid URI host: an IP address is not", - MongoClient, "mongodb+srv://127.0.0.1") + MongoClient, + "mongodb+srv://127.0.0.1", + ) self.assertRaisesRegex( ConfigurationError, "Invalid URI host: an IP address is not", - MongoClient, "mongodb+srv://[::1]") + MongoClient, + "mongodb+srv://[::1]", + ) -if __name__ == '__main__': +if __name__ == "__main__": unittest.main() diff --git a/test/test_encryption.py b/test/test_encryption.py index a0b34259f3..21cf9ee41f 100644 --- a/test/test_encryption.py +++ b/test/test_encryption.py @@ -18,8 +18,8 @@ import copy import os import re -import ssl import socket +import ssl import sys import textwrap import traceback @@ -27,135 +27,133 @@ sys.path[0:0] = [""] +from test import ( + CA_PEM, + CLIENT_PEM, + IntegrationTest, + PyMongoTestCase, + client_context, + unittest, +) +from test.test_bulk import BulkTestBase +from test.utils import ( + AllowListEventListener, + OvertCommandListener, + TestCreator, + TopologyEventListener, + camel_to_snake_args, + rs_or_single_client, + wait_until, +) +from test.utils_spec_runner import SpecRunner + from bson import encode, json_util -from bson.binary import (Binary, - UuidRepresentation, - JAVA_LEGACY, - STANDARD, - UUID_SUBTYPE) +from bson.binary import JAVA_LEGACY, STANDARD, UUID_SUBTYPE, Binary, UuidRepresentation from bson.codec_options import CodecOptions from bson.errors import BSONError from bson.json_util import JSONOptions from bson.son import SON - from pymongo import encryption from pymongo.cursor import CursorType -from pymongo.encryption import (Algorithm, - ClientEncryption) -from pymongo.encryption_options import AutoEncryptionOpts, _HAVE_PYMONGOCRYPT -from pymongo.errors import (BulkWriteError, - ConfigurationError, - EncryptionError, - InvalidOperation, - OperationFailure, - ServerSelectionTimeoutError, - WriteError) +from pymongo.encryption import Algorithm, ClientEncryption +from pymongo.encryption_options import _HAVE_PYMONGOCRYPT, AutoEncryptionOpts +from pymongo.errors import ( + BulkWriteError, + ConfigurationError, + EncryptionError, + InvalidOperation, + OperationFailure, + ServerSelectionTimeoutError, + WriteError, +) from pymongo.mongo_client import MongoClient from pymongo.operations import InsertOne, ReplaceOne, UpdateOne from pymongo.write_concern import WriteConcern -from test import (unittest, CA_PEM, CLIENT_PEM, - client_context, - IntegrationTest, - PyMongoTestCase) -from test.test_bulk import BulkTestBase -from test.utils import (TestCreator, - camel_to_snake_args, - OvertCommandListener, - TopologyEventListener, - AllowListEventListener, - rs_or_single_client, - wait_until) -from test.utils_spec_runner import SpecRunner - def get_client_opts(client): return client._MongoClient__options -KMS_PROVIDERS = {'local': {'key': b'\x00'*96}} +KMS_PROVIDERS = {"local": {"key": b"\x00" * 96}} class TestAutoEncryptionOpts(PyMongoTestCase): - @unittest.skipIf(_HAVE_PYMONGOCRYPT, 'pymongocrypt is installed') + @unittest.skipIf(_HAVE_PYMONGOCRYPT, "pymongocrypt is installed") def test_init_requires_pymongocrypt(self): with self.assertRaises(ConfigurationError): - AutoEncryptionOpts({}, 'keyvault.datakeys') + AutoEncryptionOpts({}, "keyvault.datakeys") - @unittest.skipUnless(_HAVE_PYMONGOCRYPT, 'pymongocrypt is not installed') + @unittest.skipUnless(_HAVE_PYMONGOCRYPT, "pymongocrypt is not installed") def test_init(self): - opts = AutoEncryptionOpts({}, 'keyvault.datakeys') + opts = AutoEncryptionOpts({}, "keyvault.datakeys") self.assertEqual(opts._kms_providers, {}) - self.assertEqual(opts._key_vault_namespace, 'keyvault.datakeys') + self.assertEqual(opts._key_vault_namespace, "keyvault.datakeys") self.assertEqual(opts._key_vault_client, None) self.assertEqual(opts._schema_map, None) self.assertEqual(opts._bypass_auto_encryption, False) - self.assertEqual(opts._mongocryptd_uri, 'mongodb://localhost:27020') + self.assertEqual(opts._mongocryptd_uri, "mongodb://localhost:27020") self.assertEqual(opts._mongocryptd_bypass_spawn, False) - self.assertEqual(opts._mongocryptd_spawn_path, 'mongocryptd') - self.assertEqual( - opts._mongocryptd_spawn_args, ['--idleShutdownTimeoutSecs=60']) + self.assertEqual(opts._mongocryptd_spawn_path, "mongocryptd") + self.assertEqual(opts._mongocryptd_spawn_args, ["--idleShutdownTimeoutSecs=60"]) self.assertEqual(opts._kms_ssl_contexts, {}) - @unittest.skipUnless(_HAVE_PYMONGOCRYPT, 'pymongocrypt is not installed') + @unittest.skipUnless(_HAVE_PYMONGOCRYPT, "pymongocrypt is not installed") def test_init_spawn_args(self): # User can override idleShutdownTimeoutSecs opts = AutoEncryptionOpts( - {}, 'keyvault.datakeys', - mongocryptd_spawn_args=['--idleShutdownTimeoutSecs=88']) - self.assertEqual( - opts._mongocryptd_spawn_args, ['--idleShutdownTimeoutSecs=88']) + {}, "keyvault.datakeys", mongocryptd_spawn_args=["--idleShutdownTimeoutSecs=88"] + ) + self.assertEqual(opts._mongocryptd_spawn_args, ["--idleShutdownTimeoutSecs=88"]) # idleShutdownTimeoutSecs is added by default - opts = AutoEncryptionOpts( - {}, 'keyvault.datakeys', mongocryptd_spawn_args=[]) - self.assertEqual( - opts._mongocryptd_spawn_args, ['--idleShutdownTimeoutSecs=60']) + opts = AutoEncryptionOpts({}, "keyvault.datakeys", mongocryptd_spawn_args=[]) + self.assertEqual(opts._mongocryptd_spawn_args, ["--idleShutdownTimeoutSecs=60"]) # Also added when other options are given opts = AutoEncryptionOpts( - {}, 'keyvault.datakeys', - mongocryptd_spawn_args=['--quiet', '--port=27020']) + {}, "keyvault.datakeys", mongocryptd_spawn_args=["--quiet", "--port=27020"] + ) self.assertEqual( opts._mongocryptd_spawn_args, - ['--quiet', '--port=27020', '--idleShutdownTimeoutSecs=60']) + ["--quiet", "--port=27020", "--idleShutdownTimeoutSecs=60"], + ) - @unittest.skipUnless(_HAVE_PYMONGOCRYPT, 'pymongocrypt is not installed') + @unittest.skipUnless(_HAVE_PYMONGOCRYPT, "pymongocrypt is not installed") def test_init_kms_tls_options(self): # Error cases: - with self.assertRaisesRegex( - TypeError, r'kms_tls_options\["kmip"\] must be a dict'): - AutoEncryptionOpts({}, 'k.d', kms_tls_options={'kmip': 1}) + with self.assertRaisesRegex(TypeError, r'kms_tls_options\["kmip"\] must be a dict'): + AutoEncryptionOpts({}, "k.d", kms_tls_options={"kmip": 1}) for tls_opts in [ - {'kmip': {'tls': True, 'tlsInsecure': True}}, - {'kmip': {'tls': True, 'tlsAllowInvalidCertificates': True}}, - {'kmip': {'tls': True, 'tlsAllowInvalidHostnames': True}}, - {'kmip': {'tls': True, 'tlsDisableOCSPEndpointCheck': True}}]: - with self.assertRaisesRegex( - ConfigurationError, 'Insecure TLS options prohibited'): - opts = AutoEncryptionOpts({}, 'k.d', kms_tls_options=tls_opts) + {"kmip": {"tls": True, "tlsInsecure": True}}, + {"kmip": {"tls": True, "tlsAllowInvalidCertificates": True}}, + {"kmip": {"tls": True, "tlsAllowInvalidHostnames": True}}, + {"kmip": {"tls": True, "tlsDisableOCSPEndpointCheck": True}}, + ]: + with self.assertRaisesRegex(ConfigurationError, "Insecure TLS options prohibited"): + opts = AutoEncryptionOpts({}, "k.d", kms_tls_options=tls_opts) with self.assertRaises(FileNotFoundError): - AutoEncryptionOpts({}, 'k.d', kms_tls_options={ - 'kmip': {'tlsCAFile': 'does-not-exist'}}) + AutoEncryptionOpts({}, "k.d", kms_tls_options={"kmip": {"tlsCAFile": "does-not-exist"}}) # Success cases: for tls_opts in [None, {}]: - opts = AutoEncryptionOpts({}, 'k.d', kms_tls_options=tls_opts) + opts = AutoEncryptionOpts({}, "k.d", kms_tls_options=tls_opts) self.assertEqual(opts._kms_ssl_contexts, {}) - opts = AutoEncryptionOpts( - {}, 'k.d', kms_tls_options={'kmip': {'tls': True}, 'aws': {}}) - ctx = opts._kms_ssl_contexts['kmip'] + opts = AutoEncryptionOpts({}, "k.d", kms_tls_options={"kmip": {"tls": True}, "aws": {}}) + ctx = opts._kms_ssl_contexts["kmip"] # On < 3.7 we check hostnames manually. if sys.version_info[:2] >= (3, 7): self.assertEqual(ctx.check_hostname, True) self.assertEqual(ctx.verify_mode, ssl.CERT_REQUIRED) - ctx = opts._kms_ssl_contexts['aws'] + ctx = opts._kms_ssl_contexts["aws"] if sys.version_info[:2] >= (3, 7): self.assertEqual(ctx.check_hostname, True) self.assertEqual(ctx.verify_mode, ssl.CERT_REQUIRED) opts = AutoEncryptionOpts( - {}, 'k.d', kms_tls_options={'kmip': { - 'tlsCAFile': CA_PEM, 'tlsCertificateKeyFile': CLIENT_PEM}}) - ctx = opts._kms_ssl_contexts['kmip'] + {}, + "k.d", + kms_tls_options={"kmip": {"tlsCAFile": CA_PEM, "tlsCertificateKeyFile": CLIENT_PEM}}, + ) + ctx = opts._kms_ssl_contexts["kmip"] if sys.version_info[:2] >= (3, 7): self.assertEqual(ctx.check_hostname, True) self.assertEqual(ctx.verify_mode, ssl.CERT_REQUIRED) @@ -171,9 +169,9 @@ def test_default(self): self.addCleanup(client.close) self.assertEqual(get_client_opts(client).auto_encryption_opts, None) - @unittest.skipUnless(_HAVE_PYMONGOCRYPT, 'pymongocrypt is not installed') + @unittest.skipUnless(_HAVE_PYMONGOCRYPT, "pymongocrypt is not installed") def test_kwargs(self): - opts = AutoEncryptionOpts(KMS_PROVIDERS, 'keyvault.datakeys') + opts = AutoEncryptionOpts(KMS_PROVIDERS, "keyvault.datakeys") client = MongoClient(auto_encryption_opts=opts, connect=False) self.addCleanup(client.close) self.assertEqual(get_client_opts(client).auto_encryption_opts, opts) @@ -183,7 +181,7 @@ class EncryptionIntegrationTest(IntegrationTest): """Base class for encryption integration tests.""" @classmethod - @unittest.skipUnless(_HAVE_PYMONGOCRYPT, 'pymongocrypt is not installed') + @unittest.skipUnless(_HAVE_PYMONGOCRYPT, "pymongocrypt is not installed") @client_context.require_version_min(4, 2, -1) def setUpClass(cls): super(EncryptionIntegrationTest, cls).setUpClass() @@ -198,16 +196,14 @@ def assertBinaryUUID(self, val): # Location of JSON test files. -BASE = os.path.join( - os.path.dirname(os.path.realpath(__file__)), 'client-side-encryption') -SPEC_PATH = os.path.join(BASE, 'spec') +BASE = os.path.join(os.path.dirname(os.path.realpath(__file__)), "client-side-encryption") +SPEC_PATH = os.path.join(BASE, "spec") OPTS = CodecOptions(uuid_representation=STANDARD) # Use SON to preserve the order of fields while parsing json. Use tz_aware # =False to match how CodecOptions decodes dates. -JSON_OPTS = JSONOptions(document_class=SON, uuid_representation=STANDARD, - tz_aware=False) +JSON_OPTS = JSONOptions(document_class=SON, uuid_representation=STANDARD, tz_aware=False) def read(*paths): @@ -224,38 +220,39 @@ def bson_data(*paths): class TestClientSimple(EncryptionIntegrationTest): - def _test_auto_encrypt(self, opts): client = rs_or_single_client(auto_encryption_opts=opts) self.addCleanup(client.close) # Create the encrypted field's data key. key_vault = create_key_vault( - self.client.keyvault.datakeys, - json_data('custom', 'key-document-local.json')) + self.client.keyvault.datakeys, json_data("custom", "key-document-local.json") + ) self.addCleanup(key_vault.drop) # Collection.insert_one/insert_many auto encrypts. - docs = [{'_id': 0, 'ssn': '000'}, - {'_id': 1, 'ssn': '111'}, - {'_id': 2, 'ssn': '222'}, - {'_id': 3, 'ssn': '333'}, - {'_id': 4, 'ssn': '444'}, - {'_id': 5, 'ssn': '555'}] + docs = [ + {"_id": 0, "ssn": "000"}, + {"_id": 1, "ssn": "111"}, + {"_id": 2, "ssn": "222"}, + {"_id": 3, "ssn": "333"}, + {"_id": 4, "ssn": "444"}, + {"_id": 5, "ssn": "555"}, + ] encrypted_coll = client.pymongo_test.test encrypted_coll.insert_one(docs[0]) encrypted_coll.insert_many(docs[1:3]) unack = encrypted_coll.with_options(write_concern=WriteConcern(w=0)) unack.insert_one(docs[3]) unack.insert_many(docs[4:], ordered=False) - wait_until(lambda: self.db.test.count_documents({}) == len(docs), - 'insert documents with w=0') + wait_until( + lambda: self.db.test.count_documents({}) == len(docs), "insert documents with w=0" + ) # Database.command auto decrypts. - res = client.pymongo_test.command( - 'find', 'test', filter={'ssn': '000'}) - decrypted_docs = res['cursor']['firstBatch'] - self.assertEqual(decrypted_docs, [{'_id': 0, 'ssn': '000'}]) + res = client.pymongo_test.command("find", "test", filter={"ssn": "000"}) + decrypted_docs = res["cursor"]["firstBatch"] + self.assertEqual(decrypted_docs, [{"_id": 0, "ssn": "000"}]) # Collection.find auto decrypts. decrypted_docs = list(encrypted_coll.find()) @@ -274,51 +271,48 @@ def _test_auto_encrypt(self, opts): self.assertEqual(decrypted_docs, docs) # Collection.distinct auto decrypts. - decrypted_ssns = encrypted_coll.distinct('ssn') - self.assertEqual(set(decrypted_ssns), set(d['ssn'] for d in docs)) + decrypted_ssns = encrypted_coll.distinct("ssn") + self.assertEqual(set(decrypted_ssns), set(d["ssn"] for d in docs)) # Make sure the field is actually encrypted. for encrypted_doc in self.db.test.find(): - self.assertIsInstance(encrypted_doc['_id'], int) - self.assertEncrypted(encrypted_doc['ssn']) + self.assertIsInstance(encrypted_doc["_id"], int) + self.assertEncrypted(encrypted_doc["ssn"]) # Attempt to encrypt an unencodable object. with self.assertRaises(BSONError): - encrypted_coll.insert_one({'unencodeable': object()}) + encrypted_coll.insert_one({"unencodeable": object()}) def test_auto_encrypt(self): # Configure the encrypted field via jsonSchema. - json_schema = json_data('custom', 'schema.json') + json_schema = json_data("custom", "schema.json") create_with_schema(self.db.test, json_schema) self.addCleanup(self.db.test.drop) - opts = AutoEncryptionOpts(KMS_PROVIDERS, 'keyvault.datakeys') + opts = AutoEncryptionOpts(KMS_PROVIDERS, "keyvault.datakeys") self._test_auto_encrypt(opts) def test_auto_encrypt_local_schema_map(self): # Configure the encrypted field via the local schema_map option. - schemas = {'pymongo_test.test': json_data('custom', 'schema.json')} - opts = AutoEncryptionOpts( - KMS_PROVIDERS, 'keyvault.datakeys', schema_map=schemas) + schemas = {"pymongo_test.test": json_data("custom", "schema.json")} + opts = AutoEncryptionOpts(KMS_PROVIDERS, "keyvault.datakeys", schema_map=schemas) self._test_auto_encrypt(opts) def test_use_after_close(self): - opts = AutoEncryptionOpts(KMS_PROVIDERS, 'keyvault.datakeys') + opts = AutoEncryptionOpts(KMS_PROVIDERS, "keyvault.datakeys") client = rs_or_single_client(auto_encryption_opts=opts) self.addCleanup(client.close) - client.admin.command('ping') + client.admin.command("ping") client.close() - with self.assertRaisesRegex(InvalidOperation, - 'Cannot use MongoClient after close'): - client.admin.command('ping') + with self.assertRaisesRegex(InvalidOperation, "Cannot use MongoClient after close"): + client.admin.command("ping") class TestEncryptedBulkWrite(BulkTestBase, EncryptionIntegrationTest): - def test_upsert_uuid_standard_encrypte(self): - opts = AutoEncryptionOpts(KMS_PROVIDERS, 'keyvault.datakeys') + opts = AutoEncryptionOpts(KMS_PROVIDERS, "keyvault.datakeys") client = rs_or_single_client(auto_encryption_opts=opts) self.addCleanup(client.close) @@ -326,126 +320,131 @@ def test_upsert_uuid_standard_encrypte(self): encrypted_coll = client.pymongo_test.test coll = encrypted_coll.with_options(codec_options=options) uuids = [uuid.uuid4() for _ in range(3)] - result = coll.bulk_write([ - UpdateOne({'_id': uuids[0]}, {'$set': {'a': 0}}, upsert=True), - ReplaceOne({'a': 1}, {'_id': uuids[1]}, upsert=True), - # This is just here to make the counts right in all cases. - ReplaceOne({'_id': uuids[2]}, {'_id': uuids[2]}, upsert=True), - ]) + result = coll.bulk_write( + [ + UpdateOne({"_id": uuids[0]}, {"$set": {"a": 0}}, upsert=True), + ReplaceOne({"a": 1}, {"_id": uuids[1]}, upsert=True), + # This is just here to make the counts right in all cases. + ReplaceOne({"_id": uuids[2]}, {"_id": uuids[2]}, upsert=True), + ] + ) self.assertEqualResponse( - {'nMatched': 0, - 'nModified': 0, - 'nUpserted': 3, - 'nInserted': 0, - 'nRemoved': 0, - 'upserted': [{'index': 0, '_id': uuids[0]}, - {'index': 1, '_id': uuids[1]}, - {'index': 2, '_id': uuids[2]}]}, - result.bulk_api_result) + { + "nMatched": 0, + "nModified": 0, + "nUpserted": 3, + "nInserted": 0, + "nRemoved": 0, + "upserted": [ + {"index": 0, "_id": uuids[0]}, + {"index": 1, "_id": uuids[1]}, + {"index": 2, "_id": uuids[2]}, + ], + }, + result.bulk_api_result, + ) class TestClientMaxWireVersion(IntegrationTest): - @classmethod - @unittest.skipUnless(_HAVE_PYMONGOCRYPT, 'pymongocrypt is not installed') + @unittest.skipUnless(_HAVE_PYMONGOCRYPT, "pymongocrypt is not installed") def setUpClass(cls): super(TestClientMaxWireVersion, cls).setUpClass() @client_context.require_version_max(4, 0, 99) def test_raise_max_wire_version_error(self): - opts = AutoEncryptionOpts(KMS_PROVIDERS, 'keyvault.datakeys') + opts = AutoEncryptionOpts(KMS_PROVIDERS, "keyvault.datakeys") client = rs_or_single_client(auto_encryption_opts=opts) self.addCleanup(client.close) - msg = 'Auto-encryption requires a minimum MongoDB version of 4.2' + msg = "Auto-encryption requires a minimum MongoDB version of 4.2" with self.assertRaisesRegex(ConfigurationError, msg): client.test.test.insert_one({}) with self.assertRaisesRegex(ConfigurationError, msg): - client.admin.command('ping') + client.admin.command("ping") with self.assertRaisesRegex(ConfigurationError, msg): client.test.test.find_one({}) with self.assertRaisesRegex(ConfigurationError, msg): client.test.test.bulk_write([InsertOne({})]) def test_raise_unsupported_error(self): - opts = AutoEncryptionOpts(KMS_PROVIDERS, 'keyvault.datakeys') + opts = AutoEncryptionOpts(KMS_PROVIDERS, "keyvault.datakeys") client = rs_or_single_client(auto_encryption_opts=opts) self.addCleanup(client.close) - msg = 'find_raw_batches does not support auto encryption' + msg = "find_raw_batches does not support auto encryption" with self.assertRaisesRegex(InvalidOperation, msg): client.test.test.find_raw_batches({}) - msg = 'aggregate_raw_batches does not support auto encryption' + msg = "aggregate_raw_batches does not support auto encryption" with self.assertRaisesRegex(InvalidOperation, msg): client.test.test.aggregate_raw_batches([]) if client_context.is_mongos: - msg = 'Exhaust cursors are not supported by mongos' + msg = "Exhaust cursors are not supported by mongos" else: - msg = 'exhaust cursors do not support auto encryption' + msg = "exhaust cursors do not support auto encryption" with self.assertRaisesRegex(InvalidOperation, msg): next(client.test.test.find(cursor_type=CursorType.EXHAUST)) class TestExplicitSimple(EncryptionIntegrationTest): - def test_encrypt_decrypt(self): client_encryption = ClientEncryption( - KMS_PROVIDERS, 'keyvault.datakeys', client_context.client, OPTS) + KMS_PROVIDERS, "keyvault.datakeys", client_context.client, OPTS + ) self.addCleanup(client_encryption.close) # Use standard UUID representation. - key_vault = client_context.client.keyvault.get_collection( - 'datakeys', codec_options=OPTS) + key_vault = client_context.client.keyvault.get_collection("datakeys", codec_options=OPTS) self.addCleanup(key_vault.drop) # Create the encrypted field's data key. - key_id = client_encryption.create_data_key( - 'local', key_alt_names=['name']) + key_id = client_encryption.create_data_key("local", key_alt_names=["name"]) self.assertBinaryUUID(key_id) - self.assertTrue(key_vault.find_one({'_id': key_id})) + self.assertTrue(key_vault.find_one({"_id": key_id})) # Create an unused data key to make sure filtering works. - unused_key_id = client_encryption.create_data_key( - 'local', key_alt_names=['unused']) + unused_key_id = client_encryption.create_data_key("local", key_alt_names=["unused"]) self.assertBinaryUUID(unused_key_id) - self.assertTrue(key_vault.find_one({'_id': unused_key_id})) + self.assertTrue(key_vault.find_one({"_id": unused_key_id})) - doc = {'_id': 0, 'ssn': '000'} + doc = {"_id": 0, "ssn": "000"} encrypted_ssn = client_encryption.encrypt( - doc['ssn'], Algorithm.AEAD_AES_256_CBC_HMAC_SHA_512_Deterministic, - key_id=key_id) + doc["ssn"], Algorithm.AEAD_AES_256_CBC_HMAC_SHA_512_Deterministic, key_id=key_id + ) # Ensure encryption via key_alt_name for the same key produces the # same output. encrypted_ssn2 = client_encryption.encrypt( - doc['ssn'], Algorithm.AEAD_AES_256_CBC_HMAC_SHA_512_Deterministic, - key_alt_name='name') + doc["ssn"], Algorithm.AEAD_AES_256_CBC_HMAC_SHA_512_Deterministic, key_alt_name="name" + ) self.assertEqual(encrypted_ssn, encrypted_ssn2) # Test decryption. decrypted_ssn = client_encryption.decrypt(encrypted_ssn) - self.assertEqual(decrypted_ssn, doc['ssn']) + self.assertEqual(decrypted_ssn, doc["ssn"]) def test_validation(self): client_encryption = ClientEncryption( - KMS_PROVIDERS, 'keyvault.datakeys', client_context.client, OPTS) + KMS_PROVIDERS, "keyvault.datakeys", client_context.client, OPTS + ) self.addCleanup(client_encryption.close) - msg = 'value to decrypt must be a bson.binary.Binary with subtype 6' + msg = "value to decrypt must be a bson.binary.Binary with subtype 6" with self.assertRaisesRegex(TypeError, msg): - client_encryption.decrypt('str') + client_encryption.decrypt("str") with self.assertRaisesRegex(TypeError, msg): - client_encryption.decrypt(Binary(b'123')) + client_encryption.decrypt(Binary(b"123")) - msg = 'key_id must be a bson.binary.Binary with subtype 4' + msg = "key_id must be a bson.binary.Binary with subtype 4" algo = Algorithm.AEAD_AES_256_CBC_HMAC_SHA_512_Deterministic with self.assertRaisesRegex(TypeError, msg): - client_encryption.encrypt('str', algo, key_id=uuid.uuid4()) + client_encryption.encrypt("str", algo, key_id=uuid.uuid4()) with self.assertRaisesRegex(TypeError, msg): - client_encryption.encrypt('str', algo, key_id=Binary(b'123')) + client_encryption.encrypt("str", algo, key_id=Binary(b"123")) def test_bson_errors(self): client_encryption = ClientEncryption( - KMS_PROVIDERS, 'keyvault.datakeys', client_context.client, OPTS) + KMS_PROVIDERS, "keyvault.datakeys", client_context.client, OPTS + ) self.addCleanup(client_encryption.close) # Attempt to encrypt an unencodable object. @@ -454,37 +453,38 @@ def test_bson_errors(self): client_encryption.encrypt( unencodable_value, Algorithm.AEAD_AES_256_CBC_HMAC_SHA_512_Deterministic, - key_id=Binary(uuid.uuid4().bytes, UUID_SUBTYPE)) + key_id=Binary(uuid.uuid4().bytes, UUID_SUBTYPE), + ) def test_codec_options(self): - with self.assertRaisesRegex(TypeError, 'codec_options must be'): - ClientEncryption( - KMS_PROVIDERS, 'keyvault.datakeys', client_context.client, None) + with self.assertRaisesRegex(TypeError, "codec_options must be"): + ClientEncryption(KMS_PROVIDERS, "keyvault.datakeys", client_context.client, None) opts = CodecOptions(uuid_representation=JAVA_LEGACY) client_encryption_legacy = ClientEncryption( - KMS_PROVIDERS, 'keyvault.datakeys', client_context.client, opts) + KMS_PROVIDERS, "keyvault.datakeys", client_context.client, opts + ) self.addCleanup(client_encryption_legacy.close) # Create the encrypted field's data key. - key_id = client_encryption_legacy.create_data_key('local') + key_id = client_encryption_legacy.create_data_key("local") # Encrypt a UUID with JAVA_LEGACY codec options. value = uuid.uuid4() encrypted_legacy = client_encryption_legacy.encrypt( - value, Algorithm.AEAD_AES_256_CBC_HMAC_SHA_512_Deterministic, - key_id=key_id) - decrypted_value_legacy = client_encryption_legacy.decrypt( - encrypted_legacy) + value, Algorithm.AEAD_AES_256_CBC_HMAC_SHA_512_Deterministic, key_id=key_id + ) + decrypted_value_legacy = client_encryption_legacy.decrypt(encrypted_legacy) self.assertEqual(decrypted_value_legacy, value) # Encrypt the same UUID with STANDARD codec options. client_encryption = ClientEncryption( - KMS_PROVIDERS, 'keyvault.datakeys', client_context.client, OPTS) + KMS_PROVIDERS, "keyvault.datakeys", client_context.client, OPTS + ) self.addCleanup(client_encryption.close) encrypted_standard = client_encryption.encrypt( - value, Algorithm.AEAD_AES_256_CBC_HMAC_SHA_512_Deterministic, - key_id=key_id) + value, Algorithm.AEAD_AES_256_CBC_HMAC_SHA_512_Deterministic, key_id=key_id + ) decrypted_standard = client_encryption.decrypt(encrypted_standard) self.assertEqual(decrypted_standard, value) @@ -492,163 +492,160 @@ def test_codec_options(self): self.assertNotEqual(encrypted_standard, encrypted_legacy) # Test that codec_options is applied during decryption. self.assertEqual( - client_encryption_legacy.decrypt(encrypted_standard), - Binary.from_uuid(value)) - self.assertNotEqual( - client_encryption.decrypt(encrypted_legacy), value) + client_encryption_legacy.decrypt(encrypted_standard), Binary.from_uuid(value) + ) + self.assertNotEqual(client_encryption.decrypt(encrypted_legacy), value) def test_close(self): client_encryption = ClientEncryption( - KMS_PROVIDERS, 'keyvault.datakeys', client_context.client, OPTS) + KMS_PROVIDERS, "keyvault.datakeys", client_context.client, OPTS + ) client_encryption.close() # Close can be called multiple times. client_encryption.close() algo = Algorithm.AEAD_AES_256_CBC_HMAC_SHA_512_Deterministic - msg = 'Cannot use closed ClientEncryption' + msg = "Cannot use closed ClientEncryption" with self.assertRaisesRegex(InvalidOperation, msg): - client_encryption.create_data_key('local') + client_encryption.create_data_key("local") with self.assertRaisesRegex(InvalidOperation, msg): - client_encryption.encrypt('val', algo, key_alt_name='name') + client_encryption.encrypt("val", algo, key_alt_name="name") with self.assertRaisesRegex(InvalidOperation, msg): - client_encryption.decrypt(Binary(b'', 6)) + client_encryption.decrypt(Binary(b"", 6)) def test_with_statement(self): with ClientEncryption( - KMS_PROVIDERS, 'keyvault.datakeys', - client_context.client, OPTS) as client_encryption: + KMS_PROVIDERS, "keyvault.datakeys", client_context.client, OPTS + ) as client_encryption: pass - with self.assertRaisesRegex( - InvalidOperation, 'Cannot use closed ClientEncryption'): - client_encryption.create_data_key('local') + with self.assertRaisesRegex(InvalidOperation, "Cannot use closed ClientEncryption"): + client_encryption.create_data_key("local") # Spec tests AWS_CREDS = { - 'accessKeyId': os.environ.get('FLE_AWS_KEY', ''), - 'secretAccessKey': os.environ.get('FLE_AWS_SECRET', '') + "accessKeyId": os.environ.get("FLE_AWS_KEY", ""), + "secretAccessKey": os.environ.get("FLE_AWS_SECRET", ""), } AWS_TEMP_CREDS = { - 'accessKeyId': os.environ.get('CSFLE_AWS_TEMP_ACCESS_KEY_ID', ''), - 'secretAccessKey': os.environ.get('CSFLE_AWS_TEMP_SECRET_ACCESS_KEY', ''), - 'sessionToken': os.environ.get('CSFLE_AWS_TEMP_SESSION_TOKEN', '') + "accessKeyId": os.environ.get("CSFLE_AWS_TEMP_ACCESS_KEY_ID", ""), + "secretAccessKey": os.environ.get("CSFLE_AWS_TEMP_SECRET_ACCESS_KEY", ""), + "sessionToken": os.environ.get("CSFLE_AWS_TEMP_SESSION_TOKEN", ""), } AWS_TEMP_NO_SESSION_CREDS = { - 'accessKeyId': os.environ.get('CSFLE_AWS_TEMP_ACCESS_KEY_ID', ''), - 'secretAccessKey': os.environ.get('CSFLE_AWS_TEMP_SECRET_ACCESS_KEY', '') + "accessKeyId": os.environ.get("CSFLE_AWS_TEMP_ACCESS_KEY_ID", ""), + "secretAccessKey": os.environ.get("CSFLE_AWS_TEMP_SECRET_ACCESS_KEY", ""), } AZURE_CREDS = { - 'tenantId': os.environ.get('FLE_AZURE_TENANTID', ''), - 'clientId': os.environ.get('FLE_AZURE_CLIENTID', ''), - 'clientSecret': os.environ.get('FLE_AZURE_CLIENTSECRET', '')} + "tenantId": os.environ.get("FLE_AZURE_TENANTID", ""), + "clientId": os.environ.get("FLE_AZURE_CLIENTID", ""), + "clientSecret": os.environ.get("FLE_AZURE_CLIENTSECRET", ""), +} GCP_CREDS = { - 'email': os.environ.get('FLE_GCP_EMAIL', ''), - 'privateKey': os.environ.get('FLE_GCP_PRIVATEKEY', '')} + "email": os.environ.get("FLE_GCP_EMAIL", ""), + "privateKey": os.environ.get("FLE_GCP_PRIVATEKEY", ""), +} -KMIP = {'endpoint': os.environ.get('FLE_KMIP_ENDPOINT', 'localhost:5698')} -KMS_TLS_OPTS = {'kmip': {'tlsCAFile': CA_PEM, - 'tlsCertificateKeyFile': CLIENT_PEM}} +KMIP = {"endpoint": os.environ.get("FLE_KMIP_ENDPOINT", "localhost:5698")} +KMS_TLS_OPTS = {"kmip": {"tlsCAFile": CA_PEM, "tlsCertificateKeyFile": CLIENT_PEM}} class TestSpec(SpecRunner): - @classmethod - @unittest.skipUnless(_HAVE_PYMONGOCRYPT, 'pymongocrypt is not installed') + @unittest.skipUnless(_HAVE_PYMONGOCRYPT, "pymongocrypt is not installed") def setUpClass(cls): super(TestSpec, cls).setUpClass() def parse_auto_encrypt_opts(self, opts): """Parse clientOptions.autoEncryptOpts.""" opts = camel_to_snake_args(opts) - kms_providers = opts['kms_providers'] - if 'aws' in kms_providers: - kms_providers['aws'] = AWS_CREDS + kms_providers = opts["kms_providers"] + if "aws" in kms_providers: + kms_providers["aws"] = AWS_CREDS if not any(AWS_CREDS.values()): - self.skipTest('AWS environment credentials are not set') - if 'awsTemporary' in kms_providers: - kms_providers['aws'] = AWS_TEMP_CREDS - del kms_providers['awsTemporary'] + self.skipTest("AWS environment credentials are not set") + if "awsTemporary" in kms_providers: + kms_providers["aws"] = AWS_TEMP_CREDS + del kms_providers["awsTemporary"] if not any(AWS_TEMP_CREDS.values()): - self.skipTest('AWS Temp environment credentials are not set') - if 'awsTemporaryNoSessionToken' in kms_providers: - kms_providers['aws'] = AWS_TEMP_NO_SESSION_CREDS - del kms_providers['awsTemporaryNoSessionToken'] + self.skipTest("AWS Temp environment credentials are not set") + if "awsTemporaryNoSessionToken" in kms_providers: + kms_providers["aws"] = AWS_TEMP_NO_SESSION_CREDS + del kms_providers["awsTemporaryNoSessionToken"] if not any(AWS_TEMP_NO_SESSION_CREDS.values()): - self.skipTest('AWS Temp environment credentials are not set') - if 'azure' in kms_providers: - kms_providers['azure'] = AZURE_CREDS + self.skipTest("AWS Temp environment credentials are not set") + if "azure" in kms_providers: + kms_providers["azure"] = AZURE_CREDS if not any(AZURE_CREDS.values()): - self.skipTest('Azure environment credentials are not set') - if 'gcp' in kms_providers: - kms_providers['gcp'] = GCP_CREDS + self.skipTest("Azure environment credentials are not set") + if "gcp" in kms_providers: + kms_providers["gcp"] = GCP_CREDS if not any(AZURE_CREDS.values()): - self.skipTest('GCP environment credentials are not set') - if 'kmip' in kms_providers: - kms_providers['kmip'] = KMIP - opts['kms_tls_options'] = KMS_TLS_OPTS - if 'key_vault_namespace' not in opts: - opts['key_vault_namespace'] = 'keyvault.datakeys' + self.skipTest("GCP environment credentials are not set") + if "kmip" in kms_providers: + kms_providers["kmip"] = KMIP + opts["kms_tls_options"] = KMS_TLS_OPTS + if "key_vault_namespace" not in opts: + opts["key_vault_namespace"] = "keyvault.datakeys" opts = dict(opts) return AutoEncryptionOpts(**opts) def parse_client_options(self, opts): """Override clientOptions parsing to support autoEncryptOpts.""" - encrypt_opts = opts.pop('autoEncryptOpts') + encrypt_opts = opts.pop("autoEncryptOpts") if encrypt_opts: - opts['auto_encryption_opts'] = self.parse_auto_encrypt_opts( - encrypt_opts) + opts["auto_encryption_opts"] = self.parse_auto_encrypt_opts(encrypt_opts) return super(TestSpec, self).parse_client_options(opts) def get_object_name(self, op): """Default object is collection.""" - return op.get('object', 'collection') + return op.get("object", "collection") def maybe_skip_scenario(self, test): super(TestSpec, self).maybe_skip_scenario(test) - desc = test['description'].lower() - if 'type=symbol' in desc: - self.skipTest('PyMongo does not support the symbol type') + desc = test["description"].lower() + if "type=symbol" in desc: + self.skipTest("PyMongo does not support the symbol type") def setup_scenario(self, scenario_def): """Override a test's setup.""" - key_vault_data = scenario_def['key_vault_data'] + key_vault_data = scenario_def["key_vault_data"] if key_vault_data: coll = client_context.client.get_database( - 'keyvault', - write_concern=WriteConcern(w='majority'), - codec_options=OPTS)['datakeys'] + "keyvault", write_concern=WriteConcern(w="majority"), codec_options=OPTS + )["datakeys"] coll.drop() coll.insert_many(key_vault_data) db_name = self.get_scenario_db_name(scenario_def) coll_name = self.get_scenario_coll_name(scenario_def) db = client_context.client.get_database( - db_name, write_concern=WriteConcern(w='majority'), - codec_options=OPTS) + db_name, write_concern=WriteConcern(w="majority"), codec_options=OPTS + ) coll = db[coll_name] coll.drop() - json_schema = scenario_def['json_schema'] + json_schema = scenario_def["json_schema"] if json_schema: db.create_collection( - coll_name, - validator={'$jsonSchema': json_schema}, codec_options=OPTS) + coll_name, validator={"$jsonSchema": json_schema}, codec_options=OPTS + ) else: db.create_collection(coll_name) - if scenario_def['data']: + if scenario_def["data"]: # Load data. - coll.insert_many(scenario_def['data']) + coll.insert_many(scenario_def["data"]) def allowable_errors(self, op): """Override expected error classes.""" errors = super(TestSpec, self).allowable_errors(op) # An updateOne test expects encryption to error when no $ operator # appears but pymongo raises a client side ValueError in this case. - if op['name'] == 'updateOne': + if op["name"] == "updateOne": errors += (ValueError,) return errors @@ -667,40 +664,36 @@ def run_scenario(self): # Prose Tests LOCAL_MASTER_KEY = base64.b64decode( - b'Mng0NCt4ZHVUYUJCa1kxNkVyNUR1QURhZ2h2UzR2d2RrZzh0cFBwM3R6NmdWMDFBMUN3YkQ' - b'5aXRRMkhGRGdQV09wOGVNYUMxT2k3NjZKelhaQmRCZGJkTXVyZG9uSjFk') + b"Mng0NCt4ZHVUYUJCa1kxNkVyNUR1QURhZ2h2UzR2d2RrZzh0cFBwM3R6NmdWMDFBMUN3YkQ" + b"5aXRRMkhGRGdQV09wOGVNYUMxT2k3NjZKelhaQmRCZGJkTXVyZG9uSjFk" +) ALL_KMS_PROVIDERS = { - 'aws': AWS_CREDS, - 'azure': AZURE_CREDS, - 'gcp': GCP_CREDS, - 'kmip': KMIP, - 'local': {'key': LOCAL_MASTER_KEY}} - -LOCAL_KEY_ID = Binary( - base64.b64decode(b'LOCALAAAAAAAAAAAAAAAAA=='), UUID_SUBTYPE) -AWS_KEY_ID = Binary( - base64.b64decode(b'AWSAAAAAAAAAAAAAAAAAAA=='), UUID_SUBTYPE) -AZURE_KEY_ID = Binary( - base64.b64decode(b'AZUREAAAAAAAAAAAAAAAAA=='), UUID_SUBTYPE) -GCP_KEY_ID = Binary( - base64.b64decode(b'GCPAAAAAAAAAAAAAAAAAAA=='), UUID_SUBTYPE) -KMIP_KEY_ID = Binary( - base64.b64decode(b'KMIPAAAAAAAAAAAAAAAAAA=='), UUID_SUBTYPE) + "aws": AWS_CREDS, + "azure": AZURE_CREDS, + "gcp": GCP_CREDS, + "kmip": KMIP, + "local": {"key": LOCAL_MASTER_KEY}, +} + +LOCAL_KEY_ID = Binary(base64.b64decode(b"LOCALAAAAAAAAAAAAAAAAA=="), UUID_SUBTYPE) +AWS_KEY_ID = Binary(base64.b64decode(b"AWSAAAAAAAAAAAAAAAAAAA=="), UUID_SUBTYPE) +AZURE_KEY_ID = Binary(base64.b64decode(b"AZUREAAAAAAAAAAAAAAAAA=="), UUID_SUBTYPE) +GCP_KEY_ID = Binary(base64.b64decode(b"GCPAAAAAAAAAAAAAAAAAAA=="), UUID_SUBTYPE) +KMIP_KEY_ID = Binary(base64.b64decode(b"KMIPAAAAAAAAAAAAAAAAAA=="), UUID_SUBTYPE) def create_with_schema(coll, json_schema): """Create and return a Collection with a jsonSchema.""" - coll.with_options(write_concern=WriteConcern(w='majority')).drop() + coll.with_options(write_concern=WriteConcern(w="majority")).drop() return coll.database.create_collection( - coll.name, validator={'$jsonSchema': json_schema}, codec_options=OPTS) + coll.name, validator={"$jsonSchema": json_schema}, codec_options=OPTS + ) def create_key_vault(vault, *data_keys): """Create the key vault collection with optional data keys.""" - vault = vault.with_options( - write_concern=WriteConcern(w='majority'), - codec_options=OPTS) + vault = vault.with_options(write_concern=WriteConcern(w="majority"), codec_options=OPTS) vault.drop() if data_keys: vault.insert_many(data_keys) @@ -712,27 +705,29 @@ class TestDataKeyDoubleEncryption(EncryptionIntegrationTest): KMS_PROVIDERS = ALL_KMS_PROVIDERS MASTER_KEYS = { - 'aws': { - 'region': 'us-east-1', - 'key': 'arn:aws:kms:us-east-1:579766882180:key/89fcc2c4-08b0-' - '4bd9-9f25-e30687b580d0'}, - 'azure': { - 'keyVaultEndpoint': 'key-vault-csfle.vault.azure.net', - 'keyName': 'key-name-csfle'}, - 'gcp': { - 'projectId': 'devprod-drivers', - 'location': 'global', - 'keyRing': 'key-ring-csfle', - 'keyName': 'key-name-csfle'}, - 'kmip': {}, - 'local': None + "aws": { + "region": "us-east-1", + "key": "arn:aws:kms:us-east-1:579766882180:key/89fcc2c4-08b0-" "4bd9-9f25-e30687b580d0", + }, + "azure": { + "keyVaultEndpoint": "key-vault-csfle.vault.azure.net", + "keyName": "key-name-csfle", + }, + "gcp": { + "projectId": "devprod-drivers", + "location": "global", + "keyRing": "key-ring-csfle", + "keyName": "key-name-csfle", + }, + "kmip": {}, + "local": None, } @classmethod - @unittest.skipUnless(any([all(AWS_CREDS.values()), - all(AZURE_CREDS.values()), - all(GCP_CREDS.values())]), - 'No environment credentials are set') + @unittest.skipUnless( + any([all(AWS_CREDS.values()), all(AZURE_CREDS.values()), all(GCP_CREDS.values())]), + "No environment credentials are set", + ) def setUpClass(cls): super(TestDataKeyDoubleEncryption, cls).setUpClass() cls.listener = OvertCommandListener() @@ -749,20 +744,21 @@ def setUpClass(cls): "encrypt": { "keyId": "/placeholder", "bsonType": "string", - "algorithm": "AEAD_AES_256_CBC_HMAC_SHA_512-Random" + "algorithm": "AEAD_AES_256_CBC_HMAC_SHA_512-Random", } } - } + }, } } opts = AutoEncryptionOpts( - cls.KMS_PROVIDERS, 'keyvault.datakeys', schema_map=schemas, - kms_tls_options=KMS_TLS_OPTS) + cls.KMS_PROVIDERS, "keyvault.datakeys", schema_map=schemas, kms_tls_options=KMS_TLS_OPTS + ) cls.client_encrypted = rs_or_single_client( - auto_encryption_opts=opts, uuidRepresentation='standard') + auto_encryption_opts=opts, uuidRepresentation="standard" + ) cls.client_encryption = ClientEncryption( - cls.KMS_PROVIDERS, 'keyvault.datakeys', cls.client, OPTS, - kms_tls_options=KMS_TLS_OPTS) + cls.KMS_PROVIDERS, "keyvault.datakeys", cls.client, OPTS, kms_tls_options=KMS_TLS_OPTS + ) @classmethod def tearDownClass(cls): @@ -778,96 +774,98 @@ def run_test(self, provider_name): # Create data key. master_key = self.MASTER_KEYS[provider_name] datakey_id = self.client_encryption.create_data_key( - provider_name, master_key=master_key, - key_alt_names=['%s_altname' % (provider_name,)]) + provider_name, master_key=master_key, key_alt_names=["%s_altname" % (provider_name,)] + ) self.assertBinaryUUID(datakey_id) - cmd = self.listener.results['started'][-1] - self.assertEqual('insert', cmd.command_name) - self.assertEqual({'w': 'majority'}, cmd.command.get('writeConcern')) - docs = list(self.vault.find({'_id': datakey_id})) + cmd = self.listener.results["started"][-1] + self.assertEqual("insert", cmd.command_name) + self.assertEqual({"w": "majority"}, cmd.command.get("writeConcern")) + docs = list(self.vault.find({"_id": datakey_id})) self.assertEqual(len(docs), 1) - self.assertEqual(docs[0]['masterKey']['provider'], provider_name) + self.assertEqual(docs[0]["masterKey"]["provider"], provider_name) # Encrypt by key_id. encrypted = self.client_encryption.encrypt( - 'hello %s' % (provider_name,), + "hello %s" % (provider_name,), Algorithm.AEAD_AES_256_CBC_HMAC_SHA_512_Deterministic, - key_id=datakey_id) + key_id=datakey_id, + ) self.assertEncrypted(encrypted) - self.client_encrypted.db.coll.insert_one( - {'_id': provider_name, 'value': encrypted}) - doc_decrypted = self.client_encrypted.db.coll.find_one( - {'_id': provider_name}) - self.assertEqual(doc_decrypted['value'], 'hello %s' % (provider_name,)) + self.client_encrypted.db.coll.insert_one({"_id": provider_name, "value": encrypted}) + doc_decrypted = self.client_encrypted.db.coll.find_one({"_id": provider_name}) + self.assertEqual(doc_decrypted["value"], "hello %s" % (provider_name,)) # Encrypt by key_alt_name. encrypted_altname = self.client_encryption.encrypt( - 'hello %s' % (provider_name,), + "hello %s" % (provider_name,), Algorithm.AEAD_AES_256_CBC_HMAC_SHA_512_Deterministic, - key_alt_name='%s_altname' % (provider_name,)) + key_alt_name="%s_altname" % (provider_name,), + ) self.assertEqual(encrypted_altname, encrypted) # Explicitly encrypting an auto encrypted field. - msg = (r'Cannot encrypt element of type binData because schema ' - r'requires that type is one of: \[ string \]') + msg = ( + r"Cannot encrypt element of type binData because schema " + r"requires that type is one of: \[ string \]" + ) with self.assertRaisesRegex(EncryptionError, msg): - self.client_encrypted.db.coll.insert_one( - {'encrypted_placeholder': encrypted}) + self.client_encrypted.db.coll.insert_one({"encrypted_placeholder": encrypted}) def test_data_key_local(self): - self.run_test('local') + self.run_test("local") - @unittest.skipUnless(any(AWS_CREDS.values()), - 'AWS environment credentials are not set') + @unittest.skipUnless(any(AWS_CREDS.values()), "AWS environment credentials are not set") def test_data_key_aws(self): - self.run_test('aws') + self.run_test("aws") - @unittest.skipUnless(any(AZURE_CREDS.values()), - 'Azure environment credentials are not set') + @unittest.skipUnless(any(AZURE_CREDS.values()), "Azure environment credentials are not set") def test_data_key_azure(self): - self.run_test('azure') + self.run_test("azure") - @unittest.skipUnless(any(GCP_CREDS.values()), - 'GCP environment credentials are not set') + @unittest.skipUnless(any(GCP_CREDS.values()), "GCP environment credentials are not set") def test_data_key_gcp(self): - self.run_test('gcp') + self.run_test("gcp") def test_data_key_kmip(self): - self.run_test('kmip') + self.run_test("kmip") class TestExternalKeyVault(EncryptionIntegrationTest): - @staticmethod def kms_providers(): - return {'local': {'key': LOCAL_MASTER_KEY}} + return {"local": {"key": LOCAL_MASTER_KEY}} def _test_external_key_vault(self, with_external_key_vault): self.client.db.coll.drop() vault = create_key_vault( self.client.keyvault.datakeys, - json_data('corpus', 'corpus-key-local.json'), - json_data('corpus', 'corpus-key-aws.json')) + json_data("corpus", "corpus-key-local.json"), + json_data("corpus", "corpus-key-aws.json"), + ) self.addCleanup(vault.drop) # Configure the encrypted field via the local schema_map option. - schemas = {'db.coll': json_data('external', 'external-schema.json')} + schemas = {"db.coll": json_data("external", "external-schema.json")} if with_external_key_vault: - key_vault_client = rs_or_single_client( - username='fake-user', password='fake-pwd') + key_vault_client = rs_or_single_client(username="fake-user", password="fake-pwd") self.addCleanup(key_vault_client.close) else: key_vault_client = client_context.client opts = AutoEncryptionOpts( - self.kms_providers(), 'keyvault.datakeys', schema_map=schemas, - key_vault_client=key_vault_client) + self.kms_providers(), + "keyvault.datakeys", + schema_map=schemas, + key_vault_client=key_vault_client, + ) client_encrypted = rs_or_single_client( - auto_encryption_opts=opts, uuidRepresentation='standard') + auto_encryption_opts=opts, uuidRepresentation="standard" + ) self.addCleanup(client_encrypted.close) client_encryption = ClientEncryption( - self.kms_providers(), 'keyvault.datakeys', key_vault_client, OPTS) + self.kms_providers(), "keyvault.datakeys", key_vault_client, OPTS + ) self.addCleanup(client_encryption.close) if with_external_key_vault: @@ -886,14 +884,15 @@ def _test_external_key_vault(self, with_external_key_vault): client_encryption.encrypt( "test", Algorithm.AEAD_AES_256_CBC_HMAC_SHA_512_Deterministic, - key_id=LOCAL_KEY_ID) + key_id=LOCAL_KEY_ID, + ) # AuthenticationFailed error. self.assertIsInstance(ctx.exception.cause, OperationFailure) self.assertEqual(ctx.exception.cause.code, 18) else: client_encryption.encrypt( - "test", Algorithm.AEAD_AES_256_CBC_HMAC_SHA_512_Deterministic, - key_id=LOCAL_KEY_ID) + "test", Algorithm.AEAD_AES_256_CBC_HMAC_SHA_512_Deterministic, key_id=LOCAL_KEY_ID + ) def test_external_key_vault_1(self): self._test_external_key_vault(True) @@ -903,31 +902,28 @@ def test_external_key_vault_2(self): class TestViews(EncryptionIntegrationTest): - @staticmethod def kms_providers(): - return {'local': {'key': LOCAL_MASTER_KEY}} + return {"local": {"key": LOCAL_MASTER_KEY}} def test_views_are_prohibited(self): self.client.db.view.drop() - self.client.db.create_collection('view', viewOn='coll') + self.client.db.create_collection("view", viewOn="coll") self.addCleanup(self.client.db.view.drop) - opts = AutoEncryptionOpts(self.kms_providers(), 'keyvault.datakeys') + opts = AutoEncryptionOpts(self.kms_providers(), "keyvault.datakeys") client_encrypted = rs_or_single_client( - auto_encryption_opts=opts, uuidRepresentation='standard') + auto_encryption_opts=opts, uuidRepresentation="standard" + ) self.addCleanup(client_encrypted.close) - with self.assertRaisesRegex( - EncryptionError, 'cannot auto encrypt a view'): + with self.assertRaisesRegex(EncryptionError, "cannot auto encrypt a view"): client_encrypted.db.view.insert_one({}) class TestCorpus(EncryptionIntegrationTest): - @classmethod - @unittest.skipUnless(any(AWS_CREDS.values()), - 'AWS environment credentials are not set') + @unittest.skipUnless(any(AWS_CREDS.values()), "AWS environment credentials are not set") def setUpClass(cls): super(TestCorpus, cls).setUpClass() @@ -938,141 +934,156 @@ def kms_providers(): @staticmethod def fix_up_schema(json_schema): """Remove deprecated symbol/dbPointer types from json schema.""" - for key in list(json_schema['properties']): - if '_symbol_' in key or '_dbPointer_' in key: - del json_schema['properties'][key] + for key in list(json_schema["properties"]): + if "_symbol_" in key or "_dbPointer_" in key: + del json_schema["properties"][key] return json_schema @staticmethod def fix_up_curpus(corpus): """Disallow deprecated symbol/dbPointer types from corpus test.""" for key in corpus: - if '_symbol_' in key or '_dbPointer_' in key: - corpus[key]['allowed'] = False + if "_symbol_" in key or "_dbPointer_" in key: + corpus[key]["allowed"] = False return corpus @staticmethod def fix_up_curpus_encrypted(corpus_encrypted, corpus): """Fix the expected values for deprecated symbol/dbPointer types.""" for key in corpus_encrypted: - if '_symbol_' in key or '_dbPointer_' in key: + if "_symbol_" in key or "_dbPointer_" in key: corpus_encrypted[key] = copy.deepcopy(corpus[key]) return corpus_encrypted def _test_corpus(self, opts): # Drop and create the collection 'db.coll' with jsonSchema. coll = create_with_schema( - self.client.db.coll, - self.fix_up_schema(json_data('corpus', 'corpus-schema.json'))) + self.client.db.coll, self.fix_up_schema(json_data("corpus", "corpus-schema.json")) + ) self.addCleanup(coll.drop) vault = create_key_vault( self.client.keyvault.datakeys, - json_data('corpus', 'corpus-key-local.json'), - json_data('corpus', 'corpus-key-aws.json'), - json_data('corpus', 'corpus-key-azure.json'), - json_data('corpus', 'corpus-key-gcp.json'), - json_data('corpus', 'corpus-key-kmip.json')) + json_data("corpus", "corpus-key-local.json"), + json_data("corpus", "corpus-key-aws.json"), + json_data("corpus", "corpus-key-azure.json"), + json_data("corpus", "corpus-key-gcp.json"), + json_data("corpus", "corpus-key-kmip.json"), + ) self.addCleanup(vault.drop) client_encrypted = rs_or_single_client( - auto_encryption_opts=opts, uuidRepresentation='standard') + auto_encryption_opts=opts, uuidRepresentation="standard" + ) self.addCleanup(client_encrypted.close) client_encryption = ClientEncryption( - self.kms_providers(), 'keyvault.datakeys', client_context.client, - OPTS, kms_tls_options=KMS_TLS_OPTS) + self.kms_providers(), + "keyvault.datakeys", + client_context.client, + OPTS, + kms_tls_options=KMS_TLS_OPTS, + ) self.addCleanup(client_encryption.close) - corpus = self.fix_up_curpus(json_data('corpus', 'corpus.json')) + corpus = self.fix_up_curpus(json_data("corpus", "corpus.json")) corpus_copied = SON() for key, value in corpus.items(): corpus_copied[key] = copy.deepcopy(value) - if key in ('_id', 'altname_aws', 'altname_azure', 'altname_gcp', - 'altname_local', 'altname_kmip'): + if key in ( + "_id", + "altname_aws", + "altname_azure", + "altname_gcp", + "altname_local", + "altname_kmip", + ): continue - if value['method'] == 'auto': + if value["method"] == "auto": continue - if value['method'] == 'explicit': - identifier = value['identifier'] - self.assertIn(identifier, ('id', 'altname')) - kms = value['kms'] - self.assertIn(kms, ('local', 'aws', 'azure', 'gcp', 'kmip')) - if identifier == 'id': - if kms == 'local': + if value["method"] == "explicit": + identifier = value["identifier"] + self.assertIn(identifier, ("id", "altname")) + kms = value["kms"] + self.assertIn(kms, ("local", "aws", "azure", "gcp", "kmip")) + if identifier == "id": + if kms == "local": kwargs = dict(key_id=LOCAL_KEY_ID) - elif kms == 'aws': + elif kms == "aws": kwargs = dict(key_id=AWS_KEY_ID) - elif kms == 'azure': + elif kms == "azure": kwargs = dict(key_id=AZURE_KEY_ID) - elif kms == 'gcp': + elif kms == "gcp": kwargs = dict(key_id=GCP_KEY_ID) else: kwargs = dict(key_id=KMIP_KEY_ID) else: kwargs = dict(key_alt_name=kms) - self.assertIn(value['algo'], ('det', 'rand')) - if value['algo'] == 'det': - algo = (Algorithm. - AEAD_AES_256_CBC_HMAC_SHA_512_Deterministic) + self.assertIn(value["algo"], ("det", "rand")) + if value["algo"] == "det": + algo = Algorithm.AEAD_AES_256_CBC_HMAC_SHA_512_Deterministic else: algo = Algorithm.AEAD_AES_256_CBC_HMAC_SHA_512_Random try: - encrypted_val = client_encryption.encrypt( - value['value'], algo, **kwargs) - if not value['allowed']: - self.fail('encrypt should have failed: %r: %r' % ( - key, value)) - corpus_copied[key]['value'] = encrypted_val + encrypted_val = client_encryption.encrypt(value["value"], algo, **kwargs) + if not value["allowed"]: + self.fail("encrypt should have failed: %r: %r" % (key, value)) + corpus_copied[key]["value"] = encrypted_val except Exception: - if value['allowed']: + if value["allowed"]: tb = traceback.format_exc() - self.fail('encrypt failed: %r: %r, traceback: %s' % ( - key, value, tb)) + self.fail("encrypt failed: %r: %r, traceback: %s" % (key, value, tb)) client_encrypted.db.coll.insert_one(corpus_copied) corpus_decrypted = client_encrypted.db.coll.find_one() self.assertEqual(corpus_decrypted, corpus) - corpus_encrypted_expected = self.fix_up_curpus_encrypted(json_data( - 'corpus', 'corpus-encrypted.json'), corpus) + corpus_encrypted_expected = self.fix_up_curpus_encrypted( + json_data("corpus", "corpus-encrypted.json"), corpus + ) corpus_encrypted_actual = coll.find_one() for key, value in corpus_encrypted_actual.items(): - if key in ('_id', 'altname_aws', 'altname_azure', - 'altname_gcp', 'altname_local', 'altname_kmip'): + if key in ( + "_id", + "altname_aws", + "altname_azure", + "altname_gcp", + "altname_local", + "altname_kmip", + ): continue - if value['algo'] == 'det': - self.assertEqual( - value['value'], corpus_encrypted_expected[key]['value'], - key) - elif value['algo'] == 'rand' and value['allowed']: - self.assertNotEqual( - value['value'], corpus_encrypted_expected[key]['value'], - key) - - if value['allowed']: - decrypt_actual = client_encryption.decrypt(value['value']) + if value["algo"] == "det": + self.assertEqual(value["value"], corpus_encrypted_expected[key]["value"], key) + elif value["algo"] == "rand" and value["allowed"]: + self.assertNotEqual(value["value"], corpus_encrypted_expected[key]["value"], key) + + if value["allowed"]: + decrypt_actual = client_encryption.decrypt(value["value"]) decrypt_expected = client_encryption.decrypt( - corpus_encrypted_expected[key]['value']) + corpus_encrypted_expected[key]["value"] + ) self.assertEqual(decrypt_actual, decrypt_expected, key) else: - self.assertEqual(value['value'], corpus[key]['value'], key) + self.assertEqual(value["value"], corpus[key]["value"], key) def test_corpus(self): - opts = AutoEncryptionOpts(self.kms_providers(), 'keyvault.datakeys', - kms_tls_options=KMS_TLS_OPTS) + opts = AutoEncryptionOpts( + self.kms_providers(), "keyvault.datakeys", kms_tls_options=KMS_TLS_OPTS + ) self._test_corpus(opts) def test_corpus_local_schema(self): # Configure the encrypted field via the local schema_map option. - schemas = {'db.coll': self.fix_up_schema( - json_data('corpus', 'corpus-schema.json'))} + schemas = {"db.coll": self.fix_up_schema(json_data("corpus", "corpus-schema.json"))} opts = AutoEncryptionOpts( - self.kms_providers(), 'keyvault.datakeys', schema_map=schemas, - kms_tls_options=KMS_TLS_OPTS) + self.kms_providers(), + "keyvault.datakeys", + schema_map=schemas, + kms_tls_options=KMS_TLS_OPTS, + ) self._test_corpus(opts) @@ -1090,24 +1101,26 @@ def setUpClass(cls): cls.coll = db.coll cls.coll.drop() # Configure the encrypted 'db.coll' collection via jsonSchema. - json_schema = json_data('limits', 'limits-schema.json') + json_schema = json_data("limits", "limits-schema.json") db.create_collection( - 'coll', validator={'$jsonSchema': json_schema}, codec_options=OPTS, - write_concern=WriteConcern(w='majority')) + "coll", + validator={"$jsonSchema": json_schema}, + codec_options=OPTS, + write_concern=WriteConcern(w="majority"), + ) # Create the key vault. coll = client_context.client.get_database( - 'keyvault', - write_concern=WriteConcern(w='majority'), - codec_options=OPTS)['datakeys'] + "keyvault", write_concern=WriteConcern(w="majority"), codec_options=OPTS + )["datakeys"] coll.drop() - coll.insert_one(json_data('limits', 'limits-key.json')) + coll.insert_one(json_data("limits", "limits-key.json")) - opts = AutoEncryptionOpts( - {'local': {'key': LOCAL_MASTER_KEY}}, 'keyvault.datakeys') + opts = AutoEncryptionOpts({"local": {"key": LOCAL_MASTER_KEY}}, "keyvault.datakeys") cls.listener = OvertCommandListener() cls.client_encrypted = rs_or_single_client( - auto_encryption_opts=opts, event_listeners=[cls.listener]) + auto_encryption_opts=opts, event_listeners=[cls.listener] + ) cls.coll_encrypted = cls.client_encrypted.db.coll @classmethod @@ -1117,103 +1130,96 @@ def tearDownClass(cls): super(TestBsonSizeBatches, cls).tearDownClass() def test_01_insert_succeeds_under_2MiB(self): - doc = {'_id': 'over_2mib_under_16mib', 'unencrypted': 'a' * _2_MiB} + doc = {"_id": "over_2mib_under_16mib", "unencrypted": "a" * _2_MiB} self.coll_encrypted.insert_one(doc) # Same with bulk_write. - doc['_id'] = 'over_2mib_under_16mib_bulk' + doc["_id"] = "over_2mib_under_16mib_bulk" self.coll_encrypted.bulk_write([InsertOne(doc)]) def test_02_insert_succeeds_over_2MiB_post_encryption(self): - doc = {'_id': 'encryption_exceeds_2mib', - 'unencrypted': 'a' * ((2**21) - 2000)} - doc.update(json_data('limits', 'limits-doc.json')) + doc = {"_id": "encryption_exceeds_2mib", "unencrypted": "a" * ((2**21) - 2000)} + doc.update(json_data("limits", "limits-doc.json")) self.coll_encrypted.insert_one(doc) # Same with bulk_write. - doc['_id'] = 'encryption_exceeds_2mib_bulk' + doc["_id"] = "encryption_exceeds_2mib_bulk" self.coll_encrypted.bulk_write([InsertOne(doc)]) def test_03_bulk_batch_split(self): - doc1 = {'_id': 'over_2mib_1', 'unencrypted': 'a' * _2_MiB} - doc2 = {'_id': 'over_2mib_2', 'unencrypted': 'a' * _2_MiB} + doc1 = {"_id": "over_2mib_1", "unencrypted": "a" * _2_MiB} + doc2 = {"_id": "over_2mib_2", "unencrypted": "a" * _2_MiB} self.listener.reset() self.coll_encrypted.bulk_write([InsertOne(doc1), InsertOne(doc2)]) - self.assertEqual( - self.listener.started_command_names(), ['insert', 'insert']) + self.assertEqual(self.listener.started_command_names(), ["insert", "insert"]) def test_04_bulk_batch_split(self): - limits_doc = json_data('limits', 'limits-doc.json') - doc1 = {'_id': 'encryption_exceeds_2mib_1', - 'unencrypted': 'a' * (_2_MiB - 2000)} + limits_doc = json_data("limits", "limits-doc.json") + doc1 = {"_id": "encryption_exceeds_2mib_1", "unencrypted": "a" * (_2_MiB - 2000)} doc1.update(limits_doc) - doc2 = {'_id': 'encryption_exceeds_2mib_2', - 'unencrypted': 'a' * (_2_MiB - 2000)} + doc2 = {"_id": "encryption_exceeds_2mib_2", "unencrypted": "a" * (_2_MiB - 2000)} doc2.update(limits_doc) self.listener.reset() self.coll_encrypted.bulk_write([InsertOne(doc1), InsertOne(doc2)]) - self.assertEqual( - self.listener.started_command_names(), ['insert', 'insert']) + self.assertEqual(self.listener.started_command_names(), ["insert", "insert"]) def test_05_insert_succeeds_just_under_16MiB(self): - doc = {'_id': 'under_16mib', 'unencrypted': 'a' * (_16_MiB - 2000)} + doc = {"_id": "under_16mib", "unencrypted": "a" * (_16_MiB - 2000)} self.coll_encrypted.insert_one(doc) # Same with bulk_write. - doc['_id'] = 'under_16mib_bulk' + doc["_id"] = "under_16mib_bulk" self.coll_encrypted.bulk_write([InsertOne(doc)]) def test_06_insert_fails_over_16MiB(self): - limits_doc = json_data('limits', 'limits-doc.json') - doc = {'_id': 'encryption_exceeds_16mib', - 'unencrypted': 'a' * (_16_MiB - 2000)} + limits_doc = json_data("limits", "limits-doc.json") + doc = {"_id": "encryption_exceeds_16mib", "unencrypted": "a" * (_16_MiB - 2000)} doc.update(limits_doc) - with self.assertRaisesRegex(WriteError, 'object to insert too large'): + with self.assertRaisesRegex(WriteError, "object to insert too large"): self.coll_encrypted.insert_one(doc) # Same with bulk_write. - doc['_id'] = 'encryption_exceeds_16mib_bulk' + doc["_id"] = "encryption_exceeds_16mib_bulk" with self.assertRaises(BulkWriteError) as ctx: self.coll_encrypted.bulk_write([InsertOne(doc)]) - err = ctx.exception.details['writeErrors'][0] - self.assertEqual(2, err['code']) - self.assertIn('object to insert too large', err['errmsg']) + err = ctx.exception.details["writeErrors"][0] + self.assertEqual(2, err["code"]) + self.assertIn("object to insert too large", err["errmsg"]) class TestCustomEndpoint(EncryptionIntegrationTest): """Prose tests for creating data keys with a custom endpoint.""" @classmethod - @unittest.skipUnless(any([all(AWS_CREDS.values()), - all(AZURE_CREDS.values()), - all(GCP_CREDS.values())]), - 'No environment credentials are set') + @unittest.skipUnless( + any([all(AWS_CREDS.values()), all(AZURE_CREDS.values()), all(GCP_CREDS.values())]), + "No environment credentials are set", + ) def setUpClass(cls): super(TestCustomEndpoint, cls).setUpClass() def setUp(self): - kms_providers = {'aws': AWS_CREDS, - 'azure': AZURE_CREDS, - 'gcp': GCP_CREDS, - 'kmip': KMIP} + kms_providers = {"aws": AWS_CREDS, "azure": AZURE_CREDS, "gcp": GCP_CREDS, "kmip": KMIP} self.client_encryption = ClientEncryption( kms_providers=kms_providers, - key_vault_namespace='keyvault.datakeys', + key_vault_namespace="keyvault.datakeys", key_vault_client=client_context.client, codec_options=OPTS, - kms_tls_options=KMS_TLS_OPTS) + kms_tls_options=KMS_TLS_OPTS, + ) kms_providers_invalid = copy.deepcopy(kms_providers) - kms_providers_invalid['azure']['identityPlatformEndpoint'] = 'doesnotexist.invalid:443' - kms_providers_invalid['gcp']['endpoint'] = 'doesnotexist.invalid:443' - kms_providers_invalid['kmip']['endpoint'] = 'doesnotexist.local:5698' + kms_providers_invalid["azure"]["identityPlatformEndpoint"] = "doesnotexist.invalid:443" + kms_providers_invalid["gcp"]["endpoint"] = "doesnotexist.invalid:443" + kms_providers_invalid["kmip"]["endpoint"] = "doesnotexist.local:5698" self.client_encryption_invalid = ClientEncryption( kms_providers=kms_providers_invalid, - key_vault_namespace='keyvault.datakeys', + key_vault_namespace="keyvault.datakeys", key_vault_client=client_context.client, codec_options=OPTS, - kms_tls_options=KMS_TLS_OPTS) + kms_tls_options=KMS_TLS_OPTS, + ) self._kmip_host_error = None self._invalid_host_error = None @@ -1222,131 +1228,134 @@ def tearDown(self): self.client_encryption_invalid.close() def run_test_expected_success(self, provider_name, master_key): - data_key_id = self.client_encryption.create_data_key( - provider_name, master_key=master_key) + data_key_id = self.client_encryption.create_data_key(provider_name, master_key=master_key) encrypted = self.client_encryption.encrypt( - 'test', Algorithm.AEAD_AES_256_CBC_HMAC_SHA_512_Deterministic, - key_id=data_key_id) - self.assertEqual('test', self.client_encryption.decrypt(encrypted)) + "test", Algorithm.AEAD_AES_256_CBC_HMAC_SHA_512_Deterministic, key_id=data_key_id + ) + self.assertEqual("test", self.client_encryption.decrypt(encrypted)) - @unittest.skipUnless(any(AWS_CREDS.values()), - 'AWS environment credentials are not set') + @unittest.skipUnless(any(AWS_CREDS.values()), "AWS environment credentials are not set") def test_01_aws_region_key(self): self.run_test_expected_success( - 'aws', - {"region": "us-east-1", - "key": ("arn:aws:kms:us-east-1:579766882180:key/" - "89fcc2c4-08b0-4bd9-9f25-e30687b580d0")}) + "aws", + { + "region": "us-east-1", + "key": ( + "arn:aws:kms:us-east-1:579766882180:key/" "89fcc2c4-08b0-4bd9-9f25-e30687b580d0" + ), + }, + ) - @unittest.skipUnless(any(AWS_CREDS.values()), - 'AWS environment credentials are not set') + @unittest.skipUnless(any(AWS_CREDS.values()), "AWS environment credentials are not set") def test_02_aws_region_key_endpoint(self): self.run_test_expected_success( - 'aws', - {"region": "us-east-1", - "key": ("arn:aws:kms:us-east-1:579766882180:key/" - "89fcc2c4-08b0-4bd9-9f25-e30687b580d0"), - "endpoint": "kms.us-east-1.amazonaws.com"}) - - @unittest.skipUnless(any(AWS_CREDS.values()), - 'AWS environment credentials are not set') + "aws", + { + "region": "us-east-1", + "key": ( + "arn:aws:kms:us-east-1:579766882180:key/" "89fcc2c4-08b0-4bd9-9f25-e30687b580d0" + ), + "endpoint": "kms.us-east-1.amazonaws.com", + }, + ) + + @unittest.skipUnless(any(AWS_CREDS.values()), "AWS environment credentials are not set") def test_03_aws_region_key_endpoint_port(self): self.run_test_expected_success( - 'aws', - {"region": "us-east-1", - "key": ("arn:aws:kms:us-east-1:579766882180:key/" - "89fcc2c4-08b0-4bd9-9f25-e30687b580d0"), - "endpoint": "kms.us-east-1.amazonaws.com:443"}) - - @unittest.skipUnless(any(AWS_CREDS.values()), - 'AWS environment credentials are not set') + "aws", + { + "region": "us-east-1", + "key": ( + "arn:aws:kms:us-east-1:579766882180:key/" "89fcc2c4-08b0-4bd9-9f25-e30687b580d0" + ), + "endpoint": "kms.us-east-1.amazonaws.com:443", + }, + ) + + @unittest.skipUnless(any(AWS_CREDS.values()), "AWS environment credentials are not set") def test_04_aws_endpoint_invalid_port(self): master_key = { "region": "us-east-1", - "key": ("arn:aws:kms:us-east-1:579766882180:key/" - "89fcc2c4-08b0-4bd9-9f25-e30687b580d0"), - "endpoint": "kms.us-east-1.amazonaws.com:12345" + "key": ( + "arn:aws:kms:us-east-1:579766882180:key/" "89fcc2c4-08b0-4bd9-9f25-e30687b580d0" + ), + "endpoint": "kms.us-east-1.amazonaws.com:12345", } with self.assertRaises(EncryptionError) as ctx: - self.client_encryption.create_data_key( - 'aws', master_key=master_key) + self.client_encryption.create_data_key("aws", master_key=master_key) self.assertIsInstance(ctx.exception.cause, socket.error) - @unittest.skipUnless(any(AWS_CREDS.values()), - 'AWS environment credentials are not set') + @unittest.skipUnless(any(AWS_CREDS.values()), "AWS environment credentials are not set") def test_05_aws_endpoint_wrong_region(self): master_key = { "region": "us-east-1", - "key": ("arn:aws:kms:us-east-1:579766882180:key/" - "89fcc2c4-08b0-4bd9-9f25-e30687b580d0"), - "endpoint": "kms.us-east-2.amazonaws.com" + "key": ( + "arn:aws:kms:us-east-1:579766882180:key/" "89fcc2c4-08b0-4bd9-9f25-e30687b580d0" + ), + "endpoint": "kms.us-east-2.amazonaws.com", } # The full error should be something like: # "Credential should be scoped to a valid region, not 'us-east-1'" # but we only check for "us-east-1" to avoid breaking on slight # changes to AWS' error message. - with self.assertRaisesRegex(EncryptionError, 'us-east-1'): - self.client_encryption.create_data_key( - 'aws', master_key=master_key) + with self.assertRaisesRegex(EncryptionError, "us-east-1"): + self.client_encryption.create_data_key("aws", master_key=master_key) - @unittest.skipUnless(any(AWS_CREDS.values()), - 'AWS environment credentials are not set') + @unittest.skipUnless(any(AWS_CREDS.values()), "AWS environment credentials are not set") def test_06_aws_endpoint_invalid_host(self): master_key = { "region": "us-east-1", - "key": ("arn:aws:kms:us-east-1:579766882180:key/" - "89fcc2c4-08b0-4bd9-9f25-e30687b580d0"), - "endpoint": "doesnotexist.invalid" + "key": ( + "arn:aws:kms:us-east-1:579766882180:key/" "89fcc2c4-08b0-4bd9-9f25-e30687b580d0" + ), + "endpoint": "doesnotexist.invalid", } with self.assertRaisesRegex(EncryptionError, self.invalid_host_error): - self.client_encryption.create_data_key( - 'aws', master_key=master_key) + self.client_encryption.create_data_key("aws", master_key=master_key) - @unittest.skipUnless(any(AZURE_CREDS.values()), - 'Azure environment credentials are not set') + @unittest.skipUnless(any(AZURE_CREDS.values()), "Azure environment credentials are not set") def test_07_azure(self): - master_key = {'keyVaultEndpoint': 'key-vault-csfle.vault.azure.net', - 'keyName': 'key-name-csfle'} - self.run_test_expected_success('azure', master_key) + master_key = { + "keyVaultEndpoint": "key-vault-csfle.vault.azure.net", + "keyName": "key-name-csfle", + } + self.run_test_expected_success("azure", master_key) # The full error should be something like: # "[Errno 8] nodename nor servname provided, or not known" with self.assertRaisesRegex(EncryptionError, self.invalid_host_error): - self.client_encryption_invalid.create_data_key( - 'azure', master_key=master_key) + self.client_encryption_invalid.create_data_key("azure", master_key=master_key) - @unittest.skipUnless(any(GCP_CREDS.values()), - 'GCP environment credentials are not set') + @unittest.skipUnless(any(GCP_CREDS.values()), "GCP environment credentials are not set") def test_08_gcp_valid_endpoint(self): master_key = { "projectId": "devprod-drivers", "location": "global", "keyRing": "key-ring-csfle", "keyName": "key-name-csfle", - "endpoint": "cloudkms.googleapis.com:443"} - self.run_test_expected_success('gcp', master_key) + "endpoint": "cloudkms.googleapis.com:443", + } + self.run_test_expected_success("gcp", master_key) # The full error should be something like: # "[Errno 8] nodename nor servname provided, or not known" with self.assertRaisesRegex(EncryptionError, self.invalid_host_error): - self.client_encryption_invalid.create_data_key( - 'gcp', master_key=master_key) + self.client_encryption_invalid.create_data_key("gcp", master_key=master_key) - @unittest.skipUnless(any(GCP_CREDS.values()), - 'GCP environment credentials are not set') + @unittest.skipUnless(any(GCP_CREDS.values()), "GCP environment credentials are not set") def test_09_gcp_invalid_endpoint(self): master_key = { "projectId": "devprod-drivers", "location": "global", "keyRing": "key-ring-csfle", "keyName": "key-name-csfle", - "endpoint": "doesnotexist.invalid:443"} + "endpoint": "doesnotexist.invalid:443", + } # The full error should be something like: # "Invalid KMS response, no access_token returned. HTTP status=200" with self.assertRaisesRegex(EncryptionError, "Invalid KMS response"): - self.client_encryption.create_data_key( - 'gcp', master_key=master_key) + self.client_encryption.create_data_key("gcp", master_key=master_key) def dns_error(self, host, port): # The full error should be something like: @@ -1358,95 +1367,90 @@ def dns_error(self, host, port): @property def invalid_host_error(self): if self._invalid_host_error is None: - self._invalid_host_error = self.dns_error( - 'doesnotexist.invalid', 443) + self._invalid_host_error = self.dns_error("doesnotexist.invalid", 443) return self._invalid_host_error @property def kmip_host_error(self): if self._kmip_host_error is None: - self._kmip_host_error = self.dns_error('doesnotexist.local', 5698) + self._kmip_host_error = self.dns_error("doesnotexist.local", 5698) return self._kmip_host_error def test_10_kmip_invalid_endpoint(self): - key = {'keyId': '1'} - self.run_test_expected_success('kmip', key) + key = {"keyId": "1"} + self.run_test_expected_success("kmip", key) with self.assertRaisesRegex(EncryptionError, self.kmip_host_error): - self.client_encryption_invalid.create_data_key('kmip', key) + self.client_encryption_invalid.create_data_key("kmip", key) def test_11_kmip_master_key_endpoint(self): - key = {'keyId': '1', 'endpoint': KMIP['endpoint']} - self.run_test_expected_success('kmip', key) + key = {"keyId": "1", "endpoint": KMIP["endpoint"]} + self.run_test_expected_success("kmip", key) # Override invalid endpoint: - data_key_id = self.client_encryption_invalid.create_data_key( - 'kmip', master_key=key) + data_key_id = self.client_encryption_invalid.create_data_key("kmip", master_key=key) encrypted = self.client_encryption_invalid.encrypt( - 'test', Algorithm.AEAD_AES_256_CBC_HMAC_SHA_512_Deterministic, - key_id=data_key_id) - self.assertEqual( - 'test', self.client_encryption_invalid.decrypt(encrypted)) + "test", Algorithm.AEAD_AES_256_CBC_HMAC_SHA_512_Deterministic, key_id=data_key_id + ) + self.assertEqual("test", self.client_encryption_invalid.decrypt(encrypted)) def test_12_kmip_master_key_invalid_endpoint(self): - key = {'keyId': '1', 'endpoint': 'doesnotexist.local:5698'} + key = {"keyId": "1", "endpoint": "doesnotexist.local:5698"} with self.assertRaisesRegex(EncryptionError, self.kmip_host_error): - self.client_encryption.create_data_key('kmip', key) + self.client_encryption.create_data_key("kmip", key) class AzureGCPEncryptionTestMixin(object): DEK = None KMS_PROVIDER_MAP = None - KEYVAULT_DB = 'keyvault' - KEYVAULT_COLL = 'datakeys' + KEYVAULT_DB = "keyvault" + KEYVAULT_COLL = "datakeys" def setUp(self): - keyvault = self.client.get_database( - self.KEYVAULT_DB).get_collection( - self.KEYVAULT_COLL) + keyvault = self.client.get_database(self.KEYVAULT_DB).get_collection(self.KEYVAULT_COLL) create_key_vault(keyvault, self.DEK) def _test_explicit(self, expectation): client_encryption = ClientEncryption( self.KMS_PROVIDER_MAP, - '.'.join([self.KEYVAULT_DB, self.KEYVAULT_COLL]), + ".".join([self.KEYVAULT_DB, self.KEYVAULT_COLL]), client_context.client, - OPTS) + OPTS, + ) self.addCleanup(client_encryption.close) ciphertext = client_encryption.encrypt( - 'string0', + "string0", algorithm=Algorithm.AEAD_AES_256_CBC_HMAC_SHA_512_Deterministic, - key_id=Binary.from_uuid(self.DEK['_id'], STANDARD)) + key_id=Binary.from_uuid(self.DEK["_id"], STANDARD), + ) self.assertEqual(bytes(ciphertext), base64.b64decode(expectation)) - self.assertEqual(client_encryption.decrypt(ciphertext), 'string0') + self.assertEqual(client_encryption.decrypt(ciphertext), "string0") def _test_automatic(self, expectation_extjson, payload): encrypted_db = "db" encrypted_coll = "coll" - keyvault_namespace = '.'.join([self.KEYVAULT_DB, self.KEYVAULT_COLL]) + keyvault_namespace = ".".join([self.KEYVAULT_DB, self.KEYVAULT_COLL]) encryption_opts = AutoEncryptionOpts( - self.KMS_PROVIDER_MAP, - keyvault_namespace, - schema_map=self.SCHEMA_MAP) + self.KMS_PROVIDER_MAP, keyvault_namespace, schema_map=self.SCHEMA_MAP + ) - insert_listener = AllowListEventListener('insert') + insert_listener = AllowListEventListener("insert") client = rs_or_single_client( - auto_encryption_opts=encryption_opts, - event_listeners=[insert_listener]) + auto_encryption_opts=encryption_opts, event_listeners=[insert_listener] + ) self.addCleanup(client.close) coll = client.get_database(encrypted_db).get_collection( - encrypted_coll, codec_options=OPTS, - write_concern=WriteConcern("majority")) + encrypted_coll, codec_options=OPTS, write_concern=WriteConcern("majority") + ) coll.drop() - expected_document = json_util.loads( - expectation_extjson, json_options=JSON_OPTS) + expected_document = json_util.loads(expectation_extjson, json_options=JSON_OPTS) coll.insert_one(payload) - event = insert_listener.results['started'][0] - inserted_doc = event.command['documents'][0] + event = insert_listener.results["started"][0] + inserted_doc = event.command["documents"][0] for key, value in expected_document.items(): self.assertEqual(value, inserted_doc[key]) @@ -1456,108 +1460,112 @@ def _test_automatic(self, expectation_extjson, payload): self.assertEqual(output_doc[key], value) -class TestAzureEncryption(AzureGCPEncryptionTestMixin, - EncryptionIntegrationTest): +class TestAzureEncryption(AzureGCPEncryptionTestMixin, EncryptionIntegrationTest): @classmethod - @unittest.skipUnless(any(AZURE_CREDS.values()), - 'Azure environment credentials are not set') + @unittest.skipUnless(any(AZURE_CREDS.values()), "Azure environment credentials are not set") def setUpClass(cls): - cls.KMS_PROVIDER_MAP = {'azure': AZURE_CREDS} - cls.DEK = json_data(BASE, 'custom', 'azure-dek.json') - cls.SCHEMA_MAP = json_data(BASE, 'custom', 'azure-gcp-schema.json') + cls.KMS_PROVIDER_MAP = {"azure": AZURE_CREDS} + cls.DEK = json_data(BASE, "custom", "azure-dek.json") + cls.SCHEMA_MAP = json_data(BASE, "custom", "azure-gcp-schema.json") super(TestAzureEncryption, cls).setUpClass() def test_explicit(self): return self._test_explicit( - 'AQGVERPgAAAAAAAAAAAAAAAC5DbBSwPwfSlBrDtRuglvNvCXD1KzDuCKY2P+4bRFtHDjpTOE2XuytPAUaAbXf1orsPq59PVZmsbTZbt2CB8qaQ==') + "AQGVERPgAAAAAAAAAAAAAAAC5DbBSwPwfSlBrDtRuglvNvCXD1KzDuCKY2P+4bRFtHDjpTOE2XuytPAUaAbXf1orsPq59PVZmsbTZbt2CB8qaQ==" + ) def test_automatic(self): - expected_document_extjson = textwrap.dedent(""" + expected_document_extjson = textwrap.dedent( + """ {"secret_azure": { "$binary": { "base64": "AQGVERPgAAAAAAAAAAAAAAAC5DbBSwPwfSlBrDtRuglvNvCXD1KzDuCKY2P+4bRFtHDjpTOE2XuytPAUaAbXf1orsPq59PVZmsbTZbt2CB8qaQ==", "subType": "06"} - }}""") - return self._test_automatic( - expected_document_extjson, {"secret_azure": "string0"}) + }}""" + ) + return self._test_automatic(expected_document_extjson, {"secret_azure": "string0"}) -class TestGCPEncryption(AzureGCPEncryptionTestMixin, - EncryptionIntegrationTest): +class TestGCPEncryption(AzureGCPEncryptionTestMixin, EncryptionIntegrationTest): @classmethod - @unittest.skipUnless(any(GCP_CREDS.values()), - 'GCP environment credentials are not set') + @unittest.skipUnless(any(GCP_CREDS.values()), "GCP environment credentials are not set") def setUpClass(cls): - cls.KMS_PROVIDER_MAP = {'gcp': GCP_CREDS} - cls.DEK = json_data(BASE, 'custom', 'gcp-dek.json') - cls.SCHEMA_MAP = json_data(BASE, 'custom', 'azure-gcp-schema.json') + cls.KMS_PROVIDER_MAP = {"gcp": GCP_CREDS} + cls.DEK = json_data(BASE, "custom", "gcp-dek.json") + cls.SCHEMA_MAP = json_data(BASE, "custom", "azure-gcp-schema.json") super(TestGCPEncryption, cls).setUpClass() def test_explicit(self): return self._test_explicit( - 'ARgj/gAAAAAAAAAAAAAAAAACwFd+Y5Ojw45GUXNvbcIpN9YkRdoHDHkR4kssdn0tIMKlDQOLFkWFY9X07IRlXsxPD8DcTiKnl6XINK28vhcGlg==') + "ARgj/gAAAAAAAAAAAAAAAAACwFd+Y5Ojw45GUXNvbcIpN9YkRdoHDHkR4kssdn0tIMKlDQOLFkWFY9X07IRlXsxPD8DcTiKnl6XINK28vhcGlg==" + ) def test_automatic(self): - expected_document_extjson = textwrap.dedent(""" + expected_document_extjson = textwrap.dedent( + """ {"secret_gcp": { "$binary": { "base64": "ARgj/gAAAAAAAAAAAAAAAAACwFd+Y5Ojw45GUXNvbcIpN9YkRdoHDHkR4kssdn0tIMKlDQOLFkWFY9X07IRlXsxPD8DcTiKnl6XINK28vhcGlg==", "subType": "06"} - }}""") - return self._test_automatic( - expected_document_extjson, {"secret_gcp": "string0"}) + }}""" + ) + return self._test_automatic(expected_document_extjson, {"secret_gcp": "string0"}) # https://github.com/mongodb/specifications/blob/master/source/client-side-encryption/tests/README.rst#deadlock-tests class TestDeadlockProse(EncryptionIntegrationTest): def setUp(self): self.client_test = rs_or_single_client( - maxPoolSize=1, readConcernLevel='majority', w='majority', - uuidRepresentation='standard') + maxPoolSize=1, readConcernLevel="majority", w="majority", uuidRepresentation="standard" + ) self.addCleanup(self.client_test.close) self.client_keyvault_listener = OvertCommandListener() self.client_keyvault = rs_or_single_client( - maxPoolSize=1, readConcernLevel='majority', w='majority', - event_listeners=[self.client_keyvault_listener]) + maxPoolSize=1, + readConcernLevel="majority", + w="majority", + event_listeners=[self.client_keyvault_listener], + ) self.addCleanup(self.client_keyvault.close) self.client_test.keyvault.datakeys.drop() self.client_test.db.coll.drop() - self.client_test.keyvault.datakeys.insert_one( - json_data('external', 'external-key.json')) + self.client_test.keyvault.datakeys.insert_one(json_data("external", "external-key.json")) _ = self.client_test.db.create_collection( - 'coll', validator={'$jsonSchema': json_data( - 'external', 'external-schema.json')}, - codec_options=OPTS) + "coll", + validator={"$jsonSchema": json_data("external", "external-schema.json")}, + codec_options=OPTS, + ) client_encryption = ClientEncryption( - kms_providers={'local': {'key': LOCAL_MASTER_KEY}}, - key_vault_namespace='keyvault.datakeys', - key_vault_client=self.client_test, codec_options=OPTS) + kms_providers={"local": {"key": LOCAL_MASTER_KEY}}, + key_vault_namespace="keyvault.datakeys", + key_vault_client=self.client_test, + codec_options=OPTS, + ) self.ciphertext = client_encryption.encrypt( - 'string0', Algorithm.AEAD_AES_256_CBC_HMAC_SHA_512_Deterministic, - key_alt_name='local') + "string0", Algorithm.AEAD_AES_256_CBC_HMAC_SHA_512_Deterministic, key_alt_name="local" + ) client_encryption.close() self.client_listener = OvertCommandListener() self.topology_listener = TopologyEventListener() - self.optargs = ({'local': {'key': LOCAL_MASTER_KEY}}, 'keyvault.datakeys') + self.optargs = ({"local": {"key": LOCAL_MASTER_KEY}}, "keyvault.datakeys") def _run_test(self, max_pool_size, auto_encryption_opts): client_encrypted = rs_or_single_client( - readConcernLevel='majority', - w='majority', + readConcernLevel="majority", + w="majority", maxPoolSize=max_pool_size, auto_encryption_opts=auto_encryption_opts, - event_listeners=[self.client_listener, self.topology_listener]) + event_listeners=[self.client_listener, self.topology_listener], + ) if auto_encryption_opts._bypass_auto_encryption == True: - self.client_test.db.coll.insert_one( - {"_id": 0, "encrypted": self.ciphertext}) + self.client_test.db.coll.insert_one({"_id": 0, "encrypted": self.ciphertext}) elif auto_encryption_opts._bypass_auto_encryption == False: - client_encrypted.db.coll.insert_one( - {"_id": 0, "encrypted": "string0"}) + client_encrypted.db.coll.insert_one({"_id": 0, "encrypted": "string0"}) else: raise RuntimeError("bypass_auto_encryption must be a bool") @@ -1567,162 +1575,170 @@ def _run_test(self, max_pool_size, auto_encryption_opts): self.addCleanup(client_encrypted.close) def test_case_1(self): - self._run_test(max_pool_size=1, - auto_encryption_opts=AutoEncryptionOpts( - *self.optargs, - bypass_auto_encryption=False, - key_vault_client=None)) + self._run_test( + max_pool_size=1, + auto_encryption_opts=AutoEncryptionOpts( + *self.optargs, bypass_auto_encryption=False, key_vault_client=None + ), + ) - cev = self.client_listener.results['started'] + cev = self.client_listener.results["started"] self.assertEqual(len(cev), 4) - self.assertEqual(cev[0].command_name, 'listCollections') - self.assertEqual(cev[0].database_name, 'db') - self.assertEqual(cev[1].command_name, 'find') - self.assertEqual(cev[1].database_name, 'keyvault') - self.assertEqual(cev[2].command_name, 'insert') - self.assertEqual(cev[2].database_name, 'db') - self.assertEqual(cev[3].command_name, 'find') - self.assertEqual(cev[3].database_name, 'db') + self.assertEqual(cev[0].command_name, "listCollections") + self.assertEqual(cev[0].database_name, "db") + self.assertEqual(cev[1].command_name, "find") + self.assertEqual(cev[1].database_name, "keyvault") + self.assertEqual(cev[2].command_name, "insert") + self.assertEqual(cev[2].database_name, "db") + self.assertEqual(cev[3].command_name, "find") + self.assertEqual(cev[3].database_name, "db") - self.assertEqual(len(self.topology_listener.results['opened']), 2) + self.assertEqual(len(self.topology_listener.results["opened"]), 2) def test_case_2(self): - self._run_test(max_pool_size=1, - auto_encryption_opts=AutoEncryptionOpts( - *self.optargs, - bypass_auto_encryption=False, - key_vault_client=self.client_keyvault)) + self._run_test( + max_pool_size=1, + auto_encryption_opts=AutoEncryptionOpts( + *self.optargs, bypass_auto_encryption=False, key_vault_client=self.client_keyvault + ), + ) - cev = self.client_listener.results['started'] + cev = self.client_listener.results["started"] self.assertEqual(len(cev), 3) - self.assertEqual(cev[0].command_name, 'listCollections') - self.assertEqual(cev[0].database_name, 'db') - self.assertEqual(cev[1].command_name, 'insert') - self.assertEqual(cev[1].database_name, 'db') - self.assertEqual(cev[2].command_name, 'find') - self.assertEqual(cev[2].database_name, 'db') - - cev = self.client_keyvault_listener.results['started'] + self.assertEqual(cev[0].command_name, "listCollections") + self.assertEqual(cev[0].database_name, "db") + self.assertEqual(cev[1].command_name, "insert") + self.assertEqual(cev[1].database_name, "db") + self.assertEqual(cev[2].command_name, "find") + self.assertEqual(cev[2].database_name, "db") + + cev = self.client_keyvault_listener.results["started"] self.assertEqual(len(cev), 1) - self.assertEqual(cev[0].command_name, 'find') - self.assertEqual(cev[0].database_name, 'keyvault') + self.assertEqual(cev[0].command_name, "find") + self.assertEqual(cev[0].database_name, "keyvault") - self.assertEqual(len(self.topology_listener.results['opened']), 2) + self.assertEqual(len(self.topology_listener.results["opened"]), 2) def test_case_3(self): - self._run_test(max_pool_size=1, - auto_encryption_opts=AutoEncryptionOpts( - *self.optargs, - bypass_auto_encryption=True, - key_vault_client=None)) + self._run_test( + max_pool_size=1, + auto_encryption_opts=AutoEncryptionOpts( + *self.optargs, bypass_auto_encryption=True, key_vault_client=None + ), + ) - cev = self.client_listener.results['started'] + cev = self.client_listener.results["started"] self.assertEqual(len(cev), 2) - self.assertEqual(cev[0].command_name, 'find') - self.assertEqual(cev[0].database_name, 'db') - self.assertEqual(cev[1].command_name, 'find') - self.assertEqual(cev[1].database_name, 'keyvault') + self.assertEqual(cev[0].command_name, "find") + self.assertEqual(cev[0].database_name, "db") + self.assertEqual(cev[1].command_name, "find") + self.assertEqual(cev[1].database_name, "keyvault") - self.assertEqual(len(self.topology_listener.results['opened']), 2) + self.assertEqual(len(self.topology_listener.results["opened"]), 2) def test_case_4(self): - self._run_test(max_pool_size=1, - auto_encryption_opts=AutoEncryptionOpts( - *self.optargs, - bypass_auto_encryption=True, - key_vault_client=self.client_keyvault)) + self._run_test( + max_pool_size=1, + auto_encryption_opts=AutoEncryptionOpts( + *self.optargs, bypass_auto_encryption=True, key_vault_client=self.client_keyvault + ), + ) - cev = self.client_listener.results['started'] + cev = self.client_listener.results["started"] self.assertEqual(len(cev), 1) - self.assertEqual(cev[0].command_name, 'find') - self.assertEqual(cev[0].database_name, 'db') + self.assertEqual(cev[0].command_name, "find") + self.assertEqual(cev[0].database_name, "db") - cev = self.client_keyvault_listener.results['started'] + cev = self.client_keyvault_listener.results["started"] self.assertEqual(len(cev), 1) - self.assertEqual(cev[0].command_name, 'find') - self.assertEqual(cev[0].database_name, 'keyvault') + self.assertEqual(cev[0].command_name, "find") + self.assertEqual(cev[0].database_name, "keyvault") - self.assertEqual(len(self.topology_listener.results['opened']), 1) + self.assertEqual(len(self.topology_listener.results["opened"]), 1) def test_case_5(self): - self._run_test(max_pool_size=None, - auto_encryption_opts=AutoEncryptionOpts( - *self.optargs, - bypass_auto_encryption=False, - key_vault_client=None)) + self._run_test( + max_pool_size=None, + auto_encryption_opts=AutoEncryptionOpts( + *self.optargs, bypass_auto_encryption=False, key_vault_client=None + ), + ) - cev = self.client_listener.results['started'] + cev = self.client_listener.results["started"] self.assertEqual(len(cev), 5) - self.assertEqual(cev[0].command_name, 'listCollections') - self.assertEqual(cev[0].database_name, 'db') - self.assertEqual(cev[1].command_name, 'listCollections') - self.assertEqual(cev[1].database_name, 'keyvault') - self.assertEqual(cev[2].command_name, 'find') - self.assertEqual(cev[2].database_name, 'keyvault') - self.assertEqual(cev[3].command_name, 'insert') - self.assertEqual(cev[3].database_name, 'db') - self.assertEqual(cev[4].command_name, 'find') - self.assertEqual(cev[4].database_name, 'db') - - self.assertEqual(len(self.topology_listener.results['opened']), 1) + self.assertEqual(cev[0].command_name, "listCollections") + self.assertEqual(cev[0].database_name, "db") + self.assertEqual(cev[1].command_name, "listCollections") + self.assertEqual(cev[1].database_name, "keyvault") + self.assertEqual(cev[2].command_name, "find") + self.assertEqual(cev[2].database_name, "keyvault") + self.assertEqual(cev[3].command_name, "insert") + self.assertEqual(cev[3].database_name, "db") + self.assertEqual(cev[4].command_name, "find") + self.assertEqual(cev[4].database_name, "db") + + self.assertEqual(len(self.topology_listener.results["opened"]), 1) def test_case_6(self): - self._run_test(max_pool_size=None, - auto_encryption_opts=AutoEncryptionOpts( - *self.optargs, - bypass_auto_encryption=False, - key_vault_client=self.client_keyvault)) + self._run_test( + max_pool_size=None, + auto_encryption_opts=AutoEncryptionOpts( + *self.optargs, bypass_auto_encryption=False, key_vault_client=self.client_keyvault + ), + ) - cev = self.client_listener.results['started'] + cev = self.client_listener.results["started"] self.assertEqual(len(cev), 3) - self.assertEqual(cev[0].command_name, 'listCollections') - self.assertEqual(cev[0].database_name, 'db') - self.assertEqual(cev[1].command_name, 'insert') - self.assertEqual(cev[1].database_name, 'db') - self.assertEqual(cev[2].command_name, 'find') - self.assertEqual(cev[2].database_name, 'db') - - cev = self.client_keyvault_listener.results['started'] + self.assertEqual(cev[0].command_name, "listCollections") + self.assertEqual(cev[0].database_name, "db") + self.assertEqual(cev[1].command_name, "insert") + self.assertEqual(cev[1].database_name, "db") + self.assertEqual(cev[2].command_name, "find") + self.assertEqual(cev[2].database_name, "db") + + cev = self.client_keyvault_listener.results["started"] self.assertEqual(len(cev), 1) - self.assertEqual(cev[0].command_name, 'find') - self.assertEqual(cev[0].database_name, 'keyvault') + self.assertEqual(cev[0].command_name, "find") + self.assertEqual(cev[0].database_name, "keyvault") - self.assertEqual(len(self.topology_listener.results['opened']), 1) + self.assertEqual(len(self.topology_listener.results["opened"]), 1) def test_case_7(self): - self._run_test(max_pool_size=None, - auto_encryption_opts=AutoEncryptionOpts( - *self.optargs, - bypass_auto_encryption=True, - key_vault_client=None)) + self._run_test( + max_pool_size=None, + auto_encryption_opts=AutoEncryptionOpts( + *self.optargs, bypass_auto_encryption=True, key_vault_client=None + ), + ) - cev = self.client_listener.results['started'] + cev = self.client_listener.results["started"] self.assertEqual(len(cev), 2) - self.assertEqual(cev[0].command_name, 'find') - self.assertEqual(cev[0].database_name, 'db') - self.assertEqual(cev[1].command_name, 'find') - self.assertEqual(cev[1].database_name, 'keyvault') + self.assertEqual(cev[0].command_name, "find") + self.assertEqual(cev[0].database_name, "db") + self.assertEqual(cev[1].command_name, "find") + self.assertEqual(cev[1].database_name, "keyvault") - self.assertEqual(len(self.topology_listener.results['opened']), 1) + self.assertEqual(len(self.topology_listener.results["opened"]), 1) def test_case_8(self): - self._run_test(max_pool_size=None, - auto_encryption_opts=AutoEncryptionOpts( - *self.optargs, - bypass_auto_encryption=True, - key_vault_client=self.client_keyvault)) + self._run_test( + max_pool_size=None, + auto_encryption_opts=AutoEncryptionOpts( + *self.optargs, bypass_auto_encryption=True, key_vault_client=self.client_keyvault + ), + ) - cev = self.client_listener.results['started'] + cev = self.client_listener.results["started"] self.assertEqual(len(cev), 1) - self.assertEqual(cev[0].command_name, 'find') - self.assertEqual(cev[0].database_name, 'db') + self.assertEqual(cev[0].command_name, "find") + self.assertEqual(cev[0].database_name, "db") - cev = self.client_keyvault_listener.results['started'] + cev = self.client_keyvault_listener.results["started"] self.assertEqual(len(cev), 1) - self.assertEqual(cev[0].command_name, 'find') - self.assertEqual(cev[0].database_name, 'keyvault') + self.assertEqual(cev[0].command_name, "find") + self.assertEqual(cev[0].database_name, "keyvault") - self.assertEqual(len(self.topology_listener.results['opened']), 1) + self.assertEqual(len(self.topology_listener.results["opened"]), 1) # https://github.com/mongodb/specifications/blob/master/source/client-side-encryption/tests/README.rst#bypass-spawning-mongocryptd @@ -1731,220 +1747,207 @@ def test_mongocryptd_bypass_spawn(self): # Lower the mongocryptd timeout to reduce the test run time. self._original_timeout = encryption._MONGOCRYPTD_TIMEOUT_MS encryption._MONGOCRYPTD_TIMEOUT_MS = 500 + def reset_timeout(): encryption._MONGOCRYPTD_TIMEOUT_MS = self._original_timeout + self.addCleanup(reset_timeout) # Configure the encrypted field via the local schema_map option. - schemas = {'db.coll': json_data('external', 'external-schema.json')} + schemas = {"db.coll": json_data("external", "external-schema.json")} opts = AutoEncryptionOpts( - {'local': {'key': LOCAL_MASTER_KEY}}, - 'keyvault.datakeys', + {"local": {"key": LOCAL_MASTER_KEY}}, + "keyvault.datakeys", schema_map=schemas, mongocryptd_bypass_spawn=True, - mongocryptd_uri='mongodb://localhost:27027/', + mongocryptd_uri="mongodb://localhost:27027/", mongocryptd_spawn_args=[ - '--pidfilepath=bypass-spawning-mongocryptd.pid', - '--port=27027'] + "--pidfilepath=bypass-spawning-mongocryptd.pid", + "--port=27027", + ], ) client_encrypted = rs_or_single_client(auto_encryption_opts=opts) self.addCleanup(client_encrypted.close) - with self.assertRaisesRegex(EncryptionError, 'Timeout'): - client_encrypted.db.coll.insert_one({'encrypted': 'test'}) + with self.assertRaisesRegex(EncryptionError, "Timeout"): + client_encrypted.db.coll.insert_one({"encrypted": "test"}) def test_bypassAutoEncryption(self): opts = AutoEncryptionOpts( - {'local': {'key': LOCAL_MASTER_KEY}}, - 'keyvault.datakeys', + {"local": {"key": LOCAL_MASTER_KEY}}, + "keyvault.datakeys", bypass_auto_encryption=True, mongocryptd_spawn_args=[ - '--pidfilepath=bypass-spawning-mongocryptd.pid', - '--port=27027'] + "--pidfilepath=bypass-spawning-mongocryptd.pid", + "--port=27027", + ], ) client_encrypted = rs_or_single_client(auto_encryption_opts=opts) self.addCleanup(client_encrypted.close) client_encrypted.db.coll.insert_one({"unencrypted": "test"}) # Validate that mongocryptd was not spawned: - mongocryptd_client = MongoClient( - 'mongodb://localhost:27027/?serverSelectionTimeoutMS=500') + mongocryptd_client = MongoClient("mongodb://localhost:27027/?serverSelectionTimeoutMS=500") with self.assertRaises(ServerSelectionTimeoutError): - mongocryptd_client.admin.command('ping') + mongocryptd_client.admin.command("ping") # https://github.com/mongodb/specifications/tree/master/source/client-side-encryption/tests#kms-tls-tests class TestKmsTLSProse(EncryptionIntegrationTest): - @unittest.skipUnless(any(AWS_CREDS.values()), - 'AWS environment credentials are not set') + @unittest.skipUnless(any(AWS_CREDS.values()), "AWS environment credentials are not set") def setUp(self): super(TestKmsTLSProse, self).setUp() self.patch_system_certs(CA_PEM) self.client_encrypted = ClientEncryption( - {'aws': AWS_CREDS}, 'keyvault.datakeys', self.client, OPTS) + {"aws": AWS_CREDS}, "keyvault.datakeys", self.client, OPTS + ) self.addCleanup(self.client_encrypted.close) def test_invalid_kms_certificate_expired(self): key = { - "region": "us-east-1", - "key": "arn:aws:kms:us-east-1:579766882180:key/" - "89fcc2c4-08b0-4bd9-9f25-e30687b580d0", - "endpoint": "mongodb://127.0.0.1:8000", + "region": "us-east-1", + "key": "arn:aws:kms:us-east-1:579766882180:key/" "89fcc2c4-08b0-4bd9-9f25-e30687b580d0", + "endpoint": "mongodb://127.0.0.1:8000", } # Some examples: # certificate verify failed: certificate has expired (_ssl.c:1129) # amazon1-2018 Python 3.6: certificate verify failed (_ssl.c:852) - with self.assertRaisesRegex( - EncryptionError, 'expired|certificate verify failed'): - self.client_encrypted.create_data_key('aws', master_key=key) + with self.assertRaisesRegex(EncryptionError, "expired|certificate verify failed"): + self.client_encrypted.create_data_key("aws", master_key=key) def test_invalid_hostname_in_kms_certificate(self): key = { - "region": "us-east-1", - "key": "arn:aws:kms:us-east-1:579766882180:key/" - "89fcc2c4-08b0-4bd9-9f25-e30687b580d0", - "endpoint": "mongodb://127.0.0.1:8001", + "region": "us-east-1", + "key": "arn:aws:kms:us-east-1:579766882180:key/" "89fcc2c4-08b0-4bd9-9f25-e30687b580d0", + "endpoint": "mongodb://127.0.0.1:8001", } # Some examples: # certificate verify failed: IP address mismatch, certificate is not valid for '127.0.0.1'. (_ssl.c:1129)" # hostname '127.0.0.1' doesn't match 'wronghost.com' - with self.assertRaisesRegex( - EncryptionError, 'IP address mismatch|wronghost'): - self.client_encrypted.create_data_key('aws', master_key=key) + with self.assertRaisesRegex(EncryptionError, "IP address mismatch|wronghost"): + self.client_encrypted.create_data_key("aws", master_key=key) # https://github.com/mongodb/specifications/blob/master/source/client-side-encryption/tests/README.rst#kms-tls-options-tests class TestKmsTLSOptions(EncryptionIntegrationTest): - @unittest.skipUnless(any(AWS_CREDS.values()), - 'AWS environment credentials are not set') + @unittest.skipUnless(any(AWS_CREDS.values()), "AWS environment credentials are not set") def setUp(self): super(TestKmsTLSOptions, self).setUp() # 1, create client with only tlsCAFile. providers = copy.deepcopy(ALL_KMS_PROVIDERS) - providers['azure']['identityPlatformEndpoint'] = '127.0.0.1:8002' - providers['gcp']['endpoint'] = '127.0.0.1:8002' + providers["azure"]["identityPlatformEndpoint"] = "127.0.0.1:8002" + providers["gcp"]["endpoint"] = "127.0.0.1:8002" kms_tls_opts_ca_only = { - 'aws': {'tlsCAFile': CA_PEM}, - 'azure': {'tlsCAFile': CA_PEM}, - 'gcp': {'tlsCAFile': CA_PEM}, - 'kmip': {'tlsCAFile': CA_PEM}, + "aws": {"tlsCAFile": CA_PEM}, + "azure": {"tlsCAFile": CA_PEM}, + "gcp": {"tlsCAFile": CA_PEM}, + "kmip": {"tlsCAFile": CA_PEM}, } self.client_encryption_no_client_cert = ClientEncryption( - providers, 'keyvault.datakeys', self.client, OPTS, - kms_tls_options=kms_tls_opts_ca_only) + providers, "keyvault.datakeys", self.client, OPTS, kms_tls_options=kms_tls_opts_ca_only + ) self.addCleanup(self.client_encryption_no_client_cert.close) # 2, same providers as above but with tlsCertificateKeyFile. kms_tls_opts = copy.deepcopy(kms_tls_opts_ca_only) for p in kms_tls_opts: - kms_tls_opts[p]['tlsCertificateKeyFile'] = CLIENT_PEM + kms_tls_opts[p]["tlsCertificateKeyFile"] = CLIENT_PEM self.client_encryption_with_tls = ClientEncryption( - providers, 'keyvault.datakeys', self.client, OPTS, - kms_tls_options=kms_tls_opts) + providers, "keyvault.datakeys", self.client, OPTS, kms_tls_options=kms_tls_opts + ) self.addCleanup(self.client_encryption_with_tls.close) # 3, update endpoints to expired host. providers = copy.deepcopy(providers) - providers['azure']['identityPlatformEndpoint'] = '127.0.0.1:8000' - providers['gcp']['endpoint'] = '127.0.0.1:8000' - providers['kmip']['endpoint'] = '127.0.0.1:8000' + providers["azure"]["identityPlatformEndpoint"] = "127.0.0.1:8000" + providers["gcp"]["endpoint"] = "127.0.0.1:8000" + providers["kmip"]["endpoint"] = "127.0.0.1:8000" self.client_encryption_expired = ClientEncryption( - providers, 'keyvault.datakeys', self.client, OPTS, - kms_tls_options=kms_tls_opts_ca_only) + providers, "keyvault.datakeys", self.client, OPTS, kms_tls_options=kms_tls_opts_ca_only + ) self.addCleanup(self.client_encryption_expired.close) # 3, update endpoints to invalid host. providers = copy.deepcopy(providers) - providers['azure']['identityPlatformEndpoint'] = '127.0.0.1:8001' - providers['gcp']['endpoint'] = '127.0.0.1:8001' - providers['kmip']['endpoint'] = '127.0.0.1:8001' + providers["azure"]["identityPlatformEndpoint"] = "127.0.0.1:8001" + providers["gcp"]["endpoint"] = "127.0.0.1:8001" + providers["kmip"]["endpoint"] = "127.0.0.1:8001" self.client_encryption_invalid_hostname = ClientEncryption( - providers, 'keyvault.datakeys', self.client, OPTS, - kms_tls_options=kms_tls_opts_ca_only) + providers, "keyvault.datakeys", self.client, OPTS, kms_tls_options=kms_tls_opts_ca_only + ) self.addCleanup(self.client_encryption_invalid_hostname.close) # Errors when client has no cert, some examples: # [SSL: TLSV13_ALERT_CERTIFICATE_REQUIRED] tlsv13 alert certificate required (_ssl.c:2623) - self.cert_error = ('certificate required|SSL handshake failed|' - 'KMS connection closed') + self.cert_error = "certificate required|SSL handshake failed|" "KMS connection closed" # On Windows this error might be: # [WinError 10054] An existing connection was forcibly closed by the remote host - if sys.platform == 'win32': - self.cert_error += '|forcibly closed' + if sys.platform == "win32": + self.cert_error += "|forcibly closed" # On Windows Python 3.10+ this error might be: # EOF occurred in violation of protocol (_ssl.c:2384) if sys.version_info[:2] >= (3, 10): - self.cert_error += '|EOF' + self.cert_error += "|EOF" def test_01_aws(self): key = { - 'region': 'us-east-1', - 'key': 'arn:aws:kms:us-east-1:579766882180:key/89fcc2c4-08b0-4bd9-9f25-e30687b580d0', - 'endpoint': '127.0.0.1:8002', + "region": "us-east-1", + "key": "arn:aws:kms:us-east-1:579766882180:key/89fcc2c4-08b0-4bd9-9f25-e30687b580d0", + "endpoint": "127.0.0.1:8002", } with self.assertRaisesRegex(EncryptionError, self.cert_error): - self.client_encryption_no_client_cert.create_data_key('aws', key) + self.client_encryption_no_client_cert.create_data_key("aws", key) # "parse error" here means that the TLS handshake succeeded. - with self.assertRaisesRegex(EncryptionError, 'parse error'): - self.client_encryption_with_tls.create_data_key('aws', key) + with self.assertRaisesRegex(EncryptionError, "parse error"): + self.client_encryption_with_tls.create_data_key("aws", key) # Some examples: # certificate verify failed: certificate has expired (_ssl.c:1129) # amazon1-2018 Python 3.6: certificate verify failed (_ssl.c:852) - key['endpoint'] = '127.0.0.1:8000' - with self.assertRaisesRegex( - EncryptionError, 'expired|certificate verify failed'): - self.client_encryption_expired.create_data_key('aws', key) + key["endpoint"] = "127.0.0.1:8000" + with self.assertRaisesRegex(EncryptionError, "expired|certificate verify failed"): + self.client_encryption_expired.create_data_key("aws", key) # Some examples: # certificate verify failed: IP address mismatch, certificate is not valid for '127.0.0.1'. (_ssl.c:1129)" # hostname '127.0.0.1' doesn't match 'wronghost.com' - key['endpoint'] = '127.0.0.1:8001' - with self.assertRaisesRegex( - EncryptionError, 'IP address mismatch|wronghost'): - self.client_encryption_invalid_hostname.create_data_key('aws', key) + key["endpoint"] = "127.0.0.1:8001" + with self.assertRaisesRegex(EncryptionError, "IP address mismatch|wronghost"): + self.client_encryption_invalid_hostname.create_data_key("aws", key) def test_02_azure(self): - key = {'keyVaultEndpoint': 'doesnotexist.local', 'keyName': 'foo'} + key = {"keyVaultEndpoint": "doesnotexist.local", "keyName": "foo"} # Missing client cert error. with self.assertRaisesRegex(EncryptionError, self.cert_error): - self.client_encryption_no_client_cert.create_data_key('azure', key) + self.client_encryption_no_client_cert.create_data_key("azure", key) # "HTTP status=404" here means that the TLS handshake succeeded. - with self.assertRaisesRegex(EncryptionError, 'HTTP status=404'): - self.client_encryption_with_tls.create_data_key('azure', key) + with self.assertRaisesRegex(EncryptionError, "HTTP status=404"): + self.client_encryption_with_tls.create_data_key("azure", key) # Expired cert error. - with self.assertRaisesRegex( - EncryptionError, 'expired|certificate verify failed'): - self.client_encryption_expired.create_data_key('azure', key) + with self.assertRaisesRegex(EncryptionError, "expired|certificate verify failed"): + self.client_encryption_expired.create_data_key("azure", key) # Invalid cert hostname error. - with self.assertRaisesRegex( - EncryptionError, 'IP address mismatch|wronghost'): - self.client_encryption_invalid_hostname.create_data_key( - 'azure', key) + with self.assertRaisesRegex(EncryptionError, "IP address mismatch|wronghost"): + self.client_encryption_invalid_hostname.create_data_key("azure", key) def test_03_gcp(self): - key = {'projectId': 'foo', 'location': 'bar', 'keyRing': 'baz', - 'keyName': 'foo'} + key = {"projectId": "foo", "location": "bar", "keyRing": "baz", "keyName": "foo"} # Missing client cert error. with self.assertRaisesRegex(EncryptionError, self.cert_error): - self.client_encryption_no_client_cert.create_data_key('gcp', key) + self.client_encryption_no_client_cert.create_data_key("gcp", key) # "HTTP status=404" here means that the TLS handshake succeeded. - with self.assertRaisesRegex(EncryptionError, 'HTTP status=404'): - self.client_encryption_with_tls.create_data_key('gcp', key) + with self.assertRaisesRegex(EncryptionError, "HTTP status=404"): + self.client_encryption_with_tls.create_data_key("gcp", key) # Expired cert error. - with self.assertRaisesRegex( - EncryptionError, 'expired|certificate verify failed'): - self.client_encryption_expired.create_data_key('gcp', key) + with self.assertRaisesRegex(EncryptionError, "expired|certificate verify failed"): + self.client_encryption_expired.create_data_key("gcp", key) # Invalid cert hostname error. - with self.assertRaisesRegex( - EncryptionError, 'IP address mismatch|wronghost'): - self.client_encryption_invalid_hostname.create_data_key('gcp', key) + with self.assertRaisesRegex(EncryptionError, "IP address mismatch|wronghost"): + self.client_encryption_invalid_hostname.create_data_key("gcp", key) def test_04_kmip(self): # Missing client cert error. with self.assertRaisesRegex(EncryptionError, self.cert_error): - self.client_encryption_no_client_cert.create_data_key('kmip') - self.client_encryption_with_tls.create_data_key('kmip') + self.client_encryption_no_client_cert.create_data_key("kmip") + self.client_encryption_with_tls.create_data_key("kmip") # Expired cert error. - with self.assertRaisesRegex( - EncryptionError, 'expired|certificate verify failed'): - self.client_encryption_expired.create_data_key('kmip') + with self.assertRaisesRegex(EncryptionError, "expired|certificate verify failed"): + self.client_encryption_expired.create_data_key("kmip") # Invalid cert hostname error. - with self.assertRaisesRegex( - EncryptionError, 'IP address mismatch|wronghost'): - self.client_encryption_invalid_hostname.create_data_key('kmip') + with self.assertRaisesRegex(EncryptionError, "IP address mismatch|wronghost"): + self.client_encryption_invalid_hostname.create_data_key("kmip") if __name__ == "__main__": diff --git a/test/test_errors.py b/test/test_errors.py index 53c55f8167..8a225b6548 100644 --- a/test/test_errors.py +++ b/test/test_errors.py @@ -18,12 +18,14 @@ sys.path[0:0] = [""] -from pymongo.errors import (BulkWriteError, - EncryptionError, - NotPrimaryError, - OperationFailure) -from test import (PyMongoTestCase, - unittest) +from test import PyMongoTestCase, unittest + +from pymongo.errors import ( + BulkWriteError, + EncryptionError, + NotPrimaryError, + OperationFailure, +) class TestErrors(PyMongoTestCase): @@ -36,8 +38,7 @@ def test_not_primary_error(self): self.assertIn("full error", traceback.format_exc()) def test_operation_failure(self): - exc = OperationFailure("operation failure test", 10, - {"errmsg": "error"}) + exc = OperationFailure("operation failure test", 10, {"errmsg": "error"}) self.assertIn("full error", str(exc)) try: raise exc @@ -45,26 +46,26 @@ def test_operation_failure(self): self.assertIn("full error", traceback.format_exc()) def _test_unicode_strs(self, exc): - if sys.implementation.name == 'pypy' and sys.implementation.version < (7, 3, 7): + if sys.implementation.name == "pypy" and sys.implementation.version < (7, 3, 7): # PyPy used to display unicode in repr differently. - self.assertEqual("unicode \U0001f40d, full error: {" - "'errmsg': 'unicode \\U0001f40d'}", str(exc)) + self.assertEqual( + "unicode \U0001f40d, full error: {" "'errmsg': 'unicode \\U0001f40d'}", str(exc) + ) else: - self.assertEqual("unicode \U0001f40d, full error: {" - "'errmsg': 'unicode \U0001f40d'}", str(exc)) + self.assertEqual( + "unicode \U0001f40d, full error: {" "'errmsg': 'unicode \U0001f40d'}", str(exc) + ) try: raise exc except Exception: self.assertIn("full error", traceback.format_exc()) def test_unicode_strs_operation_failure(self): - exc = OperationFailure('unicode \U0001f40d', 10, - {"errmsg": 'unicode \U0001f40d'}) + exc = OperationFailure("unicode \U0001f40d", 10, {"errmsg": "unicode \U0001f40d"}) self._test_unicode_strs(exc) def test_unicode_strs_not_primary_error(self): - exc = NotPrimaryError('unicode \U0001f40d', - {"errmsg": 'unicode \U0001f40d'}) + exc = NotPrimaryError("unicode \U0001f40d", {"errmsg": "unicode \U0001f40d"}) self._test_unicode_strs(exc) def assertPyMongoErrorEqual(self, exc1, exc2): @@ -84,7 +85,7 @@ def test_pickle_NotPrimaryError(self): self.assertPyMongoErrorEqual(exc, pickle.loads(pickle.dumps(exc))) def test_pickle_OperationFailure(self): - exc = OperationFailure('error', code=5, details={}, max_wire_version=7) + exc = OperationFailure("error", code=5, details={}, max_wire_version=7) self.assertOperationFailureEqual(exc, pickle.loads(pickle.dumps(exc))) def test_pickle_BulkWriteError(self): @@ -93,8 +94,7 @@ def test_pickle_BulkWriteError(self): self.assertIn("batch op errors occurred", str(exc)) def test_pickle_EncryptionError(self): - cause = OperationFailure('error', code=5, details={}, - max_wire_version=7) + cause = OperationFailure("error", code=5, details={}, max_wire_version=7) exc = EncryptionError(cause) exc2 = pickle.loads(pickle.dumps(exc)) self.assertPyMongoErrorEqual(exc, exc2) diff --git a/test/test_examples.py b/test/test_examples.py index dcf9dd2de3..0cdca017a4 100644 --- a/test/test_examples.py +++ b/test/test_examples.py @@ -20,6 +20,9 @@ sys.path[0:0] = [""] +from test import IntegrationTest, client_context, unittest +from test.utils import rs_client + import pymongo from pymongo.errors import ConnectionFailure, OperationFailure from pymongo.read_concern import ReadConcern @@ -27,9 +30,6 @@ from pymongo.server_api import ServerApi from pymongo.write_concern import WriteConcern -from test import client_context, unittest, IntegrationTest -from test.utils import rs_client - class TestSampleShellCommands(IntegrationTest): @classmethod @@ -51,10 +51,13 @@ def test_first_three_examples(self): # Start Example 1 db.inventory.insert_one( - {"item": "canvas", - "qty": 100, - "tags": ["cotton"], - "size": {"h": 28, "w": 35.5, "uom": "cm"}}) + { + "item": "canvas", + "qty": 100, + "tags": ["cotton"], + "size": {"h": 28, "w": 35.5, "uom": "cm"}, + } + ) # End Example 1 self.assertEqual(db.inventory.count_documents({}), 1) @@ -66,19 +69,28 @@ def test_first_three_examples(self): self.assertEqual(len(list(cursor)), 1) # Start Example 3 - db.inventory.insert_many([ - {"item": "journal", - "qty": 25, - "tags": ["blank", "red"], - "size": {"h": 14, "w": 21, "uom": "cm"}}, - {"item": "mat", - "qty": 85, - "tags": ["gray"], - "size": {"h": 27.9, "w": 35.5, "uom": "cm"}}, - {"item": "mousepad", - "qty": 25, - "tags": ["gel", "blue"], - "size": {"h": 19, "w": 22.85, "uom": "cm"}}]) + db.inventory.insert_many( + [ + { + "item": "journal", + "qty": 25, + "tags": ["blank", "red"], + "size": {"h": 14, "w": 21, "uom": "cm"}, + }, + { + "item": "mat", + "qty": 85, + "tags": ["gray"], + "size": {"h": 27.9, "w": 35.5, "uom": "cm"}, + }, + { + "item": "mousepad", + "qty": 25, + "tags": ["gel", "blue"], + "size": {"h": 19, "w": 22.85, "uom": "cm"}, + }, + ] + ) # End Example 3 self.assertEqual(db.inventory.count_documents({}), 4) @@ -87,26 +99,40 @@ def test_query_top_level_fields(self): db = self.db # Start Example 6 - db.inventory.insert_many([ - {"item": "journal", - "qty": 25, - "size": {"h": 14, "w": 21, "uom": "cm"}, - "status": "A"}, - {"item": "notebook", - "qty": 50, - "size": {"h": 8.5, "w": 11, "uom": "in"}, - "status": "A"}, - {"item": "paper", - "qty": 100, - "size": {"h": 8.5, "w": 11, "uom": "in"}, - "status": "D"}, - {"item": "planner", - "qty": 75, "size": {"h": 22.85, "w": 30, "uom": "cm"}, - "status": "D"}, - {"item": "postcard", - "qty": 45, - "size": {"h": 10, "w": 15.25, "uom": "cm"}, - "status": "A"}]) + db.inventory.insert_many( + [ + { + "item": "journal", + "qty": 25, + "size": {"h": 14, "w": 21, "uom": "cm"}, + "status": "A", + }, + { + "item": "notebook", + "qty": 50, + "size": {"h": 8.5, "w": 11, "uom": "in"}, + "status": "A", + }, + { + "item": "paper", + "qty": 100, + "size": {"h": 8.5, "w": 11, "uom": "in"}, + "status": "D", + }, + { + "item": "planner", + "qty": 75, + "size": {"h": 22.85, "w": 30, "uom": "cm"}, + "status": "D", + }, + { + "item": "postcard", + "qty": 45, + "size": {"h": 10, "w": 15.25, "uom": "cm"}, + "status": "A", + }, + ] + ) # End Example 6 self.assertEqual(db.inventory.count_documents({}), 5) @@ -136,16 +162,15 @@ def test_query_top_level_fields(self): self.assertEqual(len(list(cursor)), 1) # Start Example 12 - cursor = db.inventory.find( - {"$or": [{"status": "A"}, {"qty": {"$lt": 30}}]}) + cursor = db.inventory.find({"$or": [{"status": "A"}, {"qty": {"$lt": 30}}]}) # End Example 12 self.assertEqual(len(list(cursor)), 3) # Start Example 13 - cursor = db.inventory.find({ - "status": "A", - "$or": [{"qty": {"$lt": 30}}, {"item": {"$regex": "^p"}}]}) + cursor = db.inventory.find( + {"status": "A", "$or": [{"qty": {"$lt": 30}}, {"item": {"$regex": "^p"}}]} + ) # End Example 13 self.assertEqual(len(list(cursor)), 2) @@ -157,39 +182,51 @@ def test_query_embedded_documents(self): # Subdocument key order matters in a few of these examples so we have # to use bson.son.SON instead of a Python dict. from bson.son import SON - db.inventory.insert_many([ - {"item": "journal", - "qty": 25, - "size": SON([("h", 14), ("w", 21), ("uom", "cm")]), - "status": "A"}, - {"item": "notebook", - "qty": 50, - "size": SON([("h", 8.5), ("w", 11), ("uom", "in")]), - "status": "A"}, - {"item": "paper", - "qty": 100, - "size": SON([("h", 8.5), ("w", 11), ("uom", "in")]), - "status": "D"}, - {"item": "planner", - "qty": 75, - "size": SON([("h", 22.85), ("w", 30), ("uom", "cm")]), - "status": "D"}, - {"item": "postcard", - "qty": 45, - "size": SON([("h", 10), ("w", 15.25), ("uom", "cm")]), - "status": "A"}]) + + db.inventory.insert_many( + [ + { + "item": "journal", + "qty": 25, + "size": SON([("h", 14), ("w", 21), ("uom", "cm")]), + "status": "A", + }, + { + "item": "notebook", + "qty": 50, + "size": SON([("h", 8.5), ("w", 11), ("uom", "in")]), + "status": "A", + }, + { + "item": "paper", + "qty": 100, + "size": SON([("h", 8.5), ("w", 11), ("uom", "in")]), + "status": "D", + }, + { + "item": "planner", + "qty": 75, + "size": SON([("h", 22.85), ("w", 30), ("uom", "cm")]), + "status": "D", + }, + { + "item": "postcard", + "qty": 45, + "size": SON([("h", 10), ("w", 15.25), ("uom", "cm")]), + "status": "A", + }, + ] + ) # End Example 14 # Start Example 15 - cursor = db.inventory.find( - {"size": SON([("h", 14), ("w", 21), ("uom", "cm")])}) + cursor = db.inventory.find({"size": SON([("h", 14), ("w", 21), ("uom", "cm")])}) # End Example 15 self.assertEqual(len(list(cursor)), 1) # Start Example 16 - cursor = db.inventory.find( - {"size": SON([("w", 21), ("h", 14), ("uom", "cm")])}) + cursor = db.inventory.find({"size": SON([("w", 21), ("h", 14), ("uom", "cm")])}) # End Example 16 self.assertEqual(len(list(cursor)), 0) @@ -207,8 +244,7 @@ def test_query_embedded_documents(self): self.assertEqual(len(list(cursor)), 4) # Start Example 19 - cursor = db.inventory.find( - {"size.h": {"$lt": 15}, "size.uom": "in", "status": "D"}) + cursor = db.inventory.find({"size.h": {"$lt": 15}, "size.uom": "in", "status": "D"}) # End Example 19 self.assertEqual(len(list(cursor)), 1) @@ -217,27 +253,20 @@ def test_query_arrays(self): db = self.db # Start Example 20 - db.inventory.insert_many([ - {"item": "journal", - "qty": 25, - "tags": ["blank", "red"], - "dim_cm": [14, 21]}, - {"item": "notebook", - "qty": 50, - "tags": ["red", "blank"], - "dim_cm": [14, 21]}, - {"item": "paper", - "qty": 100, - "tags": ["red", "blank", "plain"], - "dim_cm": [14, 21]}, - {"item": "planner", - "qty": 75, - "tags": ["blank", "red"], - "dim_cm": [22.85, 30]}, - {"item": "postcard", - "qty": 45, - "tags": ["blue"], - "dim_cm": [10, 15.25]}]) + db.inventory.insert_many( + [ + {"item": "journal", "qty": 25, "tags": ["blank", "red"], "dim_cm": [14, 21]}, + {"item": "notebook", "qty": 50, "tags": ["red", "blank"], "dim_cm": [14, 21]}, + { + "item": "paper", + "qty": 100, + "tags": ["red", "blank", "plain"], + "dim_cm": [14, 21], + }, + {"item": "planner", "qty": 75, "tags": ["blank", "red"], "dim_cm": [22.85, 30]}, + {"item": "postcard", "qty": 45, "tags": ["blue"], "dim_cm": [10, 15.25]}, + ] + ) # End Example 20 # Start Example 21 @@ -271,8 +300,7 @@ def test_query_arrays(self): self.assertEqual(len(list(cursor)), 4) # Start Example 26 - cursor = db.inventory.find( - {"dim_cm": {"$elemMatch": {"$gt": 22, "$lt": 30}}}) + cursor = db.inventory.find({"dim_cm": {"$elemMatch": {"$gt": 22, "$lt": 30}}}) # End Example 26 self.assertEqual(len(list(cursor)), 1) @@ -296,64 +324,74 @@ def test_query_array_of_documents(self): # Subdocument key order matters in a few of these examples so we have # to use bson.son.SON instead of a Python dict. from bson.son import SON - db.inventory.insert_many([ - {"item": "journal", - "instock": [ - SON([("warehouse", "A"), ("qty", 5)]), - SON([("warehouse", "C"), ("qty", 15)])]}, - {"item": "notebook", - "instock": [ - SON([("warehouse", "C"), ("qty", 5)])]}, - {"item": "paper", - "instock": [ - SON([("warehouse", "A"), ("qty", 60)]), - SON([("warehouse", "B"), ("qty", 15)])]}, - {"item": "planner", - "instock": [ - SON([("warehouse", "A"), ("qty", 40)]), - SON([("warehouse", "B"), ("qty", 5)])]}, - {"item": "postcard", - "instock": [ - SON([("warehouse", "B"), ("qty", 15)]), - SON([("warehouse", "C"), ("qty", 35)])]}]) + + db.inventory.insert_many( + [ + { + "item": "journal", + "instock": [ + SON([("warehouse", "A"), ("qty", 5)]), + SON([("warehouse", "C"), ("qty", 15)]), + ], + }, + {"item": "notebook", "instock": [SON([("warehouse", "C"), ("qty", 5)])]}, + { + "item": "paper", + "instock": [ + SON([("warehouse", "A"), ("qty", 60)]), + SON([("warehouse", "B"), ("qty", 15)]), + ], + }, + { + "item": "planner", + "instock": [ + SON([("warehouse", "A"), ("qty", 40)]), + SON([("warehouse", "B"), ("qty", 5)]), + ], + }, + { + "item": "postcard", + "instock": [ + SON([("warehouse", "B"), ("qty", 15)]), + SON([("warehouse", "C"), ("qty", 35)]), + ], + }, + ] + ) # End Example 29 # Start Example 30 - cursor = db.inventory.find( - {"instock": SON([("warehouse", "A"), ("qty", 5)])}) + cursor = db.inventory.find({"instock": SON([("warehouse", "A"), ("qty", 5)])}) # End Example 30 self.assertEqual(len(list(cursor)), 1) # Start Example 31 - cursor = db.inventory.find( - {"instock": SON([("qty", 5), ("warehouse", "A")])}) + cursor = db.inventory.find({"instock": SON([("qty", 5), ("warehouse", "A")])}) # End Example 31 self.assertEqual(len(list(cursor)), 0) # Start Example 32 - cursor = db.inventory.find({'instock.0.qty': {"$lte": 20}}) + cursor = db.inventory.find({"instock.0.qty": {"$lte": 20}}) # End Example 32 self.assertEqual(len(list(cursor)), 3) # Start Example 33 - cursor = db.inventory.find({'instock.qty': {"$lte": 20}}) + cursor = db.inventory.find({"instock.qty": {"$lte": 20}}) # End Example 33 self.assertEqual(len(list(cursor)), 5) # Start Example 34 - cursor = db.inventory.find( - {"instock": {"$elemMatch": {"qty": 5, "warehouse": "A"}}}) + cursor = db.inventory.find({"instock": {"$elemMatch": {"qty": 5, "warehouse": "A"}}}) # End Example 34 self.assertEqual(len(list(cursor)), 1) # Start Example 35 - cursor = db.inventory.find( - {"instock": {"$elemMatch": {"qty": {"$gt": 10, "$lte": 20}}}}) + cursor = db.inventory.find({"instock": {"$elemMatch": {"qty": {"$gt": 10, "$lte": 20}}}}) # End Example 35 self.assertEqual(len(list(cursor)), 3) @@ -365,8 +403,7 @@ def test_query_array_of_documents(self): self.assertEqual(len(list(cursor)), 4) # Start Example 37 - cursor = db.inventory.find( - {"instock.qty": 5, "instock.warehouse": "A"}) + cursor = db.inventory.find({"instock.qty": 5, "instock.warehouse": "A"}) # End Example 37 self.assertEqual(len(list(cursor)), 2) @@ -400,29 +437,40 @@ def test_projection(self): db = self.db # Start Example 42 - db.inventory.insert_many([ - {"item": "journal", - "status": "A", - "size": {"h": 14, "w": 21, "uom": "cm"}, - "instock": [{"warehouse": "A", "qty": 5}]}, - {"item": "notebook", - "status": "A", - "size": {"h": 8.5, "w": 11, "uom": "in"}, - "instock": [{"warehouse": "C", "qty": 5}]}, - {"item": "paper", - "status": "D", - "size": {"h": 8.5, "w": 11, "uom": "in"}, - "instock": [{"warehouse": "A", "qty": 60}]}, - {"item": "planner", - "status": "D", - "size": {"h": 22.85, "w": 30, "uom": "cm"}, - "instock": [{"warehouse": "A", "qty": 40}]}, - {"item": "postcard", - "status": "A", - "size": {"h": 10, "w": 15.25, "uom": "cm"}, - "instock": [ - {"warehouse": "B", "qty": 15}, - {"warehouse": "C", "qty": 35}]}]) + db.inventory.insert_many( + [ + { + "item": "journal", + "status": "A", + "size": {"h": 14, "w": 21, "uom": "cm"}, + "instock": [{"warehouse": "A", "qty": 5}], + }, + { + "item": "notebook", + "status": "A", + "size": {"h": 8.5, "w": 11, "uom": "in"}, + "instock": [{"warehouse": "C", "qty": 5}], + }, + { + "item": "paper", + "status": "D", + "size": {"h": 8.5, "w": 11, "uom": "in"}, + "instock": [{"warehouse": "A", "qty": 60}], + }, + { + "item": "planner", + "status": "D", + "size": {"h": 22.85, "w": 30, "uom": "cm"}, + "instock": [{"warehouse": "A", "qty": 40}], + }, + { + "item": "postcard", + "status": "A", + "size": {"h": 10, "w": 15.25, "uom": "cm"}, + "instock": [{"warehouse": "B", "qty": 15}, {"warehouse": "C", "qty": 35}], + }, + ] + ) # End Example 42 # Start Example 43 @@ -432,8 +480,7 @@ def test_projection(self): self.assertEqual(len(list(cursor)), 3) # Start Example 44 - cursor = db.inventory.find( - {"status": "A"}, {"item": 1, "status": 1}) + cursor = db.inventory.find({"status": "A"}, {"item": 1, "status": 1}) # End Example 44 for doc in cursor: @@ -444,8 +491,7 @@ def test_projection(self): self.assertFalse("instock" in doc) # Start Example 45 - cursor = db.inventory.find( - {"status": "A"}, {"item": 1, "status": 1, "_id": 0}) + cursor = db.inventory.find({"status": "A"}, {"item": 1, "status": 1, "_id": 0}) # End Example 45 for doc in cursor: @@ -456,8 +502,7 @@ def test_projection(self): self.assertFalse("instock" in doc) # Start Example 46 - cursor = db.inventory.find( - {"status": "A"}, {"status": 0, "instock": 0}) + cursor = db.inventory.find({"status": "A"}, {"status": 0, "instock": 0}) # End Example 46 for doc in cursor: @@ -468,8 +513,7 @@ def test_projection(self): self.assertFalse("instock" in doc) # Start Example 47 - cursor = db.inventory.find( - {"status": "A"}, {"item": 1, "status": 1, "size.uom": 1}) + cursor = db.inventory.find({"status": "A"}, {"item": 1, "status": 1, "size.uom": 1}) # End Example 47 for doc in cursor: @@ -478,10 +522,10 @@ def test_projection(self): self.assertTrue("status" in doc) self.assertTrue("size" in doc) self.assertFalse("instock" in doc) - size = doc['size'] - self.assertTrue('uom' in size) - self.assertFalse('h' in size) - self.assertFalse('w' in size) + size = doc["size"] + self.assertTrue("uom" in size) + self.assertFalse("h" in size) + self.assertFalse("w" in size) # Start Example 48 cursor = db.inventory.find({"status": "A"}, {"size.uom": 0}) @@ -493,14 +537,13 @@ def test_projection(self): self.assertTrue("status" in doc) self.assertTrue("size" in doc) self.assertTrue("instock" in doc) - size = doc['size'] - self.assertFalse('uom' in size) - self.assertTrue('h' in size) - self.assertTrue('w' in size) + size = doc["size"] + self.assertFalse("uom" in size) + self.assertTrue("h" in size) + self.assertTrue("w" in size) # Start Example 49 - cursor = db.inventory.find( - {"status": "A"}, {"item": 1, "status": 1, "instock.qty": 1}) + cursor = db.inventory.find({"status": "A"}, {"item": 1, "status": 1, "instock.qty": 1}) # End Example 49 for doc in cursor: @@ -509,14 +552,14 @@ def test_projection(self): self.assertTrue("status" in doc) self.assertFalse("size" in doc) self.assertTrue("instock" in doc) - for subdoc in doc['instock']: - self.assertFalse('warehouse' in subdoc) - self.assertTrue('qty' in subdoc) + for subdoc in doc["instock"]: + self.assertFalse("warehouse" in subdoc) + self.assertTrue("qty" in subdoc) # Start Example 50 cursor = db.inventory.find( - {"status": "A"}, - {"item": 1, "status": 1, "instock": {"$slice": -1}}) + {"status": "A"}, {"item": 1, "status": 1, "instock": {"$slice": -1}} + ) # End Example 50 for doc in cursor: @@ -531,54 +574,77 @@ def test_update_and_replace(self): db = self.db # Start Example 51 - db.inventory.insert_many([ - {"item": "canvas", - "qty": 100, - "size": {"h": 28, "w": 35.5, "uom": "cm"}, - "status": "A"}, - {"item": "journal", - "qty": 25, - "size": {"h": 14, "w": 21, "uom": "cm"}, - "status": "A"}, - {"item": "mat", - "qty": 85, - "size": {"h": 27.9, "w": 35.5, "uom": "cm"}, - "status": "A"}, - {"item": "mousepad", - "qty": 25, - "size": {"h": 19, "w": 22.85, "uom": "cm"}, - "status": "P"}, - {"item": "notebook", - "qty": 50, - "size": {"h": 8.5, "w": 11, "uom": "in"}, - "status": "P"}, - {"item": "paper", - "qty": 100, - "size": {"h": 8.5, "w": 11, "uom": "in"}, - "status": "D"}, - {"item": "planner", - "qty": 75, - "size": {"h": 22.85, "w": 30, "uom": "cm"}, - "status": "D"}, - {"item": "postcard", - "qty": 45, - "size": {"h": 10, "w": 15.25, "uom": "cm"}, - "status": "A"}, - {"item": "sketchbook", - "qty": 80, - "size": {"h": 14, "w": 21, "uom": "cm"}, - "status": "A"}, - {"item": "sketch pad", - "qty": 95, - "size": {"h": 22.85, "w": 30.5, "uom": "cm"}, - "status": "A"}]) + db.inventory.insert_many( + [ + { + "item": "canvas", + "qty": 100, + "size": {"h": 28, "w": 35.5, "uom": "cm"}, + "status": "A", + }, + { + "item": "journal", + "qty": 25, + "size": {"h": 14, "w": 21, "uom": "cm"}, + "status": "A", + }, + { + "item": "mat", + "qty": 85, + "size": {"h": 27.9, "w": 35.5, "uom": "cm"}, + "status": "A", + }, + { + "item": "mousepad", + "qty": 25, + "size": {"h": 19, "w": 22.85, "uom": "cm"}, + "status": "P", + }, + { + "item": "notebook", + "qty": 50, + "size": {"h": 8.5, "w": 11, "uom": "in"}, + "status": "P", + }, + { + "item": "paper", + "qty": 100, + "size": {"h": 8.5, "w": 11, "uom": "in"}, + "status": "D", + }, + { + "item": "planner", + "qty": 75, + "size": {"h": 22.85, "w": 30, "uom": "cm"}, + "status": "D", + }, + { + "item": "postcard", + "qty": 45, + "size": {"h": 10, "w": 15.25, "uom": "cm"}, + "status": "A", + }, + { + "item": "sketchbook", + "qty": 80, + "size": {"h": 14, "w": 21, "uom": "cm"}, + "status": "A", + }, + { + "item": "sketch pad", + "qty": 95, + "size": {"h": 22.85, "w": 30.5, "uom": "cm"}, + "status": "A", + }, + ] + ) # End Example 51 # Start Example 52 db.inventory.update_one( {"item": "paper"}, - {"$set": {"size.uom": "cm", "status": "P"}, - "$currentDate": {"lastModified": True}}) + {"$set": {"size.uom": "cm", "status": "P"}, "$currentDate": {"lastModified": True}}, + ) # End Example 52 for doc in db.inventory.find({"item": "paper"}): @@ -589,8 +655,8 @@ def test_update_and_replace(self): # Start Example 53 db.inventory.update_many( {"qty": {"$lt": 50}}, - {"$set": {"size.uom": "in", "status": "P"}, - "$currentDate": {"lastModified": True}}) + {"$set": {"size.uom": "in", "status": "P"}, "$currentDate": {"lastModified": True}}, + ) # End Example 53 for doc in db.inventory.find({"qty": {"$lt": 50}}): @@ -601,10 +667,11 @@ def test_update_and_replace(self): # Start Example 54 db.inventory.replace_one( {"item": "paper"}, - {"item": "paper", - "instock": [ - {"warehouse": "A", "qty": 60}, - {"warehouse": "B", "qty": 40}]}) + { + "item": "paper", + "instock": [{"warehouse": "A", "qty": 60}, {"warehouse": "B", "qty": 40}], + }, + ) # End Example 54 for doc in db.inventory.find({"item": "paper"}, {"_id": 0}): @@ -617,27 +684,40 @@ def test_delete(self): db = self.db # Start Example 55 - db.inventory.insert_many([ - {"item": "journal", - "qty": 25, - "size": {"h": 14, "w": 21, "uom": "cm"}, - "status": "A"}, - {"item": "notebook", - "qty": 50, - "size": {"h": 8.5, "w": 11, "uom": "in"}, - "status": "P"}, - {"item": "paper", - "qty": 100, - "size": {"h": 8.5, "w": 11, "uom": "in"}, - "status": "D"}, - {"item": "planner", - "qty": 75, - "size": {"h": 22.85, "w": 30, "uom": "cm"}, - "status": "D"}, - {"item": "postcard", - "qty": 45, - "size": {"h": 10, "w": 15.25, "uom": "cm"}, - "status": "A"}]) + db.inventory.insert_many( + [ + { + "item": "journal", + "qty": 25, + "size": {"h": 14, "w": 21, "uom": "cm"}, + "status": "A", + }, + { + "item": "notebook", + "qty": 50, + "size": {"h": 8.5, "w": 11, "uom": "in"}, + "status": "P", + }, + { + "item": "paper", + "qty": 100, + "size": {"h": 8.5, "w": 11, "uom": "in"}, + "status": "D", + }, + { + "item": "planner", + "qty": 75, + "size": {"h": 22.85, "w": 30, "uom": "cm"}, + "status": "D", + }, + { + "item": "postcard", + "qty": 45, + "size": {"h": 10, "w": 15.25, "uom": "cm"}, + "status": "A", + }, + ] + ) # End Example 55 self.assertEqual(db.inventory.count_documents({}), 5) @@ -682,7 +762,7 @@ def insert_docs(): # End Changestream Example 1 # Start Changestream Example 2 - cursor = db.inventory.watch(full_document='updateLookup') + cursor = db.inventory.watch(full_document="updateLookup") document = next(cursor) # End Changestream Example 2 @@ -694,8 +774,8 @@ def insert_docs(): # Start Changestream Example 4 pipeline = [ - {'$match': {'fullDocument.username': 'alice'}}, - {'$addFields': {'newField': 'this is an added field!'}} + {"$match": {"fullDocument.username": "alice"}}, + {"$addFields": {"newField": "this is an added field!"}}, ] cursor = db.inventory.watch(pipeline=pipeline) document = next(cursor) @@ -708,83 +788,77 @@ def test_aggregate_examples(self): db = self.db # Start Aggregation Example 1 - db.sales.aggregate([ - {"$match": {"items.fruit": "banana"}}, - {"$sort": {"date": 1}} - ]) + db.sales.aggregate([{"$match": {"items.fruit": "banana"}}, {"$sort": {"date": 1}}]) # End Aggregation Example 1 # Start Aggregation Example 2 - db.sales.aggregate([ - {"$unwind": "$items"}, - {"$match": {"items.fruit": "banana"}}, - {"$group": { - "_id": {"day": {"$dayOfWeek": "$date"}}, - "count": {"$sum": "$items.quantity"}} - }, - {"$project": { - "dayOfWeek": "$_id.day", - "numberSold": "$count", - "_id": 0} - }, - {"$sort": {"numberSold": 1}} - ]) + db.sales.aggregate( + [ + {"$unwind": "$items"}, + {"$match": {"items.fruit": "banana"}}, + { + "$group": { + "_id": {"day": {"$dayOfWeek": "$date"}}, + "count": {"$sum": "$items.quantity"}, + } + }, + {"$project": {"dayOfWeek": "$_id.day", "numberSold": "$count", "_id": 0}}, + {"$sort": {"numberSold": 1}}, + ] + ) # End Aggregation Example 2 # Start Aggregation Example 3 - db.sales.aggregate([ - {"$unwind": "$items"}, - {"$group": { - "_id": {"day": {"$dayOfWeek": "$date"}}, - "items_sold": {"$sum": "$items.quantity"}, - "revenue": { - "$sum": { - "$multiply": [ - "$items.quantity", "$items.price"] - } + db.sales.aggregate( + [ + {"$unwind": "$items"}, + { + "$group": { + "_id": {"day": {"$dayOfWeek": "$date"}}, + "items_sold": {"$sum": "$items.quantity"}, + "revenue": {"$sum": {"$multiply": ["$items.quantity", "$items.price"]}}, } - } - }, - {"$project": { - "day": "$_id.day", - "revenue": 1, - "items_sold": 1, - "discount": { - "$cond": { - "if": {"$lte": ["$revenue", 250]}, - "then": 25, - "else": 0 - } + }, + { + "$project": { + "day": "$_id.day", + "revenue": 1, + "items_sold": 1, + "discount": { + "$cond": {"if": {"$lte": ["$revenue", 250]}, "then": 25, "else": 0} + }, } - } - } - ]) + }, + ] + ) # End Aggregation Example 3 # Start Aggregation Example 4 - db.air_alliances.aggregate([ - {"$lookup": { - "from": "air_airlines", - "let": {"constituents": "$airlines"}, - "pipeline": [ - {"$match": {"$expr": {"$in": ["$name", "$$constituents"]}}} - ], - "as": "airlines" - } - }, - {"$project": { - "_id": 0, - "name": 1, - "airlines": { - "$filter": { - "input": "$airlines", - "as": "airline", - "cond": {"$eq": ["$$airline.country", "Canada"]} - } + db.air_alliances.aggregate( + [ + { + "$lookup": { + "from": "air_airlines", + "let": {"constituents": "$airlines"}, + "pipeline": [{"$match": {"$expr": {"$in": ["$name", "$$constituents"]}}}], + "as": "airlines", } - } - } - ]) + }, + { + "$project": { + "_id": 0, + "name": 1, + "airlines": { + "$filter": { + "input": "$airlines", + "as": "airline", + "cond": {"$eq": ["$$airline.country", "Canada"]}, + } + }, + } + }, + ] + ) # End Aggregation Example 4 def test_commands(self): @@ -809,7 +883,7 @@ def test_index_management(self): # Start Index Example 1 db.restaurants.create_index( [("cuisine", pymongo.ASCENDING), ("name", pymongo.ASCENDING)], - partialFilterExpression={"rating": {"$gt": 5}} + partialFilterExpression={"rating": {"$gt": 5}}, ) # End Index Example 1 @@ -823,18 +897,14 @@ def test_misc(self): # 2. Tunable consistency controls collection = client.my_database.my_collection with client.start_session() as session: - collection.insert_one({'_id': 1}, session=session) - collection.update_one( - {'_id': 1}, {"$set": {"a": 1}}, session=session) + collection.insert_one({"_id": 1}, session=session) + collection.update_one({"_id": 1}, {"$set": {"a": 1}}, session=session) for doc in collection.find({}, session=session): pass # 3. Exploiting the power of arrays collection = client.test.array_updates_test - collection.update_one( - {'_id': 1}, - {"$set": {"a.$[i].b": 2}}, - array_filters=[{"i.b": 0}]) + collection.update_one({"_id": 1}, {"$set": {"a.$[i].b": 2}}, array_filters=[{"i.b": 0}]) class TestTransactionExamples(IntegrationTest): @@ -848,8 +918,7 @@ def test_transactions(self): employees = client.hr.employees events = client.reporting.events employees.insert_one({"employee": 3, "status": "Active"}) - events.insert_one( - {"employee": 3, "status": {"new": "Active", "old": None}}) + events.insert_one({"employee": 3, "status": {"new": "Active", "old": None}}) # Start Transactions Intro Example 1 @@ -858,15 +927,14 @@ def update_employee_info(session): events_coll = session.client.reporting.events with session.start_transaction( - read_concern=ReadConcern("snapshot"), - write_concern=WriteConcern(w="majority")): + read_concern=ReadConcern("snapshot"), write_concern=WriteConcern(w="majority") + ): employees_coll.update_one( - {"employee": 3}, {"$set": {"status": "Inactive"}}, - session=session) + {"employee": 3}, {"$set": {"status": "Inactive"}}, session=session + ) events_coll.insert_one( - {"employee": 3, "status": { - "new": "Inactive", "old": "Active"}}, - session=session) + {"employee": 3, "status": {"new": "Inactive", "old": "Active"}}, session=session + ) while True: try: @@ -876,14 +944,15 @@ def update_employee_info(session): break except (ConnectionFailure, OperationFailure) as exc: # Can retry commit - if exc.has_error_label( - "UnknownTransactionCommitResult"): - print("UnknownTransactionCommitResult, retrying " - "commit operation ...") + if exc.has_error_label("UnknownTransactionCommitResult"): + print( + "UnknownTransactionCommitResult, retrying " "commit operation ..." + ) continue else: print("Error during commit ...") raise + # End Transactions Intro Example 1 with client.start_session() as session: @@ -891,7 +960,7 @@ def update_employee_info(session): employee = employees.find_one({"employee": 3}) self.assertIsNotNone(employee) - self.assertEqual(employee['status'], 'Inactive') + self.assertEqual(employee["status"], "Inactive") # Start Transactions Retry Example 1 def run_transaction_with_retry(txn_func, session): @@ -900,16 +969,15 @@ def run_transaction_with_retry(txn_func, session): txn_func(session) # performs transaction break except (ConnectionFailure, OperationFailure) as exc: - print("Transaction aborted. Caught exception during " - "transaction.") + print("Transaction aborted. Caught exception during " "transaction.") # If transient error, retry the whole transaction if exc.has_error_label("TransientTransactionError"): - print("TransientTransactionError, retrying" - "transaction ...") + print("TransientTransactionError, retrying" "transaction ...") continue else: raise + # End Transactions Retry Example 1 with client.start_session() as session: @@ -917,7 +985,7 @@ def run_transaction_with_retry(txn_func, session): employee = employees.find_one({"employee": 3}) self.assertIsNotNone(employee) - self.assertEqual(employee['status'], 'Inactive') + self.assertEqual(employee["status"], "Inactive") # Start Transactions Retry Example 2 def commit_with_retry(session): @@ -930,23 +998,21 @@ def commit_with_retry(session): except (ConnectionFailure, OperationFailure) as exc: # Can retry commit if exc.has_error_label("UnknownTransactionCommitResult"): - print("UnknownTransactionCommitResult, retrying " - "commit operation ...") + print("UnknownTransactionCommitResult, retrying " "commit operation ...") continue else: print("Error during commit ...") raise + # End Transactions Retry Example 2 # Test commit_with_retry from the previous examples def _insert_employee_retry_commit(session): with session.start_transaction(): - employees.insert_one( - {"employee": 4, "status": "Active"}, - session=session) + employees.insert_one({"employee": 4, "status": "Active"}, session=session) events.insert_one( - {"employee": 4, "status": {"new": "Active", "old": None}}, - session=session) + {"employee": 4, "status": {"new": "Active", "old": None}}, session=session + ) commit_with_retry(session) @@ -955,7 +1021,7 @@ def _insert_employee_retry_commit(session): employee = employees.find_one({"employee": 4}) self.assertIsNotNone(employee) - self.assertEqual(employee['status'], 'Active') + self.assertEqual(employee["status"], "Active") # Start Transactions Retry Example 3 @@ -967,8 +1033,7 @@ def run_transaction_with_retry(txn_func, session): except (ConnectionFailure, OperationFailure) as exc: # If transient error, retry the whole transaction if exc.has_error_label("TransientTransactionError"): - print("TransientTransactionError, retrying " - "transaction ...") + print("TransientTransactionError, retrying " "transaction ...") continue else: raise @@ -983,8 +1048,7 @@ def commit_with_retry(session): except (ConnectionFailure, OperationFailure) as exc: # Can retry commit if exc.has_error_label("UnknownTransactionCommitResult"): - print("UnknownTransactionCommitResult, retrying " - "commit operation ...") + print("UnknownTransactionCommitResult, retrying " "commit operation ...") continue else: print("Error during commit ...") @@ -997,16 +1061,16 @@ def update_employee_info(session): events_coll = session.client.reporting.events with session.start_transaction( - read_concern=ReadConcern("snapshot"), - write_concern=WriteConcern(w="majority"), - read_preference=ReadPreference.PRIMARY): + read_concern=ReadConcern("snapshot"), + write_concern=WriteConcern(w="majority"), + read_preference=ReadPreference.PRIMARY, + ): employees_coll.update_one( - {"employee": 3}, {"$set": {"status": "Inactive"}}, - session=session) + {"employee": 3}, {"$set": {"status": "Inactive"}}, session=session + ) events_coll.insert_one( - {"employee": 3, "status": { - "new": "Inactive", "old": "Active"}}, - session=session) + {"employee": 3, "status": {"new": "Inactive", "old": "Active"}}, session=session + ) commit_with_retry(session) @@ -1022,7 +1086,7 @@ def update_employee_info(session): employee = employees.find_one({"employee": 3}) self.assertIsNotNone(employee) - self.assertEqual(employee['status'], 'Inactive') + self.assertEqual(employee["status"], "Inactive") MongoClient = lambda _: rs_client() uriString = None @@ -1038,10 +1102,8 @@ def update_employee_info(session): wc_majority = WriteConcern("majority", wtimeout=1000) # Prereq: Create collections. - client.get_database( - "mydb1", write_concern=wc_majority).foo.insert_one({'abc': 0}) - client.get_database( - "mydb2", write_concern=wc_majority).bar.insert_one({'xyz': 0}) + client.get_database("mydb1", write_concern=wc_majority).foo.insert_one({"abc": 0}) + client.get_database("mydb2", write_concern=wc_majority).bar.insert_one({"xyz": 0}) # Step 1: Define the callback that specifies the sequence of operations to perform inside the transactions. def callback(session): @@ -1049,16 +1111,18 @@ def callback(session): collection_two = session.client.mydb2.bar # Important:: You must pass the session to the operations. - collection_one.insert_one({'abc': 1}, session=session) - collection_two.insert_one({'xyz': 999}, session=session) + collection_one.insert_one({"abc": 1}, session=session) + collection_two.insert_one({"xyz": 999}, session=session) # Step 2: Start a client session. with client.start_session() as session: # Step 3: Use with_transaction to start a transaction, execute the callback, and commit (or abort on error). session.with_transaction( - callback, read_concern=ReadConcern('local'), + callback, + read_concern=ReadConcern("local"), write_concern=wc_majority, - read_preference=ReadPreference.PRIMARY) + read_preference=ReadPreference.PRIMARY, + ) # End Transactions withTxn API Example 1 @@ -1069,24 +1133,26 @@ class TestCausalConsistencyExamples(IntegrationTest): def test_causal_consistency(self): # Causal consistency examples client = self.client - self.addCleanup(client.drop_database, 'test') - client.test.drop_collection('items') - client.test.items.insert_one({ - 'sku': "111", 'name': 'Peanuts', - 'start':datetime.datetime.today()}) + self.addCleanup(client.drop_database, "test") + client.test.drop_collection("items") + client.test.items.insert_one( + {"sku": "111", "name": "Peanuts", "start": datetime.datetime.today()} + ) # Start Causal Consistency Example 1 with client.start_session(causal_consistency=True) as s1: current_date = datetime.datetime.today() items = client.get_database( - 'test', read_concern=ReadConcern('majority'), - write_concern=WriteConcern('majority', wtimeout=1000)).items + "test", + read_concern=ReadConcern("majority"), + write_concern=WriteConcern("majority", wtimeout=1000), + ).items items.update_one( - {'sku': "111", 'end': None}, - {'$set': {'end': current_date}}, session=s1) + {"sku": "111", "end": None}, {"$set": {"end": current_date}}, session=s1 + ) items.insert_one( - {'sku': "nuts-111", 'name': "Pecans", - 'start': current_date}, session=s1) + {"sku": "nuts-111", "name": "Pecans", "start": current_date}, session=s1 + ) # End Causal Consistency Example 1 # Start Causal Consistency Example 2 @@ -1095,10 +1161,12 @@ def test_causal_consistency(self): s2.advance_operation_time(s1.operation_time) items = client.get_database( - 'test', read_preference=ReadPreference.SECONDARY, - read_concern=ReadConcern('majority'), - write_concern=WriteConcern('majority', wtimeout=1000)).items - for item in items.find({'end': None}, session=s2): + "test", + read_preference=ReadPreference.SECONDARY, + read_concern=ReadConcern("majority"), + write_concern=WriteConcern("majority", wtimeout=1000), + ).items + for item in items.find({"end": None}, session=s2): print(item) # End Causal Consistency Example 2 @@ -1107,35 +1175,33 @@ class TestVersionedApiExamples(IntegrationTest): @client_context.require_version_min(4, 7) def test_versioned_api(self): # Versioned API examples - MongoClient = lambda _, server_api: rs_client( - server_api=server_api, connect=False) + MongoClient = lambda _, server_api: rs_client(server_api=server_api, connect=False) uri = None # Start Versioned API Example 1 from pymongo.server_api import ServerApi + client = MongoClient(uri, server_api=ServerApi("1")) # End Versioned API Example 1 # Start Versioned API Example 2 - client = MongoClient( - uri, server_api=ServerApi("1", strict=True)) + client = MongoClient(uri, server_api=ServerApi("1", strict=True)) # End Versioned API Example 2 # Start Versioned API Example 3 - client = MongoClient( - uri, server_api=ServerApi("1", strict=False)) + client = MongoClient(uri, server_api=ServerApi("1", strict=False)) # End Versioned API Example 3 # Start Versioned API Example 4 - client = MongoClient( - uri, server_api=ServerApi("1", deprecation_errors=True)) + client = MongoClient(uri, server_api=ServerApi("1", deprecation_errors=True)) # End Versioned API Example 4 @client_context.require_version_min(4, 7) def test_versioned_api_migration(self): # SERVER-58785 - if (client_context.is_topology_type(["sharded"]) and - not client_context.version.at_least(5, 0, 2)): + if client_context.is_topology_type(["sharded"]) and not client_context.version.at_least( + 5, 0, 2 + ): self.skipTest("This test needs MongoDB 5.0.2 or newer") client = rs_client(server_api=ServerApi("1", strict=True)) @@ -1144,22 +1210,74 @@ def test_versioned_api_migration(self): # Start Versioned API Example 5 def strptime(s): return datetime.datetime.strptime(s, "%Y-%m-%dT%H:%M:%SZ") - client.db.sales.insert_many([ - {"_id": 1, "item": "abc", "price": 10, "quantity": 2, "date": strptime("2021-01-01T08:00:00Z")}, - {"_id": 2, "item": "jkl", "price": 20, "quantity": 1, "date": strptime("2021-02-03T09:00:00Z")}, - {"_id": 3, "item": "xyz", "price": 5, "quantity": 5, "date": strptime("2021-02-03T09:05:00Z")}, - {"_id": 4, "item": "abc", "price": 10, "quantity": 10, "date": strptime("2021-02-15T08:00:00Z")}, - {"_id": 5, "item": "xyz", "price": 5, "quantity": 10, "date": strptime("2021-02-15T09:05:00Z")}, - {"_id": 6, "item": "xyz", "price": 5, "quantity": 5, "date": strptime("2021-02-15T12:05:10Z")}, - {"_id": 7, "item": "xyz", "price": 5, "quantity": 10, "date": strptime("2021-02-15T14:12:12Z")}, - {"_id": 8, "item": "abc", "price": 10, "quantity": 5, "date": strptime("2021-03-16T20:20:13Z")} - ]) + + client.db.sales.insert_many( + [ + { + "_id": 1, + "item": "abc", + "price": 10, + "quantity": 2, + "date": strptime("2021-01-01T08:00:00Z"), + }, + { + "_id": 2, + "item": "jkl", + "price": 20, + "quantity": 1, + "date": strptime("2021-02-03T09:00:00Z"), + }, + { + "_id": 3, + "item": "xyz", + "price": 5, + "quantity": 5, + "date": strptime("2021-02-03T09:05:00Z"), + }, + { + "_id": 4, + "item": "abc", + "price": 10, + "quantity": 10, + "date": strptime("2021-02-15T08:00:00Z"), + }, + { + "_id": 5, + "item": "xyz", + "price": 5, + "quantity": 10, + "date": strptime("2021-02-15T09:05:00Z"), + }, + { + "_id": 6, + "item": "xyz", + "price": 5, + "quantity": 5, + "date": strptime("2021-02-15T12:05:10Z"), + }, + { + "_id": 7, + "item": "xyz", + "price": 5, + "quantity": 10, + "date": strptime("2021-02-15T14:12:12Z"), + }, + { + "_id": 8, + "item": "abc", + "price": 10, + "quantity": 5, + "date": strptime("2021-03-16T20:20:13Z"), + }, + ] + ) # End Versioned API Example 5 with self.assertRaisesRegex( - OperationFailure, "Provided apiStrict:true, but the command " - "count is not in API Version 1"): - client.db.command('count', 'sales', query={}) + OperationFailure, + "Provided apiStrict:true, but the command " "count is not in API Version 1", + ): + client.db.command("count", "sales", query={}) # Start Versioned API Example 6 # pymongo.errors.OperationFailure: Provided apiStrict:true, but the command count is not in API Version 1, full error: {'ok': 0.0, 'errmsg': 'Provided apiStrict:true, but the command count is not in API Version 1', 'code': 323, 'codeName': 'APIStrictError'} # End Versioned API Example 6 diff --git a/test/test_grid_file.py b/test/test_grid_file.py index a53e40c4c9..94312c778c 100644 --- a/test/test_grid_file.py +++ b/test/test_grid_file.py @@ -21,32 +21,31 @@ import io import sys import zipfile - from io import BytesIO sys.path[0:0] = [""] +from test import IntegrationTest, qcheck, unittest +from test.utils import EventListener, rs_or_single_client + from bson.objectid import ObjectId from gridfs import GridFS -from gridfs.grid_file import (DEFAULT_CHUNK_SIZE, - _SEEK_CUR, - _SEEK_END, - GridIn, - GridOut, - GridOutCursor) from gridfs.errors import NoFile +from gridfs.grid_file import ( + _SEEK_CUR, + _SEEK_END, + DEFAULT_CHUNK_SIZE, + GridIn, + GridOut, + GridOutCursor, +) from pymongo import MongoClient from pymongo.errors import ConfigurationError, ServerSelectionTimeoutError from pymongo.message import _CursorAddress -from test import (IntegrationTest, - unittest, - qcheck) -from test.utils import rs_or_single_client, EventListener class TestGridFileNoConnect(unittest.TestCase): - """Test GridFile features on a client that does not connect. - """ + """Test GridFile features on a client that does not connect.""" @classmethod def setUpClass(cls): @@ -55,9 +54,17 @@ def setUpClass(cls): def test_grid_in_custom_opts(self): self.assertRaises(TypeError, GridIn, "foo") - a = GridIn(self.db.fs, _id=5, filename="my_file", - contentType="text/html", chunkSize=1000, aliases=["foo"], - metadata={"foo": 1, "bar": 2}, bar=3, baz="hello") + a = GridIn( + self.db.fs, + _id=5, + filename="my_file", + contentType="text/html", + chunkSize=1000, + aliases=["foo"], + metadata={"foo": 1, "bar": 2}, + bar=3, + baz="hello", + ) self.assertEqual(5, a._id) self.assertEqual("my_file", a.filename) @@ -70,15 +77,13 @@ def test_grid_in_custom_opts(self): self.assertEqual("hello", a.baz) self.assertRaises(AttributeError, getattr, a, "mike") - b = GridIn(self.db.fs, - content_type="text/html", chunk_size=1000, baz=100) + b = GridIn(self.db.fs, content_type="text/html", chunk_size=1000, baz=100) self.assertEqual("text/html", b.content_type) self.assertEqual(1000, b.chunk_size) self.assertEqual(100, b.baz) class TestGridFile(IntegrationTest): - def setUp(self): self.cleanup_colls(self.db.fs.files, self.db.fs.chunks) @@ -223,30 +228,48 @@ def test_grid_out_default_opts(self): self.assertEqual(None, b.metadata) self.assertEqual(None, b.md5) - for attr in ["_id", "name", "content_type", "length", "chunk_size", - "upload_date", "aliases", "metadata", "md5"]: + for attr in [ + "_id", + "name", + "content_type", + "length", + "chunk_size", + "upload_date", + "aliases", + "metadata", + "md5", + ]: self.assertRaises(AttributeError, setattr, b, attr, 5) def test_grid_out_cursor_options(self): - self.assertRaises(TypeError, GridOutCursor.__init__, self.db.fs, {}, - projection={"filename": 1}) + self.assertRaises( + TypeError, GridOutCursor.__init__, self.db.fs, {}, projection={"filename": 1} + ) cursor = GridOutCursor(self.db.fs, {}) cursor_clone = cursor.clone() cursor_dict = cursor.__dict__.copy() - cursor_dict.pop('_Cursor__session') + cursor_dict.pop("_Cursor__session") cursor_clone_dict = cursor_clone.__dict__.copy() - cursor_clone_dict.pop('_Cursor__session') + cursor_clone_dict.pop("_Cursor__session") self.assertEqual(cursor_dict, cursor_clone_dict) self.assertRaises(NotImplementedError, cursor.add_option, 0) self.assertRaises(NotImplementedError, cursor.remove_option, 0) def test_grid_out_custom_opts(self): - one = GridIn(self.db.fs, _id=5, filename="my_file", - contentType="text/html", chunkSize=1000, aliases=["foo"], - metadata={"foo": 1, "bar": 2}, bar=3, baz="hello") + one = GridIn( + self.db.fs, + _id=5, + filename="my_file", + contentType="text/html", + chunkSize=1000, + aliases=["foo"], + metadata={"foo": 1, "bar": 2}, + bar=3, + baz="hello", + ) one.write(b"hello world") one.close() @@ -264,8 +287,17 @@ def test_grid_out_custom_opts(self): self.assertEqual(3, two.bar) self.assertEqual(None, two.md5) - for attr in ["_id", "name", "content_type", "length", "chunk_size", - "upload_date", "aliases", "metadata", "md5"]: + for attr in [ + "_id", + "name", + "content_type", + "length", + "chunk_size", + "upload_date", + "aliases", + "metadata", + "md5", + ]: self.assertRaises(AttributeError, setattr, two, attr, 5) def test_grid_out_file_document(self): @@ -276,8 +308,7 @@ def test_grid_out_file_document(self): two = GridOut(self.db.fs, file_document=self.db.fs.files.find_one()) self.assertEqual(b"foo bar", two.read()) - three = GridOut(self.db.fs, 5, - file_document=self.db.fs.files.find_one()) + three = GridOut(self.db.fs, 5, file_document=self.db.fs.files.find_one()) self.assertEqual(b"foo bar", three.read()) four = GridOut(self.db.fs, file_document={}) @@ -304,8 +335,7 @@ def test_write_file_like(self): five.write(buffer) five.write(b" and mongodb") five.close() - self.assertEqual(b"hello world and mongodb", - GridOut(self.db.fs, five._id).read()) + self.assertEqual(b"hello world and mongodb", GridOut(self.db.fs, five._id).read()) def test_write_lines(self): a = GridIn(self.db.fs) @@ -335,7 +365,7 @@ def test_closed(self): self.assertTrue(g.closed) def test_multi_chunk_file(self): - random_string = b'a' * (DEFAULT_CHUNK_SIZE + 1000) + random_string = b"a" * (DEFAULT_CHUNK_SIZE + 1000) f = GridIn(self.db.fs) f.write(random_string) @@ -369,8 +399,7 @@ def helper(data): self.assertEqual(data, g.read(10) + g.read(10)) return True - qcheck.check_unittest(self, helper, - qcheck.gen_string(qcheck.gen_range(0, 20))) + qcheck.check_unittest(self, helper, qcheck.gen_string(qcheck.gen_range(0, 20))) def test_seek(self): f = GridIn(self.db.fs, chunkSize=3) @@ -428,10 +457,14 @@ def test_multiple_reads(self): def test_readline(self): f = GridIn(self.db.fs, chunkSize=5) - f.write((b"""Hello world, + f.write( + ( + b"""Hello world, How are you? Hope all is well. -Bye""")) +Bye""" + ) + ) f.close() # Try read(), then readline(). @@ -460,10 +493,14 @@ def test_readline(self): def test_readlines(self): f = GridIn(self.db.fs, chunkSize=5) - f.write((b"""Hello world, + f.write( + ( + b"""Hello world, How are you? Hope all is well. -Bye""")) +Bye""" + ) + ) f.close() # Try read(), then readlines(). @@ -483,13 +520,13 @@ def test_readlines(self): # Only readlines(). g = GridOut(self.db.fs, f._id) self.assertEqual( - [b"Hello world,\n", b"How are you?\n", b"Hope all is well.\n", b"Bye"], - g.readlines()) + [b"Hello world,\n", b"How are you?\n", b"Hope all is well.\n", b"Bye"], g.readlines() + ) g = GridOut(self.db.fs, f._id) self.assertEqual( - [b"Hello world,\n", b"How are you?\n", b"Hope all is well.\n", b"Bye"], - g.readlines(0)) + [b"Hello world,\n", b"How are you?\n", b"Hope all is well.\n", b"Bye"], g.readlines(0) + ) g = GridOut(self.db.fs, f._id) self.assertEqual([b"Hello world,\n"], g.readlines(1)) @@ -539,14 +576,13 @@ def test_iterator(self): self.assertEqual([b"hello world"], list(g)) def test_read_unaligned_buffer_size(self): - in_data = (b"This is a text that doesn't " - b"quite fit in a single 16-byte chunk.") + in_data = b"This is a text that doesn't " b"quite fit in a single 16-byte chunk." f = GridIn(self.db.fs, chunkSize=16) f.write(in_data) f.close() g = GridOut(self.db.fs, f._id) - out_data = b'' + out_data = b"" while 1: s = g.read(13) if not s: @@ -556,7 +592,7 @@ def test_read_unaligned_buffer_size(self): self.assertEqual(in_data, out_data) def test_readchunk(self): - in_data = b'a' * 10 + in_data = b"a" * 10 f = GridIn(self.db.fs, chunkSize=3) f.write(in_data) f.close() @@ -636,13 +672,12 @@ def test_context_manager(self): self.assertEqual(contents, outfile.read()) def test_prechunked_string(self): - def write_me(s, chunk_size): buf = BytesIO(s) infile = GridIn(self.db.fs) while True: to_write = buf.read(chunk_size) - if to_write == b'': + if to_write == b"": break infile.write(to_write) infile.close() @@ -652,7 +687,7 @@ def write_me(s, chunk_size): data = outfile.read() self.assertEqual(s, data) - s = b'x' * DEFAULT_CHUNK_SIZE * 4 + s = b"x" * DEFAULT_CHUNK_SIZE * 4 # Test with default chunk size write_me(s, DEFAULT_CHUNK_SIZE) # Multiple @@ -664,7 +699,7 @@ def test_grid_out_lazy_connect(self): fs = self.db.fs outfile = GridOut(fs, file_id=-1) self.assertRaises(NoFile, outfile.read) - self.assertRaises(NoFile, getattr, outfile, 'filename') + self.assertRaises(NoFile, getattr, outfile, "filename") infile = GridIn(fs, filename=1) infile.close() @@ -677,11 +712,10 @@ def test_grid_out_lazy_connect(self): outfile.readchunk() def test_grid_in_lazy_connect(self): - client = MongoClient('badhost', connect=False, - serverSelectionTimeoutMS=10) + client = MongoClient("badhost", connect=False, serverSelectionTimeoutMS=10) fs = client.db.fs infile = GridIn(fs, file_id=-1, chunk_size=1) - self.assertRaises(ServerSelectionTimeoutError, infile.write, b'data') + self.assertRaises(ServerSelectionTimeoutError, infile.write, b"data") self.assertRaises(ServerSelectionTimeoutError, infile.close) def test_unacknowledged(self): @@ -693,7 +727,7 @@ def test_survive_cursor_not_found(self): # By default the find command returns 101 documents in the first batch. # Use 102 batches to cause a single getMore. chunk_size = 1024 - data = b'd' * (102 * chunk_size) + data = b"d" * (102 * chunk_size) listener = EventListener() client = rs_or_single_client(event_listeners=[listener]) db = client.pymongo_test @@ -708,7 +742,8 @@ def test_survive_cursor_not_found(self): # readchunk(). client._close_cursor_now( outfile._GridOut__chunk_iter._cursor.cursor_id, - _CursorAddress(client.address, db.fs.chunks.full_name)) + _CursorAddress(client.address, db.fs.chunks.full_name), + ) # Read the rest of the file without error. self.assertEqual(len(outfile.read()), len(data) - chunk_size) diff --git a/test/test_gridfs.py b/test/test_gridfs.py index d7d5a74e5f..b6f959e496 100644 --- a/test/test_gridfs.py +++ b/test/test_gridfs.py @@ -21,32 +21,27 @@ import sys import threading import time - from io import BytesIO sys.path[0:0] = [""] -from bson.binary import Binary -from pymongo.mongo_client import MongoClient -from pymongo.errors import (ConfigurationError, - NotPrimaryError, - ServerSelectionTimeoutError) -from pymongo.read_preferences import ReadPreference +from test import IntegrationTest, client_context, unittest +from test.utils import joinall, one, rs_client, rs_or_single_client, single_client + import gridfs +from bson.binary import Binary from gridfs.errors import CorruptGridFile, FileExists, NoFile from gridfs.grid_file import GridOutCursor -from test import (client_context, - unittest, - IntegrationTest) -from test.utils import (joinall, - one, - rs_client, - rs_or_single_client, - single_client) +from pymongo.errors import ( + ConfigurationError, + NotPrimaryError, + ServerSelectionTimeoutError, +) +from pymongo.mongo_client import MongoClient +from pymongo.read_preferences import ReadPreference class JustWrite(threading.Thread): - def __init__(self, fs, n): threading.Thread.__init__(self) self.fs = fs @@ -61,7 +56,6 @@ def run(self): class JustRead(threading.Thread): - def __init__(self, fs, n, results): threading.Thread.__init__(self) self.fs = fs @@ -78,7 +72,6 @@ def run(self): class TestGridfsNoConnect(unittest.TestCase): - @classmethod def setUpClass(cls): cls.db = MongoClient(connect=False).pymongo_test @@ -89,7 +82,6 @@ def test_gridfs(self): class TestGridfs(IntegrationTest): - @classmethod def setUpClass(cls): super(TestGridfs, cls).setUpClass() @@ -97,8 +89,9 @@ def setUpClass(cls): cls.alt = gridfs.GridFS(cls.db, "alt") def setUp(self): - self.cleanup_colls(self.db.fs.files, self.db.fs.chunks, - self.db.alt.files, self.db.alt.chunks) + self.cleanup_colls( + self.db.fs.files, self.db.fs.chunks, self.db.alt.files, self.db.alt.chunks + ) def test_basic(self): oid = self.fs.put(b"hello world") @@ -142,8 +135,7 @@ def test_list(self): self.fs.put(b"foo", filename="test") self.fs.put(b"", filename="hello world") - self.assertEqual(set(["mike", "test", "hello world"]), - set(self.fs.list())) + self.assertEqual(set(["mike", "test", "hello world"]), set(self.fs.list())) def test_empty_file(self): oid = self.fs.put(b"") @@ -159,9 +151,8 @@ def test_empty_file(self): self.assertNotIn("md5", raw) def test_corrupt_chunk(self): - files_id = self.fs.put(b'foobar') - self.db.fs.chunks.update_one({'files_id': files_id}, - {'$set': {'data': Binary(b'foo', 0)}}) + files_id = self.fs.put(b"foobar") + self.db.fs.chunks.update_one({"files_id": files_id}, {"$set": {"data": Binary(b"foo", 0)}}) try: out = self.fs.get(files_id) self.assertRaises(CorruptGridFile, out.read) @@ -179,12 +170,18 @@ def test_put_ensures_index(self): files.drop() self.fs.put(b"junk") - self.assertTrue(any( - info.get('key') == [('files_id', 1), ('n', 1)] - for info in chunks.index_information().values())) - self.assertTrue(any( - info.get('key') == [('filename', 1), ('uploadDate', 1)] - for info in files.index_information().values())) + self.assertTrue( + any( + info.get("key") == [("files_id", 1), ("n", 1)] + for info in chunks.index_information().values() + ) + ) + self.assertTrue( + any( + info.get("key") == [("filename", 1), ("uploadDate", 1)] + for info in files.index_information().values() + ) + ) def test_alt_collection(self): oid = self.alt.put(b"hello world") @@ -206,8 +203,7 @@ def test_alt_collection(self): self.alt.put(b"foo", filename="test") self.alt.put(b"", filename="hello world") - self.assertEqual(set(["mike", "test", "hello world"]), - set(self.alt.list())) + self.assertEqual(set(["mike", "test", "hello world"]), set(self.alt.list())) def test_threaded_reads(self): self.fs.put(b"hello", _id="test") @@ -220,10 +216,7 @@ def test_threaded_reads(self): joinall(threads) - self.assertEqual( - 100 * [b'hello'], - results - ) + self.assertEqual(100 * [b"hello"], results) def test_threaded_writes(self): threads = [] @@ -237,10 +230,7 @@ def test_threaded_writes(self): self.assertEqual(f.read(), b"hello") # Should have created 100 versions of 'test' file - self.assertEqual( - 100, - self.db.fs.files.count_documents({'filename': 'test'}) - ) + self.assertEqual(100, self.db.fs.files.count_documents({"filename": "test"})) def test_get_last_version(self): one = self.fs.put(b"foo", filename="test") @@ -311,30 +301,25 @@ def test_get_version_with_metadata(self): three = self.fs.put(b"baz", filename="test", author="author2") self.assertEqual( - b"foo", - self.fs.get_version( - filename="test", author="author1", version=-2).read()) - self.assertEqual( - b"bar", self.fs.get_version( - filename="test", author="author1", version=-1).read()) - self.assertEqual( - b"foo", self.fs.get_version( - filename="test", author="author1", version=0).read()) + b"foo", self.fs.get_version(filename="test", author="author1", version=-2).read() + ) self.assertEqual( - b"bar", self.fs.get_version( - filename="test", author="author1", version=1).read()) + b"bar", self.fs.get_version(filename="test", author="author1", version=-1).read() + ) self.assertEqual( - b"baz", self.fs.get_version( - filename="test", author="author2", version=0).read()) + b"foo", self.fs.get_version(filename="test", author="author1", version=0).read() + ) self.assertEqual( - b"baz", self.fs.get_version(filename="test", version=-1).read()) + b"bar", self.fs.get_version(filename="test", author="author1", version=1).read() + ) self.assertEqual( - b"baz", self.fs.get_version(filename="test", version=2).read()) + b"baz", self.fs.get_version(filename="test", author="author2", version=0).read() + ) + self.assertEqual(b"baz", self.fs.get_version(filename="test", version=-1).read()) + self.assertEqual(b"baz", self.fs.get_version(filename="test", version=2).read()) - self.assertRaises( - NoFile, self.fs.get_version, filename="test", author="author3") - self.assertRaises( - NoFile, self.fs.get_version, filename="test", author="author1", version=2) + self.assertRaises(NoFile, self.fs.get_version, filename="test", author="author3") + self.assertRaises(NoFile, self.fs.get_version, filename="test", author="author1", version=2) self.fs.delete(one) self.fs.delete(two) @@ -354,7 +339,7 @@ def test_file_exists(self): one.close() two = self.fs.new_file(_id=123) - self.assertRaises(FileExists, two.write, b'x' * 262146) + self.assertRaises(FileExists, two.write, b"x" * 262146) def test_exists(self): oid = self.fs.put(b"hello") @@ -408,8 +393,7 @@ def iterate_file(grid_file): self.assertTrue(iterate_file(f)) def test_gridfs_lazy_connect(self): - client = MongoClient('badhost', connect=False, - serverSelectionTimeoutMS=10) + client = MongoClient("badhost", connect=False, serverSelectionTimeoutMS=10) db = client.db gfs = gridfs.GridFS(db) self.assertRaises(ServerSelectionTimeoutError, gfs.list) @@ -429,8 +413,7 @@ def test_gridfs_find(self): files = self.db.fs.files self.assertEqual(3, files.count_documents({"filename": "two"})) self.assertEqual(4, files.count_documents({})) - cursor = self.fs.find( - no_cursor_timeout=False).sort("uploadDate", -1).skip(1).limit(2) + cursor = self.fs.find(no_cursor_timeout=False).sort("uploadDate", -1).skip(1).limit(2) gout = next(cursor) self.assertEqual(b"test1", gout.read()) cursor.rewind() @@ -453,26 +436,24 @@ def test_delete_not_initialized(self): def test_gridfs_find_one(self): self.assertEqual(None, self.fs.find_one()) - id1 = self.fs.put(b'test1', filename='file1') - self.assertEqual(b'test1', self.fs.find_one().read()) + id1 = self.fs.put(b"test1", filename="file1") + self.assertEqual(b"test1", self.fs.find_one().read()) - id2 = self.fs.put(b'test2', filename='file2', meta='data') - self.assertEqual(b'test1', self.fs.find_one(id1).read()) - self.assertEqual(b'test2', self.fs.find_one(id2).read()) + id2 = self.fs.put(b"test2", filename="file2", meta="data") + self.assertEqual(b"test1", self.fs.find_one(id1).read()) + self.assertEqual(b"test2", self.fs.find_one(id2).read()) - self.assertEqual(b'test1', - self.fs.find_one({'filename': 'file1'}).read()) + self.assertEqual(b"test1", self.fs.find_one({"filename": "file1"}).read()) - self.assertEqual('data', self.fs.find_one(id2).meta) + self.assertEqual("data", self.fs.find_one(id2).meta) def test_grid_in_non_int_chunksize(self): # Lua, and perhaps other buggy GridFS clients, store size as a float. - data = b'data' - self.fs.put(data, filename='f') - self.db.fs.files.update_one({'filename': 'f'}, - {'$set': {'chunkSize': 100.0}}) + data = b"data" + self.fs.put(data, filename="f") + self.db.fs.files.update_one({"filename": "f"}, {"$set": {"chunkSize": 100.0}}) - self.assertEqual(data, self.fs.get_version('f').read()) + self.assertEqual(data, self.fs.get_version("f").read()) def test_unacknowledged(self): # w=0 is prohibited. @@ -494,7 +475,6 @@ def test_md5(self): class TestGridfsReplicaSet(IntegrationTest): - @classmethod @client_context.require_secondaries_count(1) def setUpClass(cls): @@ -502,51 +482,47 @@ def setUpClass(cls): @classmethod def tearDownClass(cls): - client_context.client.drop_database('gfsreplica') + client_context.client.drop_database("gfsreplica") def test_gridfs_replica_set(self): - rsc = rs_client( - w=client_context.w, - read_preference=ReadPreference.SECONDARY) + rsc = rs_client(w=client_context.w, read_preference=ReadPreference.SECONDARY) - fs = gridfs.GridFS(rsc.gfsreplica, 'gfsreplicatest') + fs = gridfs.GridFS(rsc.gfsreplica, "gfsreplicatest") gin = fs.new_file() self.assertEqual(gin._coll.read_preference, ReadPreference.PRIMARY) - oid = fs.put(b'foo') + oid = fs.put(b"foo") content = fs.get(oid).read() - self.assertEqual(b'foo', content) + self.assertEqual(b"foo", content) def test_gridfs_secondary(self): secondary_host, secondary_port = one(self.client.secondaries) secondary_connection = single_client( - secondary_host, secondary_port, - read_preference=ReadPreference.SECONDARY) + secondary_host, secondary_port, read_preference=ReadPreference.SECONDARY + ) # Should detect it's connected to secondary and not attempt to # create index - fs = gridfs.GridFS(secondary_connection.gfsreplica, 'gfssecondarytest') + fs = gridfs.GridFS(secondary_connection.gfsreplica, "gfssecondarytest") # This won't detect secondary, raises error - self.assertRaises(NotPrimaryError, fs.put, b'foo') + self.assertRaises(NotPrimaryError, fs.put, b"foo") def test_gridfs_secondary_lazy(self): # Should detect it's connected to secondary and not attempt to # create index. secondary_host, secondary_port = one(self.client.secondaries) client = single_client( - secondary_host, - secondary_port, - read_preference=ReadPreference.SECONDARY, - connect=False) + secondary_host, secondary_port, read_preference=ReadPreference.SECONDARY, connect=False + ) # Still no connection. - fs = gridfs.GridFS(client.gfsreplica, 'gfssecondarylazytest') + fs = gridfs.GridFS(client.gfsreplica, "gfssecondarylazytest") # Connects, doesn't create index. self.assertRaises(NoFile, fs.get_last_version) - self.assertRaises(NotPrimaryError, fs.put, 'data') + self.assertRaises(NotPrimaryError, fs.put, "data") if __name__ == "__main__": diff --git a/test/test_gridfs_bucket.py b/test/test_gridfs_bucket.py index 499643f673..38f81b3613 100644 --- a/test/test_gridfs_bucket.py +++ b/test/test_gridfs_bucket.py @@ -20,32 +20,26 @@ import itertools import threading import time - from io import BytesIO +from test import IntegrationTest, client_context, unittest +from test.utils import joinall, one, rs_client, rs_or_single_client, single_client +import gridfs from bson.binary import Binary from bson.int64 import Int64 from bson.objectid import ObjectId from bson.son import SON -import gridfs -from gridfs.errors import NoFile, CorruptGridFile -from pymongo.errors import (ConfigurationError, - NotPrimaryError, - ServerSelectionTimeoutError) +from gridfs.errors import CorruptGridFile, NoFile +from pymongo.errors import ( + ConfigurationError, + NotPrimaryError, + ServerSelectionTimeoutError, +) from pymongo.mongo_client import MongoClient from pymongo.read_preferences import ReadPreference -from test import (client_context, - unittest, - IntegrationTest) -from test.utils import (joinall, - one, - rs_client, - rs_or_single_client, - single_client) class JustWrite(threading.Thread): - def __init__(self, gfs, num): threading.Thread.__init__(self) self.gfs = gfs @@ -60,7 +54,6 @@ def run(self): class JustRead(threading.Thread): - def __init__(self, gfs, num, results): threading.Thread.__init__(self) self.gfs = gfs @@ -77,23 +70,20 @@ def run(self): class TestGridfs(IntegrationTest): - @classmethod def setUpClass(cls): super(TestGridfs, cls).setUpClass() cls.fs = gridfs.GridFSBucket(cls.db) - cls.alt = gridfs.GridFSBucket( - cls.db, bucket_name="alt") + cls.alt = gridfs.GridFSBucket(cls.db, bucket_name="alt") def setUp(self): - self.cleanup_colls(self.db.fs.files, self.db.fs.chunks, - self.db.alt.files, self.db.alt.chunks) + self.cleanup_colls( + self.db.fs.files, self.db.fs.chunks, self.db.alt.files, self.db.alt.chunks + ) def test_basic(self): - oid = self.fs.upload_from_stream("test_filename", - b"hello world") - self.assertEqual(b"hello world", - self.fs.open_download_stream(oid).read()) + oid = self.fs.upload_from_stream("test_filename", b"hello world") + self.assertEqual(b"hello world", self.fs.open_download_stream(oid).read()) self.assertEqual(1, self.db.fs.files.count_documents({})) self.assertEqual(1, self.db.fs.chunks.count_documents({})) @@ -106,9 +96,7 @@ def test_multi_chunk_delete(self): self.assertEqual(0, self.db.fs.files.count_documents({})) self.assertEqual(0, self.db.fs.chunks.count_documents({})) gfs = gridfs.GridFSBucket(self.db) - oid = gfs.upload_from_stream("test_filename", - b"hello", - chunk_size_bytes=1) + oid = gfs.upload_from_stream("test_filename", b"hello", chunk_size_bytes=1) self.assertEqual(1, self.db.fs.files.count_documents({})) self.assertEqual(5, self.db.fs.chunks.count_documents({})) gfs.delete(oid) @@ -116,8 +104,7 @@ def test_multi_chunk_delete(self): self.assertEqual(0, self.db.fs.chunks.count_documents({})) def test_empty_file(self): - oid = self.fs.upload_from_stream("test_filename", - b"") + oid = self.fs.upload_from_stream("test_filename", b"") self.assertEqual(b"", self.fs.open_download_stream(oid).read()) self.assertEqual(1, self.db.fs.files.count_documents({})) self.assertEqual(0, self.db.fs.chunks.count_documents({})) @@ -130,10 +117,8 @@ def test_empty_file(self): self.assertNotIn("md5", raw) def test_corrupt_chunk(self): - files_id = self.fs.upload_from_stream("test_filename", - b'foobar') - self.db.fs.chunks.update_one({'files_id': files_id}, - {'$set': {'data': Binary(b'foo', 0)}}) + files_id = self.fs.upload_from_stream("test_filename", b"foobar") + self.db.fs.chunks.update_one({"files_id": files_id}, {"$set": {"data": Binary(b"foo", 0)}}) try: out = self.fs.open_download_stream(files_id) self.assertRaises(CorruptGridFile, out.read) @@ -151,37 +136,45 @@ def test_upload_ensures_index(self): files.drop() self.fs.upload_from_stream("filename", b"junk") - self.assertTrue(any( - info.get('key') == [('files_id', 1), ('n', 1)] - for info in chunks.index_information().values())) - self.assertTrue(any( - info.get('key') == [('filename', 1), ('uploadDate', 1)] - for info in files.index_information().values())) + self.assertTrue( + any( + info.get("key") == [("files_id", 1), ("n", 1)] + for info in chunks.index_information().values() + ) + ) + self.assertTrue( + any( + info.get("key") == [("filename", 1), ("uploadDate", 1)] + for info in files.index_information().values() + ) + ) def test_ensure_index_shell_compat(self): files = self.db.fs.files - for i, j in itertools.combinations_with_replacement( - [1, 1.0, Int64(1)], 2): + for i, j in itertools.combinations_with_replacement([1, 1.0, Int64(1)], 2): # Create the index with different numeric types (as might be done # from the mongo shell). - shell_index = [('filename', i), ('uploadDate', j)] - self.db.command('createIndexes', files.name, - indexes=[{'key': SON(shell_index), - 'name': 'filename_1.0_uploadDate_1.0'}]) + shell_index = [("filename", i), ("uploadDate", j)] + self.db.command( + "createIndexes", + files.name, + indexes=[{"key": SON(shell_index), "name": "filename_1.0_uploadDate_1.0"}], + ) # No error. self.fs.upload_from_stream("filename", b"data") - self.assertTrue(any( - info.get('key') == [('filename', 1), ('uploadDate', 1)] - for info in files.index_information().values())) + self.assertTrue( + any( + info.get("key") == [("filename", 1), ("uploadDate", 1)] + for info in files.index_information().values() + ) + ) files.drop() def test_alt_collection(self): - oid = self.alt.upload_from_stream("test_filename", - b"hello world") - self.assertEqual(b"hello world", - self.alt.open_download_stream(oid).read()) + oid = self.alt.upload_from_stream("test_filename", b"hello world") + self.assertEqual(b"hello world", self.alt.open_download_stream(oid).read()) self.assertEqual(1, self.db.alt.files.count_documents({})) self.assertEqual(1, self.db.alt.chunks.count_documents({})) @@ -191,18 +184,17 @@ def test_alt_collection(self): self.assertEqual(0, self.db.alt.chunks.count_documents({})) self.assertRaises(NoFile, self.alt.open_download_stream, "foo") - self.alt.upload_from_stream("foo", - b"hello world") - self.assertEqual(b"hello world", - self.alt.open_download_stream_by_name("foo").read()) + self.alt.upload_from_stream("foo", b"hello world") + self.assertEqual(b"hello world", self.alt.open_download_stream_by_name("foo").read()) self.alt.upload_from_stream("mike", b"") self.alt.upload_from_stream("test", b"foo") self.alt.upload_from_stream("hello world", b"") - self.assertEqual(set(["mike", "test", "hello world", "foo"]), - set(k["filename"] for k in list( - self.db.alt.files.find()))) + self.assertEqual( + set(["mike", "test", "hello world", "foo"]), + set(k["filename"] for k in list(self.db.alt.files.find())), + ) def test_threaded_reads(self): self.fs.upload_from_stream("test", b"hello") @@ -215,10 +207,7 @@ def test_threaded_reads(self): joinall(threads) - self.assertEqual( - 100 * [b'hello'], - results - ) + self.assertEqual(100 * [b"hello"], results) def test_threaded_writes(self): threads = [] @@ -232,10 +221,7 @@ def test_threaded_writes(self): self.assertEqual(fstr.read(), b"hello") # Should have created 100 versions of 'test' file - self.assertEqual( - 100, - self.db.fs.files.count_documents({'filename': 'test'}) - ) + self.assertEqual(100, self.db.fs.files.count_documents({"filename": "test"})) def test_get_last_version(self): one = self.fs.upload_from_stream("test", b"foo") @@ -247,17 +233,13 @@ def test_get_last_version(self): two = two._id three = self.fs.upload_from_stream("test", b"baz") - self.assertEqual(b"baz", - self.fs.open_download_stream_by_name("test").read()) + self.assertEqual(b"baz", self.fs.open_download_stream_by_name("test").read()) self.fs.delete(three) - self.assertEqual(b"bar", - self.fs.open_download_stream_by_name("test").read()) + self.assertEqual(b"bar", self.fs.open_download_stream_by_name("test").read()) self.fs.delete(two) - self.assertEqual(b"foo", - self.fs.open_download_stream_by_name("test").read()) + self.assertEqual(b"foo", self.fs.open_download_stream_by_name("test").read()) self.fs.delete(one) - self.assertRaises(NoFile, - self.fs.open_download_stream_by_name, "test") + self.assertRaises(NoFile, self.fs.open_download_stream_by_name, "test") def test_get_version(self): self.fs.upload_from_stream("test", b"foo") @@ -267,56 +249,41 @@ def test_get_version(self): self.fs.upload_from_stream("test", b"baz") time.sleep(0.01) - self.assertEqual(b"foo", self.fs.open_download_stream_by_name( - "test", revision=0).read()) - self.assertEqual(b"bar", self.fs.open_download_stream_by_name( - "test", revision=1).read()) - self.assertEqual(b"baz", self.fs.open_download_stream_by_name( - "test", revision=2).read()) - - self.assertEqual(b"baz", self.fs.open_download_stream_by_name( - "test", revision=-1).read()) - self.assertEqual(b"bar", self.fs.open_download_stream_by_name( - "test", revision=-2).read()) - self.assertEqual(b"foo", self.fs.open_download_stream_by_name( - "test", revision=-3).read()) - - self.assertRaises(NoFile, self.fs.open_download_stream_by_name, - "test", revision=3) - self.assertRaises(NoFile, self.fs.open_download_stream_by_name, - "test", revision=-4) + self.assertEqual(b"foo", self.fs.open_download_stream_by_name("test", revision=0).read()) + self.assertEqual(b"bar", self.fs.open_download_stream_by_name("test", revision=1).read()) + self.assertEqual(b"baz", self.fs.open_download_stream_by_name("test", revision=2).read()) + + self.assertEqual(b"baz", self.fs.open_download_stream_by_name("test", revision=-1).read()) + self.assertEqual(b"bar", self.fs.open_download_stream_by_name("test", revision=-2).read()) + self.assertEqual(b"foo", self.fs.open_download_stream_by_name("test", revision=-3).read()) + + self.assertRaises(NoFile, self.fs.open_download_stream_by_name, "test", revision=3) + self.assertRaises(NoFile, self.fs.open_download_stream_by_name, "test", revision=-4) def test_upload_from_stream(self): - oid = self.fs.upload_from_stream("test_file", - BytesIO(b"hello world"), - chunk_size_bytes=1) + oid = self.fs.upload_from_stream("test_file", BytesIO(b"hello world"), chunk_size_bytes=1) self.assertEqual(11, self.db.fs.chunks.count_documents({})) - self.assertEqual(b"hello world", - self.fs.open_download_stream(oid).read()) + self.assertEqual(b"hello world", self.fs.open_download_stream(oid).read()) def test_upload_from_stream_with_id(self): oid = ObjectId() - self.fs.upload_from_stream_with_id(oid, - "test_file_custom_id", - BytesIO(b"custom id"), - chunk_size_bytes=1) - self.assertEqual(b"custom id", - self.fs.open_download_stream(oid).read()) + self.fs.upload_from_stream_with_id( + oid, "test_file_custom_id", BytesIO(b"custom id"), chunk_size_bytes=1 + ) + self.assertEqual(b"custom id", self.fs.open_download_stream(oid).read()) def test_open_upload_stream(self): gin = self.fs.open_upload_stream("from_stream") gin.write(b"from stream") gin.close() - self.assertEqual(b"from stream", - self.fs.open_download_stream(gin._id).read()) + self.assertEqual(b"from stream", self.fs.open_download_stream(gin._id).read()) def test_open_upload_stream_with_id(self): oid = ObjectId() gin = self.fs.open_upload_stream_with_id(oid, "from_stream_custom_id") gin.write(b"from stream with custom id") gin.close() - self.assertEqual(b"from stream with custom id", - self.fs.open_download_stream(oid).read()) + self.assertEqual(b"from stream with custom id", self.fs.open_download_stream(oid).read()) def test_missing_length_iter(self): # Test fix that guards against PHP-237 @@ -334,16 +301,15 @@ def iterate_file(grid_file): self.assertTrue(iterate_file(fstr)) def test_gridfs_lazy_connect(self): - client = MongoClient('badhost', connect=False, - serverSelectionTimeoutMS=0) + client = MongoClient("badhost", connect=False, serverSelectionTimeoutMS=0) cdb = client.db gfs = gridfs.GridFSBucket(cdb) self.assertRaises(ServerSelectionTimeoutError, gfs.delete, 0) gfs = gridfs.GridFSBucket(cdb) self.assertRaises( - ServerSelectionTimeoutError, - gfs.upload_from_stream, "test", b"") # Still no connection. + ServerSelectionTimeoutError, gfs.upload_from_stream, "test", b"" + ) # Still no connection. def test_gridfs_find(self): self.fs.upload_from_stream("two", b"test2") @@ -357,8 +323,8 @@ def test_gridfs_find(self): self.assertEqual(3, files.count_documents({"filename": "two"})) self.assertEqual(4, files.count_documents({})) cursor = self.fs.find( - {}, no_cursor_timeout=False, sort=[("uploadDate", -1)], - skip=1, limit=2) + {}, no_cursor_timeout=False, sort=[("uploadDate", -1)], skip=1, limit=2 + ) gout = next(cursor) self.assertEqual(b"test1", gout.read()) cursor.rewind() @@ -372,13 +338,11 @@ def test_gridfs_find(self): def test_grid_in_non_int_chunksize(self): # Lua, and perhaps other buggy GridFS clients, store size as a float. - data = b'data' - self.fs.upload_from_stream('f', data) - self.db.fs.files.update_one({'filename': 'f'}, - {'$set': {'chunkSize': 100.0}}) + data = b"data" + self.fs.upload_from_stream("f", data) + self.db.fs.files.update_one({"filename": "f"}, {"$set": {"chunkSize": 100.0}}) - self.assertEqual(data, - self.fs.open_download_stream_by_name('f').read()) + self.assertEqual(data, self.fs.open_download_stream_by_name("f").read()) def test_unacknowledged(self): # w=0 is prohibited. @@ -386,29 +350,23 @@ def test_unacknowledged(self): gridfs.GridFSBucket(rs_or_single_client(w=0).pymongo_test) def test_rename(self): - _id = self.fs.upload_from_stream("first_name", b'testing') - self.assertEqual(b'testing', self.fs.open_download_stream_by_name( - "first_name").read()) + _id = self.fs.upload_from_stream("first_name", b"testing") + self.assertEqual(b"testing", self.fs.open_download_stream_by_name("first_name").read()) self.fs.rename(_id, "second_name") - self.assertRaises(NoFile, self.fs.open_download_stream_by_name, - "first_name") - self.assertEqual(b"testing", self.fs.open_download_stream_by_name( - "second_name").read()) + self.assertRaises(NoFile, self.fs.open_download_stream_by_name, "first_name") + self.assertEqual(b"testing", self.fs.open_download_stream_by_name("second_name").read()) def test_abort(self): - gin = self.fs.open_upload_stream("test_filename", - chunk_size_bytes=5) + gin = self.fs.open_upload_stream("test_filename", chunk_size_bytes=5) gin.write(b"test1") gin.write(b"test2") gin.write(b"test3") - self.assertEqual(3, self.db.fs.chunks.count_documents( - {"files_id": gin._id})) + self.assertEqual(3, self.db.fs.chunks.count_documents({"files_id": gin._id})) gin.abort() self.assertTrue(gin.closed) self.assertRaises(ValueError, gin.write, b"test4") - self.assertEqual(0, self.db.fs.chunks.count_documents( - {"files_id": gin._id})) + self.assertEqual(0, self.db.fs.chunks.count_documents({"files_id": gin._id})) def test_download_to_stream(self): file1 = BytesIO(b"hello world") @@ -425,9 +383,7 @@ def test_download_to_stream(self): self.db.drop_collection("fs.files") self.db.drop_collection("fs.chunks") file1.seek(0) - oid = self.fs.upload_from_stream("many_chunks", - file1, - chunk_size_bytes=1) + oid = self.fs.upload_from_stream("many_chunks", file1, chunk_size_bytes=1) self.assertEqual(11, self.db.fs.chunks.count_documents({})) file2 = BytesIO() self.fs.download_to_stream(oid, file2) @@ -478,7 +434,6 @@ def test_md5(self): class TestGridfsBucketReplicaSet(IntegrationTest): - @classmethod @client_context.require_secondaries_count(1) def setUpClass(cls): @@ -486,52 +441,43 @@ def setUpClass(cls): @classmethod def tearDownClass(cls): - client_context.client.drop_database('gfsbucketreplica') + client_context.client.drop_database("gfsbucketreplica") def test_gridfs_replica_set(self): - rsc = rs_client( - w=client_context.w, - read_preference=ReadPreference.SECONDARY) + rsc = rs_client(w=client_context.w, read_preference=ReadPreference.SECONDARY) - gfs = gridfs.GridFSBucket(rsc.gfsbucketreplica, 'gfsbucketreplicatest') - oid = gfs.upload_from_stream("test_filename", b'foo') + gfs = gridfs.GridFSBucket(rsc.gfsbucketreplica, "gfsbucketreplicatest") + oid = gfs.upload_from_stream("test_filename", b"foo") content = gfs.open_download_stream(oid).read() - self.assertEqual(b'foo', content) + self.assertEqual(b"foo", content) def test_gridfs_secondary(self): secondary_host, secondary_port = one(self.client.secondaries) secondary_connection = single_client( - secondary_host, secondary_port, - read_preference=ReadPreference.SECONDARY) + secondary_host, secondary_port, read_preference=ReadPreference.SECONDARY + ) # Should detect it's connected to secondary and not attempt to # create index - gfs = gridfs.GridFSBucket( - secondary_connection.gfsbucketreplica, 'gfsbucketsecondarytest') + gfs = gridfs.GridFSBucket(secondary_connection.gfsbucketreplica, "gfsbucketsecondarytest") # This won't detect secondary, raises error - self.assertRaises(NotPrimaryError, gfs.upload_from_stream, - "test_filename", b'foo') + self.assertRaises(NotPrimaryError, gfs.upload_from_stream, "test_filename", b"foo") def test_gridfs_secondary_lazy(self): # Should detect it's connected to secondary and not attempt to # create index. secondary_host, secondary_port = one(self.client.secondaries) client = single_client( - secondary_host, - secondary_port, - read_preference=ReadPreference.SECONDARY, - connect=False) + secondary_host, secondary_port, read_preference=ReadPreference.SECONDARY, connect=False + ) # Still no connection. - gfs = gridfs.GridFSBucket( - client.gfsbucketreplica, 'gfsbucketsecondarylazytest') + gfs = gridfs.GridFSBucket(client.gfsbucketreplica, "gfsbucketsecondarylazytest") # Connects, doesn't create index. - self.assertRaises(NoFile, gfs.open_download_stream_by_name, - "test_filename") - self.assertRaises(NotPrimaryError, gfs.upload_from_stream, - "test_filename", b'data') + self.assertRaises(NoFile, gfs.open_download_stream_by_name, "test_filename") + self.assertRaises(NotPrimaryError, gfs.upload_from_stream, "test_filename", b"data") if __name__ == "__main__": diff --git a/test/test_gridfs_spec.py b/test/test_gridfs_spec.py index 86449db370..2ba8f461b9 100644 --- a/test/test_gridfs_spec.py +++ b/test/test_gridfs_spec.py @@ -19,39 +19,35 @@ import os import re import sys - from json import loads sys.path[0:0] = [""] +from test import IntegrationTest, unittest + +import gridfs from bson import Binary from bson.int64 import Int64 from bson.json_util import object_hook -import gridfs -from gridfs.errors import NoFile, CorruptGridFile -from test import (unittest, - IntegrationTest) +from gridfs.errors import CorruptGridFile, NoFile # Commands. -_COMMANDS = {"delete": lambda coll, doc: [coll.delete_many(d["q"]) - for d in doc['deletes']], - "insert": lambda coll, doc: coll.insert_many(doc['documents']), - "update": lambda coll, doc: [coll.update_many(u["q"], u["u"]) - for u in doc['updates']] - } +_COMMANDS = { + "delete": lambda coll, doc: [coll.delete_many(d["q"]) for d in doc["deletes"]], + "insert": lambda coll, doc: coll.insert_many(doc["documents"]), + "update": lambda coll, doc: [coll.update_many(u["q"], u["u"]) for u in doc["updates"]], +} # Location of JSON test specifications. -_TEST_PATH = os.path.join( - os.path.dirname(os.path.realpath(__file__)), - 'gridfs') +_TEST_PATH = os.path.join(os.path.dirname(os.path.realpath(__file__)), "gridfs") def camel_to_snake(camel): # Regex to convert CamelCase to snake_case. Special case for _id. if camel == "id": return "file_id" - snake = re.sub('(.)([A-Z][a-z]+)', r'\1_\2', camel) - return re.sub('([a-z0-9])([A-Z])', r'\1_\2', snake).lower() + snake = re.sub("(.)([A-Z][a-z]+)", r"\1_\2", camel) + return re.sub("([a-z0-9])([A-Z])", r"\1_\2", snake).lower() class TestAllScenarios(IntegrationTest): @@ -63,23 +59,25 @@ def setUpClass(cls): "upload": cls.fs.upload_from_stream, "download": cls.fs.open_download_stream, "delete": cls.fs.delete, - "download_by_name": cls.fs.open_download_stream_by_name} + "download_by_name": cls.fs.open_download_stream_by_name, + } def init_db(self, data, test): - self.cleanup_colls(self.db.fs.files, self.db.fs.chunks, - self.db.expected.files, self.db.expected.chunks) + self.cleanup_colls( + self.db.fs.files, self.db.fs.chunks, self.db.expected.files, self.db.expected.chunks + ) # Read in data. - if data['files']: - self.db.fs.files.insert_many(data['files']) - self.db.expected.files.insert_many(data['files']) - if data['chunks']: - self.db.fs.chunks.insert_many(data['chunks']) - self.db.expected.chunks.insert_many(data['chunks']) + if data["files"]: + self.db.fs.files.insert_many(data["files"]) + self.db.expected.files.insert_many(data["files"]) + if data["chunks"]: + self.db.fs.chunks.insert_many(data["chunks"]) + self.db.expected.chunks.insert_many(data["chunks"]) # Make initial modifications. if "arrange" in test: - for cmd in test['arrange'].get('data', []): + for cmd in test["arrange"].get("data", []): for key in cmd.keys(): if key in _COMMANDS: coll = self.db.get_collection(cmd[key]) @@ -87,11 +85,11 @@ def init_db(self, data, test): def init_expected_db(self, test, result): # Modify outcome DB. - for cmd in test['assert'].get('data', []): + for cmd in test["assert"].get("data", []): for key in cmd.keys(): if key in _COMMANDS: # Replace wildcards in inserts. - for doc in cmd.get('documents', []): + for doc in cmd.get("documents", []): keylist = doc.keys() for dockey in copy.deepcopy(list(keylist)): if "result" in str(doc[dockey]): @@ -104,8 +102,8 @@ def init_expected_db(self, test, result): coll = self.db.get_collection(cmd[key]) _COMMANDS[key](coll, cmd) - if test['assert'].get('result') == "&result": - test['assert']['result'] = result + if test["assert"].get("result") == "&result": + test["assert"]["result"] = result def sorted_list(self, coll, ignore_id): to_sort = [] @@ -126,30 +124,28 @@ def create_test(scenario_def): def run_scenario(self): # Run tests. - self.assertTrue(scenario_def['tests'], "tests cannot be empty") - for test in scenario_def['tests']: - self.init_db(scenario_def['data'], test) + self.assertTrue(scenario_def["tests"], "tests cannot be empty") + for test in scenario_def["tests"]: + self.init_db(scenario_def["data"], test) # Run GridFs Operation. - operation = self.str_to_cmd[test['act']['operation']] - args = test['act']['arguments'] + operation = self.str_to_cmd[test["act"]["operation"]] + args = test["act"]["arguments"] extra_opts = args.pop("options", {}) if "contentType" in extra_opts: - extra_opts["metadata"] = { - "contentType": extra_opts.pop("contentType")} + extra_opts["metadata"] = {"contentType": extra_opts.pop("contentType")} args.update(extra_opts) - converted_args = dict((camel_to_snake(c), v) - for c, v in args.items()) + converted_args = dict((camel_to_snake(c), v) for c, v in args.items()) - expect_error = test['assert'].get("error", False) + expect_error = test["assert"].get("error", False) result = None error = None try: result = operation(**converted_args) - if 'download' in test['act']['operation']: + if "download" in test["act"]["operation"]: result = Binary(result.read()) except Exception as exc: if not expect_error: @@ -159,47 +155,51 @@ def run_scenario(self): self.init_expected_db(test, result) # Asserts. - errors = {"FileNotFound": NoFile, - "ChunkIsMissing": CorruptGridFile, - "ExtraChunk": CorruptGridFile, - "ChunkIsWrongSize": CorruptGridFile, - "RevisionNotFound": NoFile} + errors = { + "FileNotFound": NoFile, + "ChunkIsMissing": CorruptGridFile, + "ExtraChunk": CorruptGridFile, + "ChunkIsWrongSize": CorruptGridFile, + "RevisionNotFound": NoFile, + } if expect_error: self.assertIsNotNone(error) - self.assertIsInstance(error, errors[test['assert']['error']], - test['description']) + self.assertIsInstance(error, errors[test["assert"]["error"]], test["description"]) else: self.assertIsNone(error) - if 'result' in test['assert']: - if test['assert']['result'] == 'void': - test['assert']['result'] = None - self.assertEqual(result, test['assert'].get('result')) + if "result" in test["assert"]: + if test["assert"]["result"] == "void": + test["assert"]["result"] = None + self.assertEqual(result, test["assert"].get("result")) - if 'data' in test['assert']: + if "data" in test["assert"]: # Create alphabetized list self.assertEqual( set(self.sorted_list(self.db.fs.chunks, True)), - set(self.sorted_list(self.db.expected.chunks, True))) + set(self.sorted_list(self.db.expected.chunks, True)), + ) self.assertEqual( set(self.sorted_list(self.db.fs.files, False)), - set(self.sorted_list(self.db.expected.files, False))) + set(self.sorted_list(self.db.expected.files, False)), + ) return run_scenario + def _object_hook(dct): - if 'length' in dct: - dct['length'] = Int64(dct['length']) + if "length" in dct: + dct["length"] = Int64(dct["length"]) return object_hook(dct) + def create_tests(): for dirpath, _, filenames in os.walk(_TEST_PATH): for filename in filenames: with open(os.path.join(dirpath, filename)) as scenario_stream: - scenario_def = loads( - scenario_stream.read(), object_hook=_object_hook) + scenario_def = loads(scenario_stream.read(), object_hook=_object_hook) # Because object_hook is already defined by bson.json_util, # and everything is named 'data' @@ -207,7 +207,7 @@ def str2hex(jsn): for key, val in jsn.items(): if key in ("data", "source", "result"): if "$hex" in val: - jsn[key] = Binary(bytes.fromhex(val['$hex'])) + jsn[key] = Binary(bytes.fromhex(val["$hex"])) if isinstance(jsn[key], dict): str2hex(jsn[key]) if isinstance(jsn[key], list): @@ -218,8 +218,7 @@ def str2hex(jsn): # Construct test from scenario. new_test = create_test(scenario_def) - test_name = 'test_%s' % ( - os.path.splitext(filename)[0]) + test_name = "test_%s" % (os.path.splitext(filename)[0]) new_test.__name__ = test_name setattr(TestAllScenarios, new_test.__name__, new_test) diff --git a/test/test_heartbeat_monitoring.py b/test/test_heartbeat_monitoring.py index 6941e6bd84..cd4a875e9e 100644 --- a/test/test_heartbeat_monitoring.py +++ b/test/test_heartbeat_monitoring.py @@ -19,21 +19,21 @@ sys.path[0:0] = [""] +from test import IntegrationTest, client_knobs, unittest +from test.utils import HeartbeatEventListener, MockPool, single_client, wait_until + from pymongo.errors import ConnectionFailure from pymongo.hello import Hello, HelloCompat from pymongo.monitor import Monitor -from test import unittest, client_knobs, IntegrationTest -from test.utils import (HeartbeatEventListener, MockPool, single_client, - wait_until) class TestHeartbeatMonitoring(IntegrationTest): - def create_mock_monitor(self, responses, uri, expected_results): listener = HeartbeatEventListener() - with client_knobs(heartbeat_frequency=0.1, - min_heartbeat_interval=0.1, - events_queue_frequency=0.1): + with client_knobs( + heartbeat_frequency=0.1, min_heartbeat_interval=0.1, events_queue_frequency=0.1 + ): + class MockMonitor(Monitor): def _check_with_socket(self, *args, **kwargs): if isinstance(responses[1], Exception): @@ -41,27 +41,21 @@ def _check_with_socket(self, *args, **kwargs): return Hello(responses[1]), 99 m = single_client( - h=uri, - event_listeners=(listener,), - _monitor_class=MockMonitor, - _pool_class=MockPool) + h=uri, event_listeners=(listener,), _monitor_class=MockMonitor, _pool_class=MockPool + ) expected_len = len(expected_results) # Wait for *at least* expected_len number of results. The # monitor thread may run multiple times during the execution # of this test. - wait_until( - lambda: len(listener.events) >= expected_len, - "publish all events") + wait_until(lambda: len(listener.events) >= expected_len, "publish all events") try: # zip gives us len(expected_results) pairs. for expected, actual in zip(expected_results, listener.events): - self.assertEqual(expected, - actual.__class__.__name__) - self.assertEqual(actual.connection_id, - responses[0]) - if expected != 'ServerHeartbeatStartedEvent': + self.assertEqual(expected, actual.__class__.__name__) + self.assertEqual(actual.connection_id, responses[0]) + if expected != "ServerHeartbeatStartedEvent": if isinstance(actual.reply, Hello): self.assertEqual(actual.duration, 99) self.assertEqual(actual.reply._doc, responses[1]) @@ -72,28 +66,25 @@ def _check_with_socket(self, *args, **kwargs): m.close() def test_standalone(self): - responses = (('a', 27017), - { - HelloCompat.LEGACY_CMD: True, - "maxWireVersion": 4, - "minWireVersion": 0, - "ok": 1 - }) + responses = ( + ("a", 27017), + {HelloCompat.LEGACY_CMD: True, "maxWireVersion": 4, "minWireVersion": 0, "ok": 1}, + ) uri = "mongodb://a:27017" - expected_results = ['ServerHeartbeatStartedEvent', - 'ServerHeartbeatSucceededEvent'] + expected_results = ["ServerHeartbeatStartedEvent", "ServerHeartbeatSucceededEvent"] self.create_mock_monitor(responses, uri, expected_results) def test_standalone_error(self): - responses = (('a', 27017), - ConnectionFailure("SPECIAL MESSAGE")) + responses = (("a", 27017), ConnectionFailure("SPECIAL MESSAGE")) uri = "mongodb://a:27017" # _check_with_socket failing results in a second attempt. - expected_results = ['ServerHeartbeatStartedEvent', - 'ServerHeartbeatFailedEvent', - 'ServerHeartbeatStartedEvent', - 'ServerHeartbeatFailedEvent'] + expected_results = [ + "ServerHeartbeatStartedEvent", + "ServerHeartbeatFailedEvent", + "ServerHeartbeatStartedEvent", + "ServerHeartbeatFailedEvent", + ] self.create_mock_monitor(responses, uri, expected_results) diff --git a/test/test_json_util.py b/test/test_json_util.py index dbf4f1c26a..dc12aa74b6 100644 --- a/test/test_json_util.py +++ b/test/test_json_util.py @@ -22,16 +22,26 @@ sys.path[0:0] = [""] -from bson import json_util, EPOCH_AWARE, EPOCH_NAIVE, SON -from bson.json_util import (DatetimeRepresentation, - JSONMode, - JSONOptions, - LEGACY_JSON_OPTIONS) -from bson.binary import (ALL_UUID_REPRESENTATIONS, Binary, MD5_SUBTYPE, - USER_DEFINED_SUBTYPE, UuidRepresentation, STANDARD) +from test import IntegrationTest, unittest + +from bson import EPOCH_AWARE, EPOCH_NAIVE, SON, json_util +from bson.binary import ( + ALL_UUID_REPRESENTATIONS, + MD5_SUBTYPE, + STANDARD, + USER_DEFINED_SUBTYPE, + Binary, + UuidRepresentation, +) from bson.code import Code from bson.dbref import DBRef from bson.int64 import Int64 +from bson.json_util import ( + LEGACY_JSON_OPTIONS, + DatetimeRepresentation, + JSONMode, + JSONOptions, +) from bson.max_key import MaxKey from bson.min_key import MinKey from bson.objectid import ObjectId @@ -39,14 +49,12 @@ from bson.timestamp import Timestamp from bson.tz_util import FixedOffset, utc -from test import unittest, IntegrationTest - - STRICT_JSON_OPTIONS = JSONOptions( strict_number_long=True, datetime_representation=DatetimeRepresentation.ISO8601, strict_uuid=True, - json_mode=JSONMode.LEGACY) + json_mode=JSONMode.LEGACY, +) class TestJsonUtil(unittest.TestCase): @@ -61,15 +69,13 @@ def test_basic(self): def test_json_options_with_options(self): opts = JSONOptions( - datetime_representation=DatetimeRepresentation.NUMBERLONG, - json_mode=JSONMode.LEGACY) - self.assertEqual( - opts.datetime_representation, DatetimeRepresentation.NUMBERLONG) + datetime_representation=DatetimeRepresentation.NUMBERLONG, json_mode=JSONMode.LEGACY + ) + self.assertEqual(opts.datetime_representation, DatetimeRepresentation.NUMBERLONG) opts2 = opts.with_options( - datetime_representation=DatetimeRepresentation.ISO8601, - json_mode=JSONMode.LEGACY) - self.assertEqual( - opts2.datetime_representation, DatetimeRepresentation.ISO8601) + datetime_representation=DatetimeRepresentation.ISO8601, json_mode=JSONMode.LEGACY + ) + self.assertEqual(opts2.datetime_representation, DatetimeRepresentation.ISO8601) opts = JSONOptions(strict_number_long=True, json_mode=JSONMode.LEGACY) self.assertEqual(opts.strict_number_long, True) @@ -77,16 +83,12 @@ def test_json_options_with_options(self): self.assertEqual(opts2.strict_number_long, False) opts = json_util.CANONICAL_JSON_OPTIONS - self.assertNotEqual( - opts.uuid_representation, UuidRepresentation.JAVA_LEGACY) - opts2 = opts.with_options( - uuid_representation=UuidRepresentation.JAVA_LEGACY) - self.assertEqual( - opts2.uuid_representation, UuidRepresentation.JAVA_LEGACY) + self.assertNotEqual(opts.uuid_representation, UuidRepresentation.JAVA_LEGACY) + opts2 = opts.with_options(uuid_representation=UuidRepresentation.JAVA_LEGACY) + self.assertEqual(opts2.uuid_representation, UuidRepresentation.JAVA_LEGACY) self.assertEqual(opts2.document_class, dict) opts3 = opts2.with_options(document_class=SON) - self.assertEqual( - opts3.uuid_representation, UuidRepresentation.JAVA_LEGACY) + self.assertEqual(opts3.uuid_representation, UuidRepresentation.JAVA_LEGACY) self.assertEqual(opts3.document_class, SON) def test_objectid(self): @@ -100,41 +102,42 @@ def test_dbref(self): # Check order. self.assertEqual( '{"$ref": "collection", "$id": 1, "$db": "db"}', - json_util.dumps(DBRef('collection', 1, 'db'))) + json_util.dumps(DBRef("collection", 1, "db")), + ) def test_datetime(self): - tz_aware_opts = json_util.DEFAULT_JSON_OPTIONS.with_options( - tz_aware=True) + tz_aware_opts = json_util.DEFAULT_JSON_OPTIONS.with_options(tz_aware=True) # only millis, not micros - self.round_trip({"date": datetime.datetime(2009, 12, 9, 15, 49, 45, - 191000, utc)}, json_options=tz_aware_opts) - self.round_trip({"date": datetime.datetime(2009, 12, 9, 15, - 49, 45, 191000)}) - - for jsn in ['{"dt": { "$date" : "1970-01-01T00:00:00.000+0000"}}', - '{"dt": { "$date" : "1970-01-01T00:00:00.000000+0000"}}', - '{"dt": { "$date" : "1970-01-01T00:00:00.000+00:00"}}', - '{"dt": { "$date" : "1970-01-01T00:00:00.000000+00:00"}}', - '{"dt": { "$date" : "1970-01-01T00:00:00.000000+00"}}', - '{"dt": { "$date" : "1970-01-01T00:00:00.000Z"}}', - '{"dt": { "$date" : "1970-01-01T00:00:00.000000Z"}}', - '{"dt": { "$date" : "1970-01-01T00:00:00Z"}}', - '{"dt": {"$date": "1970-01-01T00:00:00.000"}}', - '{"dt": { "$date" : "1970-01-01T00:00:00"}}', - '{"dt": { "$date" : "1970-01-01T00:00:00.000000"}}', - '{"dt": { "$date" : "1969-12-31T16:00:00.000-0800"}}', - '{"dt": { "$date" : "1969-12-31T16:00:00.000000-0800"}}', - '{"dt": { "$date" : "1969-12-31T16:00:00.000-08:00"}}', - '{"dt": { "$date" : "1969-12-31T16:00:00.000000-08:00"}}', - '{"dt": { "$date" : "1969-12-31T16:00:00.000000-08"}}', - '{"dt": { "$date" : "1970-01-01T01:00:00.000+0100"}}', - '{"dt": { "$date" : "1970-01-01T01:00:00.000000+0100"}}', - '{"dt": { "$date" : "1970-01-01T01:00:00.000+01:00"}}', - '{"dt": { "$date" : "1970-01-01T01:00:00.000000+01:00"}}', - '{"dt": { "$date" : "1970-01-01T01:00:00.000000+01"}}' - ]: - self.assertEqual(EPOCH_AWARE, json_util.loads( - jsn, json_options=tz_aware_opts)["dt"]) + self.round_trip( + {"date": datetime.datetime(2009, 12, 9, 15, 49, 45, 191000, utc)}, + json_options=tz_aware_opts, + ) + self.round_trip({"date": datetime.datetime(2009, 12, 9, 15, 49, 45, 191000)}) + + for jsn in [ + '{"dt": { "$date" : "1970-01-01T00:00:00.000+0000"}}', + '{"dt": { "$date" : "1970-01-01T00:00:00.000000+0000"}}', + '{"dt": { "$date" : "1970-01-01T00:00:00.000+00:00"}}', + '{"dt": { "$date" : "1970-01-01T00:00:00.000000+00:00"}}', + '{"dt": { "$date" : "1970-01-01T00:00:00.000000+00"}}', + '{"dt": { "$date" : "1970-01-01T00:00:00.000Z"}}', + '{"dt": { "$date" : "1970-01-01T00:00:00.000000Z"}}', + '{"dt": { "$date" : "1970-01-01T00:00:00Z"}}', + '{"dt": {"$date": "1970-01-01T00:00:00.000"}}', + '{"dt": { "$date" : "1970-01-01T00:00:00"}}', + '{"dt": { "$date" : "1970-01-01T00:00:00.000000"}}', + '{"dt": { "$date" : "1969-12-31T16:00:00.000-0800"}}', + '{"dt": { "$date" : "1969-12-31T16:00:00.000000-0800"}}', + '{"dt": { "$date" : "1969-12-31T16:00:00.000-08:00"}}', + '{"dt": { "$date" : "1969-12-31T16:00:00.000000-08:00"}}', + '{"dt": { "$date" : "1969-12-31T16:00:00.000000-08"}}', + '{"dt": { "$date" : "1970-01-01T01:00:00.000+0100"}}', + '{"dt": { "$date" : "1970-01-01T01:00:00.000000+0100"}}', + '{"dt": { "$date" : "1970-01-01T01:00:00.000+01:00"}}', + '{"dt": { "$date" : "1970-01-01T01:00:00.000000+01:00"}}', + '{"dt": { "$date" : "1970-01-01T01:00:00.000000+01"}}', + ]: + self.assertEqual(EPOCH_AWARE, json_util.loads(jsn, json_options=tz_aware_opts)["dt"]) self.assertEqual(EPOCH_NAIVE, json_util.loads(jsn)["dt"]) dtm = datetime.datetime(1, 1, 1, 1, 1, 1, 0, utc) @@ -147,84 +150,99 @@ def test_datetime(self): pre_epoch = {"dt": datetime.datetime(1, 1, 1, 1, 1, 1, 10000, utc)} post_epoch = {"dt": datetime.datetime(1972, 1, 1, 1, 1, 1, 10000, utc)} self.assertEqual( - '{"dt": {"$date": {"$numberLong": "-62135593138990"}}}', - json_util.dumps(pre_epoch)) + '{"dt": {"$date": {"$numberLong": "-62135593138990"}}}', json_util.dumps(pre_epoch) + ) self.assertEqual( - '{"dt": {"$date": "1972-01-01T01:01:01.010Z"}}', - json_util.dumps(post_epoch)) + '{"dt": {"$date": "1972-01-01T01:01:01.010Z"}}', json_util.dumps(post_epoch) + ) self.assertEqual( '{"dt": {"$date": -62135593138990}}', - json_util.dumps(pre_epoch, json_options=LEGACY_JSON_OPTIONS)) + json_util.dumps(pre_epoch, json_options=LEGACY_JSON_OPTIONS), + ) self.assertEqual( '{"dt": {"$date": 63075661010}}', - json_util.dumps(post_epoch, json_options=LEGACY_JSON_OPTIONS)) + json_util.dumps(post_epoch, json_options=LEGACY_JSON_OPTIONS), + ) self.assertEqual( '{"dt": {"$date": {"$numberLong": "-62135593138990"}}}', - json_util.dumps(pre_epoch, json_options=STRICT_JSON_OPTIONS)) + json_util.dumps(pre_epoch, json_options=STRICT_JSON_OPTIONS), + ) self.assertEqual( '{"dt": {"$date": "1972-01-01T01:01:01.010Z"}}', - json_util.dumps(post_epoch, json_options=STRICT_JSON_OPTIONS)) + json_util.dumps(post_epoch, json_options=STRICT_JSON_OPTIONS), + ) number_long_options = JSONOptions( - datetime_representation=DatetimeRepresentation.NUMBERLONG, - json_mode=JSONMode.LEGACY) + datetime_representation=DatetimeRepresentation.NUMBERLONG, json_mode=JSONMode.LEGACY + ) self.assertEqual( '{"dt": {"$date": {"$numberLong": "63075661010"}}}', - json_util.dumps(post_epoch, json_options=number_long_options)) + json_util.dumps(post_epoch, json_options=number_long_options), + ) self.assertEqual( '{"dt": {"$date": {"$numberLong": "-62135593138990"}}}', - json_util.dumps(pre_epoch, json_options=number_long_options)) + json_util.dumps(pre_epoch, json_options=number_long_options), + ) # ISO8601 mode assumes naive datetimes are UTC pre_epoch_naive = {"dt": datetime.datetime(1, 1, 1, 1, 1, 1, 10000)} - post_epoch_naive = { - "dt": datetime.datetime(1972, 1, 1, 1, 1, 1, 10000)} + post_epoch_naive = {"dt": datetime.datetime(1972, 1, 1, 1, 1, 1, 10000)} self.assertEqual( '{"dt": {"$date": {"$numberLong": "-62135593138990"}}}', - json_util.dumps(pre_epoch_naive, json_options=STRICT_JSON_OPTIONS)) + json_util.dumps(pre_epoch_naive, json_options=STRICT_JSON_OPTIONS), + ) self.assertEqual( '{"dt": {"$date": "1972-01-01T01:01:01.010Z"}}', - json_util.dumps(post_epoch_naive, - json_options=STRICT_JSON_OPTIONS)) + json_util.dumps(post_epoch_naive, json_options=STRICT_JSON_OPTIONS), + ) # Test tz_aware and tzinfo options self.assertEqual( datetime.datetime(1972, 1, 1, 1, 1, 1, 10000, utc), json_util.loads( - '{"dt": {"$date": "1972-01-01T01:01:01.010+0000"}}', - json_options=tz_aware_opts)["dt"]) + '{"dt": {"$date": "1972-01-01T01:01:01.010+0000"}}', json_options=tz_aware_opts + )["dt"], + ) self.assertEqual( datetime.datetime(1972, 1, 1, 1, 1, 1, 10000, utc), json_util.loads( '{"dt": {"$date": "1972-01-01T01:01:01.010+0000"}}', - json_options=JSONOptions(tz_aware=True, - tzinfo=utc))["dt"]) + json_options=JSONOptions(tz_aware=True, tzinfo=utc), + )["dt"], + ) self.assertEqual( datetime.datetime(1972, 1, 1, 1, 1, 1, 10000), json_util.loads( '{"dt": {"$date": "1972-01-01T01:01:01.010+0000"}}', - json_options=JSONOptions(tz_aware=False))["dt"]) - self.round_trip(pre_epoch_naive, json_options=JSONOptions( - tz_aware=False)) + json_options=JSONOptions(tz_aware=False), + )["dt"], + ) + self.round_trip(pre_epoch_naive, json_options=JSONOptions(tz_aware=False)) # Test a non-utc timezone - pacific = FixedOffset(-8 * 60, 'US/Pacific') - aware_datetime = {"dt": datetime.datetime(2002, 10, 27, 6, 0, 0, 10000, - pacific)} + pacific = FixedOffset(-8 * 60, "US/Pacific") + aware_datetime = {"dt": datetime.datetime(2002, 10, 27, 6, 0, 0, 10000, pacific)} self.assertEqual( '{"dt": {"$date": "2002-10-27T06:00:00.010-0800"}}', - json_util.dumps(aware_datetime, json_options=STRICT_JSON_OPTIONS)) - self.round_trip(aware_datetime, json_options=JSONOptions( - json_mode=JSONMode.LEGACY, - tz_aware=True, tzinfo=pacific)) - self.round_trip(aware_datetime, json_options=JSONOptions( - datetime_representation=DatetimeRepresentation.ISO8601, - json_mode=JSONMode.LEGACY, - tz_aware=True, tzinfo=pacific)) + json_util.dumps(aware_datetime, json_options=STRICT_JSON_OPTIONS), + ) + self.round_trip( + aware_datetime, + json_options=JSONOptions(json_mode=JSONMode.LEGACY, tz_aware=True, tzinfo=pacific), + ) + self.round_trip( + aware_datetime, + json_options=JSONOptions( + datetime_representation=DatetimeRepresentation.ISO8601, + json_mode=JSONMode.LEGACY, + tz_aware=True, + tzinfo=pacific, + ), + ) def test_regex_object_hook(self): # Extended JSON format regular expression. - pat = 'a*b' + pat = "a*b" json_re = '{"$regex": "%s", "$options": "u"}' % pat loaded = json_util.object_hook(json.loads(json_re)) self.assertTrue(isinstance(loaded, Regex)) @@ -232,9 +250,7 @@ def test_regex_object_hook(self): self.assertEqual(re.U, loaded.flags) def test_regex(self): - for regex_instance in ( - re.compile("a*b", re.IGNORECASE), - Regex("a*b", re.IGNORECASE)): + for regex_instance in (re.compile("a*b", re.IGNORECASE), Regex("a*b", re.IGNORECASE)): res = self.round_tripped({"r": regex_instance})["r"] self.assertEqual("a*b", res.pattern) @@ -242,33 +258,34 @@ def test_regex(self): self.assertEqual("a*b", res.pattern) self.assertEqual(re.IGNORECASE, res.flags) - unicode_options = re.I|re.M|re.S|re.U|re.X + unicode_options = re.I | re.M | re.S | re.U | re.X regex = re.compile("a*b", unicode_options) res = self.round_tripped({"r": regex})["r"] self.assertEqual(unicode_options, res.flags) # Some tools may not add $options if no flags are set. - res = json_util.loads('{"r": {"$regex": "a*b"}}')['r'] + res = json_util.loads('{"r": {"$regex": "a*b"}}')["r"] self.assertEqual(0, res.flags) self.assertEqual( - Regex('.*', 'ilm'), - json_util.loads( - '{"r": {"$regex": ".*", "$options": "ilm"}}')['r']) + Regex(".*", "ilm"), json_util.loads('{"r": {"$regex": ".*", "$options": "ilm"}}')["r"] + ) # Check order. self.assertEqual( '{"$regularExpression": {"pattern": ".*", "options": "mx"}}', - json_util.dumps(Regex('.*', re.M | re.X))) + json_util.dumps(Regex(".*", re.M | re.X)), + ) self.assertEqual( '{"$regularExpression": {"pattern": ".*", "options": "mx"}}', - json_util.dumps(re.compile(b'.*', re.M | re.X))) + json_util.dumps(re.compile(b".*", re.M | re.X)), + ) self.assertEqual( '{"$regex": ".*", "$options": "mx"}', - json_util.dumps(Regex('.*', re.M | re.X), - json_options=LEGACY_JSON_OPTIONS)) + json_util.dumps(Regex(".*", re.M | re.X), json_options=LEGACY_JSON_OPTIONS), + ) def test_regex_validation(self): non_str_types = [10, {}, []] @@ -295,87 +312,94 @@ def test_timestamp(self): def test_uuid_default(self): # Cannot directly encode native UUIDs with the default # uuid_representation. - doc = {'uuid': uuid.UUID('f47ac10b-58cc-4372-a567-0e02b2c3d479')} - with self.assertRaisesRegex(ValueError, 'cannot encode native uuid'): + doc = {"uuid": uuid.UUID("f47ac10b-58cc-4372-a567-0e02b2c3d479")} + with self.assertRaisesRegex(ValueError, "cannot encode native uuid"): json_util.dumps(doc) legacy_jsn = '{"uuid": {"$uuid": "f47ac10b58cc4372a5670e02b2c3d479"}}' - expected = {'uuid': Binary( - b'\xf4z\xc1\x0bX\xccCr\xa5g\x0e\x02\xb2\xc3\xd4y', 4)} + expected = {"uuid": Binary(b"\xf4z\xc1\x0bX\xccCr\xa5g\x0e\x02\xb2\xc3\xd4y", 4)} self.assertEqual(json_util.loads(legacy_jsn), expected) def test_uuid(self): - doc = {'uuid': uuid.UUID('f47ac10b-58cc-4372-a567-0e02b2c3d479')} + doc = {"uuid": uuid.UUID("f47ac10b-58cc-4372-a567-0e02b2c3d479")} uuid_legacy_opts = LEGACY_JSON_OPTIONS.with_options( - uuid_representation=UuidRepresentation.PYTHON_LEGACY) + uuid_representation=UuidRepresentation.PYTHON_LEGACY + ) self.round_trip(doc, json_options=uuid_legacy_opts) self.assertEqual( '{"uuid": {"$uuid": "f47ac10b58cc4372a5670e02b2c3d479"}}', - json_util.dumps(doc, json_options=LEGACY_JSON_OPTIONS)) + json_util.dumps(doc, json_options=LEGACY_JSON_OPTIONS), + ) self.assertEqual( - '{"uuid": ' - '{"$binary": "9HrBC1jMQ3KlZw4CssPUeQ==", "$type": "03"}}', + '{"uuid": ' '{"$binary": "9HrBC1jMQ3KlZw4CssPUeQ==", "$type": "03"}}', json_util.dumps( - doc, json_options=STRICT_JSON_OPTIONS.with_options( - uuid_representation=UuidRepresentation.PYTHON_LEGACY))) - self.assertEqual( - '{"uuid": ' - '{"$binary": "9HrBC1jMQ3KlZw4CssPUeQ==", "$type": "04"}}', + doc, + json_options=STRICT_JSON_OPTIONS.with_options( + uuid_representation=UuidRepresentation.PYTHON_LEGACY + ), + ), + ) + self.assertEqual( + '{"uuid": ' '{"$binary": "9HrBC1jMQ3KlZw4CssPUeQ==", "$type": "04"}}', json_util.dumps( - doc, json_options=JSONOptions( - strict_uuid=True, json_mode=JSONMode.LEGACY, - uuid_representation=STANDARD))) - self.assertEqual( - doc, json_util.loads( - '{"uuid": ' - '{"$binary": "9HrBC1jMQ3KlZw4CssPUeQ==", "$type": "03"}}', - json_options=uuid_legacy_opts)) - for uuid_representation in (set(ALL_UUID_REPRESENTATIONS) - - {UuidRepresentation.UNSPECIFIED}): + doc, + json_options=JSONOptions( + strict_uuid=True, json_mode=JSONMode.LEGACY, uuid_representation=STANDARD + ), + ), + ) + self.assertEqual( + doc, + json_util.loads( + '{"uuid": ' '{"$binary": "9HrBC1jMQ3KlZw4CssPUeQ==", "$type": "03"}}', + json_options=uuid_legacy_opts, + ), + ) + for uuid_representation in set(ALL_UUID_REPRESENTATIONS) - {UuidRepresentation.UNSPECIFIED}: options = JSONOptions( - strict_uuid=True, json_mode=JSONMode.LEGACY, - uuid_representation=uuid_representation) + strict_uuid=True, json_mode=JSONMode.LEGACY, uuid_representation=uuid_representation + ) self.round_trip(doc, json_options=options) # Ignore UUID representation when decoding BSON binary subtype 4. - self.assertEqual(doc, json_util.loads( - '{"uuid": ' - '{"$binary": "9HrBC1jMQ3KlZw4CssPUeQ==", "$type": "04"}}', - json_options=options)) + self.assertEqual( + doc, + json_util.loads( + '{"uuid": ' '{"$binary": "9HrBC1jMQ3KlZw4CssPUeQ==", "$type": "04"}}', + json_options=options, + ), + ) def test_uuid_uuid_rep_unspecified(self): _uuid = uuid.uuid4() options = JSONOptions( strict_uuid=True, json_mode=JSONMode.LEGACY, - uuid_representation=UuidRepresentation.UNSPECIFIED) + uuid_representation=UuidRepresentation.UNSPECIFIED, + ) # Cannot directly encode native UUIDs with UNSPECIFIED. - doc = {'uuid': _uuid} + doc = {"uuid": _uuid} with self.assertRaises(ValueError): json_util.dumps(doc, json_options=options) # All UUID subtypes are decoded as Binary with UNSPECIFIED. # subtype 3 - doc = {'uuid': Binary(_uuid.bytes, subtype=3)} + doc = {"uuid": Binary(_uuid.bytes, subtype=3)} ext_json_str = json_util.dumps(doc) - self.assertEqual( - doc, json_util.loads(ext_json_str, json_options=options)) + self.assertEqual(doc, json_util.loads(ext_json_str, json_options=options)) # subtype 4 - doc = {'uuid': Binary(_uuid.bytes, subtype=4)} + doc = {"uuid": Binary(_uuid.bytes, subtype=4)} ext_json_str = json_util.dumps(doc) - self.assertEqual( - doc, json_util.loads(ext_json_str, json_options=options)) + self.assertEqual(doc, json_util.loads(ext_json_str, json_options=options)) # $uuid-encoded fields - doc = {'uuid': Binary(_uuid.bytes, subtype=4)} - ext_json_str = json_util.dumps({'uuid': _uuid}, - json_options=LEGACY_JSON_OPTIONS) - self.assertEqual( - doc, json_util.loads(ext_json_str, json_options=options)) + doc = {"uuid": Binary(_uuid.bytes, subtype=4)} + ext_json_str = json_util.dumps({"uuid": _uuid}, json_options=LEGACY_JSON_OPTIONS) + self.assertEqual(doc, json_util.loads(ext_json_str, json_options=options)) def test_binary(self): bin_type_dict = {"bin": b"\x00\x01\x02\x03\x04"} md5_type_dict = { - "md5": Binary(b' n7\x18\xaf\t/\xd1\xd1/\x80\xca\xe7q\xcc\xac', - MD5_SUBTYPE)} + "md5": Binary(b" n7\x18\xaf\t/\xd1\xd1/\x80\xca\xe7q\xcc\xac", MD5_SUBTYPE) + } custom_type_dict = {"custom": Binary(b"hello", USER_DEFINED_SUBTYPE)} self.round_trip(bin_type_dict) @@ -383,43 +407,47 @@ def test_binary(self): self.round_trip(custom_type_dict) # Binary with subtype 0 is decoded into bytes in Python 3. - bin = json_util.loads( - '{"bin": {"$binary": "AAECAwQ=", "$type": "00"}}')['bin'] + bin = json_util.loads('{"bin": {"$binary": "AAECAwQ=", "$type": "00"}}')["bin"] self.assertEqual(type(bin), bytes) # PYTHON-443 ensure old type formats are supported - json_bin_dump = json_util.dumps(bin_type_dict, - json_options=LEGACY_JSON_OPTIONS) + json_bin_dump = json_util.dumps(bin_type_dict, json_options=LEGACY_JSON_OPTIONS) self.assertIn('"$type": "00"', json_bin_dump) - self.assertEqual(bin_type_dict, - json_util.loads('{"bin": {"$type": 0, "$binary": "AAECAwQ="}}')) - json_bin_dump = json_util.dumps(md5_type_dict, - json_options=LEGACY_JSON_OPTIONS) + self.assertEqual( + bin_type_dict, json_util.loads('{"bin": {"$type": 0, "$binary": "AAECAwQ="}}') + ) + json_bin_dump = json_util.dumps(md5_type_dict, json_options=LEGACY_JSON_OPTIONS) # Check order. self.assertEqual( - '{"md5": {"$binary": "IG43GK8JL9HRL4DK53HMrA==",' - + ' "$type": "05"}}', - json_bin_dump) + '{"md5": {"$binary": "IG43GK8JL9HRL4DK53HMrA==",' + ' "$type": "05"}}', json_bin_dump + ) - self.assertEqual(md5_type_dict, - json_util.loads('{"md5": {"$type": 5, "$binary":' - ' "IG43GK8JL9HRL4DK53HMrA=="}}')) + self.assertEqual( + md5_type_dict, + json_util.loads('{"md5": {"$type": 5, "$binary":' ' "IG43GK8JL9HRL4DK53HMrA=="}}'), + ) - json_bin_dump = json_util.dumps(custom_type_dict, - json_options=LEGACY_JSON_OPTIONS) + json_bin_dump = json_util.dumps(custom_type_dict, json_options=LEGACY_JSON_OPTIONS) self.assertIn('"$type": "80"', json_bin_dump) - self.assertEqual(custom_type_dict, - json_util.loads('{"custom": {"$type": 128, "$binary":' - ' "aGVsbG8="}}')) + self.assertEqual( + custom_type_dict, + json_util.loads('{"custom": {"$type": 128, "$binary":' ' "aGVsbG8="}}'), + ) # Handle mongoexport where subtype >= 128 - self.assertEqual(128, - json_util.loads('{"custom": {"$type": "ffffff80", "$binary":' - ' "aGVsbG8="}}')['custom'].subtype) + self.assertEqual( + 128, + json_util.loads('{"custom": {"$type": "ffffff80", "$binary":' ' "aGVsbG8="}}')[ + "custom" + ].subtype, + ) - self.assertEqual(255, - json_util.loads('{"custom": {"$type": "ffffffff", "$binary":' - ' "aGVsbG8="}}')['custom'].subtype) + self.assertEqual( + 255, + json_util.loads('{"custom": {"$type": "ffffffff", "$binary":' ' "aGVsbG8="}}')[ + "custom" + ].subtype, + ) def test_code(self): self.round_trip({"code": Code("function x() { return 1; }")}) @@ -431,34 +459,30 @@ def test_code(self): # Check order. self.assertEqual('{"$code": "return z", "$scope": {"z": 2}}', res) - no_scope = Code('function() {}') - self.assertEqual( - '{"$code": "function() {}"}', json_util.dumps(no_scope)) + no_scope = Code("function() {}") + self.assertEqual('{"$code": "function() {}"}', json_util.dumps(no_scope)) def test_undefined(self): jsn = '{"name": {"$undefined": true}}' - self.assertIsNone(json_util.loads(jsn)['name']) + self.assertIsNone(json_util.loads(jsn)["name"]) def test_numberlong(self): jsn = '{"weight": {"$numberLong": "65535"}}' - self.assertEqual(json_util.loads(jsn)['weight'], - Int64(65535)) - self.assertEqual(json_util.dumps({"weight": Int64(65535)}), - '{"weight": 65535}') - json_options = JSONOptions(strict_number_long=True, - json_mode=JSONMode.LEGACY) - self.assertEqual(json_util.dumps({"weight": Int64(65535)}, - json_options=json_options), - jsn) + self.assertEqual(json_util.loads(jsn)["weight"], Int64(65535)) + self.assertEqual(json_util.dumps({"weight": Int64(65535)}), '{"weight": 65535}') + json_options = JSONOptions(strict_number_long=True, json_mode=JSONMode.LEGACY) + self.assertEqual(json_util.dumps({"weight": Int64(65535)}, json_options=json_options), jsn) def test_loads_document_class(self): # document_class dict should always work - self.assertEqual({"foo": "bar"}, json_util.loads( - '{"foo": "bar"}', - json_options=JSONOptions(document_class=dict))) - self.assertEqual(SON([("foo", "bar"), ("b", 1)]), json_util.loads( - '{"foo": "bar", "b": 1}', - json_options=JSONOptions(document_class=SON))) + self.assertEqual( + {"foo": "bar"}, + json_util.loads('{"foo": "bar"}', json_options=JSONOptions(document_class=dict)), + ) + self.assertEqual( + SON([("foo", "bar"), ("b", 1)]), + json_util.loads('{"foo": "bar", "b": 1}', json_options=JSONOptions(document_class=SON)), + ) class TestJsonUtilRoundtrip(IntegrationTest): @@ -467,12 +491,11 @@ def test_cursor(self): db.drop_collection("test") docs = [ - {'foo': [1, 2]}, - {'bar': {'hello': 'world'}}, - {'code': Code("function x() { return 1; }")}, - {'bin': Binary(b"\x00\x01\x02\x03\x04", USER_DEFINED_SUBTYPE)}, - {'dbref': {'_ref': DBRef('simple', - ObjectId('509b8db456c02c5ab7e63c34'))}} + {"foo": [1, 2]}, + {"bar": {"hello": "world"}}, + {"code": Code("function x() { return 1; }")}, + {"bin": Binary(b"\x00\x01\x02\x03\x04", USER_DEFINED_SUBTYPE)}, + {"dbref": {"_ref": DBRef("simple", ObjectId("509b8db456c02c5ab7e63c34"))}}, ] db.test.insert_many(docs) @@ -480,5 +503,6 @@ def test_cursor(self): for doc in docs: self.assertTrue(doc in reloaded_docs) + if __name__ == "__main__": unittest.main() diff --git a/test/test_load_balancer.py b/test/test_load_balancer.py index 247072c7bd..547cf327d3 100644 --- a/test/test_load_balancer.py +++ b/test/test_load_balancer.py @@ -21,16 +21,12 @@ sys.path[0:0] = [""] -from test import unittest, IntegrationTest, client_context -from test.utils import (ExceptionCatchingThread, - get_pool, - rs_client, - wait_until) +from test import IntegrationTest, client_context, unittest from test.unified_format import generate_test_classes +from test.utils import ExceptionCatchingThread, get_pool, rs_client, wait_until # Location of JSON test specifications. -TEST_PATH = os.path.join( - os.path.dirname(os.path.realpath(__file__)), 'load_balancer') +TEST_PATH = os.path.join(os.path.dirname(os.path.realpath(__file__)), "load_balancer") # Generate unified tests. globals().update(generate_test_classes(TEST_PATH, module=__name__)) @@ -45,7 +41,7 @@ def test_connections_are_only_returned_once(self): nconns = len(pool.sockets) self.db.test.find_one({}) self.assertEqual(len(pool.sockets), nconns) - list(self.db.test.aggregate([{'$limit': 1}])) + list(self.db.test.aggregate([{"$limit": 1}])) self.assertEqual(len(pool.sockets), nconns) @client_context.require_load_balancer @@ -68,6 +64,7 @@ def create_resource(coll): cursor = coll.find({}, batch_size=3) next(cursor) return cursor + self._test_no_gc_deadlock(create_resource) @client_context.require_failCommand_fail_point @@ -76,6 +73,7 @@ def create_resource(coll): cursor = coll.aggregate([], batchSize=3) next(cursor) return cursor + self._test_no_gc_deadlock(create_resource) def _test_no_gc_deadlock(self, create_resource): @@ -87,15 +85,11 @@ def _test_no_gc_deadlock(self, create_resource): self.assertEqual(pool.active_sockets, 0) # Cause the initial find attempt to fail to induce a reference cycle. args = { - "mode": { - "times": 1 - }, + "mode": {"times": 1}, "data": { - "failCommands": [ - "find", "aggregate" - ], - "closeConnection": True, - } + "failCommands": ["find", "aggregate"], + "closeConnection": True, + }, } with self.fail_point(args): resource = create_resource(coll) @@ -104,7 +98,7 @@ def _test_no_gc_deadlock(self, create_resource): thread = PoolLocker(pool) thread.start() - self.assertTrue(thread.locked.wait(5), 'timed out') + self.assertTrue(thread.locked.wait(5), "timed out") # Garbage collect the resource while the pool is locked to ensure we # don't deadlock. del resource @@ -116,7 +110,7 @@ def _test_no_gc_deadlock(self, create_resource): self.assertFalse(thread.is_alive()) self.assertIsNone(thread.exc) - wait_until(lambda: pool.active_sockets == 0, 'return socket') + wait_until(lambda: pool.active_sockets == 0, "return socket") # Run another operation to ensure the socket still works. coll.delete_many({}) @@ -133,7 +127,7 @@ def test_session_gc(self): thread = PoolLocker(pool) thread.start() - self.assertTrue(thread.locked.wait(5), 'timed out') + self.assertTrue(thread.locked.wait(5), "timed out") # Garbage collect the session while the pool is locked to ensure we # don't deadlock. del session @@ -145,7 +139,7 @@ def test_session_gc(self): self.assertFalse(thread.is_alive()) self.assertIsNone(thread.exc) - wait_until(lambda: pool.active_sockets == 0, 'return socket') + wait_until(lambda: pool.active_sockets == 0, "return socket") # Run another operation to ensure the socket still works. client[self.db.name].test.delete_many({}) @@ -164,8 +158,7 @@ def lock_pool(self): # Wait for the unlock flag. unlock_pool = self.unlock.wait(10) if not unlock_pool: - raise Exception('timed out waiting for unlock signal:' - ' deadlock?') + raise Exception("timed out waiting for unlock signal:" " deadlock?") if __name__ == "__main__": diff --git a/test/test_max_staleness.py b/test/test_max_staleness.py index 1fd82884f1..4c8491b59f 100644 --- a/test/test_max_staleness.py +++ b/test/test_max_staleness.py @@ -21,18 +21,16 @@ sys.path[0:0] = [""] -from pymongo import MongoClient -from pymongo.errors import ConfigurationError -from pymongo.server_selectors import writable_server_selector - from test import client_context, unittest from test.utils import rs_or_single_client from test.utils_selection_tests import create_selection_tests +from pymongo import MongoClient +from pymongo.errors import ConfigurationError +from pymongo.server_selectors import writable_server_selector + # Location of JSON test specifications. -_TEST_PATH = os.path.join( - os.path.dirname(os.path.realpath(__file__)), - 'max_staleness') +_TEST_PATH = os.path.join(os.path.dirname(os.path.realpath(__file__)), "max_staleness") class TestAllScenarios(create_selection_tests(_TEST_PATH)): @@ -54,26 +52,21 @@ def test_max_staleness(self): with self.assertRaises(ConfigurationError): # Read pref "primary" can't be used with max staleness. - MongoClient("mongodb://a/?readPreference=primary&" - "maxStalenessSeconds=120") + MongoClient("mongodb://a/?readPreference=primary&" "maxStalenessSeconds=120") client = MongoClient("mongodb://host/?maxStalenessSeconds=-1") self.assertEqual(-1, client.read_preference.max_staleness) - client = MongoClient("mongodb://host/?readPreference=primary&" - "maxStalenessSeconds=-1") + client = MongoClient("mongodb://host/?readPreference=primary&" "maxStalenessSeconds=-1") self.assertEqual(-1, client.read_preference.max_staleness) - client = MongoClient("mongodb://host/?readPreference=secondary&" - "maxStalenessSeconds=120") + client = MongoClient("mongodb://host/?readPreference=secondary&" "maxStalenessSeconds=120") self.assertEqual(120, client.read_preference.max_staleness) - client = MongoClient("mongodb://a/?readPreference=secondary&" - "maxStalenessSeconds=1") + client = MongoClient("mongodb://a/?readPreference=secondary&" "maxStalenessSeconds=1") self.assertEqual(1, client.read_preference.max_staleness) - client = MongoClient("mongodb://a/?readPreference=secondary&" - "maxStalenessSeconds=-1") + client = MongoClient("mongodb://a/?readPreference=secondary&" "maxStalenessSeconds=-1") self.assertEqual(-1, client.read_preference.max_staleness) client = MongoClient(maxStalenessSeconds=-1, readPreference="nearest") @@ -85,15 +78,15 @@ def test_max_staleness(self): def test_max_staleness_float(self): with self.assertRaises(TypeError) as ctx: - rs_or_single_client(maxStalenessSeconds=1.5, - readPreference="nearest") + rs_or_single_client(maxStalenessSeconds=1.5, readPreference="nearest") self.assertIn("must be an integer", str(ctx.exception)) with warnings.catch_warnings(record=True) as ctx: warnings.simplefilter("always") - client = MongoClient("mongodb://host/?maxStalenessSeconds=1.5" - "&readPreference=nearest") + client = MongoClient( + "mongodb://host/?maxStalenessSeconds=1.5" "&readPreference=nearest" + ) # Option was ignored. self.assertEqual(-1, client.read_preference.max_staleness) @@ -102,15 +95,13 @@ def test_max_staleness_float(self): def test_max_staleness_zero(self): # Zero is too small. with self.assertRaises(ValueError) as ctx: - rs_or_single_client(maxStalenessSeconds=0, - readPreference="nearest") + rs_or_single_client(maxStalenessSeconds=0, readPreference="nearest") self.assertIn("must be a positive integer", str(ctx.exception)) with warnings.catch_warnings(record=True) as ctx: warnings.simplefilter("always") - client = MongoClient("mongodb://host/?maxStalenessSeconds=0" - "&readPreference=nearest") + client = MongoClient("mongodb://host/?maxStalenessSeconds=0" "&readPreference=nearest") # Option was ignored. self.assertEqual(-1, client.read_preference.max_staleness) diff --git a/test/test_mongos_load_balancing.py b/test/test_mongos_load_balancing.py index c110b8b10c..e39940f56b 100644 --- a/test/test_mongos_load_balancing.py +++ b/test/test_mongos_load_balancing.py @@ -19,12 +19,13 @@ sys.path[0:0] = [""] +from test import MockClientTest, client_context, unittest +from test.pymongo_mocks import MockClient +from test.utils import connected, wait_until + from pymongo.errors import AutoReconnect, InvalidOperation from pymongo.server_selectors import writable_server_selector from pymongo.topology_description import TOPOLOGY_TYPE -from test import unittest, client_context, MockClientTest -from test.pymongo_mocks import MockClient -from test.utils import connected, wait_until @client_context.require_connection @@ -34,14 +35,13 @@ def setUpModule(): class SimpleOp(threading.Thread): - def __init__(self, client): super(SimpleOp, self).__init__() self.client = client self.passed = False def run(self): - self.client.db.command('ping') + self.client.db.command("ping") self.passed = True # No exception raised. @@ -58,26 +58,27 @@ def do_simple_op(client, nthreads): def writable_addresses(topology): - return set(server.description.address for server in - topology.select_servers(writable_server_selector)) + return set( + server.description.address for server in topology.select_servers(writable_server_selector) + ) class TestMongosLoadBalancing(MockClientTest): - def mock_client(self, **kwargs): mock_client = MockClient( standalones=[], members=[], - mongoses=['a:1', 'b:2', 'c:3'], - host='a:1,b:2,c:3', + mongoses=["a:1", "b:2", "c:3"], + host="a:1,b:2,c:3", connect=False, - **kwargs) + **kwargs + ) self.addCleanup(mock_client.close) # Latencies in seconds. - mock_client.mock_rtts['a:1'] = 0.020 - mock_client.mock_rtts['b:2'] = 0.025 - mock_client.mock_rtts['c:3'] = 0.045 + mock_client.mock_rtts["a:1"] = 0.020 + mock_client.mock_rtts["b:2"] = 0.025 + mock_client.mock_rtts["c:3"] = 0.045 return mock_client def test_lazy_connect(self): @@ -90,15 +91,15 @@ def test_lazy_connect(self): # Trigger initial connection. do_simple_op(client, nthreads) - wait_until(lambda: len(client.nodes) == 3, 'connect to all mongoses') + wait_until(lambda: len(client.nodes) == 3, "connect to all mongoses") def test_failover(self): nthreads = 10 client = connected(self.mock_client(localThresholdMS=0.001)) - wait_until(lambda: len(client.nodes) == 3, 'connect to all mongoses') + wait_until(lambda: len(client.nodes) == 3, "connect to all mongoses") # Our chosen mongos goes down. - client.kill_host('a:1') + client.kill_host("a:1") # Trigger failover to higher-latency nodes. AutoReconnect should be # raised at most once in each thread. @@ -106,10 +107,10 @@ def test_failover(self): def f(): try: - client.db.command('ping') + client.db.command("ping") except AutoReconnect: # Second attempt succeeds. - client.db.command('ping') + client.db.command("ping") passed.append(True) @@ -128,34 +129,34 @@ def f(): def test_local_threshold(self): client = connected(self.mock_client(localThresholdMS=30)) self.assertEqual(30, client.options.local_threshold_ms) - wait_until(lambda: len(client.nodes) == 3, 'connect to all mongoses') + wait_until(lambda: len(client.nodes) == 3, "connect to all mongoses") topology = client._topology # All are within a 30-ms latency window, see self.mock_client(). - self.assertEqual(set([('a', 1), ('b', 2), ('c', 3)]), - writable_addresses(topology)) + self.assertEqual(set([("a", 1), ("b", 2), ("c", 3)]), writable_addresses(topology)) # No error - client.admin.command('ping') + client.admin.command("ping") client = connected(self.mock_client(localThresholdMS=0)) self.assertEqual(0, client.options.local_threshold_ms) # No error - client.db.command('ping') + client.db.command("ping") # Our chosen mongos goes down. - client.kill_host('%s:%s' % next(iter(client.nodes))) + client.kill_host("%s:%s" % next(iter(client.nodes))) try: - client.db.command('ping') + client.db.command("ping") except: pass # We eventually connect to a new mongos. def connect_to_new_mongos(): try: - return client.db.command('ping') + return client.db.command("ping") except AutoReconnect: pass - wait_until(connect_to_new_mongos, 'connect to a new mongos') + + wait_until(connect_to_new_mongos, "connect to a new mongos") def test_load_balancing(self): # Although the server selection JSON tests already prove that @@ -163,25 +164,25 @@ def test_load_balancing(self): # test of discovering servers' round trip times and configuring # localThresholdMS. client = connected(self.mock_client()) - wait_until(lambda: len(client.nodes) == 3, 'connect to all mongoses') + wait_until(lambda: len(client.nodes) == 3, "connect to all mongoses") # Prohibited for topology type Sharded. with self.assertRaises(InvalidOperation): client.address topology = client._topology - self.assertEqual(TOPOLOGY_TYPE.Sharded, - topology.description.topology_type) + self.assertEqual(TOPOLOGY_TYPE.Sharded, topology.description.topology_type) # a and b are within the 15-ms latency window, see self.mock_client(). - self.assertEqual(set([('a', 1), ('b', 2)]), - writable_addresses(topology)) + self.assertEqual(set([("a", 1), ("b", 2)]), writable_addresses(topology)) - client.mock_rtts['a:1'] = 0.045 + client.mock_rtts["a:1"] = 0.045 # Discover only b is within latency window. - wait_until(lambda: set([('b', 2)]) == writable_addresses(topology), - 'discover server "a" is too far') + wait_until( + lambda: set([("b", 2)]) == writable_addresses(topology), + 'discover server "a" is too far', + ) if __name__ == "__main__": diff --git a/test/test_monitor.py b/test/test_monitor.py index 61e2057b52..85cfb0bc40 100644 --- a/test/test_monitor.py +++ b/test/test_monitor.py @@ -20,13 +20,15 @@ sys.path[0:0] = [""] -from pymongo.periodic_executor import _EXECUTORS +from test import IntegrationTest, unittest +from test.utils import ( + ServerAndTopologyEventListener, + connected, + single_client, + wait_until, +) -from test import unittest, IntegrationTest -from test.utils import (connected, - ServerAndTopologyEventListener, - single_client, - wait_until) +from pymongo.periodic_executor import _EXECUTORS def unregistered(ref): @@ -58,16 +60,13 @@ def test_cleanup_executors_on_client_del(self): self.assertEqual(len(executors), 4) # Each executor stores a weakref to itself in _EXECUTORS. - executor_refs = [ - (r, r()._name) for r in _EXECUTORS.copy() if r() in executors] + executor_refs = [(r, r()._name) for r in _EXECUTORS.copy() if r() in executors] del executors del client for ref, name in executor_refs: - wait_until(partial(unregistered, ref), - 'unregister executor: %s' % (name,), - timeout=5) + wait_until(partial(unregistered, ref), "unregister executor: %s" % (name,), timeout=5) def test_cleanup_executors_on_client_close(self): client = create_client() @@ -77,9 +76,9 @@ def test_cleanup_executors_on_client_close(self): client.close() for executor in executors: - wait_until(lambda: executor._stopped, - 'closed executor: %s' % (executor._name,), - timeout=5) + wait_until( + lambda: executor._stopped, "closed executor: %s" % (executor._name,), timeout=5 + ) if __name__ == "__main__": diff --git a/test/test_monitoring.py b/test/test_monitoring.py index 0d925b04bf..7c583d9316 100644 --- a/test/test_monitoring.py +++ b/test/test_monitoring.py @@ -20,38 +20,32 @@ sys.path[0:0] = [""] +from test import IntegrationTest, client_context, client_knobs, sanitize_cmd, unittest +from test.utils import ( + EventListener, + get_pool, + rs_or_single_client, + single_client, + wait_until, +) + from bson.int64 import Int64 from bson.objectid import ObjectId from bson.son import SON -from pymongo import CursorType, monitoring, InsertOne, UpdateOne, DeleteOne +from pymongo import CursorType, DeleteOne, InsertOne, UpdateOne, monitoring from pymongo.command_cursor import CommandCursor -from pymongo.errors import (AutoReconnect, - NotPrimaryError, - OperationFailure) +from pymongo.errors import AutoReconnect, NotPrimaryError, OperationFailure from pymongo.read_preferences import ReadPreference from pymongo.write_concern import WriteConcern -from test import (client_context, - client_knobs, - IntegrationTest, - sanitize_cmd, - unittest) -from test.utils import (EventListener, - get_pool, - rs_or_single_client, - single_client, - wait_until) class TestCommandMonitoring(IntegrationTest): - @classmethod @client_context.require_connection def setUpClass(cls): super(TestCommandMonitoring, cls).setUpClass() cls.listener = EventListener() - cls.client = rs_or_single_client( - event_listeners=[cls.listener], - retryWrites=False) + cls.client = rs_or_single_client(event_listeners=[cls.listener], retryWrites=False) @classmethod def tearDownClass(cls): @@ -63,107 +57,93 @@ def tearDown(self): super(TestCommandMonitoring, self).tearDown() def test_started_simple(self): - self.client.pymongo_test.command('ping') + self.client.pymongo_test.command("ping") results = self.listener.results - started = results['started'][0] - succeeded = results['succeeded'][0] - self.assertEqual(0, len(results['failed'])) - self.assertTrue( - isinstance(succeeded, monitoring.CommandSucceededEvent)) - self.assertTrue( - isinstance(started, monitoring.CommandStartedEvent)) - self.assertEqualCommand(SON([('ping', 1)]), started.command) - self.assertEqual('ping', started.command_name) + started = results["started"][0] + succeeded = results["succeeded"][0] + self.assertEqual(0, len(results["failed"])) + self.assertTrue(isinstance(succeeded, monitoring.CommandSucceededEvent)) + self.assertTrue(isinstance(started, monitoring.CommandStartedEvent)) + self.assertEqualCommand(SON([("ping", 1)]), started.command) + self.assertEqual("ping", started.command_name) self.assertEqual(self.client.address, started.connection_id) - self.assertEqual('pymongo_test', started.database_name) + self.assertEqual("pymongo_test", started.database_name) self.assertTrue(isinstance(started.request_id, int)) def test_succeeded_simple(self): - self.client.pymongo_test.command('ping') + self.client.pymongo_test.command("ping") results = self.listener.results - started = results['started'][0] - succeeded = results['succeeded'][0] - self.assertEqual(0, len(results['failed'])) - self.assertTrue( - isinstance(started, monitoring.CommandStartedEvent)) - self.assertTrue( - isinstance(succeeded, monitoring.CommandSucceededEvent)) - self.assertEqual('ping', succeeded.command_name) + started = results["started"][0] + succeeded = results["succeeded"][0] + self.assertEqual(0, len(results["failed"])) + self.assertTrue(isinstance(started, monitoring.CommandStartedEvent)) + self.assertTrue(isinstance(succeeded, monitoring.CommandSucceededEvent)) + self.assertEqual("ping", succeeded.command_name) self.assertEqual(self.client.address, succeeded.connection_id) - self.assertEqual(1, succeeded.reply.get('ok')) + self.assertEqual(1, succeeded.reply.get("ok")) self.assertTrue(isinstance(succeeded.request_id, int)) self.assertTrue(isinstance(succeeded.duration_micros, int)) def test_failed_simple(self): try: - self.client.pymongo_test.command('oops!') + self.client.pymongo_test.command("oops!") except OperationFailure: pass results = self.listener.results - started = results['started'][0] - failed = results['failed'][0] - self.assertEqual(0, len(results['succeeded'])) - self.assertTrue( - isinstance(started, monitoring.CommandStartedEvent)) - self.assertTrue( - isinstance(failed, monitoring.CommandFailedEvent)) - self.assertEqual('oops!', failed.command_name) + started = results["started"][0] + failed = results["failed"][0] + self.assertEqual(0, len(results["succeeded"])) + self.assertTrue(isinstance(started, monitoring.CommandStartedEvent)) + self.assertTrue(isinstance(failed, monitoring.CommandFailedEvent)) + self.assertEqual("oops!", failed.command_name) self.assertEqual(self.client.address, failed.connection_id) - self.assertEqual(0, failed.failure.get('ok')) + self.assertEqual(0, failed.failure.get("ok")) self.assertTrue(isinstance(failed.request_id, int)) self.assertTrue(isinstance(failed.duration_micros, int)) def test_find_one(self): self.client.pymongo_test.test.find_one() results = self.listener.results - started = results['started'][0] - succeeded = results['succeeded'][0] - self.assertEqual(0, len(results['failed'])) - self.assertTrue( - isinstance(succeeded, monitoring.CommandSucceededEvent)) - self.assertTrue( - isinstance(started, monitoring.CommandStartedEvent)) + started = results["started"][0] + succeeded = results["succeeded"][0] + self.assertEqual(0, len(results["failed"])) + self.assertTrue(isinstance(succeeded, monitoring.CommandSucceededEvent)) + self.assertTrue(isinstance(started, monitoring.CommandStartedEvent)) self.assertEqualCommand( - SON([('find', 'test'), - ('filter', {}), - ('limit', 1), - ('singleBatch', True)]), - started.command) - self.assertEqual('find', started.command_name) + SON([("find", "test"), ("filter", {}), ("limit", 1), ("singleBatch", True)]), + started.command, + ) + self.assertEqual("find", started.command_name) self.assertEqual(self.client.address, started.connection_id) - self.assertEqual('pymongo_test', started.database_name) + self.assertEqual("pymongo_test", started.database_name) self.assertTrue(isinstance(started.request_id, int)) def test_find_and_get_more(self): self.client.pymongo_test.test.drop() self.client.pymongo_test.test.insert_many([{} for _ in range(10)]) self.listener.results.clear() - cursor = self.client.pymongo_test.test.find( - projection={'_id': False}, - batch_size=4) + cursor = self.client.pymongo_test.test.find(projection={"_id": False}, batch_size=4) for _ in range(4): next(cursor) cursor_id = cursor.cursor_id results = self.listener.results - started = results['started'][0] - succeeded = results['succeeded'][0] - self.assertEqual(0, len(results['failed'])) - self.assertTrue( - isinstance(started, monitoring.CommandStartedEvent)) + started = results["started"][0] + succeeded = results["succeeded"][0] + self.assertEqual(0, len(results["failed"])) + self.assertTrue(isinstance(started, monitoring.CommandStartedEvent)) self.assertEqualCommand( - SON([('find', 'test'), - ('filter', {}), - ('projection', {'_id': False}), - ('batchSize', 4)]), - started.command) - self.assertEqual('find', started.command_name) + SON( + [("find", "test"), ("filter", {}), ("projection", {"_id": False}), ("batchSize", 4)] + ), + started.command, + ) + self.assertEqual("find", started.command_name) self.assertEqual(self.client.address, started.connection_id) - self.assertEqual('pymongo_test', started.database_name) + self.assertEqual("pymongo_test", started.database_name) self.assertTrue(isinstance(started.request_id, int)) - self.assertTrue( - isinstance(succeeded, monitoring.CommandSucceededEvent)) + self.assertTrue(isinstance(succeeded, monitoring.CommandSucceededEvent)) self.assertTrue(isinstance(succeeded.duration_micros, int)) - self.assertEqual('find', succeeded.command_name) + self.assertEqual("find", succeeded.command_name) self.assertTrue(isinstance(succeeded.request_id, int)) self.assertEqual(cursor.address, succeeded.connection_id) csr = succeeded.reply["cursor"] @@ -177,24 +157,21 @@ def test_find_and_get_more(self): next(cursor) try: results = self.listener.results - started = results['started'][0] - succeeded = results['succeeded'][0] - self.assertEqual(0, len(results['failed'])) - self.assertTrue( - isinstance(started, monitoring.CommandStartedEvent)) + started = results["started"][0] + succeeded = results["succeeded"][0] + self.assertEqual(0, len(results["failed"])) + self.assertTrue(isinstance(started, monitoring.CommandStartedEvent)) self.assertEqualCommand( - SON([('getMore', cursor_id), - ('collection', 'test'), - ('batchSize', 4)]), - started.command) - self.assertEqual('getMore', started.command_name) + SON([("getMore", cursor_id), ("collection", "test"), ("batchSize", 4)]), + started.command, + ) + self.assertEqual("getMore", started.command_name) self.assertEqual(self.client.address, started.connection_id) - self.assertEqual('pymongo_test', started.database_name) + self.assertEqual("pymongo_test", started.database_name) self.assertTrue(isinstance(started.request_id, int)) - self.assertTrue( - isinstance(succeeded, monitoring.CommandSucceededEvent)) + self.assertTrue(isinstance(succeeded, monitoring.CommandSucceededEvent)) self.assertTrue(isinstance(succeeded.duration_micros, int)) - self.assertEqual('getMore', succeeded.command_name) + self.assertEqual("getMore", succeeded.command_name) self.assertTrue(isinstance(succeeded.request_id, int)) self.assertEqual(cursor.address, succeeded.connection_id) csr = succeeded.reply["cursor"] @@ -206,32 +183,28 @@ def test_find_and_get_more(self): tuple(cursor) def test_find_with_explain(self): - cmd = SON([('explain', SON([('find', 'test'), - ('filter', {})]))]) + cmd = SON([("explain", SON([("find", "test"), ("filter", {})]))]) self.client.pymongo_test.test.drop() self.client.pymongo_test.test.insert_one({}) self.listener.results.clear() coll = self.client.pymongo_test.test # Test that we publish the unwrapped command. if self.client.is_mongos: - coll = coll.with_options( - read_preference=ReadPreference.PRIMARY_PREFERRED) + coll = coll.with_options(read_preference=ReadPreference.PRIMARY_PREFERRED) res = coll.find().explain() results = self.listener.results - started = results['started'][0] - succeeded = results['succeeded'][0] - self.assertEqual(0, len(results['failed'])) - self.assertTrue( - isinstance(started, monitoring.CommandStartedEvent)) + started = results["started"][0] + succeeded = results["succeeded"][0] + self.assertEqual(0, len(results["failed"])) + self.assertTrue(isinstance(started, monitoring.CommandStartedEvent)) self.assertEqualCommand(cmd, started.command) - self.assertEqual('explain', started.command_name) + self.assertEqual("explain", started.command_name) self.assertEqual(self.client.address, started.connection_id) - self.assertEqual('pymongo_test', started.database_name) + self.assertEqual("pymongo_test", started.database_name) self.assertTrue(isinstance(started.request_id, int)) - self.assertTrue( - isinstance(succeeded, monitoring.CommandSucceededEvent)) + self.assertTrue(isinstance(succeeded, monitoring.CommandSucceededEvent)) self.assertTrue(isinstance(succeeded.duration_micros, int)) - self.assertEqual('explain', succeeded.command_name) + self.assertEqual("explain", succeeded.command_name) self.assertTrue(isinstance(succeeded.request_id, int)) self.assertEqual(self.client.address, succeeded.connection_id) self.assertEqual(res, succeeded.reply) @@ -239,34 +212,31 @@ def test_find_with_explain(self): def _test_find_options(self, query, expected_cmd): coll = self.client.pymongo_test.test coll.drop() - coll.create_index('x') - coll.insert_many([{'x': i} for i in range(5)]) + coll.create_index("x") + coll.insert_many([{"x": i} for i in range(5)]) # Test that we publish the unwrapped command. self.listener.results.clear() if self.client.is_mongos: - coll = coll.with_options( - read_preference=ReadPreference.PRIMARY_PREFERRED) + coll = coll.with_options(read_preference=ReadPreference.PRIMARY_PREFERRED) cursor = coll.find(**query) next(cursor) try: results = self.listener.results - started = results['started'][0] - succeeded = results['succeeded'][0] - self.assertEqual(0, len(results['failed'])) - self.assertTrue( - isinstance(started, monitoring.CommandStartedEvent)) + started = results["started"][0] + succeeded = results["succeeded"][0] + self.assertEqual(0, len(results["failed"])) + self.assertTrue(isinstance(started, monitoring.CommandStartedEvent)) self.assertEqualCommand(expected_cmd, started.command) - self.assertEqual('find', started.command_name) + self.assertEqual("find", started.command_name) self.assertEqual(self.client.address, started.connection_id) - self.assertEqual('pymongo_test', started.database_name) + self.assertEqual("pymongo_test", started.database_name) self.assertTrue(isinstance(started.request_id, int)) - self.assertTrue( - isinstance(succeeded, monitoring.CommandSucceededEvent)) + self.assertTrue(isinstance(succeeded, monitoring.CommandSucceededEvent)) self.assertTrue(isinstance(succeeded.duration_micros, int)) - self.assertEqual('find', succeeded.command_name) + self.assertEqual("find", succeeded.command_name) self.assertTrue(isinstance(succeeded.request_id, int)) self.assertEqual(self.client.address, succeeded.connection_id) finally: @@ -274,125 +244,128 @@ def _test_find_options(self, query, expected_cmd): tuple(cursor) def test_find_options(self): - query = dict(filter={}, - hint=[('x', 1)], - max_time_ms=10000, - max={'x': 10}, - min={'x': -10}, - return_key=True, - show_record_id=True, - projection={'x': False}, - skip=1, - no_cursor_timeout=True, - sort=[('_id', 1)], - allow_partial_results=True, - comment='this is a test', - batch_size=2) - - cmd = dict(find='test', - filter={}, - hint=SON([('x', 1)]), - comment='this is a test', - maxTimeMS=10000, - max={'x': 10}, - min={'x': -10}, - returnKey=True, - showRecordId=True, - sort=SON([('_id', 1)]), - projection={'x': False}, - skip=1, - batchSize=2, - noCursorTimeout=True, - allowPartialResults=True) + query = dict( + filter={}, + hint=[("x", 1)], + max_time_ms=10000, + max={"x": 10}, + min={"x": -10}, + return_key=True, + show_record_id=True, + projection={"x": False}, + skip=1, + no_cursor_timeout=True, + sort=[("_id", 1)], + allow_partial_results=True, + comment="this is a test", + batch_size=2, + ) + + cmd = dict( + find="test", + filter={}, + hint=SON([("x", 1)]), + comment="this is a test", + maxTimeMS=10000, + max={"x": 10}, + min={"x": -10}, + returnKey=True, + showRecordId=True, + sort=SON([("_id", 1)]), + projection={"x": False}, + skip=1, + batchSize=2, + noCursorTimeout=True, + allowPartialResults=True, + ) if client_context.version < (4, 1, 0, -1): - query['max_scan'] = 10 - cmd['maxScan'] = 10 + query["max_scan"] = 10 + cmd["maxScan"] = 10 self._test_find_options(query, cmd) @client_context.require_version_max(3, 7, 2) def test_find_snapshot(self): # Test "snapshot" parameter separately, can't combine with "sort". - query = dict(filter={}, - snapshot=True) + query = dict(filter={}, snapshot=True) - cmd = dict(find='test', - filter={}, - snapshot=True) + cmd = dict(find="test", filter={}, snapshot=True) self._test_find_options(query, cmd) def test_command_and_get_more(self): self.client.pymongo_test.test.drop() - self.client.pymongo_test.test.insert_many( - [{'x': 1} for _ in range(10)]) + self.client.pymongo_test.test.insert_many([{"x": 1} for _ in range(10)]) self.listener.results.clear() coll = self.client.pymongo_test.test # Test that we publish the unwrapped command. if self.client.is_mongos: - coll = coll.with_options( - read_preference=ReadPreference.PRIMARY_PREFERRED) - cursor = coll.aggregate( - [{'$project': {'_id': False, 'x': 1}}], batchSize=4) + coll = coll.with_options(read_preference=ReadPreference.PRIMARY_PREFERRED) + cursor = coll.aggregate([{"$project": {"_id": False, "x": 1}}], batchSize=4) for _ in range(4): next(cursor) cursor_id = cursor.cursor_id results = self.listener.results - started = results['started'][0] - succeeded = results['succeeded'][0] - self.assertEqual(0, len(results['failed'])) - self.assertTrue( - isinstance(started, monitoring.CommandStartedEvent)) + started = results["started"][0] + succeeded = results["succeeded"][0] + self.assertEqual(0, len(results["failed"])) + self.assertTrue(isinstance(started, monitoring.CommandStartedEvent)) self.assertEqualCommand( - SON([('aggregate', 'test'), - ('pipeline', [{'$project': {'_id': False, 'x': 1}}]), - ('cursor', {'batchSize': 4})]), - started.command) - self.assertEqual('aggregate', started.command_name) + SON( + [ + ("aggregate", "test"), + ("pipeline", [{"$project": {"_id": False, "x": 1}}]), + ("cursor", {"batchSize": 4}), + ] + ), + started.command, + ) + self.assertEqual("aggregate", started.command_name) self.assertEqual(self.client.address, started.connection_id) - self.assertEqual('pymongo_test', started.database_name) + self.assertEqual("pymongo_test", started.database_name) self.assertTrue(isinstance(started.request_id, int)) - self.assertTrue( - isinstance(succeeded, monitoring.CommandSucceededEvent)) + self.assertTrue(isinstance(succeeded, monitoring.CommandSucceededEvent)) self.assertTrue(isinstance(succeeded.duration_micros, int)) - self.assertEqual('aggregate', succeeded.command_name) + self.assertEqual("aggregate", succeeded.command_name) self.assertTrue(isinstance(succeeded.request_id, int)) self.assertEqual(cursor.address, succeeded.connection_id) - expected_cursor = {'id': cursor_id, - 'ns': 'pymongo_test.test', - 'firstBatch': [{'x': 1} for _ in range(4)]} - self.assertEqualCommand(expected_cursor, succeeded.reply.get('cursor')) + expected_cursor = { + "id": cursor_id, + "ns": "pymongo_test.test", + "firstBatch": [{"x": 1} for _ in range(4)], + } + self.assertEqualCommand(expected_cursor, succeeded.reply.get("cursor")) self.listener.results.clear() next(cursor) try: results = self.listener.results - started = results['started'][0] - succeeded = results['succeeded'][0] - self.assertEqual(0, len(results['failed'])) - self.assertTrue( - isinstance(started, monitoring.CommandStartedEvent)) + started = results["started"][0] + succeeded = results["succeeded"][0] + self.assertEqual(0, len(results["failed"])) + self.assertTrue(isinstance(started, monitoring.CommandStartedEvent)) self.assertEqualCommand( - SON([('getMore', cursor_id), - ('collection', 'test'), - ('batchSize', 4)]), - started.command) - self.assertEqual('getMore', started.command_name) + SON([("getMore", cursor_id), ("collection", "test"), ("batchSize", 4)]), + started.command, + ) + self.assertEqual("getMore", started.command_name) self.assertEqual(self.client.address, started.connection_id) - self.assertEqual('pymongo_test', started.database_name) + self.assertEqual("pymongo_test", started.database_name) self.assertTrue(isinstance(started.request_id, int)) - self.assertTrue( - isinstance(succeeded, monitoring.CommandSucceededEvent)) + self.assertTrue(isinstance(succeeded, monitoring.CommandSucceededEvent)) self.assertTrue(isinstance(succeeded.duration_micros, int)) - self.assertEqual('getMore', succeeded.command_name) + self.assertEqual("getMore", succeeded.command_name) self.assertTrue(isinstance(succeeded.request_id, int)) self.assertEqual(cursor.address, succeeded.connection_id) expected_result = { - 'cursor': {'id': cursor_id, - 'ns': 'pymongo_test.test', - 'nextBatch': [{'x': 1} for _ in range(4)]}, - 'ok': 1.0} + "cursor": { + "id": cursor_id, + "ns": "pymongo_test.test", + "nextBatch": [{"x": 1} for _ in range(4)], + }, + "ok": 1.0, + } self.assertEqualReply(expected_result, succeeded.reply) finally: # Exhaust the cursor to avoid kill cursors. @@ -409,23 +382,20 @@ def test_get_more_failure(self): except Exception: pass results = self.listener.results - started = results['started'][0] - self.assertEqual(0, len(results['succeeded'])) - failed = results['failed'][0] - self.assertTrue( - isinstance(started, monitoring.CommandStartedEvent)) + started = results["started"][0] + self.assertEqual(0, len(results["succeeded"])) + failed = results["failed"][0] + self.assertTrue(isinstance(started, monitoring.CommandStartedEvent)) self.assertEqualCommand( - SON([('getMore', cursor_id), - ('collection', 'test')]), - started.command) - self.assertEqual('getMore', started.command_name) + SON([("getMore", cursor_id), ("collection", "test")]), started.command + ) + self.assertEqual("getMore", started.command_name) self.assertEqual(self.client.address, started.connection_id) - self.assertEqual('pymongo_test', started.database_name) + self.assertEqual("pymongo_test", started.database_name) self.assertTrue(isinstance(started.request_id, int)) - self.assertTrue( - isinstance(failed, monitoring.CommandFailedEvent)) + self.assertTrue(isinstance(failed, monitoring.CommandFailedEvent)) self.assertTrue(isinstance(failed.duration_micros, int)) - self.assertEqual('getMore', failed.command_name) + self.assertEqual("getMore", failed.command_name) self.assertTrue(isinstance(failed.request_id, int)) self.assertEqual(cursor.address, failed.connection_id) self.assertEqual(0, failed.failure.get("ok")) @@ -436,7 +406,7 @@ def test_not_primary_error(self): address = next(iter(client_context.client.secondaries)) client = single_client(*address, event_listeners=[self.listener]) # Clear authentication command results from the listener. - client.admin.command('ping') + client.admin.command("ping") self.listener.results.clear() error = None try: @@ -444,16 +414,14 @@ def test_not_primary_error(self): except NotPrimaryError as exc: error = exc.errors results = self.listener.results - started = results['started'][0] - failed = results['failed'][0] - self.assertEqual(0, len(results['succeeded'])) - self.assertTrue( - isinstance(started, monitoring.CommandStartedEvent)) - self.assertTrue( - isinstance(failed, monitoring.CommandFailedEvent)) - self.assertEqual('findAndModify', failed.command_name) + started = results["started"][0] + failed = results["failed"][0] + self.assertEqual(0, len(results["succeeded"])) + self.assertTrue(isinstance(started, monitoring.CommandStartedEvent)) + self.assertTrue(isinstance(failed, monitoring.CommandFailedEvent)) + self.assertEqual("findAndModify", failed.command_name) self.assertEqual(address, failed.connection_id) - self.assertEqual(0, failed.failure.get('ok')) + self.assertEqual(0, failed.failure.get("ok")) self.assertTrue(isinstance(failed.request_id, int)) self.assertTrue(isinstance(failed.duration_micros, int)) self.assertEqual(error, failed.failure) @@ -464,60 +432,62 @@ def test_exhaust(self): self.client.pymongo_test.test.insert_many([{} for _ in range(11)]) self.listener.results.clear() cursor = self.client.pymongo_test.test.find( - projection={'_id': False}, - batch_size=5, - cursor_type=CursorType.EXHAUST) + projection={"_id": False}, batch_size=5, cursor_type=CursorType.EXHAUST + ) next(cursor) cursor_id = cursor.cursor_id results = self.listener.results - started = results['started'][0] - succeeded = results['succeeded'][0] - self.assertEqual(0, len(results['failed'])) - self.assertTrue( - isinstance(started, monitoring.CommandStartedEvent)) - self.assertEqualCommand(SON([('find', 'test'), - ('filter', {}), - ('projection', {'_id': False}), - ('batchSize', 5)]), started.command) - self.assertEqual('find', started.command_name) + started = results["started"][0] + succeeded = results["succeeded"][0] + self.assertEqual(0, len(results["failed"])) + self.assertTrue(isinstance(started, monitoring.CommandStartedEvent)) + self.assertEqualCommand( + SON( + [("find", "test"), ("filter", {}), ("projection", {"_id": False}), ("batchSize", 5)] + ), + started.command, + ) + self.assertEqual("find", started.command_name) self.assertEqual(cursor.address, started.connection_id) - self.assertEqual('pymongo_test', started.database_name) + self.assertEqual("pymongo_test", started.database_name) self.assertTrue(isinstance(started.request_id, int)) - self.assertTrue( - isinstance(succeeded, monitoring.CommandSucceededEvent)) + self.assertTrue(isinstance(succeeded, monitoring.CommandSucceededEvent)) self.assertTrue(isinstance(succeeded.duration_micros, int)) - self.assertEqual('find', succeeded.command_name) + self.assertEqual("find", succeeded.command_name) self.assertTrue(isinstance(succeeded.request_id, int)) self.assertEqual(cursor.address, succeeded.connection_id) expected_result = { - 'cursor': {'id': cursor_id, - 'ns': 'pymongo_test.test', - 'firstBatch': [{} for _ in range(5)]}, - 'ok': 1} + "cursor": { + "id": cursor_id, + "ns": "pymongo_test.test", + "firstBatch": [{} for _ in range(5)], + }, + "ok": 1, + } self.assertEqualReply(expected_result, succeeded.reply) self.listener.results.clear() tuple(cursor) results = self.listener.results - self.assertEqual(0, len(results['failed'])) - for event in results['started']: + self.assertEqual(0, len(results["failed"])) + for event in results["started"]: self.assertTrue(isinstance(event, monitoring.CommandStartedEvent)) - self.assertEqualCommand(SON([('getMore', cursor_id), - ('collection', 'test'), - ('batchSize', 5)]), event.command) - self.assertEqual('getMore', event.command_name) + self.assertEqualCommand( + SON([("getMore", cursor_id), ("collection", "test"), ("batchSize", 5)]), + event.command, + ) + self.assertEqual("getMore", event.command_name) self.assertEqual(cursor.address, event.connection_id) - self.assertEqual('pymongo_test', event.database_name) + self.assertEqual("pymongo_test", event.database_name) self.assertTrue(isinstance(event.request_id, int)) - for event in results['succeeded']: - self.assertTrue( - isinstance(event, monitoring.CommandSucceededEvent)) + for event in results["succeeded"]: + self.assertTrue(isinstance(event, monitoring.CommandSucceededEvent)) self.assertTrue(isinstance(event.duration_micros, int)) - self.assertEqual('getMore', event.command_name) + self.assertEqual("getMore", event.command_name) self.assertTrue(isinstance(event.request_id, int)) self.assertEqual(cursor.address, event.connection_id) # Last getMore receives a response with cursor id 0. - self.assertEqual(0, results['succeeded'][-1].reply['cursor']['id']) + self.assertEqual(0, results["succeeded"][-1].reply["cursor"]["id"]) def test_kill_cursors(self): with client_knobs(kill_cursor_frequency=0.01): @@ -530,30 +500,30 @@ def test_kill_cursors(self): cursor.close() time.sleep(2) results = self.listener.results - started = results['started'][0] - succeeded = results['succeeded'][0] - self.assertEqual(0, len(results['failed'])) - self.assertTrue( - isinstance(started, monitoring.CommandStartedEvent)) + started = results["started"][0] + succeeded = results["succeeded"][0] + self.assertEqual(0, len(results["failed"])) + self.assertTrue(isinstance(started, monitoring.CommandStartedEvent)) # There could be more than one cursor_id here depending on # when the thread last ran. - self.assertIn(cursor_id, started.command['cursors']) - self.assertEqual('killCursors', started.command_name) + self.assertIn(cursor_id, started.command["cursors"]) + self.assertEqual("killCursors", started.command_name) self.assertIs(type(started.connection_id), tuple) self.assertEqual(cursor.address, started.connection_id) - self.assertEqual('pymongo_test', started.database_name) + self.assertEqual("pymongo_test", started.database_name) self.assertTrue(isinstance(started.request_id, int)) - self.assertTrue( - isinstance(succeeded, monitoring.CommandSucceededEvent)) + self.assertTrue(isinstance(succeeded, monitoring.CommandSucceededEvent)) self.assertTrue(isinstance(succeeded.duration_micros, int)) - self.assertEqual('killCursors', succeeded.command_name) + self.assertEqual("killCursors", succeeded.command_name) self.assertTrue(isinstance(succeeded.request_id, int)) self.assertIs(type(succeeded.connection_id), tuple) self.assertEqual(cursor.address, succeeded.connection_id) # There could be more than one cursor_id here depending on # when the thread last ran. - self.assertTrue(cursor_id in succeeded.reply['cursorsUnknown'] - or cursor_id in succeeded.reply['cursorsKilled']) + self.assertTrue( + cursor_id in succeeded.reply["cursorsUnknown"] + or cursor_id in succeeded.reply["cursorsKilled"] + ) def test_non_bulk_writes(self): coll = self.client.pymongo_test.test @@ -561,18 +531,22 @@ def test_non_bulk_writes(self): self.listener.results.clear() # Implied write concern insert_one - res = coll.insert_one({'x': 1}) + res = coll.insert_one({"x": 1}) results = self.listener.results - started = results['started'][0] - succeeded = results['succeeded'][0] - self.assertEqual(0, len(results['failed'])) + started = results["started"][0] + succeeded = results["succeeded"][0] + self.assertEqual(0, len(results["failed"])) self.assertIsInstance(started, monitoring.CommandStartedEvent) - expected = SON([('insert', coll.name), - ('ordered', True), - ('documents', [{'_id': res.inserted_id, 'x': 1}])]) + expected = SON( + [ + ("insert", coll.name), + ("ordered", True), + ("documents", [{"_id": res.inserted_id, "x": 1}]), + ] + ) self.assertEqualCommand(expected, started.command) - self.assertEqual('pymongo_test', started.database_name) - self.assertEqual('insert', started.command_name) + self.assertEqual("pymongo_test", started.database_name) + self.assertEqual("insert", started.command_name) self.assertIsInstance(started.request_id, int) self.assertEqual(self.client.address, started.connection_id) self.assertIsInstance(succeeded, monitoring.CommandSucceededEvent) @@ -581,25 +555,29 @@ def test_non_bulk_writes(self): self.assertEqual(started.request_id, succeeded.request_id) self.assertEqual(started.connection_id, succeeded.connection_id) reply = succeeded.reply - self.assertEqual(1, reply.get('ok')) - self.assertEqual(1, reply.get('n')) + self.assertEqual(1, reply.get("ok")) + self.assertEqual(1, reply.get("n")) # Unacknowledged insert_one self.listener.results.clear() coll = coll.with_options(write_concern=WriteConcern(w=0)) - res = coll.insert_one({'x': 1}) + res = coll.insert_one({"x": 1}) results = self.listener.results - started = results['started'][0] - succeeded = results['succeeded'][0] - self.assertEqual(0, len(results['failed'])) + started = results["started"][0] + succeeded = results["succeeded"][0] + self.assertEqual(0, len(results["failed"])) self.assertIsInstance(started, monitoring.CommandStartedEvent) - expected = SON([('insert', coll.name), - ('ordered', True), - ('documents', [{'_id': res.inserted_id, 'x': 1}]), - ('writeConcern', {'w': 0})]) + expected = SON( + [ + ("insert", coll.name), + ("ordered", True), + ("documents", [{"_id": res.inserted_id, "x": 1}]), + ("writeConcern", {"w": 0}), + ] + ) self.assertEqualCommand(expected, started.command) - self.assertEqual('pymongo_test', started.database_name) - self.assertEqual('insert', started.command_name) + self.assertEqual("pymongo_test", started.database_name) + self.assertEqual("insert", started.command_name) self.assertIsInstance(started.request_id, int) self.assertEqual(self.client.address, started.connection_id) self.assertIsInstance(succeeded, monitoring.CommandSucceededEvent) @@ -607,24 +585,28 @@ def test_non_bulk_writes(self): self.assertEqual(started.command_name, succeeded.command_name) self.assertEqual(started.request_id, succeeded.request_id) self.assertEqual(started.connection_id, succeeded.connection_id) - self.assertEqualReply(succeeded.reply, {'ok': 1}) + self.assertEqualReply(succeeded.reply, {"ok": 1}) # Explicit write concern insert_one self.listener.results.clear() coll = coll.with_options(write_concern=WriteConcern(w=1)) - res = coll.insert_one({'x': 1}) + res = coll.insert_one({"x": 1}) results = self.listener.results - started = results['started'][0] - succeeded = results['succeeded'][0] - self.assertEqual(0, len(results['failed'])) + started = results["started"][0] + succeeded = results["succeeded"][0] + self.assertEqual(0, len(results["failed"])) self.assertIsInstance(started, monitoring.CommandStartedEvent) - expected = SON([('insert', coll.name), - ('ordered', True), - ('documents', [{'_id': res.inserted_id, 'x': 1}]), - ('writeConcern', {'w': 1})]) + expected = SON( + [ + ("insert", coll.name), + ("ordered", True), + ("documents", [{"_id": res.inserted_id, "x": 1}]), + ("writeConcern", {"w": 1}), + ] + ) self.assertEqualCommand(expected, started.command) - self.assertEqual('pymongo_test', started.database_name) - self.assertEqual('insert', started.command_name) + self.assertEqual("pymongo_test", started.database_name) + self.assertEqual("insert", started.command_name) self.assertIsInstance(started.request_id, int) self.assertEqual(self.client.address, started.connection_id) self.assertIsInstance(succeeded, monitoring.CommandSucceededEvent) @@ -633,25 +615,28 @@ def test_non_bulk_writes(self): self.assertEqual(started.request_id, succeeded.request_id) self.assertEqual(started.connection_id, succeeded.connection_id) reply = succeeded.reply - self.assertEqual(1, reply.get('ok')) - self.assertEqual(1, reply.get('n')) + self.assertEqual(1, reply.get("ok")) + self.assertEqual(1, reply.get("n")) # delete_many self.listener.results.clear() - res = coll.delete_many({'x': 1}) + res = coll.delete_many({"x": 1}) results = self.listener.results - started = results['started'][0] - succeeded = results['succeeded'][0] - self.assertEqual(0, len(results['failed'])) + started = results["started"][0] + succeeded = results["succeeded"][0] + self.assertEqual(0, len(results["failed"])) self.assertIsInstance(started, monitoring.CommandStartedEvent) - expected = SON([('delete', coll.name), - ('ordered', True), - ('deletes', [SON([('q', {'x': 1}), - ('limit', 0)])]), - ('writeConcern', {'w': 1})]) + expected = SON( + [ + ("delete", coll.name), + ("ordered", True), + ("deletes", [SON([("q", {"x": 1}), ("limit", 0)])]), + ("writeConcern", {"w": 1}), + ] + ) self.assertEqualCommand(expected, started.command) - self.assertEqual('pymongo_test', started.database_name) - self.assertEqual('delete', started.command_name) + self.assertEqual("pymongo_test", started.database_name) + self.assertEqual("delete", started.command_name) self.assertIsInstance(started.request_id, int) self.assertEqual(self.client.address, started.connection_id) self.assertIsInstance(succeeded, monitoring.CommandSucceededEvent) @@ -660,28 +645,41 @@ def test_non_bulk_writes(self): self.assertEqual(started.request_id, succeeded.request_id) self.assertEqual(started.connection_id, succeeded.connection_id) reply = succeeded.reply - self.assertEqual(1, reply.get('ok')) - self.assertEqual(res.deleted_count, reply.get('n')) + self.assertEqual(1, reply.get("ok")) + self.assertEqual(res.deleted_count, reply.get("n")) # replace_one self.listener.results.clear() oid = ObjectId() - res = coll.replace_one({'_id': oid}, {'_id': oid, 'x': 1}, upsert=True) + res = coll.replace_one({"_id": oid}, {"_id": oid, "x": 1}, upsert=True) results = self.listener.results - started = results['started'][0] - succeeded = results['succeeded'][0] - self.assertEqual(0, len(results['failed'])) + started = results["started"][0] + succeeded = results["succeeded"][0] + self.assertEqual(0, len(results["failed"])) self.assertIsInstance(started, monitoring.CommandStartedEvent) - expected = SON([('update', coll.name), - ('ordered', True), - ('updates', [SON([('q', {'_id': oid}), - ('u', {'_id': oid, 'x': 1}), - ('multi', False), - ('upsert', True)])]), - ('writeConcern', {'w': 1})]) + expected = SON( + [ + ("update", coll.name), + ("ordered", True), + ( + "updates", + [ + SON( + [ + ("q", {"_id": oid}), + ("u", {"_id": oid, "x": 1}), + ("multi", False), + ("upsert", True), + ] + ) + ], + ), + ("writeConcern", {"w": 1}), + ] + ) self.assertEqualCommand(expected, started.command) - self.assertEqual('pymongo_test', started.database_name) - self.assertEqual('update', started.command_name) + self.assertEqual("pymongo_test", started.database_name) + self.assertEqual("update", started.command_name) self.assertIsInstance(started.request_id, int) self.assertEqual(self.client.address, started.connection_id) self.assertIsInstance(succeeded, monitoring.CommandSucceededEvent) @@ -690,28 +688,41 @@ def test_non_bulk_writes(self): self.assertEqual(started.request_id, succeeded.request_id) self.assertEqual(started.connection_id, succeeded.connection_id) reply = succeeded.reply - self.assertEqual(1, reply.get('ok')) - self.assertEqual(1, reply.get('n')) - self.assertEqual([{'index': 0, '_id': oid}], reply.get('upserted')) + self.assertEqual(1, reply.get("ok")) + self.assertEqual(1, reply.get("n")) + self.assertEqual([{"index": 0, "_id": oid}], reply.get("upserted")) # update_one self.listener.results.clear() - res = coll.update_one({'x': 1}, {'$inc': {'x': 1}}) + res = coll.update_one({"x": 1}, {"$inc": {"x": 1}}) results = self.listener.results - started = results['started'][0] - succeeded = results['succeeded'][0] - self.assertEqual(0, len(results['failed'])) + started = results["started"][0] + succeeded = results["succeeded"][0] + self.assertEqual(0, len(results["failed"])) self.assertIsInstance(started, monitoring.CommandStartedEvent) - expected = SON([('update', coll.name), - ('ordered', True), - ('updates', [SON([('q', {'x': 1}), - ('u', {'$inc': {'x': 1}}), - ('multi', False), - ('upsert', False)])]), - ('writeConcern', {'w': 1})]) + expected = SON( + [ + ("update", coll.name), + ("ordered", True), + ( + "updates", + [ + SON( + [ + ("q", {"x": 1}), + ("u", {"$inc": {"x": 1}}), + ("multi", False), + ("upsert", False), + ] + ) + ], + ), + ("writeConcern", {"w": 1}), + ] + ) self.assertEqualCommand(expected, started.command) - self.assertEqual('pymongo_test', started.database_name) - self.assertEqual('update', started.command_name) + self.assertEqual("pymongo_test", started.database_name) + self.assertEqual("update", started.command_name) self.assertIsInstance(started.request_id, int) self.assertEqual(self.client.address, started.connection_id) self.assertIsInstance(succeeded, monitoring.CommandSucceededEvent) @@ -720,27 +731,40 @@ def test_non_bulk_writes(self): self.assertEqual(started.request_id, succeeded.request_id) self.assertEqual(started.connection_id, succeeded.connection_id) reply = succeeded.reply - self.assertEqual(1, reply.get('ok')) - self.assertEqual(1, reply.get('n')) + self.assertEqual(1, reply.get("ok")) + self.assertEqual(1, reply.get("n")) # update_many self.listener.results.clear() - res = coll.update_many({'x': 2}, {'$inc': {'x': 1}}) + res = coll.update_many({"x": 2}, {"$inc": {"x": 1}}) results = self.listener.results - started = results['started'][0] - succeeded = results['succeeded'][0] - self.assertEqual(0, len(results['failed'])) + started = results["started"][0] + succeeded = results["succeeded"][0] + self.assertEqual(0, len(results["failed"])) self.assertIsInstance(started, monitoring.CommandStartedEvent) - expected = SON([('update', coll.name), - ('ordered', True), - ('updates', [SON([('q', {'x': 2}), - ('u', {'$inc': {'x': 1}}), - ('multi', True), - ('upsert', False)])]), - ('writeConcern', {'w': 1})]) + expected = SON( + [ + ("update", coll.name), + ("ordered", True), + ( + "updates", + [ + SON( + [ + ("q", {"x": 2}), + ("u", {"$inc": {"x": 1}}), + ("multi", True), + ("upsert", False), + ] + ) + ], + ), + ("writeConcern", {"w": 1}), + ] + ) self.assertEqualCommand(expected, started.command) - self.assertEqual('pymongo_test', started.database_name) - self.assertEqual('update', started.command_name) + self.assertEqual("pymongo_test", started.database_name) + self.assertEqual("update", started.command_name) self.assertIsInstance(started.request_id, int) self.assertEqual(self.client.address, started.connection_id) self.assertIsInstance(succeeded, monitoring.CommandSucceededEvent) @@ -749,25 +773,28 @@ def test_non_bulk_writes(self): self.assertEqual(started.request_id, succeeded.request_id) self.assertEqual(started.connection_id, succeeded.connection_id) reply = succeeded.reply - self.assertEqual(1, reply.get('ok')) - self.assertEqual(1, reply.get('n')) + self.assertEqual(1, reply.get("ok")) + self.assertEqual(1, reply.get("n")) # delete_one self.listener.results.clear() - res = coll.delete_one({'x': 3}) + res = coll.delete_one({"x": 3}) results = self.listener.results - started = results['started'][0] - succeeded = results['succeeded'][0] - self.assertEqual(0, len(results['failed'])) + started = results["started"][0] + succeeded = results["succeeded"][0] + self.assertEqual(0, len(results["failed"])) self.assertIsInstance(started, monitoring.CommandStartedEvent) - expected = SON([('delete', coll.name), - ('ordered', True), - ('deletes', [SON([('q', {'x': 3}), - ('limit', 1)])]), - ('writeConcern', {'w': 1})]) + expected = SON( + [ + ("delete", coll.name), + ("ordered", True), + ("deletes", [SON([("q", {"x": 3}), ("limit", 1)])]), + ("writeConcern", {"w": 1}), + ] + ) self.assertEqualCommand(expected, started.command) - self.assertEqual('pymongo_test', started.database_name) - self.assertEqual('delete', started.command_name) + self.assertEqual("pymongo_test", started.database_name) + self.assertEqual("delete", started.command_name) self.assertIsInstance(started.request_id, int) self.assertEqual(self.client.address, started.connection_id) self.assertIsInstance(succeeded, monitoring.CommandSucceededEvent) @@ -776,30 +803,34 @@ def test_non_bulk_writes(self): self.assertEqual(started.request_id, succeeded.request_id) self.assertEqual(started.connection_id, succeeded.connection_id) reply = succeeded.reply - self.assertEqual(1, reply.get('ok')) - self.assertEqual(1, reply.get('n')) + self.assertEqual(1, reply.get("ok")) + self.assertEqual(1, reply.get("n")) self.assertEqual(0, coll.count_documents({})) # write errors - coll.insert_one({'_id': 1}) + coll.insert_one({"_id": 1}) try: self.listener.results.clear() - coll.insert_one({'_id': 1}) + coll.insert_one({"_id": 1}) except OperationFailure: pass results = self.listener.results - started = results['started'][0] - succeeded = results['succeeded'][0] - self.assertEqual(0, len(results['failed'])) + started = results["started"][0] + succeeded = results["succeeded"][0] + self.assertEqual(0, len(results["failed"])) self.assertIsInstance(started, monitoring.CommandStartedEvent) - expected = SON([('insert', coll.name), - ('ordered', True), - ('documents', [{'_id': 1}]), - ('writeConcern', {'w': 1})]) + expected = SON( + [ + ("insert", coll.name), + ("ordered", True), + ("documents", [{"_id": 1}]), + ("writeConcern", {"w": 1}), + ] + ) self.assertEqualCommand(expected, started.command) - self.assertEqual('pymongo_test', started.database_name) - self.assertEqual('insert', started.command_name) + self.assertEqual("pymongo_test", started.database_name) + self.assertEqual("insert", started.command_name) self.assertIsInstance(started.request_id, int) self.assertEqual(self.client.address, started.connection_id) self.assertIsInstance(succeeded, monitoring.CommandSucceededEvent) @@ -808,14 +839,14 @@ def test_non_bulk_writes(self): self.assertEqual(started.request_id, succeeded.request_id) self.assertEqual(started.connection_id, succeeded.connection_id) reply = succeeded.reply - self.assertEqual(1, reply.get('ok')) - self.assertEqual(0, reply.get('n')) - errors = reply.get('writeErrors') + self.assertEqual(1, reply.get("ok")) + self.assertEqual(0, reply.get("n")) + errors = reply.get("writeErrors") self.assertIsInstance(errors, list) error = errors[0] - self.assertEqual(0, error.get('index')) - self.assertIsInstance(error.get('code'), int) - self.assertIsInstance(error.get('errmsg'), str) + self.assertEqual(0, error.get("index")) + self.assertIsInstance(error.get("code"), int) + self.assertIsInstance(error.get("errmsg"), str) def test_insert_many(self): # This always uses the bulk API. @@ -823,13 +854,13 @@ def test_insert_many(self): coll.drop() self.listener.results.clear() - big = 'x' * (1024 * 1024 * 4) - docs = [{'_id': i, 'big': big} for i in range(6)] + big = "x" * (1024 * 1024 * 4) + docs = [{"_id": i, "big": big} for i in range(6)] coll.insert_many(docs) results = self.listener.results - started = results['started'] - succeeded = results['succeeded'] - self.assertEqual(0, len(results['failed'])) + started = results["started"] + succeeded = results["succeeded"] + self.assertEqual(0, len(results["failed"])) documents = [] count = 0 operation_id = started[0].operation_id @@ -837,13 +868,12 @@ def test_insert_many(self): for start, succeed in zip(started, succeeded): self.assertIsInstance(start, monitoring.CommandStartedEvent) cmd = sanitize_cmd(start.command) - self.assertEqual(['insert', 'ordered', 'documents'], - list(cmd.keys())) - self.assertEqual(coll.name, cmd['insert']) - self.assertIs(True, cmd['ordered']) - documents.extend(cmd['documents']) - self.assertEqual('pymongo_test', start.database_name) - self.assertEqual('insert', start.command_name) + self.assertEqual(["insert", "ordered", "documents"], list(cmd.keys())) + self.assertEqual(coll.name, cmd["insert"]) + self.assertIs(True, cmd["ordered"]) + documents.extend(cmd["documents"]) + self.assertEqual("pymongo_test", start.database_name) + self.assertEqual("insert", start.command_name) self.assertIsInstance(start.request_id, int) self.assertEqual(self.client.address, start.connection_id) self.assertIsInstance(succeed, monitoring.CommandSucceededEvent) @@ -854,8 +884,8 @@ def test_insert_many(self): self.assertEqual(start.operation_id, operation_id) self.assertEqual(succeed.operation_id, operation_id) reply = succeed.reply - self.assertEqual(1, reply.get('ok')) - count += reply.get('n', 0) + self.assertEqual(1, reply.get("ok")) + count += reply.get("n", 0) self.assertEqual(documents, docs) self.assertEqual(6, count) @@ -866,27 +896,26 @@ def test_insert_many_unacknowledged(self): self.listener.results.clear() # Force two batches on legacy servers. - big = 'x' * (1024 * 1024 * 12) - docs = [{'_id': i, 'big': big} for i in range(6)] + big = "x" * (1024 * 1024 * 12) + docs = [{"_id": i, "big": big} for i in range(6)] unack_coll.insert_many(docs) results = self.listener.results - started = results['started'] - succeeded = results['succeeded'] - self.assertEqual(0, len(results['failed'])) + started = results["started"] + succeeded = results["succeeded"] + self.assertEqual(0, len(results["failed"])) documents = [] operation_id = started[0].operation_id self.assertIsInstance(operation_id, int) for start, succeed in zip(started, succeeded): self.assertIsInstance(start, monitoring.CommandStartedEvent) cmd = sanitize_cmd(start.command) - cmd.pop('writeConcern', None) - self.assertEqual(['insert', 'ordered', 'documents'], - list(cmd.keys())) - self.assertEqual(coll.name, cmd['insert']) - self.assertIs(True, cmd['ordered']) - documents.extend(cmd['documents']) - self.assertEqual('pymongo_test', start.database_name) - self.assertEqual('insert', start.command_name) + cmd.pop("writeConcern", None) + self.assertEqual(["insert", "ordered", "documents"], list(cmd.keys())) + self.assertEqual(coll.name, cmd["insert"]) + self.assertIs(True, cmd["ordered"]) + documents.extend(cmd["documents"]) + self.assertEqual("pymongo_test", start.database_name) + self.assertEqual("insert", start.command_name) self.assertIsInstance(start.request_id, int) self.assertEqual(self.client.address, start.connection_id) self.assertIsInstance(succeed, monitoring.CommandSucceededEvent) @@ -896,29 +925,32 @@ def test_insert_many_unacknowledged(self): self.assertEqual(start.connection_id, succeed.connection_id) self.assertEqual(start.operation_id, operation_id) self.assertEqual(succeed.operation_id, operation_id) - self.assertEqual(1, succeed.reply.get('ok')) + self.assertEqual(1, succeed.reply.get("ok")) self.assertEqual(documents, docs) - wait_until(lambda: coll.count_documents({}) == 6, - 'insert documents with w=0') + wait_until(lambda: coll.count_documents({}) == 6, "insert documents with w=0") def test_bulk_write(self): coll = self.client.pymongo_test.test coll.drop() self.listener.results.clear() - coll.bulk_write([InsertOne({'_id': 1}), - UpdateOne({'_id': 1}, {'$set': {'x': 1}}), - DeleteOne({'_id': 1})]) + coll.bulk_write( + [ + InsertOne({"_id": 1}), + UpdateOne({"_id": 1}, {"$set": {"x": 1}}), + DeleteOne({"_id": 1}), + ] + ) results = self.listener.results - started = results['started'] - succeeded = results['succeeded'] - self.assertEqual(0, len(results['failed'])) + started = results["started"] + succeeded = results["succeeded"] + self.assertEqual(0, len(results["failed"])) operation_id = started[0].operation_id pairs = list(zip(started, succeeded)) self.assertEqual(3, len(pairs)) for start, succeed in pairs: self.assertIsInstance(start, monitoring.CommandStartedEvent) - self.assertEqual('pymongo_test', start.database_name) + self.assertEqual("pymongo_test", start.database_name) self.assertIsInstance(start.request_id, int) self.assertEqual(self.client.address, start.connection_id) self.assertIsInstance(succeed, monitoring.CommandSucceededEvent) @@ -929,21 +961,35 @@ def test_bulk_write(self): self.assertEqual(start.operation_id, operation_id) self.assertEqual(succeed.operation_id, operation_id) - expected = SON([('insert', coll.name), - ('ordered', True), - ('documents', [{'_id': 1}])]) + expected = SON([("insert", coll.name), ("ordered", True), ("documents", [{"_id": 1}])]) self.assertEqualCommand(expected, started[0].command) - expected = SON([('update', coll.name), - ('ordered', True), - ('updates', [SON([('q', {'_id': 1}), - ('u', {'$set': {'x': 1}}), - ('multi', False), - ('upsert', False)])])]) + expected = SON( + [ + ("update", coll.name), + ("ordered", True), + ( + "updates", + [ + SON( + [ + ("q", {"_id": 1}), + ("u", {"$set": {"x": 1}}), + ("multi", False), + ("upsert", False), + ] + ) + ], + ), + ] + ) self.assertEqualCommand(expected, started[1].command) - expected = SON([('delete', coll.name), - ('ordered', True), - ('deletes', [SON([('q', {'_id': 1}), - ('limit', 1)])])]) + expected = SON( + [ + ("delete", coll.name), + ("ordered", True), + ("deletes", [SON([("q", {"_id": 1}), ("limit", 1)])]), + ] + ) self.assertEqualCommand(expected, started[2].command) @client_context.require_failCommand_fail_point @@ -952,23 +998,23 @@ def test_bulk_write_command_network_error(self): self.listener.results.clear() insert_network_error = { - 'configureFailPoint': 'failCommand', - 'mode': {'times': 1}, - 'data': { - 'failCommands': ['insert'], - 'closeConnection': True, + "configureFailPoint": "failCommand", + "mode": {"times": 1}, + "data": { + "failCommands": ["insert"], + "closeConnection": True, }, } with self.fail_point(insert_network_error): with self.assertRaises(AutoReconnect): - coll.bulk_write([InsertOne({'_id': 1})]) - failed = self.listener.results['failed'] + coll.bulk_write([InsertOne({"_id": 1})]) + failed = self.listener.results["failed"] self.assertEqual(1, len(failed)) event = failed[0] - self.assertEqual(event.command_name, 'insert') + self.assertEqual(event.command_name, "insert") self.assertIsInstance(event.failure, dict) - self.assertEqual(event.failure['errtype'], 'AutoReconnect') - self.assertTrue(event.failure['errmsg']) + self.assertEqual(event.failure["errtype"], "AutoReconnect") + self.assertTrue(event.failure["errmsg"]) @client_context.require_failCommand_fail_point def test_bulk_write_command_error(self): @@ -976,24 +1022,24 @@ def test_bulk_write_command_error(self): self.listener.results.clear() insert_command_error = { - 'configureFailPoint': 'failCommand', - 'mode': {'times': 1}, - 'data': { - 'failCommands': ['insert'], - 'closeConnection': False, - 'errorCode': 10107, # Not primary + "configureFailPoint": "failCommand", + "mode": {"times": 1}, + "data": { + "failCommands": ["insert"], + "closeConnection": False, + "errorCode": 10107, # Not primary }, } with self.fail_point(insert_command_error): with self.assertRaises(NotPrimaryError): - coll.bulk_write([InsertOne({'_id': 1})]) - failed = self.listener.results['failed'] + coll.bulk_write([InsertOne({"_id": 1})]) + failed = self.listener.results["failed"] self.assertEqual(1, len(failed)) event = failed[0] - self.assertEqual(event.command_name, 'insert') + self.assertEqual(event.command_name, "insert") self.assertIsInstance(event.failure, dict) - self.assertEqual(event.failure['code'], 10107) - self.assertTrue(event.failure['errmsg']) + self.assertEqual(event.failure["code"], 10107) + self.assertTrue(event.failure["errmsg"]) def test_write_errors(self): coll = self.client.pymongo_test.test @@ -1001,23 +1047,27 @@ def test_write_errors(self): self.listener.results.clear() try: - coll.bulk_write([InsertOne({'_id': 1}), - InsertOne({'_id': 1}), - InsertOne({'_id': 1}), - DeleteOne({'_id': 1})], - ordered=False) + coll.bulk_write( + [ + InsertOne({"_id": 1}), + InsertOne({"_id": 1}), + InsertOne({"_id": 1}), + DeleteOne({"_id": 1}), + ], + ordered=False, + ) except OperationFailure: pass results = self.listener.results - started = results['started'] - succeeded = results['succeeded'] - self.assertEqual(0, len(results['failed'])) + started = results["started"] + succeeded = results["succeeded"] + self.assertEqual(0, len(results["failed"])) operation_id = started[0].operation_id pairs = list(zip(started, succeeded)) errors = [] for start, succeed in pairs: self.assertIsInstance(start, monitoring.CommandStartedEvent) - self.assertEqual('pymongo_test', start.database_name) + self.assertEqual("pymongo_test", start.database_name) self.assertIsInstance(start.request_id, int) self.assertEqual(self.client.address, start.connection_id) self.assertIsInstance(succeed, monitoring.CommandSucceededEvent) @@ -1027,11 +1077,11 @@ def test_write_errors(self): self.assertEqual(start.connection_id, succeed.connection_id) self.assertEqual(start.operation_id, operation_id) self.assertEqual(succeed.operation_id, operation_id) - if 'writeErrors' in succeed.reply: - errors.extend(succeed.reply['writeErrors']) + if "writeErrors" in succeed.reply: + errors.extend(succeed.reply["writeErrors"]) self.assertEqual(2, len(errors)) - fields = set(['index', 'code', 'errmsg']) + fields = set(["index", "code", "errmsg"]) for error in errors: self.assertTrue(fields.issubset(set(error))) @@ -1041,14 +1091,14 @@ def test_first_batch_helper(self): self.listener.results.clear() tuple(self.client.pymongo_test.test.list_indexes()) results = self.listener.results - started = results['started'][0] - succeeded = results['succeeded'][0] - self.assertEqual(0, len(results['failed'])) + started = results["started"][0] + succeeded = results["succeeded"][0] + self.assertEqual(0, len(results["failed"])) self.assertIsInstance(started, monitoring.CommandStartedEvent) - expected = SON([('listIndexes', 'test'), ('cursor', {})]) + expected = SON([("listIndexes", "test"), ("cursor", {})]) self.assertEqualCommand(expected, started.command) - self.assertEqual('pymongo_test', started.database_name) - self.assertEqual('listIndexes', started.command_name) + self.assertEqual("pymongo_test", started.database_name) + self.assertEqual("listIndexes", started.command_name) self.assertIsInstance(started.request_id, int) self.assertEqual(self.client.address, started.connection_id) self.assertIsInstance(succeeded, monitoring.CommandSucceededEvent) @@ -1056,8 +1106,8 @@ def test_first_batch_helper(self): self.assertEqual(started.command_name, succeeded.command_name) self.assertEqual(started.request_id, succeeded.request_id) self.assertEqual(started.connection_id, succeeded.connection_id) - self.assertTrue('cursor' in succeeded.reply) - self.assertTrue('ok' in succeeded.reply) + self.assertTrue("cursor" in succeeded.reply) + self.assertTrue("ok" in succeeded.reply) self.listener.results.clear() @@ -1066,20 +1116,19 @@ def test_sensitive_commands(self): self.listener.results.clear() cmd = SON([("getnonce", 1)]) - listeners.publish_command_start( - cmd, "pymongo_test", 12345, self.client.address) + listeners.publish_command_start(cmd, "pymongo_test", 12345, self.client.address) delta = datetime.timedelta(milliseconds=100) listeners.publish_command_success( - delta, {'nonce': 'e474f4561c5eb40b', 'ok': 1.0}, - "getnonce", 12345, self.client.address) + delta, {"nonce": "e474f4561c5eb40b", "ok": 1.0}, "getnonce", 12345, self.client.address + ) results = self.listener.results - started = results['started'][0] - succeeded = results['succeeded'][0] - self.assertEqual(0, len(results['failed'])) + started = results["started"][0] + succeeded = results["succeeded"][0] + self.assertEqual(0, len(results["failed"])) self.assertIsInstance(started, monitoring.CommandStartedEvent) self.assertEqual({}, started.command) - self.assertEqual('pymongo_test', started.database_name) - self.assertEqual('getnonce', started.command_name) + self.assertEqual("pymongo_test", started.database_name) + self.assertEqual("getnonce", started.command_name) self.assertIsInstance(started.request_id, int) self.assertEqual(self.client.address, started.connection_id) self.assertIsInstance(succeeded, monitoring.CommandSucceededEvent) @@ -1091,7 +1140,6 @@ def test_sensitive_commands(self): class TestGlobalListener(IntegrationTest): - @classmethod @client_context.require_connection def setUpClass(cls): @@ -1102,7 +1150,7 @@ def setUpClass(cls): monitoring.register(cls.listener) cls.client = single_client() # Get one (authenticated) socket in the pool. - cls.client.pymongo_test.command('ping') + cls.client.pymongo_test.command("ping") @classmethod def tearDownClass(cls): @@ -1115,107 +1163,101 @@ def setUp(self): self.listener.results.clear() def test_simple(self): - self.client.pymongo_test.command('ping') + self.client.pymongo_test.command("ping") results = self.listener.results - started = results['started'][0] - succeeded = results['succeeded'][0] - self.assertEqual(0, len(results['failed'])) - self.assertTrue( - isinstance(succeeded, monitoring.CommandSucceededEvent)) - self.assertTrue( - isinstance(started, monitoring.CommandStartedEvent)) - self.assertEqualCommand(SON([('ping', 1)]), started.command) - self.assertEqual('ping', started.command_name) + started = results["started"][0] + succeeded = results["succeeded"][0] + self.assertEqual(0, len(results["failed"])) + self.assertTrue(isinstance(succeeded, monitoring.CommandSucceededEvent)) + self.assertTrue(isinstance(started, monitoring.CommandStartedEvent)) + self.assertEqualCommand(SON([("ping", 1)]), started.command) + self.assertEqual("ping", started.command_name) self.assertEqual(self.client.address, started.connection_id) - self.assertEqual('pymongo_test', started.database_name) + self.assertEqual("pymongo_test", started.database_name) self.assertTrue(isinstance(started.request_id, int)) class TestEventClasses(unittest.TestCase): - def test_command_event_repr(self): - request_id, connection_id, operation_id = 1, ('localhost', 27017), 2 + request_id, connection_id, operation_id = 1, ("localhost", 27017), 2 event = monitoring.CommandStartedEvent( - {'ping': 1}, 'admin', request_id, connection_id, operation_id) + {"ping": 1}, "admin", request_id, connection_id, operation_id + ) self.assertEqual( repr(event), "") + "command: 'ping', operation_id: 2, service_id: None>", + ) delta = datetime.timedelta(milliseconds=100) event = monitoring.CommandSucceededEvent( - delta, {'ok': 1}, 'ping', request_id, connection_id, - operation_id) + delta, {"ok": 1}, "ping", request_id, connection_id, operation_id + ) self.assertEqual( repr(event), "") + "service_id: None>", + ) event = monitoring.CommandFailedEvent( - delta, {'ok': 0}, 'ping', request_id, connection_id, - operation_id) + delta, {"ok": 0}, "ping", request_id, connection_id, operation_id + ) self.assertEqual( repr(event), "") + "failure: {'ok': 0}, service_id: None>", + ) def test_server_heartbeat_event_repr(self): - connection_id = ('localhost', 27017) + connection_id = ("localhost", 27017) event = monitoring.ServerHeartbeatStartedEvent(connection_id) - self.assertEqual( - repr(event), - "") + self.assertEqual(repr(event), "") delta = 0.1 - event = monitoring.ServerHeartbeatSucceededEvent( - delta, {'ok': 1}, connection_id) + event = monitoring.ServerHeartbeatSucceededEvent(delta, {"ok": 1}, connection_id) self.assertEqual( repr(event), "") - event = monitoring.ServerHeartbeatFailedEvent( - delta, 'ERROR', connection_id) + "duration: 0.1, awaited: False, reply: {'ok': 1}>", + ) + event = monitoring.ServerHeartbeatFailedEvent(delta, "ERROR", connection_id) self.assertEqual( repr(event), "") + "duration: 0.1, awaited: False, reply: 'ERROR'>", + ) def test_server_event_repr(self): - server_address = ('localhost', 27017) - topology_id = ObjectId('000000000000000000000001') + server_address = ("localhost", 27017) + topology_id = ObjectId("000000000000000000000001") event = monitoring.ServerOpeningEvent(server_address, topology_id) self.assertEqual( repr(event), - "") - event = monitoring.ServerDescriptionChangedEvent( - 'PREV', 'NEW', server_address, topology_id) + "", + ) + event = monitoring.ServerDescriptionChangedEvent("PREV", "NEW", server_address, topology_id) self.assertEqual( repr(event), - "") + "", + ) event = monitoring.ServerClosedEvent(server_address, topology_id) self.assertEqual( repr(event), - "") + "", + ) def test_topology_event_repr(self): - topology_id = ObjectId('000000000000000000000001') + topology_id = ObjectId("000000000000000000000001") event = monitoring.TopologyOpenedEvent(topology_id) - self.assertEqual( - repr(event), - "") - event = monitoring.TopologyDescriptionChangedEvent( - 'PREV', 'NEW', topology_id) + self.assertEqual(repr(event), "") + event = monitoring.TopologyDescriptionChangedEvent("PREV", "NEW", topology_id) self.assertEqual( repr(event), "") + "changed from: PREV, to: NEW>", + ) event = monitoring.TopologyClosedEvent(topology_id) - self.assertEqual( - repr(event), - "") + self.assertEqual(repr(event), "") if __name__ == "__main__": diff --git a/test/test_objectid.py b/test/test_objectid.py index 490505234e..c768b0596b 100644 --- a/test/test_objectid.py +++ b/test/test_objectid.py @@ -21,13 +21,13 @@ sys.path[0:0] = [""] -from bson.errors import InvalidId -from bson.objectid import ObjectId, _MAX_COUNTER_VALUE -from bson.tz_util import (FixedOffset, - utc) from test import SkipTest, unittest from test.utils import oid_generated_on_process +from bson.errors import InvalidId +from bson.objectid import _MAX_COUNTER_VALUE, ObjectId +from bson.tz_util import FixedOffset, utc + def oid(x): return ObjectId() @@ -57,29 +57,28 @@ def test_from_hex(self): self.assertRaises(InvalidId, ObjectId, "123456789012123456789G12") def test_repr_str(self): - self.assertEqual(repr(ObjectId("1234567890abcdef12345678")), - "ObjectId('1234567890abcdef12345678')") - self.assertEqual(str(ObjectId("1234567890abcdef12345678")), - "1234567890abcdef12345678") - self.assertEqual(str(ObjectId(b"123456789012")), - "313233343536373839303132") - self.assertEqual(ObjectId("1234567890abcdef12345678").binary, - b'\x124Vx\x90\xab\xcd\xef\x124Vx') - self.assertEqual(str(ObjectId(b'\x124Vx\x90\xab\xcd\xef\x124Vx')), - "1234567890abcdef12345678") + self.assertEqual( + repr(ObjectId("1234567890abcdef12345678")), "ObjectId('1234567890abcdef12345678')" + ) + self.assertEqual(str(ObjectId("1234567890abcdef12345678")), "1234567890abcdef12345678") + self.assertEqual(str(ObjectId(b"123456789012")), "313233343536373839303132") + self.assertEqual( + ObjectId("1234567890abcdef12345678").binary, b"\x124Vx\x90\xab\xcd\xef\x124Vx" + ) + self.assertEqual( + str(ObjectId(b"\x124Vx\x90\xab\xcd\xef\x124Vx")), "1234567890abcdef12345678" + ) def test_equality(self): a = ObjectId() self.assertEqual(a, ObjectId(a)) - self.assertEqual(ObjectId(b"123456789012"), - ObjectId(b"123456789012")) + self.assertEqual(ObjectId(b"123456789012"), ObjectId(b"123456789012")) self.assertNotEqual(ObjectId(), ObjectId()) self.assertNotEqual(ObjectId(b"123456789012"), b"123456789012") # Explicitly test inequality self.assertFalse(a != ObjectId(a)) - self.assertFalse(ObjectId(b"123456789012") != - ObjectId(b"123456789012")) + self.assertFalse(ObjectId(b"123456789012") != ObjectId(b"123456789012")) def test_binary_str_equivalence(self): a = ObjectId() @@ -95,7 +94,7 @@ def test_generation_time(self): self.assertTrue(d2 - d1 < datetime.timedelta(seconds=2)) def test_from_datetime(self): - if 'PyPy 1.8.0' in sys.version: + if "PyPy 1.8.0" in sys.version: # See https://bugs.pypy.org/issue1092 raise SkipTest("datetime.timedelta is broken in pypy 1.8.0") d = datetime.datetime.utcnow() @@ -104,8 +103,7 @@ def test_from_datetime(self): self.assertEqual(d, oid.generation_time.replace(tzinfo=None)) self.assertEqual("0" * 16, str(oid)[8:]) - aware = datetime.datetime(1993, 4, 4, 2, - tzinfo=FixedOffset(555, "SomeZone")) + aware = datetime.datetime(1993, 4, 4, 2, tzinfo=FixedOffset(555, "SomeZone")) as_utc = (aware - aware.utcoffset()).replace(tzinfo=utc) oid = ObjectId.from_datetime(aware) self.assertEqual(as_utc, oid.generation_time) @@ -124,7 +122,8 @@ def test_pickle_backwards_compatability(self): b"(cbson.objectid\nObjectId\np1\nc__builtin__\n" b"object\np2\nNtp3\nRp4\n" b"(dp5\nS'_ObjectId__id'\np6\n" - b"S'M\\x9afV\\x13v\\xc0\\x0b\\x88\\x00\\x00\\x00'\np7\nsb.") + b"S'M\\x9afV\\x13v\\xc0\\x0b\\x88\\x00\\x00\\x00'\np7\nsb." + ) # We also test against a hardcoded "New" pickle format so that we # make sure we're backward compatible with the current version in @@ -133,11 +132,12 @@ def test_pickle_backwards_compatability(self): b"ccopy_reg\n_reconstructor\np0\n" b"(cbson.objectid\nObjectId\np1\nc__builtin__\n" b"object\np2\nNtp3\nRp4\n" - b"S'M\\x9afV\\x13v\\xc0\\x0b\\x88\\x00\\x00\\x00'\np5\nb.") + b"S'M\\x9afV\\x13v\\xc0\\x0b\\x88\\x00\\x00\\x00'\np5\nb." + ) # Have to load using 'latin-1' since these were pickled in python2.x. - oid_1_9 = pickle.loads(pickled_with_1_9, encoding='latin-1') - oid_1_10 = pickle.loads(pickled_with_1_10, encoding='latin-1') + oid_1_9 = pickle.loads(pickled_with_1_9, encoding="latin-1") + oid_1_10 = pickle.loads(pickled_with_1_10, encoding="latin-1") self.assertEqual(oid_1_9, ObjectId("4d9a66561376c00b88000000")) self.assertEqual(oid_1_9, oid_1_10) @@ -187,9 +187,7 @@ def generate_objectid_with_timestamp(timestamp): oid.generation_time except (OverflowError, ValueError): continue - self.assertEqual( - oid.generation_time, - datetime.datetime(*exp_datetime_args, tzinfo=utc)) + self.assertEqual(oid.generation_time, datetime.datetime(*exp_datetime_args, tzinfo=utc)) def test_random_regenerated_on_pid_change(self): # Test that change of pid triggers new random number generation. diff --git a/test/test_ocsp_cache.py b/test/test_ocsp_cache.py index 7562d0f74a..6e0e10edac 100644 --- a/test/test_ocsp_cache.py +++ b/test/test_ocsp_cache.py @@ -14,36 +14,35 @@ """Test the pymongo ocsp_support module.""" +import random +import sys from collections import namedtuple from datetime import datetime, timedelta from os import urandom -import random -import sys from time import sleep sys.path[0:0] = [""] -from pymongo.ocsp_cache import _OCSPCache from test import unittest +from pymongo.ocsp_cache import _OCSPCache + class TestOcspCache(unittest.TestCase): @classmethod def setUpClass(cls): - cls.MockHashAlgorithm = namedtuple( - "MockHashAlgorithm", ['name']) + cls.MockHashAlgorithm = namedtuple("MockHashAlgorithm", ["name"]) cls.MockOcspRequest = namedtuple( - "MockOcspRequest", ['hash_algorithm', 'issuer_name_hash', - 'issuer_key_hash', 'serial_number']) - cls.MockOcspResponse = namedtuple( - "MockOcspResponse", ["this_update", "next_update"]) + "MockOcspRequest", + ["hash_algorithm", "issuer_name_hash", "issuer_key_hash", "serial_number"], + ) + cls.MockOcspResponse = namedtuple("MockOcspResponse", ["this_update", "next_update"]) def setUp(self): self.cache = _OCSPCache() def _create_mock_request(self): - hash_algorithm = self.MockHashAlgorithm( - random.choice(['sha1', 'md5', 'sha256'])) + hash_algorithm = self.MockHashAlgorithm(random.choice(["sha1", "md5", "sha256"])) issuer_name_hash = urandom(8) issuer_key_hash = urandom(8) serial_number = random.randint(0, 10**10) @@ -51,19 +50,17 @@ def _create_mock_request(self): hash_algorithm=hash_algorithm, issuer_name_hash=issuer_name_hash, issuer_key_hash=issuer_key_hash, - serial_number=serial_number) + serial_number=serial_number, + ) - def _create_mock_response(self, this_update_delta_seconds, - next_update_delta_seconds): + def _create_mock_response(self, this_update_delta_seconds, next_update_delta_seconds): now = datetime.utcnow() this_update = now + timedelta(seconds=this_update_delta_seconds) if next_update_delta_seconds is not None: next_update = now + timedelta(seconds=next_update_delta_seconds) else: next_update = None - return self.MockOcspResponse( - this_update=this_update, - next_update=next_update) + return self.MockOcspResponse(this_update=this_update, next_update=next_update) def _add_mock_cache_entry(self, mock_request, mock_response): key = self.cache._get_cache_key(mock_request) diff --git a/test/test_pooling.py b/test/test_pooling.py index b8f3cf1908..bd4d5f5980 100644 --- a/test/test_pooling.py +++ b/test/test_pooling.py @@ -21,29 +21,25 @@ import threading import time -from bson.son import SON from bson.codec_options import DEFAULT_CODEC_OPTIONS - +from bson.son import SON from pymongo import MongoClient, message -from pymongo.errors import (AutoReconnect, - ConnectionFailure, - DuplicateKeyError) +from pymongo.errors import AutoReconnect, ConnectionFailure, DuplicateKeyError sys.path[0:0] = [""] +from test import IntegrationTest, client_context, unittest +from test.utils import delay, get_pool, joinall, rs_or_single_client + from pymongo.pool import Pool, PoolOptions from pymongo.socket_checker import SocketChecker -from test import client_context, IntegrationTest, unittest -from test.utils import (get_pool, - joinall, - delay, - rs_or_single_client) @client_context.require_connection def setUpModule(): pass + N = 10 DB = "pymongo-pooling-tests" @@ -62,6 +58,7 @@ def gc_collect_until_done(threads, timeout=60): class MongoThread(threading.Thread): """A thread that uses a MongoClient.""" + def __init__(self, client): super(MongoThread, self).__init__() self.daemon = True # Don't hang whole test if thread hangs. @@ -108,21 +105,22 @@ class SocketGetter(MongoThread): Checks out a socket and holds it forever. Used in test_no_wait_queue_timeout. """ + def __init__(self, client, pool): super(SocketGetter, self).__init__(client) - self.state = 'init' + self.state = "init" self.pool = pool self.sock = None def run_mongo_thread(self): - self.state = 'get_socket' + self.state = "get_socket" # Call 'pin_cursor' so we can hold the socket. with self.pool.get_socket({}) as sock: sock.pin_cursor() self.sock = sock - self.state = 'sock' + self.state = "sock" def __del__(self): if self.sock: @@ -162,16 +160,12 @@ def tearDown(self): self.c.close() super(_TestPoolingBase, self).tearDown() - def create_pool( - self, - pair=(client_context.host, client_context.port), - *args, - **kwargs): + def create_pool(self, pair=(client_context.host, client_context.port), *args, **kwargs): # Start the pool with the correct ssl options. pool_options = client_context.client._topology_settings.pool_options - kwargs['ssl_context'] = pool_options._ssl_context - kwargs['tls_allow_invalid_hostnames'] = pool_options.tls_allow_invalid_hostnames - kwargs['server_api'] = pool_options.server_api + kwargs["ssl_context"] = pool_options._ssl_context + kwargs["tls_allow_invalid_hostnames"] = pool_options.tls_allow_invalid_hostnames + kwargs["server_api"] = pool_options.server_api pool = Pool(pair, PoolOptions(*args, **kwargs)) pool.ready() return pool @@ -180,11 +174,9 @@ def create_pool( class TestPooling(_TestPoolingBase): def test_max_pool_size_validation(self): host, port = client_context.host, client_context.port - self.assertRaises( - ValueError, MongoClient, host=host, port=port, maxPoolSize=-1) + self.assertRaises(ValueError, MongoClient, host=host, port=port, maxPoolSize=-1) - self.assertRaises( - ValueError, MongoClient, host=host, port=port, maxPoolSize='foo') + self.assertRaises(ValueError, MongoClient, host=host, port=port, maxPoolSize="foo") c = MongoClient(host=host, port=port, maxPoolSize=100, connect=False) self.assertEqual(c.options.pool_options.max_pool_size, 100) @@ -264,27 +256,27 @@ def test_socket_checker(self): # Socket has nothing to read. self.assertFalse(socket_checker.select(s, read=True)) self.assertFalse(socket_checker.select(s, read=True, timeout=0)) - self.assertFalse(socket_checker.select(s, read=True, timeout=.05)) + self.assertFalse(socket_checker.select(s, read=True, timeout=0.05)) # Socket is writable. self.assertTrue(socket_checker.select(s, write=True, timeout=None)) self.assertTrue(socket_checker.select(s, write=True)) self.assertTrue(socket_checker.select(s, write=True, timeout=0)) - self.assertTrue(socket_checker.select(s, write=True, timeout=.05)) + self.assertTrue(socket_checker.select(s, write=True, timeout=0.05)) # Make the socket readable _, msg, _ = message._query( - 0, 'admin.$cmd', 0, -1, SON([('ping', 1)]), None, - DEFAULT_CODEC_OPTIONS) + 0, "admin.$cmd", 0, -1, SON([("ping", 1)]), None, DEFAULT_CODEC_OPTIONS + ) s.sendall(msg) # Block until the socket is readable. self.assertTrue(socket_checker.select(s, read=True, timeout=None)) self.assertTrue(socket_checker.select(s, read=True)) self.assertTrue(socket_checker.select(s, read=True, timeout=0)) - self.assertTrue(socket_checker.select(s, read=True, timeout=.05)) + self.assertTrue(socket_checker.select(s, read=True, timeout=0.05)) # Socket is still writable. self.assertTrue(socket_checker.select(s, write=True, timeout=None)) self.assertTrue(socket_checker.select(s, write=True)) self.assertTrue(socket_checker.select(s, write=True, timeout=0)) - self.assertTrue(socket_checker.select(s, write=True, timeout=.05)) + self.assertTrue(socket_checker.select(s, write=True, timeout=0.05)) s.close() self.assertTrue(socket_checker.socket_closed(s)) @@ -303,9 +295,7 @@ def test_return_socket_after_reset(self): def test_pool_check(self): # Test that Pool recovers from two connection failures in a row. # This exercises code at the end of Pool._check(). - cx_pool = self.create_pool(max_pool_size=1, - connect_timeout=1, - wait_queue_timeout=1) + cx_pool = self.create_pool(max_pool_size=1, connect_timeout=1, wait_queue_timeout=1) cx_pool._check_interval_seconds = 0 # Always check. self.addCleanup(cx_pool.close) @@ -315,7 +305,7 @@ def test_pool_check(self): sock_info.sock.close() # Swap pool's address with a bad one. - address, cx_pool.address = cx_pool.address, ('foo.com', 1234) + address, cx_pool.address = cx_pool.address, ("foo.com", 1234) with self.assertRaises(AutoReconnect): with cx_pool.get_socket({}): pass @@ -327,8 +317,7 @@ def test_pool_check(self): def test_wait_queue_timeout(self): wait_queue_timeout = 2 # Seconds - pool = self.create_pool( - max_pool_size=1, wait_queue_timeout=wait_queue_timeout) + pool = self.create_pool(max_pool_size=1, wait_queue_timeout=wait_queue_timeout) self.addCleanup(pool.close) with pool.get_socket({}) as sock_info: @@ -340,8 +329,8 @@ def test_wait_queue_timeout(self): duration = time.time() - start self.assertTrue( abs(wait_queue_timeout - duration) < 1, - "Waited %.2f seconds for a socket, expected %f" % ( - duration, wait_queue_timeout)) + "Waited %.2f seconds for a socket, expected %f" % (duration, wait_queue_timeout), + ) def test_no_wait_queue_timeout(self): # Verify get_socket() with no wait_queue_timeout blocks forever. @@ -352,16 +341,16 @@ def test_no_wait_queue_timeout(self): with pool.get_socket({}) as s1: t = SocketGetter(self.c, pool) t.start() - while t.state != 'get_socket': + while t.state != "get_socket": time.sleep(0.1) time.sleep(1) - self.assertEqual(t.state, 'get_socket') + self.assertEqual(t.state, "get_socket") - while t.state != 'sock': + while t.state != "sock": time.sleep(0.1) - self.assertEqual(t.state, 'sock') + self.assertEqual(t.state, "sock") self.assertEqual(t.sock, s1) def test_checkout_more_than_max_pool_size(self): @@ -381,7 +370,7 @@ def test_checkout_more_than_max_pool_size(self): threads.append(t) time.sleep(1) for t in threads: - self.assertEqual(t.state, 'get_socket') + self.assertEqual(t.state, "get_socket") for socket_info in socks: socket_info.close_socket(None) @@ -394,7 +383,8 @@ def test_maxConnecting(self): # Run 50 short running operations def find_one(): - docs.append(client.test.test.find_one({'$where': delay(0.001)})) + docs.append(client.test.test.find_one({"$where": delay(0.001)})) + threads = [threading.Thread(target=find_one) for _ in range(50)] for thread in threads: thread.start() @@ -443,7 +433,7 @@ def test_max_pool_size(self): def f(): for _ in range(5): - collection.find_one({'$where': delay(0.1)}) + collection.find_one({"$where": delay(0.1)}) assert len(cx_pool.sockets) <= max_pool_size with lock: @@ -476,7 +466,7 @@ def test_max_pool_size_none(self): def f(): for _ in range(5): - collection.find_one({'$where': delay(0.1)}) + collection.find_one({"$where": delay(0.1)}) with lock: self.n_passed += 1 @@ -489,25 +479,21 @@ def f(): joinall(threads) self.assertEqual(nthreads, self.n_passed) self.assertTrue(len(cx_pool.sockets) > 1) - self.assertEqual(cx_pool.max_pool_size, float('inf')) - + self.assertEqual(cx_pool.max_pool_size, float("inf")) def test_max_pool_size_zero(self): c = rs_or_single_client(maxPoolSize=0) self.addCleanup(c.close) pool = get_pool(c) - self.assertEqual(pool.max_pool_size, float('inf')) + self.assertEqual(pool.max_pool_size, float("inf")) def test_max_pool_size_with_connection_failure(self): # The pool acquires its semaphore before attempting to connect; ensure # it releases the semaphore on connection failure. test_pool = Pool( - ('somedomainthatdoesntexist.org', 27017), - PoolOptions( - max_pool_size=1, - connect_timeout=1, - socket_timeout=1, - wait_queue_timeout=1)) + ("somedomainthatdoesntexist.org", 27017), + PoolOptions(max_pool_size=1, connect_timeout=1, socket_timeout=1, wait_queue_timeout=1), + ) test_pool.ready() # First call to get_socket fails; if pool doesn't release its semaphore @@ -521,8 +507,7 @@ def test_max_pool_size_with_connection_failure(self): # Testing for AutoReconnect instead of ConnectionFailure, above, # is sufficient right *now* to catch a semaphore leak. But that # seems error-prone, so check the message too. - self.assertNotIn('waiting for socket from pool', - str(context.exception)) + self.assertNotIn("waiting for socket from pool", str(context.exception)) if __name__ == "__main__": diff --git a/test/test_pymongo.py b/test/test_pymongo.py index 780a4beb8b..7ec32e16a6 100644 --- a/test/test_pymongo.py +++ b/test/test_pymongo.py @@ -15,17 +15,18 @@ """Test the pymongo module itself.""" import sys + sys.path[0:0] = [""] -import pymongo from test import unittest +import pymongo + class TestPyMongo(unittest.TestCase): def test_mongo_client_alias(self): # Testing that pymongo module imports mongo_client.MongoClient - self.assertEqual(pymongo.MongoClient, - pymongo.mongo_client.MongoClient) + self.assertEqual(pymongo.MongoClient, pymongo.mongo_client.MongoClient) if __name__ == "__main__": diff --git a/test/test_raw_bson.py b/test/test_raw_bson.py index 7e1bf6f837..007fae2473 100644 --- a/test/test_raw_bson.py +++ b/test/test_raw_bson.py @@ -18,15 +18,16 @@ sys.path[0:0] = [""] +from test import client_context, unittest +from test.test_client import IntegrationTest +from test.utils import rs_or_single_client + from bson import decode, encode -from bson.binary import Binary, JAVA_LEGACY, UuidRepresentation +from bson.binary import JAVA_LEGACY, Binary, UuidRepresentation from bson.codec_options import CodecOptions from bson.errors import InvalidBSON -from bson.raw_bson import RawBSONDocument, DEFAULT_RAW_BSON_OPTIONS +from bson.raw_bson import DEFAULT_RAW_BSON_OPTIONS, RawBSONDocument from bson.son import SON -from test import client_context, unittest -from test.utils import rs_or_single_client -from test.test_client import IntegrationTest class TestRawBSONDocument(IntegrationTest): @@ -35,9 +36,9 @@ class TestRawBSONDocument(IntegrationTest): # 'name': 'Sherlock', # 'addresses': [{'street': 'Baker Street'}]} bson_string = ( - b'Z\x00\x00\x00\x07_id\x00Um\xf6\x8bn2\xab!\xa9^\x07\x85\x02name\x00\t' - b'\x00\x00\x00Sherlock\x00\x04addresses\x00&\x00\x00\x00\x030\x00\x1e' - b'\x00\x00\x00\x02street\x00\r\x00\x00\x00Baker Street\x00\x00\x00\x00' + b"Z\x00\x00\x00\x07_id\x00Um\xf6\x8bn2\xab!\xa9^\x07\x85\x02name\x00\t" + b"\x00\x00\x00Sherlock\x00\x04addresses\x00&\x00\x00\x00\x030\x00\x1e" + b"\x00\x00\x00\x02street\x00\r\x00\x00\x00Baker Street\x00\x00\x00\x00" ) document = RawBSONDocument(bson_string) @@ -52,10 +53,10 @@ def tearDown(self): self.client.pymongo_test.test_raw.drop() def test_decode(self): - self.assertEqual('Sherlock', self.document['name']) - first_address = self.document['addresses'][0] + self.assertEqual("Sherlock", self.document["name"]) + first_address = self.document["addresses"][0] self.assertIsInstance(first_address, RawBSONDocument) - self.assertEqual('Baker Street', first_address['street']) + self.assertEqual("Baker Street", first_address["street"]) def test_raw(self): self.assertEqual(self.bson_string, self.document.raw) @@ -63,43 +64,44 @@ def test_raw(self): def test_empty_doc(self): doc = RawBSONDocument(encode({})) with self.assertRaises(KeyError): - doc['does-not-exist'] + doc["does-not-exist"] def test_invalid_bson_sequence(self): - bson_byte_sequence = encode({'a': 1})+encode({}) - with self.assertRaisesRegex(InvalidBSON, 'invalid object length'): + bson_byte_sequence = encode({"a": 1}) + encode({}) + with self.assertRaisesRegex(InvalidBSON, "invalid object length"): RawBSONDocument(bson_byte_sequence) def test_invalid_bson_eoo(self): - invalid_bson_eoo = encode({'a': 1})[:-1] + b'\x01' - with self.assertRaisesRegex(InvalidBSON, 'bad eoo'): + invalid_bson_eoo = encode({"a": 1})[:-1] + b"\x01" + with self.assertRaisesRegex(InvalidBSON, "bad eoo"): RawBSONDocument(invalid_bson_eoo) @client_context.require_connection def test_round_trip(self): db = self.client.get_database( - 'pymongo_test', - codec_options=CodecOptions(document_class=RawBSONDocument)) + "pymongo_test", codec_options=CodecOptions(document_class=RawBSONDocument) + ) db.test_raw.insert_one(self.document) - result = db.test_raw.find_one(self.document['_id']) + result = db.test_raw.find_one(self.document["_id"]) self.assertIsInstance(result, RawBSONDocument) self.assertEqual(dict(self.document.items()), dict(result.items())) @client_context.require_connection def test_round_trip_raw_uuid(self): - coll = self.client.get_database('pymongo_test').test_raw + coll = self.client.get_database("pymongo_test").test_raw uid = uuid.uuid4() - doc = {'_id': 1, - 'bin4': Binary(uid.bytes, 4), - 'bin3': Binary(uid.bytes, 3)} + doc = {"_id": 1, "bin4": Binary(uid.bytes, 4), "bin3": Binary(uid.bytes, 3)} raw = RawBSONDocument(encode(doc)) coll.insert_one(raw) self.assertEqual(coll.find_one(), doc) uuid_coll = coll.with_options( codec_options=coll.codec_options.with_options( - uuid_representation=UuidRepresentation.STANDARD)) - self.assertEqual(uuid_coll.find_one(), - {'_id': 1, 'bin4': uid, 'bin3': Binary(uid.bytes, 3)}) + uuid_representation=UuidRepresentation.STANDARD + ) + ) + self.assertEqual( + uuid_coll.find_one(), {"_id": 1, "bin4": uid, "bin3": Binary(uid.bytes, 3)} + ) # Test that the raw bytes haven't changed. raw_coll = coll.with_options(codec_options=DEFAULT_RAW_BSON_OPTIONS) @@ -110,43 +112,45 @@ def test_with_codec_options(self): # '_id': UUID('026fab8f-975f-4965-9fbf-85ad874c60ff')} # encoded with JAVA_LEGACY uuid representation. bson_string = ( - b'-\x00\x00\x00\x05_id\x00\x10\x00\x00\x00\x03eI_\x97\x8f\xabo\x02' - b'\xff`L\x87\xad\x85\xbf\x9f\tdate\x00\x8a\xd6\xb9\xbaM' - b'\x01\x00\x00\x00' + b"-\x00\x00\x00\x05_id\x00\x10\x00\x00\x00\x03eI_\x97\x8f\xabo\x02" + b"\xff`L\x87\xad\x85\xbf\x9f\tdate\x00\x8a\xd6\xb9\xbaM" + b"\x01\x00\x00\x00" ) document = RawBSONDocument( bson_string, - codec_options=CodecOptions(uuid_representation=JAVA_LEGACY, - document_class=RawBSONDocument)) + codec_options=CodecOptions( + uuid_representation=JAVA_LEGACY, document_class=RawBSONDocument + ), + ) - self.assertEqual(uuid.UUID('026fab8f-975f-4965-9fbf-85ad874c60ff'), - document['_id']) + self.assertEqual(uuid.UUID("026fab8f-975f-4965-9fbf-85ad874c60ff"), document["_id"]) @client_context.require_connection def test_round_trip_codec_options(self): doc = { - 'date': datetime.datetime(2015, 6, 3, 18, 40, 50, 826000), - '_id': uuid.UUID('026fab8f-975f-4965-9fbf-85ad874c60ff') + "date": datetime.datetime(2015, 6, 3, 18, 40, 50, 826000), + "_id": uuid.UUID("026fab8f-975f-4965-9fbf-85ad874c60ff"), } db = self.client.pymongo_test coll = db.get_collection( - 'test_raw', - codec_options=CodecOptions(uuid_representation=JAVA_LEGACY)) + "test_raw", codec_options=CodecOptions(uuid_representation=JAVA_LEGACY) + ) coll.insert_one(doc) - raw_java_legacy = CodecOptions(uuid_representation=JAVA_LEGACY, - document_class=RawBSONDocument) - coll = db.get_collection('test_raw', codec_options=raw_java_legacy) + raw_java_legacy = CodecOptions( + uuid_representation=JAVA_LEGACY, document_class=RawBSONDocument + ) + coll = db.get_collection("test_raw", codec_options=raw_java_legacy) self.assertEqual( - RawBSONDocument(encode(doc, codec_options=raw_java_legacy)), - coll.find_one()) + RawBSONDocument(encode(doc, codec_options=raw_java_legacy)), coll.find_one() + ) @client_context.require_connection def test_raw_bson_document_embedded(self): - doc = {'embedded': self.document} + doc = {"embedded": self.document} db = self.client.pymongo_test db.test_raw.insert_one(doc) result = db.test_raw.find_one() - self.assertEqual(decode(self.document.raw), result['embedded']) + self.assertEqual(decode(self.document.raw), result["embedded"]) # Make sure that CodecOptions are preserved. # {'embedded': [ @@ -155,39 +159,45 @@ def test_raw_bson_document_embedded(self): # ]} # encoded with JAVA_LEGACY uuid representation. bson_string = ( - b'D\x00\x00\x00\x04embedded\x005\x00\x00\x00\x030\x00-\x00\x00\x00' - b'\tdate\x00\x8a\xd6\xb9\xbaM\x01\x00\x00\x05_id\x00\x10\x00\x00' - b'\x00\x03eI_\x97\x8f\xabo\x02\xff`L\x87\xad\x85\xbf\x9f\x00\x00' - b'\x00' + b"D\x00\x00\x00\x04embedded\x005\x00\x00\x00\x030\x00-\x00\x00\x00" + b"\tdate\x00\x8a\xd6\xb9\xbaM\x01\x00\x00\x05_id\x00\x10\x00\x00" + b"\x00\x03eI_\x97\x8f\xabo\x02\xff`L\x87\xad\x85\xbf\x9f\x00\x00" + b"\x00" ) rbd = RawBSONDocument( bson_string, - codec_options=CodecOptions(uuid_representation=JAVA_LEGACY, - document_class=RawBSONDocument)) + codec_options=CodecOptions( + uuid_representation=JAVA_LEGACY, document_class=RawBSONDocument + ), + ) db.test_raw.drop() db.test_raw.insert_one(rbd) - result = db.get_collection('test_raw', codec_options=CodecOptions( - uuid_representation=JAVA_LEGACY)).find_one() - self.assertEqual(rbd['embedded'][0]['_id'], - result['embedded'][0]['_id']) + result = db.get_collection( + "test_raw", codec_options=CodecOptions(uuid_representation=JAVA_LEGACY) + ).find_one() + self.assertEqual(rbd["embedded"][0]["_id"], result["embedded"][0]["_id"]) @client_context.require_connection def test_write_response_raw_bson(self): coll = self.client.get_database( - 'pymongo_test', - codec_options=CodecOptions(document_class=RawBSONDocument)).test_raw + "pymongo_test", codec_options=CodecOptions(document_class=RawBSONDocument) + ).test_raw # No Exceptions raised while handling write response. coll.insert_one(self.document) coll.delete_one(self.document) coll.insert_many([self.document]) coll.delete_many(self.document) - coll.update_one(self.document, {'$set': {'a': 'b'}}, upsert=True) - coll.update_many(self.document, {'$set': {'b': 'c'}}) + coll.update_one(self.document, {"$set": {"a": "b"}}, upsert=True) + coll.update_many(self.document, {"$set": {"b": "c"}}) def test_preserve_key_ordering(self): - keyvaluepairs = [('a', 1), ('b', 2), ('c', 3),] + keyvaluepairs = [ + ("a", 1), + ("b", 2), + ("c", 3), + ] rawdoc = RawBSONDocument(encode(SON(keyvaluepairs))) for rkey, elt in zip(rawdoc, keyvaluepairs): diff --git a/test/test_read_concern.py b/test/test_read_concern.py index 81a6863f5e..42b1a03369 100644 --- a/test/test_read_concern.py +++ b/test/test_read_concern.py @@ -14,16 +14,15 @@ """Test the read_concern module.""" +from test import IntegrationTest, client_context +from test.utils import OvertCommandListener, rs_or_single_client, single_client + from bson.son import SON from pymongo.errors import OperationFailure from pymongo.read_concern import ReadConcern -from test import client_context, IntegrationTest -from test.utils import single_client, rs_or_single_client, OvertCommandListener - class TestReadConcern(IntegrationTest): - @classmethod @client_context.require_connection def setUpClass(cls): @@ -31,12 +30,12 @@ def setUpClass(cls): cls.listener = OvertCommandListener() cls.client = single_client(event_listeners=[cls.listener]) cls.db = cls.client.pymongo_test - client_context.client.pymongo_test.create_collection('coll') + client_context.client.pymongo_test.create_collection("coll") @classmethod def tearDownClass(cls): cls.client.close() - client_context.client.pymongo_test.drop_collection('coll') + client_context.client.pymongo_test.drop_collection("coll") super(TestReadConcern, cls).tearDownClass() def tearDown(self): @@ -48,25 +47,23 @@ def test_read_concern(self): self.assertIsNone(rc.level) self.assertTrue(rc.ok_for_legacy) - rc = ReadConcern('majority') - self.assertEqual('majority', rc.level) + rc = ReadConcern("majority") + self.assertEqual("majority", rc.level) self.assertFalse(rc.ok_for_legacy) - rc = ReadConcern('local') - self.assertEqual('local', rc.level) + rc = ReadConcern("local") + self.assertEqual("local", rc.level) self.assertTrue(rc.ok_for_legacy) self.assertRaises(TypeError, ReadConcern, 42) def test_read_concern_uri(self): - uri = 'mongodb://%s/?readConcernLevel=majority' % ( - client_context.pair,) + uri = "mongodb://%s/?readConcernLevel=majority" % (client_context.pair,) client = rs_or_single_client(uri, connect=False) - self.assertEqual(ReadConcern('majority'), client.read_concern) + self.assertEqual(ReadConcern("majority"), client.read_concern) def test_invalid_read_concern(self): - coll = self.db.get_collection( - 'coll', read_concern=ReadConcern('unknown')) + coll = self.db.get_collection("coll", read_concern=ReadConcern("unknown")) # We rely on the server to validate read concern. with self.assertRaises(OperationFailure): coll.find_one() @@ -74,46 +71,46 @@ def test_invalid_read_concern(self): def test_find_command(self): # readConcern not sent in command if not specified. coll = self.db.coll - tuple(coll.find({'field': 'value'})) - self.assertNotIn('readConcern', - self.listener.results['started'][0].command) + tuple(coll.find({"field": "value"})) + self.assertNotIn("readConcern", self.listener.results["started"][0].command) self.listener.results.clear() # Explicitly set readConcern to 'local'. - coll = self.db.get_collection('coll', read_concern=ReadConcern('local')) - tuple(coll.find({'field': 'value'})) + coll = self.db.get_collection("coll", read_concern=ReadConcern("local")) + tuple(coll.find({"field": "value"})) self.assertEqualCommand( - SON([('find', 'coll'), - ('filter', {'field': 'value'}), - ('readConcern', {'level': 'local'})]), - self.listener.results['started'][0].command) + SON( + [ + ("find", "coll"), + ("filter", {"field": "value"}), + ("readConcern", {"level": "local"}), + ] + ), + self.listener.results["started"][0].command, + ) def test_command_cursor(self): # readConcern not sent in command if not specified. coll = self.db.coll - tuple(coll.aggregate([{'$match': {'field': 'value'}}])) - self.assertNotIn('readConcern', - self.listener.results['started'][0].command) + tuple(coll.aggregate([{"$match": {"field": "value"}}])) + self.assertNotIn("readConcern", self.listener.results["started"][0].command) self.listener.results.clear() # Explicitly set readConcern to 'local'. - coll = self.db.get_collection('coll', read_concern=ReadConcern('local')) - tuple(coll.aggregate([{'$match': {'field': 'value'}}])) + coll = self.db.get_collection("coll", read_concern=ReadConcern("local")) + tuple(coll.aggregate([{"$match": {"field": "value"}}])) self.assertEqual( - {'level': 'local'}, - self.listener.results['started'][0].command['readConcern']) + {"level": "local"}, self.listener.results["started"][0].command["readConcern"] + ) def test_aggregate_out(self): - coll = self.db.get_collection('coll', read_concern=ReadConcern('local')) - tuple(coll.aggregate([{'$match': {'field': 'value'}}, - {'$out': 'output_collection'}])) + coll = self.db.get_collection("coll", read_concern=ReadConcern("local")) + tuple(coll.aggregate([{"$match": {"field": "value"}}, {"$out": "output_collection"}])) # Aggregate with $out supports readConcern MongoDB 4.2 onwards. if client_context.version >= (4, 1): - self.assertIn('readConcern', - self.listener.results['started'][0].command) + self.assertIn("readConcern", self.listener.results["started"][0].command) else: - self.assertNotIn('readConcern', - self.listener.results['started'][0].command) + self.assertNotIn("readConcern", self.listener.results["started"][0].command) diff --git a/test/test_read_preferences.py b/test/test_read_preferences.py index 18dbd0bee4..4df5b09df4 100644 --- a/test/test_read_preferences.py +++ b/test/test_read_preferences.py @@ -22,52 +22,56 @@ sys.path[0:0] = [""] +from test import IntegrationTest, SkipTest, client_context, unittest +from test.utils import ( + OvertCommandListener, + connected, + one, + rs_client, + single_client, + wait_until, +) +from test.version import Version + from bson.son import SON from pymongo.errors import ConfigurationError, OperationFailure from pymongo.message import _maybe_add_read_preference from pymongo.mongo_client import MongoClient -from pymongo.read_preferences import (ReadPreference, MovingAverage, - Primary, PrimaryPreferred, - Secondary, SecondaryPreferred, - Nearest) +from pymongo.read_preferences import ( + MovingAverage, + Nearest, + Primary, + PrimaryPreferred, + ReadPreference, + Secondary, + SecondaryPreferred, +) from pymongo.server_description import ServerDescription -from pymongo.server_selectors import readable_server_selector, Selection +from pymongo.server_selectors import Selection, readable_server_selector from pymongo.server_type import SERVER_TYPE from pymongo.write_concern import WriteConcern -from test import (SkipTest, - client_context, - IntegrationTest, - unittest) -from test.utils import (connected, - one, - OvertCommandListener, - rs_client, - single_client, - wait_until) -from test.version import Version - class TestSelections(IntegrationTest): - @client_context.require_connection def test_bool(self): client = single_client() wait_until(lambda: client.address, "discover primary") - selection = Selection.from_topology_description( - client._topology.description) + selection = Selection.from_topology_description(client._topology.description) self.assertTrue(selection) self.assertFalse(selection.with_server_descriptions([])) class TestReadPreferenceObjects(unittest.TestCase): - prefs = [Primary(), - PrimaryPreferred(), - Secondary(), - Nearest(tag_sets=[{'a': 1}, {'b': 2}]), - SecondaryPreferred(max_staleness=30)] + prefs = [ + Primary(), + PrimaryPreferred(), + Secondary(), + Nearest(tag_sets=[{"a": 1}, {"b": 2}]), + SecondaryPreferred(max_staleness=30), + ] def test_pickle(self): for pref in self.prefs: @@ -83,7 +87,6 @@ def test_deepcopy(self): class TestReadPreferencesBase(IntegrationTest): - @classmethod @client_context.require_secondaries_count(1) def setUpClass(cls): @@ -94,47 +97,41 @@ def setUp(self): # Insert some data so we can use cursors in read_from_which_host self.client.pymongo_test.test.drop() self.client.get_database( - "pymongo_test", - write_concern=WriteConcern(w=client_context.w)).test.insert_many( - [{'_id': i} for i in range(10)]) + "pymongo_test", write_concern=WriteConcern(w=client_context.w) + ).test.insert_many([{"_id": i} for i in range(10)]) self.addCleanup(self.client.pymongo_test.test.drop) def read_from_which_host(self, client): - """Do a find() on the client and return which host was used - """ + """Do a find() on the client and return which host was used""" cursor = client.pymongo_test.test.find() next(cursor) return cursor.address def read_from_which_kind(self, client): """Do a find() on the client and return 'primary' or 'secondary' - depending on which the client used. + depending on which the client used. """ address = self.read_from_which_host(client) if address == client.primary: - return 'primary' + return "primary" elif address in client.secondaries: - return 'secondary' + return "secondary" else: self.fail( - 'Cursor used address %s, expected either primary ' - '%s or secondaries %s' % ( - address, client.primary, client.secondaries)) + "Cursor used address %s, expected either primary " + "%s or secondaries %s" % (address, client.primary, client.secondaries) + ) def assertReadsFrom(self, expected, **kwargs): c = rs_client(**kwargs) - wait_until( - lambda: len(c.nodes - c.arbiters) == client_context.w, - "discovered all nodes") + wait_until(lambda: len(c.nodes - c.arbiters) == client_context.w, "discovered all nodes") used = self.read_from_which_kind(c) - self.assertEqual(expected, used, 'Cursor used %s, expected %s' % ( - used, expected)) + self.assertEqual(expected, used, "Cursor used %s, expected %s" % (used, expected)) class TestSingleSecondaryOk(TestReadPreferencesBase): - def test_reads_from_secondary(self): host, port = next(iter(self.client.secondaries)) @@ -167,62 +164,53 @@ def test_reads_from_secondary(self): class TestReadPreferences(TestReadPreferencesBase): - def test_mode_validation(self): - for mode in (ReadPreference.PRIMARY, - ReadPreference.PRIMARY_PREFERRED, - ReadPreference.SECONDARY, - ReadPreference.SECONDARY_PREFERRED, - ReadPreference.NEAREST): - self.assertEqual( - mode, - rs_client(read_preference=mode).read_preference) - - self.assertRaises( - TypeError, - rs_client, read_preference='foo') + for mode in ( + ReadPreference.PRIMARY, + ReadPreference.PRIMARY_PREFERRED, + ReadPreference.SECONDARY, + ReadPreference.SECONDARY_PREFERRED, + ReadPreference.NEAREST, + ): + self.assertEqual(mode, rs_client(read_preference=mode).read_preference) + + self.assertRaises(TypeError, rs_client, read_preference="foo") def test_tag_sets_validation(self): S = Secondary(tag_sets=[{}]) - self.assertEqual( - [{}], - rs_client(read_preference=S).read_preference.tag_sets) + self.assertEqual([{}], rs_client(read_preference=S).read_preference.tag_sets) - S = Secondary(tag_sets=[{'k': 'v'}]) - self.assertEqual( - [{'k': 'v'}], - rs_client(read_preference=S).read_preference.tag_sets) + S = Secondary(tag_sets=[{"k": "v"}]) + self.assertEqual([{"k": "v"}], rs_client(read_preference=S).read_preference.tag_sets) - S = Secondary(tag_sets=[{'k': 'v'}, {}]) - self.assertEqual( - [{'k': 'v'}, {}], - rs_client(read_preference=S).read_preference.tag_sets) + S = Secondary(tag_sets=[{"k": "v"}, {}]) + self.assertEqual([{"k": "v"}, {}], rs_client(read_preference=S).read_preference.tag_sets) self.assertRaises(ValueError, Secondary, tag_sets=[]) # One dict not ok, must be a list of dicts - self.assertRaises(TypeError, Secondary, tag_sets={'k': 'v'}) + self.assertRaises(TypeError, Secondary, tag_sets={"k": "v"}) - self.assertRaises(TypeError, Secondary, tag_sets='foo') + self.assertRaises(TypeError, Secondary, tag_sets="foo") - self.assertRaises(TypeError, Secondary, tag_sets=['foo']) + self.assertRaises(TypeError, Secondary, tag_sets=["foo"]) def test_threshold_validation(self): - self.assertEqual(17, rs_client( - localThresholdMS=17, connect=False).options.local_threshold_ms) + self.assertEqual( + 17, rs_client(localThresholdMS=17, connect=False).options.local_threshold_ms + ) - self.assertEqual(42, rs_client( - localThresholdMS=42, connect=False).options.local_threshold_ms) + self.assertEqual( + 42, rs_client(localThresholdMS=42, connect=False).options.local_threshold_ms + ) - self.assertEqual(666, rs_client( - localThresholdMS=666, connect=False).options.local_threshold_ms) + self.assertEqual( + 666, rs_client(localThresholdMS=666, connect=False).options.local_threshold_ms + ) - self.assertEqual(0, rs_client( - localThresholdMS=0, connect=False).options.local_threshold_ms) + self.assertEqual(0, rs_client(localThresholdMS=0, connect=False).options.local_threshold_ms) - self.assertRaises(ValueError, - rs_client, - localthresholdms=-1) + self.assertRaises(ValueError, rs_client, localthresholdms=-1) def test_zero_latency(self): ping_times = set() @@ -232,11 +220,8 @@ def test_zero_latency(self): for ping_time, host in zip(ping_times, self.client.nodes): ServerDescription._host_to_round_trip_time[host] = ping_time try: - client = connected( - rs_client(readPreference='nearest', localThresholdMS=0)) - wait_until( - lambda: client.nodes == self.client.nodes, - "discovered all nodes") + client = connected(rs_client(readPreference="nearest", localThresholdMS=0)) + wait_until(lambda: client.nodes == self.client.nodes, "discovered all nodes") host = self.read_from_which_host(client) for _ in range(5): self.assertEqual(host, self.read_from_which_host(client)) @@ -244,33 +229,25 @@ def test_zero_latency(self): ServerDescription._host_to_round_trip_time.clear() def test_primary(self): - self.assertReadsFrom( - 'primary', read_preference=ReadPreference.PRIMARY) + self.assertReadsFrom("primary", read_preference=ReadPreference.PRIMARY) def test_primary_with_tags(self): # Tags not allowed with PRIMARY - self.assertRaises( - ConfigurationError, - rs_client, tag_sets=[{'dc': 'ny'}]) + self.assertRaises(ConfigurationError, rs_client, tag_sets=[{"dc": "ny"}]) def test_primary_preferred(self): - self.assertReadsFrom( - 'primary', read_preference=ReadPreference.PRIMARY_PREFERRED) + self.assertReadsFrom("primary", read_preference=ReadPreference.PRIMARY_PREFERRED) def test_secondary(self): - self.assertReadsFrom( - 'secondary', read_preference=ReadPreference.SECONDARY) + self.assertReadsFrom("secondary", read_preference=ReadPreference.SECONDARY) def test_secondary_preferred(self): - self.assertReadsFrom( - 'secondary', read_preference=ReadPreference.SECONDARY_PREFERRED) + self.assertReadsFrom("secondary", read_preference=ReadPreference.SECONDARY_PREFERRED) def test_nearest(self): # With high localThresholdMS, expect to read from any # member - c = rs_client( - read_preference=ReadPreference.NEAREST, - localThresholdMS=10000) # 10 seconds + c = rs_client(read_preference=ReadPreference.NEAREST, localThresholdMS=10000) # 10 seconds data_members = {self.client.primary} | self.client.secondaries @@ -286,16 +263,16 @@ def test_nearest(self): i += 1 not_used = data_members.difference(used) - latencies = ', '.join( - '%s: %dms' % (server.description.address, - server.description.round_trip_time) - for server in c._get_topology().select_servers( - readable_server_selector)) + latencies = ", ".join( + "%s: %dms" % (server.description.address, server.description.round_trip_time) + for server in c._get_topology().select_servers(readable_server_selector) + ) self.assertFalse( not_used, "Expected to use primary and all secondaries for mode NEAREST," - " but didn't use %s\nlatencies: %s" % (not_used, latencies)) + " but didn't use %s\nlatencies: %s" % (not_used, latencies), + ) class ReadPrefTester(MongoClient): @@ -307,8 +284,7 @@ def __init__(self, *args, **kwargs): @contextlib.contextmanager def _socket_for_reads(self, read_preference, session): - context = super(ReadPrefTester, self)._socket_for_reads( - read_preference, session) + context = super(ReadPrefTester, self)._socket_for_reads(read_preference, session) with context as (sock_info, secondary_ok): self.record_a_read(sock_info.address) yield sock_info, secondary_ok @@ -316,7 +292,8 @@ def _socket_for_reads(self, read_preference, session): @contextlib.contextmanager def _secondaryok_for_server(self, read_preference, server, session): context = super(ReadPrefTester, self)._secondaryok_for_server( - read_preference, server, session) + read_preference, server, session + ) with context as (sock_info, secondary_ok): self.record_a_read(sock_info.address) yield sock_info, secondary_ok @@ -325,17 +302,17 @@ def record_a_read(self, address): server = self._get_topology().select_server_by_address(address, 0) self.has_read_from.add(server) + _PREF_MAP = [ (Primary, SERVER_TYPE.RSPrimary), (PrimaryPreferred, SERVER_TYPE.RSPrimary), (Secondary, SERVER_TYPE.RSSecondary), (SecondaryPreferred, SERVER_TYPE.RSSecondary), - (Nearest, 'any') + (Nearest, "any"), ] class TestCommandAndReadPreference(IntegrationTest): - @classmethod @client_context.require_secondaries_count(1) def setUpClass(cls): @@ -343,16 +320,18 @@ def setUpClass(cls): cls.c = ReadPrefTester( client_context.pair, # Ignore round trip times, to test ReadPreference modes only. - localThresholdMS=1000*1000) + localThresholdMS=1000 * 1000, + ) cls.client_version = Version.from_client(cls.c) # mapReduce fails if the collection does not exist. coll = cls.c.pymongo_test.get_collection( - 'test', write_concern=WriteConcern(w=client_context.w)) + "test", write_concern=WriteConcern(w=client_context.w) + ) coll.insert_one({}) @classmethod def tearDownClass(cls): - cls.c.drop_database('pymongo_test') + cls.c.drop_database("pymongo_test") cls.c.close() def executed_on_which_server(self, client, fn, *args, **kwargs): @@ -364,12 +343,13 @@ def executed_on_which_server(self, client, fn, *args, **kwargs): def assertExecutedOn(self, server_type, client, fn, *args, **kwargs): server = self.executed_on_which_server(client, fn, *args, **kwargs) - self.assertEqual(SERVER_TYPE._fields[server_type], - SERVER_TYPE._fields[server.description.server_type]) + self.assertEqual( + SERVER_TYPE._fields[server_type], SERVER_TYPE._fields[server.description.server_type] + ) def _test_fn(self, server_type, fn): for _ in range(10): - if server_type == 'any': + if server_type == "any": used = set() for _ in range(1000): server = self.executed_on_which_server(self.c, fn) @@ -378,13 +358,9 @@ def _test_fn(self, server_type, fn): # Success break - unused = self.c.secondaries.union( - set([self.c.primary]) - ).difference(used) + unused = self.c.secondaries.union(set([self.c.primary])).difference(used) if unused: - self.fail( - "Some members not used for NEAREST: %s" % ( - unused)) + self.fail("Some members not used for NEAREST: %s" % (unused)) else: self.assertExecutedOn(server_type, self.c, fn) @@ -405,8 +381,7 @@ def test_command(self): # Test that the generic command helper obeys the read preference # passed to it. for mode, server_type in _PREF_MAP: - func = lambda: self.c.pymongo_test.command('dbStats', - read_preference=mode()) + func = lambda: self.c.pymongo_test.command("dbStats", read_preference=mode()) self._test_fn(server_type, func) def test_create_collection(self): @@ -414,28 +389,31 @@ def test_create_collection(self): # the collection already exists. self._test_primary_helper( lambda: self.c.pymongo_test.create_collection( - 'some_collection%s' % random.randint(0, sys.maxsize))) + "some_collection%s" % random.randint(0, sys.maxsize) + ) + ) def test_count_documents(self): - self._test_coll_helper( - True, self.c.pymongo_test.test, 'count_documents', {}) + self._test_coll_helper(True, self.c.pymongo_test.test, "count_documents", {}) def test_estimated_document_count(self): - self._test_coll_helper( - True, self.c.pymongo_test.test, 'estimated_document_count') + self._test_coll_helper(True, self.c.pymongo_test.test, "estimated_document_count") def test_distinct(self): - self._test_coll_helper(True, self.c.pymongo_test.test, 'distinct', 'a') + self._test_coll_helper(True, self.c.pymongo_test.test, "distinct", "a") def test_aggregate(self): - self._test_coll_helper(True, self.c.pymongo_test.test, - 'aggregate', - [{'$project': {'_id': 1}}]) + self._test_coll_helper( + True, self.c.pymongo_test.test, "aggregate", [{"$project": {"_id": 1}}] + ) def test_aggregate_write(self): - self._test_coll_helper(False, self.c.pymongo_test.test, - 'aggregate', - [{'$project': {'_id': 1}}, {'$out': "agg_write_test"}]) + self._test_coll_helper( + False, + self.c.pymongo_test.test, + "aggregate", + [{"$project": {"_id": 1}}, {"$out": "agg_write_test"}], + ) class TestMovingAverage(unittest.TestCase): @@ -451,77 +429,48 @@ def test_moving_average(self): class TestMongosAndReadPreference(IntegrationTest): - def test_read_preference_document(self): pref = Primary() - self.assertEqual( - pref.document, - {'mode': 'primary'}) + self.assertEqual(pref.document, {"mode": "primary"}) pref = PrimaryPreferred() + self.assertEqual(pref.document, {"mode": "primaryPreferred"}) + pref = PrimaryPreferred(tag_sets=[{"dc": "sf"}]) + self.assertEqual(pref.document, {"mode": "primaryPreferred", "tags": [{"dc": "sf"}]}) + pref = PrimaryPreferred(tag_sets=[{"dc": "sf"}], max_staleness=30) self.assertEqual( pref.document, - {'mode': 'primaryPreferred'}) - pref = PrimaryPreferred(tag_sets=[{'dc': 'sf'}]) - self.assertEqual( - pref.document, - {'mode': 'primaryPreferred', 'tags': [{'dc': 'sf'}]}) - pref = PrimaryPreferred( - tag_sets=[{'dc': 'sf'}], max_staleness=30) - self.assertEqual( - pref.document, - {'mode': 'primaryPreferred', - 'tags': [{'dc': 'sf'}], - 'maxStalenessSeconds': 30}) + {"mode": "primaryPreferred", "tags": [{"dc": "sf"}], "maxStalenessSeconds": 30}, + ) pref = Secondary() + self.assertEqual(pref.document, {"mode": "secondary"}) + pref = Secondary(tag_sets=[{"dc": "sf"}]) + self.assertEqual(pref.document, {"mode": "secondary", "tags": [{"dc": "sf"}]}) + pref = Secondary(tag_sets=[{"dc": "sf"}], max_staleness=30) self.assertEqual( - pref.document, - {'mode': 'secondary'}) - pref = Secondary(tag_sets=[{'dc': 'sf'}]) - self.assertEqual( - pref.document, - {'mode': 'secondary', 'tags': [{'dc': 'sf'}]}) - pref = Secondary( - tag_sets=[{'dc': 'sf'}], max_staleness=30) - self.assertEqual( - pref.document, - {'mode': 'secondary', - 'tags': [{'dc': 'sf'}], - 'maxStalenessSeconds': 30}) + pref.document, {"mode": "secondary", "tags": [{"dc": "sf"}], "maxStalenessSeconds": 30} + ) pref = SecondaryPreferred() + self.assertEqual(pref.document, {"mode": "secondaryPreferred"}) + pref = SecondaryPreferred(tag_sets=[{"dc": "sf"}]) + self.assertEqual(pref.document, {"mode": "secondaryPreferred", "tags": [{"dc": "sf"}]}) + pref = SecondaryPreferred(tag_sets=[{"dc": "sf"}], max_staleness=30) self.assertEqual( pref.document, - {'mode': 'secondaryPreferred'}) - pref = SecondaryPreferred(tag_sets=[{'dc': 'sf'}]) - self.assertEqual( - pref.document, - {'mode': 'secondaryPreferred', 'tags': [{'dc': 'sf'}]}) - pref = SecondaryPreferred( - tag_sets=[{'dc': 'sf'}], max_staleness=30) - self.assertEqual( - pref.document, - {'mode': 'secondaryPreferred', - 'tags': [{'dc': 'sf'}], - 'maxStalenessSeconds': 30}) + {"mode": "secondaryPreferred", "tags": [{"dc": "sf"}], "maxStalenessSeconds": 30}, + ) pref = Nearest() + self.assertEqual(pref.document, {"mode": "nearest"}) + pref = Nearest(tag_sets=[{"dc": "sf"}]) + self.assertEqual(pref.document, {"mode": "nearest", "tags": [{"dc": "sf"}]}) + pref = Nearest(tag_sets=[{"dc": "sf"}], max_staleness=30) self.assertEqual( - pref.document, - {'mode': 'nearest'}) - pref = Nearest(tag_sets=[{'dc': 'sf'}]) - self.assertEqual( - pref.document, - {'mode': 'nearest', 'tags': [{'dc': 'sf'}]}) - pref = Nearest( - tag_sets=[{'dc': 'sf'}], max_staleness=30) - self.assertEqual( - pref.document, - {'mode': 'nearest', - 'tags': [{'dc': 'sf'}], - 'maxStalenessSeconds': 30}) + pref.document, {"mode": "nearest", "tags": [{"dc": "sf"}], "maxStalenessSeconds": 30} + ) with self.assertRaises(TypeError): Nearest(max_staleness=1.5) # Float is prohibited. @@ -534,69 +483,64 @@ def test_read_preference_document(self): def test_read_preference_document_hedge(self): cases = { - 'primaryPreferred': PrimaryPreferred, - 'secondary': Secondary, - 'secondaryPreferred': SecondaryPreferred, - 'nearest': Nearest, + "primaryPreferred": PrimaryPreferred, + "secondary": Secondary, + "secondaryPreferred": SecondaryPreferred, + "nearest": Nearest, } for mode, cls in cases.items(): with self.assertRaises(TypeError): cls(hedge=[]) pref = cls(hedge={}) - self.assertEqual(pref.document, {'mode': mode}) + self.assertEqual(pref.document, {"mode": mode}) out = _maybe_add_read_preference({}, pref) if cls == SecondaryPreferred: # SecondaryPreferred without hedge doesn't add $readPreference. self.assertEqual(out, {}) else: - self.assertEqual( - out, - SON([("$query", {}), ("$readPreference", pref.document)])) + self.assertEqual(out, SON([("$query", {}), ("$readPreference", pref.document)])) - hedge = {'enabled': True} + hedge = {"enabled": True} pref = cls(hedge=hedge) - self.assertEqual(pref.document, {'mode': mode, 'hedge': hedge}) + self.assertEqual(pref.document, {"mode": mode, "hedge": hedge}) out = _maybe_add_read_preference({}, pref) - self.assertEqual( - out, SON([("$query", {}), ("$readPreference", pref.document)])) + self.assertEqual(out, SON([("$query", {}), ("$readPreference", pref.document)])) - hedge = {'enabled': False} + hedge = {"enabled": False} pref = cls(hedge=hedge) - self.assertEqual(pref.document, {'mode': mode, 'hedge': hedge}) + self.assertEqual(pref.document, {"mode": mode, "hedge": hedge}) out = _maybe_add_read_preference({}, pref) - self.assertEqual( - out, SON([("$query", {}), ("$readPreference", pref.document)])) + self.assertEqual(out, SON([("$query", {}), ("$readPreference", pref.document)])) - hedge = {'enabled': False, 'extra': 'option'} + hedge = {"enabled": False, "extra": "option"} pref = cls(hedge=hedge) - self.assertEqual(pref.document, {'mode': mode, 'hedge': hedge}) + self.assertEqual(pref.document, {"mode": mode, "hedge": hedge}) out = _maybe_add_read_preference({}, pref) - self.assertEqual( - out, SON([("$query", {}), ("$readPreference", pref.document)])) + self.assertEqual(out, SON([("$query", {}), ("$readPreference", pref.document)])) def test_send_hedge(self): cases = { - 'primaryPreferred': PrimaryPreferred, - 'secondaryPreferred': SecondaryPreferred, - 'nearest': Nearest, + "primaryPreferred": PrimaryPreferred, + "secondaryPreferred": SecondaryPreferred, + "nearest": Nearest, } if client_context.supports_secondary_read_pref: - cases['secondary'] = Secondary + cases["secondary"] = Secondary listener = OvertCommandListener() client = rs_client(event_listeners=[listener]) self.addCleanup(client.close) - client.admin.command('ping') + client.admin.command("ping") for mode, cls in cases.items(): - pref = cls(hedge={'enabled': True}) - coll = client.test.get_collection('test', read_preference=pref) + pref = cls(hedge={"enabled": True}) + coll = client.test.get_collection("test", read_preference=pref) listener.reset() coll.find_one() - started = listener.results['started'] + started = listener.results["started"] self.assertEqual(len(started), 1, started) cmd = started[0].command - self.assertIn('$readPreference', cmd) - self.assertEqual(cmd['$readPreference'], pref.document) + self.assertIn("$readPreference", cmd) + self.assertEqual(cmd["$readPreference"], pref.document) def test_maybe_add_read_preference(self): @@ -606,70 +550,72 @@ def test_maybe_add_read_preference(self): pref = PrimaryPreferred() out = _maybe_add_read_preference({}, pref) - self.assertEqual( - out, SON([("$query", {}), ("$readPreference", pref.document)])) - pref = PrimaryPreferred(tag_sets=[{'dc': 'nyc'}]) + self.assertEqual(out, SON([("$query", {}), ("$readPreference", pref.document)])) + pref = PrimaryPreferred(tag_sets=[{"dc": "nyc"}]) out = _maybe_add_read_preference({}, pref) - self.assertEqual( - out, SON([("$query", {}), ("$readPreference", pref.document)])) + self.assertEqual(out, SON([("$query", {}), ("$readPreference", pref.document)])) pref = Secondary() out = _maybe_add_read_preference({}, pref) - self.assertEqual( - out, SON([("$query", {}), ("$readPreference", pref.document)])) - pref = Secondary(tag_sets=[{'dc': 'nyc'}]) + self.assertEqual(out, SON([("$query", {}), ("$readPreference", pref.document)])) + pref = Secondary(tag_sets=[{"dc": "nyc"}]) out = _maybe_add_read_preference({}, pref) - self.assertEqual( - out, SON([("$query", {}), ("$readPreference", pref.document)])) + self.assertEqual(out, SON([("$query", {}), ("$readPreference", pref.document)])) # SecondaryPreferred without tag_sets or max_staleness doesn't add # $readPreference pref = SecondaryPreferred() out = _maybe_add_read_preference({}, pref) self.assertEqual(out, {}) - pref = SecondaryPreferred(tag_sets=[{'dc': 'nyc'}]) + pref = SecondaryPreferred(tag_sets=[{"dc": "nyc"}]) out = _maybe_add_read_preference({}, pref) - self.assertEqual( - out, SON([("$query", {}), ("$readPreference", pref.document)])) + self.assertEqual(out, SON([("$query", {}), ("$readPreference", pref.document)])) pref = SecondaryPreferred(max_staleness=120) out = _maybe_add_read_preference({}, pref) - self.assertEqual( - out, SON([("$query", {}), ("$readPreference", pref.document)])) + self.assertEqual(out, SON([("$query", {}), ("$readPreference", pref.document)])) pref = Nearest() out = _maybe_add_read_preference({}, pref) - self.assertEqual( - out, SON([("$query", {}), ("$readPreference", pref.document)])) - pref = Nearest(tag_sets=[{'dc': 'nyc'}]) + self.assertEqual(out, SON([("$query", {}), ("$readPreference", pref.document)])) + pref = Nearest(tag_sets=[{"dc": "nyc"}]) out = _maybe_add_read_preference({}, pref) - self.assertEqual( - out, SON([("$query", {}), ("$readPreference", pref.document)])) + self.assertEqual(out, SON([("$query", {}), ("$readPreference", pref.document)])) criteria = SON([("$query", {}), ("$orderby", SON([("_id", 1)]))]) pref = Nearest() out = _maybe_add_read_preference(criteria, pref) self.assertEqual( out, - SON([("$query", {}), - ("$orderby", SON([("_id", 1)])), - ("$readPreference", pref.document)])) - pref = Nearest(tag_sets=[{'dc': 'nyc'}]) + SON( + [ + ("$query", {}), + ("$orderby", SON([("_id", 1)])), + ("$readPreference", pref.document), + ] + ), + ) + pref = Nearest(tag_sets=[{"dc": "nyc"}]) out = _maybe_add_read_preference(criteria, pref) self.assertEqual( out, - SON([("$query", {}), - ("$orderby", SON([("_id", 1)])), - ("$readPreference", pref.document)])) + SON( + [ + ("$query", {}), + ("$orderby", SON([("_id", 1)])), + ("$readPreference", pref.document), + ] + ), + ) @client_context.require_mongos def test_mongos(self): - shard = client_context.client.config.shards.find_one()['host'] - num_members = shard.count(',') + 1 + shard = client_context.client.config.shards.find_one()["host"] + num_members = shard.count(",") + 1 if num_members == 1: raise SkipTest("Need a replica set shard to test.") coll = client_context.client.pymongo_test.get_collection( - "test", - write_concern=WriteConcern(w=num_members)) + "test", write_concern=WriteConcern(w=num_members) + ) coll.drop() res = coll.insert_many([{} for _ in range(5)]) first_id = res.inserted_ids[0] @@ -677,11 +623,7 @@ def test_mongos(self): # Note - this isn't a perfect test since there's no way to # tell what shard member a query ran on. - for pref in (Primary(), - PrimaryPreferred(), - Secondary(), - SecondaryPreferred(), - Nearest()): + for pref in (Primary(), PrimaryPreferred(), Secondary(), SecondaryPreferred(), Nearest()): qcoll = coll.with_options(read_preference=pref) results = list(qcoll.find().sort([("_id", 1)])) self.assertEqual(first_id, results[0]["_id"]) @@ -694,12 +636,14 @@ def test_mongos(self): def test_mongos_max_staleness(self): # Sanity check that we're sending maxStalenessSeconds coll = client_context.client.pymongo_test.get_collection( - "test", read_preference=SecondaryPreferred(max_staleness=120)) + "test", read_preference=SecondaryPreferred(max_staleness=120) + ) # No error coll.find_one() coll = client_context.client.pymongo_test.get_collection( - "test", read_preference=SecondaryPreferred(max_staleness=10)) + "test", read_preference=SecondaryPreferred(max_staleness=10) + ) try: coll.find_one() except OperationFailure as exc: @@ -708,14 +652,14 @@ def test_mongos_max_staleness(self): self.fail("mongos accepted invalid staleness") coll = single_client( - readPreference='secondaryPreferred', - maxStalenessSeconds=120).pymongo_test.test + readPreference="secondaryPreferred", maxStalenessSeconds=120 + ).pymongo_test.test # No error coll.find_one() coll = single_client( - readPreference='secondaryPreferred', - maxStalenessSeconds=10).pymongo_test.test + readPreference="secondaryPreferred", maxStalenessSeconds=10 + ).pymongo_test.test try: coll.find_one() except OperationFailure as exc: @@ -723,5 +667,6 @@ def test_mongos_max_staleness(self): else: self.fail("mongos accepted invalid staleness") + if __name__ == "__main__": unittest.main() diff --git a/test/test_read_write_concern_spec.py b/test/test_read_write_concern_spec.py index d6c5a68c32..8c554c656a 100644 --- a/test/test_read_write_concern_spec.py +++ b/test/test_read_write_concern_spec.py @@ -21,33 +21,33 @@ sys.path[0:0] = [""] +from test import IntegrationTest, client_context, unittest +from test.utils import ( + EventListener, + TestCreator, + disable_replication, + enable_replication, + rs_or_single_client, +) +from test.utils_spec_runner import SpecRunner + from pymongo import DESCENDING -from pymongo.errors import (BulkWriteError, - ConfigurationError, - WTimeoutError, - WriteConcernError, - WriteError) +from pymongo.errors import ( + BulkWriteError, + ConfigurationError, + WriteConcernError, + WriteError, + WTimeoutError, +) from pymongo.mongo_client import MongoClient from pymongo.operations import IndexModel, InsertOne from pymongo.read_concern import ReadConcern from pymongo.write_concern import WriteConcern -from test import (client_context, - IntegrationTest, - unittest) -from test.utils import (EventListener, - disable_replication, - enable_replication, - rs_or_single_client, - TestCreator) -from test.utils_spec_runner import SpecRunner - -_TEST_PATH = os.path.join( - os.path.dirname(os.path.realpath(__file__)), 'read_write_concern') +_TEST_PATH = os.path.join(os.path.dirname(os.path.realpath(__file__)), "read_write_concern") class TestReadWriteConcernSpec(IntegrationTest): - def test_omit_default_read_write_concern(self): listener = EventListener() # Client with default readConcern and writeConcern @@ -63,85 +63,87 @@ def test_omit_default_read_write_concern(self): def rename_and_drop(): # Ensure collection exists. collection.insert_one({}) - collection.rename('collection2') + collection.rename("collection2") client.pymongo_test.collection2.drop() def insert_command_default_write_concern(): collection.database.command( - 'insert', 'collection', documents=[{}], - write_concern=WriteConcern()) + "insert", "collection", documents=[{}], write_concern=WriteConcern() + ) ops = [ - ('aggregate', lambda: list(collection.aggregate([]))), - ('find', lambda: list(collection.find())), - ('insert_one', lambda: collection.insert_one({})), - ('update_one', - lambda: collection.update_one({}, {'$set': {'x': 1}})), - ('update_many', - lambda: collection.update_many({}, {'$set': {'x': 1}})), - ('delete_one', lambda: collection.delete_one({})), - ('delete_many', lambda: collection.delete_many({})), - ('bulk_write', lambda: collection.bulk_write([InsertOne({})])), - ('rename_and_drop', rename_and_drop), - ('command', insert_command_default_write_concern) + ("aggregate", lambda: list(collection.aggregate([]))), + ("find", lambda: list(collection.find())), + ("insert_one", lambda: collection.insert_one({})), + ("update_one", lambda: collection.update_one({}, {"$set": {"x": 1}})), + ("update_many", lambda: collection.update_many({}, {"$set": {"x": 1}})), + ("delete_one", lambda: collection.delete_one({})), + ("delete_many", lambda: collection.delete_many({})), + ("bulk_write", lambda: collection.bulk_write([InsertOne({})])), + ("rename_and_drop", rename_and_drop), + ("command", insert_command_default_write_concern), ] for name, f in ops: listener.results.clear() f() - self.assertGreaterEqual(len(listener.results['started']), 1) - for i, event in enumerate(listener.results['started']): + self.assertGreaterEqual(len(listener.results["started"]), 1) + for i, event in enumerate(listener.results["started"]): self.assertNotIn( - 'readConcern', event.command, - "%s sent default readConcern with %s" % ( - name, event.command_name)) + "readConcern", + event.command, + "%s sent default readConcern with %s" % (name, event.command_name), + ) self.assertNotIn( - 'writeConcern', event.command, - "%s sent default writeConcern with %s" % ( - name, event.command_name)) + "writeConcern", + event.command, + "%s sent default writeConcern with %s" % (name, event.command_name), + ) def assertWriteOpsRaise(self, write_concern, expected_exception): wc = write_concern.document # Set socket timeout to avoid indefinite stalls - client = rs_or_single_client( - w=wc['w'], wTimeoutMS=wc['wtimeout'], socketTimeoutMS=30000) - db = client.get_database('pymongo_test') + client = rs_or_single_client(w=wc["w"], wTimeoutMS=wc["wtimeout"], socketTimeoutMS=30000) + db = client.get_database("pymongo_test") coll = db.test def insert_command(): coll.database.command( - 'insert', 'new_collection', documents=[{}], + "insert", + "new_collection", + documents=[{}], writeConcern=write_concern.document, - parse_write_concern_error=True) + parse_write_concern_error=True, + ) ops = [ - ('insert_one', lambda: coll.insert_one({})), - ('insert_many', lambda: coll.insert_many([{}, {}])), - ('update_one', lambda: coll.update_one({}, {'$set': {'x': 1}})), - ('update_many', lambda: coll.update_many({}, {'$set': {'x': 1}})), - ('delete_one', lambda: coll.delete_one({})), - ('delete_many', lambda: coll.delete_many({})), - ('bulk_write', lambda: coll.bulk_write([InsertOne({})])), - ('command', insert_command), - ('aggregate', lambda: coll.aggregate([{'$out': 'out'}])), + ("insert_one", lambda: coll.insert_one({})), + ("insert_many", lambda: coll.insert_many([{}, {}])), + ("update_one", lambda: coll.update_one({}, {"$set": {"x": 1}})), + ("update_many", lambda: coll.update_many({}, {"$set": {"x": 1}})), + ("delete_one", lambda: coll.delete_one({})), + ("delete_many", lambda: coll.delete_many({})), + ("bulk_write", lambda: coll.bulk_write([InsertOne({})])), + ("command", insert_command), + ("aggregate", lambda: coll.aggregate([{"$out": "out"}])), # SERVER-46668 Delete all the documents in the collection to # workaround a hang in createIndexes. - ('delete_many', lambda: coll.delete_many({})), - ('create_index', lambda: coll.create_index([('a', DESCENDING)])), - ('create_indexes', lambda: coll.create_indexes([IndexModel('b')])), - ('drop_index', lambda: coll.drop_index([('a', DESCENDING)])), - ('create', lambda: db.create_collection('new')), - ('rename', lambda: coll.rename('new')), - ('drop', lambda: db.new.drop()), + ("delete_many", lambda: coll.delete_many({})), + ("create_index", lambda: coll.create_index([("a", DESCENDING)])), + ("create_indexes", lambda: coll.create_indexes([IndexModel("b")])), + ("drop_index", lambda: coll.drop_index([("a", DESCENDING)])), + ("create", lambda: db.create_collection("new")), + ("rename", lambda: coll.rename("new")), + ("drop", lambda: db.new.drop()), ] # SERVER-47194: dropDatabase does not respect wtimeout in 3.6. if client_context.version[:2] != (3, 6): - ops.append(('drop_database', lambda: client.drop_database(db))) + ops.append(("drop_database", lambda: client.drop_database(db))) for name, f in ops: # Ensure insert_many and bulk_write still raise BulkWriteError. - if name in ('insert_many', 'bulk_write'): + if name in ("insert_many", "bulk_write"): expected = BulkWriteError else: expected = expected_exception @@ -149,24 +151,24 @@ def insert_command(): f() if expected == BulkWriteError: bulk_result = cm.exception.details - wc_errors = bulk_result['writeConcernErrors'] + wc_errors = bulk_result["writeConcernErrors"] self.assertTrue(wc_errors) @client_context.require_replica_set def test_raise_write_concern_error(self): - self.addCleanup(client_context.client.drop_database, 'pymongo_test') + self.addCleanup(client_context.client.drop_database, "pymongo_test") self.assertWriteOpsRaise( - WriteConcern(w=client_context.w+1, wtimeout=1), WriteConcernError) + WriteConcern(w=client_context.w + 1, wtimeout=1), WriteConcernError + ) @client_context.require_secondaries_count(1) @client_context.require_test_commands def test_raise_wtimeout(self): - self.addCleanup(client_context.client.drop_database, 'pymongo_test') + self.addCleanup(client_context.client.drop_database, "pymongo_test") self.addCleanup(enable_replication, client_context.client) # Disable replication to guarantee a wtimeout error. disable_replication(client_context.client) - self.assertWriteOpsRaise(WriteConcern(w=client_context.w, wtimeout=1), - WTimeoutError) + self.assertWriteOpsRaise(WriteConcern(w=client_context.w, wtimeout=1), WTimeoutError) @client_context.require_failCommand_fail_point def test_error_includes_errInfo(self): @@ -174,21 +176,12 @@ def test_error_includes_errInfo(self): "code": 100, "codeName": "UnsatisfiableWriteConcern", "errmsg": "Not enough data-bearing nodes", - "errInfo": { - "writeConcern": { - "w": 2, - "wtimeout": 0, - "provenance": "clientSupplied" - } - } + "errInfo": {"writeConcern": {"w": 2, "wtimeout": 0, "provenance": "clientSupplied"}}, } cause_wce = { "configureFailPoint": "failCommand", "mode": {"times": 2}, - "data": { - "failCommands": ["insert"], - "writeConcernError": expected_wce - }, + "data": {"failCommands": ["insert"], "writeConcernError": expected_wce}, } with self.fail_point(cause_wce): # Write concern error on insert includes errInfo. @@ -200,10 +193,15 @@ def test_error_includes_errInfo(self): with self.assertRaises(BulkWriteError) as ctx: self.db.test.bulk_write([InsertOne({})]) expected_details = { - 'writeErrors': [], - 'writeConcernErrors': [expected_wce], - 'nInserted': 1, 'nUpserted': 0, 'nMatched': 0, 'nModified': 0, - 'nRemoved': 0, 'upserted': []} + "writeErrors": [], + "writeConcernErrors": [expected_wce], + "nInserted": 1, + "nUpserted": 0, + "nMatched": 0, + "nModified": 0, + "nRemoved": 0, + "upserted": [], + } self.assertEqual(ctx.exception.details, expected_details) @client_context.require_version_min(4, 9) @@ -216,14 +214,13 @@ def test_write_error_details_exposes_errinfo(self): validator = {"x": {"$type": "string"}} db.create_collection("test", validator=validator) with self.assertRaises(WriteError) as ctx: - db.test.insert_one({'x': 1}) + db.test.insert_one({"x": 1}) self.assertEqual(ctx.exception.code, 121) self.assertIsNotNone(ctx.exception.details) - self.assertIsNotNone(ctx.exception.details.get('errInfo')) - for event in listener.results['succeeded']: - if event.command_name == 'insert': - self.assertEqual( - event.reply['writeErrors'][0], ctx.exception.details) + self.assertIsNotNone(ctx.exception.details.get("errInfo")) + for event in listener.results["succeeded"]: + if event.command_name == "insert": + self.assertEqual(event.reply["writeErrors"][0], ctx.exception.details) break else: self.fail("Couldn't find insert event.") @@ -232,77 +229,58 @@ def test_write_error_details_exposes_errinfo(self): def normalize_write_concern(concern): result = {} for key in concern: - if key.lower() == 'wtimeoutms': - result['wtimeout'] = concern[key] - elif key == 'journal': - result['j'] = concern[key] + if key.lower() == "wtimeoutms": + result["wtimeout"] = concern[key] + elif key == "journal": + result["j"] = concern[key] else: result[key] = concern[key] return result def create_connection_string_test(test_case): - def run_test(self): - uri = test_case['uri'] - valid = test_case['valid'] - warning = test_case['warning'] + uri = test_case["uri"] + valid = test_case["valid"] + warning = test_case["warning"] if not valid: if warning is False: - self.assertRaises( - (ConfigurationError, ValueError), - MongoClient, - uri, - connect=False) + self.assertRaises((ConfigurationError, ValueError), MongoClient, uri, connect=False) else: with warnings.catch_warnings(): - warnings.simplefilter('error', UserWarning) - self.assertRaises( - UserWarning, - MongoClient, - uri, - connect=False) + warnings.simplefilter("error", UserWarning) + self.assertRaises(UserWarning, MongoClient, uri, connect=False) else: client = MongoClient(uri, connect=False) - if 'writeConcern' in test_case: + if "writeConcern" in test_case: document = client.write_concern.document - self.assertEqual( - document, - normalize_write_concern(test_case['writeConcern'])) - if 'readConcern' in test_case: + self.assertEqual(document, normalize_write_concern(test_case["writeConcern"])) + if "readConcern" in test_case: document = client.read_concern.document - self.assertEqual(document, test_case['readConcern']) + self.assertEqual(document, test_case["readConcern"]) return run_test def create_document_test(test_case): - def run_test(self): - valid = test_case['valid'] + valid = test_case["valid"] - if 'writeConcern' in test_case: - normalized = normalize_write_concern(test_case['writeConcern']) + if "writeConcern" in test_case: + normalized = normalize_write_concern(test_case["writeConcern"]) if not valid: - self.assertRaises( - (ConfigurationError, ValueError), - WriteConcern, - **normalized) + self.assertRaises((ConfigurationError, ValueError), WriteConcern, **normalized) else: concern = WriteConcern(**normalized) - self.assertEqual( - concern.document, test_case['writeConcernDocument']) - self.assertEqual( - concern.acknowledged, test_case['isAcknowledged']) - self.assertEqual( - concern.is_server_default, test_case['isServerDefault']) - if 'readConcern' in test_case: + self.assertEqual(concern.document, test_case["writeConcernDocument"]) + self.assertEqual(concern.acknowledged, test_case["isAcknowledged"]) + self.assertEqual(concern.is_server_default, test_case["isServerDefault"]) + if "readConcern" in test_case: # Any string for 'level' is equaly valid - concern = ReadConcern(**test_case['readConcern']) - self.assertEqual(concern.document, test_case['readConcernDocument']) - self.assertEqual( - not bool(concern.level), test_case['isServerDefault']) + concern = ReadConcern(**test_case["readConcern"]) + self.assertEqual(concern.document, test_case["readConcernDocument"]) + self.assertEqual(not bool(concern.level), test_case["isServerDefault"]) return run_test @@ -311,25 +289,26 @@ def create_tests(): for dirpath, _, filenames in os.walk(_TEST_PATH): dirname = os.path.split(dirpath)[-1] - if dirname == 'operation': + if dirname == "operation": # This directory is tested by TestOperations. continue - elif dirname == 'connection-string': + elif dirname == "connection-string": create_test = create_connection_string_test else: create_test = create_document_test for filename in filenames: with open(os.path.join(dirpath, filename)) as test_stream: - test_cases = json.load(test_stream)['tests'] + test_cases = json.load(test_stream)["tests"] fname = os.path.splitext(filename)[0] for test_case in test_cases: new_test = create_test(test_case) - test_name = 'test_%s_%s_%s' % ( - dirname.replace('-', '_'), - fname.replace('-', '_'), - str(test_case['description'].lower().replace(' ', '_'))) + test_name = "test_%s_%s_%s" % ( + dirname.replace("-", "_"), + fname.replace("-", "_"), + str(test_case["description"].lower().replace(" ", "_")), + ) new_test.__name__ = test_name setattr(TestReadWriteConcernSpec, new_test.__name__, new_test) @@ -340,11 +319,11 @@ def create_tests(): class TestOperation(SpecRunner): # Location of JSON test specifications. - TEST_PATH = os.path.join(_TEST_PATH, 'operation') + TEST_PATH = os.path.join(_TEST_PATH, "operation") def get_outcome_coll_name(self, outcome, collection): """Spec says outcome has an optional 'collection.name'.""" - return outcome['collection'].get('name', collection.name) + return outcome["collection"].get("name", collection.name) def create_operation_test(scenario_def, test, name): @@ -355,10 +334,9 @@ def run_scenario(self): return run_scenario -test_creator = TestCreator( - create_operation_test, TestOperation, TestOperation.TEST_PATH) +test_creator = TestCreator(create_operation_test, TestOperation, TestOperation.TEST_PATH) test_creator.create_tests() -if __name__ == '__main__': +if __name__ == "__main__": unittest.main() diff --git a/test/test_replica_set_reconfig.py b/test/test_replica_set_reconfig.py index f19a32ea4e..898be99d4d 100644 --- a/test/test_replica_set_reconfig.py +++ b/test/test_replica_set_reconfig.py @@ -18,12 +18,13 @@ sys.path[0:0] = [""] -from pymongo.errors import ConnectionFailure, ServerSelectionTimeoutError -from pymongo import ReadPreference -from test import unittest, client_context, client_knobs, MockClientTest +from test import MockClientTest, client_context, client_knobs, unittest from test.pymongo_mocks import MockClient from test.utils import wait_until +from pymongo import ReadPreference +from pymongo.errors import ConnectionFailure, ServerSelectionTimeoutError + @client_context.require_connection @client_context.require_no_load_balancer @@ -38,54 +39,53 @@ class TestSecondaryBecomesStandalone(MockClientTest): def test_client(self): c = MockClient( standalones=[], - members=['a:1', 'b:2', 'c:3'], + members=["a:1", "b:2", "c:3"], mongoses=[], - host='a:1,b:2,c:3', - replicaSet='rs', + host="a:1,b:2,c:3", + replicaSet="rs", serverSelectionTimeoutMS=100, - connect=False) + connect=False, + ) self.addCleanup(c.close) # C is brought up as a standalone. - c.mock_members.remove('c:3') - c.mock_standalones.append('c:3') + c.mock_members.remove("c:3") + c.mock_standalones.append("c:3") # Fail over. - c.kill_host('a:1') - c.kill_host('b:2') + c.kill_host("a:1") + c.kill_host("b:2") with self.assertRaises(ServerSelectionTimeoutError): - c.db.command('ping') + c.db.command("ping") self.assertEqual(c.address, None) # Client can still discover the primary node - c.revive_host('a:1') - wait_until(lambda: c.address is not None, 'connect to primary') - self.assertEqual(c.address, ('a', 1)) + c.revive_host("a:1") + wait_until(lambda: c.address is not None, "connect to primary") + self.assertEqual(c.address, ("a", 1)) def test_replica_set_client(self): c = MockClient( standalones=[], - members=['a:1', 'b:2', 'c:3'], + members=["a:1", "b:2", "c:3"], mongoses=[], - host='a:1,b:2,c:3', - replicaSet='rs') + host="a:1,b:2,c:3", + replicaSet="rs", + ) self.addCleanup(c.close) - wait_until(lambda: ('b', 2) in c.secondaries, - 'discover host "b"') + wait_until(lambda: ("b", 2) in c.secondaries, 'discover host "b"') - wait_until(lambda: ('c', 3) in c.secondaries, - 'discover host "c"') + wait_until(lambda: ("c", 3) in c.secondaries, 'discover host "c"') # C is brought up as a standalone. - c.mock_members.remove('c:3') - c.mock_standalones.append('c:3') + c.mock_members.remove("c:3") + c.mock_standalones.append("c:3") - wait_until(lambda: set([('b', 2)]) == c.secondaries, - 'update the list of secondaries') + wait_until(lambda: set([("b", 2)]) == c.secondaries, "update the list of secondaries") - self.assertEqual(('a', 1), c.primary) + self.assertEqual(("a", 1), c.primary) class TestSecondaryRemoved(MockClientTest): @@ -94,21 +94,21 @@ class TestSecondaryRemoved(MockClientTest): def test_replica_set_client(self): c = MockClient( standalones=[], - members=['a:1', 'b:2', 'c:3'], + members=["a:1", "b:2", "c:3"], mongoses=[], - host='a:1,b:2,c:3', - replicaSet='rs') + host="a:1,b:2,c:3", + replicaSet="rs", + ) self.addCleanup(c.close) - wait_until(lambda: ('b', 2) in c.secondaries, 'discover host "b"') - wait_until(lambda: ('c', 3) in c.secondaries, 'discover host "c"') + wait_until(lambda: ("b", 2) in c.secondaries, 'discover host "b"') + wait_until(lambda: ("c", 3) in c.secondaries, 'discover host "c"') # C is removed. - c.mock_hello_hosts.remove('c:3') - wait_until(lambda: set([('b', 2)]) == c.secondaries, - 'update list of secondaries') + c.mock_hello_hosts.remove("c:3") + wait_until(lambda: set([("b", 2)]) == c.secondaries, "update list of secondaries") - self.assertEqual(('a', 1), c.primary) + self.assertEqual(("a", 1), c.primary) class TestSocketError(MockClientTest): @@ -117,21 +117,22 @@ def test_socket_error_marks_member_down(self): with client_knobs(heartbeat_frequency=999999): c = MockClient( standalones=[], - members=['a:1', 'b:2'], + members=["a:1", "b:2"], mongoses=[], - host='a:1', - replicaSet='rs', - serverSelectionTimeoutMS=100) + host="a:1", + replicaSet="rs", + serverSelectionTimeoutMS=100, + ) self.addCleanup(c.close) - wait_until(lambda: len(c.nodes) == 2, 'discover both nodes') + wait_until(lambda: len(c.nodes) == 2, "discover both nodes") # b now raises socket.error. - c.mock_down_hosts.append('b:2') + c.mock_down_hosts.append("b:2") self.assertRaises( ConnectionFailure, - c.db.collection.with_options( - read_preference=ReadPreference.SECONDARY).find_one) + c.db.collection.with_options(read_preference=ReadPreference.SECONDARY).find_one, + ) self.assertEqual(1, len(c.nodes)) @@ -139,51 +140,44 @@ def test_socket_error_marks_member_down(self): class TestSecondaryAdded(MockClientTest): def test_client(self): c = MockClient( - standalones=[], - members=['a:1', 'b:2'], - mongoses=[], - host='a:1', - replicaSet='rs') + standalones=[], members=["a:1", "b:2"], mongoses=[], host="a:1", replicaSet="rs" + ) self.addCleanup(c.close) - wait_until(lambda: len(c.nodes) == 2, 'discover both nodes') + wait_until(lambda: len(c.nodes) == 2, "discover both nodes") # MongoClient connects to primary by default. - self.assertEqual(c.address, ('a', 1)) - self.assertEqual(set([('a', 1), ('b', 2)]), c.nodes) + self.assertEqual(c.address, ("a", 1)) + self.assertEqual(set([("a", 1), ("b", 2)]), c.nodes) # C is added. - c.mock_members.append('c:3') - c.mock_hello_hosts.append('c:3') + c.mock_members.append("c:3") + c.mock_hello_hosts.append("c:3") - c.db.command('ping') + c.db.command("ping") - self.assertEqual(c.address, ('a', 1)) + self.assertEqual(c.address, ("a", 1)) - wait_until(lambda: set([('a', 1), ('b', 2), ('c', 3)]) == c.nodes, - 'reconnect to both secondaries') + wait_until( + lambda: set([("a", 1), ("b", 2), ("c", 3)]) == c.nodes, "reconnect to both secondaries" + ) def test_replica_set_client(self): c = MockClient( - standalones=[], - members=['a:1', 'b:2'], - mongoses=[], - host='a:1', - replicaSet='rs') + standalones=[], members=["a:1", "b:2"], mongoses=[], host="a:1", replicaSet="rs" + ) self.addCleanup(c.close) - wait_until(lambda: ('a', 1) == c.primary, 'discover the primary') - wait_until(lambda: set([('b', 2)]) == c.secondaries, - 'discover the secondary') + wait_until(lambda: ("a", 1) == c.primary, "discover the primary") + wait_until(lambda: set([("b", 2)]) == c.secondaries, "discover the secondary") # C is added. - c.mock_members.append('c:3') - c.mock_hello_hosts.append('c:3') + c.mock_members.append("c:3") + c.mock_hello_hosts.append("c:3") - wait_until(lambda: set([('b', 2), ('c', 3)]) == c.secondaries, - 'discover the new secondary') + wait_until(lambda: set([("b", 2), ("c", 3)]) == c.secondaries, "discover the new secondary") - self.assertEqual(('a', 1), c.primary) + self.assertEqual(("a", 1), c.primary) if __name__ == "__main__": diff --git a/test/test_retryable_reads.py b/test/test_retryable_reads.py index c4c093f66f..808477a8c0 100644 --- a/test/test_retryable_reads.py +++ b/test/test_retryable_reads.py @@ -21,28 +21,32 @@ sys.path[0:0] = [""] -from pymongo.mongo_client import MongoClient -from pymongo.monitoring import (ConnectionCheckedOutEvent, - ConnectionCheckOutFailedEvent, - ConnectionCheckOutFailedReason, - PoolClearedEvent) -from pymongo.write_concern import WriteConcern - -from test import (client_context, - client_knobs, - IntegrationTest, - PyMongoTestCase, - unittest) -from test.utils import (CMAPListener, - OvertCommandListener, - rs_or_single_client, - TestCreator) +from test import ( + IntegrationTest, + PyMongoTestCase, + client_context, + client_knobs, + unittest, +) +from test.utils import ( + CMAPListener, + OvertCommandListener, + TestCreator, + rs_or_single_client, +) from test.utils_spec_runner import SpecRunner +from pymongo.mongo_client import MongoClient +from pymongo.monitoring import ( + ConnectionCheckedOutEvent, + ConnectionCheckOutFailedEvent, + ConnectionCheckOutFailedReason, + PoolClearedEvent, +) +from pymongo.write_concern import WriteConcern # Location of JSON test specifications. -_TEST_PATH = os.path.join( - os.path.dirname(os.path.realpath(__file__)), 'retryable_reads') +_TEST_PATH = os.path.join(os.path.dirname(os.path.realpath(__file__)), "retryable_reads") class TestClientOptions(PyMongoTestCase): @@ -57,9 +61,9 @@ def test_kwargs(self): self.assertEqual(client.options.retry_reads, False) def test_uri(self): - client = MongoClient('mongodb://h/?retryReads=true', connect=False) + client = MongoClient("mongodb://h/?retryReads=true", connect=False) self.assertEqual(client.options.retry_reads, True) - client = MongoClient('mongodb://h/?retryReads=false', connect=False) + client = MongoClient("mongodb://h/?retryReads=false", connect=False) self.assertEqual(client.options.retry_reads, False) @@ -76,51 +80,49 @@ def setUpClass(cls): def maybe_skip_scenario(self, test): super(TestSpec, self).maybe_skip_scenario(test) - skip_names = [ - 'listCollectionObjects', 'listIndexNames', 'listDatabaseObjects'] + skip_names = ["listCollectionObjects", "listIndexNames", "listDatabaseObjects"] for name in skip_names: - if name.lower() in test['description'].lower(): - self.skipTest('PyMongo does not support %s' % (name,)) + if name.lower() in test["description"].lower(): + self.skipTest("PyMongo does not support %s" % (name,)) # Serverless does not support $out and collation. if client_context.serverless: - for operation in test['operations']: - if operation['name'] == 'aggregate': - for stage in operation['arguments']['pipeline']: + for operation in test["operations"]: + if operation["name"] == "aggregate": + for stage in operation["arguments"]["pipeline"]: if "$out" in stage: - self.skipTest( - "MongoDB Serverless does not support $out") - if "collation" in operation['arguments']: - self.skipTest( - "MongoDB Serverless does not support collations") + self.skipTest("MongoDB Serverless does not support $out") + if "collation" in operation["arguments"]: + self.skipTest("MongoDB Serverless does not support collations") # Skip changeStream related tests on MMAPv1 and serverless. - test_name = self.id().rsplit('.')[-1] - if 'changestream' in test_name.lower(): - if client_context.storage_engine == 'mmapv1': + test_name = self.id().rsplit(".")[-1] + if "changestream" in test_name.lower(): + if client_context.storage_engine == "mmapv1": self.skipTest("MMAPv1 does not support change streams.") if client_context.serverless: self.skipTest("Serverless does not support change streams.") def get_scenario_coll_name(self, scenario_def): """Override a test's collection name to support GridFS tests.""" - if 'bucket_name' in scenario_def: - return scenario_def['bucket_name'] + if "bucket_name" in scenario_def: + return scenario_def["bucket_name"] return super(TestSpec, self).get_scenario_coll_name(scenario_def) def setup_scenario(self, scenario_def): """Override a test's setup to support GridFS tests.""" - if 'bucket_name' in scenario_def: + if "bucket_name" in scenario_def: db_name = self.get_scenario_db_name(scenario_def) db = client_context.client.get_database( - db_name, write_concern=WriteConcern(w='majority')) + db_name, write_concern=WriteConcern(w="majority") + ) # Create a bucket for the retryable reads GridFS tests. client_context.client.drop_database(db_name) - if scenario_def['data']: - data = scenario_def['data'] + if scenario_def["data"]: + data = scenario_def["data"] # Load data. - db['fs.chunks'].insert_many(data['fs.chunks']) - db['fs.files'].insert_many(data['fs.files']) + db["fs.chunks"].insert_many(data["fs.chunks"]) + db["fs.files"].insert_many(data["fs.files"]) else: super(TestSpec, self).setup_scenario(scenario_def) @@ -155,25 +157,23 @@ class TestPoolPausedError(IntegrationTest): RUN_ON_SERVERLESS = False @client_context.require_failCommand_blockConnection - @client_knobs(heartbeat_frequency=.05, min_heartbeat_interval=.05) + @client_knobs(heartbeat_frequency=0.05, min_heartbeat_interval=0.05) def test_pool_paused_error_is_retryable(self): cmap_listener = CMAPListener() cmd_listener = OvertCommandListener() - client = rs_or_single_client( - maxPoolSize=1, - event_listeners=[cmap_listener, cmd_listener]) + client = rs_or_single_client(maxPoolSize=1, event_listeners=[cmap_listener, cmd_listener]) self.addCleanup(client.close) for _ in range(10): cmap_listener.reset() cmd_listener.reset() threads = [FindThread(client.pymongo_test.test) for _ in range(2)] fail_command = { - 'mode': {'times': 1}, - 'data': { - 'failCommands': ['find'], - 'blockConnection': True, - 'blockTimeMS': 1000, - 'errorCode': 91, + "mode": {"times": 1}, + "data": { + "failCommands": ["find"], + "blockConnection": True, + "blockTimeMS": 1000, + "errorCode": 91, }, } with self.fail_point(fail_command): @@ -192,29 +192,25 @@ def test_pool_paused_error_is_retryable(self): break # Via CMAP monitoring, assert that the first check out succeeds. - cmap_events = cmap_listener.events_by_type(( - ConnectionCheckedOutEvent, - ConnectionCheckOutFailedEvent, - PoolClearedEvent)) + cmap_events = cmap_listener.events_by_type( + (ConnectionCheckedOutEvent, ConnectionCheckOutFailedEvent, PoolClearedEvent) + ) msg = pprint.pformat(cmap_listener.events) self.assertIsInstance(cmap_events[0], ConnectionCheckedOutEvent, msg) self.assertIsInstance(cmap_events[1], PoolClearedEvent, msg) - self.assertIsInstance( - cmap_events[2], ConnectionCheckOutFailedEvent, msg) - self.assertEqual(cmap_events[2].reason, - ConnectionCheckOutFailedReason.CONN_ERROR, - msg) + self.assertIsInstance(cmap_events[2], ConnectionCheckOutFailedEvent, msg) + self.assertEqual(cmap_events[2].reason, ConnectionCheckOutFailedReason.CONN_ERROR, msg) self.assertIsInstance(cmap_events[3], ConnectionCheckedOutEvent, msg) # Connection check out failures are not reflected in command # monitoring because we only publish command events _after_ checking # out a connection. - started = cmd_listener.results['started'] + started = cmd_listener.results["started"] msg = pprint.pformat(cmd_listener.results) self.assertEqual(3, len(started), msg) - succeeded = cmd_listener.results['succeeded'] + succeeded = cmd_listener.results["succeeded"] self.assertEqual(2, len(succeeded), msg) - failed = cmd_listener.results['failed'] + failed = cmd_listener.results["failed"] self.assertEqual(1, len(failed), msg) diff --git a/test/test_retryable_writes.py b/test/test_retryable_writes.py index f3f09095d7..f4af65966a 100644 --- a/test/test_retryable_writes.py +++ b/test/test_retryable_writes.py @@ -22,45 +22,46 @@ sys.path[0:0] = [""] +from test import IntegrationTest, SkipTest, client_context, client_knobs, unittest +from test.utils import ( + CMAPListener, + DeprecationFilter, + OvertCommandListener, + TestCreator, + rs_or_single_client, +) +from test.utils_spec_runner import SpecRunner +from test.version import Version + from bson.codec_options import DEFAULT_CODEC_OPTIONS from bson.int64 import Int64 from bson.raw_bson import RawBSONDocument from bson.son import SON - - -from pymongo.errors import (ConnectionFailure, - OperationFailure, - ServerSelectionTimeoutError, - WriteConcernError) +from pymongo.errors import ( + ConnectionFailure, + OperationFailure, + ServerSelectionTimeoutError, + WriteConcernError, +) from pymongo.mongo_client import MongoClient -from pymongo.monitoring import (ConnectionCheckedOutEvent, - ConnectionCheckOutFailedEvent, - ConnectionCheckOutFailedReason, - PoolClearedEvent) -from pymongo.operations import (InsertOne, - DeleteMany, - DeleteOne, - ReplaceOne, - UpdateMany, - UpdateOne) +from pymongo.monitoring import ( + ConnectionCheckedOutEvent, + ConnectionCheckOutFailedEvent, + ConnectionCheckOutFailedReason, + PoolClearedEvent, +) +from pymongo.operations import ( + DeleteMany, + DeleteOne, + InsertOne, + ReplaceOne, + UpdateMany, + UpdateOne, +) from pymongo.write_concern import WriteConcern -from test import (client_context, - client_knobs, - IntegrationTest, - SkipTest, - unittest) -from test.utils import (CMAPListener, - DeprecationFilter, - OvertCommandListener, - rs_or_single_client, - TestCreator) -from test.utils_spec_runner import SpecRunner -from test.version import Version - # Location of JSON test specifications. -_TEST_PATH = os.path.join( - os.path.dirname(os.path.realpath(__file__)), 'retryable_writes', 'legacy') +_TEST_PATH = os.path.join(os.path.dirname(os.path.realpath(__file__)), "retryable_writes", "legacy") class TestAllScenarios(SpecRunner): @@ -68,23 +69,23 @@ class TestAllScenarios(SpecRunner): RUN_ON_SERVERLESS = True def get_object_name(self, op): - return op.get('object', 'collection') + return op.get("object", "collection") def get_scenario_db_name(self, scenario_def): - return scenario_def.get('database_name', 'pymongo_test') + return scenario_def.get("database_name", "pymongo_test") def get_scenario_coll_name(self, scenario_def): - return scenario_def.get('collection_name', 'test') + return scenario_def.get("collection_name", "test") def run_test_ops(self, sessions, collection, test): # Transform retryable writes spec format into transactions. - operation = test['operation'] - outcome = test['outcome'] - if 'error' in outcome: - operation['error'] = outcome['error'] - if 'result' in outcome: - operation['result'] = outcome['result'] - test['operations'] = [operation] + operation = test["operation"] + outcome = test["outcome"] + if "error" in outcome: + operation["error"] = outcome["error"] + if "result" in outcome: + operation["result"] = outcome["result"] + test["operations"] = [operation] super(TestAllScenarios, self).run_test_ops(sessions, collection, test) @@ -96,6 +97,7 @@ def run_scenario(self): return run_scenario + test_creator = TestCreator(create_test, TestAllScenarios, _TEST_PATH) test_creator.create_tests() @@ -103,31 +105,36 @@ def run_scenario(self): def retryable_single_statement_ops(coll): return [ (coll.bulk_write, [[InsertOne({}), InsertOne({})]], {}), - (coll.bulk_write, [[InsertOne({}), - InsertOne({})]], {'ordered': False}), + (coll.bulk_write, [[InsertOne({}), InsertOne({})]], {"ordered": False}), (coll.bulk_write, [[ReplaceOne({}, {})]], {}), (coll.bulk_write, [[ReplaceOne({}, {}), ReplaceOne({}, {})]], {}), - (coll.bulk_write, [[UpdateOne({}, {'$set': {'a': 1}}), - UpdateOne({}, {'$set': {'a': 1}})]], {}), + ( + coll.bulk_write, + [[UpdateOne({}, {"$set": {"a": 1}}), UpdateOne({}, {"$set": {"a": 1}})]], + {}, + ), (coll.bulk_write, [[DeleteOne({})]], {}), (coll.bulk_write, [[DeleteOne({}), DeleteOne({})]], {}), (coll.insert_one, [{}], {}), (coll.insert_many, [[{}, {}]], {}), (coll.replace_one, [{}, {}], {}), - (coll.update_one, [{}, {'$set': {'a': 1}}], {}), + (coll.update_one, [{}, {"$set": {"a": 1}}], {}), (coll.delete_one, [{}], {}), - (coll.find_one_and_replace, [{}, {'a': 3}], {}), - (coll.find_one_and_update, [{}, {'$set': {'a': 1}}], {}), + (coll.find_one_and_replace, [{}, {"a": 3}], {}), + (coll.find_one_and_update, [{}, {"$set": {"a": 1}}], {}), (coll.find_one_and_delete, [{}, {}], {}), ] def non_retryable_single_statement_ops(coll): return [ - (coll.bulk_write, [[UpdateOne({}, {'$set': {'a': 1}}), - UpdateMany({}, {'$set': {'a': 1}})]], {}), + ( + coll.bulk_write, + [[UpdateOne({}, {"$set": {"a": 1}}), UpdateMany({}, {"$set": {"a": 1}})]], + {}, + ), (coll.bulk_write, [[DeleteOne({}), DeleteMany({})]], {}), - (coll.update_many, [{}, {'$set': {'a': 1}}], {}), + (coll.update_many, [{}, {"$set": {"a": 1}}], {}), (coll.delete_many, [{}], {}), ] @@ -148,13 +155,11 @@ def tearDownClass(cls): class TestRetryableWritesMMAPv1(IgnoreDeprecationsTest): - @classmethod def setUpClass(cls): super(TestRetryableWritesMMAPv1, cls).setUpClass() # Speed up the tests by decreasing the heartbeat frequency. - cls.knobs = client_knobs(heartbeat_frequency=0.1, - min_heartbeat_interval=0.1) + cls.knobs = client_knobs(heartbeat_frequency=0.1, min_heartbeat_interval=0.1) cls.knobs.enable() cls.client = rs_or_single_client(retryWrites=True) cls.db = cls.client.pymongo_test @@ -167,31 +172,29 @@ def tearDownClass(cls): @client_context.require_no_standalone def test_actionable_error_message(self): - if client_context.storage_engine != 'mmapv1': - raise SkipTest('This cluster is not running MMAPv1') - - expected_msg = ("This MongoDB deployment does not support retryable " - "writes. Please add retryWrites=false to your " - "connection string.") - for method, args, kwargs in retryable_single_statement_ops( - self.db.retryable_write_test): + if client_context.storage_engine != "mmapv1": + raise SkipTest("This cluster is not running MMAPv1") + + expected_msg = ( + "This MongoDB deployment does not support retryable " + "writes. Please add retryWrites=false to your " + "connection string." + ) + for method, args, kwargs in retryable_single_statement_ops(self.db.retryable_write_test): with self.assertRaisesRegex(OperationFailure, expected_msg): method(*args, **kwargs) class TestRetryableWrites(IgnoreDeprecationsTest): - @classmethod @client_context.require_no_mmap def setUpClass(cls): super(TestRetryableWrites, cls).setUpClass() # Speed up the tests by decreasing the heartbeat frequency. - cls.knobs = client_knobs(heartbeat_frequency=0.1, - min_heartbeat_interval=0.1) + cls.knobs = client_knobs(heartbeat_frequency=0.1, min_heartbeat_interval=0.1) cls.knobs.enable() cls.listener = OvertCommandListener() - cls.client = rs_or_single_client( - retryWrites=True, event_listeners=[cls.listener]) + cls.client = rs_or_single_client(retryWrites=True, event_listeners=[cls.listener]) cls.db = cls.client.pymongo_test @classmethod @@ -202,117 +205,123 @@ def tearDownClass(cls): def setUp(self): if client_context.is_rs and client_context.test_commands_enabled: - self.client.admin.command(SON([ - ('configureFailPoint', 'onPrimaryTransactionalWrite'), - ('mode', 'alwaysOn')])) + self.client.admin.command( + SON([("configureFailPoint", "onPrimaryTransactionalWrite"), ("mode", "alwaysOn")]) + ) def tearDown(self): if client_context.is_rs and client_context.test_commands_enabled: - self.client.admin.command(SON([ - ('configureFailPoint', 'onPrimaryTransactionalWrite'), - ('mode', 'off')])) + self.client.admin.command( + SON([("configureFailPoint", "onPrimaryTransactionalWrite"), ("mode", "off")]) + ) def test_supported_single_statement_no_retry(self): listener = OvertCommandListener() - client = rs_or_single_client( - retryWrites=False, event_listeners=[listener]) + client = rs_or_single_client(retryWrites=False, event_listeners=[listener]) self.addCleanup(client.close) - for method, args, kwargs in retryable_single_statement_ops( - client.db.retryable_write_test): - msg = '%s(*%r, **%r)' % (method.__name__, args, kwargs) + for method, args, kwargs in retryable_single_statement_ops(client.db.retryable_write_test): + msg = "%s(*%r, **%r)" % (method.__name__, args, kwargs) listener.results.clear() method(*args, **kwargs) - for event in listener.results['started']: + for event in listener.results["started"]: self.assertNotIn( - 'txnNumber', event.command, - '%s sent txnNumber with %s' % (msg, event.command_name)) + "txnNumber", + event.command, + "%s sent txnNumber with %s" % (msg, event.command_name), + ) @client_context.require_no_standalone def test_supported_single_statement_supported_cluster(self): - for method, args, kwargs in retryable_single_statement_ops( - self.db.retryable_write_test): - msg = '%s(*%r, **%r)' % (method.__name__, args, kwargs) + for method, args, kwargs in retryable_single_statement_ops(self.db.retryable_write_test): + msg = "%s(*%r, **%r)" % (method.__name__, args, kwargs) self.listener.results.clear() method(*args, **kwargs) - commands_started = self.listener.results['started'] - self.assertEqual(len(self.listener.results['succeeded']), 1, msg) + commands_started = self.listener.results["started"] + self.assertEqual(len(self.listener.results["succeeded"]), 1, msg) first_attempt = commands_started[0] self.assertIn( - 'lsid', first_attempt.command, - '%s sent no lsid with %s' % (msg, first_attempt.command_name)) - initial_session_id = first_attempt.command['lsid'] + "lsid", + first_attempt.command, + "%s sent no lsid with %s" % (msg, first_attempt.command_name), + ) + initial_session_id = first_attempt.command["lsid"] self.assertIn( - 'txnNumber', first_attempt.command, - '%s sent no txnNumber with %s' % ( - msg, first_attempt.command_name)) + "txnNumber", + first_attempt.command, + "%s sent no txnNumber with %s" % (msg, first_attempt.command_name), + ) # There should be no retry when the failpoint is not active. - if (client_context.is_mongos or - not client_context.test_commands_enabled): + if client_context.is_mongos or not client_context.test_commands_enabled: self.assertEqual(len(commands_started), 1) continue - initial_transaction_id = first_attempt.command['txnNumber'] + initial_transaction_id = first_attempt.command["txnNumber"] retry_attempt = commands_started[1] self.assertIn( - 'lsid', retry_attempt.command, - '%s sent no lsid with %s' % (msg, first_attempt.command_name)) - self.assertEqual( - retry_attempt.command['lsid'], initial_session_id, msg) + "lsid", + retry_attempt.command, + "%s sent no lsid with %s" % (msg, first_attempt.command_name), + ) + self.assertEqual(retry_attempt.command["lsid"], initial_session_id, msg) self.assertIn( - 'txnNumber', retry_attempt.command, - '%s sent no txnNumber with %s' % ( - msg, first_attempt.command_name)) - self.assertEqual(retry_attempt.command['txnNumber'], - initial_transaction_id, msg) + "txnNumber", + retry_attempt.command, + "%s sent no txnNumber with %s" % (msg, first_attempt.command_name), + ) + self.assertEqual(retry_attempt.command["txnNumber"], initial_transaction_id, msg) def test_supported_single_statement_unsupported_cluster(self): if client_context.is_rs or client_context.is_mongos: - raise SkipTest('This cluster supports retryable writes') + raise SkipTest("This cluster supports retryable writes") - for method, args, kwargs in retryable_single_statement_ops( - self.db.retryable_write_test): - msg = '%s(*%r, **%r)' % (method.__name__, args, kwargs) + for method, args, kwargs in retryable_single_statement_ops(self.db.retryable_write_test): + msg = "%s(*%r, **%r)" % (method.__name__, args, kwargs) self.listener.results.clear() method(*args, **kwargs) - for event in self.listener.results['started']: + for event in self.listener.results["started"]: self.assertNotIn( - 'txnNumber', event.command, - '%s sent txnNumber with %s' % (msg, event.command_name)) + "txnNumber", + event.command, + "%s sent txnNumber with %s" % (msg, event.command_name), + ) def test_unsupported_single_statement(self): coll = self.db.retryable_write_test coll.insert_many([{}, {}]) coll_w0 = coll.with_options(write_concern=WriteConcern(w=0)) - for method, args, kwargs in (non_retryable_single_statement_ops(coll) + - retryable_single_statement_ops(coll_w0)): - msg = '%s(*%r, **%r)' % (method.__name__, args, kwargs) + for method, args, kwargs in non_retryable_single_statement_ops( + coll + ) + retryable_single_statement_ops(coll_w0): + msg = "%s(*%r, **%r)" % (method.__name__, args, kwargs) self.listener.results.clear() method(*args, **kwargs) - started_events = self.listener.results['started'] - self.assertEqual(len(self.listener.results['succeeded']), - len(started_events), msg) - self.assertEqual(len(self.listener.results['failed']), 0, msg) + started_events = self.listener.results["started"] + self.assertEqual(len(self.listener.results["succeeded"]), len(started_events), msg) + self.assertEqual(len(self.listener.results["failed"]), 0, msg) for event in started_events: self.assertNotIn( - 'txnNumber', event.command, - '%s sent txnNumber with %s' % (msg, event.command_name)) + "txnNumber", + event.command, + "%s sent txnNumber with %s" % (msg, event.command_name), + ) def test_server_selection_timeout_not_retried(self): """A ServerSelectionTimeoutError is not retried.""" listener = OvertCommandListener() client = MongoClient( - 'somedomainthatdoesntexist.org', + "somedomainthatdoesntexist.org", serverSelectionTimeoutMS=1, - retryWrites=True, event_listeners=[listener]) - for method, args, kwargs in retryable_single_statement_ops( - client.db.retryable_write_test): - msg = '%s(*%r, **%r)' % (method.__name__, args, kwargs) + retryWrites=True, + event_listeners=[listener], + ) + for method, args, kwargs in retryable_single_statement_ops(client.db.retryable_write_test): + msg = "%s(*%r, **%r)" % (method.__name__, args, kwargs) listener.results.clear() with self.assertRaises(ServerSelectionTimeoutError, msg=msg): method(*args, **kwargs) - self.assertEqual(len(listener.results['started']), 0, msg) + self.assertEqual(len(listener.results["started"]), 0, msg) @client_context.require_replica_set @client_context.require_test_commands @@ -321,8 +330,7 @@ def test_retry_timeout_raises_original_error(self): original error. """ listener = OvertCommandListener() - client = rs_or_single_client( - retryWrites=True, event_listeners=[listener]) + client = rs_or_single_client(retryWrites=True, event_listeners=[listener]) self.addCleanup(client.close) topology = client._topology select_server = topology.select_server @@ -331,43 +339,44 @@ def mock_select_server(*args, **kwargs): server = select_server(*args, **kwargs) def raise_error(*args, **kwargs): - raise ServerSelectionTimeoutError( - 'No primary available for writes') + raise ServerSelectionTimeoutError("No primary available for writes") + # Raise ServerSelectionTimeout on the retry attempt. topology.select_server = raise_error return server - for method, args, kwargs in retryable_single_statement_ops( - client.db.retryable_write_test): - msg = '%s(*%r, **%r)' % (method.__name__, args, kwargs) + for method, args, kwargs in retryable_single_statement_ops(client.db.retryable_write_test): + msg = "%s(*%r, **%r)" % (method.__name__, args, kwargs) listener.results.clear() topology.select_server = mock_select_server with self.assertRaises(ConnectionFailure, msg=msg): method(*args, **kwargs) - self.assertEqual(len(listener.results['started']), 1, msg) + self.assertEqual(len(listener.results["started"]), 1, msg) @client_context.require_replica_set @client_context.require_test_commands def test_batch_splitting(self): """Test retry succeeds after failures during batch splitting.""" - large = 's' * 1024 * 1024 * 15 + large = "s" * 1024 * 1024 * 15 coll = self.db.retryable_write_test coll.delete_many({}) self.listener.results.clear() - bulk_result = coll.bulk_write([ - InsertOne({'_id': 1, 'l': large}), - InsertOne({'_id': 2, 'l': large}), - InsertOne({'_id': 3, 'l': large}), - UpdateOne({'_id': 1, 'l': large}, - {'$unset': {'l': 1}, '$inc': {'count': 1}}), - UpdateOne({'_id': 2, 'l': large}, {'$set': {'foo': 'bar'}}), - DeleteOne({'l': large}), - DeleteOne({'l': large})]) + bulk_result = coll.bulk_write( + [ + InsertOne({"_id": 1, "l": large}), + InsertOne({"_id": 2, "l": large}), + InsertOne({"_id": 3, "l": large}), + UpdateOne({"_id": 1, "l": large}, {"$unset": {"l": 1}, "$inc": {"count": 1}}), + UpdateOne({"_id": 2, "l": large}, {"$set": {"foo": "bar"}}), + DeleteOne({"l": large}), + DeleteOne({"l": large}), + ] + ) # Each command should fail and be retried. # With OP_MSG 3 inserts are one batch. 2 updates another. # 2 deletes a third. - self.assertEqual(len(self.listener.results['started']), 6) - self.assertEqual(coll.find_one(), {'_id': 1, 'count': 1}) + self.assertEqual(len(self.listener.results["started"]), 6) + self.assertEqual(coll.find_one(), {"_id": 1, "count": 1}) # Assert the final result expected_result = { "writeErrors": [], @@ -385,42 +394,51 @@ def test_batch_splitting(self): @client_context.require_test_commands def test_batch_splitting_retry_fails(self): """Test retry fails during batch splitting.""" - large = 's' * 1024 * 1024 * 15 + large = "s" * 1024 * 1024 * 15 coll = self.db.retryable_write_test coll.delete_many({}) - self.client.admin.command(SON([ - ('configureFailPoint', 'onPrimaryTransactionalWrite'), - ('mode', {'skip': 3}), # The number of _documents_ to skip. - ('data', {'failBeforeCommitExceptionCode': 1})])) + self.client.admin.command( + SON( + [ + ("configureFailPoint", "onPrimaryTransactionalWrite"), + ("mode", {"skip": 3}), # The number of _documents_ to skip. + ("data", {"failBeforeCommitExceptionCode": 1}), + ] + ) + ) self.listener.results.clear() with self.client.start_session() as session: initial_txn = session._server_session._transaction_id try: - coll.bulk_write([InsertOne({'_id': 1, 'l': large}), - InsertOne({'_id': 2, 'l': large}), - InsertOne({'_id': 3, 'l': large}), - InsertOne({'_id': 4, 'l': large})], - session=session) + coll.bulk_write( + [ + InsertOne({"_id": 1, "l": large}), + InsertOne({"_id": 2, "l": large}), + InsertOne({"_id": 3, "l": large}), + InsertOne({"_id": 4, "l": large}), + ], + session=session, + ) except ConnectionFailure: pass else: self.fail("bulk_write should have failed") - started = self.listener.results['started'] + started = self.listener.results["started"] self.assertEqual(len(started), 3) - self.assertEqual(len(self.listener.results['succeeded']), 1) + self.assertEqual(len(self.listener.results["succeeded"]), 1) expected_txn = Int64(initial_txn + 1) - self.assertEqual(started[0].command['txnNumber'], expected_txn) - self.assertEqual(started[0].command['lsid'], session.session_id) + self.assertEqual(started[0].command["txnNumber"], expected_txn) + self.assertEqual(started[0].command["lsid"], session.session_id) expected_txn = Int64(initial_txn + 2) - self.assertEqual(started[1].command['txnNumber'], expected_txn) - self.assertEqual(started[1].command['lsid'], session.session_id) - started[1].command.pop('$clusterTime') - started[2].command.pop('$clusterTime') + self.assertEqual(started[1].command["txnNumber"], expected_txn) + self.assertEqual(started[1].command["lsid"], session.session_id) + started[1].command.pop("$clusterTime") + started[2].command.pop("$clusterTime") self.assertEqual(started[1].command, started[2].command) final_txn = session._server_session._transaction_id self.assertEqual(final_txn, expected_txn) - self.assertEqual(coll.find_one(projection={'_id': True}), {'_id': 1}) + self.assertEqual(coll.find_one(projection={"_id": True}), {"_id": 1}) class TestWriteConcernError(IntegrationTest): @@ -434,20 +452,18 @@ class TestWriteConcernError(IntegrationTest): def setUpClass(cls): super(TestWriteConcernError, cls).setUpClass() cls.fail_insert = { - 'configureFailPoint': 'failCommand', - 'mode': {'times': 2}, - 'data': { - 'failCommands': ['insert'], - 'writeConcernError': { - 'code': 91, - 'errmsg': 'Replication is being shut down'}, - }} + "configureFailPoint": "failCommand", + "mode": {"times": 2}, + "data": { + "failCommands": ["insert"], + "writeConcernError": {"code": 91, "errmsg": "Replication is being shut down"}, + }, + } @client_context.require_version_min(4, 0) def test_RetryableWriteError_error_label(self): listener = OvertCommandListener() - client = rs_or_single_client( - retryWrites=True, event_listeners=[listener]) + client = rs_or_single_client(retryWrites=True, event_listeners=[listener]) # Ensure collection exists. client.pymongo_test.testcoll.insert_one({}) @@ -455,14 +471,13 @@ def test_RetryableWriteError_error_label(self): with self.fail_point(self.fail_insert): with self.assertRaises(WriteConcernError) as cm: client.pymongo_test.testcoll.insert_one({}) - self.assertTrue(cm.exception.has_error_label( - 'RetryableWriteError')) + self.assertTrue(cm.exception.has_error_label("RetryableWriteError")) if client_context.version >= Version(4, 4): # In MongoDB 4.4+ we rely on the server returning the error label. self.assertIn( - 'RetryableWriteError', - listener.results['succeeded'][-1].reply['errorLabels']) + "RetryableWriteError", listener.results["succeeded"][-1].reply["errorLabels"] + ) @client_context.require_version_min(4, 4) def test_RetryableWriteError_error_label_RawBSONDocument(self): @@ -471,13 +486,18 @@ def test_RetryableWriteError_error_label_RawBSONDocument(self): with self.client.start_session() as s: s._start_retryable_write() result = self.client.pymongo_test.command( - 'insert', 'testcoll', documents=[{'_id': 1}], - txnNumber=s._server_session.transaction_id, session=s, + "insert", + "testcoll", + documents=[{"_id": 1}], + txnNumber=s._server_session.transaction_id, + session=s, codec_options=DEFAULT_CODEC_OPTIONS.with_options( - document_class=RawBSONDocument)) + document_class=RawBSONDocument + ), + ) - self.assertIn('writeConcernError', result) - self.assertIn('RetryableWriteError', result['errorLabels']) + self.assertIn("writeConcernError", result) + self.assertIn("RetryableWriteError", result["errorLabels"]) class InsertThread(threading.Thread): @@ -499,26 +519,24 @@ class TestPoolPausedError(IntegrationTest): @client_context.require_failCommand_blockConnection @client_context.require_retryable_writes - @client_knobs(heartbeat_frequency=.05, min_heartbeat_interval=.05) + @client_knobs(heartbeat_frequency=0.05, min_heartbeat_interval=0.05) def test_pool_paused_error_is_retryable(self): cmap_listener = CMAPListener() cmd_listener = OvertCommandListener() - client = rs_or_single_client( - maxPoolSize=1, - event_listeners=[cmap_listener, cmd_listener]) + client = rs_or_single_client(maxPoolSize=1, event_listeners=[cmap_listener, cmd_listener]) self.addCleanup(client.close) for _ in range(10): cmap_listener.reset() cmd_listener.reset() threads = [InsertThread(client.pymongo_test.test) for _ in range(2)] fail_command = { - 'mode': {'times': 1}, - 'data': { - 'failCommands': ['insert'], - 'blockConnection': True, - 'blockTimeMS': 1000, - 'errorCode': 91, - 'errorLabels': ['RetryableWriteError'], + "mode": {"times": 1}, + "data": { + "failCommands": ["insert"], + "blockConnection": True, + "blockTimeMS": 1000, + "errorCode": 91, + "errorLabels": ["RetryableWriteError"], }, } with self.fail_point(fail_command): @@ -536,29 +554,25 @@ def test_pool_paused_error_is_retryable(self): break # Via CMAP monitoring, assert that the first check out succeeds. - cmap_events = cmap_listener.events_by_type(( - ConnectionCheckedOutEvent, - ConnectionCheckOutFailedEvent, - PoolClearedEvent)) + cmap_events = cmap_listener.events_by_type( + (ConnectionCheckedOutEvent, ConnectionCheckOutFailedEvent, PoolClearedEvent) + ) msg = pprint.pformat(cmap_listener.events) self.assertIsInstance(cmap_events[0], ConnectionCheckedOutEvent, msg) self.assertIsInstance(cmap_events[1], PoolClearedEvent, msg) - self.assertIsInstance( - cmap_events[2], ConnectionCheckOutFailedEvent, msg) - self.assertEqual(cmap_events[2].reason, - ConnectionCheckOutFailedReason.CONN_ERROR, - msg) + self.assertIsInstance(cmap_events[2], ConnectionCheckOutFailedEvent, msg) + self.assertEqual(cmap_events[2].reason, ConnectionCheckOutFailedReason.CONN_ERROR, msg) self.assertIsInstance(cmap_events[3], ConnectionCheckedOutEvent, msg) # Connection check out failures are not reflected in command # monitoring because we only publish command events _after_ checking # out a connection. - started = cmd_listener.results['started'] + started = cmd_listener.results["started"] msg = pprint.pformat(cmd_listener.results) self.assertEqual(3, len(started), msg) - succeeded = cmd_listener.results['succeeded'] + succeeded = cmd_listener.results["succeeded"] self.assertEqual(2, len(succeeded), msg) - failed = cmd_listener.results['failed'] + failed = cmd_listener.results["failed"] self.assertEqual(1, len(failed), msg) @@ -571,8 +585,7 @@ def test_increment_transaction_id_without_sending_command(self): the first attempt fails before sending the command. """ listener = OvertCommandListener() - client = rs_or_single_client( - retryWrites=True, event_listeners=[listener]) + client = rs_or_single_client(retryWrites=True, event_listeners=[listener]) self.addCleanup(client.close) topology = client._topology select_server = topology.select_server @@ -581,28 +594,27 @@ def raise_connection_err_select_server(*args, **kwargs): # Raise ConnectionFailure on the first attempt and perform # normal selection on the retry attempt. topology.select_server = select_server - raise ConnectionFailure('Connection refused') + raise ConnectionFailure("Connection refused") - for method, args, kwargs in retryable_single_statement_ops( - client.db.retryable_write_test): + for method, args, kwargs in retryable_single_statement_ops(client.db.retryable_write_test): listener.results.clear() topology.select_server = raise_connection_err_select_server with client.start_session() as session: kwargs = copy.deepcopy(kwargs) - kwargs['session'] = session - msg = '%s(*%r, **%r)' % (method.__name__, args, kwargs) + kwargs["session"] = session + msg = "%s(*%r, **%r)" % (method.__name__, args, kwargs) initial_txn_id = session._server_session.transaction_id # Each operation should fail on the first attempt and succeed # on the second. method(*args, **kwargs) - self.assertEqual(len(listener.results['started']), 1, msg) - retry_cmd = listener.results['started'][0].command - sent_txn_id = retry_cmd['txnNumber'] + self.assertEqual(len(listener.results["started"]), 1, msg) + retry_cmd = listener.results["started"][0].command + sent_txn_id = retry_cmd["txnNumber"] final_txn_id = session._server_session.transaction_id self.assertEqual(Int64(initial_txn_id + 1), sent_txn_id, msg) self.assertEqual(sent_txn_id, final_txn_id, msg) -if __name__ == '__main__': +if __name__ == "__main__": unittest.main() diff --git a/test/test_retryable_writes_unified.py b/test/test_retryable_writes_unified.py index 4e851de273..4e97c14d4b 100644 --- a/test/test_retryable_writes_unified.py +++ b/test/test_retryable_writes_unified.py @@ -23,8 +23,7 @@ from test.unified_format import generate_test_classes # Location of JSON test specifications. -TEST_PATH = os.path.join( - os.path.dirname(os.path.realpath(__file__)), 'retryable_writes', 'unified') +TEST_PATH = os.path.join(os.path.dirname(os.path.realpath(__file__)), "retryable_writes", "unified") # Generate unified tests. globals().update(generate_test_classes(TEST_PATH, module=__name__)) diff --git a/test/test_saslprep.py b/test/test_saslprep.py index c694224a6c..1dd4727181 100644 --- a/test/test_saslprep.py +++ b/test/test_saslprep.py @@ -16,11 +16,12 @@ sys.path[0:0] = [""] -from pymongo.saslprep import saslprep from test import unittest -class TestSASLprep(unittest.TestCase): +from pymongo.saslprep import saslprep + +class TestSASLprep(unittest.TestCase): def test_saslprep(self): try: import stringprep diff --git a/test/test_sdam_monitoring_spec.py b/test/test_sdam_monitoring_spec.py index 02807f05ab..bb202dc447 100644 --- a/test/test_sdam_monitoring_spec.py +++ b/test/test_sdam_monitoring_spec.py @@ -21,43 +21,40 @@ sys.path[0:0] = [""] -from pymongo import MongoClient +from test import IntegrationTest, client_context, client_knobs, unittest +from test.utils import ( + ServerAndTopologyEventListener, + rs_or_single_client, + server_name_to_type, + wait_until, +) + from bson.json_util import object_hook -from pymongo import monitoring +from pymongo import MongoClient, monitoring from pymongo.common import clean_node -from pymongo.errors import (ConnectionFailure, - NotPrimaryError) +from pymongo.errors import ConnectionFailure, NotPrimaryError from pymongo.hello import Hello from pymongo.monitor import Monitor from pymongo.server_description import ServerDescription from pymongo.topology_description import TOPOLOGY_TYPE -from test import unittest, client_context, client_knobs, IntegrationTest -from test.utils import (ServerAndTopologyEventListener, - server_name_to_type, - rs_or_single_client, - wait_until) # Location of JSON test specifications. -_TEST_PATH = os.path.join( - os.path.dirname(os.path.realpath(__file__)), - 'sdam_monitoring') +_TEST_PATH = os.path.join(os.path.dirname(os.path.realpath(__file__)), "sdam_monitoring") def compare_server_descriptions(expected, actual): - if ((not expected['address'] == "%s:%s" % actual.address) or - (not server_name_to_type(expected['type']) == - actual.server_type)): + if (not expected["address"] == "%s:%s" % actual.address) or ( + not server_name_to_type(expected["type"]) == actual.server_type + ): return False - expected_hosts = set( - expected['arbiters'] + expected['passives'] + expected['hosts']) + expected_hosts = set(expected["arbiters"] + expected["passives"] + expected["hosts"]) return expected_hosts == set("%s:%s" % s for s in actual.all_hosts) def compare_topology_descriptions(expected, actual): - if not (TOPOLOGY_TYPE.__getattribute__( - expected['topologyType']) == actual.topology_type): + if not (TOPOLOGY_TYPE.__getattribute__(expected["topologyType"]) == actual.topology_type): return False - expected = expected['servers'] + expected = expected["servers"] actual = actual.server_descriptions() if len(expected) != len(actual): return False @@ -80,70 +77,74 @@ def compare_events(expected_dict, actual): if expected_type == "server_opening_event": if not isinstance(actual, monitoring.ServerOpeningEvent): - return False, "Expected ServerOpeningEvent, got %s" % ( - actual.__class__) - if not expected['address'] == "%s:%s" % actual.server_address: - return (False, - "ServerOpeningEvent published with wrong address (expected" - " %s, got %s" % (expected['address'], - actual.server_address)) + return False, "Expected ServerOpeningEvent, got %s" % (actual.__class__) + if not expected["address"] == "%s:%s" % actual.server_address: + return ( + False, + "ServerOpeningEvent published with wrong address (expected" + " %s, got %s" % (expected["address"], actual.server_address), + ) elif expected_type == "server_description_changed_event": if not isinstance(actual, monitoring.ServerDescriptionChangedEvent): - return (False, - "Expected ServerDescriptionChangedEvent, got %s" % ( - actual.__class__)) - if not expected['address'] == "%s:%s" % actual.server_address: - return (False, "ServerDescriptionChangedEvent has wrong address" - " (expected %s, got %s" % (expected['address'], - actual.server_address)) + return (False, "Expected ServerDescriptionChangedEvent, got %s" % (actual.__class__)) + if not expected["address"] == "%s:%s" % actual.server_address: + return ( + False, + "ServerDescriptionChangedEvent has wrong address" + " (expected %s, got %s" % (expected["address"], actual.server_address), + ) + if not compare_server_descriptions(expected["newDescription"], actual.new_description): + return (False, "New ServerDescription incorrect in" " ServerDescriptionChangedEvent") if not compare_server_descriptions( - expected['newDescription'], actual.new_description): - return (False, "New ServerDescription incorrect in" - " ServerDescriptionChangedEvent") - if not compare_server_descriptions(expected['previousDescription'], - actual.previous_description): - return (False, "Previous ServerDescription incorrect in" - " ServerDescriptionChangedEvent") + expected["previousDescription"], actual.previous_description + ): + return ( + False, + "Previous ServerDescription incorrect in" " ServerDescriptionChangedEvent", + ) elif expected_type == "server_closed_event": if not isinstance(actual, monitoring.ServerClosedEvent): - return False, "Expected ServerClosedEvent, got %s" % ( - actual.__class__) - if not expected['address'] == "%s:%s" % actual.server_address: - return (False, "ServerClosedEvent published with wrong address" - " (expected %s, got %s" % (expected['address'], - actual.server_address)) + return False, "Expected ServerClosedEvent, got %s" % (actual.__class__) + if not expected["address"] == "%s:%s" % actual.server_address: + return ( + False, + "ServerClosedEvent published with wrong address" + " (expected %s, got %s" % (expected["address"], actual.server_address), + ) elif expected_type == "topology_opening_event": if not isinstance(actual, monitoring.TopologyOpenedEvent): - return False, "Expected TopologyOpeningEvent, got %s" % ( - actual.__class__) + return False, "Expected TopologyOpeningEvent, got %s" % (actual.__class__) elif expected_type == "topology_description_changed_event": if not isinstance(actual, monitoring.TopologyDescriptionChangedEvent): - return (False, "Expected TopologyDescriptionChangedEvent," - " got %s" % (actual.__class__)) - if not compare_topology_descriptions(expected['newDescription'], - actual.new_description): - return (False, "New TopologyDescription incorrect in " - "TopologyDescriptionChangedEvent") + return ( + False, + "Expected TopologyDescriptionChangedEvent," " got %s" % (actual.__class__), + ) + if not compare_topology_descriptions(expected["newDescription"], actual.new_description): + return ( + False, + "New TopologyDescription incorrect in " "TopologyDescriptionChangedEvent", + ) if not compare_topology_descriptions( - expected['previousDescription'], - actual.previous_description): - return (False, "Previous TopologyDescription incorrect in" - " TopologyDescriptionChangedEvent") + expected["previousDescription"], actual.previous_description + ): + return ( + False, + "Previous TopologyDescription incorrect in" " TopologyDescriptionChangedEvent", + ) elif expected_type == "topology_closed_event": if not isinstance(actual, monitoring.TopologyClosedEvent): - return False, "Expected TopologyClosedEvent, got %s" % ( - actual.__class__) + return False, "Expected TopologyClosedEvent, got %s" % (actual.__class__) else: - return False, "Incorrect event: expected %s, actual %s" % ( - expected_type, actual) + return False, "Incorrect event: expected %s, actual %s" % (expected_type, actual) return True, "" @@ -151,12 +152,10 @@ def compare_events(expected_dict, actual): def compare_multiple_events(i, expected_results, actual_results): events_in_a_row = [] j = i - while(j < len(expected_results) and isinstance( - actual_results[j], - actual_results[i].__class__)): + while j < len(expected_results) and isinstance(actual_results[j], actual_results[i].__class__): events_in_a_row.append(actual_results[j]) j += 1 - message = '' + message = "" for event in events_in_a_row: for k in range(i, j): passed, message = compare_events(expected_results[k], event) @@ -165,11 +164,10 @@ def compare_multiple_events(i, expected_results, actual_results): break else: return i, False, message - return j, True, '' + return j, True, "" class TestAllScenarios(IntegrationTest): - def setUp(self): super(TestAllScenarios, self).setUp() self.all_listener = ServerAndTopologyEventListener() @@ -183,51 +181,60 @@ def run_scenario(self): def _run_scenario(self): class NoopMonitor(Monitor): """Override the _run method to do nothing.""" + def _run(self): time.sleep(0.05) - m = MongoClient(host=scenario_def['uri'], port=27017, - event_listeners=[self.all_listener], - _monitor_class=NoopMonitor) + m = MongoClient( + host=scenario_def["uri"], + port=27017, + event_listeners=[self.all_listener], + _monitor_class=NoopMonitor, + ) topology = m._get_topology() try: - for phase in scenario_def['phases']: - for (source, response) in phase.get('responses', []): + for phase in scenario_def["phases"]: + for (source, response) in phase.get("responses", []): source_address = clean_node(source) - topology.on_change(ServerDescription( - address=source_address, - hello=Hello(response), - round_trip_time=0)) + topology.on_change( + ServerDescription( + address=source_address, hello=Hello(response), round_trip_time=0 + ) + ) - expected_results = phase['outcome']['events'] + expected_results = phase["outcome"]["events"] expected_len = len(expected_results) wait_until( lambda: len(self.all_listener.results) >= expected_len, - "publish all events", timeout=15) + "publish all events", + timeout=15, + ) # Wait some time to catch possible lagging extra events. time.sleep(0.5) i = 0 while i < expected_len: - result = self.all_listener.results[i] if len( - self.all_listener.results) > i else None + result = ( + self.all_listener.results[i] if len(self.all_listener.results) > i else None + ) # The order of ServerOpening/ClosedEvents doesn't matter - if isinstance(result, (monitoring.ServerOpeningEvent, - monitoring.ServerClosedEvent)): + if isinstance( + result, (monitoring.ServerOpeningEvent, monitoring.ServerClosedEvent) + ): i, passed, message = compare_multiple_events( - i, expected_results, self.all_listener.results) + i, expected_results, self.all_listener.results + ) self.assertTrue(passed, message) else: - self.assertTrue( - *compare_events(expected_results[i], result)) + self.assertTrue(*compare_events(expected_results[i], result)) i += 1 # Assert no extra events. extra_events = self.all_listener.results[expected_len:] if extra_events: - self.fail('Extra events %r' % (extra_events,)) + self.fail("Extra events %r" % (extra_events,)) self.all_listener.reset() finally: @@ -240,11 +247,10 @@ def create_tests(): for dirpath, _, filenames in os.walk(_TEST_PATH): for filename in filenames: with open(os.path.join(dirpath, filename)) as scenario_stream: - scenario_def = json.load( - scenario_stream, object_hook=object_hook) + scenario_def = json.load(scenario_stream, object_hook=object_hook) # Construct test from scenario. new_test = create_test(scenario_def) - test_name = 'test_%s' % (os.path.splitext(filename)[0],) + test_name = "test_%s" % (os.path.splitext(filename)[0],) new_test.__name__ = test_name setattr(TestAllScenarios, new_test.__name__, new_test) @@ -253,7 +259,6 @@ def create_tests(): class TestSdamMonitoring(IntegrationTest): - @classmethod @client_context.require_failCommand_fail_point def setUpClass(cls): @@ -264,7 +269,8 @@ def setUpClass(cls): cls.listener = ServerAndTopologyEventListener() retry_writes = client_context.supports_transactions() cls.test_client = rs_or_single_client( - event_listeners=[cls.listener], retryWrites=retry_writes) + event_listeners=[cls.listener], retryWrites=retry_writes + ) cls.coll = cls.test_client[cls.client.db.name].test cls.coll.insert_one({}) @@ -282,12 +288,12 @@ def _test_app_error(self, fail_command_opts, expected_error): # Test that an application error causes a ServerDescriptionChangedEvent # to be published. - data = {'failCommands': ['insert']} + data = {"failCommands": ["insert"]} data.update(fail_command_opts) fail_insert = { - 'configureFailPoint': 'failCommand', - 'mode': {'times': 1}, - 'data': data, + "configureFailPoint": "failCommand", + "mode": {"times": 1}, + "data": data, } with self.fail_point(fail_insert): if self.test_client.options.retry_writes: @@ -301,43 +307,48 @@ def marked_unknown(event): return ( isinstance(event, monitoring.ServerDescriptionChangedEvent) and event.server_address == address - and not event.new_description.is_server_type_known) + and not event.new_description.is_server_type_known + ) def discovered_node(event): return ( isinstance(event, monitoring.ServerDescriptionChangedEvent) and event.server_address == address and not event.previous_description.is_server_type_known - and event.new_description.is_server_type_known) + and event.new_description.is_server_type_known + ) def marked_unknown_and_rediscovered(): - return (len(self.listener.matching(marked_unknown)) >= 1 and - len(self.listener.matching(discovered_node)) >= 1) + return ( + len(self.listener.matching(marked_unknown)) >= 1 + and len(self.listener.matching(discovered_node)) >= 1 + ) # Topology events are published asynchronously - wait_until(marked_unknown_and_rediscovered, 'rediscover node') + wait_until(marked_unknown_and_rediscovered, "rediscover node") # Expect a single ServerDescriptionChangedEvent for the network error. marked_unknown_events = self.listener.matching(marked_unknown) self.assertEqual(len(marked_unknown_events), 1, marked_unknown_events) - self.assertIsInstance( - marked_unknown_events[0].new_description.error, expected_error) + self.assertIsInstance(marked_unknown_events[0].new_description.error, expected_error) def test_network_error_publishes_events(self): - self._test_app_error({'closeConnection': True}, ConnectionFailure) + self._test_app_error({"closeConnection": True}, ConnectionFailure) # In 4.4+, not primary errors from failCommand don't cause SDAM state # changes because topologyVersion is not incremented. @client_context.require_version_max(4, 3) def test_not_primary_error_publishes_events(self): - self._test_app_error({'errorCode': 10107, 'closeConnection': False, - 'errorLabels': ['RetryableWriteError']}, - NotPrimaryError) + self._test_app_error( + {"errorCode": 10107, "closeConnection": False, "errorLabels": ["RetryableWriteError"]}, + NotPrimaryError, + ) def test_shutdown_error_publishes_events(self): - self._test_app_error({'errorCode': 91, 'closeConnection': False, - 'errorLabels': ['RetryableWriteError']}, - NotPrimaryError) + self._test_app_error( + {"errorCode": 91, "closeConnection": False, "errorLabels": ["RetryableWriteError"]}, + NotPrimaryError, + ) if __name__ == "__main__": diff --git a/test/test_server.py b/test/test_server.py index e4996d2e09..064d77d024 100644 --- a/test/test_server.py +++ b/test/test_server.py @@ -18,18 +18,19 @@ sys.path[0:0] = [""] +from test import unittest + from pymongo.hello import Hello from pymongo.server import Server from pymongo.server_description import ServerDescription -from test import unittest class TestServer(unittest.TestCase): def test_repr(self): - hello = Hello({'ok': 1}) - sd = ServerDescription(('localhost', 27017), hello) + hello = Hello({"ok": 1}) + sd = ServerDescription(("localhost", 27017), hello) server = Server(sd, pool=object(), monitor=object()) - self.assertTrue('Standalone' in str(server)) + self.assertTrue("Standalone" in str(server)) if __name__ == "__main__": diff --git a/test/test_server_description.py b/test/test_server_description.py index 23d6c8f377..1562711375 100644 --- a/test/test_server_description.py +++ b/test/test_server_description.py @@ -18,14 +18,15 @@ sys.path[0:0] = [""] -from bson.objectid import ObjectId +from test import unittest + from bson.int64 import Int64 -from pymongo.server_type import SERVER_TYPE +from bson.objectid import ObjectId from pymongo.hello import Hello, HelloCompat from pymongo.server_description import ServerDescription -from test import unittest +from pymongo.server_type import SERVER_TYPE -address = ('localhost', 27017) +address = ("localhost", 27017) def parse_hello_response(doc): @@ -42,82 +43,88 @@ def test_unknown(self): self.assertFalse(s.is_readable) def test_mongos(self): - s = parse_hello_response({'ok': 1, 'msg': 'isdbgrid'}) + s = parse_hello_response({"ok": 1, "msg": "isdbgrid"}) self.assertEqual(SERVER_TYPE.Mongos, s.server_type) - self.assertEqual('Mongos', s.server_type_name) + self.assertEqual("Mongos", s.server_type_name) self.assertTrue(s.is_writable) self.assertTrue(s.is_readable) def test_primary(self): - s = parse_hello_response( - {'ok': 1, HelloCompat.LEGACY_CMD: True, 'setName': 'rs'}) + s = parse_hello_response({"ok": 1, HelloCompat.LEGACY_CMD: True, "setName": "rs"}) self.assertEqual(SERVER_TYPE.RSPrimary, s.server_type) - self.assertEqual('RSPrimary', s.server_type_name) + self.assertEqual("RSPrimary", s.server_type_name) self.assertTrue(s.is_writable) self.assertTrue(s.is_readable) def test_secondary(self): s = parse_hello_response( - {'ok': 1, HelloCompat.LEGACY_CMD: False, 'secondary': True, 'setName': 'rs'}) + {"ok": 1, HelloCompat.LEGACY_CMD: False, "secondary": True, "setName": "rs"} + ) self.assertEqual(SERVER_TYPE.RSSecondary, s.server_type) - self.assertEqual('RSSecondary', s.server_type_name) + self.assertEqual("RSSecondary", s.server_type_name) self.assertFalse(s.is_writable) self.assertTrue(s.is_readable) def test_arbiter(self): s = parse_hello_response( - {'ok': 1, HelloCompat.LEGACY_CMD: False, 'arbiterOnly': True, 'setName': 'rs'}) + {"ok": 1, HelloCompat.LEGACY_CMD: False, "arbiterOnly": True, "setName": "rs"} + ) self.assertEqual(SERVER_TYPE.RSArbiter, s.server_type) - self.assertEqual('RSArbiter', s.server_type_name) + self.assertEqual("RSArbiter", s.server_type_name) self.assertFalse(s.is_writable) self.assertFalse(s.is_readable) def test_other(self): - s = parse_hello_response( - {'ok': 1, HelloCompat.LEGACY_CMD: False, 'setName': 'rs'}) + s = parse_hello_response({"ok": 1, HelloCompat.LEGACY_CMD: False, "setName": "rs"}) self.assertEqual(SERVER_TYPE.RSOther, s.server_type) - self.assertEqual('RSOther', s.server_type_name) + self.assertEqual("RSOther", s.server_type_name) - s = parse_hello_response({ - 'ok': 1, - HelloCompat.LEGACY_CMD: False, - 'secondary': True, - 'hidden': True, - 'setName': 'rs'}) + s = parse_hello_response( + { + "ok": 1, + HelloCompat.LEGACY_CMD: False, + "secondary": True, + "hidden": True, + "setName": "rs", + } + ) self.assertEqual(SERVER_TYPE.RSOther, s.server_type) self.assertFalse(s.is_writable) self.assertFalse(s.is_readable) def test_ghost(self): - s = parse_hello_response({'ok': 1, 'isreplicaset': True}) + s = parse_hello_response({"ok": 1, "isreplicaset": True}) self.assertEqual(SERVER_TYPE.RSGhost, s.server_type) - self.assertEqual('RSGhost', s.server_type_name) + self.assertEqual("RSGhost", s.server_type_name) self.assertFalse(s.is_writable) self.assertFalse(s.is_readable) def test_fields(self): - s = parse_hello_response({ - 'ok': 1, - HelloCompat.LEGACY_CMD: False, - 'secondary': True, - 'primary': 'a:27017', - 'tags': {'a': 'foo', 'b': 'baz'}, - 'maxMessageSizeBytes': 1, - 'maxBsonObjectSize': 2, - 'maxWriteBatchSize': 3, - 'minWireVersion': 4, - 'maxWireVersion': 5, - 'setName': 'rs'}) + s = parse_hello_response( + { + "ok": 1, + HelloCompat.LEGACY_CMD: False, + "secondary": True, + "primary": "a:27017", + "tags": {"a": "foo", "b": "baz"}, + "maxMessageSizeBytes": 1, + "maxBsonObjectSize": 2, + "maxWriteBatchSize": 3, + "minWireVersion": 4, + "maxWireVersion": 5, + "setName": "rs", + } + ) self.assertEqual(SERVER_TYPE.RSSecondary, s.server_type) - self.assertEqual(('a', 27017), s.primary) - self.assertEqual({'a': 'foo', 'b': 'baz'}, s.tags) + self.assertEqual(("a", 27017), s.primary) + self.assertEqual({"a": "foo", "b": "baz"}, s.tags) self.assertEqual(1, s.max_message_size) self.assertEqual(2, s.max_bson_size) self.assertEqual(3, s.max_write_batch_size) @@ -125,55 +132,57 @@ def test_fields(self): self.assertEqual(5, s.max_wire_version) def test_default_max_message_size(self): - s = parse_hello_response({ - 'ok': 1, - HelloCompat.LEGACY_CMD: True, - 'maxBsonObjectSize': 2}) + s = parse_hello_response({"ok": 1, HelloCompat.LEGACY_CMD: True, "maxBsonObjectSize": 2}) # Twice max_bson_size. self.assertEqual(4, s.max_message_size) def test_standalone(self): - s = parse_hello_response({'ok': 1, HelloCompat.LEGACY_CMD: True}) + s = parse_hello_response({"ok": 1, HelloCompat.LEGACY_CMD: True}) self.assertEqual(SERVER_TYPE.Standalone, s.server_type) # Mongod started with --slave. # master-slave replication was removed in MongoDB 4.0. - s = parse_hello_response({'ok': 1, HelloCompat.LEGACY_CMD: False}) + s = parse_hello_response({"ok": 1, HelloCompat.LEGACY_CMD: False}) self.assertEqual(SERVER_TYPE.Standalone, s.server_type) self.assertTrue(s.is_writable) self.assertTrue(s.is_readable) def test_ok_false(self): - s = parse_hello_response({'ok': 0, HelloCompat.LEGACY_CMD: True}) + s = parse_hello_response({"ok": 0, HelloCompat.LEGACY_CMD: True}) self.assertEqual(SERVER_TYPE.Unknown, s.server_type) self.assertFalse(s.is_writable) self.assertFalse(s.is_readable) def test_all_hosts(self): - s = parse_hello_response({ - 'ok': 1, - HelloCompat.LEGACY_CMD: True, - 'hosts': ['a'], - 'passives': ['b:27018'], - 'arbiters': ['c'] - }) + s = parse_hello_response( + { + "ok": 1, + HelloCompat.LEGACY_CMD: True, + "hosts": ["a"], + "passives": ["b:27018"], + "arbiters": ["c"], + } + ) - self.assertEqual( - [('a', 27017), ('b', 27018), ('c', 27017)], - sorted(s.all_hosts)) + self.assertEqual([("a", 27017), ("b", 27018), ("c", 27017)], sorted(s.all_hosts)) def test_repr(self): - s = parse_hello_response({'ok': 1, 'msg': 'isdbgrid'}) - self.assertEqual(repr(s), - "") + s = parse_hello_response({"ok": 1, "msg": "isdbgrid"}) + self.assertEqual( + repr(s), "" + ) def test_topology_version(self): - topology_version = {'processId': ObjectId(), 'counter': Int64('0')} + topology_version = {"processId": ObjectId(), "counter": Int64("0")} s = parse_hello_response( - {'ok': 1, HelloCompat.LEGACY_CMD: True, 'setName': 'rs', - 'topologyVersion': topology_version}) + { + "ok": 1, + HelloCompat.LEGACY_CMD: True, + "setName": "rs", + "topologyVersion": topology_version, + } + ) self.assertEqual(SERVER_TYPE.RSPrimary, s.server_type) self.assertEqual(topology_version, s.topology_version) @@ -185,8 +194,7 @@ def test_topology_version(self): def test_topology_version_not_present(self): # No topologyVersion field. - s = parse_hello_response( - {'ok': 1, HelloCompat.LEGACY_CMD: True, 'setName': 'rs'}) + s = parse_hello_response({"ok": 1, HelloCompat.LEGACY_CMD: True, "setName": "rs"}) self.assertEqual(SERVER_TYPE.RSPrimary, s.server_type) self.assertEqual(None, s.topology_version) diff --git a/test/test_server_selection.py b/test/test_server_selection.py index 46fce3b13a..5211097e61 100644 --- a/test/test_server_selection.py +++ b/test/test_server_selection.py @@ -17,8 +17,7 @@ import os import sys -from pymongo import MongoClient -from pymongo import ReadPreference +from pymongo import MongoClient, ReadPreference from pymongo.errors import ServerSelectionTimeoutError from pymongo.hello import HelloCompat from pymongo.server_selectors import writable_server_selector @@ -27,22 +26,30 @@ sys.path[0:0] = [""] -from test import client_context, unittest, IntegrationTest -from test.utils import (rs_or_single_client, wait_until, EventListener, - FunctionCallRecorder) +from test import IntegrationTest, client_context, unittest +from test.utils import ( + EventListener, + FunctionCallRecorder, + rs_or_single_client, + wait_until, +) from test.utils_selection_tests import ( - create_selection_tests, get_addresses, get_topology_settings_dict, - make_server_description) - + create_selection_tests, + get_addresses, + get_topology_settings_dict, + make_server_description, +) # Location of JSON test specifications. _TEST_PATH = os.path.join( os.path.dirname(os.path.realpath(__file__)), - os.path.join('server_selection', 'server_selection')) + os.path.join("server_selection", "server_selection"), +) class SelectionStoreSelector(object): """No-op selector that keeps track of what was passed to it.""" + def __init__(self): self.selection = None @@ -51,7 +58,6 @@ def __call__(self, selection): return selection - class TestAllScenarios(create_selection_tests(_TEST_PATH)): pass @@ -67,37 +73,33 @@ def custom_selector(servers): # Initialize client with appropriate listeners. listener = EventListener() - client = rs_or_single_client( - server_selector=custom_selector, event_listeners=[listener]) + client = rs_or_single_client(server_selector=custom_selector, event_listeners=[listener]) self.addCleanup(client.close) - coll = client.get_database( - 'testdb', read_preference=ReadPreference.NEAREST).coll - self.addCleanup(client.drop_database, 'testdb') + coll = client.get_database("testdb", read_preference=ReadPreference.NEAREST).coll + self.addCleanup(client.drop_database, "testdb") # Wait the node list to be fully populated. def all_hosts_started(): - return (len(client.admin.command(HelloCompat.LEGACY_CMD)['hosts']) == - len(client._topology._description.readable_servers)) + return len(client.admin.command(HelloCompat.LEGACY_CMD)["hosts"]) == len( + client._topology._description.readable_servers + ) - wait_until(all_hosts_started, 'receive heartbeat from all hosts') - expected_port = max([ - n.address[1] - for n in client._topology._description.readable_servers]) + wait_until(all_hosts_started, "receive heartbeat from all hosts") + expected_port = max([n.address[1] for n in client._topology._description.readable_servers]) # Insert 1 record and access it 10 times. - coll.insert_one({'name': 'John Doe'}) + coll.insert_one({"name": "John Doe"}) for _ in range(10): - coll.find_one({'name': 'John Doe'}) + coll.find_one({"name": "John Doe"}) # Confirm all find commands are run against appropriate host. - for command in listener.results['started']: - if command.command_name == 'find': - self.assertEqual( - command.connection_id[1], expected_port) + for command in listener.results["started"]: + if command.command_name == "find": + self.assertEqual(command.connection_id[1], expected_port) def test_invalid_server_selector(self): # Client initialization must fail if server_selector is not callable. - for selector_candidate in [list(), 10, 'string', {}]: + for selector_candidate in [list(), 10, "string", {}]: with self.assertRaisesRegex(ValueError, "must be a callable"): MongoClient(connect=False, server_selector=selector_candidate) @@ -112,13 +114,13 @@ def test_selector_called(self): mongo_client = rs_or_single_client(server_selector=selector) test_collection = mongo_client.testdb.test_collection self.addCleanup(mongo_client.close) - self.addCleanup(mongo_client.drop_database, 'testdb') + self.addCleanup(mongo_client.drop_database, "testdb") # Do N operations and test selector is called at least N times. - test_collection.insert_one({'age': 20, 'name': 'John'}) - test_collection.insert_one({'age': 31, 'name': 'Jane'}) - test_collection.update_one({'name': 'Jane'}, {'$set': {'age': 21}}) - test_collection.find_one({'name': 'Roe'}) + test_collection.insert_one({"age": 20, "name": "John"}) + test_collection.insert_one({"age": 31, "name": "Jane"}) + test_collection.update_one({"name": "Jane"}, {"$set": {"age": 21}}) + test_collection.find_one({"name": "Roe"}) self.assertGreaterEqual(selector.call_count, 4) @client_context.require_replica_set @@ -126,86 +128,66 @@ def test_latency_threshold_application(self): selector = SelectionStoreSelector() scenario_def = { - 'topology_description': { - 'type': 'ReplicaSetWithPrimary', 'servers': [ - {'address': 'b:27017', - 'avg_rtt_ms': 10000, - 'type': 'RSSecondary', - 'tag': {}}, - {'address': 'c:27017', - 'avg_rtt_ms': 20000, - 'type': 'RSSecondary', - 'tag': {}}, - {'address': 'a:27017', - 'avg_rtt_ms': 30000, - 'type': 'RSPrimary', - 'tag': {}}, - ]}} + "topology_description": { + "type": "ReplicaSetWithPrimary", + "servers": [ + {"address": "b:27017", "avg_rtt_ms": 10000, "type": "RSSecondary", "tag": {}}, + {"address": "c:27017", "avg_rtt_ms": 20000, "type": "RSSecondary", "tag": {}}, + {"address": "a:27017", "avg_rtt_ms": 30000, "type": "RSPrimary", "tag": {}}, + ], + } + } # Create & populate Topology such that all but one server is too slow. - rtt_times = [srv['avg_rtt_ms'] for srv in - scenario_def['topology_description']['servers']] + rtt_times = [srv["avg_rtt_ms"] for srv in scenario_def["topology_description"]["servers"]] min_rtt_idx = rtt_times.index(min(rtt_times)) - seeds, hosts = get_addresses( - scenario_def["topology_description"]["servers"]) + seeds, hosts = get_addresses(scenario_def["topology_description"]["servers"]) settings = get_topology_settings_dict( - heartbeat_frequency=1, local_threshold_ms=1, seeds=seeds, - server_selector=selector) + heartbeat_frequency=1, local_threshold_ms=1, seeds=seeds, server_selector=selector + ) topology = Topology(TopologySettings(**settings)) topology.open() - for server in scenario_def['topology_description']['servers']: + for server in scenario_def["topology_description"]["servers"]: server_description = make_server_description(server, hosts) topology.on_change(server_description) # Invoke server selection and assert no filtering based on latency # prior to custom server selection logic kicking in. server = topology.select_server(ReadPreference.NEAREST) - self.assertEqual( - len(selector.selection), - len(topology.description.server_descriptions())) + self.assertEqual(len(selector.selection), len(topology.description.server_descriptions())) # Ensure proper filtering based on latency after custom selection. - self.assertEqual( - server.description.address, seeds[min_rtt_idx]) + self.assertEqual(server.description.address, seeds[min_rtt_idx]) @client_context.require_replica_set def test_server_selector_bypassed(self): selector = FunctionCallRecorder(lambda x: x) scenario_def = { - 'topology_description': { - 'type': 'ReplicaSetNoPrimary', 'servers': [ - {'address': 'b:27017', - 'avg_rtt_ms': 10000, - 'type': 'RSSecondary', - 'tag': {}}, - {'address': 'c:27017', - 'avg_rtt_ms': 20000, - 'type': 'RSSecondary', - 'tag': {}}, - {'address': 'a:27017', - 'avg_rtt_ms': 30000, - 'type': 'RSSecondary', - 'tag': {}}, - ]}} + "topology_description": { + "type": "ReplicaSetNoPrimary", + "servers": [ + {"address": "b:27017", "avg_rtt_ms": 10000, "type": "RSSecondary", "tag": {}}, + {"address": "c:27017", "avg_rtt_ms": 20000, "type": "RSSecondary", "tag": {}}, + {"address": "a:27017", "avg_rtt_ms": 30000, "type": "RSSecondary", "tag": {}}, + ], + } + } # Create & populate Topology such that no server is writeable. - seeds, hosts = get_addresses( - scenario_def["topology_description"]["servers"]) + seeds, hosts = get_addresses(scenario_def["topology_description"]["servers"]) settings = get_topology_settings_dict( - heartbeat_frequency=1, local_threshold_ms=1, seeds=seeds, - server_selector=selector) + heartbeat_frequency=1, local_threshold_ms=1, seeds=seeds, server_selector=selector + ) topology = Topology(TopologySettings(**settings)) topology.open() - for server in scenario_def['topology_description']['servers']: + for server in scenario_def["topology_description"]["servers"]: server_description = make_server_description(server, hosts) topology.on_change(server_description) # Invoke server selection and assert no calls to our custom selector. - with self.assertRaisesRegex( - ServerSelectionTimeoutError, 'No primary available for writes'): - topology.select_server( - writable_server_selector, server_selection_timeout=0.1) + with self.assertRaisesRegex(ServerSelectionTimeoutError, "No primary available for writes"): + topology.select_server(writable_server_selector, server_selection_timeout=0.1) self.assertEqual(selector.call_count, 0) diff --git a/test/test_server_selection_in_window.py b/test/test_server_selection_in_window.py index c599210d11..18fdbc11f2 100644 --- a/test/test_server_selection_in_window.py +++ b/test/test_server_selection_in_window.py @@ -16,18 +16,17 @@ import os import threading +from test import IntegrationTest, client_context, unittest +from test.utils import OvertCommandListener, TestCreator, rs_client, wait_until +from test.utils_selection_tests import create_topology from pymongo.common import clean_node from pymongo.read_preferences import ReadPreference -from test import client_context, IntegrationTest, unittest -from test.utils_selection_tests import create_topology -from test.utils import TestCreator, rs_client, OvertCommandListener, wait_until - # Location of JSON test specifications. TEST_PATH = os.path.join( - os.path.dirname(os.path.realpath(__file__)), - os.path.join('server_selection', 'in_window')) + os.path.dirname(os.path.realpath(__file__)), os.path.join("server_selection", "in_window") +) class TestAllScenarios(unittest.TestCase): @@ -35,28 +34,27 @@ def run_scenario(self, scenario_def): topology = create_topology(scenario_def) # Update mock operation_count state: - for mock in scenario_def['mocked_topology_state']: - address = clean_node(mock['address']) + for mock in scenario_def["mocked_topology_state"]: + address = clean_node(mock["address"]) server = topology.get_server_by_address(address) - server.pool.operation_count = mock['operation_count'] + server.pool.operation_count = mock["operation_count"] pref = ReadPreference.NEAREST - counts = dict((address, 0) for address in - topology._description.server_descriptions()) + counts = dict((address, 0) for address in topology._description.server_descriptions()) # Number of times to repeat server selection - iterations = scenario_def['iterations'] + iterations = scenario_def["iterations"] for _ in range(iterations): server = topology.select_server(pref, server_selection_timeout=0) counts[server.description.address] += 1 # Verify expected_frequencies - outcome = scenario_def['outcome'] - tolerance = outcome['tolerance'] - expected_frequencies = outcome['expected_frequencies'] + outcome = scenario_def["outcome"] + tolerance = outcome["tolerance"] + expected_frequencies = outcome["expected_frequencies"] for host_str, freq in expected_frequencies.items(): address = clean_node(host_str) - actual_freq = float(counts[address])/iterations + actual_freq = float(counts[address]) / iterations if freq == 0: # Should be exactly 0. self.assertEqual(actual_freq, 0) @@ -112,7 +110,7 @@ def frequencies(self, client, listener): for thread in threads: self.assertTrue(thread.passed) - events = listener.results['started'] + events = listener.results["started"] self.assertEqual(len(events), N_FINDS * N_THREADS) nodes = client.nodes self.assertEqual(len(nodes), 2) @@ -120,7 +118,7 @@ def frequencies(self, client, listener): for event in events: freqs[event.connection_id] += 1 for address in freqs: - freqs[address] = freqs[address]/float(len(events)) + freqs[address] = freqs[address] / float(len(events)) return freqs @client_context.require_failCommand_appName @@ -129,21 +127,23 @@ def test_load_balancing(self): listener = OvertCommandListener() # PYTHON-2584: Use a large localThresholdMS to avoid the impact of # varying RTTs. - client = rs_client(client_context.mongos_seeds(), - appName='loadBalancingTest', - event_listeners=[listener], - localThresholdMS=10000) + client = rs_client( + client_context.mongos_seeds(), + appName="loadBalancingTest", + event_listeners=[listener], + localThresholdMS=10000, + ) self.addCleanup(client.close) - wait_until(lambda: len(client.nodes) == 2, 'discover both nodes') + wait_until(lambda: len(client.nodes) == 2, "discover both nodes") # Delay find commands on delay_finds = { - 'configureFailPoint': 'failCommand', - 'mode': {'times': 10000}, - 'data': { - 'failCommands': ['find'], - 'blockConnection': True, - 'blockTimeMS': 500, - 'appName': 'loadBalancingTest', + "configureFailPoint": "failCommand", + "mode": {"times": 10000}, + "data": { + "failCommands": ["find"], + "blockConnection": True, + "blockTimeMS": 500, + "appName": "loadBalancingTest", }, } with self.fail_point(delay_finds): diff --git a/test/test_server_selection_rtt.py b/test/test_server_selection_rtt.py index f914e03030..d2d8768809 100644 --- a/test/test_server_selection_rtt.py +++ b/test/test_server_selection_rtt.py @@ -21,11 +21,11 @@ sys.path[0:0] = [""] from test import unittest + from pymongo.read_preferences import MovingAverage # Location of JSON test specifications. -_TEST_PATH = os.path.join( - os.path.dirname(os.path.realpath(__file__)), 'server_selection/rtt') +_TEST_PATH = os.path.join(os.path.dirname(os.path.realpath(__file__)), "server_selection/rtt") class TestAllScenarios(unittest.TestCase): @@ -36,14 +36,13 @@ def create_test(scenario_def): def run_scenario(self): moving_average = MovingAverage() - if scenario_def['avg_rtt_ms'] != "NULL": - moving_average.add_sample(scenario_def['avg_rtt_ms']) + if scenario_def["avg_rtt_ms"] != "NULL": + moving_average.add_sample(scenario_def["avg_rtt_ms"]) - if scenario_def['new_rtt_ms'] != "NULL": - moving_average.add_sample(scenario_def['new_rtt_ms']) + if scenario_def["new_rtt_ms"] != "NULL": + moving_average.add_sample(scenario_def["new_rtt_ms"]) - self.assertAlmostEqual(moving_average.get(), - scenario_def['new_avg_rtt']) + self.assertAlmostEqual(moving_average.get(), scenario_def["new_avg_rtt"]) return run_scenario @@ -58,8 +57,7 @@ def create_tests(): # Construct test from scenario. new_test = create_test(scenario_def) - test_name = 'test_%s_%s' % ( - dirname, os.path.splitext(filename)[0]) + test_name = "test_%s_%s" % (dirname, os.path.splitext(filename)[0]) new_test.__name__ = test_name setattr(TestAllScenarios, new_test.__name__, new_test) diff --git a/test/test_session.py b/test/test_session.py index e844ae3a08..598da1aeb4 100644 --- a/test/test_session.py +++ b/test/test_session.py @@ -18,45 +18,40 @@ import os import sys import time - from io import BytesIO sys.path[0:0] = [""] +from test import IntegrationTest, SkipTest, client_context, unittest +from test.utils import EventListener, TestCreator, rs_or_single_client, wait_until +from test.utils_spec_runner import SpecRunner + from bson import DBRef from gridfs import GridFS, GridFSBucket -from pymongo import ASCENDING, InsertOne, IndexModel, monitoring +from pymongo import ASCENDING, IndexModel, InsertOne, monitoring from pymongo.common import _MAX_END_SESSIONS -from pymongo.errors import (ConfigurationError, - InvalidOperation, - OperationFailure) +from pymongo.errors import ConfigurationError, InvalidOperation, OperationFailure from pymongo.read_concern import ReadConcern -from test import IntegrationTest, client_context, unittest, SkipTest -from test.utils import (rs_or_single_client, - EventListener, - TestCreator, - wait_until) -from test.utils_spec_runner import SpecRunner + # Ignore auth commands like saslStart, so we can assert lsid is in all commands. class SessionTestListener(EventListener): def started(self, event): - if not event.command_name.startswith('sasl'): + if not event.command_name.startswith("sasl"): super(SessionTestListener, self).started(event) def succeeded(self, event): - if not event.command_name.startswith('sasl'): + if not event.command_name.startswith("sasl"): super(SessionTestListener, self).succeeded(event) def failed(self, event): - if not event.command_name.startswith('sasl'): + if not event.command_name.startswith("sasl"): super(SessionTestListener, self).failed(event) def first_command_started(self): - assert len(self.results['started']) >= 1, ( - "No command-started events") + assert len(self.results["started"]) >= 1, "No command-started events" - return self.results['started'][0] + return self.results["started"][0] def session_ids(client): @@ -64,7 +59,6 @@ def session_ids(client): class TestSession(IntegrationTest): - @classmethod @client_context.require_sessions def setUpClass(cls): @@ -87,20 +81,21 @@ def setUp(self): self.listener = SessionTestListener() self.session_checker_listener = SessionTestListener() self.client = rs_or_single_client( - event_listeners=[self.listener, self.session_checker_listener]) + event_listeners=[self.listener, self.session_checker_listener] + ) self.addCleanup(self.client.close) self.db = self.client.pymongo_test - self.initial_lsids = set(s['id'] for s in session_ids(self.client)) + self.initial_lsids = set(s["id"] for s in session_ids(self.client)) def tearDown(self): """All sessions used in the test must be returned to the pool.""" - self.client.drop_database('pymongo_test') + self.client.drop_database("pymongo_test") used_lsids = self.initial_lsids.copy() - for event in self.session_checker_listener.results['started']: - if 'lsid' in event.command: - used_lsids.add(event.command['lsid']['id']) + for event in self.session_checker_listener.results["started"]: + if "lsid" in event.command: + used_lsids.add(event.command["lsid"]["id"]) - current_lsids = set(s['id'] for s in session_ids(self.client)) + current_lsids = set(s["id"] for s in session_ids(self.client)) self.assertLessEqual(used_lsids, current_lsids) def _test_ops(self, client, *ops): @@ -115,21 +110,21 @@ def _test_ops(self, client, *ops): # In case "f" modifies its inputs. args = copy.copy(args) kw = copy.copy(kw) - kw['session'] = s + kw["session"] = s f(*args, **kw) self.assertGreaterEqual(s._server_session.last_use, start) - self.assertGreaterEqual(len(listener.results['started']), 1) - for event in listener.results['started']: + self.assertGreaterEqual(len(listener.results["started"]), 1) + for event in listener.results["started"]: self.assertTrue( - 'lsid' in event.command, - "%s sent no lsid with %s" % ( - f.__name__, event.command_name)) + "lsid" in event.command, + "%s sent no lsid with %s" % (f.__name__, event.command_name), + ) self.assertEqual( s.session_id, - event.command['lsid'], - "%s sent wrong lsid with %s" % ( - f.__name__, event.command_name)) + event.command["lsid"], + "%s sent wrong lsid with %s" % (f.__name__, event.command_name), + ) self.assertFalse(s.has_ended) @@ -142,35 +137,35 @@ def _test_ops(self, client, *ops): # In case "f" modifies its inputs. args = copy.copy(args) kw = copy.copy(kw) - kw['session'] = s + kw["session"] = s with self.assertRaisesRegex( - InvalidOperation, - 'Can only use session with the MongoClient' - ' that started it'): + InvalidOperation, "Can only use session with the MongoClient" " that started it" + ): f(*args, **kw) # No explicit session. for f, args, kw in ops: listener.results.clear() f(*args, **kw) - self.assertGreaterEqual(len(listener.results['started']), 1) + self.assertGreaterEqual(len(listener.results["started"]), 1) lsids = [] - for event in listener.results['started']: + for event in listener.results["started"]: self.assertTrue( - 'lsid' in event.command, - "%s sent no lsid with %s" % ( - f.__name__, event.command_name)) + "lsid" in event.command, + "%s sent no lsid with %s" % (f.__name__, event.command_name), + ) - lsids.append(event.command['lsid']) + lsids.append(event.command["lsid"]) - if not (sys.platform.startswith('java') or 'PyPy' in sys.version): + if not (sys.platform.startswith("java") or "PyPy" in sys.version): # Server session was returned to pool. Ignore interpreters with # non-deterministic GC. for lsid in lsids: self.assertIn( - lsid, session_ids(client), - "%s did not return implicit session to pool" % ( - f.__name__,)) + lsid, + session_ids(client), + "%s did not return implicit session to pool" % (f.__name__,), + ) def test_pool_lifo(self): # "Pool is LIFO" test from Driver Sessions Spec. @@ -210,31 +205,28 @@ def test_end_sessions(self): listener = SessionTestListener() client = rs_or_single_client(event_listeners=[listener]) # Start many sessions. - sessions = [client.start_session() - for _ in range(_MAX_END_SESSIONS + 1)] + sessions = [client.start_session() for _ in range(_MAX_END_SESSIONS + 1)] for s in sessions: s.end_session() # Closing the client should end all sessions and clear the pool. - self.assertEqual(len(client._topology._session_pool), - _MAX_END_SESSIONS + 1) + self.assertEqual(len(client._topology._session_pool), _MAX_END_SESSIONS + 1) client.close() self.assertEqual(len(client._topology._session_pool), 0) - end_sessions = [e for e in listener.results['started'] - if e.command_name == 'endSessions'] + end_sessions = [e for e in listener.results["started"] if e.command_name == "endSessions"] self.assertEqual(len(end_sessions), 2) # Closing again should not send any commands. listener.results.clear() client.close() - self.assertEqual(len(listener.results['started']), 0) + self.assertEqual(len(listener.results["started"]), 0) def test_client(self): client = self.client ops = [ (client.server_info, [], {}), (client.list_database_names, [], {}), - (client.drop_database, ['pymongo_test'], {}), + (client.drop_database, ["pymongo_test"], {}), ] self._test_ops(client, *ops) @@ -243,12 +235,12 @@ def test_database(self): client = self.client db = client.pymongo_test ops = [ - (db.command, ['ping'], {}), - (db.create_collection, ['collection'], {}), + (db.command, ["ping"], {}), + (db.create_collection, ["collection"], {}), (db.list_collection_names, [], {}), - (db.validate_collection, ['collection'], {}), - (db.drop_collection, ['collection'], {}), - (db.dereference, [DBRef('collection', 1)], {}), + (db.validate_collection, ["collection"], {}), + (db.drop_collection, ["collection"], {}), + (db.dereference, [DBRef("collection", 1)], {}), ] self._test_ops(client, *ops) @@ -261,19 +253,19 @@ def collection_write_ops(coll): (coll.insert_one, [{}], {}), (coll.insert_many, [[{}, {}]], {}), (coll.replace_one, [{}, {}], {}), - (coll.update_one, [{}, {'$set': {'a': 1}}], {}), - (coll.update_many, [{}, {'$set': {'a': 1}}], {}), + (coll.update_one, [{}, {"$set": {"a": 1}}], {}), + (coll.update_many, [{}, {"$set": {"a": 1}}], {}), (coll.delete_one, [{}], {}), (coll.delete_many, [{}], {}), (coll.find_one_and_replace, [{}, {}], {}), - (coll.find_one_and_update, [{}, {'$set': {'a': 1}}], {}), + (coll.find_one_and_update, [{}, {"$set": {"a": 1}}], {}), (coll.find_one_and_delete, [{}, {}], {}), - (coll.rename, ['collection2'], {}), + (coll.rename, ["collection2"], {}), # Drop collection2 between tests of "rename", above. - (coll.database.drop_collection, ['collection2'], {}), - (coll.create_indexes, [[IndexModel('a')]], {}), - (coll.create_index, ['a'], {}), - (coll.drop_index, ['a_1'], {}), + (coll.database.drop_collection, ["collection2"], {}), + (coll.create_indexes, [[IndexModel("a")]], {}), + (coll.create_index, ["a"], {}), + (coll.drop_index, ["a_1"], {}), (coll.drop_indexes, [], {}), (coll.aggregate, [[{"$out": "aggout"}]], {}), ] @@ -284,15 +276,17 @@ def test_collection(self): # Test some collection methods - the rest are in test_cursor. ops = self.collection_write_ops(coll) - ops.extend([ - (coll.distinct, ['a'], {}), - (coll.find_one, [], {}), - (coll.count_documents, [{}], {}), - (coll.list_indexes, [], {}), - (coll.index_information, [], {}), - (coll.options, [], {}), - (coll.aggregate, [[]], {}), - ]) + ops.extend( + [ + (coll.distinct, ["a"], {}), + (coll.find_one, [], {}), + (coll.count_documents, [{}], {}), + (coll.list_indexes, [], {}), + (coll.index_information, [], {}), + (coll.options, [], {}), + (coll.aggregate, [[]], {}), + ] + ) self._test_ops(client, *ops) @@ -330,29 +324,28 @@ def test_cursor(self): # Test all cursor methods. ops = [ - ('find', lambda session: list(coll.find(session=session))), - ('getitem', lambda session: coll.find(session=session)[0]), - ('distinct', - lambda session: coll.find(session=session).distinct('a')), - ('explain', lambda session: coll.find(session=session).explain()), + ("find", lambda session: list(coll.find(session=session))), + ("getitem", lambda session: coll.find(session=session)[0]), + ("distinct", lambda session: coll.find(session=session).distinct("a")), + ("explain", lambda session: coll.find(session=session).explain()), ] for name, f in ops: with client.start_session() as s: listener.results.clear() f(session=s) - self.assertGreaterEqual(len(listener.results['started']), 1) - for event in listener.results['started']: + self.assertGreaterEqual(len(listener.results["started"]), 1) + for event in listener.results["started"]: self.assertTrue( - 'lsid' in event.command, - "%s sent no lsid with %s" % ( - name, event.command_name)) + "lsid" in event.command, + "%s sent no lsid with %s" % (name, event.command_name), + ) self.assertEqual( s.session_id, - event.command['lsid'], - "%s sent wrong lsid with %s" % ( - name, event.command_name)) + event.command["lsid"], + "%s sent wrong lsid with %s" % (name, event.command_name), + ) with self.assertRaisesRegex(InvalidOperation, "ended session"): f(session=s) @@ -363,67 +356,64 @@ def test_cursor(self): f(session=None) event0 = listener.first_command_started() self.assertTrue( - 'lsid' in event0.command, - "%s sent no lsid with %s" % ( - name, event0.command_name)) + "lsid" in event0.command, "%s sent no lsid with %s" % (name, event0.command_name) + ) - lsid = event0.command['lsid'] + lsid = event0.command["lsid"] - for event in listener.results['started'][1:]: + for event in listener.results["started"][1:]: self.assertTrue( - 'lsid' in event.command, - "%s sent no lsid with %s" % ( - name, event.command_name)) + "lsid" in event.command, "%s sent no lsid with %s" % (name, event.command_name) + ) self.assertEqual( lsid, - event.command['lsid'], - "%s sent wrong lsid with %s" % ( - name, event.command_name)) + event.command["lsid"], + "%s sent wrong lsid with %s" % (name, event.command_name), + ) def test_gridfs(self): client = self.client fs = GridFS(client.pymongo_test) def new_file(session=None): - grid_file = fs.new_file(_id=1, filename='f', session=session) + grid_file = fs.new_file(_id=1, filename="f", session=session) # 1 MB, 5 chunks, to test that each chunk is fetched with same lsid. - grid_file.write(b'a' * 1048576) + grid_file.write(b"a" * 1048576) grid_file.close() def find(session=None): - files = list(fs.find({'_id': 1}, session=session)) + files = list(fs.find({"_id": 1}, session=session)) for f in files: f.read() self._test_ops( client, (new_file, [], {}), - (fs.put, [b'data'], {}), + (fs.put, [b"data"], {}), (lambda session=None: fs.get(1, session=session).read(), [], {}), - (lambda session=None: fs.get_version('f', session=session).read(), - [], {}), - (lambda session=None: - fs.get_last_version('f', session=session).read(), [], {}), + (lambda session=None: fs.get_version("f", session=session).read(), [], {}), + (lambda session=None: fs.get_last_version("f", session=session).read(), [], {}), (fs.list, [], {}), (fs.find_one, [1], {}), (lambda session=None: list(fs.find(session=session)), [], {}), (fs.exists, [1], {}), (find, [], {}), - (fs.delete, [1], {})) + (fs.delete, [1], {}), + ) def test_gridfs_bucket(self): client = self.client bucket = GridFSBucket(client.pymongo_test) def upload(session=None): - stream = bucket.open_upload_stream('f', session=session) - stream.write(b'a' * 1048576) + stream = bucket.open_upload_stream("f", session=session) + stream.write(b"a" * 1048576) stream.close() def upload_with_id(session=None): - stream = bucket.open_upload_stream_with_id(1, 'f1', session=session) - stream.write(b'a' * 1048576) + stream = bucket.open_upload_stream_with_id(1, "f1", session=session) + stream.write(b"a" * 1048576) stream.close() def open_download_stream(session=None): @@ -431,11 +421,11 @@ def open_download_stream(session=None): stream.read() def open_download_stream_by_name(session=None): - stream = bucket.open_download_stream_by_name('f', session=session) + stream = bucket.open_download_stream_by_name("f", session=session) stream.read() def find(session=None): - files = list(bucket.find({'_id': 1}, session=session)) + files = list(bucket.find({"_id": 1}, session=session)) for f in files: f.read() @@ -445,17 +435,18 @@ def find(session=None): client, (upload, [], {}), (upload_with_id, [], {}), - (bucket.upload_from_stream, ['f', b'data'], {}), - (bucket.upload_from_stream_with_id, [2, 'f', b'data'], {}), + (bucket.upload_from_stream, ["f", b"data"], {}), + (bucket.upload_from_stream_with_id, [2, "f", b"data"], {}), (open_download_stream, [], {}), (open_download_stream_by_name, [], {}), (bucket.download_to_stream, [1, sio], {}), - (bucket.download_to_stream_by_name, ['f', sio], {}), + (bucket.download_to_stream_by_name, ["f", sio], {}), (find, [], {}), - (bucket.rename, [1, 'f2'], {}), + (bucket.rename, [1, "f2"], {}), # Delete both files so _test_ops can run these operations twice. (bucket.delete, [1], {}), - (bucket.delete, [2], {})) + (bucket.delete, [2], {}), + ) def test_gridfsbucket_cursor(self): client = self.client @@ -463,7 +454,7 @@ def test_gridfsbucket_cursor(self): for file_id in 1, 2: stream = bucket.open_upload_stream_with_id(file_id, str(file_id)) - stream.write(b'a' * 1048576) + stream.write(b"a" * 1048576) stream.close() with client.start_session() as s: @@ -512,10 +503,7 @@ def test_aggregate(self): coll = client.pymongo_test.collection def agg(session=None): - list(coll.aggregate( - [], - batchSize=2, - session=session)) + list(coll.aggregate([], batchSize=2, session=session)) # With empty collection. self._test_ops(client, (agg, [], {})) @@ -547,11 +535,11 @@ def test_aggregate_error(self): listener.results.clear() with self.assertRaises(OperationFailure): - coll.aggregate([{'$badOperation': {'bar': 1}}]) + coll.aggregate([{"$badOperation": {"bar": 1}}]) event = listener.first_command_started() - self.assertEqual(event.command_name, 'aggregate') - lsid = event.command['lsid'] + self.assertEqual(event.command_name, "aggregate") + lsid = event.command["lsid"] # Session was returned to pool despite error. self.assertIn(lsid, session_ids(client)) @@ -562,7 +550,7 @@ def _test_cursor_helper(self, create_cursor, close_cursor): cursor = create_cursor(coll, None) next(cursor) # Session is "owned" by cursor. - session = getattr(cursor, '_%s__session' % cursor.__class__.__name__) + session = getattr(cursor, "_%s__session" % cursor.__class__.__name__) self.assertIsNotNone(session) lsid = session.session_id next(cursor) @@ -585,45 +573,46 @@ def _test_cursor_helper(self, create_cursor, close_cursor): def test_cursor_close(self): self._test_cursor_helper( - lambda coll, session: coll.find(session=session), - lambda cursor: cursor.close()) + lambda coll, session: coll.find(session=session), lambda cursor: cursor.close() + ) def test_command_cursor_close(self): self._test_cursor_helper( - lambda coll, session: coll.aggregate([], session=session), - lambda cursor: cursor.close()) + lambda coll, session: coll.aggregate([], session=session), lambda cursor: cursor.close() + ) def test_cursor_del(self): self._test_cursor_helper( - lambda coll, session: coll.find(session=session), - lambda cursor: cursor.__del__()) + lambda coll, session: coll.find(session=session), lambda cursor: cursor.__del__() + ) def test_command_cursor_del(self): self._test_cursor_helper( lambda coll, session: coll.aggregate([], session=session), - lambda cursor: cursor.__del__()) + lambda cursor: cursor.__del__(), + ) def test_cursor_exhaust(self): self._test_cursor_helper( - lambda coll, session: coll.find(session=session), - lambda cursor: list(cursor)) + lambda coll, session: coll.find(session=session), lambda cursor: list(cursor) + ) def test_command_cursor_exhaust(self): self._test_cursor_helper( - lambda coll, session: coll.aggregate([], session=session), - lambda cursor: list(cursor)) + lambda coll, session: coll.aggregate([], session=session), lambda cursor: list(cursor) + ) def test_cursor_limit_reached(self): self._test_cursor_helper( - lambda coll, session: coll.find(limit=4, batch_size=2, - session=session), - lambda cursor: list(cursor)) + lambda coll, session: coll.find(limit=4, batch_size=2, session=session), + lambda cursor: list(cursor), + ) def test_command_cursor_limit_reached(self): self._test_cursor_helper( - lambda coll, session: coll.aggregate([], batchSize=900, - session=session), - lambda cursor: list(cursor)) + lambda coll, session: coll.aggregate([], batchSize=900, session=session), + lambda cursor: list(cursor), + ) def _test_unacknowledged_ops(self, client, *ops): listener = client.options.event_listeners[0] @@ -634,23 +623,23 @@ def _test_unacknowledged_ops(self, client, *ops): # In case "f" modifies its inputs. args = copy.copy(args) kw = copy.copy(kw) - kw['session'] = s + kw["session"] = s with self.assertRaises( - ConfigurationError, - msg="%s did not raise ConfigurationError" % ( - f.__name__,)): + ConfigurationError, msg="%s did not raise ConfigurationError" % (f.__name__,) + ): f(*args, **kw) - if f.__name__ == 'create_collection': + if f.__name__ == "create_collection": # create_collection runs listCollections first. - event = listener.results['started'].pop(0) - self.assertEqual('listCollections', event.command_name) - self.assertIn('lsid', event.command, - "%s sent no lsid with %s" % ( - f.__name__, event.command_name)) + event = listener.results["started"].pop(0) + self.assertEqual("listCollections", event.command_name) + self.assertIn( + "lsid", + event.command, + "%s sent no lsid with %s" % (f.__name__, event.command_name), + ) # Should not run any command before raising an error. - self.assertFalse(listener.results['started'], - "%s sent command" % (f.__name__,)) + self.assertFalse(listener.results["started"], "%s sent command" % (f.__name__,)) self.assertTrue(s.has_ended) @@ -658,20 +647,22 @@ def _test_unacknowledged_ops(self, client, *ops): for f, args, kw in ops: listener.results.clear() f(*args, **kw) - self.assertGreaterEqual(len(listener.results['started']), 1) + self.assertGreaterEqual(len(listener.results["started"]), 1) - if f.__name__ == 'create_collection': + if f.__name__ == "create_collection": # create_collection runs listCollections first. - event = listener.results['started'].pop(0) - self.assertEqual('listCollections', event.command_name) - self.assertIn('lsid', event.command, - "%s sent no lsid with %s" % ( - f.__name__, event.command_name)) - - for event in listener.results['started']: - self.assertNotIn('lsid', event.command, - "%s sent lsid with %s" % ( - f.__name__, event.command_name)) + event = listener.results["started"].pop(0) + self.assertEqual("listCollections", event.command_name) + self.assertIn( + "lsid", + event.command, + "%s sent no lsid with %s" % (f.__name__, event.command_name), + ) + + for event in listener.results["started"]: + self.assertNotIn( + "lsid", event.command, "%s sent lsid with %s" % (f.__name__, event.command_name) + ) def test_unacknowledged_writes(self): # Ensure the collection exists. @@ -682,8 +673,8 @@ def test_unacknowledged_writes(self): coll = db.test_unacked_writes ops = [ (client.drop_database, [db.name], {}), - (db.create_collection, ['collection'], {}), - (db.drop_collection, ['collection'], {}), + (db.create_collection, ["collection"], {}), + (db.drop_collection, ["collection"], {}), ] ops.extend(self.collection_write_ops(coll)) self._test_unacknowledged_ops(client, *ops) @@ -699,21 +690,17 @@ def drop_db(): return False raise - wait_until(drop_db, 'dropped database after w=0 writes') + wait_until(drop_db, "dropped database after w=0 writes") def test_snapshot_incompatible_with_causal_consistency(self): - with self.client.start_session(causal_consistency=False, - snapshot=False): + with self.client.start_session(causal_consistency=False, snapshot=False): pass - with self.client.start_session(causal_consistency=False, - snapshot=True): + with self.client.start_session(causal_consistency=False, snapshot=True): pass - with self.client.start_session(causal_consistency=True, - snapshot=False): + with self.client.start_session(causal_consistency=True, snapshot=False): pass with self.assertRaises(ConfigurationError): - with self.client.start_session(causal_consistency=True, - snapshot=True): + with self.client.start_session(causal_consistency=True, snapshot=True): pass def test_session_not_copyable(self): @@ -721,8 +708,8 @@ def test_session_not_copyable(self): with client.start_session() as s: self.assertRaises(TypeError, lambda: copy.copy(s)) -class TestCausalConsistency(unittest.TestCase): +class TestCausalConsistency(unittest.TestCase): @classmethod def setUpClass(cls): cls.listener = SessionTestListener() @@ -743,33 +730,32 @@ def test_core(self): self.assertIsNone(sess.operation_time) self.listener.results.clear() self.client.pymongo_test.test.find_one(session=sess) - started = self.listener.results['started'][0] + started = self.listener.results["started"][0] cmd = started.command - self.assertIsNone(cmd.get('readConcern')) + self.assertIsNone(cmd.get("readConcern")) op_time = sess.operation_time self.assertIsNotNone(op_time) - succeeded = self.listener.results['succeeded'][0] + succeeded = self.listener.results["succeeded"][0] reply = succeeded.reply - self.assertEqual(op_time, reply.get('operationTime')) + self.assertEqual(op_time, reply.get("operationTime")) # No explicit session self.client.pymongo_test.test.insert_one({}) self.assertEqual(sess.operation_time, op_time) self.listener.results.clear() try: - self.client.pymongo_test.command('doesntexist', session=sess) + self.client.pymongo_test.command("doesntexist", session=sess) except: pass - failed = self.listener.results['failed'][0] - failed_op_time = failed.failure.get('operationTime') + failed = self.listener.results["failed"][0] + failed_op_time = failed.failure.get("operationTime") # Some older builds of MongoDB 3.5 / 3.6 return None for # operationTime when a command fails. Make sure we don't # change operation_time to None. if failed_op_time is None: self.assertIsNotNone(sess.operation_time) else: - self.assertEqual( - sess.operation_time, failed_op_time) + self.assertEqual(sess.operation_time, failed_op_time) with self.client.start_session() as sess2: self.assertIsNone(sess2.cluster_time) @@ -795,36 +781,32 @@ def _test_reads(self, op, exception=None): op(coll, sess) else: op(coll, sess) - act = self.listener.results['started'][0].command.get( - 'readConcern', {}).get('afterClusterTime') + act = ( + self.listener.results["started"][0] + .command.get("readConcern", {}) + .get("afterClusterTime") + ) self.assertEqual(operation_time, act) @client_context.require_no_standalone def test_reads(self): # Make sure the collection exists. self.client.pymongo_test.test.insert_one({}) + self._test_reads(lambda coll, session: list(coll.aggregate([], session=session))) + self._test_reads(lambda coll, session: list(coll.find({}, session=session))) + self._test_reads(lambda coll, session: coll.find_one({}, session=session)) + self._test_reads(lambda coll, session: coll.count_documents({}, session=session)) + self._test_reads(lambda coll, session: coll.distinct("foo", session=session)) self._test_reads( - lambda coll, session: list(coll.aggregate([], session=session))) - self._test_reads( - lambda coll, session: list(coll.find({}, session=session))) - self._test_reads( - lambda coll, session: coll.find_one({}, session=session)) - self._test_reads( - lambda coll, session: coll.count_documents({}, session=session)) - self._test_reads( - lambda coll, session: coll.distinct('foo', session=session)) - self._test_reads( - lambda coll, session: list(coll.aggregate_raw_batches( - [], session=session))) - self._test_reads( - lambda coll, session: list(coll.find_raw_batches( - {}, session=session))) + lambda coll, session: list(coll.aggregate_raw_batches([], session=session)) + ) + self._test_reads(lambda coll, session: list(coll.find_raw_batches({}, session=session))) self.assertRaises( ConfigurationError, self._test_reads, - lambda coll, session: coll.estimated_document_count( - session=session)) + lambda coll, session: coll.estimated_document_count(session=session), + ) def _test_writes(self, op): coll = self.client.pymongo_test.test @@ -834,50 +816,46 @@ def _test_writes(self, op): self.assertIsNotNone(operation_time) self.listener.results.clear() coll.find_one({}, session=sess) - act = self.listener.results['started'][0].command.get( - 'readConcern', {}).get('afterClusterTime') + act = ( + self.listener.results["started"][0] + .command.get("readConcern", {}) + .get("afterClusterTime") + ) self.assertEqual(operation_time, act) @client_context.require_no_standalone def test_writes(self): + self._test_writes(lambda coll, session: coll.bulk_write([InsertOne({})], session=session)) + self._test_writes(lambda coll, session: coll.insert_one({}, session=session)) + self._test_writes(lambda coll, session: coll.insert_many([{}], session=session)) self._test_writes( - lambda coll, session: coll.bulk_write( - [InsertOne({})], session=session)) + lambda coll, session: coll.replace_one({"_id": 1}, {"x": 1}, session=session) + ) self._test_writes( - lambda coll, session: coll.insert_one({}, session=session)) + lambda coll, session: coll.update_one({}, {"$set": {"X": 1}}, session=session) + ) self._test_writes( - lambda coll, session: coll.insert_many([{}], session=session)) + lambda coll, session: coll.update_many({}, {"$set": {"x": 1}}, session=session) + ) + self._test_writes(lambda coll, session: coll.delete_one({}, session=session)) + self._test_writes(lambda coll, session: coll.delete_many({}, session=session)) self._test_writes( - lambda coll, session: coll.replace_one( - {'_id': 1}, {'x': 1}, session=session)) - self._test_writes( - lambda coll, session: coll.update_one( - {}, {'$set': {'X': 1}}, session=session)) - self._test_writes( - lambda coll, session: coll.update_many( - {}, {'$set': {'x': 1}}, session=session)) - self._test_writes( - lambda coll, session: coll.delete_one({}, session=session)) - self._test_writes( - lambda coll, session: coll.delete_many({}, session=session)) - self._test_writes( - lambda coll, session: coll.find_one_and_replace( - {'x': 1}, {'y': 1}, session=session)) + lambda coll, session: coll.find_one_and_replace({"x": 1}, {"y": 1}, session=session) + ) self._test_writes( lambda coll, session: coll.find_one_and_update( - {'y': 1}, {'$set': {'x': 1}}, session=session)) - self._test_writes( - lambda coll, session: coll.find_one_and_delete( - {'x': 1}, session=session)) - self._test_writes( - lambda coll, session: coll.create_index("foo", session=session)) + {"y": 1}, {"$set": {"x": 1}}, session=session + ) + ) + self._test_writes(lambda coll, session: coll.find_one_and_delete({"x": 1}, session=session)) + self._test_writes(lambda coll, session: coll.create_index("foo", session=session)) self._test_writes( lambda coll, session: coll.create_indexes( - [IndexModel([("bar", ASCENDING)])], session=session)) - self._test_writes( - lambda coll, session: coll.drop_index("foo_1", session=session)) - self._test_writes( - lambda coll, session: coll.drop_indexes(session=session)) + [IndexModel([("bar", ASCENDING)])], session=session + ) + ) + self._test_writes(lambda coll, session: coll.drop_index("foo_1", session=session)) + self._test_writes(lambda coll, session: coll.drop_indexes(session=session)) def _test_no_read_concern(self, op): coll = self.client.pymongo_test.test @@ -887,61 +865,56 @@ def _test_no_read_concern(self, op): self.assertIsNotNone(operation_time) self.listener.results.clear() op(coll, sess) - rc = self.listener.results['started'][0].command.get( - 'readConcern') + rc = self.listener.results["started"][0].command.get("readConcern") self.assertIsNone(rc) @client_context.require_no_standalone def test_writes_do_not_include_read_concern(self): self._test_no_read_concern( - lambda coll, session: coll.bulk_write( - [InsertOne({})], session=session)) + lambda coll, session: coll.bulk_write([InsertOne({})], session=session) + ) + self._test_no_read_concern(lambda coll, session: coll.insert_one({}, session=session)) + self._test_no_read_concern(lambda coll, session: coll.insert_many([{}], session=session)) self._test_no_read_concern( - lambda coll, session: coll.insert_one({}, session=session)) + lambda coll, session: coll.replace_one({"_id": 1}, {"x": 1}, session=session) + ) self._test_no_read_concern( - lambda coll, session: coll.insert_many([{}], session=session)) + lambda coll, session: coll.update_one({}, {"$set": {"X": 1}}, session=session) + ) self._test_no_read_concern( - lambda coll, session: coll.replace_one( - {'_id': 1}, {'x': 1}, session=session)) + lambda coll, session: coll.update_many({}, {"$set": {"x": 1}}, session=session) + ) + self._test_no_read_concern(lambda coll, session: coll.delete_one({}, session=session)) + self._test_no_read_concern(lambda coll, session: coll.delete_many({}, session=session)) self._test_no_read_concern( - lambda coll, session: coll.update_one( - {}, {'$set': {'X': 1}}, session=session)) - self._test_no_read_concern( - lambda coll, session: coll.update_many( - {}, {'$set': {'x': 1}}, session=session)) - self._test_no_read_concern( - lambda coll, session: coll.delete_one({}, session=session)) - self._test_no_read_concern( - lambda coll, session: coll.delete_many({}, session=session)) - self._test_no_read_concern( - lambda coll, session: coll.find_one_and_replace( - {'x': 1}, {'y': 1}, session=session)) + lambda coll, session: coll.find_one_and_replace({"x": 1}, {"y": 1}, session=session) + ) self._test_no_read_concern( lambda coll, session: coll.find_one_and_update( - {'y': 1}, {'$set': {'x': 1}}, session=session)) - self._test_no_read_concern( - lambda coll, session: coll.find_one_and_delete( - {'x': 1}, session=session)) + {"y": 1}, {"$set": {"x": 1}}, session=session + ) + ) self._test_no_read_concern( - lambda coll, session: coll.create_index("foo", session=session)) + lambda coll, session: coll.find_one_and_delete({"x": 1}, session=session) + ) + self._test_no_read_concern(lambda coll, session: coll.create_index("foo", session=session)) self._test_no_read_concern( lambda coll, session: coll.create_indexes( - [IndexModel([("bar", ASCENDING)])], session=session)) - self._test_no_read_concern( - lambda coll, session: coll.drop_index("foo_1", session=session)) - self._test_no_read_concern( - lambda coll, session: coll.drop_indexes(session=session)) + [IndexModel([("bar", ASCENDING)])], session=session + ) + ) + self._test_no_read_concern(lambda coll, session: coll.drop_index("foo_1", session=session)) + self._test_no_read_concern(lambda coll, session: coll.drop_indexes(session=session)) # Not a write, but explain also doesn't support readConcern. - self._test_no_read_concern( - lambda coll, session: coll.find({}, session=session).explain()) + self._test_no_read_concern(lambda coll, session: coll.find({}, session=session).explain()) @client_context.require_no_standalone @client_context.require_version_max(4, 1, 0) def test_aggregate_out_does_not_include_read_concern(self): self._test_no_read_concern( - lambda coll, session: list( - coll.aggregate([{"$out": "aggout"}], session=session))) + lambda coll, session: list(coll.aggregate([{"$out": "aggout"}], session=session)) + ) @client_context.require_no_standalone def test_get_more_does_not_include_read_concern(self): @@ -955,17 +928,20 @@ def test_get_more_does_not_include_read_concern(self): next(cursor) self.listener.results.clear() list(cursor) - started = self.listener.results['started'][0] - self.assertEqual(started.command_name, 'getMore') - self.assertIsNone(started.command.get('readConcern')) + started = self.listener.results["started"][0] + self.assertEqual(started.command_name, "getMore") + self.assertIsNone(started.command.get("readConcern")) def test_session_not_causal(self): with self.client.start_session(causal_consistency=False) as s: self.client.pymongo_test.test.insert_one({}, session=s) self.listener.results.clear() self.client.pymongo_test.test.find_one({}, session=s) - act = self.listener.results['started'][0].command.get( - 'readConcern', {}).get('afterClusterTime') + act = ( + self.listener.results["started"][0] + .command.get("readConcern", {}) + .get("afterClusterTime") + ) self.assertIsNone(act) @client_context.require_standalone @@ -974,8 +950,11 @@ def test_server_not_causal(self): self.client.pymongo_test.test.insert_one({}, session=s) self.listener.results.clear() self.client.pymongo_test.test.find_one({}, session=s) - act = self.listener.results['started'][0].command.get( - 'readConcern', {}).get('afterClusterTime') + act = ( + self.listener.results["started"][0] + .command.get("readConcern", {}) + .get("afterClusterTime") + ) self.assertIsNone(act) @client_context.require_no_standalone @@ -986,28 +965,25 @@ def test_read_concern(self): coll.insert_one({}, session=s) self.listener.results.clear() coll.find_one({}, session=s) - read_concern = self.listener.results['started'][0].command.get( - 'readConcern') + read_concern = self.listener.results["started"][0].command.get("readConcern") self.assertIsNotNone(read_concern) - self.assertIsNone(read_concern.get('level')) - self.assertIsNotNone(read_concern.get('afterClusterTime')) + self.assertIsNone(read_concern.get("level")) + self.assertIsNotNone(read_concern.get("afterClusterTime")) coll = coll.with_options(read_concern=ReadConcern("majority")) self.listener.results.clear() coll.find_one({}, session=s) - read_concern = self.listener.results['started'][0].command.get( - 'readConcern') + read_concern = self.listener.results["started"][0].command.get("readConcern") self.assertIsNotNone(read_concern) - self.assertEqual(read_concern.get('level'), 'majority') - self.assertIsNotNone(read_concern.get('afterClusterTime')) + self.assertEqual(read_concern.get("level"), "majority") + self.assertIsNotNone(read_concern.get("afterClusterTime")) @client_context.require_no_standalone def test_cluster_time_with_server_support(self): self.client.pymongo_test.test.insert_one({}) self.listener.results.clear() self.client.pymongo_test.test.find_one({}) - after_cluster_time = self.listener.results['started'][0].command.get( - '$clusterTime') + after_cluster_time = self.listener.results["started"][0].command.get("$clusterTime") self.assertIsNotNone(after_cluster_time) @client_context.require_standalone @@ -1015,22 +991,20 @@ def test_cluster_time_no_server_support(self): self.client.pymongo_test.test.insert_one({}) self.listener.results.clear() self.client.pymongo_test.test.find_one({}) - after_cluster_time = self.listener.results['started'][0].command.get( - '$clusterTime') + after_cluster_time = self.listener.results["started"][0].command.get("$clusterTime") self.assertIsNone(after_cluster_time) class TestClusterTime(IntegrationTest): def setUp(self): super(TestClusterTime, self).setUp() - if '$clusterTime' not in client_context.hello: - raise SkipTest('$clusterTime not supported') + if "$clusterTime" not in client_context.hello: + raise SkipTest("$clusterTime not supported") def test_cluster_time(self): listener = SessionTestListener() # Prevent heartbeats from updating $clusterTime between operations. - client = rs_or_single_client(event_listeners=[listener], - heartbeatFrequencyMS=999999) + client = rs_or_single_client(event_listeners=[listener], heartbeatFrequencyMS=999999) self.addCleanup(client.close) collection = client.pymongo_test.collection # Prepare for tests of find() and aggregate(). @@ -1041,7 +1015,7 @@ def test_cluster_time(self): def rename_and_drop(): # Ensure collection exists. collection.insert_one({}) - collection.rename('collection2') + collection.rename("collection2") client.pymongo_test.collection2.drop() def insert_and_find(): @@ -1064,22 +1038,19 @@ def insert_and_aggregate(): ops = [ # Tests from Driver Sessions Spec. - ('ping', lambda: client.admin.command('ping')), - ('aggregate', lambda: list(collection.aggregate([]))), - ('find', lambda: list(collection.find())), - ('insert_one', lambda: collection.insert_one({})), - + ("ping", lambda: client.admin.command("ping")), + ("aggregate", lambda: list(collection.aggregate([]))), + ("find", lambda: list(collection.find())), + ("insert_one", lambda: collection.insert_one({})), # Additional PyMongo tests. - ('insert_and_find', insert_and_find), - ('insert_and_aggregate', insert_and_aggregate), - ('update_one', - lambda: collection.update_one({}, {'$set': {'x': 1}})), - ('update_many', - lambda: collection.update_many({}, {'$set': {'x': 1}})), - ('delete_one', lambda: collection.delete_one({})), - ('delete_many', lambda: collection.delete_many({})), - ('bulk_write', lambda: collection.bulk_write([InsertOne({})])), - ('rename_and_drop', rename_and_drop), + ("insert_and_find", insert_and_find), + ("insert_and_aggregate", insert_and_aggregate), + ("update_one", lambda: collection.update_one({}, {"$set": {"x": 1}})), + ("update_many", lambda: collection.update_many({}, {"$set": {"x": 1}})), + ("delete_one", lambda: collection.delete_one({})), + ("delete_many", lambda: collection.delete_many({})), + ("bulk_write", lambda: collection.bulk_write([InsertOne({})])), + ("rename_and_drop", rename_and_drop), ] for name, f in ops: @@ -1090,48 +1061,48 @@ def insert_and_aggregate(): collection.insert_one({}) f() - self.assertGreaterEqual(len(listener.results['started']), 1) - for i, event in enumerate(listener.results['started']): + self.assertGreaterEqual(len(listener.results["started"]), 1) + for i, event in enumerate(listener.results["started"]): self.assertTrue( - '$clusterTime' in event.command, - "%s sent no $clusterTime with %s" % ( - f.__name__, event.command_name)) + "$clusterTime" in event.command, + "%s sent no $clusterTime with %s" % (f.__name__, event.command_name), + ) if i > 0: - succeeded = listener.results['succeeded'][i - 1] + succeeded = listener.results["succeeded"][i - 1] self.assertTrue( - '$clusterTime' in succeeded.reply, - "%s received no $clusterTime with %s" % ( - f.__name__, succeeded.command_name)) + "$clusterTime" in succeeded.reply, + "%s received no $clusterTime with %s" + % (f.__name__, succeeded.command_name), + ) self.assertTrue( - event.command['$clusterTime']['clusterTime'] >= - succeeded.reply['$clusterTime']['clusterTime'], - "%s sent wrong $clusterTime with %s" % ( - f.__name__, event.command_name)) + event.command["$clusterTime"]["clusterTime"] + >= succeeded.reply["$clusterTime"]["clusterTime"], + "%s sent wrong $clusterTime with %s" % (f.__name__, event.command_name), + ) class TestSpec(SpecRunner): RUN_ON_SERVERLESS = True # Location of JSON test specifications. - TEST_PATH = os.path.join( - os.path.dirname(os.path.realpath(__file__)), 'sessions', 'legacy') + TEST_PATH = os.path.join(os.path.dirname(os.path.realpath(__file__)), "sessions", "legacy") def last_two_command_events(self): """Return the last two command started events.""" - started_events = self.listener.results['started'][-2:] + started_events = self.listener.results["started"][-2:] self.assertEqual(2, len(started_events)) return started_events def assert_same_lsid_on_last_two_commands(self): """Run the assertSameLsidOnLastTwoCommands test operation.""" event1, event2 = self.last_two_command_events() - self.assertEqual(event1.command['lsid'], event2.command['lsid']) + self.assertEqual(event1.command["lsid"], event2.command["lsid"]) def assert_different_lsid_on_last_two_commands(self): """Run the assertDifferentLsidOnLastTwoCommands test operation.""" event1, event2 = self.last_two_command_events() - self.assertNotEqual(event1.command['lsid'], event2.command['lsid']) + self.assertNotEqual(event1.command["lsid"], event2.command["lsid"]) def assert_session_dirty(self, session): """Run the assertSessionDirty test operation. diff --git a/test/test_sessions_unified.py b/test/test_sessions_unified.py index fe25536e7e..2320d52718 100644 --- a/test/test_sessions_unified.py +++ b/test/test_sessions_unified.py @@ -23,8 +23,7 @@ from test.unified_format import generate_test_classes # Location of JSON test specifications. -TEST_PATH = os.path.join( - os.path.dirname(os.path.realpath(__file__)), 'sessions', 'unified') +TEST_PATH = os.path.join(os.path.dirname(os.path.realpath(__file__)), "sessions", "unified") # Generate unified tests. globals().update(generate_test_classes(TEST_PATH, module=__name__)) diff --git a/test/test_son.py b/test/test_son.py index edddd6b8b8..69beb81439 100644 --- a/test/test_son.py +++ b/test/test_son.py @@ -21,9 +21,11 @@ sys.path[0:0] = [""] -from bson.son import SON -from test import unittest from collections import OrderedDict +from test import unittest + +from bson.son import SON + class TestSON(unittest.TestCase): def test_ordered_dict(self): @@ -31,9 +33,9 @@ def test_ordered_dict(self): a1["hello"] = "world" a1["mike"] = "awesome" a1["hello_"] = "mike" - self.assertEqual(list(a1.items()), [("hello", "world"), - ("mike", "awesome"), - ("hello_", "mike")]) + self.assertEqual( + list(a1.items()), [("hello", "world"), ("mike", "awesome"), ("hello_", "mike")] + ) b2 = SON({"hello": "world"}) self.assertEqual(b2["hello"], "world") @@ -41,38 +43,28 @@ def test_ordered_dict(self): def test_equality(self): a1 = SON({"hello": "world"}) - b2 = SON((('hello', 'world'), ('mike', 'awesome'), ('hello_', 'mike'))) + b2 = SON((("hello", "world"), ("mike", "awesome"), ("hello_", "mike"))) self.assertEqual(a1, SON({"hello": "world"})) - self.assertEqual(b2, SON((('hello', 'world'), - ('mike', 'awesome'), - ('hello_', 'mike')))) - self.assertEqual(b2, dict((('hello_', 'mike'), - ('mike', 'awesome'), - ('hello', 'world')))) + self.assertEqual(b2, SON((("hello", "world"), ("mike", "awesome"), ("hello_", "mike")))) + self.assertEqual(b2, dict((("hello_", "mike"), ("mike", "awesome"), ("hello", "world")))) self.assertNotEqual(a1, b2) - self.assertNotEqual(b2, SON((('hello_', 'mike'), - ('mike', 'awesome'), - ('hello', 'world')))) + self.assertNotEqual(b2, SON((("hello_", "mike"), ("mike", "awesome"), ("hello", "world")))) # Explicitly test inequality self.assertFalse(a1 != SON({"hello": "world"})) - self.assertFalse(b2 != SON((('hello', 'world'), - ('mike', 'awesome'), - ('hello_', 'mike')))) - self.assertFalse(b2 != dict((('hello_', 'mike'), - ('mike', 'awesome'), - ('hello', 'world')))) + self.assertFalse(b2 != SON((("hello", "world"), ("mike", "awesome"), ("hello_", "mike")))) + self.assertFalse(b2 != dict((("hello_", "mike"), ("mike", "awesome"), ("hello", "world")))) # Embedded SON. - d4 = SON([('blah', {'foo': SON()})]) - self.assertEqual(d4, {'blah': {'foo': {}}}) - self.assertEqual(d4, {'blah': {'foo': SON()}}) - self.assertNotEqual(d4, {'blah': {'foo': []}}) + d4 = SON([("blah", {"foo": SON()})]) + self.assertEqual(d4, {"blah": {"foo": {}}}) + self.assertEqual(d4, {"blah": {"foo": SON()}}) + self.assertNotEqual(d4, {"blah": {"foo": []}}) # Original data unaffected. - self.assertEqual(SON, d4['blah']['foo'].__class__) + self.assertEqual(SON, d4["blah"]["foo"].__class__) def test_to_dict(self): a1 = SON() @@ -89,19 +81,17 @@ def test_to_dict(self): self.assertEqual(dict, d4.to_dict()["blah"]["foo"].__class__) # Original data unaffected. - self.assertEqual(SON, d4['blah']['foo'].__class__) + self.assertEqual(SON, d4["blah"]["foo"].__class__) def test_pickle(self): simple_son = SON([]) - complex_son = SON([('son', simple_son), - ('list', [simple_son, simple_son])]) + complex_son = SON([("son", simple_son), ("list", [simple_son, simple_son])]) for protocol in range(pickle.HIGHEST_PROTOCOL + 1): - pickled = pickle.loads(pickle.dumps(complex_son, - protocol=protocol)) - self.assertEqual(pickled['son'], pickled['list'][0]) - self.assertEqual(pickled['son'], pickled['list'][1]) + pickled = pickle.loads(pickle.dumps(complex_son, protocol=protocol)) + self.assertEqual(pickled["son"], pickled["list"][0]) + self.assertEqual(pickled["son"], pickled["list"][1]) def test_pickle_backwards_compatability(self): # This string was generated by pickling a SON object in pymongo @@ -109,16 +99,16 @@ def test_pickle_backwards_compatability(self): pickled_with_2_1_1 = ( "ccopy_reg\n_reconstructor\np0\n(cbson.son\nSON\np1\n" "c__builtin__\ndict\np2\n(dp3\ntp4\nRp5\n(dp6\n" - "S'_SON__keys'\np7\n(lp8\nsb.").encode('utf8') + "S'_SON__keys'\np7\n(lp8\nsb." + ).encode("utf8") son_2_1_1 = pickle.loads(pickled_with_2_1_1) self.assertEqual(son_2_1_1, SON([])) def test_copying(self): simple_son = SON([]) - complex_son = SON([('son', simple_son), - ('list', [simple_son, simple_son])]) + complex_son = SON([("son", simple_son), ("list", [simple_son, simple_son])]) regex_son = SON([("x", re.compile("^hello.*"))]) - reflexive_son = SON([('son', simple_son)]) + reflexive_son = SON([("son", simple_son)]) reflexive_son["reflexive"] = reflexive_son simple_son1 = copy.copy(simple_son) @@ -196,8 +186,10 @@ def test_keys(self): try: d - i().keys() except TypeError: - self.fail("SON().keys() is not returning an object compatible " - "with %s objects" % (str(i))) + self.fail( + "SON().keys() is not returning an object compatible " + "with %s objects" % (str(i)) + ) # Test to verify correctness d = SON({"k": "v"}).keys() for i in [OrderedDict, dict]: @@ -205,5 +197,6 @@ def test_keys(self): for i in [OrderedDict, dict]: self.assertEqual(d - i({"k": 0}).keys(), set()) + if __name__ == "__main__": unittest.main() diff --git a/test/test_srv_polling.py b/test/test_srv_polling.py index 1e99277692..f269e9b35b 100644 --- a/test/test_srv_polling.py +++ b/test/test_srv_polling.py @@ -15,28 +15,30 @@ """Run the SRV support tests.""" import sys - from time import sleep sys.path[0:0] = [""] -import pymongo +from test import client_knobs, unittest +from test.utils import FunctionCallRecorder, wait_until +import pymongo from pymongo import common from pymongo.errors import ConfigurationError -from pymongo.srv_resolver import _HAVE_DNSPYTHON from pymongo.mongo_client import MongoClient -from test import client_knobs, unittest -from test.utils import wait_until, FunctionCallRecorder - +from pymongo.srv_resolver import _HAVE_DNSPYTHON WAIT_TIME = 0.1 class SrvPollingKnobs(object): - def __init__(self, ttl_time=None, min_srv_rescan_interval=None, - nodelist_callback=None, - count_resolver_calls=False): + def __init__( + self, + ttl_time=None, + min_srv_rescan_interval=None, + nodelist_callback=None, + count_resolver_calls=False, + ): self.ttl_time = ttl_time self.min_srv_rescan_interval = min_srv_rescan_interval self.nodelist_callback = nodelist_callback @@ -47,8 +49,7 @@ def __init__(self, ttl_time=None, min_srv_rescan_interval=None, def enable(self): self.old_min_srv_rescan_interval = common.MIN_SRV_RESCAN_INTERVAL - self.old_dns_resolver_response = \ - pymongo.srv_resolver._SrvResolver.get_hosts_and_min_ttl + self.old_dns_resolver_response = pymongo.srv_resolver._SrvResolver.get_hosts_and_min_ttl if self.min_srv_rescan_interval is not None: common.MIN_SRV_RESCAN_INTERVAL = self.min_srv_rescan_interval @@ -73,8 +74,7 @@ def __enter__(self): def disable(self): common.MIN_SRV_RESCAN_INTERVAL = self.old_min_srv_rescan_interval - pymongo.srv_resolver._SrvResolver.get_hosts_and_min_ttl = \ - self.old_dns_resolver_response + pymongo.srv_resolver._SrvResolver.get_hosts_and_min_ttl = self.old_dns_resolver_response def __exit__(self, exc_type, exc_val, exc_tb): self.disable() @@ -84,18 +84,20 @@ class TestSrvPolling(unittest.TestCase): BASE_SRV_RESPONSE = [ ("localhost.test.build.10gen.cc", 27017), - ("localhost.test.build.10gen.cc", 27018)] + ("localhost.test.build.10gen.cc", 27018), + ] CONNECTION_STRING = "mongodb+srv://test1.test.build.10gen.cc" def setUp(self): if not _HAVE_DNSPYTHON: - raise unittest.SkipTest("SRV polling tests require the dnspython " - "module") + raise unittest.SkipTest("SRV polling tests require the dnspython " "module") # Patch timeouts to ensure short rescan SRV interval. self.client_knobs = client_knobs( - heartbeat_frequency=WAIT_TIME, min_heartbeat_interval=WAIT_TIME, - events_queue_frequency=WAIT_TIME) + heartbeat_frequency=WAIT_TIME, + min_heartbeat_interval=WAIT_TIME, + events_queue_frequency=WAIT_TIME, + ) self.client_knobs.enable() def tearDown(self): @@ -108,13 +110,14 @@ def assert_nodelist_change(self, expected_nodelist, client): """Check if the client._topology eventually sees all nodes in the expected_nodelist. """ + def predicate(): nodelist = self.get_nodelist(client) if set(expected_nodelist) == set(nodelist): return True return False - wait_until(predicate, "see expected nodelist", - timeout=100*WAIT_TIME) + + wait_until(predicate, "see expected nodelist", timeout=100 * WAIT_TIME) def assert_nodelist_nochange(self, expected_nodelist, client): """Check if the client._topology ever deviates from seeing all nodes @@ -122,20 +125,23 @@ def assert_nodelist_nochange(self, expected_nodelist, client): (WAIT_TIME * 10) seconds. Also check that the resolver is called at least once. """ - sleep(WAIT_TIME*10) + sleep(WAIT_TIME * 10) nodelist = self.get_nodelist(client) if set(expected_nodelist) != set(nodelist): msg = "Client nodelist %s changed unexpectedly (expected %s)" raise self.fail(msg % (nodelist, expected_nodelist)) self.assertGreaterEqual( pymongo.srv_resolver._SrvResolver.get_hosts_and_min_ttl.call_count, - 1, "resolver was never called") + 1, + "resolver was never called", + ) return True def run_scenario(self, dns_response, expect_change): if callable(dns_response): dns_resolver_response = dns_response else: + def dns_resolver_response(): return dns_response @@ -149,34 +155,29 @@ def dns_resolver_response(): expected_response = self.BASE_SRV_RESPONSE # Patch timeouts to ensure short test running times. - with SrvPollingKnobs( - ttl_time=WAIT_TIME, min_srv_rescan_interval=WAIT_TIME): + with SrvPollingKnobs(ttl_time=WAIT_TIME, min_srv_rescan_interval=WAIT_TIME): client = MongoClient(self.CONNECTION_STRING) self.assert_nodelist_change(self.BASE_SRV_RESPONSE, client) # Patch list of hosts returned by DNS query. with SrvPollingKnobs( - nodelist_callback=dns_resolver_response, - count_resolver_calls=count_resolver_calls): + nodelist_callback=dns_resolver_response, count_resolver_calls=count_resolver_calls + ): assertion_method(expected_response, client) def test_addition(self): response = self.BASE_SRV_RESPONSE[:] - response.append( - ("localhost.test.build.10gen.cc", 27019)) + response.append(("localhost.test.build.10gen.cc", 27019)) self.run_scenario(response, True) def test_removal(self): response = self.BASE_SRV_RESPONSE[:] - response.remove( - ("localhost.test.build.10gen.cc", 27018)) + response.remove(("localhost.test.build.10gen.cc", 27018)) self.run_scenario(response, True) def test_replace_one(self): response = self.BASE_SRV_RESPONSE[:] - response.remove( - ("localhost.test.build.10gen.cc", 27018)) - response.append( - ("localhost.test.build.10gen.cc", 27019)) + response.remove(("localhost.test.build.10gen.cc", 27018)) + response.append(("localhost.test.build.10gen.cc", 27019)) self.run_scenario(response, True) def test_replace_both_with_one(self): @@ -184,15 +185,20 @@ def test_replace_both_with_one(self): self.run_scenario(response, True) def test_replace_both_with_two(self): - response = [("localhost.test.build.10gen.cc", 27019), - ("localhost.test.build.10gen.cc", 27020)] + response = [ + ("localhost.test.build.10gen.cc", 27019), + ("localhost.test.build.10gen.cc", 27020), + ] self.run_scenario(response, True) def test_dns_failures(self): from dns import exception + for exc in (exception.FormError, exception.TooBig, exception.Timeout): + def response_callback(*args): raise exc("DNS Failure!") + self.run_scenario(response_callback, False) def test_dns_record_lookup_empty(self): @@ -203,89 +209,95 @@ def _test_recover_from_initial(self, initial_callback): # Construct a valid final response callback distinct from base. response_final = self.BASE_SRV_RESPONSE[:] response_final.pop() + def final_callback(): return response_final with SrvPollingKnobs( - ttl_time=WAIT_TIME, min_srv_rescan_interval=WAIT_TIME, - nodelist_callback=initial_callback, - count_resolver_calls=True): + ttl_time=WAIT_TIME, + min_srv_rescan_interval=WAIT_TIME, + nodelist_callback=initial_callback, + count_resolver_calls=True, + ): # Client uses unpatched method to get initial nodelist client = MongoClient(self.CONNECTION_STRING) # Invalid DNS resolver response should not change nodelist. self.assert_nodelist_nochange(self.BASE_SRV_RESPONSE, client) with SrvPollingKnobs( - ttl_time=WAIT_TIME, min_srv_rescan_interval=WAIT_TIME, - nodelist_callback=final_callback): + ttl_time=WAIT_TIME, min_srv_rescan_interval=WAIT_TIME, nodelist_callback=final_callback + ): # Nodelist should reflect new valid DNS resolver response. self.assert_nodelist_change(response_final, client) def test_recover_from_initially_empty_seedlist(self): def empty_seedlist(): return [] + self._test_recover_from_initial(empty_seedlist) def test_recover_from_initially_erroring_seedlist(self): def erroring_seedlist(): raise ConfigurationError + self._test_recover_from_initial(erroring_seedlist) def test_10_all_dns_selected(self): - response = [("localhost.test.build.10gen.cc", 27017), - ("localhost.test.build.10gen.cc", 27019), - ("localhost.test.build.10gen.cc", 27020)] + response = [ + ("localhost.test.build.10gen.cc", 27017), + ("localhost.test.build.10gen.cc", 27019), + ("localhost.test.build.10gen.cc", 27020), + ] def nodelist_callback(): return response - with SrvPollingKnobs(ttl_time=WAIT_TIME, - min_srv_rescan_interval=WAIT_TIME): + + with SrvPollingKnobs(ttl_time=WAIT_TIME, min_srv_rescan_interval=WAIT_TIME): client = MongoClient(self.CONNECTION_STRING, srvMaxHosts=0) self.addCleanup(client.close) with SrvPollingKnobs(nodelist_callback=nodelist_callback): self.assert_nodelist_change(response, client) def test_11_all_dns_selected(self): - response = [("localhost.test.build.10gen.cc", 27019), - ("localhost.test.build.10gen.cc", 27020)] + response = [ + ("localhost.test.build.10gen.cc", 27019), + ("localhost.test.build.10gen.cc", 27020), + ] def nodelist_callback(): return response - with SrvPollingKnobs( - ttl_time=WAIT_TIME, min_srv_rescan_interval=WAIT_TIME): + with SrvPollingKnobs(ttl_time=WAIT_TIME, min_srv_rescan_interval=WAIT_TIME): client = MongoClient(self.CONNECTION_STRING, srvMaxHosts=2) self.addCleanup(client.close) with SrvPollingKnobs(nodelist_callback=nodelist_callback): self.assert_nodelist_change(response, client) def test_12_new_dns_randomly_selected(self): - response = [("localhost.test.build.10gen.cc", 27020), - ("localhost.test.build.10gen.cc", 27019), - ("localhost.test.build.10gen.cc", 27017)] + response = [ + ("localhost.test.build.10gen.cc", 27020), + ("localhost.test.build.10gen.cc", 27019), + ("localhost.test.build.10gen.cc", 27017), + ] def nodelist_callback(): return response - with SrvPollingKnobs( - ttl_time=WAIT_TIME, min_srv_rescan_interval=WAIT_TIME): + with SrvPollingKnobs(ttl_time=WAIT_TIME, min_srv_rescan_interval=WAIT_TIME): client = MongoClient(self.CONNECTION_STRING, srvMaxHosts=2) self.addCleanup(client.close) with SrvPollingKnobs(nodelist_callback=nodelist_callback): - sleep(2*common.MIN_SRV_RESCAN_INTERVAL) - final_topology = set( - client.topology_description.server_descriptions()) - self.assertIn(("localhost.test.build.10gen.cc", 27017), - final_topology) + sleep(2 * common.MIN_SRV_RESCAN_INTERVAL) + final_topology = set(client.topology_description.server_descriptions()) + self.assertIn(("localhost.test.build.10gen.cc", 27017), final_topology) self.assertEqual(len(final_topology), 2) def test_does_not_flipflop(self): - with SrvPollingKnobs( - ttl_time=WAIT_TIME, min_srv_rescan_interval=WAIT_TIME): + with SrvPollingKnobs(ttl_time=WAIT_TIME, min_srv_rescan_interval=WAIT_TIME): client = MongoClient(self.CONNECTION_STRING, srvMaxHosts=1) self.addCleanup(client.close) old = set(client.topology_description.server_descriptions()) - sleep(4*WAIT_TIME) + sleep(4 * WAIT_TIME) new = set(client.topology_description.server_descriptions()) self.assertSetEqual(old, new) @@ -293,20 +305,19 @@ def test_srv_service_name(self): # Construct a valid final response callback distinct from base. response = [ ("localhost.test.build.10gen.cc.", 27019), - ("localhost.test.build.10gen.cc.", 27020) + ("localhost.test.build.10gen.cc.", 27020), ] def nodelist_callback(): return response - with SrvPollingKnobs( - ttl_time=WAIT_TIME, min_srv_rescan_interval=WAIT_TIME): + with SrvPollingKnobs(ttl_time=WAIT_TIME, min_srv_rescan_interval=WAIT_TIME): client = MongoClient( - "mongodb+srv://test22.test.build.10gen.cc/?srvServiceName" - "=customname") + "mongodb+srv://test22.test.build.10gen.cc/?srvServiceName" "=customname" + ) with SrvPollingKnobs(nodelist_callback=nodelist_callback): self.assert_nodelist_change(response, client) -if __name__ == '__main__': +if __name__ == "__main__": unittest.main() diff --git a/test/test_ssl.py b/test/test_ssl.py index 0162eb3a0d..d4b91ee437 100644 --- a/test/test_ssl.py +++ b/test/test_ssl.py @@ -20,26 +20,21 @@ sys.path[0:0] = [""] +from test import HAVE_IPADDRESS, IntegrationTest, SkipTest, client_context, unittest +from test.utils import ( + EventListener, + cat_files, + connected, + ignore_deprecations, + remove_all_users, +) from urllib.parse import quote_plus from pymongo import MongoClient, ssl_support -from pymongo.errors import (ConfigurationError, - ConnectionFailure, - OperationFailure) +from pymongo.errors import ConfigurationError, ConnectionFailure, OperationFailure from pymongo.hello import HelloCompat -from pymongo.ssl_support import HAVE_SSL, get_ssl_context, _ssl +from pymongo.ssl_support import HAVE_SSL, _ssl, get_ssl_context from pymongo.write_concern import WriteConcern -from test import (IntegrationTest, - client_context, - SkipTest, - unittest, - HAVE_IPADDRESS) -from test.utils import (EventListener, - cat_files, - connected, - ignore_deprecations, - remove_all_users) - _HAVE_PYOPENSSL = False try: @@ -47,9 +42,12 @@ import OpenSSL import requests import service_identity + # Ensure service_identity>=18.1 is installed from service_identity.pyopenssl import verify_ip_address + from pymongo.ocsp_support import _load_trusted_ca_certs + _HAVE_PYOPENSSL = True except ImportError: _load_trusted_ca_certs = None @@ -58,15 +56,13 @@ if HAVE_SSL: import ssl -CERT_PATH = os.path.join(os.path.dirname(os.path.realpath(__file__)), - 'certificates') -CLIENT_PEM = os.path.join(CERT_PATH, 'client.pem') -CLIENT_ENCRYPTED_PEM = os.path.join(CERT_PATH, 'password_protected.pem') -CA_PEM = os.path.join(CERT_PATH, 'ca.pem') -CA_BUNDLE_PEM = os.path.join(CERT_PATH, 'trusted-ca.pem') -CRL_PEM = os.path.join(CERT_PATH, 'crl.pem') -MONGODB_X509_USERNAME = ( - "C=US,ST=New York,L=New York City,O=MDB,OU=Drivers,CN=client") +CERT_PATH = os.path.join(os.path.dirname(os.path.realpath(__file__)), "certificates") +CLIENT_PEM = os.path.join(CERT_PATH, "client.pem") +CLIENT_ENCRYPTED_PEM = os.path.join(CERT_PATH, "password_protected.pem") +CA_PEM = os.path.join(CERT_PATH, "ca.pem") +CA_BUNDLE_PEM = os.path.join(CERT_PATH, "trusted-ca.pem") +CRL_PEM = os.path.join(CERT_PATH, "crl.pem") +MONGODB_X509_USERNAME = "C=US,ST=New York,L=New York City,O=MDB,OU=Drivers,CN=client" _PY37PLUS = sys.version_info[:2] >= (3, 7) @@ -82,27 +78,24 @@ class TestClientSSL(unittest.TestCase): - - @unittest.skipIf(HAVE_SSL, "The ssl module is available, can't test what " - "happens without it.") + @unittest.skipIf( + HAVE_SSL, "The ssl module is available, can't test what " "happens without it." + ) def test_no_ssl_module(self): # Explicit - self.assertRaises(ConfigurationError, - MongoClient, ssl=True) + self.assertRaises(ConfigurationError, MongoClient, ssl=True) # Implied - self.assertRaises(ConfigurationError, - MongoClient, tlsCertificateKeyFile=CLIENT_PEM) + self.assertRaises(ConfigurationError, MongoClient, tlsCertificateKeyFile=CLIENT_PEM) @unittest.skipUnless(HAVE_SSL, "The ssl module is not available.") @ignore_deprecations def test_config_ssl(self): # Tests various ssl configurations - self.assertRaises(ValueError, MongoClient, ssl='foo') - self.assertRaises(ConfigurationError, - MongoClient, - tls=False, - tlsCertificateKeyFile=CLIENT_PEM) + self.assertRaises(ValueError, MongoClient, ssl="foo") + self.assertRaises( + ConfigurationError, MongoClient, tls=False, tlsCertificateKeyFile=CLIENT_PEM + ) self.assertRaises(TypeError, MongoClient, ssl=0) self.assertRaises(TypeError, MongoClient, ssl=5.5) self.assertRaises(TypeError, MongoClient, ssl=[]) @@ -112,30 +105,20 @@ def test_config_ssl(self): self.assertRaises(TypeError, MongoClient, tlsCertificateKeyFile=[]) # Test invalid combinations - self.assertRaises(ConfigurationError, - MongoClient, - tls=False, - tlsCertificateKeyFile=CLIENT_PEM) - self.assertRaises(ConfigurationError, - MongoClient, - tls=False, - tlsCAFile=CA_PEM) - self.assertRaises(ConfigurationError, - MongoClient, - tls=False, - tlsCRLFile=CRL_PEM) - self.assertRaises(ConfigurationError, - MongoClient, - tls=False, - tlsAllowInvalidCertificates=False) - self.assertRaises(ConfigurationError, - MongoClient, - tls=False, - tlsAllowInvalidHostnames=False) - self.assertRaises(ConfigurationError, - MongoClient, - tls=False, - tlsDisableOCSPEndpointCheck=False) + self.assertRaises( + ConfigurationError, MongoClient, tls=False, tlsCertificateKeyFile=CLIENT_PEM + ) + self.assertRaises(ConfigurationError, MongoClient, tls=False, tlsCAFile=CA_PEM) + self.assertRaises(ConfigurationError, MongoClient, tls=False, tlsCRLFile=CRL_PEM) + self.assertRaises( + ConfigurationError, MongoClient, tls=False, tlsAllowInvalidCertificates=False + ) + self.assertRaises( + ConfigurationError, MongoClient, tls=False, tlsAllowInvalidHostnames=False + ) + self.assertRaises( + ConfigurationError, MongoClient, tls=False, tlsDisableOCSPEndpointCheck=False + ) @unittest.skipUnless(_HAVE_PYOPENSSL, "PyOpenSSL is not available.") def test_use_pyopenssl_when_available(self): @@ -148,13 +131,13 @@ def test_load_trusted_ca_certs(self): class TestSSL(IntegrationTest): - def assertClientWorks(self, client): coll = client.pymongo_test.ssl_test.with_options( - write_concern=WriteConcern(w=client_context.w)) + write_concern=WriteConcern(w=client_context.w) + ) coll.drop() - coll.insert_one({'ssl': True}) - self.assertTrue(coll.find_one()['ssl']) + coll.insert_one({"ssl": True}) + self.assertTrue(coll.find_one()["ssl"]) coll.drop() @classmethod @@ -183,30 +166,36 @@ def test_tlsCertificateKeyFilePassword(self): # # --sslPEMKeyFile=/path/to/pymongo/test/certificates/server.pem # --sslCAFile=/path/to/pymongo/test/certificates/ca.pem - if not hasattr(ssl, 'SSLContext') and not _ssl.IS_PYOPENSSL: + if not hasattr(ssl, "SSLContext") and not _ssl.IS_PYOPENSSL: self.assertRaises( ConfigurationError, MongoClient, - 'localhost', + "localhost", ssl=True, tlsCertificateKeyFile=CLIENT_ENCRYPTED_PEM, tlsCertificateKeyFilePassword="qwerty", tlsCAFile=CA_PEM, - serverSelectionTimeoutMS=100) + serverSelectionTimeoutMS=100, + ) else: - connected(MongoClient('localhost', - ssl=True, - tlsCertificateKeyFile=CLIENT_ENCRYPTED_PEM, - tlsCertificateKeyFilePassword="qwerty", - tlsCAFile=CA_PEM, - serverSelectionTimeoutMS=5000, - **self.credentials)) - - uri_fmt = ("mongodb://localhost/?ssl=true" - "&tlsCertificateKeyFile=%s&tlsCertificateKeyFilePassword=qwerty" - "&tlsCAFile=%s&serverSelectionTimeoutMS=5000") - connected(MongoClient(uri_fmt % (CLIENT_ENCRYPTED_PEM, CA_PEM), - **self.credentials)) + connected( + MongoClient( + "localhost", + ssl=True, + tlsCertificateKeyFile=CLIENT_ENCRYPTED_PEM, + tlsCertificateKeyFilePassword="qwerty", + tlsCAFile=CA_PEM, + serverSelectionTimeoutMS=5000, + **self.credentials + ) + ) + + uri_fmt = ( + "mongodb://localhost/?ssl=true" + "&tlsCertificateKeyFile=%s&tlsCertificateKeyFilePassword=qwerty" + "&tlsCAFile=%s&serverSelectionTimeoutMS=5000" + ) + connected(MongoClient(uri_fmt % (CLIENT_ENCRYPTED_PEM, CA_PEM), **self.credentials)) @client_context.require_tlsCertificateKeyFile @client_context.require_no_auth @@ -219,16 +208,21 @@ def test_cert_ssl_implicitly_set(self): # # test that setting tlsCertificateKeyFile causes ssl to be set to True - client = MongoClient(client_context.host, client_context.port, - tlsAllowInvalidCertificates=True, - tlsCertificateKeyFile=CLIENT_PEM) + client = MongoClient( + client_context.host, + client_context.port, + tlsAllowInvalidCertificates=True, + tlsCertificateKeyFile=CLIENT_PEM, + ) response = client.admin.command(HelloCompat.LEGACY_CMD) - if 'setName' in response: - client = MongoClient(client_context.pair, - replicaSet=response['setName'], - w=len(response['hosts']), - tlsAllowInvalidCertificates=True, - tlsCertificateKeyFile=CLIENT_PEM) + if "setName" in response: + client = MongoClient( + client_context.pair, + replicaSet=response["setName"], + w=len(response["hosts"]), + tlsAllowInvalidCertificates=True, + tlsCertificateKeyFile=CLIENT_PEM, + ) self.assertClientWorks(client) @@ -241,33 +235,41 @@ def test_cert_ssl_validation(self): # --sslPEMKeyFile=/path/to/pymongo/test/certificates/server.pem # --sslCAFile=/path/to/pymongo/test/certificates/ca.pem # - client = MongoClient('localhost', - ssl=True, - tlsCertificateKeyFile=CLIENT_PEM, - tlsAllowInvalidCertificates=False, - tlsCAFile=CA_PEM) + client = MongoClient( + "localhost", + ssl=True, + tlsCertificateKeyFile=CLIENT_PEM, + tlsAllowInvalidCertificates=False, + tlsCAFile=CA_PEM, + ) response = client.admin.command(HelloCompat.LEGACY_CMD) - if 'setName' in response: - if response['primary'].split(":")[0] != 'localhost': - raise SkipTest("No hosts in the replicaset for 'localhost'. " - "Cannot validate hostname in the certificate") - - client = MongoClient('localhost', - replicaSet=response['setName'], - w=len(response['hosts']), - ssl=True, - tlsCertificateKeyFile=CLIENT_PEM, - tlsAllowInvalidCertificates=False, - tlsCAFile=CA_PEM) + if "setName" in response: + if response["primary"].split(":")[0] != "localhost": + raise SkipTest( + "No hosts in the replicaset for 'localhost'. " + "Cannot validate hostname in the certificate" + ) + + client = MongoClient( + "localhost", + replicaSet=response["setName"], + w=len(response["hosts"]), + ssl=True, + tlsCertificateKeyFile=CLIENT_PEM, + tlsAllowInvalidCertificates=False, + tlsCAFile=CA_PEM, + ) self.assertClientWorks(client) if HAVE_IPADDRESS: - client = MongoClient('127.0.0.1', - ssl=True, - tlsCertificateKeyFile=CLIENT_PEM, - tlsAllowInvalidCertificates=False, - tlsCAFile=CA_PEM) + client = MongoClient( + "127.0.0.1", + ssl=True, + tlsCertificateKeyFile=CLIENT_PEM, + tlsAllowInvalidCertificates=False, + tlsCAFile=CA_PEM, + ) self.assertClientWorks(client) @client_context.require_tlsCertificateKeyFile @@ -279,9 +281,11 @@ def test_cert_ssl_uri_support(self): # --sslPEMKeyFile=/path/to/pymongo/test/certificates/server.pem # --sslCAFile=/path/to/pymongo/test/certificates/ca.pem # - uri_fmt = ("mongodb://localhost/?ssl=true&tlsCertificateKeyFile=%s&tlsAllowInvalidCertificates" - "=%s&tlsCAFile=%s&tlsAllowInvalidHostnames=false") - client = MongoClient(uri_fmt % (CLIENT_PEM, 'true', CA_PEM)) + uri_fmt = ( + "mongodb://localhost/?ssl=true&tlsCertificateKeyFile=%s&tlsAllowInvalidCertificates" + "=%s&tlsCAFile=%s&tlsAllowInvalidHostnames=false" + ) + client = MongoClient(uri_fmt % (CLIENT_PEM, "true", CA_PEM)) self.assertClientWorks(client) @client_context.require_tlsCertificateKeyFile @@ -307,81 +311,105 @@ def test_cert_ssl_validation_hostname_matching(self): response = self.client.admin.command(HelloCompat.LEGACY_CMD) with self.assertRaises(ConnectionFailure): - connected(MongoClient('server', - ssl=True, - tlsCertificateKeyFile=CLIENT_PEM, - tlsAllowInvalidCertificates=False, - tlsCAFile=CA_PEM, - serverSelectionTimeoutMS=500, - **self.credentials)) - - connected(MongoClient('server', - ssl=True, - tlsCertificateKeyFile=CLIENT_PEM, - tlsAllowInvalidCertificates=False, - tlsCAFile=CA_PEM, - tlsAllowInvalidHostnames=True, - serverSelectionTimeoutMS=500, - **self.credentials)) - - if 'setName' in response: + connected( + MongoClient( + "server", + ssl=True, + tlsCertificateKeyFile=CLIENT_PEM, + tlsAllowInvalidCertificates=False, + tlsCAFile=CA_PEM, + serverSelectionTimeoutMS=500, + **self.credentials + ) + ) + + connected( + MongoClient( + "server", + ssl=True, + tlsCertificateKeyFile=CLIENT_PEM, + tlsAllowInvalidCertificates=False, + tlsCAFile=CA_PEM, + tlsAllowInvalidHostnames=True, + serverSelectionTimeoutMS=500, + **self.credentials + ) + ) + + if "setName" in response: with self.assertRaises(ConnectionFailure): - connected(MongoClient('server', - replicaSet=response['setName'], - ssl=True, - tlsCertificateKeyFile=CLIENT_PEM, - tlsAllowInvalidCertificates=False, - tlsCAFile=CA_PEM, - serverSelectionTimeoutMS=500, - **self.credentials)) - - connected(MongoClient('server', - replicaSet=response['setName'], - ssl=True, - tlsCertificateKeyFile=CLIENT_PEM, - tlsAllowInvalidCertificates=False, - tlsCAFile=CA_PEM, - tlsAllowInvalidHostnames=True, - serverSelectionTimeoutMS=500, - **self.credentials)) + connected( + MongoClient( + "server", + replicaSet=response["setName"], + ssl=True, + tlsCertificateKeyFile=CLIENT_PEM, + tlsAllowInvalidCertificates=False, + tlsCAFile=CA_PEM, + serverSelectionTimeoutMS=500, + **self.credentials + ) + ) + + connected( + MongoClient( + "server", + replicaSet=response["setName"], + ssl=True, + tlsCertificateKeyFile=CLIENT_PEM, + tlsAllowInvalidCertificates=False, + tlsCAFile=CA_PEM, + tlsAllowInvalidHostnames=True, + serverSelectionTimeoutMS=500, + **self.credentials + ) + ) @client_context.require_tlsCertificateKeyFile @ignore_deprecations def test_tlsCRLFile_support(self): - if not hasattr(ssl, 'VERIFY_CRL_CHECK_LEAF') or _ssl.IS_PYOPENSSL: + if not hasattr(ssl, "VERIFY_CRL_CHECK_LEAF") or _ssl.IS_PYOPENSSL: self.assertRaises( ConfigurationError, MongoClient, - 'localhost', + "localhost", ssl=True, tlsCAFile=CA_PEM, tlsCRLFile=CRL_PEM, - serverSelectionTimeoutMS=100) + serverSelectionTimeoutMS=100, + ) else: - connected(MongoClient('localhost', - ssl=True, - tlsCAFile=CA_PEM, - serverSelectionTimeoutMS=100, - **self.credentials)) + connected( + MongoClient( + "localhost", + ssl=True, + tlsCAFile=CA_PEM, + serverSelectionTimeoutMS=100, + **self.credentials + ) + ) with self.assertRaises(ConnectionFailure): - connected(MongoClient('localhost', - ssl=True, - tlsCAFile=CA_PEM, - tlsCRLFile=CRL_PEM, - serverSelectionTimeoutMS=100, - **self.credentials)) - - uri_fmt = ("mongodb://localhost/?ssl=true&" - "tlsCAFile=%s&serverSelectionTimeoutMS=100") - connected(MongoClient(uri_fmt % (CA_PEM,), - **self.credentials)) - - uri_fmt = ("mongodb://localhost/?ssl=true&tlsCRLFile=%s" - "&tlsCAFile=%s&serverSelectionTimeoutMS=100") + connected( + MongoClient( + "localhost", + ssl=True, + tlsCAFile=CA_PEM, + tlsCRLFile=CRL_PEM, + serverSelectionTimeoutMS=100, + **self.credentials + ) + ) + + uri_fmt = "mongodb://localhost/?ssl=true&" "tlsCAFile=%s&serverSelectionTimeoutMS=100" + connected(MongoClient(uri_fmt % (CA_PEM,), **self.credentials)) + + uri_fmt = ( + "mongodb://localhost/?ssl=true&tlsCRLFile=%s" + "&tlsCAFile=%s&serverSelectionTimeoutMS=100" + ) with self.assertRaises(ConnectionFailure): - connected(MongoClient(uri_fmt % (CRL_PEM, CA_PEM), - **self.credentials)) + connected(MongoClient(uri_fmt % (CRL_PEM, CA_PEM), **self.credentials)) @client_context.require_tlsCertificateKeyFile @client_context.require_server_resolvable @@ -396,37 +424,39 @@ def test_validation_with_system_ca_certs(self): self.patch_system_certs(CA_PEM) with self.assertRaises(ConnectionFailure): # Server cert is verified but hostname matching fails - connected(MongoClient('server', - ssl=True, - serverSelectionTimeoutMS=100, - **self.credentials)) + connected( + MongoClient("server", ssl=True, serverSelectionTimeoutMS=100, **self.credentials) + ) # Server cert is verified. Disable hostname matching. - connected(MongoClient('server', - ssl=True, - tlsAllowInvalidHostnames=True, - serverSelectionTimeoutMS=100, - **self.credentials)) + connected( + MongoClient( + "server", + ssl=True, + tlsAllowInvalidHostnames=True, + serverSelectionTimeoutMS=100, + **self.credentials + ) + ) # Server cert and hostname are verified. - connected(MongoClient('localhost', - ssl=True, - serverSelectionTimeoutMS=100, - **self.credentials)) + connected( + MongoClient("localhost", ssl=True, serverSelectionTimeoutMS=100, **self.credentials) + ) # Server cert and hostname are verified. connected( MongoClient( - 'mongodb://localhost/?ssl=true&serverSelectionTimeoutMS=100', - **self.credentials)) + "mongodb://localhost/?ssl=true&serverSelectionTimeoutMS=100", **self.credentials + ) + ) def test_system_certs_config_error(self): ctx = get_ssl_context(None, None, None, None, True, True, False) - if ((sys.platform != "win32" - and hasattr(ctx, "set_default_verify_paths")) - or hasattr(ctx, "load_default_certs")): - raise SkipTest( - "Can't test when system CA certificates are loadable.") + if (sys.platform != "win32" and hasattr(ctx, "set_default_verify_paths")) or hasattr( + ctx, "load_default_certs" + ): + raise SkipTest("Can't test when system CA certificates are loadable.") have_certifi = ssl_support.HAVE_CERTIFI have_wincertstore = ssl_support.HAVE_WINCERTSTORE @@ -453,8 +483,7 @@ def test_certifi_support(self): # Force the test on Windows, regardless of environment. ssl_support.HAVE_WINCERTSTORE = False try: - ctx = get_ssl_context(None, None, CA_PEM, None, False, False, - False) + ctx = get_ssl_context(None, None, CA_PEM, None, False, False, False) ssl_sock = ctx.wrap_socket(socket.socket()) self.assertEqual(ssl_sock.ca_certs, CA_PEM) @@ -488,18 +517,24 @@ def test_wincertstore(self): @ignore_deprecations def test_mongodb_x509_auth(self): host, port = client_context.host, client_context.port - self.addCleanup(remove_all_users, client_context.client['$external']) + self.addCleanup(remove_all_users, client_context.client["$external"]) # Give x509 user all necessary privileges. - client_context.create_user('$external', MONGODB_X509_USERNAME, roles=[ - {'role': 'readWriteAnyDatabase', 'db': 'admin'}, - {'role': 'userAdminAnyDatabase', 'db': 'admin'}]) + client_context.create_user( + "$external", + MONGODB_X509_USERNAME, + roles=[ + {"role": "readWriteAnyDatabase", "db": "admin"}, + {"role": "userAdminAnyDatabase", "db": "admin"}, + ], + ) noauth = MongoClient( client_context.pair, ssl=True, tlsAllowInvalidCertificates=True, - tlsCertificateKeyFile=CLIENT_PEM) + tlsCertificateKeyFile=CLIENT_PEM, + ) self.addCleanup(noauth.close) with self.assertRaises(OperationFailure): @@ -508,11 +543,12 @@ def test_mongodb_x509_auth(self): listener = EventListener() auth = MongoClient( client_context.pair, - authMechanism='MONGODB-X509', + authMechanism="MONGODB-X509", ssl=True, tlsAllowInvalidCertificates=True, tlsCertificateKeyFile=CLIENT_PEM, - event_listeners=[listener]) + event_listeners=[listener], + ) self.addCleanup(auth.close) # No error @@ -520,64 +556,73 @@ def test_mongodb_x509_auth(self): names = listener.started_command_names() if client_context.version.at_least(4, 4, -1): # Speculative auth skips the authenticate command. - self.assertEqual(names, ['find']) + self.assertEqual(names, ["find"]) else: - self.assertEqual(names, ['authenticate', 'find']) - - uri = ('mongodb://%s@%s:%d/?authMechanism=' - 'MONGODB-X509' % ( - quote_plus(MONGODB_X509_USERNAME), host, port)) - client = MongoClient(uri, - ssl=True, - tlsAllowInvalidCertificates=True, - tlsCertificateKeyFile=CLIENT_PEM) + self.assertEqual(names, ["authenticate", "find"]) + + uri = "mongodb://%s@%s:%d/?authMechanism=" "MONGODB-X509" % ( + quote_plus(MONGODB_X509_USERNAME), + host, + port, + ) + client = MongoClient( + uri, ssl=True, tlsAllowInvalidCertificates=True, tlsCertificateKeyFile=CLIENT_PEM + ) self.addCleanup(client.close) # No error client.pymongo_test.test.find_one() - uri = 'mongodb://%s:%d/?authMechanism=MONGODB-X509' % (host, port) - client = MongoClient(uri, - ssl=True, - tlsAllowInvalidCertificates=True, - tlsCertificateKeyFile=CLIENT_PEM) + uri = "mongodb://%s:%d/?authMechanism=MONGODB-X509" % (host, port) + client = MongoClient( + uri, ssl=True, tlsAllowInvalidCertificates=True, tlsCertificateKeyFile=CLIENT_PEM + ) self.addCleanup(client.close) # No error client.pymongo_test.test.find_one() # Auth should fail if username and certificate do not match - uri = ('mongodb://%s@%s:%d/?authMechanism=' - 'MONGODB-X509' % ( - quote_plus("not the username"), host, port)) + uri = "mongodb://%s@%s:%d/?authMechanism=" "MONGODB-X509" % ( + quote_plus("not the username"), + host, + port, + ) bad_client = MongoClient( - uri, ssl=True, tlsAllowInvalidCertificates=True, - tlsCertificateKeyFile=CLIENT_PEM) + uri, ssl=True, tlsAllowInvalidCertificates=True, tlsCertificateKeyFile=CLIENT_PEM + ) self.addCleanup(bad_client.close) with self.assertRaises(OperationFailure): bad_client.pymongo_test.test.find_one() bad_client = MongoClient( - client_context.pair, - username="not the username", - authMechanism='MONGODB-X509', - ssl=True, - tlsAllowInvalidCertificates=True, - tlsCertificateKeyFile=CLIENT_PEM) + client_context.pair, + username="not the username", + authMechanism="MONGODB-X509", + ssl=True, + tlsAllowInvalidCertificates=True, + tlsCertificateKeyFile=CLIENT_PEM, + ) self.addCleanup(bad_client.close) with self.assertRaises(OperationFailure): bad_client.pymongo_test.test.find_one() # Invalid certificate (using CA certificate as client certificate) - uri = ('mongodb://%s@%s:%d/?authMechanism=' - 'MONGODB-X509' % ( - quote_plus(MONGODB_X509_USERNAME), host, port)) + uri = "mongodb://%s@%s:%d/?authMechanism=" "MONGODB-X509" % ( + quote_plus(MONGODB_X509_USERNAME), + host, + port, + ) try: - connected(MongoClient(uri, - ssl=True, - tlsAllowInvalidCertificates=True, - tlsCertificateKeyFile=CA_PEM, - serverSelectionTimeoutMS=100)) + connected( + MongoClient( + uri, + ssl=True, + tlsAllowInvalidCertificates=True, + tlsCertificateKeyFile=CA_PEM, + serverSelectionTimeoutMS=100, + ) + ) except (ConnectionFailure, ConfigurationError): pass else: @@ -592,15 +637,14 @@ def remove(path): except OSError: pass - temp_ca_bundle = os.path.join(CERT_PATH, 'trusted-ca-bundle.pem') + temp_ca_bundle = os.path.join(CERT_PATH, "trusted-ca-bundle.pem") self.addCleanup(remove, temp_ca_bundle) # Add the CA cert file to the bundle. cat_files(temp_ca_bundle, CA_BUNDLE_PEM, CA_PEM) - with MongoClient('localhost', - tls=True, - tlsCertificateKeyFile=CLIENT_PEM, - tlsCAFile=temp_ca_bundle) as client: - self.assertTrue(client.admin.command('ping')) + with MongoClient( + "localhost", tls=True, tlsCertificateKeyFile=CLIENT_PEM, tlsCAFile=temp_ca_bundle + ) as client: + self.assertTrue(client.admin.command("ping")) if __name__ == "__main__": diff --git a/test/test_streaming_protocol.py b/test/test_streaming_protocol.py index 4715fbfee7..e8cd6d2fa6 100644 --- a/test/test_streaming_protocol.py +++ b/test/test_streaming_protocol.py @@ -19,18 +19,18 @@ sys.path[0:0] = [""] +from test import IntegrationTest, client_context, unittest +from test.utils import ( + HeartbeatEventListener, + ServerEventListener, + rs_or_single_client, + single_client, + wait_until, +) + from pymongo import monitoring from pymongo.hello import HelloCompat -from test import (client_context, - IntegrationTest, - unittest) -from test.utils import (HeartbeatEventListener, - rs_or_single_client, - single_client, - ServerEventListener, - wait_until) - class TestStreamingProtocol(IntegrationTest): @client_context.require_failCommand_appName @@ -38,33 +38,40 @@ def test_failCommand_streaming(self): listener = ServerEventListener() hb_listener = HeartbeatEventListener() client = rs_or_single_client( - event_listeners=[listener, hb_listener], heartbeatFrequencyMS=500, - appName='failingHeartbeatTest') + event_listeners=[listener, hb_listener], + heartbeatFrequencyMS=500, + appName="failingHeartbeatTest", + ) self.addCleanup(client.close) # Force a connection. - client.admin.command('ping') + client.admin.command("ping") address = client.address listener.reset() fail_hello = { - 'configureFailPoint': 'failCommand', - 'mode': {'times': 4}, - 'data': { - 'failCommands': [HelloCompat.LEGACY_CMD, 'hello'], - 'closeConnection': False, - 'errorCode': 10107, - 'appName': 'failingHeartbeatTest', + "configureFailPoint": "failCommand", + "mode": {"times": 4}, + "data": { + "failCommands": [HelloCompat.LEGACY_CMD, "hello"], + "closeConnection": False, + "errorCode": 10107, + "appName": "failingHeartbeatTest", }, } with self.fail_point(fail_hello): + def _marked_unknown(event): - return (event.server_address == address - and not event.new_description.is_server_type_known) + return ( + event.server_address == address + and not event.new_description.is_server_type_known + ) def _discovered_node(event): - return (event.server_address == address - and not event.previous_description.is_server_type_known - and event.new_description.is_server_type_known) + return ( + event.server_address == address + and not event.previous_description.is_server_type_known + and event.new_description.is_server_type_known + ) def marked_unknown(): return len(listener.matching(_marked_unknown)) >= 1 @@ -73,11 +80,11 @@ def rediscovered(): return len(listener.matching(_discovered_node)) >= 1 # Topology events are published asynchronously - wait_until(marked_unknown, 'mark node unknown') - wait_until(rediscovered, 'rediscover node') + wait_until(marked_unknown, "mark node unknown") + wait_until(rediscovered, "rediscover node") # Server should be selectable. - client.admin.command('ping') + client.admin.command("ping") @client_context.require_failCommand_appName def test_streaming_rtt(self): @@ -86,45 +93,46 @@ def test_streaming_rtt(self): # On Windows, RTT can actually be 0.0 because time.time() only has # 1-15 millisecond resolution. We need to delay the initial hello # to ensure that RTT is never zero. - name = 'streamingRttTest' + name = "streamingRttTest" delay_hello = { - 'configureFailPoint': 'failCommand', - 'mode': {'times': 1000}, - 'data': { - 'failCommands': [HelloCompat.LEGACY_CMD, 'hello'], - 'blockConnection': True, - 'blockTimeMS': 20, + "configureFailPoint": "failCommand", + "mode": {"times": 1000}, + "data": { + "failCommands": [HelloCompat.LEGACY_CMD, "hello"], + "blockConnection": True, + "blockTimeMS": 20, # This can be uncommented after SERVER-49220 is fixed. # 'appName': name, }, } with self.fail_point(delay_hello): client = rs_or_single_client( - event_listeners=[listener, hb_listener], - heartbeatFrequencyMS=500, - appName=name) + event_listeners=[listener, hb_listener], heartbeatFrequencyMS=500, appName=name + ) self.addCleanup(client.close) # Force a connection. - client.admin.command('ping') + client.admin.command("ping") address = client.address - delay_hello['data']['blockTimeMS'] = 500 - delay_hello['data']['appName'] = name + delay_hello["data"]["blockTimeMS"] = 500 + delay_hello["data"]["appName"] = name with self.fail_point(delay_hello): + def rtt_exceeds_250_ms(): # XXX: Add a public TopologyDescription getter to MongoClient? topology = client._topology sd = topology.description.server_descriptions()[address] return sd.round_trip_time > 0.250 - wait_until(rtt_exceeds_250_ms, 'exceed 250ms RTT') + wait_until(rtt_exceeds_250_ms, "exceed 250ms RTT") # Server should be selectable. - client.admin.command('ping') + client.admin.command("ping") def changed_event(event): - return (event.server_address == address and isinstance( - event, monitoring.ServerDescriptionChangedEvent)) + return event.server_address == address and isinstance( + event, monitoring.ServerDescriptionChangedEvent + ) # There should only be one event published, for the initial discovery. events = listener.matching(changed_event) @@ -137,21 +145,21 @@ def test_monitor_waits_after_server_check_error(self): # This test implements: # https://github.com/mongodb/specifications/blob/6c5b2ac/source/server-discovery-and-monitoring/server-discovery-and-monitoring-tests.rst#monitors-sleep-at-least-minheartbeatfreqencyms-between-checks fail_hello = { - 'mode': {'times': 5}, - 'data': { - 'failCommands': [HelloCompat.LEGACY_CMD, 'hello'], - 'errorCode': 1234, - 'appName': 'SDAMMinHeartbeatFrequencyTest', + "mode": {"times": 5}, + "data": { + "failCommands": [HelloCompat.LEGACY_CMD, "hello"], + "errorCode": 1234, + "appName": "SDAMMinHeartbeatFrequencyTest", }, } with self.fail_point(fail_hello): start = time.time() client = single_client( - appName='SDAMMinHeartbeatFrequencyTest', - serverSelectionTimeoutMS=5000) + appName="SDAMMinHeartbeatFrequencyTest", serverSelectionTimeoutMS=5000 + ) self.addCleanup(client.close) # Force a connection. - client.admin.command('ping') + client.admin.command("ping") duration = time.time() - start # Explanation of the expected events: # 0ms: run configureFailPoint @@ -172,11 +180,13 @@ def test_monitor_waits_after_server_check_error(self): def test_heartbeat_awaited_flag(self): hb_listener = HeartbeatEventListener() client = single_client( - event_listeners=[hb_listener], heartbeatFrequencyMS=500, - appName='heartbeatEventAwaitedFlag') + event_listeners=[hb_listener], + heartbeatFrequencyMS=500, + appName="heartbeatEventAwaitedFlag", + ) self.addCleanup(client.close) # Force a connection. - client.admin.command('ping') + client.admin.command("ping") def hb_succeeded(event): return isinstance(event, monitoring.ServerHeartbeatSucceededEvent) @@ -185,18 +195,17 @@ def hb_failed(event): return isinstance(event, monitoring.ServerHeartbeatFailedEvent) fail_heartbeat = { - 'mode': {'times': 2}, - 'data': { - 'failCommands': [HelloCompat.LEGACY_CMD, 'hello'], - 'closeConnection': True, - 'appName': 'heartbeatEventAwaitedFlag', + "mode": {"times": 2}, + "data": { + "failCommands": [HelloCompat.LEGACY_CMD, "hello"], + "closeConnection": True, + "appName": "heartbeatEventAwaitedFlag", }, } with self.fail_point(fail_heartbeat): - wait_until(lambda: hb_listener.matching(hb_failed), - "published failed event") + wait_until(lambda: hb_listener.matching(hb_failed), "published failed event") # Reconnect. - client.admin.command('ping') + client.admin.command("ping") hb_succeeded_events = hb_listener.matching(hb_succeeded) hb_failed_events = hb_listener.matching(hb_failed) @@ -205,10 +214,12 @@ def hb_failed(event): # Depending on thread scheduling, the failed heartbeat could occur on # the second or third check. events = [type(e) for e in hb_listener.events[:4]] - if events == [monitoring.ServerHeartbeatStartedEvent, - monitoring.ServerHeartbeatSucceededEvent, - monitoring.ServerHeartbeatStartedEvent, - monitoring.ServerHeartbeatFailedEvent]: + if events == [ + monitoring.ServerHeartbeatStartedEvent, + monitoring.ServerHeartbeatSucceededEvent, + monitoring.ServerHeartbeatStartedEvent, + monitoring.ServerHeartbeatFailedEvent, + ]: self.assertFalse(hb_succeeded_events[1].awaited) else: self.assertTrue(hb_succeeded_events[1].awaited) diff --git a/test/test_threads.py b/test/test_threads.py index a3cde207a2..064008ee32 100644 --- a/test/test_threads.py +++ b/test/test_threads.py @@ -15,12 +15,8 @@ """Test that pymongo is thread safe.""" import threading - -from test import (client_context, - IntegrationTest, - unittest) -from test.utils import rs_or_single_client -from test.utils import joinall +from test import IntegrationTest, client_context, unittest +from test.utils import joinall, rs_or_single_client @client_context.require_connection @@ -29,7 +25,6 @@ def setUpModule(): class AutoAuthenticateThreads(threading.Thread): - def __init__(self, collection, num): threading.Thread.__init__(self) self.coll = collection @@ -39,14 +34,13 @@ def __init__(self, collection, num): def run(self): for i in range(self.num): - self.coll.insert_one({'num': i}) - self.coll.find_one({'num': i}) + self.coll.insert_one({"num": i}) + self.coll.find_one({"num": i}) self.success = True class SaveAndFind(threading.Thread): - def __init__(self, collection): threading.Thread.__init__(self) self.collection = collection @@ -63,7 +57,6 @@ def run(self): class Insert(threading.Thread): - def __init__(self, collection, n, expect_exception): threading.Thread.__init__(self) self.collection = collection @@ -87,7 +80,6 @@ def run(self): class Update(threading.Thread): - def __init__(self, collection, n, expect_exception): threading.Thread.__init__(self) self.collection = collection @@ -100,8 +92,7 @@ def run(self): error = True try: - self.collection.update_one({"test": "unique"}, - {"$set": {"test": "update"}}) + self.collection.update_one({"test": "unique"}, {"$set": {"test": "update"}}) error = False except: if not self.expect_exception: diff --git a/test/test_timestamp.py b/test/test_timestamp.py index bb3358121c..3602fe2808 100644 --- a/test/test_timestamp.py +++ b/test/test_timestamp.py @@ -14,15 +14,17 @@ """Tests for the Timestamp class.""" -import datetime -import sys import copy +import datetime import pickle +import sys + sys.path[0:0] = [""] +from test import unittest + from bson.timestamp import Timestamp from bson.tz_util import utc -from test import unittest class TestTimestamp(unittest.TestCase): @@ -78,5 +80,6 @@ def test_repr(self): t = Timestamp(0, 0) self.assertEqual(repr(t), "Timestamp(0, 0)") + if __name__ == "__main__": unittest.main() diff --git a/test/test_topology.py b/test/test_topology.py index a309d622ab..d7bae9229f 100644 --- a/test/test_topology.py +++ b/test/test_topology.py @@ -18,27 +18,23 @@ sys.path[0:0] = [""] -from bson.objectid import ObjectId +from test import client_knobs, unittest +from test.pymongo_mocks import DummyMonitor +from test.utils import MockPool, wait_until +from bson.objectid import ObjectId from pymongo import common -from pymongo.errors import (AutoReconnect, - ConfigurationError, - ConnectionFailure) +from pymongo.errors import AutoReconnect, ConfigurationError, ConnectionFailure from pymongo.hello import Hello, HelloCompat from pymongo.monitor import Monitor from pymongo.pool import PoolOptions from pymongo.read_preferences import ReadPreference, Secondary from pymongo.server_description import ServerDescription -from pymongo.server_selectors import (any_server_selector, - writable_server_selector) +from pymongo.server_selectors import any_server_selector, writable_server_selector from pymongo.server_type import SERVER_TYPE from pymongo.settings import TopologySettings -from pymongo.topology import (_ErrorContext, - Topology) +from pymongo.topology import Topology, _ErrorContext from pymongo.topology_description import TOPOLOGY_TYPE -from test import client_knobs, unittest -from test.pymongo_mocks import DummyMonitor -from test.utils import MockPool, wait_until class SetNameDiscoverySettings(TopologySettings): @@ -46,20 +42,20 @@ def get_topology_type(self): return TOPOLOGY_TYPE.ReplicaSetNoPrimary -address = ('a', 27017) +address = ("a", 27017) def create_mock_topology( - seeds=None, - replica_set_name=None, - monitor_class=DummyMonitor, - direct_connection=False): - partitioned_seeds = list(map(common.partition_node, seeds or ['a'])) + seeds=None, replica_set_name=None, monitor_class=DummyMonitor, direct_connection=False +): + partitioned_seeds = list(map(common.partition_node, seeds or ["a"])) topology_settings = TopologySettings( partitioned_seeds, replica_set_name=replica_set_name, pool_class=MockPool, - monitor_class=monitor_class, direct_connection=direct_connection) + monitor_class=monitor_class, + direct_connection=direct_connection, + ) t = Topology(topology_settings) t.open() @@ -67,8 +63,7 @@ def create_mock_topology( def got_hello(topology, server_address, hello_response): - server_description = ServerDescription( - server_address, Hello(hello_response), 0) + server_description = ServerDescription(server_address, Hello(hello_response), 0) topology.on_change(server_description) @@ -108,7 +103,7 @@ def test_timeout_configuration(self): t.open() # Get the default server. - server = t.get_server_by_address(('localhost', 27017)) + server = t.get_server_by_address(("localhost", 27017)) # The pool for application operations obeys our settings. self.assertEqual(1, server._pool.opts.connect_timeout) @@ -127,55 +122,53 @@ def test_timeout_configuration(self): class TestSingleServerTopology(TopologyTest): def test_direct_connection(self): for server_type, hello_response in [ - (SERVER_TYPE.RSPrimary, { - 'ok': 1, - HelloCompat.LEGACY_CMD: True, - 'hosts': ['a'], - 'setName': 'rs', - 'maxWireVersion': 6}), - - (SERVER_TYPE.RSSecondary, { - 'ok': 1, - HelloCompat.LEGACY_CMD: False, - 'secondary': True, - 'hosts': ['a'], - 'setName': 'rs', - 'maxWireVersion': 6}), - - (SERVER_TYPE.Mongos, { - 'ok': 1, - HelloCompat.LEGACY_CMD: True, - 'msg': 'isdbgrid', - 'maxWireVersion': 6}), - - (SERVER_TYPE.RSArbiter, { - 'ok': 1, - HelloCompat.LEGACY_CMD: False, - 'arbiterOnly': True, - 'hosts': ['a'], - 'setName': 'rs', - 'maxWireVersion': 6}), - - (SERVER_TYPE.Standalone, { - 'ok': 1, - HelloCompat.LEGACY_CMD: True, - 'maxWireVersion': 6}), - + ( + SERVER_TYPE.RSPrimary, + { + "ok": 1, + HelloCompat.LEGACY_CMD: True, + "hosts": ["a"], + "setName": "rs", + "maxWireVersion": 6, + }, + ), + ( + SERVER_TYPE.RSSecondary, + { + "ok": 1, + HelloCompat.LEGACY_CMD: False, + "secondary": True, + "hosts": ["a"], + "setName": "rs", + "maxWireVersion": 6, + }, + ), + ( + SERVER_TYPE.Mongos, + {"ok": 1, HelloCompat.LEGACY_CMD: True, "msg": "isdbgrid", "maxWireVersion": 6}, + ), + ( + SERVER_TYPE.RSArbiter, + { + "ok": 1, + HelloCompat.LEGACY_CMD: False, + "arbiterOnly": True, + "hosts": ["a"], + "setName": "rs", + "maxWireVersion": 6, + }, + ), + (SERVER_TYPE.Standalone, {"ok": 1, HelloCompat.LEGACY_CMD: True, "maxWireVersion": 6}), # A "slave" in a master-slave deployment. # This replication type was removed in MongoDB # 4.0. - (SERVER_TYPE.Standalone, { - 'ok': 1, - HelloCompat.LEGACY_CMD: False, - 'maxWireVersion': 6}), + (SERVER_TYPE.Standalone, {"ok": 1, HelloCompat.LEGACY_CMD: False, "maxWireVersion": 6}), ]: t = create_mock_topology(direct_connection=True) # Can't select a server while the only server is of type Unknown. - with self.assertRaisesRegex(ConnectionFailure, - 'No servers found yet'): - t.select_servers(any_server_selector, - server_selection_timeout=0) + with self.assertRaisesRegex(ConnectionFailure, "No servers found yet"): + t.select_servers(any_server_selector, server_selection_timeout=0) got_hello(t, address, hello_response) @@ -189,12 +182,13 @@ def test_direct_connection(self): # Topology type single is always readable and writable regardless # of server type or state. - self.assertEqual(t.description.topology_type_name, 'Single') + self.assertEqual(t.description.topology_type_name, "Single") self.assertTrue(t.description.has_writable_server()) self.assertTrue(t.description.has_readable_server()) self.assertTrue(t.description.has_readable_server(Secondary())) - self.assertTrue(t.description.has_readable_server( - Secondary(tag_sets=[{'tag': 'does-not-exist'}]))) + self.assertTrue( + t.description.has_readable_server(Secondary(tag_sets=[{"tag": "does-not-exist"}])) + ) def test_reopen(self): t = create_mock_topology() @@ -206,7 +200,7 @@ def test_reopen(self): def test_unavailable_seed(self): t = create_mock_topology() disconnected(t, address) - self.assertEqual(SERVER_TYPE.Unknown, get_type(t, 'a')) + self.assertEqual(SERVER_TYPE.Unknown, get_type(t, "a")) def test_round_trip_time(self): round_trip_time = 125 @@ -215,10 +209,9 @@ def test_round_trip_time(self): class TestMonitor(Monitor): def _check_with_socket(self, *args, **kwargs): if available: - return (Hello({'ok': 1, 'maxWireVersion': 6}), - round_trip_time) + return (Hello({"ok": 1, "maxWireVersion": 6}), round_trip_time) else: - raise AutoReconnect('mock monitor error') + raise AutoReconnect("mock monitor error") t = create_mock_topology(monitor_class=TestMonitor) self.addCleanup(t.close) @@ -237,14 +230,13 @@ def _check_with_socket(self, *args, **kwargs): def raises_err(): try: - t.select_server(writable_server_selector, - server_selection_timeout=0.1) + t.select_server(writable_server_selector, server_selection_timeout=0.1) except ConnectionFailure: return True else: return False - wait_until(raises_err, 'discover server is down') + wait_until(raises_err, "discover server is down") self.assertIsNone(s.description.round_trip_time) # Bring it back, RTT is now 20 milliseconds. @@ -254,8 +246,10 @@ def raises_err(): def new_average(): # We reset the average to the most recent measurement. description = s.description - return (description.round_trip_time is not None - and round(abs(20 - description.round_trip_time), 7) == 0) + return ( + description.round_trip_time is not None + and round(abs(20 - description.round_trip_time), 7) == 0 + ) tries = 0 while not new_average(): @@ -267,275 +261,289 @@ def new_average(): class TestMultiServerTopology(TopologyTest): def test_readable_writable(self): - t = create_mock_topology(replica_set_name='rs') - got_hello(t, ('a', 27017), { - 'ok': 1, - HelloCompat.LEGACY_CMD: True, - 'setName': 'rs', - 'hosts': ['a', 'b']}) - - got_hello(t, ('b', 27017), { - 'ok': 1, - HelloCompat.LEGACY_CMD: False, - 'secondary': True, - 'setName': 'rs', - 'hosts': ['a', 'b']}) + t = create_mock_topology(replica_set_name="rs") + got_hello( + t, + ("a", 27017), + {"ok": 1, HelloCompat.LEGACY_CMD: True, "setName": "rs", "hosts": ["a", "b"]}, + ) + + got_hello( + t, + ("b", 27017), + { + "ok": 1, + HelloCompat.LEGACY_CMD: False, + "secondary": True, + "setName": "rs", + "hosts": ["a", "b"], + }, + ) - self.assertEqual( - t.description.topology_type_name, 'ReplicaSetWithPrimary') + self.assertEqual(t.description.topology_type_name, "ReplicaSetWithPrimary") self.assertTrue(t.description.has_writable_server()) self.assertTrue(t.description.has_readable_server()) - self.assertTrue( - t.description.has_readable_server(Secondary())) - self.assertFalse( - t.description.has_readable_server( - Secondary(tag_sets=[{'tag': 'exists'}]))) - - t = create_mock_topology(replica_set_name='rs') - got_hello(t, ('a', 27017), { - 'ok': 1, - HelloCompat.LEGACY_CMD: False, - 'secondary': False, - 'setName': 'rs', - 'hosts': ['a', 'b']}) - - got_hello(t, ('b', 27017), { - 'ok': 1, - HelloCompat.LEGACY_CMD: False, - 'secondary': True, - 'setName': 'rs', - 'hosts': ['a', 'b']}) + self.assertTrue(t.description.has_readable_server(Secondary())) + self.assertFalse(t.description.has_readable_server(Secondary(tag_sets=[{"tag": "exists"}]))) + + t = create_mock_topology(replica_set_name="rs") + got_hello( + t, + ("a", 27017), + { + "ok": 1, + HelloCompat.LEGACY_CMD: False, + "secondary": False, + "setName": "rs", + "hosts": ["a", "b"], + }, + ) + + got_hello( + t, + ("b", 27017), + { + "ok": 1, + HelloCompat.LEGACY_CMD: False, + "secondary": True, + "setName": "rs", + "hosts": ["a", "b"], + }, + ) - self.assertEqual( - t.description.topology_type_name, 'ReplicaSetNoPrimary') + self.assertEqual(t.description.topology_type_name, "ReplicaSetNoPrimary") self.assertFalse(t.description.has_writable_server()) self.assertFalse(t.description.has_readable_server()) - self.assertTrue( - t.description.has_readable_server(Secondary())) - self.assertFalse( - t.description.has_readable_server( - Secondary(tag_sets=[{'tag': 'exists'}]))) - - t = create_mock_topology(replica_set_name='rs') - got_hello(t, ('a', 27017), { - 'ok': 1, - HelloCompat.LEGACY_CMD: True, - 'setName': 'rs', - 'hosts': ['a', 'b']}) - - got_hello(t, ('b', 27017), { - 'ok': 1, - HelloCompat.LEGACY_CMD: False, - 'secondary': True, - 'setName': 'rs', - 'hosts': ['a', 'b'], - 'tags': {'tag': 'exists'}}) - - self.assertEqual( - t.description.topology_type_name, 'ReplicaSetWithPrimary') + self.assertTrue(t.description.has_readable_server(Secondary())) + self.assertFalse(t.description.has_readable_server(Secondary(tag_sets=[{"tag": "exists"}]))) + + t = create_mock_topology(replica_set_name="rs") + got_hello( + t, + ("a", 27017), + {"ok": 1, HelloCompat.LEGACY_CMD: True, "setName": "rs", "hosts": ["a", "b"]}, + ) + + got_hello( + t, + ("b", 27017), + { + "ok": 1, + HelloCompat.LEGACY_CMD: False, + "secondary": True, + "setName": "rs", + "hosts": ["a", "b"], + "tags": {"tag": "exists"}, + }, + ) + + self.assertEqual(t.description.topology_type_name, "ReplicaSetWithPrimary") self.assertTrue(t.description.has_writable_server()) self.assertTrue(t.description.has_readable_server()) - self.assertTrue( - t.description.has_readable_server(Secondary())) - self.assertTrue( - t.description.has_readable_server( - Secondary(tag_sets=[{'tag': 'exists'}]))) + self.assertTrue(t.description.has_readable_server(Secondary())) + self.assertTrue(t.description.has_readable_server(Secondary(tag_sets=[{"tag": "exists"}]))) def test_close(self): - t = create_mock_topology(replica_set_name='rs') - got_hello(t, ('a', 27017), { - 'ok': 1, - HelloCompat.LEGACY_CMD: True, - 'setName': 'rs', - 'hosts': ['a', 'b']}) - - got_hello(t, ('b', 27017), { - 'ok': 1, - HelloCompat.LEGACY_CMD: False, - 'secondary': True, - 'setName': 'rs', - 'hosts': ['a', 'b']}) - - self.assertEqual(SERVER_TYPE.RSPrimary, get_type(t, 'a')) - self.assertEqual(SERVER_TYPE.RSSecondary, get_type(t, 'b')) - self.assertTrue(get_monitor(t, 'a').opened) - self.assertTrue(get_monitor(t, 'b').opened) - self.assertEqual(TOPOLOGY_TYPE.ReplicaSetWithPrimary, - t.description.topology_type) + t = create_mock_topology(replica_set_name="rs") + got_hello( + t, + ("a", 27017), + {"ok": 1, HelloCompat.LEGACY_CMD: True, "setName": "rs", "hosts": ["a", "b"]}, + ) + + got_hello( + t, + ("b", 27017), + { + "ok": 1, + HelloCompat.LEGACY_CMD: False, + "secondary": True, + "setName": "rs", + "hosts": ["a", "b"], + }, + ) + + self.assertEqual(SERVER_TYPE.RSPrimary, get_type(t, "a")) + self.assertEqual(SERVER_TYPE.RSSecondary, get_type(t, "b")) + self.assertTrue(get_monitor(t, "a").opened) + self.assertTrue(get_monitor(t, "b").opened) + self.assertEqual(TOPOLOGY_TYPE.ReplicaSetWithPrimary, t.description.topology_type) t.close() self.assertEqual(2, len(t.description.server_descriptions())) - self.assertEqual(SERVER_TYPE.Unknown, get_type(t, 'a')) - self.assertEqual(SERVER_TYPE.Unknown, get_type(t, 'b')) - self.assertFalse(get_monitor(t, 'a').opened) - self.assertFalse(get_monitor(t, 'b').opened) - self.assertEqual('rs', t.description.replica_set_name) - self.assertEqual(TOPOLOGY_TYPE.ReplicaSetNoPrimary, - t.description.topology_type) + self.assertEqual(SERVER_TYPE.Unknown, get_type(t, "a")) + self.assertEqual(SERVER_TYPE.Unknown, get_type(t, "b")) + self.assertFalse(get_monitor(t, "a").opened) + self.assertFalse(get_monitor(t, "b").opened) + self.assertEqual("rs", t.description.replica_set_name) + self.assertEqual(TOPOLOGY_TYPE.ReplicaSetNoPrimary, t.description.topology_type) # A closed topology should not be updated when receiving a hello. - got_hello(t, ('a', 27017), { - 'ok': 1, - HelloCompat.LEGACY_CMD: True, - 'setName': 'rs', - 'hosts': ['a', 'b', 'c']}) + got_hello( + t, + ("a", 27017), + {"ok": 1, HelloCompat.LEGACY_CMD: True, "setName": "rs", "hosts": ["a", "b", "c"]}, + ) self.assertEqual(2, len(t.description.server_descriptions())) - self.assertEqual(SERVER_TYPE.Unknown, get_type(t, 'a')) - self.assertEqual(SERVER_TYPE.Unknown, get_type(t, 'b')) - self.assertFalse(get_monitor(t, 'a').opened) - self.assertFalse(get_monitor(t, 'b').opened) + self.assertEqual(SERVER_TYPE.Unknown, get_type(t, "a")) + self.assertEqual(SERVER_TYPE.Unknown, get_type(t, "b")) + self.assertFalse(get_monitor(t, "a").opened) + self.assertFalse(get_monitor(t, "b").opened) # Server c should not have been added. - self.assertEqual(None, get_server(t, 'c')) - self.assertEqual('rs', t.description.replica_set_name) - self.assertEqual(TOPOLOGY_TYPE.ReplicaSetNoPrimary, - t.description.topology_type) + self.assertEqual(None, get_server(t, "c")) + self.assertEqual("rs", t.description.replica_set_name) + self.assertEqual(TOPOLOGY_TYPE.ReplicaSetNoPrimary, t.description.topology_type) def test_handle_error(self): - t = create_mock_topology(replica_set_name='rs') - got_hello(t, ('a', 27017), { - 'ok': 1, - HelloCompat.LEGACY_CMD: True, - 'setName': 'rs', - 'hosts': ['a', 'b']}) - - got_hello(t, ('b', 27017), { - 'ok': 1, - HelloCompat.LEGACY_CMD: False, - 'secondary': True, - 'setName': 'rs', - 'hosts': ['a', 'b']}) - - errctx = _ErrorContext(AutoReconnect('mock'), 0, 0, True, None) - t.handle_error(('a', 27017), errctx) - self.assertEqual(SERVER_TYPE.Unknown, get_type(t, 'a')) - self.assertEqual(SERVER_TYPE.RSSecondary, get_type(t, 'b')) - self.assertEqual('rs', t.description.replica_set_name) - self.assertEqual(TOPOLOGY_TYPE.ReplicaSetNoPrimary, - t.description.topology_type) - - got_hello(t, ('a', 27017), { - 'ok': 1, - HelloCompat.LEGACY_CMD: True, - 'setName': 'rs', - 'hosts': ['a', 'b']}) - - self.assertEqual(SERVER_TYPE.RSPrimary, get_type(t, 'a')) - self.assertEqual(TOPOLOGY_TYPE.ReplicaSetWithPrimary, - t.description.topology_type) - - t.handle_error(('b', 27017), errctx) - self.assertEqual(SERVER_TYPE.RSPrimary, get_type(t, 'a')) - self.assertEqual(SERVER_TYPE.Unknown, get_type(t, 'b')) - self.assertEqual('rs', t.description.replica_set_name) - self.assertEqual(TOPOLOGY_TYPE.ReplicaSetWithPrimary, - t.description.topology_type) + t = create_mock_topology(replica_set_name="rs") + got_hello( + t, + ("a", 27017), + {"ok": 1, HelloCompat.LEGACY_CMD: True, "setName": "rs", "hosts": ["a", "b"]}, + ) + + got_hello( + t, + ("b", 27017), + { + "ok": 1, + HelloCompat.LEGACY_CMD: False, + "secondary": True, + "setName": "rs", + "hosts": ["a", "b"], + }, + ) + + errctx = _ErrorContext(AutoReconnect("mock"), 0, 0, True, None) + t.handle_error(("a", 27017), errctx) + self.assertEqual(SERVER_TYPE.Unknown, get_type(t, "a")) + self.assertEqual(SERVER_TYPE.RSSecondary, get_type(t, "b")) + self.assertEqual("rs", t.description.replica_set_name) + self.assertEqual(TOPOLOGY_TYPE.ReplicaSetNoPrimary, t.description.topology_type) + + got_hello( + t, + ("a", 27017), + {"ok": 1, HelloCompat.LEGACY_CMD: True, "setName": "rs", "hosts": ["a", "b"]}, + ) + + self.assertEqual(SERVER_TYPE.RSPrimary, get_type(t, "a")) + self.assertEqual(TOPOLOGY_TYPE.ReplicaSetWithPrimary, t.description.topology_type) + + t.handle_error(("b", 27017), errctx) + self.assertEqual(SERVER_TYPE.RSPrimary, get_type(t, "a")) + self.assertEqual(SERVER_TYPE.Unknown, get_type(t, "b")) + self.assertEqual("rs", t.description.replica_set_name) + self.assertEqual(TOPOLOGY_TYPE.ReplicaSetWithPrimary, t.description.topology_type) def test_handle_error_removed_server(self): - t = create_mock_topology(replica_set_name='rs') + t = create_mock_topology(replica_set_name="rs") # No error resetting a server not in the TopologyDescription. - errctx = _ErrorContext(AutoReconnect('mock'), 0, 0, True, None) - t.handle_error(('b', 27017), errctx) + errctx = _ErrorContext(AutoReconnect("mock"), 0, 0, True, None) + t.handle_error(("b", 27017), errctx) # Server was *not* added as type Unknown. - self.assertFalse(t.has_server(('b', 27017))) + self.assertFalse(t.has_server(("b", 27017))) def test_discover_set_name_from_primary(self): # Discovering a replica set without the setName supplied by the user # is not yet supported by MongoClient, but Topology can do it. topology_settings = SetNameDiscoverySettings( - seeds=[address], - pool_class=MockPool, - monitor_class=DummyMonitor) + seeds=[address], pool_class=MockPool, monitor_class=DummyMonitor + ) t = Topology(topology_settings) self.assertEqual(t.description.replica_set_name, None) - self.assertEqual(t.description.topology_type, - TOPOLOGY_TYPE.ReplicaSetNoPrimary) + self.assertEqual(t.description.topology_type, TOPOLOGY_TYPE.ReplicaSetNoPrimary) t.open() - got_hello(t, address, { - 'ok': 1, - HelloCompat.LEGACY_CMD: True, - 'setName': 'rs', - 'hosts': ['a']}) + got_hello( + t, address, {"ok": 1, HelloCompat.LEGACY_CMD: True, "setName": "rs", "hosts": ["a"]} + ) - self.assertEqual(t.description.replica_set_name, 'rs') - self.assertEqual(t.description.topology_type, - TOPOLOGY_TYPE.ReplicaSetWithPrimary) + self.assertEqual(t.description.replica_set_name, "rs") + self.assertEqual(t.description.topology_type, TOPOLOGY_TYPE.ReplicaSetWithPrimary) # Another response from the primary. Tests the code that processes # primary response when topology type is already ReplicaSetWithPrimary. - got_hello(t, address, { - 'ok': 1, - HelloCompat.LEGACY_CMD: True, - 'setName': 'rs', - 'hosts': ['a']}) + got_hello( + t, address, {"ok": 1, HelloCompat.LEGACY_CMD: True, "setName": "rs", "hosts": ["a"]} + ) # No change. - self.assertEqual(t.description.replica_set_name, 'rs') - self.assertEqual(t.description.topology_type, - TOPOLOGY_TYPE.ReplicaSetWithPrimary) + self.assertEqual(t.description.replica_set_name, "rs") + self.assertEqual(t.description.topology_type, TOPOLOGY_TYPE.ReplicaSetWithPrimary) def test_discover_set_name_from_secondary(self): # Discovering a replica set without the setName supplied by the user # is not yet supported by MongoClient, but Topology can do it. topology_settings = SetNameDiscoverySettings( - seeds=[address], - pool_class=MockPool, - monitor_class=DummyMonitor) + seeds=[address], pool_class=MockPool, monitor_class=DummyMonitor + ) t = Topology(topology_settings) self.assertEqual(t.description.replica_set_name, None) - self.assertEqual(t.description.topology_type, - TOPOLOGY_TYPE.ReplicaSetNoPrimary) + self.assertEqual(t.description.topology_type, TOPOLOGY_TYPE.ReplicaSetNoPrimary) t.open() - got_hello(t, address, { - 'ok': 1, - HelloCompat.LEGACY_CMD: False, - 'secondary': True, - 'setName': 'rs', - 'hosts': ['a']}) + got_hello( + t, + address, + { + "ok": 1, + HelloCompat.LEGACY_CMD: False, + "secondary": True, + "setName": "rs", + "hosts": ["a"], + }, + ) - self.assertEqual(t.description.replica_set_name, 'rs') - self.assertEqual(t.description.topology_type, - TOPOLOGY_TYPE.ReplicaSetNoPrimary) + self.assertEqual(t.description.replica_set_name, "rs") + self.assertEqual(t.description.topology_type, TOPOLOGY_TYPE.ReplicaSetNoPrimary) def test_wire_version(self): - t = create_mock_topology(replica_set_name='rs') + t = create_mock_topology(replica_set_name="rs") t.description.check_compatible() # No error. - got_hello(t, address, { - 'ok': 1, - HelloCompat.LEGACY_CMD: True, - 'setName': 'rs', - 'hosts': ['a']}) + got_hello( + t, address, {"ok": 1, HelloCompat.LEGACY_CMD: True, "setName": "rs", "hosts": ["a"]} + ) # Use defaults. server = t.get_server_by_address(address) self.assertEqual(server.description.min_wire_version, 0) self.assertEqual(server.description.max_wire_version, 0) - got_hello(t, address, { - 'ok': 1, - HelloCompat.LEGACY_CMD: True, - 'setName': 'rs', - 'hosts': ['a'], - 'minWireVersion': 1, - 'maxWireVersion': 6}) + got_hello( + t, + address, + { + "ok": 1, + HelloCompat.LEGACY_CMD: True, + "setName": "rs", + "hosts": ["a"], + "minWireVersion": 1, + "maxWireVersion": 6, + }, + ) self.assertEqual(server.description.min_wire_version, 1) self.assertEqual(server.description.max_wire_version, 6) t.select_servers(any_server_selector) # Incompatible. - got_hello(t, address, { - 'ok': 1, - HelloCompat.LEGACY_CMD: True, - 'setName': 'rs', - 'hosts': ['a'], - 'minWireVersion': 21, - 'maxWireVersion': 22}) + got_hello( + t, + address, + { + "ok": 1, + HelloCompat.LEGACY_CMD: True, + "setName": "rs", + "hosts": ["a"], + "minWireVersion": 21, + "maxWireVersion": 22, + }, + ) try: t.select_servers(any_server_selector) @@ -544,19 +552,24 @@ def test_wire_version(self): self.assertEqual( str(e), "Server at a:27017 requires wire version 21, but this version " - "of PyMongo only supports up to %d." - % (common.MAX_SUPPORTED_WIRE_VERSION,)) + "of PyMongo only supports up to %d." % (common.MAX_SUPPORTED_WIRE_VERSION,), + ) else: - self.fail('No error with incompatible wire version') + self.fail("No error with incompatible wire version") # Incompatible. - got_hello(t, address, { - 'ok': 1, - HelloCompat.LEGACY_CMD: True, - 'setName': 'rs', - 'hosts': ['a'], - 'minWireVersion': 0, - 'maxWireVersion': 0}) + got_hello( + t, + address, + { + "ok": 1, + HelloCompat.LEGACY_CMD: True, + "setName": "rs", + "hosts": ["a"], + "minWireVersion": 0, + "maxWireVersion": 0, + }, + ) try: t.select_servers(any_server_selector) @@ -566,57 +579,72 @@ def test_wire_version(self): str(e), "Server at a:27017 reports wire version 0, but this version " "of PyMongo requires at least %d (MongoDB %s)." - % (common.MIN_SUPPORTED_WIRE_VERSION, - common.MIN_SUPPORTED_SERVER_VERSION)) + % (common.MIN_SUPPORTED_WIRE_VERSION, common.MIN_SUPPORTED_SERVER_VERSION), + ) else: - self.fail('No error with incompatible wire version') + self.fail("No error with incompatible wire version") def test_max_write_batch_size(self): - t = create_mock_topology(seeds=['a', 'b'], replica_set_name='rs') + t = create_mock_topology(seeds=["a", "b"], replica_set_name="rs") def write_batch_size(): s = t.select_server(writable_server_selector) return s.description.max_write_batch_size - got_hello(t, ('a', 27017), { - 'ok': 1, - HelloCompat.LEGACY_CMD: True, - 'setName': 'rs', - 'hosts': ['a', 'b'], - 'maxWireVersion': 6, - 'maxWriteBatchSize': 1}) - - got_hello(t, ('b', 27017), { - 'ok': 1, - HelloCompat.LEGACY_CMD: False, - 'secondary': True, - 'setName': 'rs', - 'hosts': ['a', 'b'], - 'maxWireVersion': 6, - 'maxWriteBatchSize': 2}) + got_hello( + t, + ("a", 27017), + { + "ok": 1, + HelloCompat.LEGACY_CMD: True, + "setName": "rs", + "hosts": ["a", "b"], + "maxWireVersion": 6, + "maxWriteBatchSize": 1, + }, + ) + + got_hello( + t, + ("b", 27017), + { + "ok": 1, + HelloCompat.LEGACY_CMD: False, + "secondary": True, + "setName": "rs", + "hosts": ["a", "b"], + "maxWireVersion": 6, + "maxWriteBatchSize": 2, + }, + ) # Uses primary's max batch size. self.assertEqual(1, write_batch_size()) # b becomes primary. - got_hello(t, ('b', 27017), { - 'ok': 1, - HelloCompat.LEGACY_CMD: True, - 'setName': 'rs', - 'hosts': ['a', 'b'], - 'maxWireVersion': 6, - 'maxWriteBatchSize': 2}) + got_hello( + t, + ("b", 27017), + { + "ok": 1, + HelloCompat.LEGACY_CMD: True, + "setName": "rs", + "hosts": ["a", "b"], + "maxWireVersion": 6, + "maxWriteBatchSize": 2, + }, + ) self.assertEqual(2, write_batch_size()) def test_topology_repr(self): - t = create_mock_topology(replica_set_name='rs') + t = create_mock_topology(replica_set_name="rs") self.addCleanup(t.close) - got_hello(t, ('a', 27017), { - 'ok': 1, - HelloCompat.LEGACY_CMD: True, - 'setName': 'rs', - 'hosts': ['a', 'c', 'b']}) + got_hello( + t, + ("a", 27017), + {"ok": 1, HelloCompat.LEGACY_CMD: True, "setName": "rs", "hosts": ["a", "c", "b"]}, + ) self.assertEqual( repr(t.description), ", " "]>" % (t._topology_id,)) + " rtt: None>]>" % (t._topology_id,), + ) def test_unexpected_load_balancer(self): # Note: This behavior should not be reachable in practice but we # should handle it gracefully nonetheless. See PYTHON-2791. # Load balancers are included in topology with a single seed. - t = create_mock_topology(seeds=['a']) - mock_lb_response = {'ok': 1, 'msg': 'isdbgrid', - 'serviceId': ObjectId(), 'maxWireVersion': 13} - got_hello(t, ('a', 27017), mock_lb_response) + t = create_mock_topology(seeds=["a"]) + mock_lb_response = { + "ok": 1, + "msg": "isdbgrid", + "serviceId": ObjectId(), + "maxWireVersion": 13, + } + got_hello(t, ("a", 27017), mock_lb_response) sds = t.description.server_descriptions() - self.assertIn(('a', 27017), sds) - self.assertEqual(sds[('a', 27017)].server_type_name, 'LoadBalancer') - self.assertEqual(t.description.topology_type_name, 'Single') + self.assertIn(("a", 27017), sds) + self.assertEqual(sds[("a", 27017)].server_type_name, "LoadBalancer") + self.assertEqual(t.description.topology_type_name, "Single") self.assertTrue(t.description.has_writable_server()) # Load balancers are removed from a topology with multiple seeds. - t = create_mock_topology(seeds=['a', 'b']) - got_hello(t, ('a', 27017), mock_lb_response) - self.assertNotIn(('a', 27017), t.description.server_descriptions()) - self.assertEqual(t.description.topology_type_name, 'Unknown') + t = create_mock_topology(seeds=["a", "b"]) + got_hello(t, ("a", 27017), mock_lb_response) + self.assertNotIn(("a", 27017), t.description.server_descriptions()) + self.assertEqual(t.description.topology_type_name, "Unknown") def wait_for_primary(topology): @@ -663,7 +696,7 @@ def get_primary(): except ConnectionFailure: return None - return wait_until(get_primary, 'find primary') + return wait_until(get_primary, "find primary") class TestTopologyErrors(TopologyTest): @@ -677,9 +710,9 @@ class TestMonitor(Monitor): def _check_with_socket(self, *args, **kwargs): hello_count[0] += 1 if hello_count[0] == 1: - return Hello({'ok': 1, 'maxWireVersion': 6}), 0 + return Hello({"ok": 1, "maxWireVersion": 6}), 0 else: - raise AutoReconnect('mock monitor error') + raise AutoReconnect("mock monitor error") t = create_mock_topology(monitor_class=TestMonitor) self.addCleanup(t.close) @@ -699,17 +732,15 @@ class TestMonitor(Monitor): def _check_with_socket(self, *args, **kwargs): hello_count[0] += 1 if hello_count[0] in (1, 3): - return Hello({'ok': 1, 'maxWireVersion': 6}), 0 + return Hello({"ok": 1, "maxWireVersion": 6}), 0 else: - raise AutoReconnect( - 'mock monitor error #%s' % (hello_count[0],)) + raise AutoReconnect("mock monitor error #%s" % (hello_count[0],)) t = create_mock_topology(monitor_class=TestMonitor) self.addCleanup(t.close) server = wait_for_primary(t) self.assertEqual(1, hello_count[0]) - self.assertEqual(SERVER_TYPE.Standalone, - server.description.server_type) + self.assertEqual(SERVER_TYPE.Standalone, server.description.server_type) # Second hello call, server is marked Unknown, then the monitor # immediately runs a retry (third hello). @@ -718,12 +749,11 @@ def _check_with_socket(self, *args, **kwargs): # after the failed check triggered by request_check_all. Wait until # the server becomes known again. server = t.select_server(writable_server_selector, 0.250) - self.assertEqual(SERVER_TYPE.Standalone, - server.description.server_type) + self.assertEqual(SERVER_TYPE.Standalone, server.description.server_type) self.assertEqual(3, hello_count[0]) def test_internal_monitor_error(self): - exception = AssertionError('internal error') + exception = AssertionError("internal error") class TestMonitor(Monitor): def _check_with_socket(self, *args, **kwargs): @@ -731,9 +761,8 @@ def _check_with_socket(self, *args, **kwargs): t = create_mock_topology(monitor_class=TestMonitor) self.addCleanup(t.close) - with self.assertRaisesRegex(ConnectionFailure, 'internal error'): - t.select_server(any_server_selector, - server_selection_timeout=0.5) + with self.assertRaisesRegex(ConnectionFailure, "internal error"): + t.select_server(any_server_selector, server_selection_timeout=0.5) class TestServerSelectionErrors(TopologyTest): @@ -744,69 +773,80 @@ def assertMessage(self, message, topology, selector=any_server_selector): self.assertIn(message, str(context.exception)) def test_no_primary(self): - t = create_mock_topology(replica_set_name='rs') - got_hello(t, address, { - 'ok': 1, - HelloCompat.LEGACY_CMD: False, - 'secondary': True, - 'setName': 'rs', - 'hosts': ['a']}) + t = create_mock_topology(replica_set_name="rs") + got_hello( + t, + address, + { + "ok": 1, + HelloCompat.LEGACY_CMD: False, + "secondary": True, + "setName": "rs", + "hosts": ["a"], + }, + ) - self.assertMessage('No replica set members match selector "Primary()"', - t, ReadPreference.PRIMARY) + self.assertMessage( + 'No replica set members match selector "Primary()"', t, ReadPreference.PRIMARY + ) - self.assertMessage('No primary available for writes', - t, writable_server_selector) + self.assertMessage("No primary available for writes", t, writable_server_selector) def test_no_secondary(self): - t = create_mock_topology(replica_set_name='rs') - got_hello(t, address, { - 'ok': 1, - HelloCompat.LEGACY_CMD: True, - 'setName': 'rs', - 'hosts': ['a']}) + t = create_mock_topology(replica_set_name="rs") + got_hello( + t, address, {"ok": 1, HelloCompat.LEGACY_CMD: True, "setName": "rs", "hosts": ["a"]} + ) self.assertMessage( - 'No replica set members match selector' + "No replica set members match selector" ' "Secondary(tag_sets=None, max_staleness=-1, hedge=None)"', - t, ReadPreference.SECONDARY) + t, + ReadPreference.SECONDARY, + ) self.assertMessage( "No replica set members match selector" " \"Secondary(tag_sets=[{'dc': 'ny'}], max_staleness=-1, " - "hedge=None)\"", - t, Secondary(tag_sets=[{'dc': 'ny'}])) + 'hedge=None)"', + t, + Secondary(tag_sets=[{"dc": "ny"}]), + ) def test_bad_replica_set_name(self): - t = create_mock_topology(replica_set_name='rs') - got_hello(t, address, { - 'ok': 1, - HelloCompat.LEGACY_CMD: False, - 'secondary': True, - 'setName': 'wrong', - 'hosts': ['a']}) + t = create_mock_topology(replica_set_name="rs") + got_hello( + t, + address, + { + "ok": 1, + HelloCompat.LEGACY_CMD: False, + "secondary": True, + "setName": "wrong", + "hosts": ["a"], + }, + ) - self.assertMessage( - 'No replica set members available for replica set name "rs"', t) + self.assertMessage('No replica set members available for replica set name "rs"', t) def test_multiple_standalones(self): # Standalones are removed from a topology with multiple seeds. - t = create_mock_topology(seeds=['a', 'b']) - got_hello(t, ('a', 27017), {'ok': 1}) - got_hello(t, ('b', 27017), {'ok': 1}) - self.assertMessage('No servers available', t) + t = create_mock_topology(seeds=["a", "b"]) + got_hello(t, ("a", 27017), {"ok": 1}) + got_hello(t, ("b", 27017), {"ok": 1}) + self.assertMessage("No servers available", t) def test_no_mongoses(self): # Standalones are removed from a topology with multiple seeds. - t = create_mock_topology(seeds=['a', 'b']) + t = create_mock_topology(seeds=["a", "b"]) # Discover a mongos and change topology type to Sharded. - got_hello(t, ('a', 27017), {'ok': 1, 'msg': 'isdbgrid'}) + got_hello(t, ("a", 27017), {"ok": 1, "msg": "isdbgrid"}) # Oops, both servers are standalone now. Remove them. - got_hello(t, ('a', 27017), {'ok': 1}) - got_hello(t, ('b', 27017), {'ok': 1}) - self.assertMessage('No mongoses available', t) + got_hello(t, ("a", 27017), {"ok": 1}) + got_hello(t, ("b", 27017), {"ok": 1}) + self.assertMessage("No mongoses available", t) if __name__ == "__main__": diff --git a/test/test_transactions.py b/test/test_transactions.py index 32f02f8437..54ae7eecce 100644 --- a/test/test_transactions.py +++ b/test/test_transactions.py @@ -16,35 +16,38 @@ import os import sys - from io import BytesIO sys.path[0:0] = [""] -from pymongo import client_session, WriteConcern +from test import client_context, unittest +from test.utils import ( + OvertCommandListener, + TestCreator, + rs_client, + single_client, + wait_until, +) +from test.utils_spec_runner import SpecRunner + +from gridfs import GridFS, GridFSBucket +from pymongo import WriteConcern, client_session from pymongo.client_session import TransactionOptions -from pymongo.errors import (CollectionInvalid, - ConfigurationError, - ConnectionFailure, - InvalidOperation, - OperationFailure) +from pymongo.errors import ( + CollectionInvalid, + ConfigurationError, + ConnectionFailure, + InvalidOperation, + OperationFailure, +) from pymongo.operations import IndexModel, InsertOne from pymongo.read_concern import ReadConcern from pymongo.read_preferences import ReadPreference -from gridfs import GridFS, GridFSBucket - -from test import unittest, client_context -from test.utils import (rs_client, single_client, - wait_until, OvertCommandListener, - TestCreator) -from test.utils_spec_runner import SpecRunner - # Location of JSON test specifications. -TEST_PATH = os.path.join( - os.path.dirname(os.path.realpath(__file__)), 'transactions', 'legacy') +TEST_PATH = os.path.join(os.path.dirname(os.path.realpath(__file__)), "transactions", "legacy") -_TXN_TESTS_DEBUG = os.environ.get('TRANSACTION_TESTS_DEBUG') +_TXN_TESTS_DEBUG = os.environ.get("TRANSACTION_TESTS_DEBUG") # Max number of operations to perform after a transaction to prove unpinning # occurs. Chosen so that there's a low false positive rate. With 2 mongoses, @@ -59,7 +62,7 @@ def setUpClass(cls): super(TransactionsBase, cls).setUpClass() if client_context.supports_transactions(): for address in client_context.mongoses: - cls.mongos_clients.append(single_client('%s:%s' % address)) + cls.mongos_clients.append(single_client("%s:%s" % address)) @classmethod def tearDownClass(cls): @@ -69,14 +72,17 @@ def tearDownClass(cls): def maybe_skip_scenario(self, test): super(TransactionsBase, self).maybe_skip_scenario(test) - if ('secondary' in self.id() and - not client_context.is_mongos and - not client_context.has_secondaries): - raise unittest.SkipTest('No secondaries') + if ( + "secondary" in self.id() + and not client_context.is_mongos + and not client_context.has_secondaries + ): + raise unittest.SkipTest("No secondaries") class TestTransactions(TransactionsBase): RUN_ON_SERVERLESS = True + @client_context.require_transactions def test_transaction_options_validation(self): default_options = TransactionOptions() @@ -85,23 +91,23 @@ def test_transaction_options_validation(self): self.assertIsNone(default_options.read_preference) self.assertIsNone(default_options.max_commit_time_ms) # No error when valid options are provided. - TransactionOptions(read_concern=ReadConcern(), - write_concern=WriteConcern(), - read_preference=ReadPreference.PRIMARY, - max_commit_time_ms=10000) + TransactionOptions( + read_concern=ReadConcern(), + write_concern=WriteConcern(), + read_preference=ReadPreference.PRIMARY, + max_commit_time_ms=10000, + ) with self.assertRaisesRegex(TypeError, "read_concern must be "): TransactionOptions(read_concern={}) with self.assertRaisesRegex(TypeError, "write_concern must be "): TransactionOptions(write_concern={}) with self.assertRaisesRegex( - ConfigurationError, - "transactions do not support unacknowledged write concern"): + ConfigurationError, "transactions do not support unacknowledged write concern" + ): TransactionOptions(write_concern=WriteConcern(w=0)) - with self.assertRaisesRegex( - TypeError, "is not valid for read_preference"): + with self.assertRaisesRegex(TypeError, "is not valid for read_preference"): TransactionOptions(read_preference={}) - with self.assertRaisesRegex( - TypeError, "max_commit_time_ms must be an integer or None"): + with self.assertRaisesRegex(TypeError, "max_commit_time_ms must be an integer or None"): TransactionOptions(max_commit_time_ms="10000") @client_context.require_transactions @@ -115,16 +121,11 @@ def test_transaction_write_concern_override(self): with client.start_session() as s: with s.start_transaction(write_concern=WriteConcern(w=1)): self.assertTrue(coll.insert_one({}, session=s).acknowledged) - self.assertTrue(coll.insert_many( - [{}, {}], session=s).acknowledged) - self.assertTrue(coll.bulk_write( - [InsertOne({})], session=s).acknowledged) - self.assertTrue(coll.replace_one( - {}, {}, session=s).acknowledged) - self.assertTrue(coll.update_one( - {}, {"$set": {"a": 1}}, session=s).acknowledged) - self.assertTrue(coll.update_many( - {}, {"$set": {"a": 1}}, session=s).acknowledged) + self.assertTrue(coll.insert_many([{}, {}], session=s).acknowledged) + self.assertTrue(coll.bulk_write([InsertOne({})], session=s).acknowledged) + self.assertTrue(coll.replace_one({}, {}, session=s).acknowledged) + self.assertTrue(coll.update_one({}, {"$set": {"a": 1}}, session=s).acknowledged) + self.assertTrue(coll.update_many({}, {"$set": {"a": 1}}, session=s).acknowledged) self.assertTrue(coll.delete_one({}, session=s).acknowledged) self.assertTrue(coll.delete_many({}, session=s).acknowledged) coll.find_one_and_delete({}, session=s) @@ -133,27 +134,29 @@ def test_transaction_write_concern_override(self): unsupported_txn_writes = [ (client.drop_database, [db.name], {}), - (db.drop_collection, ['collection'], {}), + (db.drop_collection, ["collection"], {}), (coll.drop, [], {}), - (coll.rename, ['collection2'], {}), + (coll.rename, ["collection2"], {}), # Drop collection2 between tests of "rename", above. - (coll.database.drop_collection, ['collection2'], {}), - (coll.create_indexes, [[IndexModel('a')]], {}), - (coll.create_index, ['a'], {}), - (coll.drop_index, ['a_1'], {}), + (coll.database.drop_collection, ["collection2"], {}), + (coll.create_indexes, [[IndexModel("a")]], {}), + (coll.create_index, ["a"], {}), + (coll.drop_index, ["a_1"], {}), (coll.drop_indexes, [], {}), (coll.aggregate, [[{"$out": "aggout"}]], {}), ] # Creating a collection in a transaction requires MongoDB 4.4+. if client_context.version < (4, 3, 4): - unsupported_txn_writes.extend([ - (db.create_collection, ['collection'], {}), - ]) + unsupported_txn_writes.extend( + [ + (db.create_collection, ["collection"], {}), + ] + ) for op in unsupported_txn_writes: op, args, kwargs = op with client.start_session() as s: - kwargs['session'] = s + kwargs["session"] = s s.start_transaction(write_concern=WriteConcern(w=1)) with self.assertRaises(OperationFailure): op(*args, **kwargs) @@ -164,8 +167,7 @@ def test_transaction_write_concern_override(self): def test_unpin_for_next_transaction(self): # Increase localThresholdMS and wait until both nodes are discovered # to avoid false positives. - client = rs_client(client_context.mongos_seeds(), - localThresholdMS=1000) + client = rs_client(client_context.mongos_seeds(), localThresholdMS=1000) wait_until(lambda: len(client.nodes) > 1, "discover both mongoses") coll = client.test.test # Create the collection. @@ -193,8 +195,7 @@ def test_unpin_for_next_transaction(self): def test_unpin_for_non_transaction_operation(self): # Increase localThresholdMS and wait until both nodes are discovered # to avoid false positives. - client = rs_client(client_context.mongos_seeds(), - localThresholdMS=1000) + client = rs_client(client_context.mongos_seeds(), localThresholdMS=1000) wait_until(lambda: len(client.nodes) > 1, "discover both mongoses") coll = client.test.test # Create the collection. @@ -255,46 +256,71 @@ def gridfs_find(*args, **kwargs): return gfs.find(*args, **kwargs).next() def gridfs_open_upload_stream(*args, **kwargs): - bucket.open_upload_stream(*args, **kwargs).write(b'1') + bucket.open_upload_stream(*args, **kwargs).write(b"1") gridfs_ops = [ - (gfs.put, (b'123',)), + (gfs.put, (b"123",)), (gfs.get, (1,)), - (gfs.get_version, ('name',)), - (gfs.get_last_version, ('name',)), - (gfs.delete, (1, )), + (gfs.get_version, ("name",)), + (gfs.get_last_version, ("name",)), + (gfs.delete, (1,)), (gfs.list, ()), (gfs.find_one, ()), (gridfs_find, ()), (gfs.exists, ()), - (gridfs_open_upload_stream, ('name',)), - (bucket.upload_from_stream, ('name', b'data',)), - (bucket.download_to_stream, (1, BytesIO(),)), - (bucket.download_to_stream_by_name, ('name', BytesIO(),)), + (gridfs_open_upload_stream, ("name",)), + ( + bucket.upload_from_stream, + ( + "name", + b"data", + ), + ), + ( + bucket.download_to_stream, + ( + 1, + BytesIO(), + ), + ), + ( + bucket.download_to_stream_by_name, + ( + "name", + BytesIO(), + ), + ), (bucket.delete, (1,)), (bucket.find, ()), (bucket.open_download_stream, (1,)), - (bucket.open_download_stream_by_name, ('name',)), - (bucket.rename, (1, 'new-name',)), + (bucket.open_download_stream_by_name, ("name",)), + ( + bucket.rename, + ( + 1, + "new-name", + ), + ), ] with client.start_session() as s, s.start_transaction(): for op, args in gridfs_ops: with self.assertRaisesRegex( - InvalidOperation, - 'GridFS does not support multi-document transactions', + InvalidOperation, + "GridFS does not support multi-document transactions", ): op(*args, session=s) # Require 4.2+ for large (16MB+) transactions. @client_context.require_version_min(4, 2) @client_context.require_transactions - @unittest.skipIf(sys.platform == 'win32', - 'Our Windows machines are too slow to pass this test') + @unittest.skipIf(sys.platform == "win32", "Our Windows machines are too slow to pass this test") def test_transaction_starts_with_batched_write(self): - if 'PyPy' in sys.version and client_context.tls: - self.skipTest('PYTHON-2937 PyPy is so slow sending large ' - 'messages over TLS that this test fails') + if "PyPy" in sys.version and client_context.tls: + self.skipTest( + "PYTHON-2937 PyPy is so slow sending large " + "messages over TLS that this test fails" + ) # Start a transaction with a batch of operations that needs to be # split. listener = OvertCommandListener() @@ -304,27 +330,29 @@ def test_transaction_starts_with_batched_write(self): listener.reset() self.addCleanup(client.close) self.addCleanup(coll.drop) - large_str = '\0'*(10*1024*1024) - ops = [InsertOne({'a': large_str}) for _ in range(10)] + large_str = "\0" * (10 * 1024 * 1024) + ops = [InsertOne({"a": large_str}) for _ in range(10)] with client.start_session() as session: with session.start_transaction(): coll.bulk_write(ops, session=session) # Assert commands were constructed properly. - self.assertEqual(['insert', 'insert', 'insert', 'commitTransaction'], - listener.started_command_names()) - first_cmd = listener.results['started'][0].command - self.assertTrue(first_cmd['startTransaction']) - lsid = first_cmd['lsid'] - txn_number = first_cmd['txnNumber'] - for event in listener.results['started'][1:]: - self.assertNotIn('startTransaction', event.command) - self.assertEqual(lsid, event.command['lsid']) - self.assertEqual(txn_number, event.command['txnNumber']) + self.assertEqual( + ["insert", "insert", "insert", "commitTransaction"], listener.started_command_names() + ) + first_cmd = listener.results["started"][0].command + self.assertTrue(first_cmd["startTransaction"]) + lsid = first_cmd["lsid"] + txn_number = first_cmd["txnNumber"] + for event in listener.results["started"][1:]: + self.assertNotIn("startTransaction", event.command) + self.assertEqual(lsid, event.command["lsid"]) + self.assertEqual(txn_number, event.command["txnNumber"]) self.assertEqual(10, coll.count_documents({})) class PatchSessionTimeout(object): """Patches the client_session's with_transaction timeout for testing.""" + def __init__(self, mock_timeout): self.real_timeout = client_session._WITH_TRANSACTION_RETRY_TIME_LIMIT self.mock_timeout = mock_timeout @@ -338,15 +366,18 @@ def __exit__(self, exc_type, exc_val, exc_tb): class TestTransactionsConvenientAPI(TransactionsBase): - TEST_PATH = os.path.join(os.path.dirname(os.path.realpath(__file__)), - 'transactions-convenient-api') + TEST_PATH = os.path.join( + os.path.dirname(os.path.realpath(__file__)), "transactions-convenient-api" + ) @client_context.require_transactions def test_callback_raises_custom_error(self): - class _MyException(Exception):pass + class _MyException(Exception): + pass def raise_error(_): raise _MyException() + with self.client.start_session() as s: with self.assertRaises(_MyException): s.with_transaction(raise_error) @@ -354,17 +385,19 @@ def raise_error(_): @client_context.require_transactions def test_callback_returns_value(self): def callback(_): - return 'Foo' + return "Foo" + with self.client.start_session() as s: - self.assertEqual(s.with_transaction(callback), 'Foo') + self.assertEqual(s.with_transaction(callback), "Foo") self.db.test.insert_one({}) def callback(session): self.db.test.insert_one({}, session=session) - return 'Foo' + return "Foo" + with self.client.start_session() as s: - self.assertEqual(s.with_transaction(callback), 'Foo') + self.assertEqual(s.with_transaction(callback), "Foo") @client_context.require_transactions def test_callback_not_retried_after_timeout(self): @@ -376,13 +409,13 @@ def test_callback_not_retried_after_timeout(self): def callback(session): coll.insert_one({}, session=session) err = { - 'ok': 0, - 'errmsg': 'Transaction 7819 has been aborted.', - 'code': 251, - 'codeName': 'NoSuchTransaction', - 'errorLabels': ['TransientTransactionError'], + "ok": 0, + "errmsg": "Transaction 7819 has been aborted.", + "code": 251, + "codeName": "NoSuchTransaction", + "errorLabels": ["TransientTransactionError"], } - raise OperationFailure(err['errmsg'], err['code'], err) + raise OperationFailure(err["errmsg"], err["code"], err) # Create the collection. coll.insert_one({}) @@ -392,8 +425,7 @@ def callback(session): with self.assertRaises(OperationFailure): s.with_transaction(callback) - self.assertEqual(listener.started_command_names(), - ['insert', 'abortTransaction']) + self.assertEqual(listener.started_command_names(), ["insert", "abortTransaction"]) @client_context.require_test_commands @client_context.require_transactions @@ -408,14 +440,17 @@ def callback(session): # Create the collection. coll.insert_one({}) - self.set_fail_point({ - 'configureFailPoint': 'failCommand', 'mode': {'times': 1}, - 'data': { - 'failCommands': ['commitTransaction'], - 'errorCode': 251, # NoSuchTransaction - }}) - self.addCleanup(self.set_fail_point, { - 'configureFailPoint': 'failCommand', 'mode': 'off'}) + self.set_fail_point( + { + "configureFailPoint": "failCommand", + "mode": {"times": 1}, + "data": { + "failCommands": ["commitTransaction"], + "errorCode": 251, # NoSuchTransaction + }, + } + ) + self.addCleanup(self.set_fail_point, {"configureFailPoint": "failCommand", "mode": "off"}) listener.results.clear() with client.start_session() as s: @@ -423,8 +458,7 @@ def callback(session): with self.assertRaises(OperationFailure): s.with_transaction(callback) - self.assertEqual(listener.started_command_names(), - ['insert', 'commitTransaction']) + self.assertEqual(listener.started_command_names(), ["insert", "commitTransaction"]) @client_context.require_test_commands @client_context.require_transactions @@ -439,13 +473,14 @@ def callback(session): # Create the collection. coll.insert_one({}) - self.set_fail_point({ - 'configureFailPoint': 'failCommand', 'mode': {'times': 2}, - 'data': { - 'failCommands': ['commitTransaction'], - 'closeConnection': True}}) - self.addCleanup(self.set_fail_point, { - 'configureFailPoint': 'failCommand', 'mode': 'off'}) + self.set_fail_point( + { + "configureFailPoint": "failCommand", + "mode": {"times": 2}, + "data": {"failCommands": ["commitTransaction"], "closeConnection": True}, + } + ) + self.addCleanup(self.set_fail_point, {"configureFailPoint": "failCommand", "mode": "off"}) listener.results.clear() with client.start_session() as s: @@ -455,8 +490,9 @@ def callback(session): # One insert for the callback and two commits (includes the automatic # retry). - self.assertEqual(listener.started_command_names(), - ['insert', 'commitTransaction', 'commitTransaction']) + self.assertEqual( + listener.started_command_names(), ["insert", "commitTransaction", "commitTransaction"] + ) # Tested here because this supports Motor's convenient transactions API. @client_context.require_transactions @@ -489,6 +525,7 @@ def test_in_transaction_property(self): # Using a callback def callback(session): self.assertTrue(session.in_transaction) + with client.start_session() as s: self.assertFalse(s.in_transaction) s.with_transaction(callback) @@ -508,8 +545,9 @@ def run_scenario(self): test_creator.create_tests() -TestCreator(create_test, TestTransactionsConvenientAPI, - TestTransactionsConvenientAPI.TEST_PATH).create_tests() +TestCreator( + create_test, TestTransactionsConvenientAPI, TestTransactionsConvenientAPI.TEST_PATH +).create_tests() if __name__ == "__main__": diff --git a/test/test_transactions_unified.py b/test/test_transactions_unified.py index 37e8d06153..4f3aa233fa 100644 --- a/test/test_transactions_unified.py +++ b/test/test_transactions_unified.py @@ -23,8 +23,7 @@ from test.unified_format import generate_test_classes # Location of JSON test specifications. -TEST_PATH = os.path.join( - os.path.dirname(os.path.realpath(__file__)), 'transactions', 'unified') +TEST_PATH = os.path.join(os.path.dirname(os.path.realpath(__file__)), "transactions", "unified") # Generate unified tests. globals().update(generate_test_classes(TEST_PATH, module=__name__)) diff --git a/test/test_unified_format.py b/test/test_unified_format.py index 74770b6f3a..e36959a224 100644 --- a/test/test_unified_format.py +++ b/test/test_unified_format.py @@ -17,35 +17,39 @@ sys.path[0:0] = [""] -from bson import ObjectId - from test import unittest -from test.unified_format import generate_test_classes, MatchEvaluatorUtil +from test.unified_format import MatchEvaluatorUtil, generate_test_classes +from bson import ObjectId -_TEST_PATH = os.path.join( - os.path.dirname(os.path.realpath(__file__)), 'unified-test-format') +_TEST_PATH = os.path.join(os.path.dirname(os.path.realpath(__file__)), "unified-test-format") -globals().update(generate_test_classes( - os.path.join(_TEST_PATH, 'valid-pass'), - module=__name__, - class_name_prefix='UnifiedTestFormat', - expected_failures=[ - 'Client side error in command starting transaction', # PYTHON-1894 - ], - RUN_ON_SERVERLESS=False)) +globals().update( + generate_test_classes( + os.path.join(_TEST_PATH, "valid-pass"), + module=__name__, + class_name_prefix="UnifiedTestFormat", + expected_failures=[ + "Client side error in command starting transaction", # PYTHON-1894 + ], + RUN_ON_SERVERLESS=False, + ) +) -globals().update(generate_test_classes( - os.path.join(_TEST_PATH, 'valid-fail'), - module=__name__, - class_name_prefix='UnifiedTestFormat', - bypass_test_generation_errors=True, - expected_failures=[ - '.*', # All tests expected to fail - ], - RUN_ON_SERVERLESS=False)) +globals().update( + generate_test_classes( + os.path.join(_TEST_PATH, "valid-fail"), + module=__name__, + class_name_prefix="UnifiedTestFormat", + bypass_test_generation_errors=True, + expected_failures=[ + ".*", # All tests expected to fail + ], + RUN_ON_SERVERLESS=False, + ) +) class TestMatchEvaluatorUtil(unittest.TestCase): @@ -53,22 +57,27 @@ def setUp(self): self.match_evaluator = MatchEvaluatorUtil(self) def test_unsetOrMatches(self): - spec = {'$$unsetOrMatches': {'y': {'$$unsetOrMatches': 2}}} - for actual in [{}, {'y': 2}, None]: + spec = {"$$unsetOrMatches": {"y": {"$$unsetOrMatches": 2}}} + for actual in [{}, {"y": 2}, None]: self.match_evaluator.match_result(spec, actual) - spec = {'x': {'$$unsetOrMatches': {'y': {'$$unsetOrMatches': 2}}}} - for actual in [{}, {'x': {}}, {'x': {'y': 2}}]: + spec = {"x": {"$$unsetOrMatches": {"y": {"$$unsetOrMatches": 2}}}} + for actual in [{}, {"x": {}}, {"x": {"y": 2}}]: self.match_evaluator.match_result(spec, actual) def test_type(self): self.match_evaluator.match_result( - {'operationType': 'insert', - 'ns': {'db': 'change-stream-tests', 'coll': 'test'}, - 'fullDocument': {'_id': {'$$type': 'objectId'}, 'x': 1}}, - {'operationType': 'insert', - 'fullDocument': {'_id': ObjectId('5fc93511ac93941052098f0c'), 'x': 1}, - 'ns': {'db': 'change-stream-tests', 'coll': 'test'}}) + { + "operationType": "insert", + "ns": {"db": "change-stream-tests", "coll": "test"}, + "fullDocument": {"_id": {"$$type": "objectId"}, "x": 1}, + }, + { + "operationType": "insert", + "fullDocument": {"_id": ObjectId("5fc93511ac93941052098f0c"), "x": 1}, + "ns": {"db": "change-stream-tests", "coll": "test"}, + }, + ) if __name__ == "__main__": diff --git a/test/test_uri_parser.py b/test/test_uri_parser.py index 7e00bd9760..23eac3bffd 100644 --- a/test/test_uri_parser.py +++ b/test/test_uri_parser.py @@ -21,444 +21,431 @@ sys.path[0:0] = [""] +from test import unittest + from bson.binary import JAVA_LEGACY from pymongo import ReadPreference from pymongo.errors import ConfigurationError, InvalidURI -from pymongo.uri_parser import (parse_userinfo, - split_hosts, - split_options, - parse_uri) -from test import unittest +from pymongo.uri_parser import parse_uri, parse_userinfo, split_hosts, split_options class TestURI(unittest.TestCase): - def test_validate_userinfo(self): - self.assertRaises(InvalidURI, parse_userinfo, - 'foo@') - self.assertRaises(InvalidURI, parse_userinfo, - ':password') - self.assertRaises(InvalidURI, parse_userinfo, - 'fo::o:p@ssword') - self.assertRaises(InvalidURI, parse_userinfo, ':') - self.assertTrue(parse_userinfo('user:password')) - self.assertEqual(('us:r', 'p@ssword'), - parse_userinfo('us%3Ar:p%40ssword')) - self.assertEqual(('us er', 'p ssword'), - parse_userinfo('us+er:p+ssword')) - self.assertEqual(('us er', 'p ssword'), - parse_userinfo('us%20er:p%20ssword')) - self.assertEqual(('us+er', 'p+ssword'), - parse_userinfo('us%2Ber:p%2Bssword')) - self.assertEqual(('dev1@FOO.COM', ''), - parse_userinfo('dev1%40FOO.COM')) - self.assertEqual(('dev1@FOO.COM', ''), - parse_userinfo('dev1%40FOO.COM:')) + self.assertRaises(InvalidURI, parse_userinfo, "foo@") + self.assertRaises(InvalidURI, parse_userinfo, ":password") + self.assertRaises(InvalidURI, parse_userinfo, "fo::o:p@ssword") + self.assertRaises(InvalidURI, parse_userinfo, ":") + self.assertTrue(parse_userinfo("user:password")) + self.assertEqual(("us:r", "p@ssword"), parse_userinfo("us%3Ar:p%40ssword")) + self.assertEqual(("us er", "p ssword"), parse_userinfo("us+er:p+ssword")) + self.assertEqual(("us er", "p ssword"), parse_userinfo("us%20er:p%20ssword")) + self.assertEqual(("us+er", "p+ssword"), parse_userinfo("us%2Ber:p%2Bssword")) + self.assertEqual(("dev1@FOO.COM", ""), parse_userinfo("dev1%40FOO.COM")) + self.assertEqual(("dev1@FOO.COM", ""), parse_userinfo("dev1%40FOO.COM:")) def test_split_hosts(self): - self.assertRaises(ConfigurationError, split_hosts, - 'localhost:27017,') - self.assertRaises(ConfigurationError, split_hosts, - ',localhost:27017') - self.assertRaises(ConfigurationError, split_hosts, - 'localhost:27017,,localhost:27018') - self.assertEqual([('localhost', 27017), ('example.com', 27017)], - split_hosts('localhost,example.com')) - self.assertEqual([('localhost', 27018), ('example.com', 27019)], - split_hosts('localhost:27018,example.com:27019')) - self.assertEqual([('/tmp/mongodb-27017.sock', None)], - split_hosts('/tmp/mongodb-27017.sock')) - self.assertEqual([('/tmp/mongodb-27017.sock', None), - ('example.com', 27017)], - split_hosts('/tmp/mongodb-27017.sock,' - 'example.com:27017')) - self.assertEqual([('example.com', 27017), - ('/tmp/mongodb-27017.sock', None)], - split_hosts('example.com:27017,' - '/tmp/mongodb-27017.sock')) - self.assertRaises(ValueError, split_hosts, '::1', 27017) - self.assertRaises(ValueError, split_hosts, '[::1:27017') - self.assertRaises(ValueError, split_hosts, '::1') - self.assertRaises(ValueError, split_hosts, '::1]:27017') - self.assertEqual([('::1', 27017)], split_hosts('[::1]:27017')) - self.assertEqual([('::1', 27017)], split_hosts('[::1]')) + self.assertRaises(ConfigurationError, split_hosts, "localhost:27017,") + self.assertRaises(ConfigurationError, split_hosts, ",localhost:27017") + self.assertRaises(ConfigurationError, split_hosts, "localhost:27017,,localhost:27018") + self.assertEqual( + [("localhost", 27017), ("example.com", 27017)], split_hosts("localhost,example.com") + ) + self.assertEqual( + [("localhost", 27018), ("example.com", 27019)], + split_hosts("localhost:27018,example.com:27019"), + ) + self.assertEqual( + [("/tmp/mongodb-27017.sock", None)], split_hosts("/tmp/mongodb-27017.sock") + ) + self.assertEqual( + [("/tmp/mongodb-27017.sock", None), ("example.com", 27017)], + split_hosts("/tmp/mongodb-27017.sock," "example.com:27017"), + ) + self.assertEqual( + [("example.com", 27017), ("/tmp/mongodb-27017.sock", None)], + split_hosts("example.com:27017," "/tmp/mongodb-27017.sock"), + ) + self.assertRaises(ValueError, split_hosts, "::1", 27017) + self.assertRaises(ValueError, split_hosts, "[::1:27017") + self.assertRaises(ValueError, split_hosts, "::1") + self.assertRaises(ValueError, split_hosts, "::1]:27017") + self.assertEqual([("::1", 27017)], split_hosts("[::1]:27017")) + self.assertEqual([("::1", 27017)], split_hosts("[::1]")) def test_split_options(self): - self.assertRaises(ConfigurationError, split_options, 'foo') - self.assertRaises(ConfigurationError, split_options, 'foo=bar;foo') - self.assertTrue(split_options('ssl=true')) - self.assertTrue(split_options('connect=true')) - self.assertTrue(split_options('tlsAllowInvalidHostnames=false')) + self.assertRaises(ConfigurationError, split_options, "foo") + self.assertRaises(ConfigurationError, split_options, "foo=bar;foo") + self.assertTrue(split_options("ssl=true")) + self.assertTrue(split_options("connect=true")) + self.assertTrue(split_options("tlsAllowInvalidHostnames=false")) # Test Invalid URI options that should throw warnings. with warnings.catch_warnings(): - warnings.filterwarnings('error') - self.assertRaises(Warning, split_options, - 'foo=bar', warn=True) - self.assertRaises(Warning, split_options, - 'socketTimeoutMS=foo', warn=True) - self.assertRaises(Warning, split_options, - 'socketTimeoutMS=0.0', warn=True) - self.assertRaises(Warning, split_options, - 'connectTimeoutMS=foo', warn=True) - self.assertRaises(Warning, split_options, - 'connectTimeoutMS=0.0', warn=True) - self.assertRaises(Warning, split_options, - 'connectTimeoutMS=1e100000', warn=True) - self.assertRaises(Warning, split_options, - 'connectTimeoutMS=-1e100000', warn=True) - self.assertRaises(Warning, split_options, - 'ssl=foo', warn=True) - self.assertRaises(Warning, split_options, - 'connect=foo', warn=True) - self.assertRaises(Warning, split_options, - 'tlsAllowInvalidHostnames=foo', warn=True) - self.assertRaises(Warning, split_options, - 'connectTimeoutMS=inf', warn=True) - self.assertRaises(Warning, split_options, - 'connectTimeoutMS=-inf', warn=True) - self.assertRaises(Warning, split_options, 'wtimeoutms=foo', - warn=True) - self.assertRaises(Warning, split_options, 'wtimeoutms=5.5', - warn=True) - self.assertRaises(Warning, split_options, 'fsync=foo', - warn=True) - self.assertRaises(Warning, split_options, 'fsync=5.5', - warn=True) - self.assertRaises(Warning, - split_options, 'authMechanism=foo', - warn=True) + warnings.filterwarnings("error") + self.assertRaises(Warning, split_options, "foo=bar", warn=True) + self.assertRaises(Warning, split_options, "socketTimeoutMS=foo", warn=True) + self.assertRaises(Warning, split_options, "socketTimeoutMS=0.0", warn=True) + self.assertRaises(Warning, split_options, "connectTimeoutMS=foo", warn=True) + self.assertRaises(Warning, split_options, "connectTimeoutMS=0.0", warn=True) + self.assertRaises(Warning, split_options, "connectTimeoutMS=1e100000", warn=True) + self.assertRaises(Warning, split_options, "connectTimeoutMS=-1e100000", warn=True) + self.assertRaises(Warning, split_options, "ssl=foo", warn=True) + self.assertRaises(Warning, split_options, "connect=foo", warn=True) + self.assertRaises(Warning, split_options, "tlsAllowInvalidHostnames=foo", warn=True) + self.assertRaises(Warning, split_options, "connectTimeoutMS=inf", warn=True) + self.assertRaises(Warning, split_options, "connectTimeoutMS=-inf", warn=True) + self.assertRaises(Warning, split_options, "wtimeoutms=foo", warn=True) + self.assertRaises(Warning, split_options, "wtimeoutms=5.5", warn=True) + self.assertRaises(Warning, split_options, "fsync=foo", warn=True) + self.assertRaises(Warning, split_options, "fsync=5.5", warn=True) + self.assertRaises(Warning, split_options, "authMechanism=foo", warn=True) # Test invalid options with warn=False. - self.assertRaises(ConfigurationError, split_options, 'foo=bar') - self.assertRaises(ValueError, split_options, 'socketTimeoutMS=foo') - self.assertRaises(ValueError, split_options, 'socketTimeoutMS=0.0') - self.assertRaises(ValueError, split_options, 'connectTimeoutMS=foo') - self.assertRaises(ValueError, split_options, 'connectTimeoutMS=0.0') - self.assertRaises(ValueError, split_options, - 'connectTimeoutMS=1e100000') - self.assertRaises(ValueError, split_options, - 'connectTimeoutMS=-1e100000') - self.assertRaises(ValueError, split_options, 'ssl=foo') - self.assertRaises(ValueError, split_options, 'connect=foo') - self.assertRaises(ValueError, split_options, 'tlsAllowInvalidHostnames=foo') - self.assertRaises(ValueError, split_options, 'connectTimeoutMS=inf') - self.assertRaises(ValueError, split_options, 'connectTimeoutMS=-inf') - self.assertRaises(ValueError, split_options, 'wtimeoutms=foo') - self.assertRaises(ValueError, split_options, 'wtimeoutms=5.5') - self.assertRaises(ValueError, split_options, 'fsync=foo') - self.assertRaises(ValueError, split_options, 'fsync=5.5') - self.assertRaises(ValueError, - split_options, 'authMechanism=foo') + self.assertRaises(ConfigurationError, split_options, "foo=bar") + self.assertRaises(ValueError, split_options, "socketTimeoutMS=foo") + self.assertRaises(ValueError, split_options, "socketTimeoutMS=0.0") + self.assertRaises(ValueError, split_options, "connectTimeoutMS=foo") + self.assertRaises(ValueError, split_options, "connectTimeoutMS=0.0") + self.assertRaises(ValueError, split_options, "connectTimeoutMS=1e100000") + self.assertRaises(ValueError, split_options, "connectTimeoutMS=-1e100000") + self.assertRaises(ValueError, split_options, "ssl=foo") + self.assertRaises(ValueError, split_options, "connect=foo") + self.assertRaises(ValueError, split_options, "tlsAllowInvalidHostnames=foo") + self.assertRaises(ValueError, split_options, "connectTimeoutMS=inf") + self.assertRaises(ValueError, split_options, "connectTimeoutMS=-inf") + self.assertRaises(ValueError, split_options, "wtimeoutms=foo") + self.assertRaises(ValueError, split_options, "wtimeoutms=5.5") + self.assertRaises(ValueError, split_options, "fsync=foo") + self.assertRaises(ValueError, split_options, "fsync=5.5") + self.assertRaises(ValueError, split_options, "authMechanism=foo") # Test splitting options works when valid. - self.assertTrue(split_options('socketTimeoutMS=300')) - self.assertTrue(split_options('connectTimeoutMS=300')) - self.assertEqual({'sockettimeoutms': 0.3}, - split_options('socketTimeoutMS=300')) - self.assertEqual({'sockettimeoutms': 0.0001}, - split_options('socketTimeoutMS=0.1')) - self.assertEqual({'connecttimeoutms': 0.3}, - split_options('connectTimeoutMS=300')) - self.assertEqual({'connecttimeoutms': 0.0001}, - split_options('connectTimeoutMS=0.1')) - self.assertTrue(split_options('connectTimeoutMS=300')) - self.assertTrue(isinstance(split_options('w=5')['w'], int)) - self.assertTrue(isinstance(split_options('w=5.5')['w'], str)) - self.assertTrue(split_options('w=foo')) - self.assertTrue(split_options('w=majority')) - self.assertTrue(split_options('wtimeoutms=500')) - self.assertEqual({'fsync': True}, split_options('fsync=true')) - self.assertEqual({'fsync': False}, split_options('fsync=false')) - self.assertEqual({'authmechanism': 'GSSAPI'}, - split_options('authMechanism=GSSAPI')) - self.assertEqual({'authmechanism': 'MONGODB-CR'}, - split_options('authMechanism=MONGODB-CR')) - self.assertEqual({'authmechanism': 'SCRAM-SHA-1'}, - split_options('authMechanism=SCRAM-SHA-1')) - self.assertEqual({'authsource': 'foobar'}, - split_options('authSource=foobar')) - self.assertEqual({'maxpoolsize': 50}, split_options('maxpoolsize=50')) + self.assertTrue(split_options("socketTimeoutMS=300")) + self.assertTrue(split_options("connectTimeoutMS=300")) + self.assertEqual({"sockettimeoutms": 0.3}, split_options("socketTimeoutMS=300")) + self.assertEqual({"sockettimeoutms": 0.0001}, split_options("socketTimeoutMS=0.1")) + self.assertEqual({"connecttimeoutms": 0.3}, split_options("connectTimeoutMS=300")) + self.assertEqual({"connecttimeoutms": 0.0001}, split_options("connectTimeoutMS=0.1")) + self.assertTrue(split_options("connectTimeoutMS=300")) + self.assertTrue(isinstance(split_options("w=5")["w"], int)) + self.assertTrue(isinstance(split_options("w=5.5")["w"], str)) + self.assertTrue(split_options("w=foo")) + self.assertTrue(split_options("w=majority")) + self.assertTrue(split_options("wtimeoutms=500")) + self.assertEqual({"fsync": True}, split_options("fsync=true")) + self.assertEqual({"fsync": False}, split_options("fsync=false")) + self.assertEqual({"authmechanism": "GSSAPI"}, split_options("authMechanism=GSSAPI")) + self.assertEqual({"authmechanism": "MONGODB-CR"}, split_options("authMechanism=MONGODB-CR")) + self.assertEqual( + {"authmechanism": "SCRAM-SHA-1"}, split_options("authMechanism=SCRAM-SHA-1") + ) + self.assertEqual({"authsource": "foobar"}, split_options("authSource=foobar")) + self.assertEqual({"maxpoolsize": 50}, split_options("maxpoolsize=50")) def test_parse_uri(self): self.assertRaises(InvalidURI, parse_uri, "http://foobar.com") self.assertRaises(InvalidURI, parse_uri, "http://foo@foobar.com") - self.assertRaises(ValueError, - parse_uri, "mongodb://::1", 27017) + self.assertRaises(ValueError, parse_uri, "mongodb://::1", 27017) orig = { - 'nodelist': [("localhost", 27017)], - 'username': None, - 'password': None, - 'database': None, - 'collection': None, - 'options': {}, - 'fqdn': None + "nodelist": [("localhost", 27017)], + "username": None, + "password": None, + "database": None, + "collection": None, + "options": {}, + "fqdn": None, } res = copy.deepcopy(orig) self.assertEqual(res, parse_uri("mongodb://localhost")) - res.update({'username': 'fred', 'password': 'foobar'}) + res.update({"username": "fred", "password": "foobar"}) self.assertEqual(res, parse_uri("mongodb://fred:foobar@localhost")) - res.update({'database': 'baz'}) + res.update({"database": "baz"}) self.assertEqual(res, parse_uri("mongodb://fred:foobar@localhost/baz")) res = copy.deepcopy(orig) - res['nodelist'] = [("example1.com", 27017), ("example2.com", 27017)] - self.assertEqual(res, - parse_uri("mongodb://example1.com:27017," - "example2.com:27017")) + res["nodelist"] = [("example1.com", 27017), ("example2.com", 27017)] + self.assertEqual(res, parse_uri("mongodb://example1.com:27017," "example2.com:27017")) res = copy.deepcopy(orig) - res['nodelist'] = [("localhost", 27017), - ("localhost", 27018), - ("localhost", 27019)] - self.assertEqual(res, - parse_uri("mongodb://localhost," - "localhost:27018,localhost:27019")) + res["nodelist"] = [("localhost", 27017), ("localhost", 27018), ("localhost", 27019)] + self.assertEqual(res, parse_uri("mongodb://localhost," "localhost:27018,localhost:27019")) res = copy.deepcopy(orig) - res['database'] = 'foo' + res["database"] = "foo" self.assertEqual(res, parse_uri("mongodb://localhost/foo")) res = copy.deepcopy(orig) self.assertEqual(res, parse_uri("mongodb://localhost/")) - res.update({'database': 'test', 'collection': 'yield_historical.in'}) - self.assertEqual(res, parse_uri("mongodb://" - "localhost/test.yield_historical.in")) + res.update({"database": "test", "collection": "yield_historical.in"}) + self.assertEqual(res, parse_uri("mongodb://" "localhost/test.yield_historical.in")) - res.update({'username': 'fred', 'password': 'foobar'}) - self.assertEqual(res, - parse_uri("mongodb://fred:foobar@localhost/" - "test.yield_historical.in")) + res.update({"username": "fred", "password": "foobar"}) + self.assertEqual( + res, parse_uri("mongodb://fred:foobar@localhost/" "test.yield_historical.in") + ) res = copy.deepcopy(orig) - res['nodelist'] = [("example1.com", 27017), ("example2.com", 27017)] - res.update({'database': 'test', 'collection': 'yield_historical.in'}) - self.assertEqual(res, - parse_uri("mongodb://example1.com:27017,example2.com" - ":27017/test.yield_historical.in")) + res["nodelist"] = [("example1.com", 27017), ("example2.com", 27017)] + res.update({"database": "test", "collection": "yield_historical.in"}) + self.assertEqual( + res, + parse_uri( + "mongodb://example1.com:27017,example2.com" ":27017/test.yield_historical.in" + ), + ) # Test socket path without escaped characters. - self.assertRaises(InvalidURI, parse_uri, - "mongodb:///tmp/mongodb-27017.sock") + self.assertRaises(InvalidURI, parse_uri, "mongodb:///tmp/mongodb-27017.sock") # Test with escaped characters. res = copy.deepcopy(orig) - res['nodelist'] = [("example2.com", 27017), - ("/tmp/mongodb-27017.sock", None)] - self.assertEqual(res, - parse_uri("mongodb://example2.com," - "%2Ftmp%2Fmongodb-27017.sock")) + res["nodelist"] = [("example2.com", 27017), ("/tmp/mongodb-27017.sock", None)] + self.assertEqual(res, parse_uri("mongodb://example2.com," "%2Ftmp%2Fmongodb-27017.sock")) res = copy.deepcopy(orig) - res['nodelist'] = [("shoe.sock.pants.co.uk", 27017), - ("/tmp/mongodb-27017.sock", None)] - res['database'] = "nethers_db" - self.assertEqual(res, - parse_uri("mongodb://shoe.sock.pants.co.uk," - "%2Ftmp%2Fmongodb-27017.sock/nethers_db")) + res["nodelist"] = [("shoe.sock.pants.co.uk", 27017), ("/tmp/mongodb-27017.sock", None)] + res["database"] = "nethers_db" + self.assertEqual( + res, + parse_uri("mongodb://shoe.sock.pants.co.uk," "%2Ftmp%2Fmongodb-27017.sock/nethers_db"), + ) res = copy.deepcopy(orig) - res['nodelist'] = [("/tmp/mongodb-27017.sock", None), - ("example2.com", 27017)] - res.update({'database': 'test', 'collection': 'yield_historical.in'}) - self.assertEqual(res, - parse_uri("mongodb://%2Ftmp%2Fmongodb-27017.sock," - "example2.com:27017" - "/test.yield_historical.in")) + res["nodelist"] = [("/tmp/mongodb-27017.sock", None), ("example2.com", 27017)] + res.update({"database": "test", "collection": "yield_historical.in"}) + self.assertEqual( + res, + parse_uri( + "mongodb://%2Ftmp%2Fmongodb-27017.sock," + "example2.com:27017" + "/test.yield_historical.in" + ), + ) res = copy.deepcopy(orig) - res['nodelist'] = [("/tmp/mongodb-27017.sock", None), - ("example2.com", 27017)] - res.update({'database': 'test', 'collection': 'yield_historical.sock'}) - self.assertEqual(res, - parse_uri("mongodb://%2Ftmp%2Fmongodb-27017.sock," - "example2.com:27017/test.yield_historical" - ".sock")) + res["nodelist"] = [("/tmp/mongodb-27017.sock", None), ("example2.com", 27017)] + res.update({"database": "test", "collection": "yield_historical.sock"}) + self.assertEqual( + res, + parse_uri( + "mongodb://%2Ftmp%2Fmongodb-27017.sock," + "example2.com:27017/test.yield_historical" + ".sock" + ), + ) res = copy.deepcopy(orig) - res['nodelist'] = [("example2.com", 27017)] - res.update({'database': 'test', 'collection': 'yield_historical.sock'}) - self.assertEqual(res, - parse_uri("mongodb://example2.com:27017" - "/test.yield_historical.sock")) + res["nodelist"] = [("example2.com", 27017)] + res.update({"database": "test", "collection": "yield_historical.sock"}) + self.assertEqual( + res, parse_uri("mongodb://example2.com:27017" "/test.yield_historical.sock") + ) res = copy.deepcopy(orig) - res['nodelist'] = [("/tmp/mongodb-27017.sock", None)] - res.update({'database': 'test', 'collection': 'mongodb-27017.sock'}) - self.assertEqual(res, - parse_uri("mongodb://%2Ftmp%2Fmongodb-27017.sock" - "/test.mongodb-27017.sock")) + res["nodelist"] = [("/tmp/mongodb-27017.sock", None)] + res.update({"database": "test", "collection": "mongodb-27017.sock"}) + self.assertEqual( + res, parse_uri("mongodb://%2Ftmp%2Fmongodb-27017.sock" "/test.mongodb-27017.sock") + ) res = copy.deepcopy(orig) - res['nodelist'] = [('/tmp/mongodb-27020.sock', None), - ("::1", 27017), - ("2001:0db8:85a3:0000:0000:8a2e:0370:7334", 27018), - ("192.168.0.212", 27019), - ("localhost", 27018)] - self.assertEqual(res, parse_uri("mongodb://%2Ftmp%2Fmongodb-27020.sock" - ",[::1]:27017,[2001:0db8:" - "85a3:0000:0000:8a2e:0370:7334]," - "192.168.0.212:27019,localhost", - 27018)) + res["nodelist"] = [ + ("/tmp/mongodb-27020.sock", None), + ("::1", 27017), + ("2001:0db8:85a3:0000:0000:8a2e:0370:7334", 27018), + ("192.168.0.212", 27019), + ("localhost", 27018), + ] + self.assertEqual( + res, + parse_uri( + "mongodb://%2Ftmp%2Fmongodb-27020.sock" + ",[::1]:27017,[2001:0db8:" + "85a3:0000:0000:8a2e:0370:7334]," + "192.168.0.212:27019,localhost", + 27018, + ), + ) res = copy.deepcopy(orig) - res.update({'username': 'fred', 'password': 'foobar'}) - res.update({'database': 'test', 'collection': 'yield_historical.in'}) - self.assertEqual(res, - parse_uri("mongodb://fred:foobar@localhost/" - "test.yield_historical.in")) + res.update({"username": "fred", "password": "foobar"}) + res.update({"database": "test", "collection": "yield_historical.in"}) + self.assertEqual( + res, parse_uri("mongodb://fred:foobar@localhost/" "test.yield_historical.in") + ) res = copy.deepcopy(orig) - res['database'] = 'test' - res['collection'] = 'name/with "delimiters' - self.assertEqual( - res, parse_uri("mongodb://localhost/test.name/with \"delimiters")) + res["database"] = "test" + res["collection"] = 'name/with "delimiters' + self.assertEqual(res, parse_uri('mongodb://localhost/test.name/with "delimiters')) res = copy.deepcopy(orig) - res['options'] = { - 'readpreference': ReadPreference.SECONDARY.mongos_mode - } - self.assertEqual(res, parse_uri( - "mongodb://localhost/?readPreference=secondary")) + res["options"] = {"readpreference": ReadPreference.SECONDARY.mongos_mode} + self.assertEqual(res, parse_uri("mongodb://localhost/?readPreference=secondary")) # Various authentication tests res = copy.deepcopy(orig) - res['options'] = {'authmechanism': 'MONGODB-CR'} - res['username'] = 'user' - res['password'] = 'password' - self.assertEqual(res, - parse_uri("mongodb://user:password@localhost/" - "?authMechanism=MONGODB-CR")) + res["options"] = {"authmechanism": "MONGODB-CR"} + res["username"] = "user" + res["password"] = "password" + self.assertEqual( + res, parse_uri("mongodb://user:password@localhost/" "?authMechanism=MONGODB-CR") + ) res = copy.deepcopy(orig) - res['options'] = {'authmechanism': 'MONGODB-CR', 'authsource': 'bar'} - res['username'] = 'user' - res['password'] = 'password' - res['database'] = 'foo' - self.assertEqual(res, - parse_uri("mongodb://user:password@localhost/foo" - "?authSource=bar;authMechanism=MONGODB-CR")) + res["options"] = {"authmechanism": "MONGODB-CR", "authsource": "bar"} + res["username"] = "user" + res["password"] = "password" + res["database"] = "foo" + self.assertEqual( + res, + parse_uri( + "mongodb://user:password@localhost/foo" "?authSource=bar;authMechanism=MONGODB-CR" + ), + ) res = copy.deepcopy(orig) - res['options'] = {'authmechanism': 'MONGODB-CR'} - res['username'] = 'user' - res['password'] = '' - self.assertEqual(res, - parse_uri("mongodb://user:@localhost/" - "?authMechanism=MONGODB-CR")) + res["options"] = {"authmechanism": "MONGODB-CR"} + res["username"] = "user" + res["password"] = "" + self.assertEqual(res, parse_uri("mongodb://user:@localhost/" "?authMechanism=MONGODB-CR")) res = copy.deepcopy(orig) - res['username'] = 'user@domain.com' - res['password'] = 'password' - res['database'] = 'foo' - self.assertEqual(res, - parse_uri("mongodb://user%40domain.com:password" - "@localhost/foo")) + res["username"] = "user@domain.com" + res["password"] = "password" + res["database"] = "foo" + self.assertEqual(res, parse_uri("mongodb://user%40domain.com:password" "@localhost/foo")) res = copy.deepcopy(orig) - res['options'] = {'authmechanism': 'GSSAPI'} - res['username'] = 'user@domain.com' - res['password'] = 'password' - res['database'] = 'foo' - self.assertEqual(res, - parse_uri("mongodb://user%40domain.com:password" - "@localhost/foo?authMechanism=GSSAPI")) + res["options"] = {"authmechanism": "GSSAPI"} + res["username"] = "user@domain.com" + res["password"] = "password" + res["database"] = "foo" + self.assertEqual( + res, + parse_uri("mongodb://user%40domain.com:password" "@localhost/foo?authMechanism=GSSAPI"), + ) res = copy.deepcopy(orig) - res['options'] = {'authmechanism': 'GSSAPI'} - res['username'] = 'user@domain.com' - res['password'] = '' - res['database'] = 'foo' - self.assertEqual(res, - parse_uri("mongodb://user%40domain.com" - "@localhost/foo?authMechanism=GSSAPI")) + res["options"] = {"authmechanism": "GSSAPI"} + res["username"] = "user@domain.com" + res["password"] = "" + res["database"] = "foo" + self.assertEqual( + res, parse_uri("mongodb://user%40domain.com" "@localhost/foo?authMechanism=GSSAPI") + ) res = copy.deepcopy(orig) - res['options'] = { - 'readpreference': ReadPreference.SECONDARY.mongos_mode, - 'readpreferencetags': [ - {'dc': 'west', 'use': 'website'}, - {'dc': 'east', 'use': 'website'} - ] + res["options"] = { + "readpreference": ReadPreference.SECONDARY.mongos_mode, + "readpreferencetags": [ + {"dc": "west", "use": "website"}, + {"dc": "east", "use": "website"}, + ], } - res['username'] = 'user@domain.com' - res['password'] = 'password' - res['database'] = 'foo' - self.assertEqual(res, - parse_uri("mongodb://user%40domain.com:password" - "@localhost/foo?readpreference=secondary&" - "readpreferencetags=dc:west,use:website&" - "readpreferencetags=dc:east,use:website")) + res["username"] = "user@domain.com" + res["password"] = "password" + res["database"] = "foo" + self.assertEqual( + res, + parse_uri( + "mongodb://user%40domain.com:password" + "@localhost/foo?readpreference=secondary&" + "readpreferencetags=dc:west,use:website&" + "readpreferencetags=dc:east,use:website" + ), + ) res = copy.deepcopy(orig) - res['options'] = { - 'readpreference': ReadPreference.SECONDARY.mongos_mode, - 'readpreferencetags': [ - {'dc': 'west', 'use': 'website'}, - {'dc': 'east', 'use': 'website'}, - {} - ] + res["options"] = { + "readpreference": ReadPreference.SECONDARY.mongos_mode, + "readpreferencetags": [ + {"dc": "west", "use": "website"}, + {"dc": "east", "use": "website"}, + {}, + ], } - res['username'] = 'user@domain.com' - res['password'] = 'password' - res['database'] = 'foo' - self.assertEqual(res, - parse_uri("mongodb://user%40domain.com:password" - "@localhost/foo?readpreference=secondary&" - "readpreferencetags=dc:west,use:website&" - "readpreferencetags=dc:east,use:website&" - "readpreferencetags=")) + res["username"] = "user@domain.com" + res["password"] = "password" + res["database"] = "foo" + self.assertEqual( + res, + parse_uri( + "mongodb://user%40domain.com:password" + "@localhost/foo?readpreference=secondary&" + "readpreferencetags=dc:west,use:website&" + "readpreferencetags=dc:east,use:website&" + "readpreferencetags=" + ), + ) res = copy.deepcopy(orig) - res['options'] = {'uuidrepresentation': JAVA_LEGACY} - res['username'] = 'user@domain.com' - res['password'] = 'password' - res['database'] = 'foo' - self.assertEqual(res, - parse_uri("mongodb://user%40domain.com:password" - "@localhost/foo?uuidrepresentation=" - "javaLegacy")) + res["options"] = {"uuidrepresentation": JAVA_LEGACY} + res["username"] = "user@domain.com" + res["password"] = "password" + res["database"] = "foo" + self.assertEqual( + res, + parse_uri( + "mongodb://user%40domain.com:password" + "@localhost/foo?uuidrepresentation=" + "javaLegacy" + ), + ) with warnings.catch_warnings(): - warnings.filterwarnings('error') - self.assertRaises(Warning, parse_uri, - "mongodb://user%40domain.com:password" - "@localhost/foo?uuidrepresentation=notAnOption", - warn=True) - self.assertRaises(ValueError, parse_uri, - "mongodb://user%40domain.com:password" - "@localhost/foo?uuidrepresentation=notAnOption") + warnings.filterwarnings("error") + self.assertRaises( + Warning, + parse_uri, + "mongodb://user%40domain.com:password" + "@localhost/foo?uuidrepresentation=notAnOption", + warn=True, + ) + self.assertRaises( + ValueError, + parse_uri, + "mongodb://user%40domain.com:password" "@localhost/foo?uuidrepresentation=notAnOption", + ) def test_parse_ssl_paths(self): # Turn off "validate" since these paths don't exist on filesystem. self.assertEqual( - {'collection': None, - 'database': None, - 'nodelist': [('/MongoDB.sock', None)], - 'options': {'tlsCertificateKeyFile': '/a/b'}, - 'password': 'foo/bar', - 'username': 'jesse', - 'fqdn': None}, + { + "collection": None, + "database": None, + "nodelist": [("/MongoDB.sock", None)], + "options": {"tlsCertificateKeyFile": "/a/b"}, + "password": "foo/bar", + "username": "jesse", + "fqdn": None, + }, parse_uri( - 'mongodb://jesse:foo%2Fbar@%2FMongoDB.sock/?tlsCertificateKeyFile=/a/b', - validate=False)) + "mongodb://jesse:foo%2Fbar@%2FMongoDB.sock/?tlsCertificateKeyFile=/a/b", + validate=False, + ), + ) self.assertEqual( - {'collection': None, - 'database': None, - 'nodelist': [('/MongoDB.sock', None)], - 'options': {'tlsCertificateKeyFile': 'a/b'}, - 'password': 'foo/bar', - 'username': 'jesse', - 'fqdn': None}, + { + "collection": None, + "database": None, + "nodelist": [("/MongoDB.sock", None)], + "options": {"tlsCertificateKeyFile": "a/b"}, + "password": "foo/bar", + "username": "jesse", + "fqdn": None, + }, parse_uri( - 'mongodb://jesse:foo%2Fbar@%2FMongoDB.sock/?tlsCertificateKeyFile=a/b', - validate=False)) + "mongodb://jesse:foo%2Fbar@%2FMongoDB.sock/?tlsCertificateKeyFile=a/b", + validate=False, + ), + ) def test_tlsinsecure_simple(self): # check that tlsInsecure is expanded correctly. @@ -467,59 +454,68 @@ def test_tlsinsecure_simple(self): res = { "tlsAllowInvalidHostnames": True, "tlsAllowInvalidCertificates": True, - "tlsInsecure": True, 'tlsDisableOCSPEndpointCheck': True} + "tlsInsecure": True, + "tlsDisableOCSPEndpointCheck": True, + } self.assertEqual(res, parse_uri(uri)["options"]) def test_normalize_options(self): # check that options are converted to their internal names correctly. - uri = ("mongodb://example.com/?ssl=true&appname=myapp") + uri = "mongodb://example.com/?ssl=true&appname=myapp" res = {"tls": True, "appname": "myapp"} self.assertEqual(res, parse_uri(uri)["options"]) def test_unquote_after_parsing(self): quoted_val = "val%21%40%23%24%25%5E%26%2A%28%29_%2B%2C%3A+etc" unquoted_val = "val!@#$%^&*()_+,: etc" - uri = ("mongodb://user:password@localhost/?authMechanism=MONGODB-AWS" - "&authMechanismProperties=AWS_SESSION_TOKEN:"+quoted_val) + uri = ( + "mongodb://user:password@localhost/?authMechanism=MONGODB-AWS" + "&authMechanismProperties=AWS_SESSION_TOKEN:" + quoted_val + ) res = parse_uri(uri) options = { - 'authmechanism': 'MONGODB-AWS', - 'authmechanismproperties': { - 'AWS_SESSION_TOKEN': unquoted_val}} - self.assertEqual(options, res['options']) - - uri = (("mongodb://localhost/foo?readpreference=secondary&" - "readpreferencetags=dc:west,"+quoted_val+":"+quoted_val+"&" - "readpreferencetags=dc:east,use:"+quoted_val)) + "authmechanism": "MONGODB-AWS", + "authmechanismproperties": {"AWS_SESSION_TOKEN": unquoted_val}, + } + self.assertEqual(options, res["options"]) + + uri = ( + "mongodb://localhost/foo?readpreference=secondary&" + "readpreferencetags=dc:west," + quoted_val + ":" + quoted_val + "&" + "readpreferencetags=dc:east,use:" + quoted_val + ) res = parse_uri(uri) options = { - 'readpreference': ReadPreference.SECONDARY.mongos_mode, - 'readpreferencetags': [ - {'dc': 'west', unquoted_val: unquoted_val}, - {'dc': 'east', 'use': unquoted_val} - ] + "readpreference": ReadPreference.SECONDARY.mongos_mode, + "readpreferencetags": [ + {"dc": "west", unquoted_val: unquoted_val}, + {"dc": "east", "use": unquoted_val}, + ], } - self.assertEqual(options, res['options']) + self.assertEqual(options, res["options"]) def test_redact_AWS_SESSION_TOKEN(self): unquoted_colon = "token:" - uri = ("mongodb://user:password@localhost/?authMechanism=MONGODB-AWS" - "&authMechanismProperties=AWS_SESSION_TOKEN:"+unquoted_colon) + uri = ( + "mongodb://user:password@localhost/?authMechanism=MONGODB-AWS" + "&authMechanismProperties=AWS_SESSION_TOKEN:" + unquoted_colon + ) with self.assertRaisesRegex( - ValueError, - 'auth mechanism properties must be key:value pairs like ' - 'SERVICE_NAME:mongodb, not AWS_SESSION_TOKEN:' - ', did you forget to percent-escape the token with ' - 'quote_plus?'): + ValueError, + "auth mechanism properties must be key:value pairs like " + "SERVICE_NAME:mongodb, not AWS_SESSION_TOKEN:" + ", did you forget to percent-escape the token with " + "quote_plus?", + ): parse_uri(uri) def test_special_chars(self): user = "user@ /9+:?~!$&'()*+,;=" pwd = "pwd@ /9+:?~!$&'()*+,;=" - uri = 'mongodb://%s:%s@localhost' % (quote_plus(user), quote_plus(pwd)) + uri = "mongodb://%s:%s@localhost" % (quote_plus(user), quote_plus(pwd)) res = parse_uri(uri) - self.assertEqual(user, res['username']) - self.assertEqual(pwd, res['password']) + self.assertEqual(user, res["username"]) + self.assertEqual(pwd, res["password"]) if __name__ == "__main__": diff --git a/test/test_uri_spec.py b/test/test_uri_spec.py index 59457b57ac..d12abf3b91 100644 --- a/test/test_uri_spec.py +++ b/test/test_uri_spec.py @@ -22,19 +22,18 @@ sys.path[0:0] = [""] +from test import clear_warning_registry, unittest + from pymongo.common import INTERNAL_URI_OPTION_NAME_MAP, validate from pymongo.compression_support import _HAVE_SNAPPY from pymongo.srv_resolver import _HAVE_DNSPYTHON -from pymongo.uri_parser import parse_uri, SRV_SCHEME -from test import clear_warning_registry, unittest - +from pymongo.uri_parser import SRV_SCHEME, parse_uri CONN_STRING_TEST_PATH = os.path.join( - os.path.dirname(os.path.realpath(__file__)), - os.path.join('connection_string', 'test')) + os.path.dirname(os.path.realpath(__file__)), os.path.join("connection_string", "test") +) -URI_OPTIONS_TEST_PATH = os.path.join( - os.path.dirname(os.path.realpath(__file__)), 'uri_options') +URI_OPTIONS_TEST_PATH = os.path.join(os.path.dirname(os.path.realpath(__file__)), "uri_options") TEST_DESC_SKIP_LIST = [ "Valid options specific to single-threaded drivers are parsed correctly", @@ -64,7 +63,8 @@ "tlsDisableOCSPEndpointCheck and tlsDisableCertificateRevocationCheck both present (and true) raises an error", "tlsDisableOCSPEndpointCheck=true and tlsDisableCertificateRevocationCheck=false raises an error", "tlsDisableOCSPEndpointCheck=false and tlsDisableCertificateRevocationCheck=true raises an error", - "tlsDisableOCSPEndpointCheck and tlsDisableCertificateRevocationCheck both present (and false) raises an error"] + "tlsDisableOCSPEndpointCheck and tlsDisableCertificateRevocationCheck both present (and false) raises an error", +] class TestAllScenarios(unittest.TestCase): @@ -73,8 +73,7 @@ def setUp(self): def get_error_message_template(expected, artefact): - return "%s %s for test '%s'" % ( - "Expected" if expected else "Unexpected", artefact, "%s") + return "%s %s for test '%s'" % ("Expected" if expected else "Unexpected", artefact, "%s") def run_scenario_in_dir(target_workdir): @@ -84,91 +83,107 @@ def modified_test_scenario(*args, **kwargs): os.chdir(target_workdir) func(*args, **kwargs) os.chdir(original_workdir) + return modified_test_scenario + return workdir_context_decorator def create_test(test, test_workdir): def run_scenario(self): - compressors = (test.get('options') or {}).get('compressors', []) - if 'snappy' in compressors and not _HAVE_SNAPPY: - self.skipTest('This test needs the snappy module.') - if test['uri'].startswith(SRV_SCHEME) and not _HAVE_DNSPYTHON: + compressors = (test.get("options") or {}).get("compressors", []) + if "snappy" in compressors and not _HAVE_SNAPPY: + self.skipTest("This test needs the snappy module.") + if test["uri"].startswith(SRV_SCHEME) and not _HAVE_DNSPYTHON: self.skipTest("This test needs dnspython package.") valid = True warning = False with warnings.catch_warnings(record=True) as ctx: - warnings.simplefilter('always') + warnings.simplefilter("always") try: - options = parse_uri(test['uri'], warn=True) + options = parse_uri(test["uri"], warn=True) except Exception: valid = False else: warning = len(ctx) > 0 - expected_valid = test.get('valid', True) + expected_valid = test.get("valid", True) self.assertEqual( - valid, expected_valid, get_error_message_template( - not expected_valid, "error") % test['description']) + valid, + expected_valid, + get_error_message_template(not expected_valid, "error") % test["description"], + ) if expected_valid: - expected_warning = test.get('warning', False) + expected_warning = test.get("warning", False) self.assertEqual( - warning, expected_warning, get_error_message_template( - expected_warning, "warning") % test['description']) + warning, + expected_warning, + get_error_message_template(expected_warning, "warning") % test["description"], + ) # Compare hosts and port. - if test['hosts'] is not None: + if test["hosts"] is not None: self.assertEqual( - len(test['hosts']), len(options['nodelist']), - "Incorrect number of hosts parsed from URI") - - for exp, actual in zip(test['hosts'], - options['nodelist']): - self.assertEqual(exp['host'], actual[0], - "Expected host %s but got %s" - % (exp['host'], actual[0])) - if exp['port'] is not None: - self.assertEqual(exp['port'], actual[1], - "Expected port %s but got %s" - % (exp['port'], actual)) + len(test["hosts"]), + len(options["nodelist"]), + "Incorrect number of hosts parsed from URI", + ) + + for exp, actual in zip(test["hosts"], options["nodelist"]): + self.assertEqual( + exp["host"], actual[0], "Expected host %s but got %s" % (exp["host"], actual[0]) + ) + if exp["port"] is not None: + self.assertEqual( + exp["port"], + actual[1], + "Expected port %s but got %s" % (exp["port"], actual), + ) # Compare auth options. - auth = test['auth'] + auth = test["auth"] if auth is not None: - auth['database'] = auth.pop('db') # db == database + auth["database"] = auth.pop("db") # db == database # Special case for PyMongo's collection parsing. - if options.get('collection') is not None: - options['database'] += "." + options['collection'] + if options.get("collection") is not None: + options["database"] += "." + options["collection"] for elm in auth: if auth[elm] is not None: # We have to do this because while the spec requires # "+"->"+", unquote_plus does "+"->" " options[elm] = options[elm].replace(" ", "+") - self.assertEqual(auth[elm], options[elm], - "Expected %s but got %s" - % (auth[elm], options[elm])) + self.assertEqual( + auth[elm], + options[elm], + "Expected %s but got %s" % (auth[elm], options[elm]), + ) # Compare URI options. err_msg = "For option %s expected %s but got %s" - if test['options']: - opts = options['options'] - for opt in test['options']: + if test["options"]: + opts = options["options"] + for opt in test["options"]: lopt = opt.lower() optname = INTERNAL_URI_OPTION_NAME_MAP.get(lopt, lopt) if opts.get(optname) is not None: - if opts[optname] == test['options'][opt]: - expected_value = test['options'][opt] + if opts[optname] == test["options"][opt]: + expected_value = test["options"][opt] else: - expected_value = validate( - lopt, test['options'][opt])[1] + expected_value = validate(lopt, test["options"][opt])[1] self.assertEqual( - opts[optname], expected_value, - err_msg % (opt, expected_value, opts[optname],)) + opts[optname], + expected_value, + err_msg + % ( + opt, + expected_value, + opts[optname], + ), + ) else: - self.fail( - "Missing expected option %s" % (opt,)) + self.fail("Missing expected option %s" % (opt,)) return run_scenario_in_dir(test_workdir)(run_scenario) @@ -176,27 +191,29 @@ def run_scenario(self): def create_tests(test_path): for dirpath, _, filenames in os.walk(test_path): dirname = os.path.split(dirpath) - dirname = os.path.split(dirname[-2])[-1] + '_' + dirname[-1] + dirname = os.path.split(dirname[-2])[-1] + "_" + dirname[-1] for filename in filenames: - if not filename.endswith('.json'): + if not filename.endswith(".json"): # skip everything that is not a test specification continue json_path = os.path.join(dirpath, filename) with open(json_path, encoding="utf-8") as scenario_stream: scenario_def = json.load(scenario_stream) - for testcase in scenario_def['tests']: - dsc = testcase['description'] + for testcase in scenario_def["tests"]: + dsc = testcase["description"] if dsc in TEST_DESC_SKIP_LIST: print("Skipping test '%s'" % dsc) continue testmethod = create_test(testcase, dirpath) - testname = 'test_%s_%s_%s' % ( - dirname, os.path.splitext(filename)[0], - str(dsc).replace(' ', '_')) + testname = "test_%s_%s_%s" % ( + dirname, + os.path.splitext(filename)[0], + str(dsc).replace(" ", "_"), + ) testmethod.__name__ = testname setattr(TestAllScenarios, testmethod.__name__, testmethod) diff --git a/test/test_versioned_api.py b/test/test_versioned_api.py index 44fc89ac73..a2fd059d21 100644 --- a/test/test_versioned_api.py +++ b/test/test_versioned_api.py @@ -17,16 +17,14 @@ sys.path[0:0] = [""] -from pymongo.mongo_client import MongoClient -from pymongo.server_api import ServerApi, ServerApiVersion - -from test import client_context, IntegrationTest, unittest +from test import IntegrationTest, client_context, unittest from test.unified_format import generate_test_classes from test.utils import OvertCommandListener, rs_or_single_client +from pymongo.mongo_client import MongoClient +from pymongo.server_api import ServerApi, ServerApiVersion -TEST_PATH = os.path.join( - os.path.dirname(os.path.realpath(__file__)), 'versioned-api') +TEST_PATH = os.path.join(os.path.dirname(os.path.realpath(__file__)), "versioned-api") # Generate unified tests. globals().update(generate_test_classes(TEST_PATH, module=__name__)) @@ -38,38 +36,38 @@ class TestServerApi(IntegrationTest): def test_server_api_defaults(self): api = ServerApi(ServerApiVersion.V1) - self.assertEqual(api.version, '1') + self.assertEqual(api.version, "1") self.assertIsNone(api.strict) self.assertIsNone(api.deprecation_errors) def test_server_api_explicit_false(self): - api = ServerApi('1', strict=False, deprecation_errors=False) - self.assertEqual(api.version, '1') + api = ServerApi("1", strict=False, deprecation_errors=False) + self.assertEqual(api.version, "1") self.assertFalse(api.strict) self.assertFalse(api.deprecation_errors) def test_server_api_strict(self): - api = ServerApi('1', strict=True, deprecation_errors=True) - self.assertEqual(api.version, '1') + api = ServerApi("1", strict=True, deprecation_errors=True) + self.assertEqual(api.version, "1") self.assertTrue(api.strict) self.assertTrue(api.deprecation_errors) def test_server_api_validation(self): with self.assertRaises(ValueError): - ServerApi('2') + ServerApi("2") with self.assertRaises(TypeError): - ServerApi('1', strict='not-a-bool') + ServerApi("1", strict="not-a-bool") with self.assertRaises(TypeError): - ServerApi('1', deprecation_errors='not-a-bool') + ServerApi("1", deprecation_errors="not-a-bool") with self.assertRaises(TypeError): - MongoClient(server_api='not-a-ServerApi') + MongoClient(server_api="not-a-ServerApi") def assertServerApi(self, event): - self.assertIn('apiVersion', event.command) - self.assertEqual(event.command['apiVersion'], '1') + self.assertIn("apiVersion", event.command) + self.assertEqual(event.command["apiVersion"], "1") def assertNoServerApi(self, event): - self.assertNotIn('apiVersion', event.command) + self.assertNotIn("apiVersion", event.command) def assertServerApiInAllCommands(self, events): for event in events: @@ -78,22 +76,20 @@ def assertServerApiInAllCommands(self, events): @client_context.require_version_min(4, 7) def test_command_options(self): listener = OvertCommandListener() - client = rs_or_single_client(server_api=ServerApi('1'), - event_listeners=[listener]) + client = rs_or_single_client(server_api=ServerApi("1"), event_listeners=[listener]) self.addCleanup(client.close) coll = client.test.test coll.insert_many([{} for _ in range(100)]) self.addCleanup(coll.delete_many, {}) list(coll.find(batch_size=25)) - client.admin.command('ping') - self.assertServerApiInAllCommands(listener.results['started']) + client.admin.command("ping") + self.assertServerApiInAllCommands(listener.results["started"]) @client_context.require_version_min(4, 7) @client_context.require_transactions def test_command_options_txn(self): listener = OvertCommandListener() - client = rs_or_single_client(server_api=ServerApi('1'), - event_listeners=[listener]) + client = rs_or_single_client(server_api=ServerApi("1"), event_listeners=[listener]) self.addCleanup(client.close) coll = client.test.test coll.insert_many([{} for _ in range(100)]) @@ -103,8 +99,8 @@ def test_command_options_txn(self): with client.start_session() as s, s.start_transaction(): coll.insert_many([{} for _ in range(100)], session=s) list(coll.find(batch_size=25, session=s)) - client.test.command('find', 'test', session=s) - self.assertServerApiInAllCommands(listener.results['started']) + client.test.command("find", "test", session=s) + self.assertServerApiInAllCommands(listener.results["started"]) if __name__ == "__main__": diff --git a/test/test_write_concern.py b/test/test_write_concern.py index 5e1c0f73f9..c2932fa4e9 100644 --- a/test/test_write_concern.py +++ b/test/test_write_concern.py @@ -22,7 +22,6 @@ class TestWriteConcern(unittest.TestCase): - def test_invalid(self): # Can't use fsync and j options together self.assertRaises(ConfigurationError, WriteConcern, j=True, fsync=True) @@ -41,9 +40,7 @@ def test_equality_to_none(self): self.assertTrue(concern != None) # noqa def test_equality_compatible_type(self): - class _FakeWriteConcern(object): - def __init__(self, **document): self.document = document @@ -66,9 +63,9 @@ def __ne__(self, other): self.assertNotEqual(WriteConcern(wtimeout=42), _FakeWriteConcern(wtimeout=2000)) def test_equality_incompatible_type(self): - _fake_type = collections.namedtuple('NotAWriteConcern', ['document']) - self.assertNotEqual(WriteConcern(j=True), _fake_type({'j': True})) + _fake_type = collections.namedtuple("NotAWriteConcern", ["document"]) + self.assertNotEqual(WriteConcern(j=True), _fake_type({"j": True})) -if __name__ == '__main__': +if __name__ == "__main__": unittest.main() diff --git a/test/unicode/test_utf8.py b/test/unicode/test_utf8.py index 65738d5c04..7ce2936b7a 100644 --- a/test/unicode/test_utf8.py +++ b/test/unicode/test_utf8.py @@ -2,9 +2,11 @@ sys.path[0:0] = [""] +from test import unittest + from bson import encode from bson.errors import InvalidStringData -from test import unittest + class TestUTF8(unittest.TestCase): @@ -12,18 +14,19 @@ class TestUTF8(unittest.TestCase): # legal utf-8 if the first byte is 0xf4 (244) def _assert_same_utf8_validation(self, data): try: - data.decode('utf-8') - py_is_legal = True + data.decode("utf-8") + py_is_legal = True except UnicodeDecodeError: py_is_legal = False try: - encode({'x': data}) - bson_is_legal = True + encode({"x": data}) + bson_is_legal = True except InvalidStringData: bson_is_legal = False self.assertEqual(py_is_legal, bson_is_legal, data) + if __name__ == "__main__": unittest.main() diff --git a/test/unified_format.py b/test/unified_format.py index 25a980425f..fd0e938df3 100644 --- a/test/unified_format.py +++ b/test/unified_format.py @@ -25,48 +25,64 @@ import sys import time import types - from collections import abc +from test import IntegrationTest, client_context, unittest +from test.utils import ( + CMAPListener, + camel_to_snake, + camel_to_snake_args, + get_pool, + parse_collection_options, + parse_spec_options, + prepare_spec_arguments, + rs_or_single_client, + single_client, + snake_to_camel, +) +from test.version import Version -from bson import json_util, Code, Decimal128, DBRef, SON, Int64, MaxKey, MinKey +from bson import SON, Code, DBRef, Decimal128, Int64, MaxKey, MinKey, json_util from bson.binary import Binary from bson.objectid import ObjectId -from bson.regex import Regex, RE_TYPE - +from bson.regex import RE_TYPE, Regex from gridfs import GridFSBucket - from pymongo import ASCENDING, MongoClient -from pymongo.client_session import ClientSession, TransactionOptions, _TxnState from pymongo.change_stream import ChangeStream +from pymongo.client_session import ClientSession, TransactionOptions, _TxnState from pymongo.collection import Collection from pymongo.database import Database from pymongo.errors import ( - BulkWriteError, ConnectionFailure, ConfigurationError, InvalidOperation, - NotPrimaryError, PyMongoError) + BulkWriteError, + ConfigurationError, + ConnectionFailure, + InvalidOperation, + NotPrimaryError, + PyMongoError, +) from pymongo.monitoring import ( - CommandFailedEvent, CommandListener, CommandStartedEvent, - CommandSucceededEvent, _SENSITIVE_COMMANDS, PoolCreatedEvent, - PoolReadyEvent, PoolClearedEvent, PoolClosedEvent, ConnectionCreatedEvent, - ConnectionReadyEvent, ConnectionClosedEvent, - ConnectionCheckOutStartedEvent, ConnectionCheckOutFailedEvent, - ConnectionCheckedOutEvent, ConnectionCheckedInEvent) + _SENSITIVE_COMMANDS, + CommandFailedEvent, + CommandListener, + CommandStartedEvent, + CommandSucceededEvent, + ConnectionCheckedInEvent, + ConnectionCheckedOutEvent, + ConnectionCheckOutFailedEvent, + ConnectionCheckOutStartedEvent, + ConnectionClosedEvent, + ConnectionCreatedEvent, + ConnectionReadyEvent, + PoolClearedEvent, + PoolClosedEvent, + PoolCreatedEvent, + PoolReadyEvent, +) from pymongo.read_concern import ReadConcern from pymongo.read_preferences import ReadPreference from pymongo.results import BulkWriteResult from pymongo.server_api import ServerApi from pymongo.write_concern import WriteConcern -from test import client_context, unittest, IntegrationTest -from test.utils import ( - camel_to_snake, get_pool, rs_or_single_client, single_client, - snake_to_camel, CMAPListener) - -from test.version import Version -from test.utils import ( - camel_to_snake_args, parse_collection_options, parse_spec_options, - prepare_spec_arguments) - - JSON_OPTS = json_util.JSONOptions(tz_aware=False) IS_INTERRUPTED = False @@ -86,14 +102,13 @@ def with_metaclass(meta, *bases): # metaclass for one level of class instantiation that replaces itself with # the actual metaclass. class metaclass(type): - def __new__(cls, name, this_bases, d): if sys.version_info[:2] >= (3, 7): # This version introduced PEP 560 that requires a bit # of extra care (we mimic what is done by __build_class__). resolved_bases = types.resolve_bases(bases) if resolved_bases is not bases: - d['__orig_bases__'] = bases + d["__orig_bases__"] = bases else: resolved_bases = bases return meta(name, resolved_bases, d) @@ -101,40 +116,38 @@ def __new__(cls, name, this_bases, d): @classmethod def __prepare__(cls, name, this_bases): return meta.__prepare__(name, bases) - return type.__new__(metaclass, 'temporary_class', (), {}) + + return type.__new__(metaclass, "temporary_class", (), {}) def is_run_on_requirement_satisfied(requirement): topology_satisfied = True - req_topologies = requirement.get('topologies') + req_topologies = requirement.get("topologies") if req_topologies: - topology_satisfied = client_context.is_topology_type( - req_topologies) + topology_satisfied = client_context.is_topology_type(req_topologies) server_version = Version(*client_context.version[:3]) min_version_satisfied = True - req_min_server_version = requirement.get('minServerVersion') + req_min_server_version = requirement.get("minServerVersion") if req_min_server_version: - min_version_satisfied = Version.from_string( - req_min_server_version) <= server_version + min_version_satisfied = Version.from_string(req_min_server_version) <= server_version max_version_satisfied = True - req_max_server_version = requirement.get('maxServerVersion') + req_max_server_version = requirement.get("maxServerVersion") if req_max_server_version: - max_version_satisfied = Version.from_string( - req_max_server_version) >= server_version + max_version_satisfied = Version.from_string(req_max_server_version) >= server_version - serverless = requirement.get('serverless') + serverless = requirement.get("serverless") if serverless == "require": serverless_satisfied = client_context.serverless elif serverless == "forbid": serverless_satisfied = not client_context.serverless - else: # unset or "allow" + else: # unset or "allow" serverless_satisfied = True params_satisfied = True - params = requirement.get('serverParameters') + params = requirement.get("serverParameters") if params: for param, val in params.items(): if param not in client_context.server_parameters: @@ -143,16 +156,21 @@ def is_run_on_requirement_satisfied(requirement): params_satisfied = False auth_satisfied = True - req_auth = requirement.get('auth') + req_auth = requirement.get("auth") if req_auth is not None: if req_auth: auth_satisfied = client_context.auth_enabled else: auth_satisfied = not client_context.auth_enabled - return (topology_satisfied and min_version_satisfied and - max_version_satisfied and serverless_satisfied and - params_satisfied and auth_satisfied) + return ( + topology_satisfied + and min_version_satisfied + and max_version_satisfied + and serverless_satisfied + and params_satisfied + and auth_satisfied + ) def parse_collection_or_database_options(options): @@ -160,15 +178,15 @@ def parse_collection_or_database_options(options): def parse_bulk_write_result(result): - upserted_ids = {str(int_idx): result.upserted_ids[int_idx] - for int_idx in result.upserted_ids} + upserted_ids = {str(int_idx): result.upserted_ids[int_idx] for int_idx in result.upserted_ids} return { - 'deletedCount': result.deleted_count, - 'insertedCount': result.inserted_count, - 'matchedCount': result.matched_count, - 'modifiedCount': result.modified_count, - 'upsertedCount': result.upserted_count, - 'upsertedIds': upserted_ids} + "deletedCount": result.deleted_count, + "insertedCount": result.inserted_count, + "matchedCount": result.matched_count, + "modifiedCount": result.modified_count, + "upsertedCount": result.upserted_count, + "upsertedIds": upserted_ids, + } def parse_bulk_write_error_result(error): @@ -178,6 +196,7 @@ def parse_bulk_write_error_result(error): class NonLazyCursor(object): """A find cursor proxy that creates the remote cursor when initialized.""" + def __init__(self, find_cursor): self.find_cursor = find_cursor # Create the server side cursor. @@ -195,8 +214,9 @@ def close(self): class EventListenerUtil(CMAPListener, CommandListener): - def __init__(self, observe_events, ignore_commands, - observe_sensitive_commands, store_events, entity_map): + def __init__( + self, observe_events, ignore_commands, observe_sensitive_commands, store_events, entity_map + ): self._event_types = set(name.lower() for name in observe_events) if observe_sensitive_commands: self._observe_sensitive_commands = True @@ -204,7 +224,7 @@ def __init__(self, observe_events, ignore_commands, else: self._observe_sensitive_commands = False self._ignore_commands = _SENSITIVE_COMMANDS | set(ignore_commands) - self._ignore_commands.add('configurefailpoint') + self._ignore_commands.add("configurefailpoint") self._event_mapping = collections.defaultdict(list) self.entity_map = entity_map if store_events: @@ -217,20 +237,22 @@ def __init__(self, observe_events, ignore_commands, super(EventListenerUtil, self).__init__() def get_events(self, event_type): - if event_type == 'command': - return [e for e in self.events if 'Command' in type(e).__name__] - return [e for e in self.events if 'Command' not in type(e).__name__] + if event_type == "command": + return [e for e in self.events if "Command" in type(e).__name__] + return [e for e in self.events if "Command" not in type(e).__name__] def add_event(self, event): event_name = type(event).__name__.lower() if event_name in self._event_types: super(EventListenerUtil, self).add_event(event) for id in self._event_mapping[event_name]: - self.entity_map[id].append({ - "name": type(event).__name__, - "observedAt": time.time(), - "description": repr(event) - }) + self.entity_map[id].append( + { + "name": type(event).__name__, + "observedAt": time.time(), + "description": repr(event), + } + ) def _command_event(self, event): if event.command_name.lower() not in self._ignore_commands: @@ -259,6 +281,7 @@ def failed(self, event): class EntityMapUtil(object): """Utility class that implements an entity map as per the unified test format specification.""" + def __init__(self, test_class): self._entities = {} self._listeners = {} @@ -275,102 +298,100 @@ def __getitem__(self, item): try: return self._entities[item] except KeyError: - self.test.fail('Could not find entity named %s in map' % ( - item,)) + self.test.fail("Could not find entity named %s in map" % (item,)) def __setitem__(self, key, value): if not isinstance(key, str): - self.test.fail( - 'Expected entity name of type str, got %s' % (type(key))) + self.test.fail("Expected entity name of type str, got %s" % (type(key))) if key in self._entities: - self.test.fail('Entity named %s already in map' % (key,)) + self.test.fail("Entity named %s already in map" % (key,)) self._entities[key] = value def _create_entity(self, entity_spec, uri=None): if len(entity_spec) != 1: self.test.fail( - "Entity spec %s did not contain exactly one top-level key" % ( - entity_spec,)) + "Entity spec %s did not contain exactly one top-level key" % (entity_spec,) + ) entity_type, spec = next(iter(entity_spec.items())) - if entity_type == 'client': + if entity_type == "client": kwargs = {} - observe_events = spec.get('observeEvents', []) - ignore_commands = spec.get('ignoreCommandMonitoringEvents', []) - observe_sensitive_commands = spec.get( - 'observeSensitiveCommands', False) + observe_events = spec.get("observeEvents", []) + ignore_commands = spec.get("ignoreCommandMonitoringEvents", []) + observe_sensitive_commands = spec.get("observeSensitiveCommands", False) ignore_commands = [cmd.lower() for cmd in ignore_commands] listener = EventListenerUtil( - observe_events, ignore_commands, + observe_events, + ignore_commands, observe_sensitive_commands, - spec.get("storeEventsAsEntities"), self) - self._listeners[spec['id']] = listener - kwargs['event_listeners'] = [listener] - if spec.get('useMultipleMongoses'): + spec.get("storeEventsAsEntities"), + self, + ) + self._listeners[spec["id"]] = listener + kwargs["event_listeners"] = [listener] + if spec.get("useMultipleMongoses"): if client_context.load_balancer or client_context.serverless: - kwargs['h'] = client_context.MULTI_MONGOS_LB_URI + kwargs["h"] = client_context.MULTI_MONGOS_LB_URI elif client_context.is_mongos: - kwargs['h'] = client_context.mongos_seeds() - kwargs.update(spec.get('uriOptions', {})) - server_api = spec.get('serverApi') + kwargs["h"] = client_context.mongos_seeds() + kwargs.update(spec.get("uriOptions", {})) + server_api = spec.get("serverApi") if server_api: - kwargs['server_api'] = ServerApi( - server_api['version'], strict=server_api.get('strict'), - deprecation_errors=server_api.get('deprecationErrors')) + kwargs["server_api"] = ServerApi( + server_api["version"], + strict=server_api.get("strict"), + deprecation_errors=server_api.get("deprecationErrors"), + ) if uri: - kwargs['h'] = uri + kwargs["h"] = uri client = rs_or_single_client(**kwargs) - self[spec['id']] = client + self[spec["id"]] = client self.test.addCleanup(client.close) return - elif entity_type == 'database': - client = self[spec['client']] + elif entity_type == "database": + client = self[spec["client"]] if not isinstance(client, MongoClient): self.test.fail( - 'Expected entity %s to be of type MongoClient, got %s' % ( - spec['client'], type(client))) - options = parse_collection_or_database_options( - spec.get('databaseOptions', {})) - self[spec['id']] = client.get_database( - spec['databaseName'], **options) + "Expected entity %s to be of type MongoClient, got %s" + % (spec["client"], type(client)) + ) + options = parse_collection_or_database_options(spec.get("databaseOptions", {})) + self[spec["id"]] = client.get_database(spec["databaseName"], **options) return - elif entity_type == 'collection': - database = self[spec['database']] + elif entity_type == "collection": + database = self[spec["database"]] if not isinstance(database, Database): self.test.fail( - 'Expected entity %s to be of type Database, got %s' % ( - spec['database'], type(database))) - options = parse_collection_or_database_options( - spec.get('collectionOptions', {})) - self[spec['id']] = database.get_collection( - spec['collectionName'], **options) + "Expected entity %s to be of type Database, got %s" + % (spec["database"], type(database)) + ) + options = parse_collection_or_database_options(spec.get("collectionOptions", {})) + self[spec["id"]] = database.get_collection(spec["collectionName"], **options) return - elif entity_type == 'session': - client = self[spec['client']] + elif entity_type == "session": + client = self[spec["client"]] if not isinstance(client, MongoClient): self.test.fail( - 'Expected entity %s to be of type MongoClient, got %s' % ( - spec['client'], type(client))) - opts = camel_to_snake_args(spec.get('sessionOptions', {})) - if 'default_transaction_options' in opts: - txn_opts = parse_spec_options( - opts['default_transaction_options']) + "Expected entity %s to be of type MongoClient, got %s" + % (spec["client"], type(client)) + ) + opts = camel_to_snake_args(spec.get("sessionOptions", {})) + if "default_transaction_options" in opts: + txn_opts = parse_spec_options(opts["default_transaction_options"]) txn_opts = TransactionOptions(**txn_opts) opts = copy.deepcopy(opts) - opts['default_transaction_options'] = txn_opts + opts["default_transaction_options"] = txn_opts session = client.start_session(**dict(opts)) - self[spec['id']] = session - self._session_lsids[spec['id']] = copy.deepcopy(session.session_id) + self[spec["id"]] = session + self._session_lsids[spec["id"]] = copy.deepcopy(session.session_id) self.test.addCleanup(session.end_session) return - elif entity_type == 'bucket': + elif entity_type == "bucket": # TODO: implement the 'bucket' entity type - self.test.skipTest( - 'GridFS is not currently supported (PYTHON-2459)') - self.test.fail( - 'Unable to create entity of unknown type %s' % (entity_type,)) + self.test.skipTest("GridFS is not currently supported (PYTHON-2459)") + self.test.fail("Unable to create entity of unknown type %s" % (entity_type,)) def create_entities_from_spec(self, entity_spec, uri=None): for spec in entity_spec: @@ -380,13 +401,12 @@ def get_listener_for_client(self, client_name): client = self[client_name] if not isinstance(client, MongoClient): self.test.fail( - 'Expected entity %s to be of type MongoClient, got %s' % ( - client_name, type(client))) + "Expected entity %s to be of type MongoClient, got %s" % (client_name, type(client)) + ) listener = self._listeners.get(client_name) if not listener: - self.test.fail( - 'No listeners configured for client %s' % (client_name,)) + self.test.fail("No listeners configured for client %s" % (client_name,)) return listener @@ -394,8 +414,9 @@ def get_lsid_for_session(self, session_name): session = self[session_name] if not isinstance(session, ClientSession): self.test.fail( - 'Expected entity %s to be of type ClientSession, got %s' % ( - session_name, type(session))) + "Expected entity %s to be of type ClientSession, got %s" + % (session_name, type(session)) + ) try: return session.session_id @@ -412,32 +433,33 @@ def get_lsid_for_session(self, session_name): BSON_TYPE_ALIAS_MAP = { # https://docs.mongodb.com/manual/reference/operator/query/type/ # https://pymongo.readthedocs.io/en/stable/api/bson/index.html - 'double': (float,), - 'string': (str,), - 'object': (abc.Mapping,), - 'array': (abc.MutableSequence,), - 'binData': binary_types, - 'undefined': (type(None),), - 'objectId': (ObjectId,), - 'bool': (bool,), - 'date': (datetime.datetime,), - 'null': (type(None),), - 'regex': (Regex, RE_TYPE), - 'dbPointer': (DBRef,), - 'javascript': (unicode_type, Code), - 'symbol': (unicode_type,), - 'javascriptWithScope': (unicode_type, Code), - 'int': (int,), - 'long': (Int64,), - 'decimal': (Decimal128,), - 'maxKey': (MaxKey,), - 'minKey': (MinKey,), + "double": (float,), + "string": (str,), + "object": (abc.Mapping,), + "array": (abc.MutableSequence,), + "binData": binary_types, + "undefined": (type(None),), + "objectId": (ObjectId,), + "bool": (bool,), + "date": (datetime.datetime,), + "null": (type(None),), + "regex": (Regex, RE_TYPE), + "dbPointer": (DBRef,), + "javascript": (unicode_type, Code), + "symbol": (unicode_type,), + "javascriptWithScope": (unicode_type, Code), + "int": (int,), + "long": (Int64,), + "decimal": (Decimal128,), + "maxKey": (MaxKey,), + "minKey": (MinKey,), } class MatchEvaluatorUtil(object): """Utility class that implements methods for evaluating matches as per the unified test format specification.""" + def __init__(self, test_class): self.test = test_class @@ -447,19 +469,18 @@ def _operation_exists(self, spec, actual, key_to_compare): elif spec is False: self.test.assertNotIn(key_to_compare, actual) else: - self.test.fail( - 'Expected boolean value for $$exists operator, got %s' % ( - spec,)) + self.test.fail("Expected boolean value for $$exists operator, got %s" % (spec,)) def __type_alias_to_type(self, alias): if alias not in BSON_TYPE_ALIAS_MAP: - self.test.fail('Unrecognized BSON type alias %s' % (alias,)) + self.test.fail("Unrecognized BSON type alias %s" % (alias,)) return BSON_TYPE_ALIAS_MAP[alias] def _operation_type(self, spec, actual, key_to_compare): if isinstance(spec, abc.MutableSequence): - permissible_types = tuple([ - t for alias in spec for t in self.__type_alias_to_type(alias)]) + permissible_types = tuple( + [t for alias in spec for t in self.__type_alias_to_type(alias)] + ) else: permissible_types = self.__type_alias_to_type(spec) value = actual[key_to_compare] if key_to_compare else actual @@ -480,7 +501,7 @@ def _operation_unsetOrMatches(self, spec, actual, key_to_compare): if key_to_compare not in actual: # we add a dummy value for the compared key to pass map size check - actual[key_to_compare] = 'dummyValue' + actual[key_to_compare] = "dummyValue" return self.match_result(spec, actual[key_to_compare], in_recursive_call=True) @@ -488,19 +509,16 @@ def _operation_sessionLsid(self, spec, actual, key_to_compare): expected_lsid = self.test.entity_map.get_lsid_for_session(spec) self.test.assertEqual(expected_lsid, actual[key_to_compare]) - def _evaluate_special_operation(self, opname, spec, actual, - key_to_compare): - method_name = '_operation_%s' % (opname.strip('$'),) + def _evaluate_special_operation(self, opname, spec, actual, key_to_compare): + method_name = "_operation_%s" % (opname.strip("$"),) try: method = getattr(self, method_name) except AttributeError: - self.test.fail( - 'Unsupported special matching operator %s' % (opname,)) + self.test.fail("Unsupported special matching operator %s" % (opname,)) else: method(spec, actual, key_to_compare) - def _evaluate_if_special_operation(self, expectation, actual, - key_to_compare=None): + def _evaluate_if_special_operation(self, expectation, actual, key_to_compare=None): """Returns True if a special operation is evaluated, False otherwise. If the ``expectation`` map contains a single key, value pair we check it for a special operation. @@ -514,7 +532,7 @@ def _evaluate_if_special_operation(self, expectation, actual, is_special_op, opname, spec = False, False, False if key_to_compare is not None: - if key_to_compare.startswith('$$'): + if key_to_compare.startswith("$$"): is_special_op = True opname = key_to_compare spec = expectation[key_to_compare] @@ -523,20 +541,18 @@ def _evaluate_if_special_operation(self, expectation, actual, nested = expectation[key_to_compare] if isinstance(nested, abc.Mapping) and len(nested) == 1: opname, spec = next(iter(nested.items())) - if opname.startswith('$$'): + if opname.startswith("$$"): is_special_op = True elif len(expectation) == 1: opname, spec = next(iter(expectation.items())) - if opname.startswith('$$'): + if opname.startswith("$$"): is_special_op = True key_to_compare = None if is_special_op: self._evaluate_special_operation( - opname=opname, - spec=spec, - actual=actual, - key_to_compare=key_to_compare) + opname=opname, spec=spec, actual=actual, key_to_compare=key_to_compare + ) return True return False @@ -556,37 +572,33 @@ def _match_document(self, expectation, actual, is_root): if not is_root: expected_keys = set(expectation.keys()) for key, value in expectation.items(): - if value == {'$$exists': False}: + if value == {"$$exists": False}: expected_keys.remove(key) self.test.assertEqual(expected_keys, set(actual.keys())) - def match_result(self, expectation, actual, - in_recursive_call=False): + def match_result(self, expectation, actual, in_recursive_call=False): if isinstance(expectation, abc.Mapping): - return self._match_document( - expectation, actual, is_root=not in_recursive_call) + return self._match_document(expectation, actual, is_root=not in_recursive_call) if isinstance(expectation, abc.MutableSequence): self.test.assertIsInstance(actual, abc.MutableSequence) for e, a in zip(expectation, actual): if isinstance(e, abc.Mapping): - self._match_document( - e, a, is_root=not in_recursive_call) + self._match_document(e, a, is_root=not in_recursive_call) else: self.match_result(e, a, in_recursive_call=True) return # account for flexible numerics in element-wise comparison - if (isinstance(expectation, int) or - isinstance(expectation, float)): + if isinstance(expectation, int) or isinstance(expectation, float): self.test.assertEqual(expectation, actual) else: self.test.assertIsInstance(actual, type(expectation)) self.test.assertEqual(expectation, actual) def assertHasServiceId(self, spec, actual): - if 'hasServiceId' in spec: - if spec.get('hasServiceId'): + if "hasServiceId" in spec: + if spec.get("hasServiceId"): self.test.assertIsNotNone(actual.service_id) self.test.assertIsInstance(actual.service_id, ObjectId) else: @@ -596,85 +608,83 @@ def match_event(self, event_type, expectation, actual): name, spec = next(iter(expectation.items())) # every command event has the commandName field - if event_type == 'command': - command_name = spec.get('commandName') + if event_type == "command": + command_name = spec.get("commandName") if command_name: self.test.assertEqual(command_name, actual.command_name) - if name == 'commandStartedEvent': + if name == "commandStartedEvent": self.test.assertIsInstance(actual, CommandStartedEvent) - command = spec.get('command') - database_name = spec.get('databaseName') + command = spec.get("command") + database_name = spec.get("databaseName") if command: - if actual.command_name == 'update': + if actual.command_name == "update": # TODO: remove this once PYTHON-1744 is done. # Add upsert and multi fields back into expectations. - for update in command.get('updates', []): - update.setdefault('upsert', False) - update.setdefault('multi', False) + for update in command.get("updates", []): + update.setdefault("upsert", False) + update.setdefault("multi", False) self.match_result(command, actual.command) if database_name: - self.test.assertEqual( - database_name, actual.database_name) + self.test.assertEqual(database_name, actual.database_name) self.assertHasServiceId(spec, actual) - elif name == 'commandSucceededEvent': + elif name == "commandSucceededEvent": self.test.assertIsInstance(actual, CommandSucceededEvent) - reply = spec.get('reply') + reply = spec.get("reply") if reply: self.match_result(reply, actual.reply) self.assertHasServiceId(spec, actual) - elif name == 'commandFailedEvent': + elif name == "commandFailedEvent": self.test.assertIsInstance(actual, CommandFailedEvent) self.assertHasServiceId(spec, actual) - elif name == 'poolCreatedEvent': + elif name == "poolCreatedEvent": self.test.assertIsInstance(actual, PoolCreatedEvent) - elif name == 'poolReadyEvent': + elif name == "poolReadyEvent": self.test.assertIsInstance(actual, PoolReadyEvent) - elif name == 'poolClearedEvent': + elif name == "poolClearedEvent": self.test.assertIsInstance(actual, PoolClearedEvent) self.assertHasServiceId(spec, actual) - elif name == 'poolClosedEvent': + elif name == "poolClosedEvent": self.test.assertIsInstance(actual, PoolClosedEvent) - elif name == 'connectionCreatedEvent': + elif name == "connectionCreatedEvent": self.test.assertIsInstance(actual, ConnectionCreatedEvent) - elif name == 'connectionReadyEvent': + elif name == "connectionReadyEvent": self.test.assertIsInstance(actual, ConnectionReadyEvent) - elif name == 'connectionClosedEvent': + elif name == "connectionClosedEvent": self.test.assertIsInstance(actual, ConnectionClosedEvent) - if 'reason' in spec: - self.test.assertEqual(actual.reason, spec['reason']) - elif name == 'connectionCheckOutStartedEvent': + if "reason" in spec: + self.test.assertEqual(actual.reason, spec["reason"]) + elif name == "connectionCheckOutStartedEvent": self.test.assertIsInstance(actual, ConnectionCheckOutStartedEvent) - elif name == 'connectionCheckOutFailedEvent': + elif name == "connectionCheckOutFailedEvent": self.test.assertIsInstance(actual, ConnectionCheckOutFailedEvent) - if 'reason' in spec: - self.test.assertEqual(actual.reason, spec['reason']) - elif name == 'connectionCheckedOutEvent': + if "reason" in spec: + self.test.assertEqual(actual.reason, spec["reason"]) + elif name == "connectionCheckedOutEvent": self.test.assertIsInstance(actual, ConnectionCheckedOutEvent) - elif name == 'connectionCheckedInEvent': + elif name == "connectionCheckedInEvent": self.test.assertIsInstance(actual, ConnectionCheckedInEvent) else: - self.test.fail( - 'Unsupported event type %s' % (name,)) + self.test.fail("Unsupported event type %s" % (name,)) def coerce_result(opname, result): """Convert a pymongo result into the spec's result format.""" - if hasattr(result, 'acknowledged') and not result.acknowledged: - return {'acknowledged': False} - if opname == 'bulkWrite': + if hasattr(result, "acknowledged") and not result.acknowledged: + return {"acknowledged": False} + if opname == "bulkWrite": return parse_bulk_write_result(result) - if opname == 'insertOne': - return {'insertedId': result.inserted_id} - if opname == 'insertMany': + if opname == "insertOne": + return {"insertedId": result.inserted_id} + if opname == "insertMany": return {idx: _id for idx, _id in enumerate(result.inserted_ids)} - if opname in ('deleteOne', 'deleteMany'): - return {'deletedCount': result.deleted_count} - if opname in ('updateOne', 'updateMany', 'replaceOne'): + if opname in ("deleteOne", "deleteMany"): + return {"deletedCount": result.deleted_count} + if opname in ("updateOne", "updateMany", "replaceOne"): return { - 'matchedCount': result.matched_count, - 'modifiedCount': result.modified_count, - 'upsertedCount': 0 if result.upserted_id is None else 1, + "matchedCount": result.matched_count, + "modifiedCount": result.modified_count, + "upsertedCount": 0 if result.upserted_id is None else 1, } return result @@ -688,7 +698,8 @@ class UnifiedSpecTestMixinV1(IntegrationTest): Specification of the test suite being currently run is available as a class attribute ``TEST_SPEC``. """ - SCHEMA_VERSION = Version.from_string('1.5') + + SCHEMA_VERSION = Version.from_string("1.5") RUN_ON_LOAD_BALANCER = True RUN_ON_SERVERLESS = True @@ -705,12 +716,13 @@ def should_run_on(run_on_spec): def insert_initial_data(self, initial_data): for collection_data in initial_data: - coll_name = collection_data['collectionName'] - db_name = collection_data['databaseName'] - documents = collection_data['documents'] + coll_name = collection_data["collectionName"] + db_name = collection_data["databaseName"] + documents = collection_data["documents"] coll = self.client.get_database(db_name).get_collection( - coll_name, write_concern=WriteConcern(w="majority")) + coll_name, write_concern=WriteConcern(w="majority") + ) coll.drop() if len(documents) > 0: @@ -718,56 +730,54 @@ def insert_initial_data(self, initial_data): else: # ensure collection exists result = coll.insert_one({}) - coll.delete_one({'_id': result.inserted_id}) + coll.delete_one({"_id": result.inserted_id}) @classmethod def setUpClass(cls): # super call creates internal client cls.client super(UnifiedSpecTestMixinV1, cls).setUpClass() # process file-level runOnRequirements - run_on_spec = cls.TEST_SPEC.get('runOnRequirements', []) + run_on_spec = cls.TEST_SPEC.get("runOnRequirements", []) if not cls.should_run_on(run_on_spec): - raise unittest.SkipTest( - '%s runOnRequirements not satisfied' % (cls.__name__,)) + raise unittest.SkipTest("%s runOnRequirements not satisfied" % (cls.__name__,)) # add any special-casing for skipping tests here - if client_context.storage_engine == 'mmapv1': - if 'retryable-writes' in cls.TEST_SPEC['description']: - raise unittest.SkipTest( - "MMAPv1 does not support retryWrites=True") + if client_context.storage_engine == "mmapv1": + if "retryable-writes" in cls.TEST_SPEC["description"]: + raise unittest.SkipTest("MMAPv1 does not support retryWrites=True") def setUp(self): super(UnifiedSpecTestMixinV1, self).setUp() # process schemaVersion # note: we check major schema version during class generation # note: we do this here because we cannot run assertions in setUpClass - version = Version.from_string(self.TEST_SPEC['schemaVersion']) + version = Version.from_string(self.TEST_SPEC["schemaVersion"]) self.assertLessEqual( - version, self.SCHEMA_VERSION, - 'expected schema version %s or lower, got %s' % ( - self.SCHEMA_VERSION, version)) + version, + self.SCHEMA_VERSION, + "expected schema version %s or lower, got %s" % (self.SCHEMA_VERSION, version), + ) # initialize internals self.match_evaluator = MatchEvaluatorUtil(self) def maybe_skip_test(self, spec): # add any special-casing for skipping tests here - if client_context.storage_engine == 'mmapv1': - if 'Dirty explicit session is discarded' in spec['description']: - raise unittest.SkipTest( - "MMAPv1 does not support retryWrites=True") - elif 'Client side error in command starting transaction' in spec['description']: + if client_context.storage_engine == "mmapv1": + if "Dirty explicit session is discarded" in spec["description"]: + raise unittest.SkipTest("MMAPv1 does not support retryWrites=True") + elif "Client side error in command starting transaction" in spec["description"]: raise unittest.SkipTest("Implement PYTHON-1894") def process_error(self, exception, spec): - is_error = spec.get('isError') - is_client_error = spec.get('isClientError') - error_contains = spec.get('errorContains') - error_code = spec.get('errorCode') - error_code_name = spec.get('errorCodeName') - error_labels_contain = spec.get('errorLabelsContain') - error_labels_omit = spec.get('errorLabelsOmit') - expect_result = spec.get('expectResult') + is_error = spec.get("isError") + is_client_error = spec.get("isClientError") + error_contains = spec.get("errorContains") + error_code = spec.get("errorCode") + error_code_name = spec.get("errorCodeName") + error_labels_contain = spec.get("errorLabelsContain") + error_labels_omit = spec.get("errorLabelsOmit") + expect_result = spec.get("expectResult") if is_error: # already satisfied because exception was raised @@ -790,75 +800,72 @@ def process_error(self, exception, spec): self.assertIn(error_contains.lower(), errmsg) if error_code: - self.assertEqual( - error_code, exception.details.get('code')) + self.assertEqual(error_code, exception.details.get("code")) if error_code_name: - self.assertEqual( - error_code_name, exception.details.get('codeName')) + self.assertEqual(error_code_name, exception.details.get("codeName")) if error_labels_contain: - labels = [err_label for err_label in error_labels_contain - if exception.has_error_label(err_label)] + labels = [ + err_label + for err_label in error_labels_contain + if exception.has_error_label(err_label) + ] self.assertEqual(labels, error_labels_contain) if error_labels_omit: for err_label in error_labels_omit: if exception.has_error_label(err_label): - self.fail("Exception '%s' unexpectedly had label '%s'" % ( - exception, err_label)) + self.fail("Exception '%s' unexpectedly had label '%s'" % (exception, err_label)) if expect_result: if isinstance(exception, BulkWriteError): - result = parse_bulk_write_error_result( - exception) + result = parse_bulk_write_error_result(exception) self.match_evaluator.match_result(expect_result, result) else: - self.fail("expectResult can only be specified with %s " - "exceptions" % (BulkWriteError,)) + self.fail( + "expectResult can only be specified with %s " "exceptions" % (BulkWriteError,) + ) def __raise_if_unsupported(self, opname, target, *target_types): if not isinstance(target, target_types): - self.fail('Operation %s not supported for entity ' - 'of type %s' % (opname, type(target))) + self.fail( + "Operation %s not supported for entity " "of type %s" % (opname, type(target)) + ) def __entityOperation_createChangeStream(self, target, *args, **kwargs): - if client_context.storage_engine == 'mmapv1': + if client_context.storage_engine == "mmapv1": self.skipTest("MMAPv1 does not support change streams") - self.__raise_if_unsupported( - 'createChangeStream', target, MongoClient, Database, Collection) + self.__raise_if_unsupported("createChangeStream", target, MongoClient, Database, Collection) stream = target.watch(*args, **kwargs) self.addCleanup(stream.close) return stream def _clientOperation_createChangeStream(self, target, *args, **kwargs): - return self.__entityOperation_createChangeStream( - target, *args, **kwargs) + return self.__entityOperation_createChangeStream(target, *args, **kwargs) def _databaseOperation_createChangeStream(self, target, *args, **kwargs): - return self.__entityOperation_createChangeStream( - target, *args, **kwargs) + return self.__entityOperation_createChangeStream(target, *args, **kwargs) def _collectionOperation_createChangeStream(self, target, *args, **kwargs): - return self.__entityOperation_createChangeStream( - target, *args, **kwargs) + return self.__entityOperation_createChangeStream(target, *args, **kwargs) def _databaseOperation_runCommand(self, target, **kwargs): - self.__raise_if_unsupported('runCommand', target, Database) + self.__raise_if_unsupported("runCommand", target, Database) # Ensure the first key is the command name. - ordered_command = SON([(kwargs.pop('command_name'), 1)]) - ordered_command.update(kwargs['command']) - kwargs['command'] = ordered_command + ordered_command = SON([(kwargs.pop("command_name"), 1)]) + ordered_command.update(kwargs["command"]) + kwargs["command"] = ordered_command return target.command(**kwargs) def _databaseOperation_listCollections(self, target, *args, **kwargs): - if 'batch_size' in kwargs: - kwargs['cursor'] = {'batchSize': kwargs.pop('batch_size')} + if "batch_size" in kwargs: + kwargs["cursor"] = {"batchSize": kwargs.pop("batch_size")} cursor = target.list_collections(*args, **kwargs) return list(cursor) def __entityOperation_aggregate(self, target, *args, **kwargs): - self.__raise_if_unsupported('aggregate', target, Database, Collection) + self.__raise_if_unsupported("aggregate", target, Database, Collection) return list(target.aggregate(*args, **kwargs)) def _databaseOperation_aggregate(self, target, *args, **kwargs): @@ -868,86 +875,84 @@ def _collectionOperation_aggregate(self, target, *args, **kwargs): return self.__entityOperation_aggregate(target, *args, **kwargs) def _collectionOperation_find(self, target, *args, **kwargs): - self.__raise_if_unsupported('find', target, Collection) + self.__raise_if_unsupported("find", target, Collection) find_cursor = target.find(*args, **kwargs) return list(find_cursor) def _collectionOperation_createFindCursor(self, target, *args, **kwargs): - self.__raise_if_unsupported('find', target, Collection) - if 'filter' not in kwargs: + self.__raise_if_unsupported("find", target, Collection) + if "filter" not in kwargs: self.fail('createFindCursor requires a "filter" argument') cursor = NonLazyCursor(target.find(*args, **kwargs)) self.addCleanup(cursor.close) return cursor def _collectionOperation_listIndexes(self, target, *args, **kwargs): - if 'batch_size' in kwargs: - self.skipTest('PyMongo does not support batch_size for ' - 'list_indexes') + if "batch_size" in kwargs: + self.skipTest("PyMongo does not support batch_size for " "list_indexes") return target.list_indexes(*args, **kwargs) def _sessionOperation_withTransaction(self, target, *args, **kwargs): - if client_context.storage_engine == 'mmapv1': - self.skipTest('MMAPv1 does not support document-level locking') - self.__raise_if_unsupported('withTransaction', target, ClientSession) + if client_context.storage_engine == "mmapv1": + self.skipTest("MMAPv1 does not support document-level locking") + self.__raise_if_unsupported("withTransaction", target, ClientSession) return target.with_transaction(*args, **kwargs) def _sessionOperation_startTransaction(self, target, *args, **kwargs): - if client_context.storage_engine == 'mmapv1': - self.skipTest('MMAPv1 does not support document-level locking') - self.__raise_if_unsupported('startTransaction', target, ClientSession) + if client_context.storage_engine == "mmapv1": + self.skipTest("MMAPv1 does not support document-level locking") + self.__raise_if_unsupported("startTransaction", target, ClientSession) return target.start_transaction(*args, **kwargs) - def _changeStreamOperation_iterateUntilDocumentOrError(self, target, - *args, **kwargs): - self.__raise_if_unsupported( - 'iterateUntilDocumentOrError', target, ChangeStream) + def _changeStreamOperation_iterateUntilDocumentOrError(self, target, *args, **kwargs): + self.__raise_if_unsupported("iterateUntilDocumentOrError", target, ChangeStream) return next(target) def _cursor_iterateUntilDocumentOrError(self, target, *args, **kwargs): - self.__raise_if_unsupported( - 'iterateUntilDocumentOrError', target, NonLazyCursor) + self.__raise_if_unsupported("iterateUntilDocumentOrError", target, NonLazyCursor) return next(target) def _cursor_close(self, target, *args, **kwargs): - self.__raise_if_unsupported('close', target, NonLazyCursor) + self.__raise_if_unsupported("close", target, NonLazyCursor) return target.close() def run_entity_operation(self, spec): - target = self.entity_map[spec['object']] - opname = spec['name'] - opargs = spec.get('arguments') - expect_error = spec.get('expectError') - save_as_entity = spec.get('saveResultAsEntity') - expect_result = spec.get('expectResult') - ignore = spec.get('ignoreResultAndError') + target = self.entity_map[spec["object"]] + opname = spec["name"] + opargs = spec.get("arguments") + expect_error = spec.get("expectError") + save_as_entity = spec.get("saveResultAsEntity") + expect_result = spec.get("expectResult") + ignore = spec.get("ignoreResultAndError") if ignore and (expect_error or save_as_entity or expect_result): raise ValueError( - 'ignoreResultAndError is incompatible with saveResultAsEntity' - ', expectError, and expectResult') + "ignoreResultAndError is incompatible with saveResultAsEntity" + ", expectError, and expectResult" + ) if opargs: arguments = parse_spec_options(copy.deepcopy(opargs)) - prepare_spec_arguments(spec, arguments, camel_to_snake(opname), - self.entity_map, self.run_operations) + prepare_spec_arguments( + spec, arguments, camel_to_snake(opname), self.entity_map, self.run_operations + ) else: arguments = tuple() if isinstance(target, MongoClient): - method_name = '_clientOperation_%s' % (opname,) + method_name = "_clientOperation_%s" % (opname,) elif isinstance(target, Database): - method_name = '_databaseOperation_%s' % (opname,) + method_name = "_databaseOperation_%s" % (opname,) elif isinstance(target, Collection): - method_name = '_collectionOperation_%s' % (opname,) + method_name = "_collectionOperation_%s" % (opname,) elif isinstance(target, ChangeStream): - method_name = '_changeStreamOperation_%s' % (opname,) + method_name = "_changeStreamOperation_%s" % (opname,) elif isinstance(target, NonLazyCursor): - method_name = '_cursor_%s' % (opname,) + method_name = "_cursor_%s" % (opname,) elif isinstance(target, ClientSession): - method_name = '_sessionOperation_%s' % (opname,) + method_name = "_sessionOperation_%s" % (opname,) elif isinstance(target, GridFSBucket): raise NotImplementedError else: - method_name = 'doesNotExist' + method_name = "doesNotExist" try: method = getattr(self, method_name) @@ -955,8 +960,7 @@ def run_entity_operation(self, spec): try: cmd = getattr(target, camel_to_snake(opname)) except AttributeError: - self.fail('Unsupported operation %s on entity %s' % ( - opname, target)) + self.fail("Unsupported operation %s on entity %s" % (opname, target)) else: cmd = functools.partial(method, target) @@ -972,8 +976,9 @@ def run_entity_operation(self, spec): raise else: if expect_error: - self.fail('Excepted error %s but "%s" succeeded: %s' % ( - expect_error, opname, result)) + self.fail( + 'Excepted error %s but "%s" succeeded: %s' % (expect_error, opname, result) + ) if expect_result: actual = coerce_result(opname, result) @@ -984,42 +989,43 @@ def run_entity_operation(self, spec): def __set_fail_point(self, client, command_args): if not client_context.test_commands_enabled: - self.skipTest('Test commands must be enabled') + self.skipTest("Test commands must be enabled") - cmd_on = SON([('configureFailPoint', 'failCommand')]) + cmd_on = SON([("configureFailPoint", "failCommand")]) cmd_on.update(command_args) client.admin.command(cmd_on) self.addCleanup( - client.admin.command, - 'configureFailPoint', cmd_on['configureFailPoint'], mode='off') + client.admin.command, "configureFailPoint", cmd_on["configureFailPoint"], mode="off" + ) def _testOperation_failPoint(self, spec): self.__set_fail_point( - client=self.entity_map[spec['client']], - command_args=spec['failPoint']) + client=self.entity_map[spec["client"]], command_args=spec["failPoint"] + ) def _testOperation_targetedFailPoint(self, spec): - session = self.entity_map[spec['session']] + session = self.entity_map[spec["session"]] if not session._pinned_address: - self.fail("Cannot use targetedFailPoint operation with unpinned " - "session %s" % (spec['session'],)) + self.fail( + "Cannot use targetedFailPoint operation with unpinned " + "session %s" % (spec["session"],) + ) - client = single_client('%s:%s' % session._pinned_address) + client = single_client("%s:%s" % session._pinned_address) self.addCleanup(client.close) - self.__set_fail_point( - client=client, command_args=spec['failPoint']) + self.__set_fail_point(client=client, command_args=spec["failPoint"]) def _testOperation_assertSessionTransactionState(self, spec): - session = self.entity_map[spec['session']] - expected_state = getattr(_TxnState, spec['state'].upper()) + session = self.entity_map[spec["session"]] + expected_state = getattr(_TxnState, spec["state"].upper()) self.assertEqual(expected_state, session._transaction.state) def _testOperation_assertSessionPinned(self, spec): - session = self.entity_map[spec['session']] + session = self.entity_map[spec["session"]] self.assertIsNotNone(session._transaction.pinned_address) def _testOperation_assertSessionUnpinned(self, spec): - session = self.entity_map[spec['session']] + session = self.entity_map[spec["session"]] self.assertIsNone(session._pinned_address) self.assertIsNone(session._transaction.pinned_address) @@ -1029,61 +1035,61 @@ def __get_last_two_command_lsids(self, listener): if isinstance(event, CommandStartedEvent): cmd_started_events.append(event) if len(cmd_started_events) < 2: - self.fail('Needed 2 CommandStartedEvents to compare lsids, ' - 'got %s' % (len(cmd_started_events))) - return tuple([e.command['lsid'] for e in cmd_started_events][:2]) + self.fail( + "Needed 2 CommandStartedEvents to compare lsids, " + "got %s" % (len(cmd_started_events)) + ) + return tuple([e.command["lsid"] for e in cmd_started_events][:2]) def _testOperation_assertDifferentLsidOnLastTwoCommands(self, spec): - listener = self.entity_map.get_listener_for_client(spec['client']) + listener = self.entity_map.get_listener_for_client(spec["client"]) self.assertNotEqual(*self.__get_last_two_command_lsids(listener)) def _testOperation_assertSameLsidOnLastTwoCommands(self, spec): - listener = self.entity_map.get_listener_for_client(spec['client']) + listener = self.entity_map.get_listener_for_client(spec["client"]) self.assertEqual(*self.__get_last_two_command_lsids(listener)) def _testOperation_assertSessionDirty(self, spec): - session = self.entity_map[spec['session']] + session = self.entity_map[spec["session"]] self.assertTrue(session._server_session.dirty) def _testOperation_assertSessionNotDirty(self, spec): - session = self.entity_map[spec['session']] + session = self.entity_map[spec["session"]] return self.assertFalse(session._server_session.dirty) def _testOperation_assertCollectionExists(self, spec): - database_name = spec['databaseName'] - collection_name = spec['collectionName'] - collection_name_list = list( - self.client.get_database(database_name).list_collection_names()) + database_name = spec["databaseName"] + collection_name = spec["collectionName"] + collection_name_list = list(self.client.get_database(database_name).list_collection_names()) self.assertIn(collection_name, collection_name_list) def _testOperation_assertCollectionNotExists(self, spec): - database_name = spec['databaseName'] - collection_name = spec['collectionName'] - collection_name_list = list( - self.client.get_database(database_name).list_collection_names()) + database_name = spec["databaseName"] + collection_name = spec["collectionName"] + collection_name_list = list(self.client.get_database(database_name).list_collection_names()) self.assertNotIn(collection_name, collection_name_list) def _testOperation_assertIndexExists(self, spec): - collection = self.client[spec['databaseName']][spec['collectionName']] - index_names = [idx['name'] for idx in collection.list_indexes()] - self.assertIn(spec['indexName'], index_names) + collection = self.client[spec["databaseName"]][spec["collectionName"]] + index_names = [idx["name"] for idx in collection.list_indexes()] + self.assertIn(spec["indexName"], index_names) def _testOperation_assertIndexNotExists(self, spec): - collection = self.client[spec['databaseName']][spec['collectionName']] + collection = self.client[spec["databaseName"]][spec["collectionName"]] for index in collection.list_indexes(): - self.assertNotEqual(spec['indexName'], index['name']) + self.assertNotEqual(spec["indexName"], index["name"]) def _testOperation_assertNumberConnectionsCheckedOut(self, spec): - client = self.entity_map[spec['client']] + client = self.entity_map[spec["client"]] pool = get_pool(client) - self.assertEqual(spec['connections'], pool.active_sockets) + self.assertEqual(spec["connections"], pool.active_sockets) def _testOperation_loop(self, spec): - failure_key = spec.get('storeFailuresAsEntity') - error_key = spec.get('storeErrorsAsEntity') - successes_key = spec.get('storeSuccessesAsEntity') - iteration_key = spec.get('storeIterationsAsEntity') - iteration_limiter_key = spec.get('numIterations') + failure_key = spec.get("storeFailuresAsEntity") + error_key = spec.get("storeErrorsAsEntity") + successes_key = spec.get("storeSuccessesAsEntity") + iteration_key = spec.get("storeIterationsAsEntity") + iteration_limiter_key = spec.get("numIterations") for i in [failure_key, error_key]: if i: self.entity_map[i] = [] @@ -1112,37 +1118,34 @@ def _testOperation_loop(self, spec): key = error_key or failure_key if not key: raise - self.entity_map[key].append({ - "error": str(exc), - "time": time.time(), - "type": type(exc).__name__ - }) + self.entity_map[key].append( + {"error": str(exc), "time": time.time(), "type": type(exc).__name__} + ) def run_special_operation(self, spec): - opname = spec['name'] - method_name = '_testOperation_%s' % (opname,) + opname = spec["name"] + method_name = "_testOperation_%s" % (opname,) try: method = getattr(self, method_name) except AttributeError: - self.fail('Unsupported special test operation %s' % (opname,)) + self.fail("Unsupported special test operation %s" % (opname,)) else: - method(spec['arguments']) + method(spec["arguments"]) def run_operations(self, spec): for op in spec: - if op['object'] == 'testRunner': + if op["object"] == "testRunner": self.run_special_operation(op) else: self.run_entity_operation(op) - def check_events(self, spec): for event_spec in spec: - client_name = event_spec['client'] - events = event_spec['events'] + client_name = event_spec["client"] + events = event_spec["events"] # Valid types: 'command', 'cmap' - event_type = event_spec.get('eventType', 'command') - assert event_type in ('command', 'cmap') + event_type = event_spec.get("eventType", "command") + assert event_type in ("command", "cmap") listener = self.entity_map.get_listener_for_client(client_name) actual_events = listener.get_events(event_type) @@ -1151,80 +1154,76 @@ def check_events(self, spec): continue if len(events) > len(actual_events): - self.fail('Expected to see %s events, got %s' % ( - len(events), len(actual_events))) + self.fail("Expected to see %s events, got %s" % (len(events), len(actual_events))) for idx, expected_event in enumerate(events): - self.match_evaluator.match_event( - event_type, expected_event, actual_events[idx]) + self.match_evaluator.match_event(event_type, expected_event, actual_events[idx]) def verify_outcome(self, spec): for collection_data in spec: - coll_name = collection_data['collectionName'] - db_name = collection_data['databaseName'] - expected_documents = collection_data['documents'] + coll_name = collection_data["collectionName"] + db_name = collection_data["databaseName"] + expected_documents = collection_data["documents"] coll = self.client.get_database(db_name).get_collection( coll_name, read_preference=ReadPreference.PRIMARY, - read_concern=ReadConcern(level='local')) + read_concern=ReadConcern(level="local"), + ) if expected_documents: - sorted_expected_documents = sorted( - expected_documents, key=lambda doc: doc['_id']) - actual_documents = list( - coll.find({}, sort=[('_id', ASCENDING)])) - self.assertListEqual(sorted_expected_documents, - actual_documents) + sorted_expected_documents = sorted(expected_documents, key=lambda doc: doc["_id"]) + actual_documents = list(coll.find({}, sort=[("_id", ASCENDING)])) + self.assertListEqual(sorted_expected_documents, actual_documents) def run_scenario(self, spec, uri=None): # maybe skip test manually self.maybe_skip_test(spec) # process test-level runOnRequirements - run_on_spec = spec.get('runOnRequirements', []) + run_on_spec = spec.get("runOnRequirements", []) if not self.should_run_on(run_on_spec): - raise unittest.SkipTest('runOnRequirements not satisfied') + raise unittest.SkipTest("runOnRequirements not satisfied") # process skipReason - skip_reason = spec.get('skipReason', None) + skip_reason = spec.get("skipReason", None) if skip_reason is not None: - raise unittest.SkipTest('%s' % (skip_reason,)) + raise unittest.SkipTest("%s" % (skip_reason,)) # process createEntities self.entity_map = EntityMapUtil(self) - self.entity_map.create_entities_from_spec( - self.TEST_SPEC.get('createEntities', []), uri=uri) + self.entity_map.create_entities_from_spec(self.TEST_SPEC.get("createEntities", []), uri=uri) # process initialData - self.insert_initial_data(self.TEST_SPEC.get('initialData', [])) + self.insert_initial_data(self.TEST_SPEC.get("initialData", [])) # process operations - self.run_operations(spec['operations']) + self.run_operations(spec["operations"]) # process expectEvents - if 'expectEvents' in spec: - expect_events = spec['expectEvents'] - self.assertTrue(expect_events, 'expectEvents must be non-empty') + if "expectEvents" in spec: + expect_events = spec["expectEvents"] + self.assertTrue(expect_events, "expectEvents must be non-empty") self.check_events(expect_events) # process outcome - self.verify_outcome(spec.get('outcome', [])) + self.verify_outcome(spec.get("outcome", [])) class UnifiedSpecTestMeta(type): """Metaclass for generating test classes.""" + def __init__(cls, *args, **kwargs): super(UnifiedSpecTestMeta, cls).__init__(*args, **kwargs) def create_test(spec): def test_case(self): self.run_scenario(spec) + return test_case - for test_spec in cls.TEST_SPEC['tests']: - description = test_spec['description'] - test_name = 'test_%s' % (description.strip('. '). - replace(' ', '_').replace('.', '_'),) + for test_spec in cls.TEST_SPEC["tests"]: + description = test_spec["description"] + test_name = "test_%s" % (description.strip(". ").replace(" ", "_").replace(".", "_"),) test_method = create_test(copy.deepcopy(test_spec)) test_method.__name__ = str(test_name) @@ -1243,13 +1242,18 @@ def test_case(self): _SCHEMA_VERSION_MAJOR_TO_MIXIN_CLASS = { - KLASS.SCHEMA_VERSION[0]: KLASS for KLASS in _ALL_MIXIN_CLASSES} + KLASS.SCHEMA_VERSION[0]: KLASS for KLASS in _ALL_MIXIN_CLASSES +} -def generate_test_classes(test_path, module=__name__, class_name_prefix='', - expected_failures=[], - bypass_test_generation_errors=False, - **kwargs): +def generate_test_classes( + test_path, + module=__name__, + class_name_prefix="", + expected_failures=[], + bypass_test_generation_errors=False, + **kwargs +): """Method for generating test classes. Returns a dictionary where keys are the names of test classes and values are the test class objects.""" test_klasses = {} @@ -1258,9 +1262,11 @@ def test_base_class_factory(test_spec): """Utility that creates the base class to use for test generation. This is needed to ensure that cls.TEST_SPEC is appropriately set when the metaclass __init__ is invoked.""" + class SpecTestBase(with_metaclass(UnifiedSpecTestMeta)): TEST_SPEC = test_spec EXPECTED_FAILURES = expected_failures + return SpecTestBase for dirpath, _, filenames in os.walk(test_path): @@ -1272,30 +1278,34 @@ class SpecTestBase(with_metaclass(UnifiedSpecTestMeta)): # Use tz_aware=False to match how CodecOptions decodes # dates. opts = json_util.JSONOptions(tz_aware=False) - scenario_def = json_util.loads( - scenario_stream.read(), json_options=opts) + scenario_def = json_util.loads(scenario_stream.read(), json_options=opts) test_type = os.path.splitext(filename)[0] - snake_class_name = 'Test%s_%s_%s' % ( - class_name_prefix, dirname.replace('-', '_'), - test_type.replace('-', '_').replace('.', '_')) + snake_class_name = "Test%s_%s_%s" % ( + class_name_prefix, + dirname.replace("-", "_"), + test_type.replace("-", "_").replace(".", "_"), + ) class_name = snake_to_camel(snake_class_name) try: - schema_version = Version.from_string( - scenario_def['schemaVersion']) - mixin_class = _SCHEMA_VERSION_MAJOR_TO_MIXIN_CLASS.get( - schema_version[0]) + schema_version = Version.from_string(scenario_def["schemaVersion"]) + mixin_class = _SCHEMA_VERSION_MAJOR_TO_MIXIN_CLASS.get(schema_version[0]) if mixin_class is None: raise ValueError( - "test file '%s' has unsupported schemaVersion '%s'" % ( - fpath, schema_version)) - module_dict = {'__module__': module} + "test file '%s' has unsupported schemaVersion '%s'" + % (fpath, schema_version) + ) + module_dict = {"__module__": module} module_dict.update(kwargs) test_klasses[class_name] = type( class_name, - (mixin_class, test_base_class_factory(scenario_def),), - module_dict) + ( + mixin_class, + test_base_class_factory(scenario_def), + ), + module_dict, + ) except Exception: if bypass_test_generation_errors: continue diff --git a/test/utils.py b/test/utils.py index 5b6f9fd264..dd2cd4817d 100644 --- a/test/utils.py +++ b/test/utils.py @@ -26,16 +26,14 @@ import time import unittest import warnings - from collections import abc, defaultdict from functools import partial +from test import client_context, db_pwd, db_user from bson import json_util from bson.objectid import ObjectId from bson.son import SON - -from pymongo import (MongoClient, - monitoring, operations, read_preferences) +from pymongo import MongoClient, monitoring, operations, read_preferences from pymongo.collection import ReturnDocument from pymongo.errors import ConfigurationError, OperationFailure from pymongo.hello import HelloCompat @@ -43,16 +41,10 @@ from pymongo.pool import _CancellationContext, _PoolGeneration from pymongo.read_concern import ReadConcern from pymongo.read_preferences import ReadPreference -from pymongo.server_selectors import (any_server_selector, - writable_server_selector) +from pymongo.server_selectors import any_server_selector, writable_server_selector from pymongo.server_type import SERVER_TYPE -from pymongo.write_concern import WriteConcern from pymongo.uri_parser import parse_uri - -from test import (client_context, - db_user, - db_pwd) - +from pymongo.write_concern import WriteConcern IMPOSSIBLE_WRITE_CONCERN = WriteConcern(w=50) @@ -83,8 +75,7 @@ def matching(self, matcher): def wait_for_event(self, event, count): """Wait for a number of events to be published, or fail.""" - wait_until(lambda: self.event_count(event) >= count, - 'find %s %s event(s)' % (count, event)) + wait_until(lambda: self.event_count(event) >= count, "find %s %s event(s)" % (count, event)) class CMAPListener(BaseListener, monitoring.ConnectionPoolListener): @@ -123,22 +114,21 @@ def pool_closed(self, event): class EventListener(monitoring.CommandListener): - def __init__(self): self.results = defaultdict(list) def started(self, event): - self.results['started'].append(event) + self.results["started"].append(event) def succeeded(self, event): - self.results['succeeded'].append(event) + self.results["succeeded"].append(event) def failed(self, event): - self.results['failed'].append(event) + self.results["failed"].append(event) def started_command_names(self): """Return list of command names started.""" - return [event.command_name for event in self.results['started']] + return [event.command_name for event in self.results["started"]] def reset(self): """Reset the state of this listener.""" @@ -150,13 +140,13 @@ def __init__(self): self.results = defaultdict(list) def closed(self, event): - self.results['closed'].append(event) + self.results["closed"].append(event) def description_changed(self, event): - self.results['description_changed'].append(event) + self.results["description_changed"].append(event) def opened(self, event): - self.results['opened'].append(event) + self.results["opened"].append(event) def reset(self): """Reset the state of this listener.""" @@ -164,7 +154,6 @@ def reset(self): class AllowListEventListener(EventListener): - def __init__(self, *commands): self.commands = set(commands) super(AllowListEventListener, self).__init__() @@ -184,6 +173,7 @@ def failed(self, event): class OvertCommandListener(EventListener): """A CommandListener that ignores sensitive commands.""" + def started(self, event): if event.command_name.lower() not in _SENSITIVE_COMMANDS: super(OvertCommandListener, self).started(event) @@ -221,13 +211,11 @@ def reset(self): self.results = [] -class ServerEventListener(_ServerEventListener, - monitoring.ServerListener): +class ServerEventListener(_ServerEventListener, monitoring.ServerListener): """Listens to Server events.""" -class ServerAndTopologyEventListener(ServerEventListener, - monitoring.TopologyListener): +class ServerAndTopologyEventListener(ServerEventListener, monitoring.TopologyListener): """Listens to Server and Topology events.""" @@ -300,6 +288,7 @@ def remove_stale_sockets(self, *args, **kwargs): class ScenarioDict(dict): """Dict that returns {} for any unknown key, recursively.""" + def __init__(self, data): def convert(v): if isinstance(v, abc.Mapping): @@ -322,6 +311,7 @@ def __getitem__(self, item): class CompareType(object): """Class that compares equal to any object of the given type.""" + def __init__(self, type): self.type = type @@ -335,6 +325,7 @@ def __ne__(self, other): class FunctionCallRecorder(object): """Utility class to wrap a callable and record its invocations.""" + def __init__(self, function): self._function = function self._call_list = [] @@ -359,6 +350,7 @@ def call_count(self): class TestCreator(object): """Class to create test cases from specifications.""" + def __init__(self, create_test, test_class, test_path): """Create a TestCreator object. @@ -372,7 +364,7 @@ def __init__(self, create_test, test_class, test_path): test case. - `test_path`: path to the directory containing the JSON files with the test specifications. - """ + """ self._create_test = create_test self._test_class = test_class self.test_path = test_path @@ -380,67 +372,63 @@ def __init__(self, create_test, test_class, test_path): def _ensure_min_max_server_version(self, scenario_def, method): """Test modifier that enforces a version range for the server on a test case.""" - if 'minServerVersion' in scenario_def: - min_ver = tuple( - int(elt) for - elt in scenario_def['minServerVersion'].split('.')) + if "minServerVersion" in scenario_def: + min_ver = tuple(int(elt) for elt in scenario_def["minServerVersion"].split(".")) if min_ver is not None: method = client_context.require_version_min(*min_ver)(method) - if 'maxServerVersion' in scenario_def: - max_ver = tuple( - int(elt) for - elt in scenario_def['maxServerVersion'].split('.')) + if "maxServerVersion" in scenario_def: + max_ver = tuple(int(elt) for elt in scenario_def["maxServerVersion"].split(".")) if max_ver is not None: method = client_context.require_version_max(*max_ver)(method) - if 'serverless' in scenario_def: - serverless = scenario_def['serverless'] + if "serverless" in scenario_def: + serverless = scenario_def["serverless"] if serverless == "require": serverless_satisfied = client_context.serverless elif serverless == "forbid": serverless_satisfied = not client_context.serverless - else: # unset or "allow" + else: # unset or "allow" serverless_satisfied = True method = unittest.skipUnless( - serverless_satisfied, - "Serverless requirement not satisfied")(method) + serverless_satisfied, "Serverless requirement not satisfied" + )(method) return method @staticmethod def valid_topology(run_on_req): return client_context.is_topology_type( - run_on_req.get('topology', ['single', 'replicaset', 'sharded', - 'load-balanced'])) + run_on_req.get("topology", ["single", "replicaset", "sharded", "load-balanced"]) + ) @staticmethod def min_server_version(run_on_req): - version = run_on_req.get('minServerVersion') + version = run_on_req.get("minServerVersion") if version: - min_ver = tuple(int(elt) for elt in version.split('.')) + min_ver = tuple(int(elt) for elt in version.split(".")) return client_context.version >= min_ver return True @staticmethod def max_server_version(run_on_req): - version = run_on_req.get('maxServerVersion') + version = run_on_req.get("maxServerVersion") if version: - max_ver = tuple(int(elt) for elt in version.split('.')) + max_ver = tuple(int(elt) for elt in version.split(".")) return client_context.version <= max_ver return True @staticmethod def valid_auth_enabled(run_on_req): - if 'authEnabled' in run_on_req: - if run_on_req['authEnabled']: + if "authEnabled" in run_on_req: + if run_on_req["authEnabled"]: return client_context.auth_enabled return not client_context.auth_enabled return True @staticmethod def serverless_ok(run_on_req): - serverless = run_on_req['serverless'] + serverless = run_on_req["serverless"] if serverless == "require": return client_context.serverless elif serverless == "forbid": @@ -449,30 +437,31 @@ def serverless_ok(run_on_req): return True def should_run_on(self, scenario_def): - run_on = scenario_def.get('runOn', []) + run_on = scenario_def.get("runOn", []) if not run_on: # Always run these tests. return True for req in run_on: - if (self.valid_topology(req) and - self.min_server_version(req) and - self.max_server_version(req) and - self.valid_auth_enabled(req) and - self.serverless_ok(req)): + if ( + self.valid_topology(req) + and self.min_server_version(req) + and self.max_server_version(req) + and self.valid_auth_enabled(req) + and self.serverless_ok(req) + ): return True return False def ensure_run_on(self, scenario_def, method): """Test modifier that enforces a 'runOn' on a test case.""" return client_context._require( - lambda: self.should_run_on(scenario_def), - "runOn not satisfied", - method) + lambda: self.should_run_on(scenario_def), "runOn not satisfied", method + ) def tests(self, scenario_def): """Allow CMAP spec test to override the location of test.""" - return scenario_def['tests'] + return scenario_def["tests"] def create_tests(self): for dirpath, _, filenames in os.walk(self.test_path): @@ -484,25 +473,22 @@ def create_tests(self): # dates. opts = json_util.JSONOptions(tz_aware=False) scenario_def = ScenarioDict( - json_util.loads(scenario_stream.read(), - json_options=opts)) + json_util.loads(scenario_stream.read(), json_options=opts) + ) test_type = os.path.splitext(filename)[0] # Construct test from scenario. for test_def in self.tests(scenario_def): - test_name = 'test_%s_%s_%s' % ( + test_name = "test_%s_%s_%s" % ( dirname, - test_type.replace("-", "_").replace('.', '_'), - str(test_def['description'].replace(" ", "_").replace( - '.', '_'))) + test_type.replace("-", "_").replace(".", "_"), + str(test_def["description"].replace(" ", "_").replace(".", "_")), + ) - new_test = self._create_test( - scenario_def, test_def, test_name) - new_test = self._ensure_min_max_server_version( - scenario_def, new_test) - new_test = self.ensure_run_on( - scenario_def, new_test) + new_test = self._create_test(scenario_def, test_def, test_name) + new_test = self._ensure_min_max_server_version(scenario_def, new_test) + new_test = self.ensure_run_on(scenario_def, new_test) new_test.__name__ = test_name setattr(self._test_class, new_test.__name__, new_test) @@ -514,35 +500,36 @@ def _connection_string(h): return "mongodb://%s" % (str(h),) -def _mongo_client(host, port, authenticate=True, directConnection=None, - **kwargs): +def _mongo_client(host, port, authenticate=True, directConnection=None, **kwargs): """Create a new client over SSL/TLS if necessary.""" host = host or client_context.host port = port or client_context.port client_options = client_context.default_client_options.copy() if client_context.replica_set_name and not directConnection: - client_options['replicaSet'] = client_context.replica_set_name + client_options["replicaSet"] = client_context.replica_set_name if directConnection is not None: - client_options['directConnection'] = directConnection + client_options["directConnection"] = directConnection client_options.update(kwargs) uri = _connection_string(host) if client_context.auth_enabled and authenticate: # Only add the default username or password if one is not provided. res = parse_uri(uri) - if (not res['username'] and not res['password'] and - 'username' not in client_options and - 'password' not in client_options): - client_options['username'] = db_user - client_options['password'] = db_pwd + if ( + not res["username"] + and not res["password"] + and "username" not in client_options + and "password" not in client_options + ): + client_options["username"] = db_user + client_options["password"] = db_pwd return MongoClient(uri, port, **client_options) def single_client_noauth(h=None, p=None, **kwargs): """Make a direct connection. Don't authenticate.""" - return _mongo_client(h, p, authenticate=False, - directConnection=True, **kwargs) + return _mongo_client(h, p, authenticate=False, directConnection=True, **kwargs) def single_client(h=None, p=None, **kwargs): @@ -585,17 +572,16 @@ def ensure_all_connected(client): that are configured on the client. """ hello = client.admin.command(HelloCompat.LEGACY_CMD) - if 'setName' not in hello: + if "setName" not in hello: raise ConfigurationError("cluster is not a replica set") - target_host_list = set(hello['hosts']) - connected_host_list = set([hello['me']]) - admindb = client.get_database('admin') + target_host_list = set(hello["hosts"]) + connected_host_list = set([hello["me"]]) + admindb = client.get_database("admin") # Run hello until we have connected to each host at least once. while connected_host_list != target_host_list: - hello = admindb.command(HelloCompat.LEGACY_CMD, - read_preference=ReadPreference.SECONDARY) + hello = admindb.command(HelloCompat.LEGACY_CMD, read_preference=ReadPreference.SECONDARY) connected_host_list.update([hello["me"]]) @@ -612,19 +598,19 @@ def oid_generated_on_process(oid): def delay(sec): - return '''function() { sleep(%f * 1000); return true; }''' % sec + return """function() { sleep(%f * 1000); return true; }""" % sec def get_command_line(client): - command_line = client.admin.command('getCmdLineOpts') - assert command_line['ok'] == 1, "getCmdLineOpts() failed" + command_line = client.admin.command("getCmdLineOpts") + assert command_line["ok"] == 1, "getCmdLineOpts() failed" return command_line def camel_to_snake(camel): # Regex to convert CamelCase to snake_case. - snake = re.sub('(.)([A-Z][a-z]+)', r'\1_\2', camel) - return re.sub('([a-z0-9])([A-Z])', r'\1_\2', snake).lower() + snake = re.sub("(.)([A-Z][a-z]+)", r"\1_\2", camel) + return re.sub("([a-z0-9])([A-Z])", r"\1_\2", snake).lower() def camel_to_upper_camel(camel): @@ -640,21 +626,18 @@ def camel_to_snake_args(arguments): def snake_to_camel(snake): # Regex to convert snake_case to lowerCamelCase. - return re.sub(r'_([a-z])', lambda m: m.group(1).upper(), snake) + return re.sub(r"_([a-z])", lambda m: m.group(1).upper(), snake) def parse_collection_options(opts): - if 'readPreference' in opts: - opts['read_preference'] = parse_read_preference( - opts.pop('readPreference')) + if "readPreference" in opts: + opts["read_preference"] = parse_read_preference(opts.pop("readPreference")) - if 'writeConcern' in opts: - opts['write_concern'] = WriteConcern( - **dict(opts.pop('writeConcern'))) + if "writeConcern" in opts: + opts["write_concern"] = WriteConcern(**dict(opts.pop("writeConcern"))) - if 'readConcern' in opts: - opts['read_concern'] = ReadConcern( - **dict(opts.pop('readConcern'))) + if "readConcern" in opts: + opts["read_concern"] = ReadConcern(**dict(opts.pop("readConcern"))) return opts @@ -666,11 +649,11 @@ def server_started_with_option(client, cmdline_opt, config_opt): - `config_opt`: The config file option (i.e. nojournal) """ command_line = get_command_line(client) - if 'parsed' in command_line: - parsed = command_line['parsed'] + if "parsed" in command_line: + parsed = command_line["parsed"] if config_opt in parsed: return parsed[config_opt] - argv = command_line['argv'] + argv = command_line["argv"] return cmdline_opt in argv @@ -678,39 +661,37 @@ def server_started_with_auth(client): try: command_line = get_command_line(client) except OperationFailure as e: - msg = e.details.get('errmsg', '') - if e.code == 13 or 'unauthorized' in msg or 'login' in msg: + msg = e.details.get("errmsg", "") + if e.code == 13 or "unauthorized" in msg or "login" in msg: # Unauthorized. return True raise # MongoDB >= 2.0 - if 'parsed' in command_line: - parsed = command_line['parsed'] + if "parsed" in command_line: + parsed = command_line["parsed"] # MongoDB >= 2.6 - if 'security' in parsed: - security = parsed['security'] + if "security" in parsed: + security = parsed["security"] # >= rc3 - if 'authorization' in security: - return security['authorization'] == 'enabled' + if "authorization" in security: + return security["authorization"] == "enabled" # < rc3 - return security.get('auth', False) or bool(security.get('keyFile')) - return parsed.get('auth', False) or bool(parsed.get('keyFile')) + return security.get("auth", False) or bool(security.get("keyFile")) + return parsed.get("auth", False) or bool(parsed.get("keyFile")) # Legacy - argv = command_line['argv'] - return '--auth' in argv or '--keyFile' in argv + argv = command_line["argv"] + return "--auth" in argv or "--keyFile" in argv def drop_collections(db): # Drop all non-system collections in this database. - for coll in db.list_collection_names( - filter={"name": {"$regex": r"^(?!system\.)"}}): + for coll in db.list_collection_names(filter={"name": {"$regex": r"^(?!system\.)"}}): db.drop_collection(coll) def remove_all_users(db): - db.command("dropAllUsersFromDatabase", 1, - writeConcern={"w": client_context.w}) + db.command("dropAllUsersFromDatabase", 1, writeConcern={"w": client_context.w}) def joinall(threads): @@ -726,7 +707,7 @@ def connected(client): # Ignore warning that ping is always routed to primary even # if client's read preference isn't PRIMARY. warnings.simplefilter("ignore", UserWarning) - client.admin.command('ping') # Force connection. + client.admin.command("ping") # Force connection. return client @@ -745,7 +726,7 @@ def wait_until(predicate, success_description, timeout=10): Returns the predicate's first true value. """ start = time.time() - interval = min(float(timeout)/100, 0.1) + interval = min(float(timeout) / 100, 0.1) while True: retval = predicate() if retval: @@ -759,17 +740,17 @@ def wait_until(predicate, success_description, timeout=10): def repl_set_step_down(client, **kwargs): """Run replSetStepDown, first unfreezing a secondary with replSetFreeze.""" - cmd = SON([('replSetStepDown', 1)]) + cmd = SON([("replSetStepDown", 1)]) cmd.update(kwargs) # Unfreeze a secondary to ensure a speedy election. - client.admin.command( - 'replSetFreeze', 0, read_preference=ReadPreference.SECONDARY) + client.admin.command("replSetFreeze", 0, read_preference=ReadPreference.SECONDARY) client.admin.command(cmd) + def is_mongos(client): res = client.admin.command(HelloCompat.LEGACY_CMD) - return res.get('msg', '') == 'isdbgrid' + return res.get("msg", "") == "isdbgrid" def assertRaisesExactly(cls, fn, *args, **kwargs): @@ -781,8 +762,7 @@ def assertRaisesExactly(cls, fn, *args, **kwargs): try: fn(*args, **kwargs) except Exception as e: - assert e.__class__ == cls, "got %s, expected %s" % ( - e.__class__.__name__, cls.__name__) + assert e.__class__ == cls, "got %s, expected %s" % (e.__class__.__name__, cls.__name__) else: raise AssertionError("%s not raised" % cls) @@ -797,6 +777,7 @@ def _ignore_deprecations(): def ignore_deprecations(wrapped=None): """A context manager or a decorator.""" if wrapped: + @functools.wraps(wrapped) def wrapper(*args, **kwargs): with _ignore_deprecations(): @@ -809,7 +790,6 @@ def wrapper(*args, **kwargs): class DeprecationFilter(object): - def __init__(self, action="ignore"): """Start filtering deprecations.""" self.warn_context = warnings.catch_warnings() @@ -831,9 +811,7 @@ def get_pool(client): def get_pools(client): """Get all pools.""" - return [ - server.pool for server in - client._get_topology().select_servers(any_server_selector)] + return [server.pool for server in client._get_topology().select_servers(any_server_selector)] # Constants for run_threads and lazy_client_trial. @@ -863,8 +841,8 @@ def run_threads(collection, target): def frequent_thread_switches(): """Make concurrency bugs more likely to manifest.""" interval = None - if not sys.platform.startswith('java'): - if hasattr(sys, 'getswitchinterval'): + if not sys.platform.startswith("java"): + if hasattr(sys, "getswitchinterval"): interval = sys.getswitchinterval() sys.setswitchinterval(1e-6) else: @@ -874,8 +852,8 @@ def frequent_thread_switches(): try: yield finally: - if not sys.platform.startswith('java'): - if hasattr(sys, 'setswitchinterval'): + if not sys.platform.startswith("java"): + if hasattr(sys, "setswitchinterval"): sys.setswitchinterval(interval) else: sys.setcheckinterval(interval) @@ -910,7 +888,9 @@ def gevent_monkey_patched(): warnings.simplefilter("ignore", ImportWarning) try: import socket + import gevent.socket + return socket.socket is gevent.socket.socket except ImportError: return False @@ -919,8 +899,8 @@ def gevent_monkey_patched(): def eventlet_monkey_patched(): """Check if eventlet's monkey patching is active.""" import threading - return (threading.current_thread.__module__ == - 'eventlet.green.threading') + + return threading.current_thread.__module__ == "eventlet.green.threading" def is_greenthread_patched(): @@ -931,20 +911,19 @@ def disable_replication(client): """Disable replication on all secondaries, requires MongoDB 3.2.""" for host, port in client.secondaries: secondary = single_client(host, port) - secondary.admin.command('configureFailPoint', 'stopReplProducer', - mode='alwaysOn') + secondary.admin.command("configureFailPoint", "stopReplProducer", mode="alwaysOn") def enable_replication(client): """Enable replication on all secondaries, requires MongoDB 3.2.""" for host, port in client.secondaries: secondary = single_client(host, port) - secondary.admin.command('configureFailPoint', 'stopReplProducer', - mode='off') + secondary.admin.command("configureFailPoint", "stopReplProducer", mode="off") class ExceptionCatchingThread(threading.Thread): """A thread that stores any exception encountered from run().""" + def __init__(self, *args, **kwargs): self.exc = None super(ExceptionCatchingThread, self).__init__(*args, **kwargs) @@ -959,13 +938,14 @@ def run(self): def parse_read_preference(pref): # Make first letter lowercase to match read_pref's modes. - mode_string = pref.get('mode', 'primary') + mode_string = pref.get("mode", "primary") mode_string = mode_string[:1].lower() + mode_string[1:] mode = read_preferences.read_pref_mode_from_name(mode_string) - max_staleness = pref.get('maxStalenessSeconds', -1) - tag_sets = pref.get('tag_sets') + max_staleness = pref.get("maxStalenessSeconds", -1) + tag_sets = pref.get("tag_sets") return read_preferences.make_read_preference( - mode, tag_sets=tag_sets, max_staleness=max_staleness) + mode, tag_sets=tag_sets, max_staleness=max_staleness + ) def server_name_to_type(name): @@ -973,16 +953,16 @@ def server_name_to_type(name): # Special case, some tests in the spec include the PossiblePrimary # type, but only single-threaded drivers need that type. We call # possible primaries Unknown. - if name == 'PossiblePrimary': + if name == "PossiblePrimary": return SERVER_TYPE.Unknown return getattr(SERVER_TYPE, name) def cat_files(dest, *sources): """Cat multiple files into dest.""" - with open(dest, 'wb') as fdst: + with open(dest, "wb") as fdst: for src in sources: - with open(src, 'rb') as fsrc: + with open(src, "rb") as fsrc: shutil.copyfileobj(fsrc, fdst) @@ -992,64 +972,60 @@ def assertion_context(msg): try: yield except AssertionError as exc: - msg = '%s (%s)' % (exc, msg) + msg = "%s (%s)" % (exc, msg) exc_type, exc_val, exc_tb = sys.exc_info() raise exc_type(exc_val).with_traceback(exc_tb) def parse_spec_options(opts): - if 'readPreference' in opts: - opts['read_preference'] = parse_read_preference( - opts.pop('readPreference')) + if "readPreference" in opts: + opts["read_preference"] = parse_read_preference(opts.pop("readPreference")) - if 'writeConcern' in opts: - opts['write_concern'] = WriteConcern( - **dict(opts.pop('writeConcern'))) + if "writeConcern" in opts: + opts["write_concern"] = WriteConcern(**dict(opts.pop("writeConcern"))) - if 'readConcern' in opts: - opts['read_concern'] = ReadConcern( - **dict(opts.pop('readConcern'))) + if "readConcern" in opts: + opts["read_concern"] = ReadConcern(**dict(opts.pop("readConcern"))) - if 'maxTimeMS' in opts: - opts['max_time_ms'] = opts.pop('maxTimeMS') + if "maxTimeMS" in opts: + opts["max_time_ms"] = opts.pop("maxTimeMS") - if 'maxCommitTimeMS' in opts: - opts['max_commit_time_ms'] = opts.pop('maxCommitTimeMS') + if "maxCommitTimeMS" in opts: + opts["max_commit_time_ms"] = opts.pop("maxCommitTimeMS") - if 'hint' in opts: - hint = opts.pop('hint') + if "hint" in opts: + hint = opts.pop("hint") if not isinstance(hint, str): hint = list(hint.items()) - opts['hint'] = hint + opts["hint"] = hint # Properly format 'hint' arguments for the Bulk API tests. - if 'requests' in opts: - reqs = opts.pop('requests') + if "requests" in opts: + reqs = opts.pop("requests") for req in reqs: - if 'name' in req: + if "name" in req: # CRUD v2 format - args = req.pop('arguments', {}) - if 'hint' in args: - hint = args.pop('hint') + args = req.pop("arguments", {}) + if "hint" in args: + hint = args.pop("hint") if not isinstance(hint, str): hint = list(hint.items()) - args['hint'] = hint - req['arguments'] = args + args["hint"] = hint + req["arguments"] = args else: # Unified test format bulk_model, spec = next(iter(req.items())) - if 'hint' in spec: - hint = spec.pop('hint') + if "hint" in spec: + hint = spec.pop("hint") if not isinstance(hint, str): hint = list(hint.items()) - spec['hint'] = hint - opts['requests'] = reqs + spec["hint"] = hint + opts["requests"] = reqs return dict(opts) -def prepare_spec_arguments(spec, arguments, opname, entity_map, - with_txn_callback): +def prepare_spec_arguments(spec, arguments, opname, entity_map, with_txn_callback): for arg_name in list(arguments): c2s = camel_to_snake(arg_name) # PyMongo accepts sort as list of tuples. @@ -1060,8 +1036,7 @@ def prepare_spec_arguments(spec, arguments, opname, entity_map, if arg_name == "fieldName": arguments["key"] = arguments.pop(arg_name) # Aggregate uses "batchSize", while find uses batch_size. - elif ((arg_name == "batchSize" or arg_name == "allowDiskUse") and - opname == "aggregate"): + elif (arg_name == "batchSize" or arg_name == "allowDiskUse") and opname == "aggregate": continue # Requires boolean returnDocument. elif arg_name == "returnDocument": @@ -1070,7 +1045,7 @@ def prepare_spec_arguments(spec, arguments, opname, entity_map, # Parse each request into a bulk write model. requests = [] for request in arguments["requests"]: - if 'name' in request: + if "name" in request: # CRUD v2 format bulk_model = camel_to_upper_camel(request["name"]) bulk_class = getattr(operations, bulk_model) @@ -1083,39 +1058,37 @@ def prepare_spec_arguments(spec, arguments, opname, entity_map, requests.append(bulk_class(**dict(bulk_arguments))) arguments["requests"] = requests elif arg_name == "session": - arguments['session'] = entity_map[arguments['session']] - elif (opname in ('command', 'run_admin_command') and - arg_name == 'command'): + arguments["session"] = entity_map[arguments["session"]] + elif opname in ("command", "run_admin_command") and arg_name == "command": # Ensure the first key is the command name. - ordered_command = SON([(spec['command_name'], 1)]) - ordered_command.update(arguments['command']) - arguments['command'] = ordered_command - elif opname == 'open_download_stream' and arg_name == 'id': - arguments['file_id'] = arguments.pop(arg_name) - elif opname != 'find' and c2s == 'max_time_ms': + ordered_command = SON([(spec["command_name"], 1)]) + ordered_command.update(arguments["command"]) + arguments["command"] = ordered_command + elif opname == "open_download_stream" and arg_name == "id": + arguments["file_id"] = arguments.pop(arg_name) + elif opname != "find" and c2s == "max_time_ms": # find is the only method that accepts snake_case max_time_ms. # All other methods take kwargs which must use the server's # camelCase maxTimeMS. See PYTHON-1855. - arguments['maxTimeMS'] = arguments.pop('max_time_ms') - elif opname == 'with_transaction' and arg_name == 'callback': - if 'operations' in arguments[arg_name]: + arguments["maxTimeMS"] = arguments.pop("max_time_ms") + elif opname == "with_transaction" and arg_name == "callback": + if "operations" in arguments[arg_name]: # CRUD v2 format - callback_ops = arguments[arg_name]['operations'] + callback_ops = arguments[arg_name]["operations"] else: # Unified test format callback_ops = arguments[arg_name] - arguments['callback'] = lambda _: with_txn_callback( - copy.deepcopy(callback_ops)) - elif opname == 'drop_collection' and arg_name == 'collection': - arguments['name_or_collection'] = arguments.pop(arg_name) - elif opname == 'create_collection': - if arg_name == 'collection': - arguments['name'] = arguments.pop(arg_name) + arguments["callback"] = lambda _: with_txn_callback(copy.deepcopy(callback_ops)) + elif opname == "drop_collection" and arg_name == "collection": + arguments["name_or_collection"] = arguments.pop(arg_name) + elif opname == "create_collection": + if arg_name == "collection": + arguments["name"] = arguments.pop(arg_name) # Any other arguments to create_collection are passed through # **kwargs. - elif opname == 'create_index' and arg_name == 'keys': - arguments['keys'] = list(arguments.pop(arg_name).items()) - elif opname == 'drop_index' and arg_name == 'name': - arguments['index_or_name'] = arguments.pop(arg_name) + elif opname == "create_index" and arg_name == "keys": + arguments["keys"] = list(arguments.pop(arg_name).items()) + elif opname == "drop_index" and arg_name == "name": + arguments["index_or_name"] = arguments.pop(arg_name) else: arguments[c2s] = arguments.pop(arg_name) diff --git a/test/utils_selection_tests.py b/test/utils_selection_tests.py index 76125b6f15..e693fc25f0 100644 --- a/test/utils_selection_tests.py +++ b/test/utils_selection_tests.py @@ -20,37 +20,37 @@ sys.path[0:0] = [""] +from test import unittest +from test.pymongo_mocks import DummyMonitor +from test.utils import MockPool, parse_read_preference + from bson import json_util -from pymongo.common import clean_node, HEARTBEAT_FREQUENCY +from pymongo.common import HEARTBEAT_FREQUENCY, clean_node from pymongo.errors import AutoReconnect, ConfigurationError from pymongo.hello import Hello, HelloCompat from pymongo.server_description import ServerDescription -from pymongo.settings import TopologySettings from pymongo.server_selectors import writable_server_selector +from pymongo.settings import TopologySettings from pymongo.topology import Topology -from test import unittest -from test.utils import MockPool, parse_read_preference -from test.pymongo_mocks import DummyMonitor def get_addresses(server_list): seeds = [] hosts = [] for server in server_list: - seeds.append(clean_node(server['address'])) - hosts.append(server['address']) + seeds.append(clean_node(server["address"])) + hosts.append(server["address"]) return seeds, hosts def make_last_write_date(server): epoch = datetime.datetime.utcfromtimestamp(0) - millis = server.get('lastWrite', {}).get('lastWriteDate') + millis = server.get("lastWrite", {}).get("lastWriteDate") if millis: diff = ((millis % 1000) + 1000) % 1000 seconds = (millis - diff) / 1000 micros = diff * 1000 - return epoch + datetime.timedelta( - seconds=seconds, microseconds=micros) + return epoch + datetime.timedelta(seconds=seconds, microseconds=micros) else: # "Unknown" server. return epoch @@ -58,61 +58,59 @@ def make_last_write_date(server): def make_server_description(server, hosts): """Make a ServerDescription from server info in a JSON test.""" - server_type = server['type'] + server_type = server["type"] if server_type in ("Unknown", "PossiblePrimary"): - return ServerDescription(clean_node(server['address']), Hello({})) + return ServerDescription(clean_node(server["address"]), Hello({})) - hello_response = {'ok': True, 'hosts': hosts} + hello_response = {"ok": True, "hosts": hosts} if server_type not in ("Standalone", "Mongos", "RSGhost"): - hello_response['setName'] = "rs" + hello_response["setName"] = "rs" if server_type == "RSPrimary": hello_response[HelloCompat.LEGACY_CMD] = True elif server_type == "RSSecondary": - hello_response['secondary'] = True + hello_response["secondary"] = True elif server_type == "Mongos": - hello_response['msg'] = 'isdbgrid' + hello_response["msg"] = "isdbgrid" elif server_type == "RSGhost": - hello_response['isreplicaset'] = True + hello_response["isreplicaset"] = True elif server_type == "RSArbiter": - hello_response['arbiterOnly'] = True + hello_response["arbiterOnly"] = True - hello_response['lastWrite'] = { - 'lastWriteDate': make_last_write_date(server) - } + hello_response["lastWrite"] = {"lastWriteDate": make_last_write_date(server)} - for field in 'maxWireVersion', 'tags', 'idleWritePeriodMillis': + for field in "maxWireVersion", "tags", "idleWritePeriodMillis": if field in server: hello_response[field] = server[field] - hello_response.setdefault('maxWireVersion', 6) + hello_response.setdefault("maxWireVersion", 6) # Sets _last_update_time to now. - sd = ServerDescription(clean_node(server['address']), - Hello(hello_response), - round_trip_time=server['avg_rtt_ms'] / 1000.0) + sd = ServerDescription( + clean_node(server["address"]), + Hello(hello_response), + round_trip_time=server["avg_rtt_ms"] / 1000.0, + ) - if 'lastUpdateTime' in server: - sd._last_update_time = server['lastUpdateTime'] / 1000.0 # ms to sec. + if "lastUpdateTime" in server: + sd._last_update_time = server["lastUpdateTime"] / 1000.0 # ms to sec. return sd def get_topology_type_name(scenario_def): - td = scenario_def['topology_description'] - name = td['type'] - if name == 'Unknown': + td = scenario_def["topology_description"] + name = td["type"] + if name == "Unknown": # PyMongo never starts a topology in type Unknown. - return 'Sharded' if len(td['servers']) > 1 else 'Single' + return "Sharded" if len(td["servers"]) > 1 else "Single" else: return name def get_topology_settings_dict(**kwargs): settings = dict( - monitor_class=DummyMonitor, - heartbeat_frequency=HEARTBEAT_FREQUENCY, - pool_class=MockPool + monitor_class=DummyMonitor, heartbeat_frequency=HEARTBEAT_FREQUENCY, pool_class=MockPool ) settings.update(kwargs) return settings @@ -120,25 +118,20 @@ def get_topology_settings_dict(**kwargs): def create_topology(scenario_def, **kwargs): # Initialize topologies. - if 'heartbeatFrequencyMS' in scenario_def: - frequency = int(scenario_def['heartbeatFrequencyMS']) / 1000.0 + if "heartbeatFrequencyMS" in scenario_def: + frequency = int(scenario_def["heartbeatFrequencyMS"]) / 1000.0 else: frequency = HEARTBEAT_FREQUENCY - seeds, hosts = get_addresses( - scenario_def['topology_description']['servers']) + seeds, hosts = get_addresses(scenario_def["topology_description"]["servers"]) topology_type = get_topology_type_name(scenario_def) - if topology_type == 'LoadBalanced': - kwargs.setdefault('load_balanced', True) + if topology_type == "LoadBalanced": + kwargs.setdefault("load_balanced", True) # Force topology description to ReplicaSet - elif topology_type in ['ReplicaSetNoPrimary', 'ReplicaSetWithPrimary']: - kwargs.setdefault('replica_set_name', 'rs') - settings = get_topology_settings_dict( - heartbeat_frequency=frequency, - seeds=seeds, - **kwargs - ) + elif topology_type in ["ReplicaSetNoPrimary", "ReplicaSetWithPrimary"]: + kwargs.setdefault("replica_set_name", "rs") + settings = get_topology_settings_dict(heartbeat_frequency=frequency, seeds=seeds, **kwargs) # "Eligible servers" is defined in the server selection spec as # the set of servers matching both the ReadPreference's mode @@ -147,21 +140,21 @@ def create_topology(scenario_def, **kwargs): topology.open() # Update topologies with server descriptions. - for server in scenario_def['topology_description']['servers']: + for server in scenario_def["topology_description"]["servers"]: server_description = make_server_description(server, hosts) topology.on_change(server_description) # Assert that descriptions match - assert (scenario_def['topology_description']['type'] == - topology.description.topology_type_name), topology.description.topology_type_name + assert ( + scenario_def["topology_description"]["type"] == topology.description.topology_type_name + ), topology.description.topology_type_name return topology def create_test(scenario_def): def run_scenario(self): - _, hosts = get_addresses( - scenario_def['topology_description']['servers']) + _, hosts = get_addresses(scenario_def["topology_description"]["servers"]) # "Eligible servers" is defined in the server selection spec as # the set of servers matching both the ReadPreference's mode # and tag sets. @@ -170,16 +163,15 @@ def run_scenario(self): # "In latency window" is defined in the server selection # spec as the subset of suitable_servers that falls within the # allowable latency window. - top_suitable = create_topology( - scenario_def, local_threshold_ms=1000000) + top_suitable = create_topology(scenario_def, local_threshold_ms=1000000) # Create server selector. if scenario_def.get("operation") == "write": pref = writable_server_selector else: # Make first letter lowercase to match read_pref's modes. - pref_def = scenario_def['read_preference'] - if scenario_def.get('error'): + pref_def = scenario_def["read_preference"] + if scenario_def.get("error"): with self.assertRaises((ConfigurationError, ValueError)): # Error can be raised when making Read Pref or selecting. pref = parse_read_preference(pref_def) @@ -189,35 +181,33 @@ def run_scenario(self): pref = parse_read_preference(pref_def) # Select servers. - if not scenario_def.get('suitable_servers'): + if not scenario_def.get("suitable_servers"): with self.assertRaises(AutoReconnect): top_suitable.select_server(pref, server_selection_timeout=0) return - if not scenario_def['in_latency_window']: + if not scenario_def["in_latency_window"]: with self.assertRaises(AutoReconnect): top_latency.select_server(pref, server_selection_timeout=0) return - actual_suitable_s = top_suitable.select_servers( - pref, server_selection_timeout=0) - actual_latency_s = top_latency.select_servers( - pref, server_selection_timeout=0) + actual_suitable_s = top_suitable.select_servers(pref, server_selection_timeout=0) + actual_latency_s = top_latency.select_servers(pref, server_selection_timeout=0) expected_suitable_servers = {} - for server in scenario_def['suitable_servers']: + for server in scenario_def["suitable_servers"]: server_description = make_server_description(server, hosts) - expected_suitable_servers[server['address']] = server_description + expected_suitable_servers[server["address"]] = server_description actual_suitable_servers = {} for s in actual_suitable_s: - actual_suitable_servers["%s:%d" % (s.description.address[0], - s.description.address[1])] = s.description + actual_suitable_servers[ + "%s:%d" % (s.description.address[0], s.description.address[1]) + ] = s.description - self.assertEqual(len(actual_suitable_servers), - len(expected_suitable_servers)) + self.assertEqual(len(actual_suitable_servers), len(expected_suitable_servers)) for k, actual in actual_suitable_servers.items(): expected = expected_suitable_servers[k] self.assertEqual(expected.address, actual.address) @@ -227,18 +217,17 @@ def run_scenario(self): self.assertEqual(expected.all_hosts, actual.all_hosts) expected_latency_servers = {} - for server in scenario_def['in_latency_window']: + for server in scenario_def["in_latency_window"]: server_description = make_server_description(server, hosts) - expected_latency_servers[server['address']] = server_description + expected_latency_servers[server["address"]] = server_description actual_latency_servers = {} for s in actual_latency_s: - actual_latency_servers["%s:%d" % - (s.description.address[0], - s.description.address[1])] = s.description + actual_latency_servers[ + "%s:%d" % (s.description.address[0], s.description.address[1]) + ] = s.description - self.assertEqual(len(actual_latency_servers), - len(expected_latency_servers)) + self.assertEqual(len(actual_latency_servers), len(expected_latency_servers)) for k, actual in actual_latency_servers.items(): expected = expected_latency_servers[k] self.assertEqual(expected.address, actual.address) @@ -256,7 +245,7 @@ class TestAllScenarios(unittest.TestCase): for dirpath, _, filenames in os.walk(test_dir): dirname = os.path.split(dirpath) - dirname = os.path.split(dirname[-2])[-1] + '_' + dirname[-1] + dirname = os.path.split(dirname[-2])[-1] + "_" + dirname[-1] for filename in filenames: if os.path.splitext(filename)[1] != ".json": @@ -266,8 +255,7 @@ class TestAllScenarios(unittest.TestCase): # Construct test from scenario. new_test = create_test(scenario_def) - test_name = 'test_%s_%s' % ( - dirname, os.path.splitext(filename)[0]) + test_name = "test_%s_%s" % (dirname, os.path.splitext(filename)[0]) new_test.__name__ = test_name setattr(TestAllScenarios, new_test.__name__, new_test) diff --git a/test/utils_spec_runner.py b/test/utils_spec_runner.py index f3a9c4390a..e626cefa49 100644 --- a/test/utils_spec_runner.py +++ b/test/utils_spec_runner.py @@ -16,40 +16,34 @@ import functools import threading - from collections import abc +from test import IntegrationTest, client_context, client_knobs +from test.utils import ( + CMAPListener, + CompareType, + OvertCommandListener, + ServerAndTopologyEventListener, + camel_to_snake, + camel_to_snake_args, + parse_spec_options, + prepare_spec_arguments, + rs_client, +) from bson import decode, encode from bson.binary import Binary from bson.int64 import Int64 from bson.son import SON - from gridfs import GridFSBucket - from pymongo import client_session from pymongo.command_cursor import CommandCursor from pymongo.cursor import Cursor -from pymongo.errors import (BulkWriteError, - OperationFailure, - PyMongoError) +from pymongo.errors import BulkWriteError, OperationFailure, PyMongoError from pymongo.read_concern import ReadConcern from pymongo.read_preferences import ReadPreference -from pymongo.results import _WriteResult, BulkWriteResult +from pymongo.results import BulkWriteResult, _WriteResult from pymongo.write_concern import WriteConcern -from test import (client_context, - client_knobs, - IntegrationTest) -from test.utils import (camel_to_snake, - camel_to_snake_args, - CompareType, - CMAPListener, - OvertCommandListener, - parse_spec_options, - prepare_spec_arguments, - rs_client, - ServerAndTopologyEventListener) - class SpecRunnerThread(threading.Thread): def __init__(self, name): @@ -73,7 +67,7 @@ def stop(self): def run(self): while not self.stopped or self.ops: - if not self. ops: + if not self.ops: with self.cond: self.cond.wait(10) if self.ops: @@ -86,15 +80,13 @@ def run(self): class SpecRunner(IntegrationTest): - @classmethod def setUpClass(cls): super(SpecRunner, cls).setUpClass() cls.mongos_clients = [] # Speed up the tests by decreasing the heartbeat frequency. - cls.knobs = client_knobs(heartbeat_frequency=0.1, - min_heartbeat_interval=0.1) + cls.knobs = client_knobs(heartbeat_frequency=0.1, min_heartbeat_interval=0.1) cls.knobs.enable() @classmethod @@ -111,7 +103,7 @@ def setUp(self): self.maxDiff = None def _set_fail_point(self, client, command_args): - cmd = SON([('configureFailPoint', 'failCommand')]) + cmd = SON([("configureFailPoint", "failCommand")]) cmd.update(command_args) client.admin.command(cmd) @@ -128,7 +120,7 @@ def targeted_fail_point(self, session, fail_point): clients = {c.address: c for c in self.mongos_clients} client = clients[session._pinned_address] self._set_fail_point(client, fail_point) - self.addCleanup(self.set_fail_point, {'mode': 'off'}) + self.addCleanup(self.set_fail_point, {"mode": "off"}) def assert_session_pinned(self, session): """Run the assertSessionPinned test operation. @@ -158,12 +150,12 @@ def assert_collection_not_exists(self, database, collection): def assert_index_exists(self, database, collection, index): """Run the assertIndexExists test operation.""" coll = self.client[database][collection] - self.assertIn(index, [doc['name'] for doc in coll.list_indexes()]) + self.assertIn(index, [doc["name"] for doc in coll.list_indexes()]) def assert_index_not_exists(self, database, collection, index): """Run the assertIndexNotExists test operation.""" coll = self.client[database][collection] - self.assertNotIn(index, [doc['name'] for doc in coll.list_indexes()]) + self.assertNotIn(index, [doc["name"] for doc in coll.list_indexes()]) def assertErrorLabelsContain(self, exc, expected_labels): labels = [l for l in expected_labels if exc.has_error_label(l)] @@ -172,14 +164,14 @@ def assertErrorLabelsContain(self, exc, expected_labels): def assertErrorLabelsOmit(self, exc, omit_labels): for label in omit_labels: self.assertFalse( - exc.has_error_label(label), - msg='error labels should not contain %s' % (label,)) + exc.has_error_label(label), msg="error labels should not contain %s" % (label,) + ) def kill_all_sessions(self): clients = self.mongos_clients if self.mongos_clients else [self.client] for client in clients: try: - client.admin.command('killAllSessions', []) + client.admin.command("killAllSessions", []) except OperationFailure: # "operation was interrupted" by killing the command's # own session. @@ -201,8 +193,7 @@ def check_result(self, expected_result, result): for res in expected_result: prop = camel_to_snake(res) # SPEC-869: Only BulkWriteResult has upserted_count. - if (prop == "upserted_count" - and not isinstance(result, BulkWriteResult)): + if prop == "upserted_count" and not isinstance(result, BulkWriteResult): if result.upserted_id is not None: upserted_count = 1 else: @@ -211,8 +202,7 @@ def check_result(self, expected_result, result): elif prop == "inserted_ids": # BulkWriteResult does not have inserted_ids. if isinstance(result, BulkWriteResult): - self.assertEqual(len(expected_result[res]), - result.inserted_count) + self.assertEqual(len(expected_result[res]), result.inserted_count) else: # InsertManyResult may be compared to [id1] from the # crud spec or {"0": id1} from the retryable write spec. @@ -228,8 +218,7 @@ def check_result(self, expected_result, result): expected_ids[int(str_index)] = ids[str_index] self.assertEqual(expected_ids, result.upserted_ids, prop) else: - self.assertEqual( - getattr(result, prop), expected_result[res], prop) + self.assertEqual(getattr(result, prop), expected_result[res], prop) return True else: @@ -240,7 +229,7 @@ def get_object_name(self, op): Transaction spec says 'object' is required. """ - return op['object'] + return op["object"] @staticmethod def parse_options(opts): @@ -248,54 +237,54 @@ def parse_options(opts): def run_operation(self, sessions, collection, operation): original_collection = collection - name = camel_to_snake(operation['name']) - if name == 'run_command': - name = 'command' - elif name == 'download_by_name': - name = 'open_download_stream_by_name' - elif name == 'download': - name = 'open_download_stream' - elif name == 'map_reduce': - self.skipTest('PyMongo does not support mapReduce') - elif name == 'count': - self.skipTest('PyMongo does not support count') + name = camel_to_snake(operation["name"]) + if name == "run_command": + name = "command" + elif name == "download_by_name": + name = "open_download_stream_by_name" + elif name == "download": + name = "open_download_stream" + elif name == "map_reduce": + self.skipTest("PyMongo does not support mapReduce") + elif name == "count": + self.skipTest("PyMongo does not support count") database = collection.database collection = database.get_collection(collection.name) - if 'collectionOptions' in operation: + if "collectionOptions" in operation: collection = collection.with_options( - **self.parse_options(operation['collectionOptions'])) + **self.parse_options(operation["collectionOptions"]) + ) object_name = self.get_object_name(operation) - if object_name == 'gridfsbucket': + if object_name == "gridfsbucket": # Only create the GridFSBucket when we need it (for the gridfs # retryable reads tests). obj = GridFSBucket(database, bucket_name=collection.name) else: objects = { - 'client': database.client, - 'database': database, - 'collection': collection, - 'testRunner': self + "client": database.client, + "database": database, + "collection": collection, + "testRunner": self, } objects.update(sessions) obj = objects[object_name] # Combine arguments with options and handle special cases. - arguments = operation.get('arguments', {}) + arguments = operation.get("arguments", {}) arguments.update(arguments.pop("options", {})) self.parse_options(arguments) cmd = getattr(obj, name) with_txn_callback = functools.partial( - self.run_operations, sessions, original_collection, - in_with_transaction=True) - prepare_spec_arguments(operation, arguments, name, sessions, - with_txn_callback) + self.run_operations, sessions, original_collection, in_with_transaction=True + ) + prepare_spec_arguments(operation, arguments, name, sessions, with_txn_callback) - if name == 'run_on_thread': - args = {'sessions': sessions, 'collection': collection} + if name == "run_on_thread": + args = {"sessions": sessions, "collection": collection} args.update(arguments) arguments = args result = cmd(**dict(arguments)) @@ -308,10 +297,10 @@ def run_operation(self, sessions, collection, operation): if arguments["pipeline"] and "$out" in arguments["pipeline"][-1]: # Read from the primary to ensure causal consistency. out = collection.database.get_collection( - arguments["pipeline"][-1]["$out"], - read_preference=ReadPreference.PRIMARY) + arguments["pipeline"][-1]["$out"], read_preference=ReadPreference.PRIMARY + ) return out.find() - if 'download' in name: + if "download" in name: result = Binary(result.read()) if isinstance(result, Cursor) or isinstance(result, CommandCursor): @@ -324,10 +313,9 @@ def allowable_errors(self, op): return (PyMongoError,) def _run_op(self, sessions, collection, op, in_with_transaction): - expected_result = op.get('result') + expected_result = op.get("result") if expect_error(op): - with self.assertRaises(self.allowable_errors(op), - msg=op['name']) as context: + with self.assertRaises(self.allowable_errors(op), msg=op["name"]) as context: self.run_operation(sessions, collection, op.copy()) if expect_error_message(expected_result): @@ -335,19 +323,17 @@ def _run_op(self, sessions, collection, op, in_with_transaction): errmsg = str(context.exception.details).lower() else: errmsg = str(context.exception).lower() - self.assertIn(expected_result['errorContains'].lower(), - errmsg) + self.assertIn(expected_result["errorContains"].lower(), errmsg) if expect_error_code(expected_result): - self.assertEqual(expected_result['errorCodeName'], - context.exception.details.get('codeName')) + self.assertEqual( + expected_result["errorCodeName"], context.exception.details.get("codeName") + ) if expect_error_labels_contain(expected_result): self.assertErrorLabelsContain( - context.exception, - expected_result['errorLabelsContain']) + context.exception, expected_result["errorLabelsContain"] + ) if expect_error_labels_omit(expected_result): - self.assertErrorLabelsOmit( - context.exception, - expected_result['errorLabelsOmit']) + self.assertErrorLabelsOmit(context.exception, expected_result["errorLabelsOmit"]) # Reraise the exception if we're in the with_transaction # callback. @@ -355,65 +341,61 @@ def _run_op(self, sessions, collection, op, in_with_transaction): raise context.exception else: result = self.run_operation(sessions, collection, op.copy()) - if 'result' in op: - if op['name'] == 'runCommand': + if "result" in op: + if op["name"] == "runCommand": self.check_command_result(expected_result, result) else: self.check_result(expected_result, result) - def run_operations(self, sessions, collection, ops, - in_with_transaction=False): + def run_operations(self, sessions, collection, ops, in_with_transaction=False): for op in ops: self._run_op(sessions, collection, op, in_with_transaction) # TODO: factor with test_command_monitoring.py def check_events(self, test, listener, session_ids): res = listener.results - if not len(test['expectations']): + if not len(test["expectations"]): return # Give a nicer message when there are missing or extra events - cmds = decode_raw([event.command for event in res['started']]) - self.assertEqual( - len(res['started']), len(test['expectations']), cmds) - for i, expectation in enumerate(test['expectations']): + cmds = decode_raw([event.command for event in res["started"]]) + self.assertEqual(len(res["started"]), len(test["expectations"]), cmds) + for i, expectation in enumerate(test["expectations"]): event_type = next(iter(expectation)) - event = res['started'][i] + event = res["started"][i] # The tests substitute 42 for any number other than 0. - if (event.command_name == 'getMore' - and event.command['getMore']): - event.command['getMore'] = Int64(42) - elif event.command_name == 'killCursors': - event.command['cursors'] = [Int64(42)] - elif event.command_name == 'update': + if event.command_name == "getMore" and event.command["getMore"]: + event.command["getMore"] = Int64(42) + elif event.command_name == "killCursors": + event.command["cursors"] = [Int64(42)] + elif event.command_name == "update": # TODO: remove this once PYTHON-1744 is done. # Add upsert and multi fields back into expectations. - updates = expectation[event_type]['command']['updates'] + updates = expectation[event_type]["command"]["updates"] for update in updates: - update.setdefault('upsert', False) - update.setdefault('multi', False) + update.setdefault("upsert", False) + update.setdefault("multi", False) # Replace afterClusterTime: 42 with actual afterClusterTime. - expected_cmd = expectation[event_type]['command'] - expected_read_concern = expected_cmd.get('readConcern') + expected_cmd = expectation[event_type]["command"] + expected_read_concern = expected_cmd.get("readConcern") if expected_read_concern is not None: - time = expected_read_concern.get('afterClusterTime') + time = expected_read_concern.get("afterClusterTime") if time == 42: - actual_time = event.command.get( - 'readConcern', {}).get('afterClusterTime') + actual_time = event.command.get("readConcern", {}).get("afterClusterTime") if actual_time is not None: - expected_read_concern['afterClusterTime'] = actual_time + expected_read_concern["afterClusterTime"] = actual_time - recovery_token = expected_cmd.get('recoveryToken') + recovery_token = expected_cmd.get("recoveryToken") if recovery_token == 42: - expected_cmd['recoveryToken'] = CompareType(dict) + expected_cmd["recoveryToken"] = CompareType(dict) # Replace lsid with a name like "session0" to match test. - if 'lsid' in event.command: + if "lsid" in event.command: for name, lsid in session_ids.items(): - if event.command['lsid'] == lsid: - event.command['lsid'] = name + if event.command["lsid"] == lsid: + event.command["lsid"] = name break for attr, expected in expectation[event_type].items(): @@ -423,28 +405,27 @@ def check_events(self, test, listener, session_ids): for key, val in expected.items(): if val is None: if key in actual: - self.fail("Unexpected key [%s] in %r" % ( - key, actual)) + self.fail("Unexpected key [%s] in %r" % (key, actual)) elif key not in actual: - self.fail("Expected key [%s] in %r" % ( - key, actual)) + self.fail("Expected key [%s] in %r" % (key, actual)) else: - self.assertEqual(val, decode_raw(actual[key]), - "Key [%s] in %s" % (key, actual)) + self.assertEqual( + val, decode_raw(actual[key]), "Key [%s] in %s" % (key, actual) + ) else: self.assertEqual(actual, expected) def maybe_skip_scenario(self, test): - if test.get('skipReason'): - self.skipTest(test.get('skipReason')) + if test.get("skipReason"): + self.skipTest(test.get("skipReason")) def get_scenario_db_name(self, scenario_def): """Allow subclasses to override a test's database name.""" - return scenario_def['database_name'] + return scenario_def["database_name"] def get_scenario_coll_name(self, scenario_def): """Allow subclasses to override a test's collection name.""" - return scenario_def['collection_name'] + return scenario_def["collection_name"] def get_outcome_coll_name(self, outcome, collection): """Allow subclasses to override outcome collection.""" @@ -453,7 +434,7 @@ def get_outcome_coll_name(self, outcome, collection): def run_test_ops(self, sessions, collection, test): """Added to allow retryable writes spec to override a test's operation.""" - self.run_operations(sessions, collection, test['operations']) + self.run_operations(sessions, collection, test["operations"]) def parse_client_options(self, opts): """Allow encryption spec to override a clientOptions parsing.""" @@ -465,14 +446,13 @@ def setup_scenario(self, scenario_def): """Allow specs to override a test's setup.""" db_name = self.get_scenario_db_name(scenario_def) coll_name = self.get_scenario_coll_name(scenario_def) - db = client_context.client.get_database( - db_name, write_concern=WriteConcern(w='majority')) + db = client_context.client.get_database(db_name, write_concern=WriteConcern(w="majority")) coll = db[coll_name] coll.drop() db.create_collection(coll_name) - if scenario_def['data']: + if scenario_def["data"]: # Load data. - coll.insert_many(scenario_def['data']) + coll.insert_many(scenario_def["data"]) def run_scenario(self, scenario_def, test): self.maybe_skip_scenario(test) @@ -490,22 +470,22 @@ def run_scenario(self, scenario_def, test): c[database_name][collection_name].distinct("x") # Configure the fail point before creating the client. - if 'failPoint' in test: - fp = test['failPoint'] + if "failPoint" in test: + fp = test["failPoint"] self.set_fail_point(fp) - self.addCleanup(self.set_fail_point, { - 'configureFailPoint': fp['configureFailPoint'], 'mode': 'off'}) + self.addCleanup( + self.set_fail_point, {"configureFailPoint": fp["configureFailPoint"], "mode": "off"} + ) listener = OvertCommandListener() pool_listener = CMAPListener() server_listener = ServerAndTopologyEventListener() # Create a new client, to avoid interference from pooled sessions. - client_options = self.parse_client_options(test['clientOptions']) + client_options = self.parse_client_options(test["clientOptions"]) # MMAPv1 does not support retryable writes. - if (client_options.get('retryWrites') is True and - client_context.storage_engine == 'mmapv1'): + if client_options.get("retryWrites") is True and client_context.storage_engine == "mmapv1": self.skipTest("MMAPv1 does not support retryWrites=True") - use_multi_mongos = test['useMultipleMongoses'] + use_multi_mongos = test["useMultipleMongoses"] host = None if use_multi_mongos: if client_context.load_balancer or client_context.serverless: @@ -513,9 +493,8 @@ def run_scenario(self, scenario_def, test): elif client_context.is_mongos: host = client_context.mongos_seeds() client = rs_client( - h=host, - event_listeners=[listener, pool_listener, server_listener], - **client_options) + h=host, event_listeners=[listener, pool_listener, server_listener], **client_options + ) self.scenario_client = client self.listener = listener self.pool_listener = pool_listener @@ -531,13 +510,12 @@ def run_scenario(self, scenario_def, test): # the running server version. if not client_context.sessions_enabled: break - session_name = 'session%d' % i - opts = camel_to_snake_args(test['sessionOptions'][session_name]) - if 'default_transaction_options' in opts: - txn_opts = self.parse_options( - opts['default_transaction_options']) + session_name = "session%d" % i + opts = camel_to_snake_args(test["sessionOptions"][session_name]) + if "default_transaction_options" in opts: + txn_opts = self.parse_options(opts["default_transaction_options"]) txn_opts = client_session.TransactionOptions(**txn_opts) - opts['default_transaction_options'] = txn_opts + opts["default_transaction_options"] = txn_opts s = client.start_session(**dict(opts)) @@ -555,74 +533,74 @@ def run_scenario(self, scenario_def, test): self.check_events(test, listener, session_ids) # Disable fail points. - if 'failPoint' in test: - fp = test['failPoint'] - self.set_fail_point({ - 'configureFailPoint': fp['configureFailPoint'], 'mode': 'off'}) + if "failPoint" in test: + fp = test["failPoint"] + self.set_fail_point({"configureFailPoint": fp["configureFailPoint"], "mode": "off"}) # Assert final state is expected. - outcome = test['outcome'] - expected_c = outcome.get('collection') + outcome = test["outcome"] + expected_c = outcome.get("collection") if expected_c is not None: - outcome_coll_name = self.get_outcome_coll_name( - outcome, collection) + outcome_coll_name = self.get_outcome_coll_name(outcome, collection) # Read from the primary with local read concern to ensure causal # consistency. - outcome_coll = client_context.client[ - collection.database.name].get_collection( + outcome_coll = client_context.client[collection.database.name].get_collection( outcome_coll_name, read_preference=ReadPreference.PRIMARY, - read_concern=ReadConcern('local')) - actual_data = list(outcome_coll.find(sort=[('_id', 1)])) + read_concern=ReadConcern("local"), + ) + actual_data = list(outcome_coll.find(sort=[("_id", 1)])) # The expected data needs to be the left hand side here otherwise # CompareType(Binary) doesn't work. - self.assertEqual(wrap_types(expected_c['data']), actual_data) + self.assertEqual(wrap_types(expected_c["data"]), actual_data) def expect_any_error(op): if isinstance(op, dict): - return op.get('error') + return op.get("error") return False def expect_error_message(expected_result): if isinstance(expected_result, dict): - return isinstance(expected_result['errorContains'], str) + return isinstance(expected_result["errorContains"], str) return False def expect_error_code(expected_result): if isinstance(expected_result, dict): - return expected_result['errorCodeName'] + return expected_result["errorCodeName"] return False def expect_error_labels_contain(expected_result): if isinstance(expected_result, dict): - return expected_result['errorLabelsContain'] + return expected_result["errorLabelsContain"] return False def expect_error_labels_omit(expected_result): if isinstance(expected_result, dict): - return expected_result['errorLabelsOmit'] + return expected_result["errorLabelsOmit"] return False def expect_error(op): - expected_result = op.get('result') - return (expect_any_error(op) or - expect_error_message(expected_result) - or expect_error_code(expected_result) - or expect_error_labels_contain(expected_result) - or expect_error_labels_omit(expected_result)) + expected_result = op.get("result") + return ( + expect_any_error(op) + or expect_error_message(expected_result) + or expect_error_code(expected_result) + or expect_error_labels_contain(expected_result) + or expect_error_labels_omit(expected_result) + ) def end_sessions(sessions): @@ -634,13 +612,13 @@ def end_sessions(sessions): def decode_raw(val): """Decode RawBSONDocuments in the given container.""" if isinstance(val, (list, abc.Mapping)): - return decode(encode({'v': val}))['v'] + return decode(encode({"v": val}))["v"] return val TYPES = { - 'binData': Binary, - 'long': Int64, + "binData": Binary, + "long": Int64, } @@ -649,7 +627,7 @@ def wrap_types(val): if isinstance(val, list): return [wrap_types(v) for v in val] if isinstance(val, abc.Mapping): - typ = val.get('$$type') + typ = val.get("$$type") if typ: return CompareType(TYPES[typ]) d = {} diff --git a/test/version.py b/test/version.py index 3348060bfc..e102db7111 100644 --- a/test/version.py +++ b/test/version.py @@ -16,7 +16,6 @@ class Version(tuple): - def __new__(cls, *version): padded_version = cls._padded(version, 4) return super(Version, cls).__new__(cls, tuple(padded_version)) @@ -43,16 +42,15 @@ def from_string(cls, version_string): version_string = version_string[0:-1] mod = -1 # Deal with '-rcX' substrings - if '-rc' in version_string: - version_string = version_string[0:version_string.find('-rc')] + if "-rc" in version_string: + version_string = version_string[0 : version_string.find("-rc")] mod = -1 # Deal with git describe generated substrings - elif '-' in version_string: - version_string = version_string[0:version_string.find('-')] + elif "-" in version_string: + version_string = version_string[0 : version_string.find("-")] mod = -1 bump_patch_level = True - version = [int(part) for part in version_string.split(".")] version = cls._padded(version, 3) # Make from_string and from_version_array agree. For example: @@ -77,9 +75,9 @@ def from_version_array(cls, version_array): @classmethod def from_client(cls, client): info = client.server_info() - if 'versionArray' in info: - return cls.from_version_array(info['versionArray']) - return cls.from_string(info['version']) + if "versionArray" in info: + return cls.from_version_array(info["versionArray"]) + return cls.from_string(info["version"]) def at_least(self, *other_version): return self >= Version(*other_version) diff --git a/tools/clean.py b/tools/clean.py index a5d383af4e..31279aeac9 100644 --- a/tools/clean.py +++ b/tools/clean.py @@ -34,12 +34,14 @@ try: from pymongo import _cmessage + sys.exit("could still import _cmessage") except ImportError: pass try: from bson import _cbson + sys.exit("could still import _cbson") except ImportError: pass diff --git a/tools/fail_if_no_c.py b/tools/fail_if_no_c.py index e6fd83a36b..a2d4954789 100644 --- a/tools/fail_if_no_c.py +++ b/tools/fail_if_no_c.py @@ -18,6 +18,7 @@ """ import sys + sys.path[0:0] = [""] import bson diff --git a/tools/ocsptest.py b/tools/ocsptest.py index 149da000ba..14df8a8fe3 100644 --- a/tools/ocsptest.py +++ b/tools/ocsptest.py @@ -21,18 +21,20 @@ # Enable logs in this format: # 2020-06-08 23:49:35,982 DEBUG ocsp_support Peer did not staple an OCSP response -FORMAT = '%(asctime)s %(levelname)s %(module)s %(message)s' +FORMAT = "%(asctime)s %(levelname)s %(module)s %(message)s" logging.basicConfig(format=FORMAT, level=logging.DEBUG) + def check_ocsp(host, port, capath): ctx = get_ssl_context( - None, # certfile - None, # passphrase + None, # certfile + None, # passphrase capath, # ca_certs - None, # crlfile - False, # allow_invalid_certificates - False, # allow_invalid_hostnames - False) # disable_ocsp_endpoint_check + None, # crlfile + False, # allow_invalid_certificates + False, # allow_invalid_hostnames + False, + ) # disable_ocsp_endpoint_check # Ensure we're using pyOpenSSL. assert isinstance(ctx, SSLContext) @@ -44,18 +46,15 @@ def check_ocsp(host, port, capath): finally: s.close() + def main(): - parser = argparse.ArgumentParser( - description='Debug OCSP') - parser.add_argument( - '--host', type=str, required=True, help="Host to connect to") - parser.add_argument( - '-p', '--port', type=int, default=443, help="Port to connect to") - parser.add_argument( - '--ca_file', type=str, default=None, help="CA file for host") + parser = argparse.ArgumentParser(description="Debug OCSP") + parser.add_argument("--host", type=str, required=True, help="Host to connect to") + parser.add_argument("-p", "--port", type=int, default=443, help="Port to connect to") + parser.add_argument("--ca_file", type=str, default=None, help="CA file for host") args = parser.parse_args() check_ocsp(args.host, args.port, args.ca_file) -if __name__ == '__main__': - main() +if __name__ == "__main__": + main()