diff --git a/zerver/lib/soft_deactivation.py b/zerver/lib/soft_deactivation.py index 7a25f8f3c4..07d13de98c 100644 --- a/zerver/lib/soft_deactivation.py +++ b/zerver/lib/soft_deactivation.py @@ -5,7 +5,7 @@ from typing import Any, DefaultDict, Dict, Iterable, List, Optional, Sequence, S from django.conf import settings from django.db import transaction -from django.db.models import Max, QuerySet +from django.db.models import Exists, Max, OuterRef, QuerySet from django.utils.timezone import now as timezone_now from sentry_sdk import capture_exception @@ -208,9 +208,18 @@ def add_missing_messages(user_profile: UserProfile) -> None: continue recipient_ids.append(sub["recipient_id"]) - all_stream_msgs = list( - Message.objects.filter( + new_stream_msgs = ( + Message.objects.annotate( + has_user_message=Exists( + UserMessage.objects.filter( + user_profile_id=user_profile, + message_id=OuterRef("id"), + ) + ) + ) + .filter( # Uses index: zerver_message_realm_recipient_id + has_user_message=0, realm_id=user_profile.realm_id, recipient_id__in=recipient_ids, id__gt=user_profile.last_active_message_id, @@ -218,20 +227,12 @@ def add_missing_messages(user_profile: UserProfile) -> None: .order_by("id") .values("id", "recipient__type_id") ) - already_created_ums = set( - UserMessage.objects.filter( - user_profile=user_profile, - message__recipient__type=Recipient.STREAM, - message_id__gt=user_profile.last_active_message_id, - ).values_list("message_id", flat=True) - ) - - # Filter those messages for which UserMessage rows have been already created - all_stream_msgs = [msg for msg in all_stream_msgs if msg["id"] not in already_created_ums] stream_messages: DefaultDict[int, List[MissingMessageDict]] = defaultdict(list) - for msg in all_stream_msgs: - stream_messages[msg["recipient__type_id"]].append(msg) + for msg in new_stream_msgs: + stream_messages[msg["recipient__type_id"]].append( + MissingMessageDict(id=msg["id"], recipient__type_id=msg["recipient__type_id"]) + ) # Calling this function to filter out stream messages based upon # subscription logs and then store all UserMessage objects for bulk insert diff --git a/zerver/tests/test_soft_deactivation.py b/zerver/tests/test_soft_deactivation.py index 451f257bc2..71f484f7e8 100644 --- a/zerver/tests/test_soft_deactivation.py +++ b/zerver/tests/test_soft_deactivation.py @@ -321,7 +321,7 @@ class SoftDeactivationMessageTest(ZulipTestCase): idle_user_msg_list = get_user_messages(long_term_idle_user) idle_user_msg_count = len(idle_user_msg_list) self.assertNotEqual(idle_user_msg_list[-1].content, message) - with self.assert_database_query_count(8): + with self.assert_database_query_count(7): reactivate_user_if_soft_deactivated(long_term_idle_user) self.assertFalse(long_term_idle_user.long_term_idle) self.assertEqual( @@ -382,7 +382,7 @@ class SoftDeactivationMessageTest(ZulipTestCase): idle_user_msg_list = get_user_messages(long_term_idle_user) idle_user_msg_count = len(idle_user_msg_list) self.assertNotEqual(idle_user_msg_list[-1], sent_message) - with self.assert_database_query_count(6): + with self.assert_database_query_count(5): add_missing_messages(long_term_idle_user) idle_user_msg_list = get_user_messages(long_term_idle_user) self.assert_length(idle_user_msg_list, idle_user_msg_count + 1) @@ -398,7 +398,7 @@ class SoftDeactivationMessageTest(ZulipTestCase): idle_user_msg_list = get_user_messages(long_term_idle_user) idle_user_msg_count = len(idle_user_msg_list) self.assertNotEqual(idle_user_msg_list[-1], sent_message) - with self.assert_database_query_count(7): + with self.assert_database_query_count(6): add_missing_messages(long_term_idle_user) idle_user_msg_list = get_user_messages(long_term_idle_user) self.assert_length(idle_user_msg_list, idle_user_msg_count + 1) @@ -422,7 +422,7 @@ class SoftDeactivationMessageTest(ZulipTestCase): idle_user_msg_count = len(idle_user_msg_list) for sent_message in sent_message_list: self.assertNotEqual(idle_user_msg_list.pop(), sent_message) - with self.assert_database_query_count(6): + with self.assert_database_query_count(5): add_missing_messages(long_term_idle_user) idle_user_msg_list = get_user_messages(long_term_idle_user) self.assert_length(idle_user_msg_list, idle_user_msg_count + 2) @@ -453,7 +453,7 @@ class SoftDeactivationMessageTest(ZulipTestCase): idle_user_msg_count = len(idle_user_msg_list) for sent_message in sent_message_list: self.assertNotEqual(idle_user_msg_list.pop(), sent_message) - with self.assert_database_query_count(6): + with self.assert_database_query_count(5): add_missing_messages(long_term_idle_user) idle_user_msg_list = get_user_messages(long_term_idle_user) self.assert_length(idle_user_msg_list, idle_user_msg_count + 2) @@ -487,7 +487,7 @@ class SoftDeactivationMessageTest(ZulipTestCase): self.assertEqual(idle_user_msg_list[-1].id, sent_message_id) # There are no streams to fetch missing messages from, so # the Message.objects query will be avoided. - with self.assert_database_query_count(4): + with self.assert_database_query_count(3): add_missing_messages(long_term_idle_user) idle_user_msg_list = get_user_messages(long_term_idle_user) # No new UserMessage rows should have been created. @@ -513,7 +513,7 @@ class SoftDeactivationMessageTest(ZulipTestCase): idle_user_msg_count = len(idle_user_msg_list) for sent_message in sent_message_list: self.assertNotEqual(idle_user_msg_list.pop(), sent_message) - with self.assert_database_query_count(6): + with self.assert_database_query_count(5): add_missing_messages(long_term_idle_user) idle_user_msg_list = get_user_messages(long_term_idle_user) self.assert_length(idle_user_msg_list, idle_user_msg_count + 2) @@ -550,7 +550,7 @@ class SoftDeactivationMessageTest(ZulipTestCase): idle_user_msg_list = get_user_messages(long_term_idle_user) idle_user_msg_count = len(idle_user_msg_list) - with self.assert_database_query_count(10): + with self.assert_database_query_count(9): add_missing_messages(long_term_idle_user) idle_user_msg_list = get_user_messages(long_term_idle_user) self.assert_length(idle_user_msg_list, idle_user_msg_count + num_new_messages)