You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 
 
 
 

390 lines
13 KiB

  1. # Copyright 2015, 2016 OpenMarket Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import collections
  15. import contextlib
  16. import logging
  17. import threading
  18. import typing
  19. from typing import (
  20. Any,
  21. Callable,
  22. ContextManager,
  23. DefaultDict,
  24. Dict,
  25. Iterator,
  26. List,
  27. Mapping,
  28. MutableSet,
  29. Optional,
  30. Set,
  31. Tuple,
  32. )
  33. from weakref import WeakSet
  34. from prometheus_client.core import Counter
  35. from twisted.internet import defer
  36. from synapse.api.errors import LimitExceededError
  37. from synapse.config.ratelimiting import FederationRatelimitSettings
  38. from synapse.logging.context import (
  39. PreserveLoggingContext,
  40. make_deferred_yieldable,
  41. run_in_background,
  42. )
  43. from synapse.logging.opentracing import start_active_span
  44. from synapse.metrics import Histogram, LaterGauge
  45. from synapse.util import Clock
  46. if typing.TYPE_CHECKING:
  47. from contextlib import _GeneratorContextManager
  48. logger = logging.getLogger(__name__)
  49. # Track how much the ratelimiter is affecting requests
  50. rate_limit_sleep_counter = Counter(
  51. "synapse_rate_limit_sleep",
  52. "Number of requests slept by the rate limiter",
  53. ["rate_limiter_name"],
  54. )
  55. rate_limit_reject_counter = Counter(
  56. "synapse_rate_limit_reject",
  57. "Number of requests rejected by the rate limiter",
  58. ["rate_limiter_name"],
  59. )
  60. queue_wait_timer = Histogram(
  61. "synapse_rate_limit_queue_wait_time_seconds",
  62. "Amount of time spent waiting for the rate limiter to let our request through.",
  63. ["rate_limiter_name"],
  64. buckets=(
  65. 0.005,
  66. 0.01,
  67. 0.025,
  68. 0.05,
  69. 0.1,
  70. 0.25,
  71. 0.5,
  72. 0.75,
  73. 1.0,
  74. 2.5,
  75. 5.0,
  76. 10.0,
  77. 20.0,
  78. "+Inf",
  79. ),
  80. )
  81. # This must be a `WeakSet`, otherwise we indirectly hold on to entire `HomeServer`s
  82. # during trial test runs and leak a lot of memory.
  83. _rate_limiter_instances: MutableSet["FederationRateLimiter"] = WeakSet()
  84. # Protects the _rate_limiter_instances set from concurrent access
  85. _rate_limiter_instances_lock = threading.Lock()
  86. def _get_counts_from_rate_limiter_instance(
  87. count_func: Callable[["FederationRateLimiter"], int]
  88. ) -> Mapping[Tuple[str, ...], int]:
  89. """Returns a count of something (slept/rejected hosts) by (metrics_name)"""
  90. # Cast to a list to prevent it changing while the Prometheus
  91. # thread is collecting metrics
  92. with _rate_limiter_instances_lock:
  93. rate_limiter_instances = list(_rate_limiter_instances)
  94. # Map from (metrics_name,) -> int, the number of something like slept hosts
  95. # or rejected hosts. The key type is Tuple[str], but we leave the length
  96. # unspecified for compatability with LaterGauge's annotations.
  97. counts: Dict[Tuple[str, ...], int] = {}
  98. for rate_limiter_instance in rate_limiter_instances:
  99. # Only track metrics if they provided a `metrics_name` to
  100. # differentiate this instance of the rate limiter.
  101. if rate_limiter_instance.metrics_name:
  102. key = (rate_limiter_instance.metrics_name,)
  103. counts[key] = count_func(rate_limiter_instance)
  104. return counts
  105. # We track the number of affected hosts per time-period so we can
  106. # differentiate one really noisy homeserver from a general
  107. # ratelimit tuning problem across the federation.
  108. LaterGauge(
  109. "synapse_rate_limit_sleep_affected_hosts",
  110. "Number of hosts that had requests put to sleep",
  111. ["rate_limiter_name"],
  112. lambda: _get_counts_from_rate_limiter_instance(
  113. lambda rate_limiter_instance: sum(
  114. ratelimiter.should_sleep()
  115. for ratelimiter in rate_limiter_instance.ratelimiters.values()
  116. )
  117. ),
  118. )
  119. LaterGauge(
  120. "synapse_rate_limit_reject_affected_hosts",
  121. "Number of hosts that had requests rejected",
  122. ["rate_limiter_name"],
  123. lambda: _get_counts_from_rate_limiter_instance(
  124. lambda rate_limiter_instance: sum(
  125. ratelimiter.should_reject()
  126. for ratelimiter in rate_limiter_instance.ratelimiters.values()
  127. )
  128. ),
  129. )
  130. class FederationRateLimiter:
  131. """Used to rate limit request per-host."""
  132. def __init__(
  133. self,
  134. clock: Clock,
  135. config: FederationRatelimitSettings,
  136. metrics_name: Optional[str] = None,
  137. ):
  138. """
  139. Args:
  140. clock
  141. config
  142. metrics_name: The name of the rate limiter so we can differentiate it
  143. from the rest in the metrics. If `None`, we don't track metrics
  144. for this rate limiter.
  145. """
  146. self.metrics_name = metrics_name
  147. def new_limiter() -> "_PerHostRatelimiter":
  148. return _PerHostRatelimiter(
  149. clock=clock, config=config, metrics_name=metrics_name
  150. )
  151. self.ratelimiters: DefaultDict[
  152. str, "_PerHostRatelimiter"
  153. ] = collections.defaultdict(new_limiter)
  154. with _rate_limiter_instances_lock:
  155. _rate_limiter_instances.add(self)
  156. def ratelimit(self, host: str) -> "_GeneratorContextManager[defer.Deferred[None]]":
  157. """Used to ratelimit an incoming request from a given host
  158. Example usage:
  159. with rate_limiter.ratelimit(origin) as wait_deferred:
  160. yield wait_deferred
  161. # Handle request ...
  162. Args:
  163. host: Origin of incoming request.
  164. Returns:
  165. context manager which returns a deferred.
  166. """
  167. return self.ratelimiters[host].ratelimit(host)
  168. class _PerHostRatelimiter:
  169. def __init__(
  170. self,
  171. clock: Clock,
  172. config: FederationRatelimitSettings,
  173. metrics_name: Optional[str] = None,
  174. ):
  175. """
  176. Args:
  177. clock
  178. config
  179. metrics_name: The name of the rate limiter so we can differentiate it
  180. from the rest in the metrics. If `None`, we don't track metrics
  181. for this rate limiter.
  182. from the rest in the metrics
  183. """
  184. self.clock = clock
  185. self.metrics_name = metrics_name
  186. self.window_size = config.window_size
  187. self.sleep_limit = config.sleep_limit
  188. self.sleep_sec = config.sleep_delay / 1000.0
  189. self.reject_limit = config.reject_limit
  190. self.concurrent_requests = config.concurrent
  191. # request_id objects for requests which have been slept
  192. self.sleeping_requests: Set[object] = set()
  193. # map from request_id object to Deferred for requests which are ready
  194. # for processing but have been queued
  195. self.ready_request_queue: collections.OrderedDict[
  196. object, defer.Deferred[None]
  197. ] = collections.OrderedDict()
  198. # request id objects for requests which are in progress
  199. self.current_processing: Set[object] = set()
  200. # times at which we have recently (within the last window_size ms)
  201. # received requests.
  202. self.request_times: List[int] = []
  203. @contextlib.contextmanager
  204. def ratelimit(self, host: str) -> "Iterator[defer.Deferred[None]]":
  205. # `contextlib.contextmanager` takes a generator and turns it into a
  206. # context manager. The generator should only yield once with a value
  207. # to be returned by manager.
  208. # Exceptions will be reraised at the yield.
  209. self.host = host
  210. request_id = object()
  211. # Ideally we'd use `Deferred.fromCoroutine()` here, to save on redundant
  212. # type-checking, but we'd need Twisted >= 21.2.
  213. ret = defer.ensureDeferred(self._on_enter_with_tracing(request_id))
  214. try:
  215. yield ret
  216. finally:
  217. self._on_exit(request_id)
  218. def should_reject(self) -> bool:
  219. """
  220. Whether to reject the request if we already have too many queued up
  221. (either sleeping or in the ready queue).
  222. """
  223. queue_size = len(self.ready_request_queue) + len(self.sleeping_requests)
  224. return queue_size > self.reject_limit
  225. def should_sleep(self) -> bool:
  226. """
  227. Whether to sleep the request if we already have too many requests coming
  228. through within the window.
  229. """
  230. return len(self.request_times) > self.sleep_limit
  231. async def _on_enter_with_tracing(self, request_id: object) -> None:
  232. maybe_metrics_cm: ContextManager = contextlib.nullcontext()
  233. if self.metrics_name:
  234. maybe_metrics_cm = queue_wait_timer.labels(self.metrics_name).time()
  235. with start_active_span("ratelimit wait"), maybe_metrics_cm:
  236. await self._on_enter(request_id)
  237. def _on_enter(self, request_id: object) -> "defer.Deferred[None]":
  238. time_now = self.clock.time_msec()
  239. # remove any entries from request_times which aren't within the window
  240. self.request_times[:] = [
  241. r for r in self.request_times if time_now - r < self.window_size
  242. ]
  243. # reject the request if we already have too many queued up (either
  244. # sleeping or in the ready queue).
  245. if self.should_reject():
  246. logger.debug("Ratelimiter(%s): rejecting request", self.host)
  247. if self.metrics_name:
  248. rate_limit_reject_counter.labels(self.metrics_name).inc()
  249. raise LimitExceededError(
  250. retry_after_ms=int(self.window_size / self.sleep_limit)
  251. )
  252. self.request_times.append(time_now)
  253. def queue_request() -> "defer.Deferred[None]":
  254. if len(self.current_processing) >= self.concurrent_requests:
  255. queue_defer: defer.Deferred[None] = defer.Deferred()
  256. self.ready_request_queue[request_id] = queue_defer
  257. logger.info(
  258. "Ratelimiter(%s): queueing request (queue now %i items)",
  259. self.host,
  260. len(self.ready_request_queue),
  261. )
  262. return queue_defer
  263. else:
  264. return defer.succeed(None)
  265. logger.debug(
  266. "Ratelimit(%s) [%s]: len(self.request_times)=%d",
  267. self.host,
  268. id(request_id),
  269. len(self.request_times),
  270. )
  271. if self.should_sleep():
  272. logger.debug(
  273. "Ratelimiter(%s) [%s]: sleeping request for %f sec",
  274. self.host,
  275. id(request_id),
  276. self.sleep_sec,
  277. )
  278. if self.metrics_name:
  279. rate_limit_sleep_counter.labels(self.metrics_name).inc()
  280. ret_defer = run_in_background(self.clock.sleep, self.sleep_sec)
  281. self.sleeping_requests.add(request_id)
  282. def on_wait_finished(_: Any) -> "defer.Deferred[None]":
  283. logger.debug(
  284. "Ratelimit(%s) [%s]: Finished sleeping", self.host, id(request_id)
  285. )
  286. self.sleeping_requests.discard(request_id)
  287. queue_defer = queue_request()
  288. return queue_defer
  289. ret_defer.addBoth(on_wait_finished)
  290. else:
  291. ret_defer = queue_request()
  292. def on_start(r: object) -> object:
  293. logger.debug(
  294. "Ratelimit(%s) [%s]: Processing req", self.host, id(request_id)
  295. )
  296. self.current_processing.add(request_id)
  297. return r
  298. def on_err(r: object) -> object:
  299. # XXX: why is this necessary? this is called before we start
  300. # processing the request so why would the request be in
  301. # current_processing?
  302. self.current_processing.discard(request_id)
  303. return r
  304. def on_both(r: object) -> object:
  305. # Ensure that we've properly cleaned up.
  306. self.sleeping_requests.discard(request_id)
  307. self.ready_request_queue.pop(request_id, None)
  308. return r
  309. ret_defer.addCallbacks(on_start, on_err)
  310. ret_defer.addBoth(on_both)
  311. return make_deferred_yieldable(ret_defer)
  312. def _on_exit(self, request_id: object) -> None:
  313. logger.debug("Ratelimit(%s) [%s]: Processed req", self.host, id(request_id))
  314. # When requests complete synchronously, we will recursively start the next
  315. # request in the queue. To avoid stack exhaustion, we defer starting the next
  316. # request until the next reactor tick.
  317. def start_next_request() -> None:
  318. # We only remove the completed request from the list when we're about to
  319. # start the next one, otherwise we can allow extra requests through.
  320. self.current_processing.discard(request_id)
  321. try:
  322. # start processing the next item on the queue.
  323. _, deferred = self.ready_request_queue.popitem(last=False)
  324. with PreserveLoggingContext():
  325. deferred.callback(None)
  326. except KeyError:
  327. pass
  328. self.clock.call_later(0.0, start_next_request)