From 7c3507feefb32ccd1eee9af2ebbb9ec9e33dcce6 Mon Sep 17 00:00:00 2001 From: Alex Vandiver Date: Mon, 15 Nov 2021 12:03:55 -0800 Subject: [PATCH] queue: Allow passing down a prefetch count to pika. --- zerver/lib/queue.py | 13 +++++++++++-- zerver/tests/test_queue_worker.py | 2 +- 2 files changed, 12 insertions(+), 3 deletions(-) diff --git a/zerver/lib/queue.py b/zerver/lib/queue.py index 7717eff349..3abaad80c1 100644 --- a/zerver/lib/queue.py +++ b/zerver/lib/queue.py @@ -32,10 +32,12 @@ class QueueClient(Generic[ChannelT], metaclass=ABCMeta): self, # Disable RabbitMQ heartbeats by default because BlockingConnection can't process them rabbitmq_heartbeat: Optional[int] = 0, + prefetch: int = 0, ) -> None: self.log = logging.getLogger("zulip.queue") self.queues: Set[str] = set() self.channel: Optional[ChannelT] = None + self.prefetch = prefetch self.consumers: Dict[str, Set[Consumer[ChannelT]]] = defaultdict(set) self.rabbitmq_heartbeat = rabbitmq_heartbeat self.is_consuming = False @@ -158,9 +160,12 @@ class SimpleQueueClient(QueueClient[BlockingChannel]): self._connect() assert self.channel is not None + self.channel.basic_qos(prefetch_count=self.prefetch) + if queue_name not in self.queues: self.channel.queue_declare(queue=queue_name, durable=True) self.queues.add(queue_name) + callback(self.channel) def start_json_consumer( @@ -329,9 +334,13 @@ class TornadoQueueClient(QueueClient[Channel]): self.connection.close() def ensure_queue(self, queue_name: str, callback: Callable[[Channel], object]) -> None: - def finish(frame: Any) -> None: + def set_qos(frame: Any) -> None: assert self.channel is not None self.queues.add(queue_name) + self.channel.basic_qos(prefetch_count=self.prefetch, callback=finish) + + def finish(frame: Any) -> None: + assert self.channel is not None callback(self.channel) if queue_name not in self.queues: @@ -342,7 +351,7 @@ class TornadoQueueClient(QueueClient[Channel]): return assert self.channel is not None - self.channel.queue_declare(queue=queue_name, durable=True, callback=finish) + self.channel.queue_declare(queue=queue_name, durable=True, callback=set_qos) else: assert self.channel is not None callback(self.channel) diff --git a/zerver/tests/test_queue_worker.py b/zerver/tests/test_queue_worker.py index 3fc41eab86..a9525b3a7c 100644 --- a/zerver/tests/test_queue_worker.py +++ b/zerver/tests/test_queue_worker.py @@ -47,7 +47,7 @@ Event = Dict[str, Any] class FakeClient: - def __init__(self) -> None: + def __init__(self, prefetch: int = 0) -> None: self.queues: Dict[str, List[Dict[str, Any]]] = defaultdict(list) def enqueue(self, queue_name: str, data: Dict[str, Any]) -> None: