diff --git a/zerver/tests/test_queue_worker.py b/zerver/tests/test_queue_worker.py index b63d2e8c1d..2015ed6b89 100644 --- a/zerver/tests/test_queue_worker.py +++ b/zerver/tests/test_queue_worker.py @@ -1,4 +1,5 @@ import base64 +import datetime import os import signal import time @@ -19,7 +20,15 @@ from zerver.lib.remote_server import PushNotificationBouncerRetryLaterError from zerver.lib.send_email import EmailNotDeliveredException, FromAddress from zerver.lib.test_classes import ZulipTestCase from zerver.lib.test_helpers import mock_queue_publish, simulated_queue_client -from zerver.models import PreregistrationUser, UserActivity, get_client, get_realm, get_stream +from zerver.models import ( + NotificationTriggers, + PreregistrationUser, + ScheduledMessageNotificationEmail, + UserActivity, + get_client, + get_realm, + get_stream, +) from zerver.tornado.event_queue import build_offline_notification from zerver.worker import queue_processors from zerver.worker.queue_processors import ( @@ -144,17 +153,31 @@ class WorkerTest(ZulipTestCase): content="where art thou, othello?", ) - events = [ - dict(user_profile_id=hamlet.id, message_id=hamlet1_msg_id), - dict(user_profile_id=hamlet.id, message_id=hamlet2_msg_id), - dict(user_profile_id=othello.id, message_id=othello_msg_id), - ] + hamlet_event1 = dict( + user_profile_id=hamlet.id, + message_id=hamlet1_msg_id, + trigger=NotificationTriggers.PRIVATE_MESSAGE, + ) + hamlet_event2 = dict( + user_profile_id=hamlet.id, + message_id=hamlet2_msg_id, + trigger=NotificationTriggers.PRIVATE_MESSAGE, + mentioned_user_group_id=4, + ) + othello_event = dict( + user_profile_id=othello.id, + message_id=othello_msg_id, + trigger=NotificationTriggers.PRIVATE_MESSAGE, + ) + + events = [hamlet_event1, hamlet_event2, othello_event] fake_client = self.FakeClient() for event in events: fake_client.enqueue("missedmessage_emails", event) mmw = MissedMessageWorker() + batch_duration = datetime.timedelta(seconds=mmw.BATCH_DURATION) class MockTimer: is_running = False @@ -174,36 +197,100 @@ class WorkerTest(ZulipTestCase): send_mock = patch( "zerver.lib.email_notifications.do_send_missedmessage_events_reply_in_zulip", ) - mmw.BATCH_DURATION = 0 - bonus_event = dict(user_profile_id=hamlet.id, message_id=hamlet3_msg_id) + bonus_event_hamlet = dict( + user_profile_id=hamlet.id, + message_id=hamlet3_msg_id, + trigger=NotificationTriggers.PRIVATE_MESSAGE, + ) + + def check_row( + row: ScheduledMessageNotificationEmail, + scheduled_timestamp: datetime.datetime, + mentioned_user_group_id: Optional[int], + ) -> None: + self.assertEqual(row.trigger, NotificationTriggers.PRIVATE_MESSAGE) + self.assertEqual(row.scheduled_timestamp, scheduled_timestamp) + self.assertEqual(row.mentioned_user_group_id, mentioned_user_group_id) with send_mock as sm, timer_mock as tm: with simulated_queue_client(lambda: fake_client): self.assertFalse(timer.is_alive()) - mmw.setup() - mmw.start() - self.assertTrue(timer.is_alive()) - fake_client.enqueue("missedmessage_emails", bonus_event) - # Double-calling start is our way to get it to run again + time_zero = datetime.datetime(2021, 1, 1, tzinfo=datetime.timezone.utc) + expected_scheduled_timestamp = time_zero + batch_duration + with patch("zerver.worker.queue_processors.timezone_now", return_value=time_zero): + mmw.setup() + mmw.start() + + # The events should be saved in the database + hamlet_row1 = ScheduledMessageNotificationEmail.objects.get( + user_profile_id=hamlet.id, message_id=hamlet1_msg_id + ) + check_row(hamlet_row1, expected_scheduled_timestamp, None) + + hamlet_row2 = ScheduledMessageNotificationEmail.objects.get( + user_profile_id=hamlet.id, message_id=hamlet2_msg_id + ) + check_row(hamlet_row2, expected_scheduled_timestamp, 4) + + othello_row1 = ScheduledMessageNotificationEmail.objects.get( + user_profile_id=othello.id, message_id=othello_msg_id + ) + check_row(othello_row1, expected_scheduled_timestamp, None) + + # Additionally, the timer should have be started + self.assertTrue(timer.is_alive()) + + # If another event is received, test that it gets saved with the same + # `expected_scheduled_timestamp` as the earlier events. + fake_client.enqueue("missedmessage_emails", bonus_event_hamlet) self.assertTrue(timer.is_alive()) - mmw.start() - with self.assertLogs(level="INFO") as info_logs: - # Now, we actually send the emails. + few_moments_later = time_zero + datetime.timedelta(seconds=3) + with patch( + "zerver.worker.queue_processors.timezone_now", return_value=few_moments_later + ): + # Double-calling start is our way to get it to run again + mmw.start() + hamlet_row3 = ScheduledMessageNotificationEmail.objects.get( + user_profile_id=hamlet.id, message_id=hamlet3_msg_id + ) + check_row(hamlet_row3, expected_scheduled_timestamp, None) + + # Now let us test `maybe_send_batched_emails` + # If called too early, it shouldn't process the emails. + one_minute_premature = expected_scheduled_timestamp - datetime.timedelta(seconds=60) + with patch( + "zerver.worker.queue_processors.timezone_now", return_value=one_minute_premature + ): mmw.maybe_send_batched_emails() - self.assertEqual( - info_logs.output, - [ + self.assertEqual(ScheduledMessageNotificationEmail.objects.count(), 4) + + # If called after `expected_scheduled_timestamp`, it should process all emails. + one_minute_overdue = expected_scheduled_timestamp + datetime.timedelta(seconds=60) + with self.assertLogs(level="INFO") as info_logs, patch( + "zerver.worker.queue_processors.timezone_now", return_value=one_minute_overdue + ): + mmw.maybe_send_batched_emails() + self.assertEqual(ScheduledMessageNotificationEmail.objects.count(), 0) + + self.assert_length(info_logs.output, 2) + self.assertIn( f"INFO:root:Batch-processing 3 missedmessage_emails events for user {hamlet.id}", + info_logs.output, + ) + self.assertIn( f"INFO:root:Batch-processing 1 missedmessage_emails events for user {othello.id}", - ], - ) + info_logs.output, + ) - self.assertEqual(mmw.timer_event, None) + # All batches got processed. Verify that the timer isn't running. + self.assertEqual(mmw.timer_event, None) - self.assertEqual(tm.call_args[0][0], 5) # should sleep 5 seconds + # Check that the frequency of calling maybe_send_batched_emails is correct (5 seconds) + self.assertEqual(tm.call_args[0][0], 5) + # Verify the payloads now args = [c[0] for c in sm.call_args_list] arg_dict = { arg[0].id: dict( diff --git a/zerver/worker/queue_processors.py b/zerver/worker/queue_processors.py index 6cb8472e04..9aecaca87f 100644 --- a/zerver/worker/queue_processors.py +++ b/zerver/worker/queue_processors.py @@ -13,7 +13,7 @@ import tempfile import time import urllib from abc import ABC, abstractmethod -from collections import defaultdict, deque +from collections import deque from email.message import EmailMessage from functools import wraps from threading import Lock, Timer @@ -37,7 +37,7 @@ import orjson import sentry_sdk from django.conf import settings from django.core.mail.backends.smtp import EmailBackend -from django.db import connection +from django.db import connection, transaction from django.db.models import F from django.utils.timezone import now as timezone_now from django.utils.translation import gettext as _ @@ -93,6 +93,7 @@ from zerver.models import ( PreregistrationUser, Realm, RealmAuditLog, + ScheduledMessageNotificationEmail, UserMessage, UserProfile, filter_to_valid_prereg_users, @@ -561,17 +562,9 @@ class MissedMessageWorker(QueueProcessingWorker): # # The timer is running whenever; we poll at most every TIMER_FREQUENCY # seconds, to avoid excessive activity. - # - # TODO: Since this process keeps events in memory for up to 2 - # minutes, it now will lose approximately BATCH_DURATION worth of - # missed_message emails whenever it is restarted as part of a - # server restart. We should probably add some sort of save/reload - # mechanism for that case. TIMER_FREQUENCY = 5 BATCH_DURATION = 120 timer_event: Optional[Timer] = None - events_by_recipient: Dict[int, List[Dict[str, Any]]] = defaultdict(list) - batch_start_by_recipient: Dict[int, float] = {} # This lock protects access to all of the data structures declared # above. A lock is required because maybe_send_batched_emails, as @@ -589,11 +582,29 @@ class MissedMessageWorker(QueueProcessingWorker): with self.lock: logging.debug("Received missedmessage_emails event: %s", event) - # When we process an event, just put it into the queue and ensure we have a timer going. - user_profile_id = event["user_profile_id"] - if user_profile_id not in self.batch_start_by_recipient: - self.batch_start_by_recipient[user_profile_id] = time.time() - self.events_by_recipient[user_profile_id].append(event) + # When we consume an event, check if there are existing pending emails + # for that user, and if so use the same scheduled timestamp. + user_profile_id: int = event["user_profile_id"] + batch_duration = datetime.timedelta(seconds=self.BATCH_DURATION) + + with transaction.atomic(): + try: + pending_email = ScheduledMessageNotificationEmail.objects.filter( + user_profile_id=user_profile_id + )[0] + scheduled_timestamp = pending_email.scheduled_timestamp + except IndexError: + scheduled_timestamp = timezone_now() + batch_duration + + entry = ScheduledMessageNotificationEmail( + user_profile_id=user_profile_id, + message_id=event["message_id"], + trigger=event["trigger"], + scheduled_timestamp=scheduled_timestamp, + ) + if "mentioned_user_group_id" in event: + entry.mentioned_user_group_id = event["mentioned_user_group_id"] + entry.save() self.ensure_timer() @@ -615,25 +626,44 @@ class MissedMessageWorker(QueueProcessingWorker): # is active. self.timer_event = None - current_time = time.time() - for user_profile_id, timestamp in list(self.batch_start_by_recipient.items()): - if current_time - timestamp < self.BATCH_DURATION: - continue - events = self.events_by_recipient[user_profile_id] - logging.info( - "Batch-processing %s missedmessage_emails events for user %s", - len(events), - user_profile_id, - ) - handle_missedmessage_emails(user_profile_id, events) - del self.events_by_recipient[user_profile_id] - del self.batch_start_by_recipient[user_profile_id] + current_time = timezone_now() + + with transaction.atomic(): + events_to_process = ScheduledMessageNotificationEmail.objects.filter( + scheduled_timestamp__lte=current_time + ).select_related() + + # Batch the entries by user + events_by_recipient: Dict[int, List[Dict[str, Any]]] = {} + for event in events_to_process: + entry = dict( + user_profile_id=event.user_profile_id, + message_id=event.message_id, + trigger=event.trigger, + mentioned_user_group_id=event.mentioned_user_group_id, + ) + if event.user_profile_id in events_by_recipient: + events_by_recipient[event.user_profile_id].append(entry) + else: + events_by_recipient[event.user_profile_id] = [entry] + + for user_profile_id in events_by_recipient.keys(): + events: List[Dict[str, Any]] = events_by_recipient[user_profile_id] + + logging.info( + "Batch-processing %s missedmessage_emails events for user %s", + len(events), + user_profile_id, + ) + handle_missedmessage_emails(user_profile_id, events) + + events_to_process.delete() # By only restarting the timer if there are actually events in # the queue, we ensure this queue processor is idle when there # are no missed-message emails to process. This avoids # constant CPU usage when there is no work to do. - if len(self.batch_start_by_recipient) > 0: + if ScheduledMessageNotificationEmail.objects.exists(): self.ensure_timer()