diff --git a/zerver/lib/cache.py b/zerver/lib/cache.py index e5b82ae4a6..33fbec3355 100644 --- a/zerver/lib/cache.py +++ b/zerver/lib/cache.py @@ -7,12 +7,13 @@ import secrets import sys import time import traceback -from functools import lru_cache, wraps +from functools import _lru_cache_wrapper, lru_cache, wraps from typing import ( TYPE_CHECKING, Any, Callable, Dict, + Generic, Iterable, List, Optional, @@ -737,9 +738,41 @@ def flush_submessage(*, instance: "SubMessage", **kwargs: object) -> None: cache_delete(to_dict_cache_key_id(message_id)) +class IgnoreUnhashableLruCacheWrapper(Generic[ParamT, ReturnT]): + def __init__( + self, function: Callable[ParamT, ReturnT], cached_function: "_lru_cache_wrapper[ReturnT]" + ): + self.key_prefix = KEY_PREFIX + self.function = function + self.cached_function = cached_function + self.cache_info = cached_function.cache_info + self.cache_clear = cached_function.cache_clear + + def __call__(self, *args: ParamT.args, **kwargs: ParamT.kwargs) -> ReturnT: + if self.key_prefix != KEY_PREFIX: + # Clear cache when cache.KEY_PREFIX changes. This is used in + # tests. + self.cache_clear() + self.key_prefix = KEY_PREFIX + + try: + return self.cached_function( + *args, **kwargs # type: ignore[arg-type] # might be unhashable + ) + except TypeError: + # args or kwargs contains an element which is unhashable. In + # this case we don't cache the result. + pass + + # Deliberately calling this function from outside of exception + # handler to get a more descriptive traceback. Otherwise traceback + # can include the exception from cached_function as well. + return self.function(*args, **kwargs) + + def ignore_unhashable_lru_cache( maxsize: int = 128, typed: bool = False -) -> Callable[[Callable[ParamT, ReturnT]], Callable[ParamT, ReturnT]]: +) -> Callable[[Callable[ParamT, ReturnT]], IgnoreUnhashableLruCacheWrapper[ParamT, ReturnT]]: """ This is a wrapper over lru_cache function. It adds following features on top of lru_cache: @@ -749,42 +782,10 @@ def ignore_unhashable_lru_cache( """ internal_decorator = lru_cache(maxsize=maxsize, typed=typed) - def decorator(user_function: Callable[ParamT, ReturnT]) -> Callable[ParamT, ReturnT]: - if settings.DEVELOPMENT and not settings.TEST_SUITE: # nocoverage - # In the development environment, we want every file - # change to refresh the source files from disk. - return user_function - - cache_enabled_user_function = internal_decorator(user_function) - key_prefix = KEY_PREFIX - - def wrapper(*args: ParamT.args, **kwargs: ParamT.kwargs) -> ReturnT: - nonlocal key_prefix - - if key_prefix != KEY_PREFIX: - # Clear cache when cache.KEY_PREFIX changes. This is used in - # tests. - cache_enabled_user_function.cache_clear() - key_prefix = KEY_PREFIX - - try: - return cache_enabled_user_function( - *args, **kwargs # type: ignore[arg-type] # might be unhashable - ) - except TypeError: - # args or kwargs contains an element which is unhashable. In - # this case we don't cache the result. - pass - - # Deliberately calling this function from outside of exception - # handler to get a more descriptive traceback. Otherwise traceback - # can include the exception from cached_enabled_user_function as - # well. - return user_function(*args, **kwargs) - - setattr(wrapper, "cache_info", cache_enabled_user_function.cache_info) - setattr(wrapper, "cache_clear", cache_enabled_user_function.cache_clear) - return wrapper + def decorator( + user_function: Callable[ParamT, ReturnT] + ) -> IgnoreUnhashableLruCacheWrapper[ParamT, ReturnT]: + return IgnoreUnhashableLruCacheWrapper(user_function, internal_decorator(user_function)) return decorator diff --git a/zerver/tests/test_decorators.py b/zerver/tests/test_decorators.py index abce3f66f1..3ac8ba75ca 100644 --- a/zerver/tests/test_decorators.py +++ b/zerver/tests/test_decorators.py @@ -2083,48 +2083,38 @@ class TestIgnoreUnhashableLRUCache(ZulipTestCase): def f(arg: Any) -> Any: return arg - def get_cache_info() -> Tuple[int, int, int]: - info = getattr(f, "cache_info")() - hits = getattr(info, "hits") - misses = getattr(info, "misses") - currsize = getattr(info, "currsize") - return hits, misses, currsize - - def clear_cache() -> None: - getattr(f, "cache_clear")() - # Check hashable argument. result = f(1) - hits, misses, currsize = get_cache_info() + info = f.cache_info() # First one should be a miss. - self.assertEqual(hits, 0) - self.assertEqual(misses, 1) - self.assertEqual(currsize, 1) + self.assertEqual(info.hits, 0) + self.assertEqual(info.misses, 1) + self.assertEqual(info.currsize, 1) self.assertEqual(result, 1) result = f(1) - hits, misses, currsize = get_cache_info() + info = f.cache_info() # Second one should be a hit. - self.assertEqual(hits, 1) - self.assertEqual(misses, 1) - self.assertEqual(currsize, 1) + self.assertEqual(info.hits, 1) + self.assertEqual(info.misses, 1) + self.assertEqual(info.currsize, 1) self.assertEqual(result, 1) # Check unhashable argument. result = f({1: 2}) - hits, misses, currsize = get_cache_info() + info = f.cache_info() # Cache should not be used. - self.assertEqual(hits, 1) - self.assertEqual(misses, 1) - self.assertEqual(currsize, 1) + self.assertEqual(info.hits, 1) + self.assertEqual(info.misses, 1) + self.assertEqual(info.currsize, 1) self.assertEqual(result, {1: 2}) # Clear cache. - clear_cache() - hits, misses, currsize = get_cache_info() - self.assertEqual(hits, 0) - self.assertEqual(misses, 0) - self.assertEqual(currsize, 0) + f.cache_clear() + info = f.cache_info() + self.assertEqual(info.hits, 0) + self.assertEqual(info.misses, 0) + self.assertEqual(info.currsize, 0) def test_cache_hit_dict_args(self) -> None: @ignore_unhashable_lru_cache() @@ -2132,60 +2122,50 @@ class TestIgnoreUnhashableLRUCache(ZulipTestCase): def g(arg: Any) -> Any: return arg - def get_cache_info() -> Tuple[int, int, int]: - info = getattr(g, "cache_info")() - hits = getattr(info, "hits") - misses = getattr(info, "misses") - currsize = getattr(info, "currsize") - return hits, misses, currsize - - def clear_cache() -> None: - getattr(g, "cache_clear")() - - # Not used as a decorator on the definition to allow defining - # get_cache_info and clear_cache + # Not used as a decorator on the definition to allow calling + # cache_info and cache_clear f = dict_to_items_tuple(g) # Check hashable argument. result = f(1) - hits, misses, currsize = get_cache_info() + info = g.cache_info() # First one should be a miss. - self.assertEqual(hits, 0) - self.assertEqual(misses, 1) - self.assertEqual(currsize, 1) + self.assertEqual(info.hits, 0) + self.assertEqual(info.misses, 1) + self.assertEqual(info.currsize, 1) self.assertEqual(result, 1) result = f(1) - hits, misses, currsize = get_cache_info() + info = g.cache_info() # Second one should be a hit. - self.assertEqual(hits, 1) - self.assertEqual(misses, 1) - self.assertEqual(currsize, 1) + self.assertEqual(info.hits, 1) + self.assertEqual(info.misses, 1) + self.assertEqual(info.currsize, 1) self.assertEqual(result, 1) # Check dict argument. result = f({1: 2}) - hits, misses, currsize = get_cache_info() + info = g.cache_info() # First one is a miss - self.assertEqual(hits, 1) - self.assertEqual(misses, 2) - self.assertEqual(currsize, 2) + self.assertEqual(info.hits, 1) + self.assertEqual(info.misses, 2) + self.assertEqual(info.currsize, 2) self.assertEqual(result, {1: 2}) result = f({1: 2}) - hits, misses, currsize = get_cache_info() + info = g.cache_info() # Second one should be a hit. - self.assertEqual(hits, 2) - self.assertEqual(misses, 2) - self.assertEqual(currsize, 2) + self.assertEqual(info.hits, 2) + self.assertEqual(info.misses, 2) + self.assertEqual(info.currsize, 2) self.assertEqual(result, {1: 2}) # Clear cache. - clear_cache() - hits, misses, currsize = get_cache_info() - self.assertEqual(hits, 0) - self.assertEqual(misses, 0) - self.assertEqual(currsize, 0) + g.cache_clear() + info = g.cache_info() + self.assertEqual(info.hits, 0) + self.assertEqual(info.misses, 0) + self.assertEqual(info.currsize, 0) class TestRequestNotes(ZulipTestCase):