diff --git a/pymongo/asynchronous/bulk.py b/pymongo/asynchronous/bulk.py index ac514db98f..a98c2b99c1 100644 --- a/pymongo/asynchronous/bulk.py +++ b/pymongo/asynchronous/bulk.py @@ -26,6 +26,8 @@ from typing import ( TYPE_CHECKING, Any, + Callable, + Iterable, Iterator, Mapping, Optional, @@ -72,7 +74,7 @@ from pymongo.write_concern import WriteConcern if TYPE_CHECKING: - from pymongo.asynchronous.collection import AsyncCollection + from pymongo.asynchronous.collection import AsyncCollection, _WriteOp from pymongo.asynchronous.mongo_client import AsyncMongoClient from pymongo.asynchronous.pool import AsyncConnection from pymongo.typings import _DocumentOut, _DocumentType, _Pipeline @@ -128,13 +130,32 @@ def bulk_ctx_class(self) -> Type[_BulkWriteContext]: self.is_encrypted = False return _BulkWriteContext - def add_insert(self, document: _DocumentOut) -> None: + # @property + # def is_retryable(self) -> bool: + # if self.current_run: + # return self.current_run.is_retryable + # return True + # + # @property + # def retrying(self) -> bool: + # if self.current_run: + # return self.current_run.retrying + # return False + # + # @property + # def started_retryable_write(self) -> bool: + # if self.current_run: + # return self.current_run.started_retryable_write + # return False + + def add_insert(self, document: _DocumentOut) -> bool: """Add an insert document to the list of ops.""" validate_is_document_type("document", document) # Generate ObjectId client side. if not (isinstance(document, RawBSONDocument) or "_id" in document): document["_id"] = ObjectId() self.ops.append((_INSERT, document)) + return True def add_update( self, @@ -146,7 +167,7 @@ def add_update( array_filters: Optional[list[Mapping[str, Any]]] = None, hint: Union[str, dict[str, Any], None] = None, sort: Optional[Mapping[str, Any]] = None, - ) -> None: + ) -> bool: """Create an update document and add it to the list of ops.""" validate_ok_for_update(update) cmd: dict[str, Any] = {"q": selector, "u": update, "multi": multi} @@ -164,10 +185,12 @@ def add_update( if sort is not None: self.uses_sort = True cmd["sort"] = sort + + self.ops.append((_UPDATE, cmd)) if multi: # A bulk_write containing an update_many is not retryable. - self.is_retryable = False - self.ops.append((_UPDATE, cmd)) + return False + return True def add_replace( self, @@ -177,7 +200,7 @@ def add_replace( collation: Optional[Mapping[str, Any]] = None, hint: Union[str, dict[str, Any], None] = None, sort: Optional[Mapping[str, Any]] = None, - ) -> None: + ) -> bool: """Create a replace document and add it to the list of ops.""" validate_ok_for_replace(replacement) cmd: dict[str, Any] = {"q": selector, "u": replacement} @@ -193,6 +216,7 @@ def add_replace( self.uses_sort = True cmd["sort"] = sort self.ops.append((_UPDATE, cmd)) + return True def add_delete( self, @@ -200,7 +224,7 @@ def add_delete( limit: int, collation: Optional[Mapping[str, Any]] = None, hint: Union[str, dict[str, Any], None] = None, - ) -> None: + ) -> bool: """Create a delete document and add it to the list of ops.""" cmd: dict[str, Any] = {"q": selector, "limit": limit} if collation is not None: @@ -209,33 +233,56 @@ def add_delete( if hint is not None: self.uses_hint_delete = True cmd["hint"] = hint + + self.ops.append((_DELETE, cmd)) if limit == _DELETE_ALL: # A bulk_write containing a delete_many is not retryable. - self.is_retryable = False - self.ops.append((_DELETE, cmd)) + return False + return True - def gen_ordered(self) -> Iterator[Optional[_Run]]: + def gen_ordered( + self, + requests: Iterable[Any], + process: Callable[[Union[_DocumentType, RawBSONDocument, _WriteOp]], bool], + ) -> Iterator[_Run]: """Generate batches of operations, batched by type of operation, in the order **provided**. """ run = None - for idx, (op_type, operation) in enumerate(self.ops): + for idx, request in enumerate(requests): + retryable = process(request) + (op_type, operation) = self.ops[idx] if run is None: run = _Run(op_type) elif run.op_type != op_type: yield run run = _Run(op_type) run.add(idx, operation) + self.is_retryable = self.is_retryable and retryable + if run is None: + raise InvalidOperation("No operations to execute") yield run - def gen_unordered(self) -> Iterator[_Run]: + def gen_unordered( + self, + requests: Iterable[Any], + process: Callable[[Union[_DocumentType, RawBSONDocument, _WriteOp]], bool], + ) -> Iterator[_Run]: """Generate batches of operations, batched by type of operation, in arbitrary order. """ operations = [_Run(_INSERT), _Run(_UPDATE), _Run(_DELETE)] - for idx, (op_type, operation) in enumerate(self.ops): + for idx, request in enumerate(requests): + retryable = process(request) + (op_type, operation) = self.ops[idx] operations[op_type].add(idx, operation) - + self.is_retryable = self.is_retryable and retryable + if ( + len(operations[_INSERT].ops) == 0 + and len(operations[_UPDATE].ops) == 0 + and len(operations[_DELETE].ops) == 0 + ): + raise InvalidOperation("No operations to execute") for run in operations: if run.ops: yield run @@ -470,8 +517,8 @@ async def _execute_command( session: Optional[AsyncClientSession], conn: AsyncConnection, op_id: int, - retryable: bool, full_result: MutableMapping[str, Any], + validate: bool, final_write_concern: Optional[WriteConcern] = None, ) -> None: db_name = self.collection.database.name @@ -523,10 +570,10 @@ async def _execute_command( if session: # Start a new retryable write unless one was already # started for this command. - if retryable and not self.started_retryable_write: + if self.is_retryable and not self.started_retryable_write: session._start_retryable_write() self.started_retryable_write = True - session._apply_to(cmd, retryable, ReadPreference.PRIMARY, conn) + session._apply_to(cmd, self.is_retryable, ReadPreference.PRIMARY, conn) conn.send_cluster_time(cmd, session, client) conn.add_server_api(cmd) # CSOT: apply timeout before encoding the command. @@ -534,9 +581,10 @@ async def _execute_command( ops = islice(run.ops, run.idx_offset, None) # Run as many ops as possible in one command. + if validate: + await self.validate_batch(conn, write_concern) if write_concern.acknowledged: result, to_send = await self._execute_batch(bwc, cmd, ops, client) - # Retryable writeConcernErrors halt the execution of this run. wce = result.get("writeConcernError", {}) if wce.get("code", 0) in _RETRYABLE_ERROR_CODES: @@ -549,8 +597,8 @@ async def _execute_command( _merge_command(run, full_result, run.idx_offset, result) # We're no longer in a retry once a command succeeds. - self.retrying = False - self.started_retryable_write = False + run.retrying = False + run.started_retryable_write = False if self.ordered and "writeErrors" in result: break @@ -588,7 +636,8 @@ async def execute_command( op_id = _randint() async def retryable_bulk( - session: Optional[AsyncClientSession], conn: AsyncConnection, retryable: bool + session: Optional[AsyncClientSession], + conn: AsyncConnection, ) -> None: await self._execute_command( generator, @@ -596,26 +645,24 @@ async def retryable_bulk( session, conn, op_id, - retryable, full_result, + validate=False, ) client = self.collection.database.client _ = await client._retryable_write( - self.is_retryable, retryable_bulk, session, operation, bulk=self, # type: ignore[arg-type] operation_id=op_id, ) - if full_result["writeErrors"] or full_result["writeConcernErrors"]: _raise_bulk_write_error(full_result) return full_result async def execute_op_msg_no_results( - self, conn: AsyncConnection, generator: Iterator[Any] + self, conn: AsyncConnection, generator: Iterator[Any], write_concern: WriteConcern ) -> None: """Execute write commands with OP_MSG and w=0 writeConcern, unordered.""" db_name = self.collection.database.name @@ -649,6 +696,7 @@ async def execute_op_msg_no_results( conn.add_server_api(cmd) ops = islice(run.ops, run.idx_offset, None) # Run as many ops as possible. + await self.validate_batch(conn, write_concern) to_send = await self._execute_batch_unack(bwc, cmd, ops, client) run.idx_offset += len(to_send) self.current_run = run = next(generator, None) @@ -682,12 +730,15 @@ async def execute_command_no_results( None, conn, op_id, - False, full_result, + True, write_concern, ) - except OperationFailure: - pass + except OperationFailure as exc: + if "Cannot set bypass_document_validation with unacknowledged write concern" in str( + exc + ): + raise exc async def execute_no_results( self, @@ -696,6 +747,11 @@ async def execute_no_results( write_concern: WriteConcern, ) -> None: """Execute all operations, returning no results (w=0).""" + if self.ordered: + return await self.execute_command_no_results(conn, generator, write_concern) + return await self.execute_op_msg_no_results(conn, generator, write_concern) + + async def validate_batch(self, conn: AsyncConnection, write_concern: WriteConcern) -> None: if self.uses_collation: raise ConfigurationError("Collation is unsupported for unacknowledged writes.") if self.uses_array_filters: @@ -720,19 +776,15 @@ async def execute_no_results( "Cannot set bypass_document_validation with unacknowledged write concern" ) - if self.ordered: - return await self.execute_command_no_results(conn, generator, write_concern) - return await self.execute_op_msg_no_results(conn, generator) - async def execute( self, + generator: Iterable[Any], + process: Callable[[Union[_DocumentType, RawBSONDocument, _WriteOp]], bool], write_concern: WriteConcern, session: Optional[AsyncClientSession], operation: str, ) -> Any: """Execute operations.""" - if not self.ops: - raise InvalidOperation("No operations to execute") if self.executed: raise InvalidOperation("Bulk operations can only be executed once.") self.executed = True @@ -740,9 +792,9 @@ async def execute( session = _validate_session_write_concern(session, write_concern) if self.ordered: - generator = self.gen_ordered() + generator = self.gen_ordered(generator, process) else: - generator = self.gen_unordered() + generator = self.gen_unordered(generator, process) client = self.collection.database.client if not write_concern.acknowledged: diff --git a/pymongo/asynchronous/client_bulk.py b/pymongo/asynchronous/client_bulk.py index 5f7ac013e9..dbbad9e0e8 100644 --- a/pymongo/asynchronous/client_bulk.py +++ b/pymongo/asynchronous/client_bulk.py @@ -116,6 +116,7 @@ def __init__( self.is_retryable = self.client.options.retry_writes self.retrying = False self.started_retryable_write = False + self.current_run = None @property def bulk_ctx_class(self) -> Type[_ClientBulkWriteContext]: @@ -488,7 +489,6 @@ async def _execute_command( session: Optional[AsyncClientSession], conn: AsyncConnection, op_id: int, - retryable: bool, full_result: MutableMapping[str, Any], final_write_concern: Optional[WriteConcern] = None, ) -> None: @@ -534,10 +534,10 @@ async def _execute_command( if session: # Start a new retryable write unless one was already # started for this command. - if retryable and not self.started_retryable_write: + if self.is_retryable and not self.started_retryable_write: session._start_retryable_write() self.started_retryable_write = True - session._apply_to(cmd, retryable, ReadPreference.PRIMARY, conn) + session._apply_to(cmd, self.is_retryable, ReadPreference.PRIMARY, conn) conn.send_cluster_time(cmd, session, self.client) conn.add_server_api(cmd) # CSOT: apply timeout before encoding the command. @@ -564,7 +564,7 @@ async def _execute_command( # Synthesize the full bulk result without modifying the # current one because this write operation may be retried. - if retryable and (retryable_top_level_error or retryable_network_error): + if self.is_retryable and (retryable_top_level_error or retryable_network_error): full = copy.deepcopy(full_result) _merge_command(self.ops, self.idx_offset, full, result) _throw_client_bulk_write_exception(full, self.verbose_results) @@ -583,7 +583,7 @@ async def _execute_command( _merge_command(self.ops, self.idx_offset, full_result, result) break - if retryable: + if self.is_retryable: # Retryable writeConcernErrors halt the execution of this batch. wce = result.get("writeConcernError", {}) if wce.get("code", 0) in _RETRYABLE_ERROR_CODES: @@ -638,7 +638,6 @@ async def execute_command( async def retryable_bulk( session: Optional[AsyncClientSession], conn: AsyncConnection, - retryable: bool, ) -> None: if conn.max_wire_version < 25: raise InvalidOperation( @@ -649,12 +648,10 @@ async def retryable_bulk( session, conn, op_id, - retryable, full_result, ) await self.client._retryable_write( - self.is_retryable, retryable_bulk, session, operation, diff --git a/pymongo/asynchronous/client_session.py b/pymongo/asynchronous/client_session.py index b808684dd4..b9d8449a34 100644 --- a/pymongo/asynchronous/client_session.py +++ b/pymongo/asynchronous/client_session.py @@ -854,13 +854,12 @@ async def _finish_transaction_with_retry(self, command_name: str) -> dict[str, A """ async def func( - _session: Optional[AsyncClientSession], conn: AsyncConnection, _retryable: bool + _session: Optional[AsyncClientSession], + conn: AsyncConnection, ) -> dict[str, Any]: return await self._finish_transaction(conn, command_name) - return await self._client._retry_internal( - func, self, None, retryable=True, operation=_Op.ABORT - ) + return await self._client._retry_internal(func, self, None, operation=_Op.ABORT) async def _finish_transaction(self, conn: AsyncConnection, command_name: str) -> dict[str, Any]: self._transaction.attempt += 1 diff --git a/pymongo/asynchronous/collection.py b/pymongo/asynchronous/collection.py index 7fb20b7ab3..5ee67ddf89 100644 --- a/pymongo/asynchronous/collection.py +++ b/pymongo/asynchronous/collection.py @@ -699,7 +699,7 @@ async def _create( @_csot.apply async def bulk_write( self, - requests: Sequence[_WriteOp[_DocumentType]], + requests: Iterable[_WriteOp], ordered: bool = True, bypass_document_validation: Optional[bool] = None, session: Optional[AsyncClientSession] = None, @@ -779,17 +779,21 @@ async def bulk_write( .. versionadded:: 3.0 """ - common.validate_list("requests", requests) + common.validate_list_or_generator("requests", requests) blk = _AsyncBulk(self, ordered, bypass_document_validation, comment=comment, let=let) - for request in requests: + + write_concern = self._write_concern_for(session) + + def process_for_bulk(request: _WriteOp) -> bool: try: - request._add_to_bulk(blk) + return request._add_to_bulk(blk) except AttributeError: raise TypeError(f"{request!r} is not a valid request") from None - write_concern = self._write_concern_for(session) - bulk_api_result = await blk.execute(write_concern, session, _Op.INSERT) + bulk_api_result = await blk.execute( + requests, process_for_bulk, write_concern, session, _Op.INSERT + ) if bulk_api_result is not None: return BulkWriteResult(bulk_api_result, True) return BulkWriteResult({}, False) @@ -806,17 +810,15 @@ async def _insert_one( ) -> Any: """Internal helper for inserting a single document.""" write_concern = write_concern or self.write_concern - acknowledged = write_concern.acknowledged command = {"insert": self.name, "ordered": ordered, "documents": [doc]} if comment is not None: command["comment"] = comment async def _insert_command( - session: Optional[AsyncClientSession], conn: AsyncConnection, retryable_write: bool + session: Optional[AsyncClientSession], conn: AsyncConnection ) -> None: if bypass_doc_val is not None: command["bypassDocumentValidation"] = bypass_doc_val - result = await conn.command( self._database.name, command, @@ -824,14 +826,11 @@ async def _insert_command( codec_options=self._write_response_codec_options, session=session, client=self._database.client, - retryable_write=retryable_write, ) _check_write_command_response(result) - await self._database.client._retryable_write( - acknowledged, _insert_command, session, operation=_Op.INSERT - ) + await self._database.client._retryable_write(_insert_command, session, operation=_Op.INSERT) if not isinstance(doc, RawBSONDocument): return doc.get("_id") @@ -960,20 +959,19 @@ async def insert_many( raise TypeError("documents must be a non-empty list") inserted_ids: list[ObjectId] = [] - def gen() -> Iterator[tuple[int, Mapping[str, Any]]]: + def process_for_bulk(document: Union[_DocumentType, RawBSONDocument]) -> bool: """A generator that validates documents and handles _ids.""" - for document in documents: - common.validate_is_document_type("document", document) - if not isinstance(document, RawBSONDocument): - if "_id" not in document: - document["_id"] = ObjectId() # type: ignore[index] - inserted_ids.append(document["_id"]) - yield (message._INSERT, document) + common.validate_is_document_type("document", document) + if not isinstance(document, RawBSONDocument): + if "_id" not in document: + document["_id"] = ObjectId() # type: ignore[index] + inserted_ids.append(document["_id"]) + blk.ops.append((message._INSERT, document)) + return True write_concern = self._write_concern_for(session) blk = _AsyncBulk(self, ordered, bypass_document_validation, comment=comment) - blk.ops = list(gen()) - await blk.execute(write_concern, session, _Op.INSERT) + await blk.execute(documents, process_for_bulk, write_concern, session, _Op.INSERT) return InsertManyResult(inserted_ids, write_concern.acknowledged) async def _update( @@ -991,7 +989,6 @@ async def _update( array_filters: Optional[Sequence[Mapping[str, Any]]] = None, hint: Optional[_IndexKeyHint] = None, session: Optional[AsyncClientSession] = None, - retryable_write: bool = False, let: Optional[Mapping[str, Any]] = None, sort: Optional[Mapping[str, Any]] = None, comment: Optional[Any] = None, @@ -1054,7 +1051,6 @@ async def _update( codec_options=self._write_response_codec_options, session=session, client=self._database.client, - retryable_write=retryable_write, ) ).copy() _check_write_command_response(result) @@ -1094,7 +1090,7 @@ async def _update_retryable( """Internal update / replace helper.""" async def _update( - session: Optional[AsyncClientSession], conn: AsyncConnection, retryable_write: bool + session: Optional[AsyncClientSession], conn: AsyncConnection ) -> Optional[Mapping[str, Any]]: return await self._update( conn, @@ -1110,14 +1106,12 @@ async def _update( array_filters=array_filters, hint=hint, session=session, - retryable_write=retryable_write, let=let, sort=sort, comment=comment, ) return await self._database.client._retryable_write( - (write_concern or self.write_concern).acknowledged and not multi, _update, session, operation, @@ -1507,7 +1501,6 @@ async def _delete( collation: Optional[_CollationIn] = None, hint: Optional[_IndexKeyHint] = None, session: Optional[AsyncClientSession] = None, - retryable_write: bool = False, let: Optional[Mapping[str, Any]] = None, comment: Optional[Any] = None, ) -> Mapping[str, Any]: @@ -1547,7 +1540,6 @@ async def _delete( codec_options=self._write_response_codec_options, session=session, client=self._database.client, - retryable_write=retryable_write, ) _check_write_command_response(result) return result @@ -1568,7 +1560,7 @@ async def _delete_retryable( """Internal delete helper.""" async def _delete( - session: Optional[AsyncClientSession], conn: AsyncConnection, retryable_write: bool + session: Optional[AsyncClientSession], conn: AsyncConnection ) -> Mapping[str, Any]: return await self._delete( conn, @@ -1580,13 +1572,11 @@ async def _delete( collation=collation, hint=hint, session=session, - retryable_write=retryable_write, let=let, comment=comment, ) return await self._database.client._retryable_write( - (write_concern or self.write_concern).acknowledged and not multi, _delete, session, operation=_Op.DELETE, @@ -3231,7 +3221,7 @@ async def _find_and_modify( write_concern = self._write_concern_for_cmd(cmd, session) async def _find_and_modify_helper( - session: Optional[AsyncClientSession], conn: AsyncConnection, retryable_write: bool + session: Optional[AsyncClientSession], conn: AsyncConnection ) -> Any: acknowledged = write_concern.acknowledged if array_filters is not None: @@ -3257,7 +3247,6 @@ async def _find_and_modify_helper( write_concern=write_concern, collation=collation, session=session, - retryable_write=retryable_write, user_fields=_FIND_AND_MODIFY_DOC_FIELDS, ) _check_write_command_response(out) @@ -3265,7 +3254,6 @@ async def _find_and_modify_helper( return out.get("value") return await self._database.client._retryable_write( - write_concern.acknowledged, _find_and_modify_helper, session, operation=_Op.FIND_AND_MODIFY, diff --git a/pymongo/asynchronous/mongo_client.py b/pymongo/asynchronous/mongo_client.py index a236b21348..92918981ac 100644 --- a/pymongo/asynchronous/mongo_client.py +++ b/pymongo/asynchronous/mongo_client.py @@ -149,9 +149,7 @@ T = TypeVar("T") -_WriteCall = Callable[ - [Optional["AsyncClientSession"], "AsyncConnection", bool], Coroutine[Any, Any, T] -] +_WriteCall = Callable[[Optional["AsyncClientSession"], "AsyncConnection"], Coroutine[Any, Any, T]] _ReadCall = Callable[ [Optional["AsyncClientSession"], "Server", "AsyncConnection", _ServerMode], Coroutine[Any, Any, T], @@ -1931,7 +1929,6 @@ async def _cmd( async def _retry_with_session( self, - retryable: bool, func: _WriteCall[T], session: Optional[AsyncClientSession], bulk: Optional[Union[_AsyncBulk, _AsyncClientBulk]], @@ -1947,15 +1944,11 @@ async def _retry_with_session( """ # Ensure that the options supports retry_writes and there is a valid session not in # transaction, otherwise, we will not support retry behavior for this txn. - retryable = bool( - retryable and self.options.retry_writes and session and not session.in_transaction - ) return await self._retry_internal( func=func, session=session, bulk=bulk, operation=operation, - retryable=retryable, operation_id=operation_id, ) @@ -1969,7 +1962,6 @@ async def _retry_internal( is_read: bool = False, address: Optional[_Address] = None, read_pref: Optional[_ServerMode] = None, - retryable: bool = False, operation_id: Optional[int] = None, ) -> T: """Internal retryable helper for all client transactions. @@ -1994,7 +1986,6 @@ async def _retry_internal( session=session, read_pref=read_pref, address=address, - retryable=retryable, operation_id=operation_id, ).run() @@ -2037,13 +2028,11 @@ async def _retryable_read( is_read=True, address=address, read_pref=read_pref, - retryable=retryable, operation_id=operation_id, ) async def _retryable_write( self, - retryable: bool, func: _WriteCall[T], session: Optional[AsyncClientSession], operation: str, @@ -2064,7 +2053,7 @@ async def _retryable_write( :param bulk: bulk abstraction to execute operations in bulk, defaults to None """ async with self._tmp_session(session) as s: - return await self._retry_with_session(retryable, func, s, bulk, operation, operation_id) + return await self._retry_with_session(func, s, bulk, operation, operation_id) def _cleanup_cursor_no_lock( self, @@ -2723,6 +2712,11 @@ def __init__( self._operation_id = operation_id self._attempt_number = 0 + def _bulk_retryable(self) -> bool: + if self._bulk is not None: + return self._bulk.is_retryable + return True + async def run(self) -> T: """Runs the supplied func() and attempts a retry @@ -2733,7 +2727,12 @@ async def run(self) -> T: # Increment the transaction id up front to ensure any retry attempt # will use the proper txnNumber, even if server or socket selection # fails before the command can be sent. - if self._is_session_state_retryable() and self._retryable and not self._is_read: + if ( + self._is_session_state_retryable() + and self._retryable + and self._bulk_retryable() + and not self._is_read + ): self._session._start_retryable_write() # type: ignore if self._bulk: self._bulk.started_retryable_write = True @@ -2770,7 +2769,7 @@ async def run(self) -> T: # Specialized catch on write operation if not self._is_read: - if not self._retryable: + if not self._retryable and not self._bulk_retryable(): raise if isinstance(exc, ClientBulkWriteException) and exc.error: retryable_write_error_exc = isinstance( @@ -2801,11 +2800,15 @@ async def run(self) -> T: def _is_not_eligible_for_retry(self) -> bool: """Checks if the exchange is not eligible for retry""" - return not self._retryable or (self._is_retrying() and not self._multiple_retries) + return ( + not self._retryable + or not self._bulk_retryable() + or (self._is_retrying() and not self._multiple_retries) + ) def _is_retrying(self) -> bool: """Checks if the exchange is currently undergoing a retry""" - return self._bulk.retrying if self._bulk else self._retrying + return self._bulk.retrying if self._bulk is not None else self._retrying def _is_session_state_retryable(self) -> bool: """Checks if provided session is eligible for retry @@ -2865,6 +2868,8 @@ async def _write(self) -> T: # not support sessions raise the last error. self._check_last_error() self._retryable = False + if self._bulk: + self._bulk.is_retryable = False if self._retrying: _debug_log( _COMMAND_LOGGER, @@ -2873,9 +2878,9 @@ async def _write(self) -> T: commandName=self._operation, operationId=self._operation_id, ) - return await self._func(self._session, conn, self._retryable) # type: ignore + return await self._func(self._session, conn) # type: ignore except PyMongoError as exc: - if not self._retryable: + if not self._retryable or not self._bulk_retryable(): raise # Add the RetryableWriteError label, if applicable. _add_retryable_write_error(exc, max_wire_version, is_mongos) @@ -2892,7 +2897,7 @@ async def _read(self) -> T: conn, read_pref, ): - if self._retrying and not self._retryable: + if self._retrying and (not self._retryable or not self._bulk_retryable()): self._check_last_error() if self._retrying: _debug_log( diff --git a/pymongo/common.py b/pymongo/common.py index 3d8095eedf..6d9bb2f37a 100644 --- a/pymongo/common.py +++ b/pymongo/common.py @@ -24,6 +24,7 @@ TYPE_CHECKING, Any, Callable, + Generator, Iterator, Mapping, MutableMapping, @@ -530,6 +531,13 @@ def validate_list(option: str, value: Any) -> list: return value +def validate_list_or_generator(option: str, value: Any) -> Union[list, Generator]: + """Validates that 'value' is a list or generator.""" + if isinstance(value, Generator): + return value + return validate_list(option, value) + + def validate_list_or_none(option: Any, value: Any) -> Optional[list]: """Validates that 'value' is a list or None.""" if value is None: diff --git a/pymongo/operations.py b/pymongo/operations.py index 300f1ba123..49b41ee614 100644 --- a/pymongo/operations.py +++ b/pymongo/operations.py @@ -106,9 +106,9 @@ def __init__(self, document: _DocumentType, namespace: Optional[str] = None) -> self._doc = document self._namespace = namespace - def _add_to_bulk(self, bulkobj: _AgnosticBulk) -> None: + def _add_to_bulk(self, bulkobj: _AgnosticBulk) -> bool: """Add this operation to the _AsyncBulk/_Bulk instance `bulkobj`.""" - bulkobj.add_insert(self._doc) # type: ignore[arg-type] + return bulkobj.add_insert(self._doc) # type: ignore[arg-type] def _add_to_client_bulk(self, bulkobj: _AgnosticClientBulk) -> None: """Add this operation to the _AsyncClientBulk/_ClientBulk instance `bulkobj`.""" @@ -230,9 +230,9 @@ def __init__( """ super().__init__(filter, collation, hint, namespace) - def _add_to_bulk(self, bulkobj: _AgnosticBulk) -> None: + def _add_to_bulk(self, bulkobj: _AgnosticBulk) -> bool: """Add this operation to the _AsyncBulk/_Bulk instance `bulkobj`.""" - bulkobj.add_delete( + return bulkobj.add_delete( self._filter, 1, collation=validate_collation_or_none(self._collation), @@ -291,9 +291,9 @@ def __init__( """ super().__init__(filter, collation, hint, namespace) - def _add_to_bulk(self, bulkobj: _AgnosticBulk) -> None: + def _add_to_bulk(self, bulkobj: _AgnosticBulk) -> bool: """Add this operation to the _AsyncBulk/_Bulk instance `bulkobj`.""" - bulkobj.add_delete( + return bulkobj.add_delete( self._filter, 0, collation=validate_collation_or_none(self._collation), @@ -384,9 +384,9 @@ def __init__( self._collation = collation self._namespace = namespace - def _add_to_bulk(self, bulkobj: _AgnosticBulk) -> None: + def _add_to_bulk(self, bulkobj: _AgnosticBulk) -> bool: """Add this operation to the _AsyncBulk/_Bulk instance `bulkobj`.""" - bulkobj.add_replace( + return bulkobj.add_replace( self._filter, self._doc, self._upsert, @@ -606,9 +606,9 @@ def __init__( """ super().__init__(filter, update, upsert, collation, array_filters, hint, namespace, sort) - def _add_to_bulk(self, bulkobj: _AgnosticBulk) -> None: + def _add_to_bulk(self, bulkobj: _AgnosticBulk) -> bool: """Add this operation to the _AsyncBulk/_Bulk instance `bulkobj`.""" - bulkobj.add_update( + return bulkobj.add_update( self._filter, self._doc, False, @@ -687,9 +687,9 @@ def __init__( """ super().__init__(filter, update, upsert, collation, array_filters, hint, namespace, None) - def _add_to_bulk(self, bulkobj: _AgnosticBulk) -> None: + def _add_to_bulk(self, bulkobj: _AgnosticBulk) -> bool: """Add this operation to the _AsyncBulk/_Bulk instance `bulkobj`.""" - bulkobj.add_update( + return bulkobj.add_update( self._filter, self._doc, True, diff --git a/pymongo/synchronous/bulk.py b/pymongo/synchronous/bulk.py index a528b09add..c3323ed841 100644 --- a/pymongo/synchronous/bulk.py +++ b/pymongo/synchronous/bulk.py @@ -26,6 +26,8 @@ from typing import ( TYPE_CHECKING, Any, + Callable, + Iterable, Iterator, Mapping, Optional, @@ -72,7 +74,7 @@ from pymongo.write_concern import WriteConcern if TYPE_CHECKING: - from pymongo.synchronous.collection import Collection + from pymongo.synchronous.collection import Collection, _WriteOp from pymongo.synchronous.mongo_client import MongoClient from pymongo.synchronous.pool import Connection from pymongo.typings import _DocumentOut, _DocumentType, _Pipeline @@ -128,13 +130,32 @@ def bulk_ctx_class(self) -> Type[_BulkWriteContext]: self.is_encrypted = False return _BulkWriteContext - def add_insert(self, document: _DocumentOut) -> None: + # @property + # def is_retryable(self) -> bool: + # if self.current_run: + # return self.current_run.is_retryable + # return True + # + # @property + # def retrying(self) -> bool: + # if self.current_run: + # return self.current_run.retrying + # return False + # + # @property + # def started_retryable_write(self) -> bool: + # if self.current_run: + # return self.current_run.started_retryable_write + # return False + + def add_insert(self, document: _DocumentOut) -> bool: """Add an insert document to the list of ops.""" validate_is_document_type("document", document) # Generate ObjectId client side. if not (isinstance(document, RawBSONDocument) or "_id" in document): document["_id"] = ObjectId() self.ops.append((_INSERT, document)) + return True def add_update( self, @@ -146,7 +167,7 @@ def add_update( array_filters: Optional[list[Mapping[str, Any]]] = None, hint: Union[str, dict[str, Any], None] = None, sort: Optional[Mapping[str, Any]] = None, - ) -> None: + ) -> bool: """Create an update document and add it to the list of ops.""" validate_ok_for_update(update) cmd: dict[str, Any] = {"q": selector, "u": update, "multi": multi} @@ -164,10 +185,12 @@ def add_update( if sort is not None: self.uses_sort = True cmd["sort"] = sort + + self.ops.append((_UPDATE, cmd)) if multi: # A bulk_write containing an update_many is not retryable. - self.is_retryable = False - self.ops.append((_UPDATE, cmd)) + return False + return True def add_replace( self, @@ -177,7 +200,7 @@ def add_replace( collation: Optional[Mapping[str, Any]] = None, hint: Union[str, dict[str, Any], None] = None, sort: Optional[Mapping[str, Any]] = None, - ) -> None: + ) -> bool: """Create a replace document and add it to the list of ops.""" validate_ok_for_replace(replacement) cmd: dict[str, Any] = {"q": selector, "u": replacement} @@ -193,6 +216,7 @@ def add_replace( self.uses_sort = True cmd["sort"] = sort self.ops.append((_UPDATE, cmd)) + return True def add_delete( self, @@ -200,7 +224,7 @@ def add_delete( limit: int, collation: Optional[Mapping[str, Any]] = None, hint: Union[str, dict[str, Any], None] = None, - ) -> None: + ) -> bool: """Create a delete document and add it to the list of ops.""" cmd: dict[str, Any] = {"q": selector, "limit": limit} if collation is not None: @@ -209,33 +233,56 @@ def add_delete( if hint is not None: self.uses_hint_delete = True cmd["hint"] = hint + + self.ops.append((_DELETE, cmd)) if limit == _DELETE_ALL: # A bulk_write containing a delete_many is not retryable. - self.is_retryable = False - self.ops.append((_DELETE, cmd)) + return False + return True - def gen_ordered(self) -> Iterator[Optional[_Run]]: + def gen_ordered( + self, + requests: Iterable[Any], + process: Callable[[Union[_DocumentType, RawBSONDocument, _WriteOp]], bool], + ) -> Iterator[_Run]: """Generate batches of operations, batched by type of operation, in the order **provided**. """ run = None - for idx, (op_type, operation) in enumerate(self.ops): + for idx, request in enumerate(requests): + retryable = process(request) + (op_type, operation) = self.ops[idx] if run is None: run = _Run(op_type) elif run.op_type != op_type: yield run run = _Run(op_type) run.add(idx, operation) + self.is_retryable = self.is_retryable and retryable + if run is None: + raise InvalidOperation("No operations to execute") yield run - def gen_unordered(self) -> Iterator[_Run]: + def gen_unordered( + self, + requests: Iterable[Any], + process: Callable[[Union[_DocumentType, RawBSONDocument, _WriteOp]], bool], + ) -> Iterator[_Run]: """Generate batches of operations, batched by type of operation, in arbitrary order. """ operations = [_Run(_INSERT), _Run(_UPDATE), _Run(_DELETE)] - for idx, (op_type, operation) in enumerate(self.ops): + for idx, request in enumerate(requests): + retryable = process(request) + (op_type, operation) = self.ops[idx] operations[op_type].add(idx, operation) - + self.is_retryable = self.is_retryable and retryable + if ( + len(operations[_INSERT].ops) == 0 + and len(operations[_UPDATE].ops) == 0 + and len(operations[_DELETE].ops) == 0 + ): + raise InvalidOperation("No operations to execute") for run in operations: if run.ops: yield run @@ -470,8 +517,8 @@ def _execute_command( session: Optional[ClientSession], conn: Connection, op_id: int, - retryable: bool, full_result: MutableMapping[str, Any], + validate: bool, final_write_concern: Optional[WriteConcern] = None, ) -> None: db_name = self.collection.database.name @@ -523,10 +570,10 @@ def _execute_command( if session: # Start a new retryable write unless one was already # started for this command. - if retryable and not self.started_retryable_write: + if self.is_retryable and not self.started_retryable_write: session._start_retryable_write() self.started_retryable_write = True - session._apply_to(cmd, retryable, ReadPreference.PRIMARY, conn) + session._apply_to(cmd, self.is_retryable, ReadPreference.PRIMARY, conn) conn.send_cluster_time(cmd, session, client) conn.add_server_api(cmd) # CSOT: apply timeout before encoding the command. @@ -534,9 +581,10 @@ def _execute_command( ops = islice(run.ops, run.idx_offset, None) # Run as many ops as possible in one command. + if validate: + self.validate_batch(conn, write_concern) if write_concern.acknowledged: result, to_send = self._execute_batch(bwc, cmd, ops, client) - # Retryable writeConcernErrors halt the execution of this run. wce = result.get("writeConcernError", {}) if wce.get("code", 0) in _RETRYABLE_ERROR_CODES: @@ -549,8 +597,8 @@ def _execute_command( _merge_command(run, full_result, run.idx_offset, result) # We're no longer in a retry once a command succeeds. - self.retrying = False - self.started_retryable_write = False + run.retrying = False + run.started_retryable_write = False if self.ordered and "writeErrors" in result: break @@ -588,7 +636,8 @@ def execute_command( op_id = _randint() def retryable_bulk( - session: Optional[ClientSession], conn: Connection, retryable: bool + session: Optional[ClientSession], + conn: Connection, ) -> None: self._execute_command( generator, @@ -596,25 +645,25 @@ def retryable_bulk( session, conn, op_id, - retryable, full_result, + validate=False, ) client = self.collection.database.client _ = client._retryable_write( - self.is_retryable, retryable_bulk, session, operation, bulk=self, # type: ignore[arg-type] operation_id=op_id, ) - if full_result["writeErrors"] or full_result["writeConcernErrors"]: _raise_bulk_write_error(full_result) return full_result - def execute_op_msg_no_results(self, conn: Connection, generator: Iterator[Any]) -> None: + def execute_op_msg_no_results( + self, conn: Connection, generator: Iterator[Any], write_concern: WriteConcern + ) -> None: """Execute write commands with OP_MSG and w=0 writeConcern, unordered.""" db_name = self.collection.database.name client = self.collection.database.client @@ -647,6 +696,7 @@ def execute_op_msg_no_results(self, conn: Connection, generator: Iterator[Any]) conn.add_server_api(cmd) ops = islice(run.ops, run.idx_offset, None) # Run as many ops as possible. + self.validate_batch(conn, write_concern) to_send = self._execute_batch_unack(bwc, cmd, ops, client) run.idx_offset += len(to_send) self.current_run = run = next(generator, None) @@ -680,12 +730,15 @@ def execute_command_no_results( None, conn, op_id, - False, full_result, + True, write_concern, ) - except OperationFailure: - pass + except OperationFailure as exc: + if "Cannot set bypass_document_validation with unacknowledged write concern" in str( + exc + ): + raise exc def execute_no_results( self, @@ -694,6 +747,11 @@ def execute_no_results( write_concern: WriteConcern, ) -> None: """Execute all operations, returning no results (w=0).""" + if self.ordered: + return self.execute_command_no_results(conn, generator, write_concern) + return self.execute_op_msg_no_results(conn, generator, write_concern) + + def validate_batch(self, conn: Connection, write_concern: WriteConcern) -> None: if self.uses_collation: raise ConfigurationError("Collation is unsupported for unacknowledged writes.") if self.uses_array_filters: @@ -718,19 +776,15 @@ def execute_no_results( "Cannot set bypass_document_validation with unacknowledged write concern" ) - if self.ordered: - return self.execute_command_no_results(conn, generator, write_concern) - return self.execute_op_msg_no_results(conn, generator) - def execute( self, + generator: Iterable[Any], + process: Callable[[Union[_DocumentType, RawBSONDocument, _WriteOp]], bool], write_concern: WriteConcern, session: Optional[ClientSession], operation: str, ) -> Any: """Execute operations.""" - if not self.ops: - raise InvalidOperation("No operations to execute") if self.executed: raise InvalidOperation("Bulk operations can only be executed once.") self.executed = True @@ -738,9 +792,9 @@ def execute( session = _validate_session_write_concern(session, write_concern) if self.ordered: - generator = self.gen_ordered() + generator = self.gen_ordered(generator, process) else: - generator = self.gen_unordered() + generator = self.gen_unordered(generator, process) client = self.collection.database.client if not write_concern.acknowledged: diff --git a/pymongo/synchronous/client_bulk.py b/pymongo/synchronous/client_bulk.py index d73bfb2a2b..0b0d4190f9 100644 --- a/pymongo/synchronous/client_bulk.py +++ b/pymongo/synchronous/client_bulk.py @@ -116,6 +116,7 @@ def __init__( self.is_retryable = self.client.options.retry_writes self.retrying = False self.started_retryable_write = False + self.current_run = None @property def bulk_ctx_class(self) -> Type[_ClientBulkWriteContext]: @@ -486,7 +487,6 @@ def _execute_command( session: Optional[ClientSession], conn: Connection, op_id: int, - retryable: bool, full_result: MutableMapping[str, Any], final_write_concern: Optional[WriteConcern] = None, ) -> None: @@ -532,10 +532,10 @@ def _execute_command( if session: # Start a new retryable write unless one was already # started for this command. - if retryable and not self.started_retryable_write: + if self.is_retryable and not self.started_retryable_write: session._start_retryable_write() self.started_retryable_write = True - session._apply_to(cmd, retryable, ReadPreference.PRIMARY, conn) + session._apply_to(cmd, self.is_retryable, ReadPreference.PRIMARY, conn) conn.send_cluster_time(cmd, session, self.client) conn.add_server_api(cmd) # CSOT: apply timeout before encoding the command. @@ -562,7 +562,7 @@ def _execute_command( # Synthesize the full bulk result without modifying the # current one because this write operation may be retried. - if retryable and (retryable_top_level_error or retryable_network_error): + if self.is_retryable and (retryable_top_level_error or retryable_network_error): full = copy.deepcopy(full_result) _merge_command(self.ops, self.idx_offset, full, result) _throw_client_bulk_write_exception(full, self.verbose_results) @@ -581,7 +581,7 @@ def _execute_command( _merge_command(self.ops, self.idx_offset, full_result, result) break - if retryable: + if self.is_retryable: # Retryable writeConcernErrors halt the execution of this batch. wce = result.get("writeConcernError", {}) if wce.get("code", 0) in _RETRYABLE_ERROR_CODES: @@ -636,7 +636,6 @@ def execute_command( def retryable_bulk( session: Optional[ClientSession], conn: Connection, - retryable: bool, ) -> None: if conn.max_wire_version < 25: raise InvalidOperation( @@ -647,12 +646,10 @@ def retryable_bulk( session, conn, op_id, - retryable, full_result, ) self.client._retryable_write( - self.is_retryable, retryable_bulk, session, operation, diff --git a/pymongo/synchronous/client_session.py b/pymongo/synchronous/client_session.py index aaf2d7574f..dc52a24911 100644 --- a/pymongo/synchronous/client_session.py +++ b/pymongo/synchronous/client_session.py @@ -851,11 +851,12 @@ def _finish_transaction_with_retry(self, command_name: str) -> dict[str, Any]: """ def func( - _session: Optional[ClientSession], conn: Connection, _retryable: bool + _session: Optional[ClientSession], + conn: Connection, ) -> dict[str, Any]: return self._finish_transaction(conn, command_name) - return self._client._retry_internal(func, self, None, retryable=True, operation=_Op.ABORT) + return self._client._retry_internal(func, self, None, operation=_Op.ABORT) def _finish_transaction(self, conn: Connection, command_name: str) -> dict[str, Any]: self._transaction.attempt += 1 diff --git a/pymongo/synchronous/collection.py b/pymongo/synchronous/collection.py index 8a71768318..27b2a072d3 100644 --- a/pymongo/synchronous/collection.py +++ b/pymongo/synchronous/collection.py @@ -698,7 +698,7 @@ def _create( @_csot.apply def bulk_write( self, - requests: Sequence[_WriteOp[_DocumentType]], + requests: Iterable[_WriteOp], ordered: bool = True, bypass_document_validation: Optional[bool] = None, session: Optional[ClientSession] = None, @@ -778,17 +778,21 @@ def bulk_write( .. versionadded:: 3.0 """ - common.validate_list("requests", requests) + common.validate_list_or_generator("requests", requests) blk = _Bulk(self, ordered, bypass_document_validation, comment=comment, let=let) - for request in requests: + + write_concern = self._write_concern_for(session) + + def process_for_bulk(request: _WriteOp) -> bool: try: - request._add_to_bulk(blk) + return request._add_to_bulk(blk) except AttributeError: raise TypeError(f"{request!r} is not a valid request") from None - write_concern = self._write_concern_for(session) - bulk_api_result = blk.execute(write_concern, session, _Op.INSERT) + bulk_api_result = blk.execute( + requests, process_for_bulk, write_concern, session, _Op.INSERT + ) if bulk_api_result is not None: return BulkWriteResult(bulk_api_result, True) return BulkWriteResult({}, False) @@ -805,17 +809,13 @@ def _insert_one( ) -> Any: """Internal helper for inserting a single document.""" write_concern = write_concern or self.write_concern - acknowledged = write_concern.acknowledged command = {"insert": self.name, "ordered": ordered, "documents": [doc]} if comment is not None: command["comment"] = comment - def _insert_command( - session: Optional[ClientSession], conn: Connection, retryable_write: bool - ) -> None: + def _insert_command(session: Optional[ClientSession], conn: Connection) -> None: if bypass_doc_val is not None: command["bypassDocumentValidation"] = bypass_doc_val - result = conn.command( self._database.name, command, @@ -823,14 +823,11 @@ def _insert_command( codec_options=self._write_response_codec_options, session=session, client=self._database.client, - retryable_write=retryable_write, ) _check_write_command_response(result) - self._database.client._retryable_write( - acknowledged, _insert_command, session, operation=_Op.INSERT - ) + self._database.client._retryable_write(_insert_command, session, operation=_Op.INSERT) if not isinstance(doc, RawBSONDocument): return doc.get("_id") @@ -959,20 +956,19 @@ def insert_many( raise TypeError("documents must be a non-empty list") inserted_ids: list[ObjectId] = [] - def gen() -> Iterator[tuple[int, Mapping[str, Any]]]: + def process_for_bulk(document: Union[_DocumentType, RawBSONDocument]) -> bool: """A generator that validates documents and handles _ids.""" - for document in documents: - common.validate_is_document_type("document", document) - if not isinstance(document, RawBSONDocument): - if "_id" not in document: - document["_id"] = ObjectId() # type: ignore[index] - inserted_ids.append(document["_id"]) - yield (message._INSERT, document) + common.validate_is_document_type("document", document) + if not isinstance(document, RawBSONDocument): + if "_id" not in document: + document["_id"] = ObjectId() # type: ignore[index] + inserted_ids.append(document["_id"]) + blk.ops.append((message._INSERT, document)) + return True write_concern = self._write_concern_for(session) blk = _Bulk(self, ordered, bypass_document_validation, comment=comment) - blk.ops = list(gen()) - blk.execute(write_concern, session, _Op.INSERT) + blk.execute(documents, process_for_bulk, write_concern, session, _Op.INSERT) return InsertManyResult(inserted_ids, write_concern.acknowledged) def _update( @@ -990,7 +986,6 @@ def _update( array_filters: Optional[Sequence[Mapping[str, Any]]] = None, hint: Optional[_IndexKeyHint] = None, session: Optional[ClientSession] = None, - retryable_write: bool = False, let: Optional[Mapping[str, Any]] = None, sort: Optional[Mapping[str, Any]] = None, comment: Optional[Any] = None, @@ -1053,7 +1048,6 @@ def _update( codec_options=self._write_response_codec_options, session=session, client=self._database.client, - retryable_write=retryable_write, ) ).copy() _check_write_command_response(result) @@ -1093,7 +1087,7 @@ def _update_retryable( """Internal update / replace helper.""" def _update( - session: Optional[ClientSession], conn: Connection, retryable_write: bool + session: Optional[ClientSession], conn: Connection ) -> Optional[Mapping[str, Any]]: return self._update( conn, @@ -1109,14 +1103,12 @@ def _update( array_filters=array_filters, hint=hint, session=session, - retryable_write=retryable_write, let=let, sort=sort, comment=comment, ) return self._database.client._retryable_write( - (write_concern or self.write_concern).acknowledged and not multi, _update, session, operation, @@ -1506,7 +1498,6 @@ def _delete( collation: Optional[_CollationIn] = None, hint: Optional[_IndexKeyHint] = None, session: Optional[ClientSession] = None, - retryable_write: bool = False, let: Optional[Mapping[str, Any]] = None, comment: Optional[Any] = None, ) -> Mapping[str, Any]: @@ -1546,7 +1537,6 @@ def _delete( codec_options=self._write_response_codec_options, session=session, client=self._database.client, - retryable_write=retryable_write, ) _check_write_command_response(result) return result @@ -1566,9 +1556,7 @@ def _delete_retryable( ) -> Mapping[str, Any]: """Internal delete helper.""" - def _delete( - session: Optional[ClientSession], conn: Connection, retryable_write: bool - ) -> Mapping[str, Any]: + def _delete(session: Optional[ClientSession], conn: Connection) -> Mapping[str, Any]: return self._delete( conn, criteria, @@ -1579,13 +1567,11 @@ def _delete( collation=collation, hint=hint, session=session, - retryable_write=retryable_write, let=let, comment=comment, ) return self._database.client._retryable_write( - (write_concern or self.write_concern).acknowledged and not multi, _delete, session, operation=_Op.DELETE, @@ -3223,9 +3209,7 @@ def _find_and_modify( write_concern = self._write_concern_for_cmd(cmd, session) - def _find_and_modify_helper( - session: Optional[ClientSession], conn: Connection, retryable_write: bool - ) -> Any: + def _find_and_modify_helper(session: Optional[ClientSession], conn: Connection) -> Any: acknowledged = write_concern.acknowledged if array_filters is not None: if not acknowledged: @@ -3250,7 +3234,6 @@ def _find_and_modify_helper( write_concern=write_concern, collation=collation, session=session, - retryable_write=retryable_write, user_fields=_FIND_AND_MODIFY_DOC_FIELDS, ) _check_write_command_response(out) @@ -3258,7 +3241,6 @@ def _find_and_modify_helper( return out.get("value") return self._database.client._retryable_write( - write_concern.acknowledged, _find_and_modify_helper, session, operation=_Op.FIND_AND_MODIFY, diff --git a/pymongo/synchronous/mongo_client.py b/pymongo/synchronous/mongo_client.py index 99a517e5c1..a8ab74c63a 100644 --- a/pymongo/synchronous/mongo_client.py +++ b/pymongo/synchronous/mongo_client.py @@ -148,7 +148,7 @@ T = TypeVar("T") -_WriteCall = Callable[[Optional["ClientSession"], "Connection", bool], T] +_WriteCall = Callable[[Optional["ClientSession"], "Connection"], T] _ReadCall = Callable[ [Optional["ClientSession"], "Server", "Connection", _ServerMode], T, @@ -1925,7 +1925,6 @@ def _cmd( def _retry_with_session( self, - retryable: bool, func: _WriteCall[T], session: Optional[ClientSession], bulk: Optional[Union[_Bulk, _ClientBulk]], @@ -1941,15 +1940,11 @@ def _retry_with_session( """ # Ensure that the options supports retry_writes and there is a valid session not in # transaction, otherwise, we will not support retry behavior for this txn. - retryable = bool( - retryable and self.options.retry_writes and session and not session.in_transaction - ) return self._retry_internal( func=func, session=session, bulk=bulk, operation=operation, - retryable=retryable, operation_id=operation_id, ) @@ -1963,7 +1958,6 @@ def _retry_internal( is_read: bool = False, address: Optional[_Address] = None, read_pref: Optional[_ServerMode] = None, - retryable: bool = False, operation_id: Optional[int] = None, ) -> T: """Internal retryable helper for all client transactions. @@ -1988,7 +1982,6 @@ def _retry_internal( session=session, read_pref=read_pref, address=address, - retryable=retryable, operation_id=operation_id, ).run() @@ -2031,13 +2024,11 @@ def _retryable_read( is_read=True, address=address, read_pref=read_pref, - retryable=retryable, operation_id=operation_id, ) def _retryable_write( self, - retryable: bool, func: _WriteCall[T], session: Optional[ClientSession], operation: str, @@ -2058,7 +2049,7 @@ def _retryable_write( :param bulk: bulk abstraction to execute operations in bulk, defaults to None """ with self._tmp_session(session) as s: - return self._retry_with_session(retryable, func, s, bulk, operation, operation_id) + return self._retry_with_session(func, s, bulk, operation, operation_id) def _cleanup_cursor_no_lock( self, @@ -2709,6 +2700,11 @@ def __init__( self._operation_id = operation_id self._attempt_number = 0 + def _bulk_retryable(self) -> bool: + if self._bulk is not None: + return self._bulk.is_retryable + return True + def run(self) -> T: """Runs the supplied func() and attempts a retry @@ -2719,7 +2715,12 @@ def run(self) -> T: # Increment the transaction id up front to ensure any retry attempt # will use the proper txnNumber, even if server or socket selection # fails before the command can be sent. - if self._is_session_state_retryable() and self._retryable and not self._is_read: + if ( + self._is_session_state_retryable() + and self._retryable + and self._bulk_retryable() + and not self._is_read + ): self._session._start_retryable_write() # type: ignore if self._bulk: self._bulk.started_retryable_write = True @@ -2756,7 +2757,7 @@ def run(self) -> T: # Specialized catch on write operation if not self._is_read: - if not self._retryable: + if not self._retryable and not self._bulk_retryable(): raise if isinstance(exc, ClientBulkWriteException) and exc.error: retryable_write_error_exc = isinstance( @@ -2787,11 +2788,15 @@ def run(self) -> T: def _is_not_eligible_for_retry(self) -> bool: """Checks if the exchange is not eligible for retry""" - return not self._retryable or (self._is_retrying() and not self._multiple_retries) + return ( + not self._retryable + or not self._bulk_retryable() + or (self._is_retrying() and not self._multiple_retries) + ) def _is_retrying(self) -> bool: """Checks if the exchange is currently undergoing a retry""" - return self._bulk.retrying if self._bulk else self._retrying + return self._bulk.retrying if self._bulk is not None else self._retrying def _is_session_state_retryable(self) -> bool: """Checks if provided session is eligible for retry @@ -2851,6 +2856,8 @@ def _write(self) -> T: # not support sessions raise the last error. self._check_last_error() self._retryable = False + if self._bulk: + self._bulk.is_retryable = False if self._retrying: _debug_log( _COMMAND_LOGGER, @@ -2859,9 +2866,9 @@ def _write(self) -> T: commandName=self._operation, operationId=self._operation_id, ) - return self._func(self._session, conn, self._retryable) # type: ignore + return self._func(self._session, conn) # type: ignore except PyMongoError as exc: - if not self._retryable: + if not self._retryable or not self._bulk_retryable(): raise # Add the RetryableWriteError label, if applicable. _add_retryable_write_error(exc, max_wire_version, is_mongos) @@ -2878,7 +2885,7 @@ def _read(self) -> T: conn, read_pref, ): - if self._retrying and not self._retryable: + if self._retrying and (not self._retryable or not self._bulk_retryable()): self._check_last_error() if self._retrying: _debug_log( diff --git a/test/asynchronous/test_bulk.py b/test/asynchronous/test_bulk.py index 65ed6e236a..4d2338eae2 100644 --- a/test/asynchronous/test_bulk.py +++ b/test/asynchronous/test_bulk.py @@ -299,6 +299,21 @@ async def test_numerous_inserts(self): self.assertEqual(n_docs, result.inserted_count) self.assertEqual(n_docs, await self.coll.count_documents({})) + async def test_numerous_inserts_generator(self): + # Ensure we don't exceed server's maxWriteBatchSize size limit. + n_docs = await async_client_context.max_write_batch_size + 100 + requests = (InsertOne[dict]({}) for _ in range(n_docs)) + result = await self.coll.bulk_write(requests, ordered=False) + self.assertEqual(n_docs, result.inserted_count) + self.assertEqual(n_docs, await self.coll.count_documents({})) + + # Same with ordered bulk. + await self.coll.drop() + requests = (InsertOne[dict]({}) for _ in range(n_docs)) + result = await self.coll.bulk_write(requests) + self.assertEqual(n_docs, result.inserted_count) + self.assertEqual(n_docs, await self.coll.count_documents({})) + async def test_bulk_max_message_size(self): await self.coll.delete_many({}) self.addAsyncCleanup(self.coll.delete_many, {}) @@ -338,11 +353,6 @@ async def test_bulk_write_no_results(self): self.assertRaises(InvalidOperation, lambda: result.upserted_ids) async def test_bulk_write_invalid_arguments(self): - # The requests argument must be a list. - generator = (InsertOne[dict]({}) for _ in range(10)) - with self.assertRaises(TypeError): - await self.coll.bulk_write(generator) # type: ignore[arg-type] - # Document is not wrapped in a bulk write operation. with self.assertRaises(TypeError): await self.coll.bulk_write([{}]) # type: ignore[list-item] diff --git a/test/test_bulk.py b/test/test_bulk.py index 8a863cc49b..9696f6da1d 100644 --- a/test/test_bulk.py +++ b/test/test_bulk.py @@ -299,6 +299,21 @@ def test_numerous_inserts(self): self.assertEqual(n_docs, result.inserted_count) self.assertEqual(n_docs, self.coll.count_documents({})) + def test_numerous_inserts_generator(self): + # Ensure we don't exceed server's maxWriteBatchSize size limit. + n_docs = client_context.max_write_batch_size + 100 + requests = (InsertOne[dict]({}) for _ in range(n_docs)) + result = self.coll.bulk_write(requests, ordered=False) + self.assertEqual(n_docs, result.inserted_count) + self.assertEqual(n_docs, self.coll.count_documents({})) + + # Same with ordered bulk. + self.coll.drop() + requests = (InsertOne[dict]({}) for _ in range(n_docs)) + result = self.coll.bulk_write(requests) + self.assertEqual(n_docs, result.inserted_count) + self.assertEqual(n_docs, self.coll.count_documents({})) + def test_bulk_max_message_size(self): self.coll.delete_many({}) self.addCleanup(self.coll.delete_many, {}) @@ -338,11 +353,6 @@ def test_bulk_write_no_results(self): self.assertRaises(InvalidOperation, lambda: result.upserted_ids) def test_bulk_write_invalid_arguments(self): - # The requests argument must be a list. - generator = (InsertOne[dict]({}) for _ in range(10)) - with self.assertRaises(TypeError): - self.coll.bulk_write(generator) # type: ignore[arg-type] - # Document is not wrapped in a bulk write operation. with self.assertRaises(TypeError): self.coll.bulk_write([{}]) # type: ignore[list-item]