@@ -0,0 +1 @@ | |||
Describe which rate limiter was hit in logs. |
@@ -211,6 +211,11 @@ class SynapseError(CodeMessageException): | |||
def error_dict(self, config: Optional["HomeServerConfig"]) -> "JsonDict": | |||
return cs_error(self.msg, self.errcode, **self._additional_fields) | |||
@property | |||
def debug_context(self) -> Optional[str]: | |||
"""Override this to add debugging context that shouldn't be sent to clients.""" | |||
return None | |||
class InvalidAPICallError(SynapseError): | |||
"""You called an existing API endpoint, but fed that endpoint | |||
@@ -508,8 +513,8 @@ class LimitExceededError(SynapseError): | |||
def __init__( | |||
self, | |||
limiter_name: str, | |||
code: int = 429, | |||
msg: str = "Too Many Requests", | |||
retry_after_ms: Optional[int] = None, | |||
errcode: str = Codes.LIMIT_EXCEEDED, | |||
): | |||
@@ -518,12 +523,17 @@ class LimitExceededError(SynapseError): | |||
if self.include_retry_after_header and retry_after_ms is not None | |||
else None | |||
) | |||
super().__init__(code, msg, errcode, headers=headers) | |||
super().__init__(code, "Too Many Requests", errcode, headers=headers) | |||
self.retry_after_ms = retry_after_ms | |||
self.limiter_name = limiter_name | |||
def error_dict(self, config: Optional["HomeServerConfig"]) -> "JsonDict": | |||
return cs_error(self.msg, self.errcode, retry_after_ms=self.retry_after_ms) | |||
@property | |||
def debug_context(self) -> Optional[str]: | |||
return self.limiter_name | |||
class RoomKeysVersionError(SynapseError): | |||
"""A client has tried to upload to a non-current version of the room_keys store""" | |||
@@ -61,12 +61,16 @@ class Ratelimiter: | |||
""" | |||
def __init__( | |||
self, store: DataStore, clock: Clock, rate_hz: float, burst_count: int | |||
self, | |||
store: DataStore, | |||
clock: Clock, | |||
cfg: RatelimitSettings, | |||
): | |||
self.clock = clock | |||
self.rate_hz = rate_hz | |||
self.burst_count = burst_count | |||
self.rate_hz = cfg.per_second | |||
self.burst_count = cfg.burst_count | |||
self.store = store | |||
self._limiter_name = cfg.key | |||
# An ordered dictionary representing the token buckets tracked by this rate | |||
# limiter. Each entry maps a key of arbitrary type to a tuple representing: | |||
@@ -305,7 +309,8 @@ class Ratelimiter: | |||
if not allowed: | |||
raise LimitExceededError( | |||
retry_after_ms=int(1000 * (time_allowed - time_now_s)) | |||
limiter_name=self._limiter_name, | |||
retry_after_ms=int(1000 * (time_allowed - time_now_s)), | |||
) | |||
@@ -322,7 +327,9 @@ class RequestRatelimiter: | |||
# The rate_hz and burst_count are overridden on a per-user basis | |||
self.request_ratelimiter = Ratelimiter( | |||
store=self.store, clock=self.clock, rate_hz=0, burst_count=0 | |||
store=self.store, | |||
clock=self.clock, | |||
cfg=RatelimitSettings(key=rc_message.key, per_second=0, burst_count=0), | |||
) | |||
self._rc_message = rc_message | |||
@@ -332,8 +339,7 @@ class RequestRatelimiter: | |||
self.admin_redaction_ratelimiter: Optional[Ratelimiter] = Ratelimiter( | |||
store=self.store, | |||
clock=self.clock, | |||
rate_hz=rc_admin_redaction.per_second, | |||
burst_count=rc_admin_redaction.burst_count, | |||
cfg=rc_admin_redaction, | |||
) | |||
else: | |||
self.admin_redaction_ratelimiter = None | |||
@@ -12,7 +12,7 @@ | |||
# See the License for the specific language governing permissions and | |||
# limitations under the License. | |||
from typing import Any, Dict, Optional | |||
from typing import Any, Dict, Optional, cast | |||
import attr | |||
@@ -21,16 +21,47 @@ from synapse.types import JsonDict | |||
from ._base import Config | |||
@attr.s(slots=True, frozen=True, auto_attribs=True) | |||
class RatelimitSettings: | |||
def __init__( | |||
self, | |||
config: Dict[str, float], | |||
key: str | |||
per_second: float | |||
burst_count: int | |||
@classmethod | |||
def parse( | |||
cls, | |||
config: Dict[str, Any], | |||
key: str, | |||
defaults: Optional[Dict[str, float]] = None, | |||
): | |||
) -> "RatelimitSettings": | |||
"""Parse config[key] as a new-style rate limiter config. | |||
The key may refer to a nested dictionary using a full stop (.) to separate | |||
each nested key. For example, use the key "a.b.c" to parse the following: | |||
a: | |||
b: | |||
c: | |||
per_second: 10 | |||
burst_count: 200 | |||
If this lookup fails, we'll fallback to the defaults. | |||
""" | |||
defaults = defaults or {"per_second": 0.17, "burst_count": 3.0} | |||
self.per_second = config.get("per_second", defaults["per_second"]) | |||
self.burst_count = int(config.get("burst_count", defaults["burst_count"])) | |||
rl_config = config | |||
for part in key.split("."): | |||
rl_config = rl_config.get(part, {}) | |||
# By this point we should have hit the rate limiter parameters. | |||
# We don't actually check this though! | |||
rl_config = cast(Dict[str, float], rl_config) | |||
return cls( | |||
key=key, | |||
per_second=rl_config.get("per_second", defaults["per_second"]), | |||
burst_count=int(rl_config.get("burst_count", defaults["burst_count"])), | |||
) | |||
@attr.s(auto_attribs=True) | |||
@@ -49,15 +80,14 @@ class RatelimitConfig(Config): | |||
# Load the new-style messages config if it exists. Otherwise fall back | |||
# to the old method. | |||
if "rc_message" in config: | |||
self.rc_message = RatelimitSettings( | |||
config["rc_message"], defaults={"per_second": 0.2, "burst_count": 10.0} | |||
self.rc_message = RatelimitSettings.parse( | |||
config, "rc_message", defaults={"per_second": 0.2, "burst_count": 10.0} | |||
) | |||
else: | |||
self.rc_message = RatelimitSettings( | |||
{ | |||
"per_second": config.get("rc_messages_per_second", 0.2), | |||
"burst_count": config.get("rc_message_burst_count", 10.0), | |||
} | |||
key="rc_messages", | |||
per_second=config.get("rc_messages_per_second", 0.2), | |||
burst_count=config.get("rc_message_burst_count", 10.0), | |||
) | |||
# Load the new-style federation config, if it exists. Otherwise, fall | |||
@@ -79,51 +109,59 @@ class RatelimitConfig(Config): | |||
} | |||
) | |||
self.rc_registration = RatelimitSettings(config.get("rc_registration", {})) | |||
self.rc_registration = RatelimitSettings.parse(config, "rc_registration", {}) | |||
self.rc_registration_token_validity = RatelimitSettings( | |||
config.get("rc_registration_token_validity", {}), | |||
self.rc_registration_token_validity = RatelimitSettings.parse( | |||
config, | |||
"rc_registration_token_validity", | |||
defaults={"per_second": 0.1, "burst_count": 5}, | |||
) | |||
# It is reasonable to login with a bunch of devices at once (i.e. when | |||
# setting up an account), but it is *not* valid to continually be | |||
# logging into new devices. | |||
rc_login_config = config.get("rc_login", {}) | |||
self.rc_login_address = RatelimitSettings( | |||
rc_login_config.get("address", {}), | |||
self.rc_login_address = RatelimitSettings.parse( | |||
config, | |||
"rc_login.address", | |||
defaults={"per_second": 0.003, "burst_count": 5}, | |||
) | |||
self.rc_login_account = RatelimitSettings( | |||
rc_login_config.get("account", {}), | |||
self.rc_login_account = RatelimitSettings.parse( | |||
config, | |||
"rc_login.account", | |||
defaults={"per_second": 0.003, "burst_count": 5}, | |||
) | |||
self.rc_login_failed_attempts = RatelimitSettings( | |||
rc_login_config.get("failed_attempts", {}) | |||
self.rc_login_failed_attempts = RatelimitSettings.parse( | |||
config, | |||
"rc_login.failed_attempts", | |||
{}, | |||
) | |||
self.federation_rr_transactions_per_room_per_second = config.get( | |||
"federation_rr_transactions_per_room_per_second", 50 | |||
) | |||
rc_admin_redaction = config.get("rc_admin_redaction") | |||
self.rc_admin_redaction = None | |||
if rc_admin_redaction: | |||
self.rc_admin_redaction = RatelimitSettings(rc_admin_redaction) | |||
if "rc_admin_redaction" in config: | |||
self.rc_admin_redaction = RatelimitSettings.parse( | |||
config, "rc_admin_redaction", {} | |||
) | |||
self.rc_joins_local = RatelimitSettings( | |||
config.get("rc_joins", {}).get("local", {}), | |||
self.rc_joins_local = RatelimitSettings.parse( | |||
config, | |||
"rc_joins.local", | |||
defaults={"per_second": 0.1, "burst_count": 10}, | |||
) | |||
self.rc_joins_remote = RatelimitSettings( | |||
config.get("rc_joins", {}).get("remote", {}), | |||
self.rc_joins_remote = RatelimitSettings.parse( | |||
config, | |||
"rc_joins.remote", | |||
defaults={"per_second": 0.01, "burst_count": 10}, | |||
) | |||
# Track the rate of joins to a given room. If there are too many, temporarily | |||
# prevent local joins and remote joins via this server. | |||
self.rc_joins_per_room = RatelimitSettings( | |||
config.get("rc_joins_per_room", {}), | |||
self.rc_joins_per_room = RatelimitSettings.parse( | |||
config, | |||
"rc_joins_per_room", | |||
defaults={"per_second": 1, "burst_count": 10}, | |||
) | |||
@@ -132,31 +170,37 @@ class RatelimitConfig(Config): | |||
# * For requests received over federation this is keyed by the origin. | |||
# | |||
# Note that this isn't exposed in the configuration as it is obscure. | |||
self.rc_key_requests = RatelimitSettings( | |||
config.get("rc_key_requests", {}), | |||
self.rc_key_requests = RatelimitSettings.parse( | |||
config, | |||
"rc_key_requests", | |||
defaults={"per_second": 20, "burst_count": 100}, | |||
) | |||
self.rc_3pid_validation = RatelimitSettings( | |||
config.get("rc_3pid_validation") or {}, | |||
self.rc_3pid_validation = RatelimitSettings.parse( | |||
config, | |||
"rc_3pid_validation", | |||
defaults={"per_second": 0.003, "burst_count": 5}, | |||
) | |||
self.rc_invites_per_room = RatelimitSettings( | |||
config.get("rc_invites", {}).get("per_room", {}), | |||
self.rc_invites_per_room = RatelimitSettings.parse( | |||
config, | |||
"rc_invites.per_room", | |||
defaults={"per_second": 0.3, "burst_count": 10}, | |||
) | |||
self.rc_invites_per_user = RatelimitSettings( | |||
config.get("rc_invites", {}).get("per_user", {}), | |||
self.rc_invites_per_user = RatelimitSettings.parse( | |||
config, | |||
"rc_invites.per_user", | |||
defaults={"per_second": 0.003, "burst_count": 5}, | |||
) | |||
self.rc_invites_per_issuer = RatelimitSettings( | |||
config.get("rc_invites", {}).get("per_issuer", {}), | |||
self.rc_invites_per_issuer = RatelimitSettings.parse( | |||
config, | |||
"rc_invites.per_issuer", | |||
defaults={"per_second": 0.3, "burst_count": 10}, | |||
) | |||
self.rc_third_party_invite = RatelimitSettings( | |||
config.get("rc_third_party_invite", {}), | |||
self.rc_third_party_invite = RatelimitSettings.parse( | |||
config, | |||
"rc_third_party_invite", | |||
defaults={"per_second": 0.0025, "burst_count": 5}, | |||
) |
@@ -218,19 +218,17 @@ class AuthHandler: | |||
self._failed_uia_attempts_ratelimiter = Ratelimiter( | |||
store=self.store, | |||
clock=self.clock, | |||
rate_hz=self.hs.config.ratelimiting.rc_login_failed_attempts.per_second, | |||
burst_count=self.hs.config.ratelimiting.rc_login_failed_attempts.burst_count, | |||
cfg=self.hs.config.ratelimiting.rc_login_failed_attempts, | |||
) | |||
# The number of seconds to keep a UI auth session active. | |||
self._ui_auth_session_timeout = hs.config.auth.ui_auth_session_timeout | |||
# Ratelimitier for failed /login attempts | |||
# Ratelimiter for failed /login attempts | |||
self._failed_login_attempts_ratelimiter = Ratelimiter( | |||
store=self.store, | |||
clock=hs.get_clock(), | |||
rate_hz=self.hs.config.ratelimiting.rc_login_failed_attempts.per_second, | |||
burst_count=self.hs.config.ratelimiting.rc_login_failed_attempts.burst_count, | |||
cfg=self.hs.config.ratelimiting.rc_login_failed_attempts, | |||
) | |||
self._clock = self.hs.get_clock() | |||
@@ -90,8 +90,7 @@ class DeviceMessageHandler: | |||
self._ratelimiter = Ratelimiter( | |||
store=self.store, | |||
clock=hs.get_clock(), | |||
rate_hz=hs.config.ratelimiting.rc_key_requests.per_second, | |||
burst_count=hs.config.ratelimiting.rc_key_requests.burst_count, | |||
cfg=hs.config.ratelimiting.rc_key_requests, | |||
) | |||
async def on_direct_to_device_edu(self, origin: str, content: JsonDict) -> None: | |||
@@ -66,14 +66,12 @@ class IdentityHandler: | |||
self._3pid_validation_ratelimiter_ip = Ratelimiter( | |||
store=self.store, | |||
clock=hs.get_clock(), | |||
rate_hz=hs.config.ratelimiting.rc_3pid_validation.per_second, | |||
burst_count=hs.config.ratelimiting.rc_3pid_validation.burst_count, | |||
cfg=hs.config.ratelimiting.rc_3pid_validation, | |||
) | |||
self._3pid_validation_ratelimiter_address = Ratelimiter( | |||
store=self.store, | |||
clock=hs.get_clock(), | |||
rate_hz=hs.config.ratelimiting.rc_3pid_validation.per_second, | |||
burst_count=hs.config.ratelimiting.rc_3pid_validation.burst_count, | |||
cfg=hs.config.ratelimiting.rc_3pid_validation, | |||
) | |||
async def ratelimit_request_token_requests( | |||
@@ -112,8 +112,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta): | |||
self._join_rate_limiter_local = Ratelimiter( | |||
store=self.store, | |||
clock=self.clock, | |||
rate_hz=hs.config.ratelimiting.rc_joins_local.per_second, | |||
burst_count=hs.config.ratelimiting.rc_joins_local.burst_count, | |||
cfg=hs.config.ratelimiting.rc_joins_local, | |||
) | |||
# Tracks joins from local users to rooms this server isn't a member of. | |||
# I.e. joins this server makes by requesting /make_join /send_join from | |||
@@ -121,8 +120,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta): | |||
self._join_rate_limiter_remote = Ratelimiter( | |||
store=self.store, | |||
clock=self.clock, | |||
rate_hz=hs.config.ratelimiting.rc_joins_remote.per_second, | |||
burst_count=hs.config.ratelimiting.rc_joins_remote.burst_count, | |||
cfg=hs.config.ratelimiting.rc_joins_remote, | |||
) | |||
# TODO: find a better place to keep this Ratelimiter. | |||
# It needs to be | |||
@@ -135,8 +133,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta): | |||
self._join_rate_per_room_limiter = Ratelimiter( | |||
store=self.store, | |||
clock=self.clock, | |||
rate_hz=hs.config.ratelimiting.rc_joins_per_room.per_second, | |||
burst_count=hs.config.ratelimiting.rc_joins_per_room.burst_count, | |||
cfg=hs.config.ratelimiting.rc_joins_per_room, | |||
) | |||
# Ratelimiter for invites, keyed by room (across all issuers, all | |||
@@ -144,8 +141,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta): | |||
self._invites_per_room_limiter = Ratelimiter( | |||
store=self.store, | |||
clock=self.clock, | |||
rate_hz=hs.config.ratelimiting.rc_invites_per_room.per_second, | |||
burst_count=hs.config.ratelimiting.rc_invites_per_room.burst_count, | |||
cfg=hs.config.ratelimiting.rc_invites_per_room, | |||
) | |||
# Ratelimiter for invites, keyed by recipient (across all rooms, all | |||
@@ -153,8 +149,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta): | |||
self._invites_per_recipient_limiter = Ratelimiter( | |||
store=self.store, | |||
clock=self.clock, | |||
rate_hz=hs.config.ratelimiting.rc_invites_per_user.per_second, | |||
burst_count=hs.config.ratelimiting.rc_invites_per_user.burst_count, | |||
cfg=hs.config.ratelimiting.rc_invites_per_user, | |||
) | |||
# Ratelimiter for invites, keyed by issuer (across all rooms, all | |||
@@ -162,15 +157,13 @@ class RoomMemberHandler(metaclass=abc.ABCMeta): | |||
self._invites_per_issuer_limiter = Ratelimiter( | |||
store=self.store, | |||
clock=self.clock, | |||
rate_hz=hs.config.ratelimiting.rc_invites_per_issuer.per_second, | |||
burst_count=hs.config.ratelimiting.rc_invites_per_issuer.burst_count, | |||
cfg=hs.config.ratelimiting.rc_invites_per_issuer, | |||
) | |||
self._third_party_invite_limiter = Ratelimiter( | |||
store=self.store, | |||
clock=self.clock, | |||
rate_hz=hs.config.ratelimiting.rc_third_party_invite.per_second, | |||
burst_count=hs.config.ratelimiting.rc_third_party_invite.burst_count, | |||
cfg=hs.config.ratelimiting.rc_third_party_invite, | |||
) | |||
self.request_ratelimiter = hs.get_request_ratelimiter() | |||
@@ -35,6 +35,7 @@ from synapse.api.errors import ( | |||
UnsupportedRoomVersionError, | |||
) | |||
from synapse.api.ratelimiting import Ratelimiter | |||
from synapse.config.ratelimiting import RatelimitSettings | |||
from synapse.events import EventBase | |||
from synapse.types import JsonDict, Requester, StrCollection | |||
from synapse.util.caches.response_cache import ResponseCache | |||
@@ -94,7 +95,9 @@ class RoomSummaryHandler: | |||
self._server_name = hs.hostname | |||
self._federation_client = hs.get_federation_client() | |||
self._ratelimiter = Ratelimiter( | |||
store=self._store, clock=hs.get_clock(), rate_hz=5, burst_count=10 | |||
store=self._store, | |||
clock=hs.get_clock(), | |||
cfg=RatelimitSettings("<room summary>", per_second=5, burst_count=10), | |||
) | |||
# If a user tries to fetch the same page multiple times in quick succession, | |||
@@ -115,7 +115,13 @@ def return_json_error( | |||
if exc.headers is not None: | |||
for header, value in exc.headers.items(): | |||
request.setHeader(header, value) | |||
logger.info("%s SynapseError: %s - %s", request, error_code, exc.msg) | |||
error_ctx = exc.debug_context | |||
if error_ctx: | |||
logger.info( | |||
"%s SynapseError: %s - %s (%s)", request, error_code, exc.msg, error_ctx | |||
) | |||
else: | |||
logger.info("%s SynapseError: %s - %s", request, error_code, exc.msg) | |||
elif f.check(CancelledError): | |||
error_code = HTTP_STATUS_REQUEST_CANCELLED | |||
error_dict = {"error": "Request cancelled", "errcode": Codes.UNKNOWN} | |||
@@ -120,14 +120,12 @@ class LoginRestServlet(RestServlet): | |||
self._address_ratelimiter = Ratelimiter( | |||
store=self._main_store, | |||
clock=hs.get_clock(), | |||
rate_hz=self.hs.config.ratelimiting.rc_login_address.per_second, | |||
burst_count=self.hs.config.ratelimiting.rc_login_address.burst_count, | |||
cfg=self.hs.config.ratelimiting.rc_login_address, | |||
) | |||
self._account_ratelimiter = Ratelimiter( | |||
store=self._main_store, | |||
clock=hs.get_clock(), | |||
rate_hz=self.hs.config.ratelimiting.rc_login_account.per_second, | |||
burst_count=self.hs.config.ratelimiting.rc_login_account.burst_count, | |||
cfg=self.hs.config.ratelimiting.rc_login_account, | |||
) | |||
# ensure the CAS/SAML/OIDC handlers are loaded on this worker instance. | |||
@@ -16,6 +16,7 @@ import logging | |||
from typing import TYPE_CHECKING, Tuple | |||
from synapse.api.ratelimiting import Ratelimiter | |||
from synapse.config.ratelimiting import RatelimitSettings | |||
from synapse.http.server import HttpServer | |||
from synapse.http.servlet import RestServlet, parse_json_object_from_request | |||
from synapse.http.site import SynapseRequest | |||
@@ -66,15 +67,18 @@ class LoginTokenRequestServlet(RestServlet): | |||
self.token_timeout = hs.config.auth.login_via_existing_token_timeout | |||
self._require_ui_auth = hs.config.auth.login_via_existing_require_ui_auth | |||
# Ratelimit aggressively to a maxmimum of 1 request per minute. | |||
# Ratelimit aggressively to a maximum of 1 request per minute. | |||
# | |||
# This endpoint can be used to spawn additional sessions and could be | |||
# abused by a malicious client to create many sessions. | |||
self._ratelimiter = Ratelimiter( | |||
store=self._main_store, | |||
clock=hs.get_clock(), | |||
rate_hz=1 / 60, | |||
burst_count=1, | |||
cfg=RatelimitSettings( | |||
key="<login token request>", | |||
per_second=1 / 60, | |||
burst_count=1, | |||
), | |||
) | |||
@interactive_auth_handler | |||
@@ -376,8 +376,7 @@ class RegistrationTokenValidityRestServlet(RestServlet): | |||
self.ratelimiter = Ratelimiter( | |||
store=self.store, | |||
clock=hs.get_clock(), | |||
rate_hz=hs.config.ratelimiting.rc_registration_token_validity.per_second, | |||
burst_count=hs.config.ratelimiting.rc_registration_token_validity.burst_count, | |||
cfg=hs.config.ratelimiting.rc_registration_token_validity, | |||
) | |||
async def on_GET(self, request: Request) -> Tuple[int, JsonDict]: | |||
@@ -408,8 +408,7 @@ class HomeServer(metaclass=abc.ABCMeta): | |||
return Ratelimiter( | |||
store=self.get_datastores().main, | |||
clock=self.get_clock(), | |||
rate_hz=self.config.ratelimiting.rc_registration.per_second, | |||
burst_count=self.config.ratelimiting.rc_registration.burst_count, | |||
cfg=self.config.ratelimiting.rc_registration, | |||
) | |||
@cache_in_self | |||
@@ -291,7 +291,8 @@ class _PerHostRatelimiter: | |||
if self.metrics_name: | |||
rate_limit_reject_counter.labels(self.metrics_name).inc() | |||
raise LimitExceededError( | |||
retry_after_ms=int(self.window_size / self.sleep_limit) | |||
limiter_name="rc_federation", | |||
retry_after_ms=int(self.window_size / self.sleep_limit), | |||
) | |||
self.request_times.append(time_now) | |||
@@ -1,6 +1,5 @@ | |||
# Copyright 2023 The Matrix.org Foundation C.I.C. | |||
# | |||
# | |||
# Licensed under the Apache License, Version 2.0 (the "License"); | |||
# you may not use this file except in compliance with the License. | |||
# You may obtain a copy of the License at | |||
@@ -13,24 +12,32 @@ | |||
# See the License for the specific language governing permissions and | |||
# limitations under the License. | |||
import json | |||
from synapse.api.errors import LimitExceededError | |||
from tests import unittest | |||
class ErrorsTestCase(unittest.TestCase): | |||
class LimitExceededErrorTestCase(unittest.TestCase): | |||
def test_key_appears_in_context_but_not_error_dict(self) -> None: | |||
err = LimitExceededError("needle") | |||
serialised = json.dumps(err.error_dict(None)) | |||
self.assertIn("needle", err.debug_context) | |||
self.assertNotIn("needle", serialised) | |||
# Create a sub-class to avoid mutating the class-level property. | |||
class LimitExceededErrorHeaders(LimitExceededError): | |||
include_retry_after_header = True | |||
def test_limit_exceeded_header(self) -> None: | |||
err = ErrorsTestCase.LimitExceededErrorHeaders(retry_after_ms=100) | |||
err = self.LimitExceededErrorHeaders(limiter_name="test", retry_after_ms=100) | |||
self.assertEqual(err.error_dict(None).get("retry_after_ms"), 100) | |||
assert err.headers is not None | |||
self.assertEqual(err.headers.get("Retry-After"), "1") | |||
def test_limit_exceeded_rounding(self) -> None: | |||
err = ErrorsTestCase.LimitExceededErrorHeaders(retry_after_ms=3001) | |||
err = self.LimitExceededErrorHeaders(limiter_name="test", retry_after_ms=3001) | |||
self.assertEqual(err.error_dict(None).get("retry_after_ms"), 3001) | |||
assert err.headers is not None | |||
self.assertEqual(err.headers.get("Retry-After"), "4") |
@@ -1,5 +1,6 @@ | |||
from synapse.api.ratelimiting import LimitExceededError, Ratelimiter | |||
from synapse.appservice import ApplicationService | |||
from synapse.config.ratelimiting import RatelimitSettings | |||
from synapse.types import create_requester | |||
from tests import unittest | |||
@@ -10,8 +11,7 @@ class TestRatelimiter(unittest.HomeserverTestCase): | |||
limiter = Ratelimiter( | |||
store=self.hs.get_datastores().main, | |||
clock=self.clock, | |||
rate_hz=0.1, | |||
burst_count=1, | |||
cfg=RatelimitSettings(key="", per_second=0.1, burst_count=1), | |||
) | |||
allowed, time_allowed = self.get_success_or_raise( | |||
limiter.can_do_action(None, key="test_id", _time_now_s=0) | |||
@@ -43,8 +43,11 @@ class TestRatelimiter(unittest.HomeserverTestCase): | |||
limiter = Ratelimiter( | |||
store=self.hs.get_datastores().main, | |||
clock=self.clock, | |||
rate_hz=0.1, | |||
burst_count=1, | |||
cfg=RatelimitSettings( | |||
key="", | |||
per_second=0.1, | |||
burst_count=1, | |||
), | |||
) | |||
allowed, time_allowed = self.get_success_or_raise( | |||
limiter.can_do_action(as_requester, _time_now_s=0) | |||
@@ -76,8 +79,11 @@ class TestRatelimiter(unittest.HomeserverTestCase): | |||
limiter = Ratelimiter( | |||
store=self.hs.get_datastores().main, | |||
clock=self.clock, | |||
rate_hz=0.1, | |||
burst_count=1, | |||
cfg=RatelimitSettings( | |||
key="", | |||
per_second=0.1, | |||
burst_count=1, | |||
), | |||
) | |||
allowed, time_allowed = self.get_success_or_raise( | |||
limiter.can_do_action(as_requester, _time_now_s=0) | |||
@@ -101,8 +107,7 @@ class TestRatelimiter(unittest.HomeserverTestCase): | |||
limiter = Ratelimiter( | |||
store=self.hs.get_datastores().main, | |||
clock=self.clock, | |||
rate_hz=0.1, | |||
burst_count=1, | |||
cfg=RatelimitSettings(key="", per_second=0.1, burst_count=1), | |||
) | |||
# Shouldn't raise | |||
@@ -128,8 +133,7 @@ class TestRatelimiter(unittest.HomeserverTestCase): | |||
limiter = Ratelimiter( | |||
store=self.hs.get_datastores().main, | |||
clock=self.clock, | |||
rate_hz=0.1, | |||
burst_count=1, | |||
cfg=RatelimitSettings(key="", per_second=0.1, burst_count=1), | |||
) | |||
# First attempt should be allowed | |||
@@ -177,8 +181,7 @@ class TestRatelimiter(unittest.HomeserverTestCase): | |||
limiter = Ratelimiter( | |||
store=self.hs.get_datastores().main, | |||
clock=self.clock, | |||
rate_hz=0.1, | |||
burst_count=1, | |||
cfg=RatelimitSettings(key="", per_second=0.1, burst_count=1), | |||
) | |||
# First attempt should be allowed | |||
@@ -208,8 +211,7 @@ class TestRatelimiter(unittest.HomeserverTestCase): | |||
limiter = Ratelimiter( | |||
store=self.hs.get_datastores().main, | |||
clock=self.clock, | |||
rate_hz=0.1, | |||
burst_count=1, | |||
cfg=RatelimitSettings(key="", per_second=0.1, burst_count=1), | |||
) | |||
self.get_success_or_raise( | |||
limiter.can_do_action(None, key="test_id_1", _time_now_s=0) | |||
@@ -244,7 +246,11 @@ class TestRatelimiter(unittest.HomeserverTestCase): | |||
) | |||
) | |||
limiter = Ratelimiter(store=store, clock=self.clock, rate_hz=0.1, burst_count=1) | |||
limiter = Ratelimiter( | |||
store=store, | |||
clock=self.clock, | |||
cfg=RatelimitSettings("", per_second=0.1, burst_count=1), | |||
) | |||
# Shouldn't raise | |||
for _ in range(20): | |||
@@ -254,8 +260,11 @@ class TestRatelimiter(unittest.HomeserverTestCase): | |||
limiter = Ratelimiter( | |||
store=self.hs.get_datastores().main, | |||
clock=self.clock, | |||
rate_hz=0.1, | |||
burst_count=3, | |||
cfg=RatelimitSettings( | |||
key="", | |||
per_second=0.1, | |||
burst_count=3, | |||
), | |||
) | |||
# Test that 4 actions aren't allowed with a maximum burst of 3. | |||
allowed, time_allowed = self.get_success_or_raise( | |||
@@ -321,8 +330,7 @@ class TestRatelimiter(unittest.HomeserverTestCase): | |||
limiter = Ratelimiter( | |||
store=self.hs.get_datastores().main, | |||
clock=self.clock, | |||
rate_hz=0.1, | |||
burst_count=3, | |||
cfg=RatelimitSettings("", per_second=0.1, burst_count=3), | |||
) | |||
def consume_at(time: float) -> bool: | |||
@@ -346,8 +354,11 @@ class TestRatelimiter(unittest.HomeserverTestCase): | |||
limiter = Ratelimiter( | |||
store=self.hs.get_datastores().main, | |||
clock=self.clock, | |||
rate_hz=0.1, | |||
burst_count=3, | |||
cfg=RatelimitSettings( | |||
"", | |||
per_second=0.1, | |||
burst_count=3, | |||
), | |||
) | |||
# Observe two actions, leaving room in the bucket for one more. | |||
@@ -369,8 +380,11 @@ class TestRatelimiter(unittest.HomeserverTestCase): | |||
limiter = Ratelimiter( | |||
store=self.hs.get_datastores().main, | |||
clock=self.clock, | |||
rate_hz=0.1, | |||
burst_count=3, | |||
cfg=RatelimitSettings( | |||
"", | |||
per_second=0.1, | |||
burst_count=3, | |||
), | |||
) | |||
# Observe three actions, filling up the bucket. | |||
@@ -398,8 +412,11 @@ class TestRatelimiter(unittest.HomeserverTestCase): | |||
limiter = Ratelimiter( | |||
store=self.hs.get_datastores().main, | |||
clock=self.clock, | |||
rate_hz=0.1, | |||
burst_count=3, | |||
cfg=RatelimitSettings( | |||
"", | |||
per_second=0.1, | |||
burst_count=3, | |||
), | |||
) | |||
# Observe four actions, exceeding the bucket. | |||
@@ -12,11 +12,42 @@ | |||
# See the License for the specific language governing permissions and | |||
# limitations under the License. | |||
from synapse.config.homeserver import HomeServerConfig | |||
from synapse.config.ratelimiting import RatelimitSettings | |||
from tests.unittest import TestCase | |||
from tests.utils import default_config | |||
class ParseRatelimitSettingsTestcase(TestCase): | |||
def test_depth_1(self) -> None: | |||
cfg = { | |||
"a": { | |||
"per_second": 5, | |||
"burst_count": 10, | |||
} | |||
} | |||
parsed = RatelimitSettings.parse(cfg, "a") | |||
self.assertEqual(parsed, RatelimitSettings("a", 5, 10)) | |||
def test_depth_2(self) -> None: | |||
cfg = { | |||
"a": { | |||
"b": { | |||
"per_second": 5, | |||
"burst_count": 10, | |||
}, | |||
} | |||
} | |||
parsed = RatelimitSettings.parse(cfg, "a.b") | |||
self.assertEqual(parsed, RatelimitSettings("a.b", 5, 10)) | |||
def test_missing(self) -> None: | |||
parsed = RatelimitSettings.parse( | |||
{}, "a", defaults={"per_second": 5, "burst_count": 10} | |||
) | |||
self.assertEqual(parsed, RatelimitSettings("a", 5, 10)) | |||
class RatelimitConfigTestCase(TestCase): | |||
def test_parse_rc_federation(self) -> None: | |||
config_dict = default_config("test") | |||