diff --git a/zerver/lib/queue.py b/zerver/lib/queue.py index 6c58a813c7..0ec7ff4849 100644 --- a/zerver/lib/queue.py +++ b/zerver/lib/queue.py @@ -393,28 +393,20 @@ class TornadoQueueClient(QueueClient[Channel]): ) -queue_client: Optional[Union[SimpleQueueClient, TornadoQueueClient]] = None +thread_data = threading.local() def get_queue_client() -> Union[SimpleQueueClient, TornadoQueueClient]: - global queue_client - if queue_client is None: - if settings.RUNNING_INSIDE_TORNADO and settings.USING_RABBITMQ: - queue_client = TornadoQueueClient() - elif settings.USING_RABBITMQ: - queue_client = SimpleQueueClient() - else: + if not hasattr(thread_data, "queue_client"): + if not settings.USING_RABBITMQ: raise RuntimeError("Cannot get a queue client without USING_RABBITMQ") + thread_data.queue_client = SimpleQueueClient() - return queue_client + return thread_data.queue_client -# We using a simple lock to prevent multiple RabbitMQ messages being -# sent to the SimpleQueueClient at the same time; this is a workaround -# for an issue with the pika BlockingConnection where using -# BlockingConnection for multiple queues causes the channel to -# randomly close. -queue_lock = threading.RLock() +def set_queue_client(queue_client: Union[SimpleQueueClient, TornadoQueueClient]) -> None: + thread_data.queue_client = queue_client def queue_json_publish( @@ -422,16 +414,15 @@ def queue_json_publish( event: Dict[str, Any], processor: Optional[Callable[[Any], None]] = None, ) -> None: - with queue_lock: - if settings.USING_RABBITMQ: - get_queue_client().json_publish(queue_name, event) - elif processor: - processor(event) - else: - # Must be imported here: A top section import leads to circular imports - from zerver.worker.queue_processors import get_worker + if settings.USING_RABBITMQ: + get_queue_client().json_publish(queue_name, event) + elif processor: + processor(event) + else: + # Must be imported here: A top section import leads to circular imports + from zerver.worker.queue_processors import get_worker - get_worker(queue_name).consume_single_event(event) + get_worker(queue_name).consume_single_event(event) def retry_event( diff --git a/zerver/management/commands/runtornado.py b/zerver/management/commands/runtornado.py index ac0b3adfb1..ef8dd6c367 100644 --- a/zerver/management/commands/runtornado.py +++ b/zerver/management/commands/runtornado.py @@ -21,7 +21,7 @@ from zerver.tornado.event_queue import ( from zerver.tornado.sharding import notify_tornado_queue_name if settings.USING_RABBITMQ: - from zerver.lib.queue import TornadoQueueClient, get_queue_client + from zerver.lib.queue import TornadoQueueClient, set_queue_client class Command(BaseCommand): @@ -67,8 +67,8 @@ class Command(BaseCommand): print(f"Tornado server (re)started on port {port}") if settings.USING_RABBITMQ: - queue_client = get_queue_client() - assert isinstance(queue_client, TornadoQueueClient) + queue_client = TornadoQueueClient() + set_queue_client(queue_client) # Process notifications received via RabbitMQ queue_name = notify_tornado_queue_name(port) queue_client.start_json_consumer( @@ -88,7 +88,8 @@ class Command(BaseCommand): logging_data["port"] = str(port) setup_event_queue(http_server, port) add_client_gc_hook(missedmessage_hook) - setup_tornado_rabbitmq() + if settings.USING_RABBITMQ: + setup_tornado_rabbitmq(queue_client) instance = ioloop.IOLoop.instance() diff --git a/zerver/tornado/application.py b/zerver/tornado/application.py index 25fdb64596..6d29ceab18 100644 --- a/zerver/tornado/application.py +++ b/zerver/tornado/application.py @@ -4,16 +4,14 @@ import tornado.web from django.conf import settings from tornado import autoreload -from zerver.lib.queue import get_queue_client +from zerver.lib.queue import TornadoQueueClient from zerver.tornado.handlers import AsyncDjangoHandler -def setup_tornado_rabbitmq() -> None: # nocoverage +def setup_tornado_rabbitmq(queue_client: TornadoQueueClient) -> None: # nocoverage # When tornado is shut down, disconnect cleanly from RabbitMQ - if settings.USING_RABBITMQ: - queue_client = get_queue_client() - atexit.register(lambda: queue_client.close()) - autoreload.add_reload_hook(lambda: queue_client.close()) + atexit.register(lambda: queue_client.close()) + autoreload.add_reload_hook(lambda: queue_client.close()) def create_tornado_application() -> tornado.web.Application: