diff --git a/zerver/lib/topic.py b/zerver/lib/topic.py index bec801246a..c5b18df037 100644 --- a/zerver/lib/topic.py +++ b/zerver/lib/topic.py @@ -6,6 +6,7 @@ from django.utils.timezone import now as timezone_now from sqlalchemy.sql import ( column, + literal, func, ) @@ -26,7 +27,7 @@ PREV_TOPIC = "prev_subject" def topic_match_sa(topic_name: str) -> Any: # _sa is short for Sql Alchemy, which we use mostly for # queries that search messages - topic_cond = func.upper(column("subject")) == func.upper(topic_name) + topic_cond = func.upper(column("subject")) == func.upper(literal(topic_name)) return topic_cond def filter_by_exact_message_topic(query: QuerySet, message: Message) -> QuerySet: diff --git a/zerver/tests/test_narrow.py b/zerver/tests/test_narrow.py index 27cf45851f..ec6fd3674c 100644 --- a/zerver/tests/test_narrow.py +++ b/zerver/tests/test_narrow.py @@ -2410,13 +2410,13 @@ class GetOldMessagesTest(ZulipTestCase): expected_query = ''' SELECT id AS message_id FROM zerver_message - WHERE NOT (recipient_id = :recipient_id_1 AND upper(subject) = upper(:upper_1)) + WHERE NOT (recipient_id = :recipient_id_1 AND upper(subject) = upper(:param_1)) ''' self.assertEqual(fix_ws(query), fix_ws(expected_query)) params = get_sqlalchemy_query_params(query) self.assertEqual(params['recipient_id_1'], get_recipient_id_for_stream_name(realm, 'Scotland')) - self.assertEqual(params['upper_1'], 'golf') + self.assertEqual(params['param_1'], 'golf') mute_stream(realm, user_profile, 'Verona') @@ -2435,15 +2435,15 @@ class GetOldMessagesTest(ZulipTestCase): FROM zerver_message WHERE recipient_id NOT IN (:recipient_id_1) AND NOT - (recipient_id = :recipient_id_2 AND upper(subject) = upper(:upper_1) OR - recipient_id = :recipient_id_3 AND upper(subject) = upper(:upper_2))''' + (recipient_id = :recipient_id_2 AND upper(subject) = upper(:param_1) OR + recipient_id = :recipient_id_3 AND upper(subject) = upper(:param_2))''' self.assertEqual(fix_ws(query), fix_ws(expected_query)) params = get_sqlalchemy_query_params(query) self.assertEqual(params['recipient_id_1'], get_recipient_id_for_stream_name(realm, 'Verona')) self.assertEqual(params['recipient_id_2'], get_recipient_id_for_stream_name(realm, 'Scotland')) - self.assertEqual(params['upper_1'], 'golf') + self.assertEqual(params['param_1'], 'golf') self.assertEqual(params['recipient_id_3'], get_recipient_id_for_stream_name(realm, 'web stuff')) - self.assertEqual(params['upper_2'], 'css') + self.assertEqual(params['param_2'], 'css') def test_get_messages_queries(self) -> None: query_ids = self.get_query_ids() diff --git a/zerver/views/messages.py b/zerver/views/messages.py index 5566f94a3e..da5baefa52 100644 --- a/zerver/views/messages.py +++ b/zerver/views/messages.py @@ -33,6 +33,9 @@ from zerver.lib.sqlalchemy_utils import get_sqlalchemy_connection from zerver.lib.streams import access_stream_by_id, can_access_stream_history_by_name from zerver.lib.timestamp import datetime_to_timestamp, convert_to_UTC from zerver.lib.timezone import get_timezone +from zerver.lib.topic import ( + topic_match_sa, +) from zerver.lib.topic_mutes import exclude_topic_mutes from zerver.lib.utils import statsd from zerver.lib.validator import \ @@ -241,36 +244,36 @@ class NarrowBuilder: # instance "personal" to be the same. if base_topic in ('', 'personal', '(instance "")'): cond = or_( - func.upper(column("subject")) == func.upper(literal("")), - func.upper(column("subject")) == func.upper(literal(".d")), - func.upper(column("subject")) == func.upper(literal(".d.d")), - func.upper(column("subject")) == func.upper(literal(".d.d.d")), - func.upper(column("subject")) == func.upper(literal(".d.d.d.d")), - func.upper(column("subject")) == func.upper(literal("personal")), - func.upper(column("subject")) == func.upper(literal("personal.d")), - func.upper(column("subject")) == func.upper(literal("personal.d.d")), - func.upper(column("subject")) == func.upper(literal("personal.d.d.d")), - func.upper(column("subject")) == func.upper(literal("personal.d.d.d.d")), - func.upper(column("subject")) == func.upper(literal('(instance "")')), - func.upper(column("subject")) == func.upper(literal('(instance "").d')), - func.upper(column("subject")) == func.upper(literal('(instance "").d.d')), - func.upper(column("subject")) == func.upper(literal('(instance "").d.d.d')), - func.upper(column("subject")) == func.upper(literal('(instance "").d.d.d.d')), + topic_match_sa(""), + topic_match_sa(".d"), + topic_match_sa(".d.d"), + topic_match_sa(".d.d.d"), + topic_match_sa(".d.d.d.d"), + topic_match_sa("personal"), + topic_match_sa("personal.d"), + topic_match_sa("personal.d.d"), + topic_match_sa("personal.d.d.d"), + topic_match_sa("personal.d.d.d.d"), + topic_match_sa('(instance "")'), + topic_match_sa('(instance "").d'), + topic_match_sa('(instance "").d.d'), + topic_match_sa('(instance "").d.d.d'), + topic_match_sa('(instance "").d.d.d.d'), ) else: # We limit `.d` counts, since postgres has much better # query planning for this than they do for a regular # expression (which would sometimes table scan). cond = or_( - func.upper(column("subject")) == func.upper(literal(base_topic)), - func.upper(column("subject")) == func.upper(literal(base_topic + ".d")), - func.upper(column("subject")) == func.upper(literal(base_topic + ".d.d")), - func.upper(column("subject")) == func.upper(literal(base_topic + ".d.d.d")), - func.upper(column("subject")) == func.upper(literal(base_topic + ".d.d.d.d")), + topic_match_sa(base_topic), + topic_match_sa(base_topic + ".d"), + topic_match_sa(base_topic + ".d.d"), + topic_match_sa(base_topic + ".d.d.d"), + topic_match_sa(base_topic + ".d.d.d.d"), ) return query.where(maybe_negate(cond)) - cond = func.upper(column("subject")) == func.upper(literal(operand)) + cond = topic_match_sa(operand) return query.where(maybe_negate(cond)) def by_sender(self, query: Query, operand: str, maybe_negate: ConditionTransform) -> Query: