Skip to content

PYTHON-1752 bulk_write should be able to accept a generator #2262

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 4 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
124 changes: 88 additions & 36 deletions pymongo/asynchronous/bulk.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@
from typing import (
TYPE_CHECKING,
Any,
Callable,
Iterable,
Iterator,
Mapping,
Optional,
Expand Down Expand Up @@ -72,7 +74,7 @@
from pymongo.write_concern import WriteConcern

if TYPE_CHECKING:
from pymongo.asynchronous.collection import AsyncCollection
from pymongo.asynchronous.collection import AsyncCollection, _WriteOp
from pymongo.asynchronous.mongo_client import AsyncMongoClient
from pymongo.asynchronous.pool import AsyncConnection
from pymongo.typings import _DocumentOut, _DocumentType, _Pipeline
Expand Down Expand Up @@ -128,13 +130,32 @@ def bulk_ctx_class(self) -> Type[_BulkWriteContext]:
self.is_encrypted = False
return _BulkWriteContext

def add_insert(self, document: _DocumentOut) -> None:
# @property
# def is_retryable(self) -> bool:
# if self.current_run:
# return self.current_run.is_retryable
# return True
#
# @property
# def retrying(self) -> bool:
# if self.current_run:
# return self.current_run.retrying
# return False
#
# @property
# def started_retryable_write(self) -> bool:
# if self.current_run:
# return self.current_run.started_retryable_write
# return False

def add_insert(self, document: _DocumentOut) -> bool:
"""Add an insert document to the list of ops."""
validate_is_document_type("document", document)
# Generate ObjectId client side.
if not (isinstance(document, RawBSONDocument) or "_id" in document):
document["_id"] = ObjectId()
self.ops.append((_INSERT, document))
return True

def add_update(
self,
Expand All @@ -146,7 +167,7 @@ def add_update(
array_filters: Optional[list[Mapping[str, Any]]] = None,
hint: Union[str, dict[str, Any], None] = None,
sort: Optional[Mapping[str, Any]] = None,
) -> None:
) -> bool:
"""Create an update document and add it to the list of ops."""
validate_ok_for_update(update)
cmd: dict[str, Any] = {"q": selector, "u": update, "multi": multi}
Expand All @@ -164,10 +185,12 @@ def add_update(
if sort is not None:
self.uses_sort = True
cmd["sort"] = sort

self.ops.append((_UPDATE, cmd))
if multi:
# A bulk_write containing an update_many is not retryable.
self.is_retryable = False
self.ops.append((_UPDATE, cmd))
return False
return True

def add_replace(
self,
Expand All @@ -177,7 +200,7 @@ def add_replace(
collation: Optional[Mapping[str, Any]] = None,
hint: Union[str, dict[str, Any], None] = None,
sort: Optional[Mapping[str, Any]] = None,
) -> None:
) -> bool:
"""Create a replace document and add it to the list of ops."""
validate_ok_for_replace(replacement)
cmd: dict[str, Any] = {"q": selector, "u": replacement}
Expand All @@ -193,14 +216,15 @@ def add_replace(
self.uses_sort = True
cmd["sort"] = sort
self.ops.append((_UPDATE, cmd))
return True

def add_delete(
self,
selector: Mapping[str, Any],
limit: int,
collation: Optional[Mapping[str, Any]] = None,
hint: Union[str, dict[str, Any], None] = None,
) -> None:
) -> bool:
"""Create a delete document and add it to the list of ops."""
cmd: dict[str, Any] = {"q": selector, "limit": limit}
if collation is not None:
Expand All @@ -209,33 +233,56 @@ def add_delete(
if hint is not None:
self.uses_hint_delete = True
cmd["hint"] = hint

self.ops.append((_DELETE, cmd))
if limit == _DELETE_ALL:
# A bulk_write containing a delete_many is not retryable.
self.is_retryable = False
self.ops.append((_DELETE, cmd))
return False
return True

def gen_ordered(self) -> Iterator[Optional[_Run]]:
def gen_ordered(
self,
requests: Iterable[Any],
process: Callable[[Union[_DocumentType, RawBSONDocument, _WriteOp]], bool],
) -> Iterator[_Run]:
"""Generate batches of operations, batched by type of
operation, in the order **provided**.
"""
run = None
for idx, (op_type, operation) in enumerate(self.ops):
for idx, request in enumerate(requests):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the goal of this ticket is to avoid inflating the whole generator upfront and only iterate requests as they are needed at the encoding step. For example:

coll.bulk_write((InsertOne({'x': 'large'*1024*1024}) for _ in range(1_000_000))

If we inflate all at once like we do here, then that code will need to allocate all 1 million documents at once.

retryable = process(request)
(op_type, operation) = self.ops[idx]
if run is None:
run = _Run(op_type)
elif run.op_type != op_type:
yield run
run = _Run(op_type)
run.add(idx, operation)
self.is_retryable = self.is_retryable and retryable
if run is None:
raise InvalidOperation("No operations to execute")
yield run

def gen_unordered(self) -> Iterator[_Run]:
def gen_unordered(
self,
requests: Iterable[Any],
process: Callable[[Union[_DocumentType, RawBSONDocument, _WriteOp]], bool],
) -> Iterator[_Run]:
"""Generate batches of operations, batched by type of
operation, in arbitrary order.
"""
operations = [_Run(_INSERT), _Run(_UPDATE), _Run(_DELETE)]
for idx, (op_type, operation) in enumerate(self.ops):
for idx, request in enumerate(requests):
retryable = process(request)
(op_type, operation) = self.ops[idx]
operations[op_type].add(idx, operation)

self.is_retryable = self.is_retryable and retryable
if (
len(operations[_INSERT].ops) == 0
and len(operations[_UPDATE].ops) == 0
and len(operations[_DELETE].ops) == 0
):
raise InvalidOperation("No operations to execute")
for run in operations:
if run.ops:
yield run
Expand Down Expand Up @@ -470,8 +517,8 @@ async def _execute_command(
session: Optional[AsyncClientSession],
conn: AsyncConnection,
op_id: int,
retryable: bool,
full_result: MutableMapping[str, Any],
validate: bool,
final_write_concern: Optional[WriteConcern] = None,
) -> None:
db_name = self.collection.database.name
Expand Down Expand Up @@ -523,20 +570,21 @@ async def _execute_command(
if session:
# Start a new retryable write unless one was already
# started for this command.
if retryable and not self.started_retryable_write:
if self.is_retryable and not self.started_retryable_write:
session._start_retryable_write()
self.started_retryable_write = True
session._apply_to(cmd, retryable, ReadPreference.PRIMARY, conn)
session._apply_to(cmd, self.is_retryable, ReadPreference.PRIMARY, conn)
conn.send_cluster_time(cmd, session, client)
conn.add_server_api(cmd)
# CSOT: apply timeout before encoding the command.
conn.apply_timeout(client, cmd)
ops = islice(run.ops, run.idx_offset, None)

# Run as many ops as possible in one command.
if validate:
await self.validate_batch(conn, write_concern)
if write_concern.acknowledged:
result, to_send = await self._execute_batch(bwc, cmd, ops, client)

# Retryable writeConcernErrors halt the execution of this run.
wce = result.get("writeConcernError", {})
if wce.get("code", 0) in _RETRYABLE_ERROR_CODES:
Expand All @@ -549,8 +597,8 @@ async def _execute_command(
_merge_command(run, full_result, run.idx_offset, result)

# We're no longer in a retry once a command succeeds.
self.retrying = False
self.started_retryable_write = False
run.retrying = False
run.started_retryable_write = False

if self.ordered and "writeErrors" in result:
break
Expand Down Expand Up @@ -588,34 +636,33 @@ async def execute_command(
op_id = _randint()

async def retryable_bulk(
session: Optional[AsyncClientSession], conn: AsyncConnection, retryable: bool
session: Optional[AsyncClientSession],
conn: AsyncConnection,
) -> None:
await self._execute_command(
generator,
write_concern,
session,
conn,
op_id,
retryable,
full_result,
validate=False,
)

client = self.collection.database.client
_ = await client._retryable_write(
self.is_retryable,
retryable_bulk,
session,
operation,
bulk=self, # type: ignore[arg-type]
operation_id=op_id,
)

if full_result["writeErrors"] or full_result["writeConcernErrors"]:
_raise_bulk_write_error(full_result)
return full_result

async def execute_op_msg_no_results(
self, conn: AsyncConnection, generator: Iterator[Any]
self, conn: AsyncConnection, generator: Iterator[Any], write_concern: WriteConcern
) -> None:
"""Execute write commands with OP_MSG and w=0 writeConcern, unordered."""
db_name = self.collection.database.name
Expand Down Expand Up @@ -649,6 +696,7 @@ async def execute_op_msg_no_results(
conn.add_server_api(cmd)
ops = islice(run.ops, run.idx_offset, None)
# Run as many ops as possible.
await self.validate_batch(conn, write_concern)
to_send = await self._execute_batch_unack(bwc, cmd, ops, client)
run.idx_offset += len(to_send)
self.current_run = run = next(generator, None)
Expand Down Expand Up @@ -682,12 +730,15 @@ async def execute_command_no_results(
None,
conn,
op_id,
False,
full_result,
True,
write_concern,
)
except OperationFailure:
pass
except OperationFailure as exc:
if "Cannot set bypass_document_validation with unacknowledged write concern" in str(
exc
):
raise exc

async def execute_no_results(
self,
Expand All @@ -696,6 +747,11 @@ async def execute_no_results(
write_concern: WriteConcern,
) -> None:
"""Execute all operations, returning no results (w=0)."""
if self.ordered:
return await self.execute_command_no_results(conn, generator, write_concern)
return await self.execute_op_msg_no_results(conn, generator, write_concern)

async def validate_batch(self, conn: AsyncConnection, write_concern: WriteConcern) -> None:
if self.uses_collation:
raise ConfigurationError("Collation is unsupported for unacknowledged writes.")
if self.uses_array_filters:
Expand All @@ -720,29 +776,25 @@ async def execute_no_results(
"Cannot set bypass_document_validation with unacknowledged write concern"
)

if self.ordered:
return await self.execute_command_no_results(conn, generator, write_concern)
return await self.execute_op_msg_no_results(conn, generator)

async def execute(
self,
generator: Iterable[Any],
process: Callable[[Union[_DocumentType, RawBSONDocument, _WriteOp]], bool],
write_concern: WriteConcern,
session: Optional[AsyncClientSession],
operation: str,
) -> Any:
"""Execute operations."""
if not self.ops:
raise InvalidOperation("No operations to execute")
if self.executed:
raise InvalidOperation("Bulk operations can only be executed once.")
self.executed = True
write_concern = write_concern or self.collection.write_concern
session = _validate_session_write_concern(session, write_concern)

if self.ordered:
generator = self.gen_ordered()
generator = self.gen_ordered(generator, process)
else:
generator = self.gen_unordered()
generator = self.gen_unordered(generator, process)

client = self.collection.database.client
if not write_concern.acknowledged:
Expand Down
13 changes: 5 additions & 8 deletions pymongo/asynchronous/client_bulk.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,7 @@ def __init__(
self.is_retryable = self.client.options.retry_writes
self.retrying = False
self.started_retryable_write = False
self.current_run = None

@property
def bulk_ctx_class(self) -> Type[_ClientBulkWriteContext]:
Expand Down Expand Up @@ -488,7 +489,6 @@ async def _execute_command(
session: Optional[AsyncClientSession],
conn: AsyncConnection,
op_id: int,
retryable: bool,
full_result: MutableMapping[str, Any],
final_write_concern: Optional[WriteConcern] = None,
) -> None:
Expand Down Expand Up @@ -534,10 +534,10 @@ async def _execute_command(
if session:
# Start a new retryable write unless one was already
# started for this command.
if retryable and not self.started_retryable_write:
if self.is_retryable and not self.started_retryable_write:
session._start_retryable_write()
self.started_retryable_write = True
session._apply_to(cmd, retryable, ReadPreference.PRIMARY, conn)
session._apply_to(cmd, self.is_retryable, ReadPreference.PRIMARY, conn)
conn.send_cluster_time(cmd, session, self.client)
conn.add_server_api(cmd)
# CSOT: apply timeout before encoding the command.
Expand All @@ -564,7 +564,7 @@ async def _execute_command(

# Synthesize the full bulk result without modifying the
# current one because this write operation may be retried.
if retryable and (retryable_top_level_error or retryable_network_error):
if self.is_retryable and (retryable_top_level_error or retryable_network_error):
full = copy.deepcopy(full_result)
_merge_command(self.ops, self.idx_offset, full, result)
_throw_client_bulk_write_exception(full, self.verbose_results)
Expand All @@ -583,7 +583,7 @@ async def _execute_command(
_merge_command(self.ops, self.idx_offset, full_result, result)
break

if retryable:
if self.is_retryable:
# Retryable writeConcernErrors halt the execution of this batch.
wce = result.get("writeConcernError", {})
if wce.get("code", 0) in _RETRYABLE_ERROR_CODES:
Expand Down Expand Up @@ -638,7 +638,6 @@ async def execute_command(
async def retryable_bulk(
session: Optional[AsyncClientSession],
conn: AsyncConnection,
retryable: bool,
) -> None:
if conn.max_wire_version < 25:
raise InvalidOperation(
Expand All @@ -649,12 +648,10 @@ async def retryable_bulk(
session,
conn,
op_id,
retryable,
full_result,
)

await self.client._retryable_write(
self.is_retryable,
retryable_bulk,
session,
operation,
Expand Down
Loading
Loading