From 44ecd66eaec6533778bdff3fbb31ceb0acc0419a Mon Sep 17 00:00:00 2001 From: Zixuan James Li <359101898@qq.com> Date: Wed, 25 May 2022 20:51:35 -0400 Subject: [PATCH] types: Better types for API fields. Signed-off-by: Zixuan James Li <359101898@qq.com> --- zerver/actions/default_streams.py | 3 ++- zerver/actions/streams.py | 45 +++++++++++++++++++++++++------ zerver/lib/streams.py | 7 ++--- zerver/lib/types.py | 43 +++++++++++++++++++++++++++++ zerver/models.py | 30 ++++++++++++--------- zerver/tests/test_subs.py | 33 ++++++++++++++++++++++- 6 files changed, 135 insertions(+), 26 deletions(-) diff --git a/zerver/actions/default_streams.py b/zerver/actions/default_streams.py index a8932c1227..dba7608c70 100644 --- a/zerver/actions/default_streams.py +++ b/zerver/actions/default_streams.py @@ -4,6 +4,7 @@ from django.db import transaction from django.utils.translation import gettext as _ from zerver.lib.exceptions import JsonableError +from zerver.lib.types import APIStreamDict from zerver.models import ( DefaultStream, DefaultStreamGroup, @@ -183,7 +184,7 @@ def get_default_streams_for_realm(realm_id: int) -> List[Stream]: # returns default streams in JSON serializable format -def streams_to_dicts_sorted(streams: List[Stream]) -> List[Dict[str, Any]]: +def streams_to_dicts_sorted(streams: List[Stream]) -> List[APIStreamDict]: return sorted((stream.to_dict() for stream in streams), key=lambda elt: elt["name"]) diff --git a/zerver/actions/streams.py b/zerver/actions/streams.py index 09c522db42..912ee8a7af 100644 --- a/zerver/actions/streams.py +++ b/zerver/actions/streams.py @@ -46,6 +46,7 @@ from zerver.lib.streams import ( send_stream_creation_event, ) from zerver.lib.subscription_info import get_subscribers_query +from zerver.lib.types import APISubscriptionDict from zerver.models import ( ArchivedAttachment, Attachment, @@ -177,19 +178,47 @@ def send_subscription_add_events( ) for user_id, sub_infos in info_by_user.items(): - sub_dicts = [] + sub_dicts: List[APISubscriptionDict] = [] for sub_info in sub_infos: stream = sub_info.stream stream_info = stream_info_dict[stream.id] subscription = sub_info.sub - sub_dict = stream.to_dict() - for field_name in Subscription.API_FIELDS: - sub_dict[field_name] = getattr(subscription, field_name) + stream_dict = stream.to_dict() + # This is verbose as we cannot unpack existing TypedDict + # to initialize another TypedDict while making mypy happy. + # https://github.com/python/mypy/issues/5382 + sub_dict = APISubscriptionDict( + # Fields from Subscription.API_FIELDS + audible_notifications=subscription.audible_notifications, + color=subscription.color, + desktop_notifications=subscription.desktop_notifications, + email_notifications=subscription.email_notifications, + is_muted=subscription.is_muted, + pin_to_top=subscription.pin_to_top, + push_notifications=subscription.push_notifications, + role=subscription.role, + wildcard_mentions_notify=subscription.wildcard_mentions_notify, + # Computed fields not present in Subscription.API_FIELDS + email_address=stream_info.email_address, + in_home_view=not subscription.is_muted, + stream_weekly_traffic=stream_info.stream_weekly_traffic, + subscribers=stream_info.subscribers, + # Fields from Stream.API_FIELDS + date_created=stream_dict["date_created"], + description=stream_dict["description"], + first_message_id=stream_dict["first_message_id"], + history_public_to_subscribers=stream_dict["history_public_to_subscribers"], + invite_only=stream_dict["invite_only"], + is_web_public=stream_dict["is_web_public"], + message_retention_days=stream_dict["message_retention_days"], + name=stream_dict["name"], + rendered_description=stream_dict["rendered_description"], + stream_id=stream_dict["stream_id"], + stream_post_policy=stream_dict["stream_post_policy"], + # Computed fields not present in Stream.API_FIELDS + is_announcement_only=stream_dict["is_announcement_only"], + ) - sub_dict["in_home_view"] = not subscription.is_muted - sub_dict["email_address"] = stream_info.email_address - sub_dict["stream_weekly_traffic"] = stream_info.stream_weekly_traffic - sub_dict["subscribers"] = stream_info.subscribers sub_dicts.append(sub_dict) # Send a notification to the user who subscribed. diff --git a/zerver/lib/streams.py b/zerver/lib/streams.py index 0d8a12542c..b978b13f3a 100644 --- a/zerver/lib/streams.py +++ b/zerver/lib/streams.py @@ -1,4 +1,4 @@ -from typing import Any, Collection, Dict, List, Optional, Set, Tuple, TypedDict, Union +from typing import Collection, List, Optional, Set, Tuple, TypedDict, Union from django.db import transaction from django.db.models import Exists, OuterRef, Q @@ -18,6 +18,7 @@ from zerver.lib.stream_subscription import ( get_subscribed_stream_ids_for_user, ) from zerver.lib.string_validation import check_stream_name +from zerver.lib.types import APIStreamDict from zerver.models import ( DefaultStreamGroup, Realm, @@ -785,7 +786,7 @@ def get_occupied_streams(realm: Realm) -> QuerySet: return occupied_streams -def get_web_public_streams(realm: Realm) -> List[Dict[str, Any]]: # nocoverage +def get_web_public_streams(realm: Realm) -> List[APIStreamDict]: # nocoverage query = get_web_public_streams_queryset(realm) streams = Stream.get_client_data(query) return streams @@ -799,7 +800,7 @@ def do_get_streams( include_all_active: bool = False, include_default: bool = False, include_owner_subscribed: bool = False, -) -> List[Dict[str, Any]]: +) -> List[APIStreamDict]: # This function is only used by API clients now. if include_all_active and not user_profile.is_realm_admin: diff --git a/zerver/lib/types.py b/zerver/lib/types.py index 8f41a453d7..38931fa694 100644 --- a/zerver/lib/types.py +++ b/zerver/lib/types.py @@ -221,6 +221,49 @@ class NeverSubscribedStreamDict(TypedDict): subscribers: NotRequired[List[int]] +class APIStreamDict(TypedDict): + """Stream information provided to Zulip clients as a dictionary via API. + It should contain all the fields specified in `zerver.models.Stream.API_FIELDS` + with few exceptions and possible additional fields. + """ + + date_created: int + description: str + first_message_id: Optional[int] + history_public_to_subscribers: bool + invite_only: bool + is_web_public: bool + message_retention_days: Optional[int] + name: str + rendered_description: str + stream_id: int # `stream_id`` represents `id` of the `Stream` object in `API_FIELDS` + stream_post_policy: int + # Computed fields not specified in `Stream.API_FIELDS` + is_announcement_only: bool + is_default: NotRequired[bool] + + +class APISubscriptionDict(APIStreamDict): + """Similar to StreamClientDict, it should contain all the fields specified in + `zerver.models.Subscription.API_FIELDS` and several additional fields. + """ + + audible_notifications: Optional[bool] + color: str + desktop_notifications: Optional[bool] + email_notifications: Optional[bool] + is_muted: bool + pin_to_top: bool + push_notifications: Optional[bool] + role: int + wildcard_mentions_notify: Optional[bool] + # Computed fields not specified in `Subscription.API_FIELDS` + email_address: str + in_home_view: bool + stream_weekly_traffic: Optional[int] + subscribers: List[int] + + @dataclass class SubscriptionInfo: subscriptions: List[SubscriptionStreamDict] diff --git a/zerver/models.py b/zerver/models.py index 6da17f4286..b82e31f5d1 100644 --- a/zerver/models.py +++ b/zerver/models.py @@ -83,6 +83,7 @@ from zerver.lib.exceptions import JsonableError, RateLimited from zerver.lib.pysa import mark_sanitized from zerver.lib.timestamp import datetime_to_timestamp from zerver.lib.types import ( + APIStreamDict, DisplayRecipientT, ExtendedFieldElement, ExtendedValidator, @@ -2483,22 +2484,25 @@ class Stream(models.Model): ] @staticmethod - def get_client_data(query: QuerySet) -> List[Dict[str, Any]]: + def get_client_data(query: QuerySet) -> List[APIStreamDict]: query = query.only(*Stream.API_FIELDS) return [row.to_dict() for row in query] - def to_dict(self) -> Dict[str, Any]: - result = {} - for field_name in self.API_FIELDS: - if field_name == "id": - result["stream_id"] = self.id - continue - elif field_name == "date_created": - result["date_created"] = datetime_to_timestamp(self.date_created) - continue - result[field_name] = getattr(self, field_name) - result["is_announcement_only"] = self.stream_post_policy == Stream.STREAM_POST_POLICY_ADMINS - return result + def to_dict(self) -> APIStreamDict: + return APIStreamDict( + date_created=datetime_to_timestamp(self.date_created), + description=self.description, + first_message_id=self.first_message_id, + history_public_to_subscribers=self.history_public_to_subscribers, + invite_only=self.invite_only, + is_web_public=self.is_web_public, + message_retention_days=self.message_retention_days, + name=self.name, + rendered_description=self.rendered_description, + stream_id=self.id, + stream_post_policy=self.stream_post_policy, + is_announcement_only=self.stream_post_policy == Stream.STREAM_POST_POLICY_ADMINS, + ) class Meta: indexes = [ diff --git a/zerver/tests/test_subs.py b/zerver/tests/test_subs.py index f4f7b47ef9..2746ad58f7 100644 --- a/zerver/tests/test_subs.py +++ b/zerver/tests/test_subs.py @@ -75,7 +75,12 @@ from zerver.lib.test_helpers import ( queries_captured, reset_emails_in_zulip_realm, ) -from zerver.lib.types import NeverSubscribedStreamDict, SubscriptionInfo +from zerver.lib.types import ( + APIStreamDict, + APISubscriptionDict, + NeverSubscribedStreamDict, + SubscriptionInfo, +) from zerver.models import ( Attachment, DefaultStream, @@ -205,6 +210,32 @@ class TestMiscStuff(ZulipTestCase): ) self.assertEqual(streams, []) + def test_api_fields(self) -> None: + """Verify that all the fields from `Stream.API_FIELDS` and `Subscription.API_FIELDS` present + in `APIStreamDict` and `APISubscriptionDict`, respectively. + """ + expected_fields = set(Stream.API_FIELDS) | {"stream_id"} + expected_fields -= {"id"} + + stream_dict_fields = set(APIStreamDict.__annotations__.keys()) + computed_fields = set(["is_announcement_only", "is_default"]) + + self.assertEqual(stream_dict_fields - computed_fields, expected_fields) + + expected_fields = set(Subscription.API_FIELDS) + + subscription_dict_fields = set(APISubscriptionDict.__annotations__.keys()) + computed_fields = set( + ["in_home_view", "email_address", "stream_weekly_traffic", "subscribers"] + ) + # `APISubscriptionDict` is a subclass of `APIStreamDict`, therefore having all the + # fields in addition to the computed fields and `Subscription.API_FIELDS` that + # need to be excluded here. + self.assertEqual( + subscription_dict_fields - computed_fields - stream_dict_fields, + expected_fields, + ) + class TestCreateStreams(ZulipTestCase): def test_creating_streams(self) -> None: