diff --git a/zerver/lib/actions.py b/zerver/lib/actions.py index 1faf23c6a9..36ea1683bb 100644 --- a/zerver/lib/actions.py +++ b/zerver/lib/actions.py @@ -4,6 +4,7 @@ import logging import os import time from collections import defaultdict +from dataclasses import dataclass from operator import itemgetter from typing import ( AbstractSet, @@ -232,6 +233,13 @@ from zerver.models import ( ) from zerver.tornado.django_api import send_event + +@dataclass +class SubInfo: + user: UserProfile + sub: Subscription + stream: Stream + if settings.BILLING_ENABLED: from corporate.lib.stripe import ( downgrade_now_without_creating_additional_invoices, @@ -2724,19 +2732,21 @@ def get_subscriber_emails(stream: Stream, def send_subscription_add_events( realm: Realm, - sub_streams: List[Tuple[Subscription, Stream]], + sub_info_list: List[SubInfo], subscriber_dict: Dict[int, List[int]], ) -> None: - sub_tuples_by_user: Dict[int, List[Tuple[Subscription, Stream]]] = defaultdict(list) - for (sub, stream) in sub_streams: - sub_tuples_by_user[sub.user_profile.id].append((sub, stream)) + info_by_user: Dict[int, List[SubInfo]] = defaultdict(list) + for sub_info in sub_info_list: + info_by_user[sub_info.user.id].append(sub_info) - stream_ids = {stream.id for (sub, stream) in sub_streams} + stream_ids = {sub_info.stream.id for sub_info in sub_info_list} recent_traffic = get_streams_traffic(stream_ids=stream_ids) - for user_id, sub_streams in sub_tuples_by_user.items(): + for user_id, sub_infos in info_by_user.items(): sub_dicts = [] - for (subscription, stream) in sub_streams: + for sub_info in sub_infos: + stream = sub_info.stream + subscription = sub_info.sub sub_dict = stream.to_dict() for field_name in Subscription.API_FIELDS: if field_name == "active": @@ -2808,7 +2818,7 @@ def get_user_ids_for_streams(streams: Iterable[Stream]) -> Dict[int, List[int]]: return all_subscribers_by_stream -SubT = Tuple[List[Tuple[UserProfile, Stream]], List[Tuple[UserProfile, Stream]]] +SubT = Tuple[List[SubInfo], List[SubInfo]] def bulk_add_subscriptions( realm: Realm, streams: Iterable[Stream], @@ -2832,9 +2842,9 @@ def bulk_add_subscriptions( for sub in all_subs_query: subs_by_user[sub.user_profile_id].append(sub) - already_subscribed: List[Tuple[UserProfile, Stream]] = [] - subs_to_activate: List[Tuple[Subscription, Stream]] = [] - subs_to_add: List[Tuple[Subscription, Stream]] = [] + already_subscribed: List[SubInfo] = [] + subs_to_activate: List[SubInfo] = [] + subs_to_add: List[SubInfo] = [] for user_profile in users: my_subs = subs_by_user[user_profile.id] used_colors = {sub.color for sub in my_subs} @@ -2848,10 +2858,11 @@ def bulk_add_subscriptions( if sub.recipient_id in new_recipient_ids: new_recipient_ids.remove(sub.recipient_id) stream = recipient_id_to_stream[sub.recipient_id] + sub_info = SubInfo(user_profile, sub, stream) if sub.active: - already_subscribed.append((user_profile, stream)) + already_subscribed.append(sub_info) else: - subs_to_activate.append((sub, stream)) + subs_to_activate.append(sub_info) for recipient_id in new_recipient_ids: stream = recipient_id_to_stream[recipient_id] @@ -2862,9 +2873,14 @@ def bulk_add_subscriptions( color = pick_color(user_profile, used_colors) used_colors.add(color) - sub_to_add = Subscription(user_profile=user_profile, active=True, - color=color, recipient_id=recipient_id) - subs_to_add.append((sub_to_add, stream)) + sub = Subscription( + user_profile=user_profile, + active=True, + color=color, + recipient_id=recipient_id + ) + sub_info = SubInfo(user_profile, sub, stream) + subs_to_add.append(sub_info) bulk_add_subs_to_db_with_logging( realm=realm, @@ -2881,8 +2897,8 @@ def bulk_add_subscriptions( all_subscribers_by_stream = get_user_ids_for_streams(streams=streams) new_stream_user_ids: Dict[int, Set[int]] = defaultdict(set) - for (sub, stream) in subs_to_add + subs_to_activate: - new_stream_user_ids[stream.id].add(sub.user_profile_id) + for sub_info in subs_to_add + subs_to_activate: + new_stream_user_ids[sub_info.stream.id].add(sub_info.user.id) stream_dict = {stream.id: stream for stream in streams} @@ -2897,7 +2913,7 @@ def bulk_add_subscriptions( send_subscription_add_events( realm=realm, - sub_streams=subs_to_add + subs_to_activate, + sub_info_list=subs_to_add + subs_to_activate, subscriber_dict=all_subscribers_by_stream, ) @@ -2908,19 +2924,20 @@ def bulk_add_subscriptions( all_subscribers_by_stream=all_subscribers_by_stream, ) - return ([(sub.user_profile, stream) for (sub, stream) in subs_to_add] + - [(sub.user_profile, stream) for (sub, stream) in subs_to_activate], - already_subscribed) + return ( + subs_to_add + subs_to_activate, + already_subscribed, + ) def bulk_add_subs_to_db_with_logging( realm: Realm, acting_user: Optional[UserProfile], - subs_to_add: List[Tuple[Subscription, Stream]], - subs_to_activate: List[Tuple[Subscription, Stream]], + subs_to_add: List[SubInfo], + subs_to_activate: List[SubInfo], ) -> None: - Subscription.objects.bulk_create(sub for (sub, stream) in subs_to_add) - sub_ids = [sub.id for (sub, stream) in subs_to_activate] + Subscription.objects.bulk_create(info.sub for info in subs_to_add) + sub_ids = [info.sub.id for info in subs_to_activate] Subscription.objects.filter(id__in=sub_ids).update(active=True) # Log Subscription Activities in RealmAuditLog @@ -2928,19 +2945,19 @@ def bulk_add_subs_to_db_with_logging( event_last_message_id = get_last_message_id() all_subscription_logs: (List[RealmAuditLog]) = [] - for (sub, stream) in subs_to_add: + for sub_info in subs_to_add: all_subscription_logs.append(RealmAuditLog(realm=realm, acting_user=acting_user, - modified_user=sub.user_profile, - modified_stream=stream, + modified_user=sub_info.user, + modified_stream=sub_info.stream, event_last_message_id=event_last_message_id, event_type=RealmAuditLog.SUBSCRIPTION_CREATED, event_time=event_time)) - for (sub, stream) in subs_to_activate: + for sub_info in subs_to_activate: all_subscription_logs.append(RealmAuditLog(realm=realm, acting_user=acting_user, - modified_user=sub.user_profile, - modified_stream=stream, + modified_user=sub_info.user, + modified_stream=sub_info.stream, event_last_message_id=event_last_message_id, event_type=RealmAuditLog.SUBSCRIPTION_ACTIVATED, event_time=event_time)) diff --git a/zerver/management/commands/add_users_to_streams.py b/zerver/management/commands/add_users_to_streams.py index b7b9d9895c..b47d60d484 100644 --- a/zerver/management/commands/add_users_to_streams.py +++ b/zerver/management/commands/add_users_to_streams.py @@ -29,7 +29,7 @@ class Command(ZulipBaseCommand): for user_profile in user_profiles: stream = ensure_stream(realm, stream_name, acting_user=None) _ignore, already_subscribed = bulk_add_subscriptions(realm, [stream], [user_profile]) - was_there_already = user_profile.id in (tup[0].id for tup in already_subscribed) + was_there_already = user_profile.id in (info.user.id for info in already_subscribed) print("{} {} to {}".format( "Already subscribed" if was_there_already else "Subscribed", user_profile.delivery_email, stream_name)) diff --git a/zerver/views/streams.py b/zerver/views/streams.py index f7e4b90819..cf65126f76 100644 --- a/zerver/views/streams.py +++ b/zerver/views/streams.py @@ -489,10 +489,14 @@ def add_subscriptions_backend( email_to_user_profile: Dict[str, UserProfile] = {} result: Dict[str, Any] = dict(subscribed=defaultdict(list), already_subscribed=defaultdict(list)) - for (subscriber, stream) in subscribed: + for sub_info in subscribed: + subscriber = sub_info.user + stream = sub_info.stream result["subscribed"][subscriber.email].append(stream.name) email_to_user_profile[subscriber.email] = subscriber - for (subscriber, stream) in already_subscribed: + for sub_info in already_subscribed: + subscriber = sub_info.user + stream = sub_info.stream result["already_subscribed"][subscriber.email].append(stream.name) result["subscribed"] = dict(result["subscribed"])