From 5e42c8cd2be798358767bb407987dfd576e8664f Mon Sep 17 00:00:00 2001 From: Prakhar Pratyush Date: Mon, 10 Jun 2024 18:12:36 +0530 Subject: [PATCH] user_topics: Handle IntegrityError during bulk insertion. When there was a race during bulk insertion of UserTopic rows, it resulted in Integrity error. We update the 'last_updated' and 'visibility_policy' columns for conflicting rows. We also removed the separate update query to update visibility_policy because now the new SQL query can handle the updates too. This leads to have fewer round trips to the database. --- zerver/lib/user_topics.py | 54 +++++++++++++++++++++++++-------------- 1 file changed, 35 insertions(+), 19 deletions(-) diff --git a/zerver/lib/user_topics.py b/zerver/lib/user_topics.py index 3898b4b1c5..6285f1d88e 100644 --- a/zerver/lib/user_topics.py +++ b/zerver/lib/user_topics.py @@ -3,9 +3,10 @@ from collections import defaultdict from datetime import datetime from typing import Callable, Dict, List, Optional, Tuple, TypedDict -from django.db import transaction +from django.db import connection, transaction from django.db.models import QuerySet from django.utils.timezone import now as timezone_now +from psycopg2.sql import SQL, Literal from sqlalchemy.sql import ClauseElement, and_, column, not_, or_ from sqlalchemy.types import Integer @@ -162,7 +163,9 @@ def bulk_set_user_topic_visibility_policy_in_database( assert last_updated is not None assert recipient_id is not None - user_profiles_seeking_visibility_policy_update: List[UserProfile] = [] + user_profiles_seeking_user_topic_update_or_create: List[UserProfile] = ( + user_profiles_without_visibility_policy + ) for row in rows: if row.visibility_policy == visibility_policy: logging.info( @@ -172,26 +175,39 @@ def bulk_set_user_topic_visibility_policy_in_database( ) continue # The request is to just 'update' the visibility policy of a topic - user_profiles_seeking_visibility_policy_update.append(row.user_profile) + user_profiles_seeking_user_topic_update_or_create.append(row.user_profile) - if user_profiles_seeking_visibility_policy_update: - rows.filter(user_profile__in=user_profiles_seeking_visibility_policy_update).update( - visibility_policy=visibility_policy, last_updated=last_updated - ) - - if user_profiles_without_visibility_policy: - UserTopic.objects.bulk_create( - UserTopic( - user_profile=user_profile, - stream_id=stream_id, - recipient_id=recipient_id, - topic_name=topic_name, - last_updated=last_updated, - visibility_policy=visibility_policy, + if user_profiles_seeking_user_topic_update_or_create: + user_profile_ids_array = SQL("ARRAY[{}]").format( + SQL(", ").join( + [ + Literal(user_profile.id) + for user_profile in user_profiles_seeking_user_topic_update_or_create + ] ) - for user_profile in user_profiles_without_visibility_policy ) - return user_profiles_seeking_visibility_policy_update + user_profiles_without_visibility_policy + + query = SQL(""" + INSERT INTO zerver_usertopic(user_profile_id, stream_id, recipient_id, topic_name, last_updated, visibility_policy) + SELECT * FROM UNNEST({user_profile_ids_array}) AS user_profile(user_profile_id) + CROSS JOIN (VALUES ({stream_id}, {recipient_id}, {topic_name}, {last_updated}, {visibility_policy})) + AS other_values(stream_id, recipient_id, topic_name, last_updated, visibility_policy) + ON CONFLICT (user_profile_id, stream_id, lower(topic_name)) DO UPDATE SET + last_updated = EXCLUDED.last_updated, + visibility_policy = EXCLUDED.visibility_policy; + """).format( + user_profile_ids_array=user_profile_ids_array, + stream_id=Literal(stream_id), + recipient_id=Literal(recipient_id), + topic_name=Literal(topic_name), + last_updated=Literal(last_updated), + visibility_policy=Literal(visibility_policy), + ) + + with connection.cursor() as cursor: + cursor.execute(query) + + return user_profiles_seeking_user_topic_update_or_create def topic_has_visibility_policy(