decorator: Type cache_info, cache_clear for ignore_unhashable_lru_cache.

Signed-off-by: Anders Kaseorg <anders@zulip.com>
This commit is contained in:
Anders Kaseorg 2022-06-26 11:21:11 -07:00 committed by Tim Abbott
parent 9db8a59a56
commit 53231aa9d9
2 changed files with 79 additions and 98 deletions

View File

@ -7,12 +7,13 @@ import secrets
import sys
import time
import traceback
from functools import lru_cache, wraps
from functools import _lru_cache_wrapper, lru_cache, wraps
from typing import (
TYPE_CHECKING,
Any,
Callable,
Dict,
Generic,
Iterable,
List,
Optional,
@ -737,9 +738,41 @@ def flush_submessage(*, instance: "SubMessage", **kwargs: object) -> None:
cache_delete(to_dict_cache_key_id(message_id))
class IgnoreUnhashableLruCacheWrapper(Generic[ParamT, ReturnT]):
def __init__(
self, function: Callable[ParamT, ReturnT], cached_function: "_lru_cache_wrapper[ReturnT]"
):
self.key_prefix = KEY_PREFIX
self.function = function
self.cached_function = cached_function
self.cache_info = cached_function.cache_info
self.cache_clear = cached_function.cache_clear
def __call__(self, *args: ParamT.args, **kwargs: ParamT.kwargs) -> ReturnT:
if self.key_prefix != KEY_PREFIX:
# Clear cache when cache.KEY_PREFIX changes. This is used in
# tests.
self.cache_clear()
self.key_prefix = KEY_PREFIX
try:
return self.cached_function(
*args, **kwargs # type: ignore[arg-type] # might be unhashable
)
except TypeError:
# args or kwargs contains an element which is unhashable. In
# this case we don't cache the result.
pass
# Deliberately calling this function from outside of exception
# handler to get a more descriptive traceback. Otherwise traceback
# can include the exception from cached_function as well.
return self.function(*args, **kwargs)
def ignore_unhashable_lru_cache(
maxsize: int = 128, typed: bool = False
) -> Callable[[Callable[ParamT, ReturnT]], Callable[ParamT, ReturnT]]:
) -> Callable[[Callable[ParamT, ReturnT]], IgnoreUnhashableLruCacheWrapper[ParamT, ReturnT]]:
"""
This is a wrapper over lru_cache function. It adds following features on
top of lru_cache:
@ -749,42 +782,10 @@ def ignore_unhashable_lru_cache(
"""
internal_decorator = lru_cache(maxsize=maxsize, typed=typed)
def decorator(user_function: Callable[ParamT, ReturnT]) -> Callable[ParamT, ReturnT]:
if settings.DEVELOPMENT and not settings.TEST_SUITE: # nocoverage
# In the development environment, we want every file
# change to refresh the source files from disk.
return user_function
cache_enabled_user_function = internal_decorator(user_function)
key_prefix = KEY_PREFIX
def wrapper(*args: ParamT.args, **kwargs: ParamT.kwargs) -> ReturnT:
nonlocal key_prefix
if key_prefix != KEY_PREFIX:
# Clear cache when cache.KEY_PREFIX changes. This is used in
# tests.
cache_enabled_user_function.cache_clear()
key_prefix = KEY_PREFIX
try:
return cache_enabled_user_function(
*args, **kwargs # type: ignore[arg-type] # might be unhashable
)
except TypeError:
# args or kwargs contains an element which is unhashable. In
# this case we don't cache the result.
pass
# Deliberately calling this function from outside of exception
# handler to get a more descriptive traceback. Otherwise traceback
# can include the exception from cached_enabled_user_function as
# well.
return user_function(*args, **kwargs)
setattr(wrapper, "cache_info", cache_enabled_user_function.cache_info)
setattr(wrapper, "cache_clear", cache_enabled_user_function.cache_clear)
return wrapper
def decorator(
user_function: Callable[ParamT, ReturnT]
) -> IgnoreUnhashableLruCacheWrapper[ParamT, ReturnT]:
return IgnoreUnhashableLruCacheWrapper(user_function, internal_decorator(user_function))
return decorator

View File

@ -2083,48 +2083,38 @@ class TestIgnoreUnhashableLRUCache(ZulipTestCase):
def f(arg: Any) -> Any:
return arg
def get_cache_info() -> Tuple[int, int, int]:
info = getattr(f, "cache_info")()
hits = getattr(info, "hits")
misses = getattr(info, "misses")
currsize = getattr(info, "currsize")
return hits, misses, currsize
def clear_cache() -> None:
getattr(f, "cache_clear")()
# Check hashable argument.
result = f(1)
hits, misses, currsize = get_cache_info()
info = f.cache_info()
# First one should be a miss.
self.assertEqual(hits, 0)
self.assertEqual(misses, 1)
self.assertEqual(currsize, 1)
self.assertEqual(info.hits, 0)
self.assertEqual(info.misses, 1)
self.assertEqual(info.currsize, 1)
self.assertEqual(result, 1)
result = f(1)
hits, misses, currsize = get_cache_info()
info = f.cache_info()
# Second one should be a hit.
self.assertEqual(hits, 1)
self.assertEqual(misses, 1)
self.assertEqual(currsize, 1)
self.assertEqual(info.hits, 1)
self.assertEqual(info.misses, 1)
self.assertEqual(info.currsize, 1)
self.assertEqual(result, 1)
# Check unhashable argument.
result = f({1: 2})
hits, misses, currsize = get_cache_info()
info = f.cache_info()
# Cache should not be used.
self.assertEqual(hits, 1)
self.assertEqual(misses, 1)
self.assertEqual(currsize, 1)
self.assertEqual(info.hits, 1)
self.assertEqual(info.misses, 1)
self.assertEqual(info.currsize, 1)
self.assertEqual(result, {1: 2})
# Clear cache.
clear_cache()
hits, misses, currsize = get_cache_info()
self.assertEqual(hits, 0)
self.assertEqual(misses, 0)
self.assertEqual(currsize, 0)
f.cache_clear()
info = f.cache_info()
self.assertEqual(info.hits, 0)
self.assertEqual(info.misses, 0)
self.assertEqual(info.currsize, 0)
def test_cache_hit_dict_args(self) -> None:
@ignore_unhashable_lru_cache()
@ -2132,60 +2122,50 @@ class TestIgnoreUnhashableLRUCache(ZulipTestCase):
def g(arg: Any) -> Any:
return arg
def get_cache_info() -> Tuple[int, int, int]:
info = getattr(g, "cache_info")()
hits = getattr(info, "hits")
misses = getattr(info, "misses")
currsize = getattr(info, "currsize")
return hits, misses, currsize
def clear_cache() -> None:
getattr(g, "cache_clear")()
# Not used as a decorator on the definition to allow defining
# get_cache_info and clear_cache
# Not used as a decorator on the definition to allow calling
# cache_info and cache_clear
f = dict_to_items_tuple(g)
# Check hashable argument.
result = f(1)
hits, misses, currsize = get_cache_info()
info = g.cache_info()
# First one should be a miss.
self.assertEqual(hits, 0)
self.assertEqual(misses, 1)
self.assertEqual(currsize, 1)
self.assertEqual(info.hits, 0)
self.assertEqual(info.misses, 1)
self.assertEqual(info.currsize, 1)
self.assertEqual(result, 1)
result = f(1)
hits, misses, currsize = get_cache_info()
info = g.cache_info()
# Second one should be a hit.
self.assertEqual(hits, 1)
self.assertEqual(misses, 1)
self.assertEqual(currsize, 1)
self.assertEqual(info.hits, 1)
self.assertEqual(info.misses, 1)
self.assertEqual(info.currsize, 1)
self.assertEqual(result, 1)
# Check dict argument.
result = f({1: 2})
hits, misses, currsize = get_cache_info()
info = g.cache_info()
# First one is a miss
self.assertEqual(hits, 1)
self.assertEqual(misses, 2)
self.assertEqual(currsize, 2)
self.assertEqual(info.hits, 1)
self.assertEqual(info.misses, 2)
self.assertEqual(info.currsize, 2)
self.assertEqual(result, {1: 2})
result = f({1: 2})
hits, misses, currsize = get_cache_info()
info = g.cache_info()
# Second one should be a hit.
self.assertEqual(hits, 2)
self.assertEqual(misses, 2)
self.assertEqual(currsize, 2)
self.assertEqual(info.hits, 2)
self.assertEqual(info.misses, 2)
self.assertEqual(info.currsize, 2)
self.assertEqual(result, {1: 2})
# Clear cache.
clear_cache()
hits, misses, currsize = get_cache_info()
self.assertEqual(hits, 0)
self.assertEqual(misses, 0)
self.assertEqual(currsize, 0)
g.cache_clear()
info = g.cache_info()
self.assertEqual(info.hits, 0)
self.assertEqual(info.misses, 0)
self.assertEqual(info.currsize, 0)
class TestRequestNotes(ZulipTestCase):