From f1925487e89455edee58abbecc1688ca3b548864 Mon Sep 17 00:00:00 2001 From: Anders Kaseorg Date: Thu, 6 Apr 2023 11:58:14 -0700 Subject: [PATCH] db: Force use of TimeTrackingCursor to work around Django 4.2 bug. Effectively revert commit b4cf9ad777e1e80a56bb441d62ad7ab6c7e14f42 to work around https://code.djangoproject.com/ticket/34466. Signed-off-by: Anders Kaseorg --- zerver/lib/db.py | 49 ++++++++++++++++++++++++++++++++- zproject/computed_settings.py | 3 +- zproject/test_extra_settings.py | 3 +- 3 files changed, 50 insertions(+), 5 deletions(-) diff --git a/zerver/lib/db.py b/zerver/lib/db.py index 10231aa9f5..5691682fa8 100644 --- a/zerver/lib/db.py +++ b/zerver/lib/db.py @@ -1,5 +1,17 @@ import time -from typing import Any, Callable, Dict, Iterable, List, Mapping, Sequence, TypeVar, Union +from typing import ( + Any, + Callable, + Dict, + Iterable, + List, + Mapping, + Optional, + Sequence, + TypeVar, + Union, + overload, +) from psycopg2.extensions import connection, cursor from psycopg2.sql import Composable @@ -48,6 +60,41 @@ class TimeTrackingConnection(connection): self.queries: List[Dict[str, str]] = [] super().__init__(*args, **kwargs) + @overload + def cursor( + self, + name: Union[str, bytes, None] = ..., + *, + withhold: bool = ..., + scrollable: Optional[bool] = ..., + ) -> TimeTrackingCursor: + ... + + @overload + def cursor( + self, + name: Union[str, bytes, None] = ..., + *, + cursor_factory: Callable[..., CursorT] = ..., + withhold: bool = ..., + scrollable: Optional[bool] = ..., + ) -> CursorT: + ... + + @overload + def cursor( + self, + name: Union[str, bytes, None], + cursor_factory: Callable[..., CursorT] = ..., + withhold: bool = ..., + scrollable: Optional[bool] = ..., + ) -> CursorT: + ... + + def cursor(self, *args: Any, **kwargs: Any) -> cursor: + kwargs["cursor_factory"] = TimeTrackingCursor + return super().cursor(*args, **kwargs) + def reset_queries() -> None: from django.db import connections diff --git a/zproject/computed_settings.py b/zproject/computed_settings.py index 13876b3e93..8d10662d12 100644 --- a/zproject/computed_settings.py +++ b/zproject/computed_settings.py @@ -7,7 +7,7 @@ from typing import Any, Dict, Final, List, Tuple, Union from urllib.parse import urljoin from scripts.lib.zulip_tools import get_tornado_ports -from zerver.lib.db import TimeTrackingConnection, TimeTrackingCursor +from zerver.lib.db import TimeTrackingConnection from .config import ( DEPLOY_ROOT, @@ -282,7 +282,6 @@ DATABASES: Dict[str, Dict[str, Any]] = { "CONN_MAX_AGE": 600, "OPTIONS": { "connection_factory": TimeTrackingConnection, - "cursor_factory": TimeTrackingCursor, }, } } diff --git a/zproject/test_extra_settings.py b/zproject/test_extra_settings.py index cbea88b7d6..0a8f9b7e3f 100644 --- a/zproject/test_extra_settings.py +++ b/zproject/test_extra_settings.py @@ -4,7 +4,7 @@ from typing import Dict, List, Optional, Tuple import ldap from django_auth_ldap.config import LDAPSearch -from zerver.lib.db import TimeTrackingConnection, TimeTrackingCursor +from zerver.lib.db import TimeTrackingConnection from zproject.settings_types import OIDCIdPConfigDict, SAMLIdPConfigDict, SCIMConfigDict from .config import DEPLOY_ROOT, get_from_file_if_exists @@ -36,7 +36,6 @@ DATABASES["default"] = { "TEST_NAME": "django_zulip_tests", "OPTIONS": { "connection_factory": TimeTrackingConnection, - "cursor_factory": TimeTrackingCursor, }, }