From ebfd2b25b18c8cc2f3d7bfa2636348c6896097ed Mon Sep 17 00:00:00 2001 From: Zixuan James Li Date: Thu, 21 Jul 2022 09:19:40 -0400 Subject: [PATCH] user_status: Add UserInfoDict. The shared fields of `RawUserInfoDict` and `UserInfoDict` could have been reused if they both require all keys or none. This is unfortunately not the case, because subclassing does not override `__total__`. Signed-off-by: Zixuan James Li --- zerver/lib/user_status.py | 28 +++++++++++++++++++++++----- zerver/tests/test_user_status.py | 4 ++-- 2 files changed, 25 insertions(+), 7 deletions(-) diff --git a/zerver/lib/user_status.py b/zerver/lib/user_status.py index 3009a47652..afffb68246 100644 --- a/zerver/lib/user_status.py +++ b/zerver/lib/user_status.py @@ -1,4 +1,4 @@ -from typing import Any, Dict, Optional +from typing import Dict, Optional, TypedDict from django.db.models import Q from django.utils.timezone import now as timezone_now @@ -6,14 +6,32 @@ from django.utils.timezone import now as timezone_now from zerver.models import UserStatus -def format_user_status(row: Dict[str, Any]) -> Dict[str, Any]: +class UserInfoDict(TypedDict, total=False): + status: int + status_text: str + emoji_name: str + emoji_code: str + reaction_type: str + away: bool + + +class RawUserInfoDict(TypedDict): + user_profile_id: int + status: int + status_text: str + emoji_name: str + emoji_code: str + reaction_type: str + + +def format_user_status(row: RawUserInfoDict) -> UserInfoDict: away = row["status"] == UserStatus.AWAY status_text = row["status_text"] emoji_name = row["emoji_name"] emoji_code = row["emoji_code"] reaction_type = row["reaction_type"] - dct = {} + dct: UserInfoDict = {} if away: dct["away"] = away if status_text: @@ -26,7 +44,7 @@ def format_user_status(row: Dict[str, Any]) -> Dict[str, Any]: return dct -def get_user_info_dict(realm_id: int) -> Dict[str, Dict[str, Any]]: +def get_user_info_dict(realm_id: int) -> Dict[str, UserInfoDict]: rows = ( UserStatus.objects.filter( user_profile__realm_id=realm_id, @@ -49,7 +67,7 @@ def get_user_info_dict(realm_id: int) -> Dict[str, Dict[str, Any]]: ) ) - user_dict: Dict[str, Dict[str, Any]] = {} + user_dict: Dict[str, UserInfoDict] = {} for row in rows: user_id = row["user_profile_id"] user_dict[str(user_id)] = format_user_status(row) diff --git a/zerver/tests/test_user_status.py b/zerver/tests/test_user_status.py index eb1c3691d2..8e0ae3f47d 100644 --- a/zerver/tests/test_user_status.py +++ b/zerver/tests/test_user_status.py @@ -3,7 +3,7 @@ from typing import Any, Dict, List, Mapping, Set import orjson from zerver.lib.test_classes import ZulipTestCase -from zerver.lib.user_status import get_user_info_dict, update_user_status +from zerver.lib.user_status import UserInfoDict, get_user_info_dict, update_user_status from zerver.models import UserProfile, UserStatus, get_client @@ -13,7 +13,7 @@ def get_away_user_ids(realm_id: int) -> Set[int]: return {int(user_id) for user_id in user_dict if user_dict[user_id].get("away")} -def user_info(user: UserProfile) -> Dict[str, Any]: +def user_info(user: UserProfile) -> UserInfoDict: user_dict = get_user_info_dict(user.realm_id) return user_dict.get(str(user.id), {})