diff --git a/zerver/lib/streams.py b/zerver/lib/streams.py index 8b90df0b10..a858e11ef3 100644 --- a/zerver/lib/streams.py +++ b/zerver/lib/streams.py @@ -24,7 +24,7 @@ from zerver.lib.string_validation import check_stream_name from zerver.lib.timestamp import datetime_to_timestamp from zerver.lib.types import AnonymousSettingGroupDict, APIStreamDict from zerver.lib.user_groups import ( - get_recursive_group_members, + get_recursive_group_members_union_for_groups, get_recursive_membership_groups, get_role_based_system_groups_dict, user_has_permission_for_group_setting, @@ -181,17 +181,11 @@ def get_default_values_for_stream_permission_group_settings( def get_user_ids_with_metadata_access_via_permission_groups(stream: Stream) -> set[int]: - stream_admin_user_ids = set( - get_recursive_group_members(stream.can_administer_channel_group_id).values_list( - "id", flat=True - ) + return set( + get_recursive_group_members_union_for_groups( + [stream.can_add_subscribers_group_id, stream.can_administer_channel_group_id] + ).values_list("id", flat=True) ) - stream_add_subscribers_group_user_ids = set( - get_recursive_group_members(stream.can_add_subscribers_group_id).values_list( - "id", flat=True - ) - ) - return stream_admin_user_ids | stream_add_subscribers_group_user_ids @transaction.atomic(savepoint=False) diff --git a/zerver/lib/user_groups.py b/zerver/lib/user_groups.py index ffb60b3ad8..93714f142a 100644 --- a/zerver/lib/user_groups.py +++ b/zerver/lib/user_groups.py @@ -669,9 +669,9 @@ def get_direct_memberships_of_users(user_group: UserGroup, members: list[UserPro # https://code.djangoproject.com/ticket/28919 -def get_recursive_subgroups(user_group_id: int) -> QuerySet[UserGroup]: +def get_recursive_subgroups_union_for_groups(user_group_ids: list[int]) -> QuerySet[UserGroup]: cte = With.recursive( - lambda cte: UserGroup.objects.filter(id=user_group_id) + lambda cte: UserGroup.objects.filter(id__in=user_group_ids) .values(group_id=F("id")) .union( cte.join(NamedUserGroup, direct_supergroups=cte.col.group_id).values(group_id=F("id")) @@ -680,6 +680,10 @@ def get_recursive_subgroups(user_group_id: int) -> QuerySet[UserGroup]: return cte.join(UserGroup, id=cte.col.group_id).with_cte(cte) +def get_recursive_subgroups(user_group_id: int) -> QuerySet[UserGroup]: + return get_recursive_subgroups_union_for_groups([user_group_id]) + + def get_recursive_strict_subgroups(user_group: UserGroup) -> QuerySet[NamedUserGroup]: # Same as get_recursive_subgroups but does not include the # user_group passed. @@ -695,8 +699,15 @@ def get_recursive_strict_subgroups(user_group: UserGroup) -> QuerySet[NamedUserG def get_recursive_group_members(user_group_id: int) -> QuerySet[UserProfile]: + return get_recursive_group_members_union_for_groups([user_group_id]) + + +def get_recursive_group_members_union_for_groups( + user_group_ids: list[int], +) -> QuerySet[UserProfile]: return UserProfile.objects.filter( - is_active=True, direct_groups__in=get_recursive_subgroups(user_group_id) + is_active=True, + direct_groups__in=get_recursive_subgroups_union_for_groups(user_group_ids), ) diff --git a/zerver/tests/test_subs.py b/zerver/tests/test_subs.py index 7409161f49..381456fc4c 100644 --- a/zerver/tests/test_subs.py +++ b/zerver/tests/test_subs.py @@ -3248,7 +3248,7 @@ class StreamAdminTest(ZulipTestCase): are on. """ result = self.attempt_unsubscribe_of_principal( - query_count=19, + query_count=16, target_users=[self.example_user("cordelia")], is_realm_admin=True, is_subbed=True, @@ -3265,7 +3265,7 @@ class StreamAdminTest(ZulipTestCase): streams you aren't on. """ result = self.attempt_unsubscribe_of_principal( - query_count=19, + query_count=16, target_users=[self.example_user("cordelia")], is_realm_admin=True, is_subbed=False, @@ -5992,7 +5992,7 @@ class SubscriptionAPITest(ZulipTestCase): # Sends 3 peer-remove events, 2 unsubscribe events # and 2 stream delete events for private streams. with ( - self.assert_database_query_count(20), + self.assert_database_query_count(19), self.assert_memcached_count(3), self.capture_send_event_calls(expected_num_events=7) as events, ): @@ -6548,7 +6548,7 @@ class SubscriptionAPITest(ZulipTestCase): ) # Test creating private stream. - with self.assert_database_query_count(50): + with self.assert_database_query_count(48): self.subscribe_via_post( self.test_user, [new_streams[1]], diff --git a/zerver/tests/test_user_groups.py b/zerver/tests/test_user_groups.py index 540be342a0..a21e38ddf7 100644 --- a/zerver/tests/test_user_groups.py +++ b/zerver/tests/test_user_groups.py @@ -40,9 +40,11 @@ from zerver.lib.types import AnonymousSettingGroupDict from zerver.lib.user_groups import ( get_direct_user_groups, get_recursive_group_members, + get_recursive_group_members_union_for_groups, get_recursive_membership_groups, get_recursive_strict_subgroups, get_recursive_subgroups, + get_recursive_subgroups_union_for_groups, get_role_based_system_groups_dict, get_subgroup_ids, get_user_group_member_ids, @@ -249,6 +251,8 @@ class UserGroupTestCase(ZulipTestCase): iago = self.example_user("iago") desdemona = self.example_user("desdemona") shiva = self.example_user("shiva") + aaron = self.example_user("aaron") + prospero = self.example_user("prospero") leadership_group = check_add_user_group( realm, "Leadership", [desdemona], acting_user=desdemona @@ -257,8 +261,14 @@ class UserGroupTestCase(ZulipTestCase): staff_group = check_add_user_group(realm, "Staff", [iago], acting_user=iago) GroupGroupMembership.objects.create(supergroup=staff_group, subgroup=leadership_group) + manager_group = check_add_user_group( + realm, "Managers", [aaron, prospero], acting_user=aaron + ) + GroupGroupMembership.objects.create(supergroup=manager_group, subgroup=leadership_group) + everyone_group = check_add_user_group(realm, "Everyone", [shiva], acting_user=shiva) GroupGroupMembership.objects.create(supergroup=everyone_group, subgroup=staff_group) + GroupGroupMembership.objects.create(supergroup=everyone_group, subgroup=manager_group) self.assertCountEqual( list(get_recursive_subgroups(leadership_group.id)), [leadership_group.usergroup_ptr] @@ -273,6 +283,16 @@ class UserGroupTestCase(ZulipTestCase): leadership_group.usergroup_ptr, staff_group.usergroup_ptr, everyone_group.usergroup_ptr, + manager_group.usergroup_ptr, + ], + ) + + self.assertCountEqual( + list(get_recursive_subgroups_union_for_groups([staff_group.id, manager_group.id])), + [ + leadership_group.usergroup_ptr, + staff_group.usergroup_ptr, + manager_group.usergroup_ptr, ], ) @@ -280,28 +300,43 @@ class UserGroupTestCase(ZulipTestCase): self.assertCountEqual(list(get_recursive_strict_subgroups(staff_group)), [leadership_group]) self.assertCountEqual( list(get_recursive_strict_subgroups(everyone_group)), - [leadership_group, staff_group], + [leadership_group, staff_group, manager_group], ) self.assertCountEqual(list(get_recursive_group_members(leadership_group.id)), [desdemona]) self.assertCountEqual(list(get_recursive_group_members(staff_group.id)), [desdemona, iago]) self.assertCountEqual( - list(get_recursive_group_members(everyone_group.id)), [desdemona, iago, shiva] + list(get_recursive_group_members(everyone_group.id)), + [desdemona, iago, shiva, aaron, prospero], + ) + + self.assertCountEqual( + list(get_recursive_group_members_union_for_groups([staff_group.id, manager_group.id])), + [iago, desdemona, aaron, prospero], + ) + self.assertCountEqual( + list( + get_recursive_group_members_union_for_groups([leadership_group.id, staff_group.id]) + ), + [desdemona, iago], ) self.assertIn(leadership_group.usergroup_ptr, get_recursive_membership_groups(desdemona)) self.assertIn(staff_group.usergroup_ptr, get_recursive_membership_groups(desdemona)) self.assertIn(everyone_group.usergroup_ptr, get_recursive_membership_groups(desdemona)) + self.assertIn(manager_group.usergroup_ptr, get_recursive_membership_groups(desdemona)) self.assertIn(staff_group.usergroup_ptr, get_recursive_membership_groups(iago)) self.assertIn(everyone_group.usergroup_ptr, get_recursive_membership_groups(iago)) + self.assertNotIn(manager_group.usergroup_ptr, get_recursive_membership_groups(iago)) self.assertIn(everyone_group.usergroup_ptr, get_recursive_membership_groups(shiva)) do_deactivate_user(iago, acting_user=None) self.assertCountEqual(list(get_recursive_group_members(staff_group.id)), [desdemona]) self.assertCountEqual( - list(get_recursive_group_members(everyone_group.id)), [desdemona, shiva] + list(get_recursive_group_members(everyone_group.id)), + [desdemona, shiva, aaron, prospero], ) def test_subgroups_of_role_based_system_groups(self) -> None: diff --git a/zerver/tests/test_users.py b/zerver/tests/test_users.py index 13e038fb66..efb39abd5f 100644 --- a/zerver/tests/test_users.py +++ b/zerver/tests/test_users.py @@ -1021,7 +1021,7 @@ class QueryCountTest(ZulipTestCase): prereg_user = PreregistrationUser.objects.get(email="fred@zulip.com") with ( - self.assert_database_query_count(93), + self.assert_database_query_count(87), self.assert_memcached_count(19), self.capture_send_event_calls(expected_num_events=10) as events, ):