diff --git a/zerver/lib/rate_limiter.py b/zerver/lib/rate_limiter.py index f5096c049c..9fff9a95a0 100644 --- a/zerver/lib/rate_limiter.py +++ b/zerver/lib/rate_limiter.py @@ -41,14 +41,13 @@ class RateLimitedObject(ABC): def rate_limit_request(self, request: HttpRequest) -> None: ratelimited, time = self.rate_limit() - entity_type = type(self).__name__ - if not hasattr(request, '_ratelimit'): - request._ratelimit = {} - request._ratelimit[entity_type] = RateLimitResult( + if not hasattr(request, '_ratelimits_applied'): + request._ratelimits_applied = [] + request._ratelimits_applied.append(RateLimitResult( entity=self, secs_to_freedom=time, over_limit=ratelimited - ) + )) # Abort this request if the user is over their rate limits if ratelimited: # Pass information about what kind of entity got limited in the exception: @@ -56,8 +55,8 @@ class RateLimitedObject(ABC): calls_remaining, time_reset = self.api_calls_left() - request._ratelimit[entity_type].remaining = calls_remaining - request._ratelimit[entity_type].secs_to_freedom = time_reset + request._ratelimits_applied[-1].remaining = calls_remaining + request._ratelimits_applied[-1].secs_to_freedom = time_reset def block_access(self, seconds: int) -> None: "Manually blocks an entity for the desired number of seconds" diff --git a/zerver/middleware.py b/zerver/middleware.py index 55dad7275d..f961504b75 100644 --- a/zerver/middleware.py +++ b/zerver/middleware.py @@ -370,9 +370,8 @@ class RateLimitMiddleware(MiddlewareMixin): return response # Add X-RateLimit-*** headers - if hasattr(request, '_ratelimit'): - rate_limit_results = list(request._ratelimit.values()) - self.set_response_headers(response, rate_limit_results) + if hasattr(request, '_ratelimits_applied'): + self.set_response_headers(response, request._ratelimits_applied) return response diff --git a/zproject/backends.py b/zproject/backends.py index 13fdf2b3fe..6423e84a8e 100644 --- a/zproject/backends.py +++ b/zproject/backends.py @@ -192,7 +192,11 @@ def rate_limit_authentication_by_username(request: HttpRequest, username: str) - RateLimitedAuthenticationByUsername(username).rate_limit_request(request) def auth_rate_limiting_already_applied(request: HttpRequest) -> bool: - return hasattr(request, '_ratelimit') and 'RateLimitedAuthenticationByUsername' in request._ratelimit + if not hasattr(request, '_ratelimits_applied'): + return False + + return any(isinstance(r.entity, RateLimitedAuthenticationByUsername) + for r in request._ratelimits_applied) # Django's authentication mechanism uses introspection on the various authenticate() functions # defined by backends, so we need a decorator that doesn't break function signatures.