Source code for datumaro.util.multi_procs_util
# Copyright (C) 2023 Intel Corporation
#
# SPDX-License-Identifier: MIT
import logging as log
from contextlib import contextmanager
from enum import IntEnum
from queue import Full, Queue
from threading import Condition, Thread
from typing import Any, Generator, Iterator, Optional, TypeVar
__all__ = ["consumer_generator"]
[docs]
class ProducerMessage(IntEnum):
START = 0
END = 1
Item = TypeVar("Item")
[docs]
@contextmanager
def consumer_generator(
producer_generator: Iterator[Item],
queue_size: int = 100,
enqueue_timeout: float = 5.0,
join_timeout: Optional[float] = 10.0,
) -> Generator[Iterator[Item], None, None]:
"""Context manager that creates a generator to consume items produced by another generator.
This context manager sets up a producer thread that generates items from the `producer_generator`
and enqueues them to be consumed by the consumer generator, which is also created by this function.
Parameters:
producer_generator: A generator that produces items.
queue_size: The maximum size of the shared queue between the producer and consumer.
enqueue_timeout: The maximum time to wait for enqueuing an item to the queue if it's full.
join_timeout: The maximum time to wait for the producer thread to finish when exiting the context.
If None, wait until the producer thread terminates.
Returns:
Iterator: A context for iterating over the generated items.
"""
queue = Queue(maxsize=queue_size)
lock = Condition()
is_terminated = False
def _enqueue(item: Any, queue: Queue):
while True:
try:
queue.put(item, block=True, timeout=enqueue_timeout)
return
except Full:
with lock:
if is_terminated:
raise RuntimeError(
"Item to enqueue is left. However, the main process is terminated."
)
def _target(queue: Queue) -> None:
try:
_enqueue(ProducerMessage.START, queue)
for item in producer_generator:
_enqueue(item, queue)
_enqueue(ProducerMessage.END, queue)
except RuntimeError as e:
log.error(e)
return
producer = Thread(target=_target, args=(queue,))
producer.start()
def _generator() -> Iterator[Item]:
while True:
item = queue.get()
if item == ProducerMessage.START:
continue
elif item == ProducerMessage.END:
return
yield item
try:
yield _generator()
finally:
with lock:
is_terminated = True
producer.join(timeout=join_timeout)