You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 
 
 
 

1502 rivejä
58 KiB

  1. # Copyright 2015, 2016 OpenMarket Ltd
  2. # Copyright 2018 New Vector Ltd
  3. # Copyright 2019-2021 Matrix.org Federation C.I.C
  4. #
  5. # Licensed under the Apache License, Version 2.0 (the "License");
  6. # you may not use this file except in compliance with the License.
  7. # You may obtain a copy of the License at
  8. #
  9. # http://www.apache.org/licenses/LICENSE-2.0
  10. #
  11. # Unless required by applicable law or agreed to in writing, software
  12. # distributed under the License is distributed on an "AS IS" BASIS,
  13. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  14. # See the License for the specific language governing permissions and
  15. # limitations under the License.
  16. import logging
  17. import random
  18. from typing import (
  19. TYPE_CHECKING,
  20. Any,
  21. Awaitable,
  22. Callable,
  23. Collection,
  24. Dict,
  25. List,
  26. Mapping,
  27. Optional,
  28. Tuple,
  29. Union,
  30. )
  31. from prometheus_client import Counter, Gauge, Histogram
  32. from twisted.python import failure
  33. from synapse.api.constants import (
  34. Direction,
  35. EduTypes,
  36. EventContentFields,
  37. EventTypes,
  38. Membership,
  39. )
  40. from synapse.api.errors import (
  41. AuthError,
  42. Codes,
  43. FederationError,
  44. IncompatibleRoomVersionError,
  45. NotFoundError,
  46. PartialStateConflictError,
  47. SynapseError,
  48. UnsupportedRoomVersionError,
  49. )
  50. from synapse.api.room_versions import KNOWN_ROOM_VERSIONS, RoomVersion
  51. from synapse.crypto.event_signing import compute_event_signature
  52. from synapse.events import EventBase
  53. from synapse.events.snapshot import EventContext
  54. from synapse.federation.federation_base import (
  55. FederationBase,
  56. InvalidEventSignatureError,
  57. event_from_pdu_json,
  58. )
  59. from synapse.federation.persistence import TransactionActions
  60. from synapse.federation.units import Edu, Transaction
  61. from synapse.handlers.worker_lock import NEW_EVENT_DURING_PURGE_LOCK_NAME
  62. from synapse.http.servlet import assert_params_in_dict
  63. from synapse.logging.context import (
  64. make_deferred_yieldable,
  65. nested_logging_context,
  66. run_in_background,
  67. )
  68. from synapse.logging.opentracing import (
  69. SynapseTags,
  70. log_kv,
  71. set_tag,
  72. start_active_span_from_edu,
  73. tag_args,
  74. trace,
  75. )
  76. from synapse.metrics.background_process_metrics import wrap_as_background_process
  77. from synapse.replication.http.federation import (
  78. ReplicationFederationSendEduRestServlet,
  79. ReplicationGetQueryRestServlet,
  80. )
  81. from synapse.storage.databases.main.lock import Lock
  82. from synapse.storage.databases.main.roommember import extract_heroes_from_room_summary
  83. from synapse.storage.roommember import MemberSummary
  84. from synapse.types import JsonDict, StateMap, UserID, get_domain_from_id
  85. from synapse.util import unwrapFirstError
  86. from synapse.util.async_helpers import Linearizer, concurrently_execute, gather_results
  87. from synapse.util.caches.response_cache import ResponseCache
  88. from synapse.util.stringutils import parse_server_name
  89. if TYPE_CHECKING:
  90. from synapse.server import HomeServer
  91. # when processing incoming transactions, we try to handle multiple rooms in
  92. # parallel, up to this limit.
  93. TRANSACTION_CONCURRENCY_LIMIT = 10
  94. logger = logging.getLogger(__name__)
  95. received_pdus_counter = Counter("synapse_federation_server_received_pdus", "")
  96. received_edus_counter = Counter("synapse_federation_server_received_edus", "")
  97. received_queries_counter = Counter(
  98. "synapse_federation_server_received_queries", "", ["type"]
  99. )
  100. pdu_process_time = Histogram(
  101. "synapse_federation_server_pdu_process_time",
  102. "Time taken to process an event",
  103. )
  104. last_pdu_ts_metric = Gauge(
  105. "synapse_federation_last_received_pdu_time",
  106. "The timestamp of the last PDU which was successfully received from the given domain",
  107. labelnames=("server_name",),
  108. )
  109. # The name of the lock to use when process events in a room received over
  110. # federation.
  111. _INBOUND_EVENT_HANDLING_LOCK_NAME = "federation_inbound_pdu"
  112. class FederationServer(FederationBase):
  113. def __init__(self, hs: "HomeServer"):
  114. super().__init__(hs)
  115. self.server_name = hs.hostname
  116. self.handler = hs.get_federation_handler()
  117. self._spam_checker_module_callbacks = hs.get_module_api_callbacks().spam_checker
  118. self._federation_event_handler = hs.get_federation_event_handler()
  119. self.state = hs.get_state_handler()
  120. self._event_auth_handler = hs.get_event_auth_handler()
  121. self._room_member_handler = hs.get_room_member_handler()
  122. self._e2e_keys_handler = hs.get_e2e_keys_handler()
  123. self._worker_lock_handler = hs.get_worker_locks_handler()
  124. self._state_storage_controller = hs.get_storage_controllers().state
  125. self.device_handler = hs.get_device_handler()
  126. # Ensure the following handlers are loaded since they register callbacks
  127. # with FederationHandlerRegistry.
  128. hs.get_directory_handler()
  129. self._server_linearizer = Linearizer("fed_server")
  130. # origins that we are currently processing a transaction from.
  131. # a dict from origin to txn id.
  132. self._active_transactions: Dict[str, str] = {}
  133. # We cache results for transaction with the same ID
  134. self._transaction_resp_cache: ResponseCache[Tuple[str, str]] = ResponseCache(
  135. hs.get_clock(), "fed_txn_handler", timeout_ms=30000
  136. )
  137. self.transaction_actions = TransactionActions(self.store)
  138. self.registry = hs.get_federation_registry()
  139. # We cache responses to state queries, as they take a while and often
  140. # come in waves.
  141. self._state_resp_cache: ResponseCache[
  142. Tuple[str, Optional[str]]
  143. ] = ResponseCache(hs.get_clock(), "state_resp", timeout_ms=30000)
  144. self._state_ids_resp_cache: ResponseCache[Tuple[str, str]] = ResponseCache(
  145. hs.get_clock(), "state_ids_resp", timeout_ms=30000
  146. )
  147. self._federation_metrics_domains = (
  148. hs.config.federation.federation_metrics_domains
  149. )
  150. self._room_prejoin_state_types = hs.config.api.room_prejoin_state
  151. # Whether we have started handling old events in the staging area.
  152. self._started_handling_of_staged_events = False
  153. @wrap_as_background_process("_handle_old_staged_events")
  154. async def _handle_old_staged_events(self) -> None:
  155. """Handle old staged events by fetching all rooms that have staged
  156. events and start the processing of each of those rooms.
  157. """
  158. # Get all the rooms IDs with staged events.
  159. room_ids = await self.store.get_all_rooms_with_staged_incoming_events()
  160. # We then shuffle them so that if there are multiple instances doing
  161. # this work they're less likely to collide.
  162. random.shuffle(room_ids)
  163. for room_id in room_ids:
  164. room_version = await self.store.get_room_version(room_id)
  165. # Try and acquire the processing lock for the room, if we get it start a
  166. # background process for handling the events in the room.
  167. lock = await self.store.try_acquire_lock(
  168. _INBOUND_EVENT_HANDLING_LOCK_NAME, room_id
  169. )
  170. if lock:
  171. logger.info("Handling old staged inbound events in %s", room_id)
  172. self._process_incoming_pdus_in_room_inner(
  173. room_id,
  174. room_version,
  175. lock,
  176. )
  177. # We pause a bit so that we don't start handling all rooms at once.
  178. await self._clock.sleep(random.uniform(0, 0.1))
  179. async def on_backfill_request(
  180. self, origin: str, room_id: str, versions: List[str], limit: int
  181. ) -> Tuple[int, Dict[str, Any]]:
  182. async with self._server_linearizer.queue((origin, room_id)):
  183. origin_host, _ = parse_server_name(origin)
  184. await self.check_server_matches_acl(origin_host, room_id)
  185. pdus = await self.handler.on_backfill_request(
  186. origin, room_id, versions, limit
  187. )
  188. res = self._transaction_dict_from_pdus(pdus)
  189. return 200, res
  190. async def on_timestamp_to_event_request(
  191. self, origin: str, room_id: str, timestamp: int, direction: Direction
  192. ) -> Tuple[int, Dict[str, Any]]:
  193. """When we receive a federated `/timestamp_to_event` request,
  194. handle all of the logic for validating and fetching the event.
  195. Args:
  196. origin: The server we received the event from
  197. room_id: Room to fetch the event from
  198. timestamp: The point in time (inclusive) we should navigate from in
  199. the given direction to find the closest event.
  200. direction: indicates whether we should navigate forward
  201. or backward from the given timestamp to find the closest event.
  202. Returns:
  203. Tuple indicating the response status code and dictionary response
  204. body including `event_id`.
  205. """
  206. async with self._server_linearizer.queue((origin, room_id)):
  207. origin_host, _ = parse_server_name(origin)
  208. await self.check_server_matches_acl(origin_host, room_id)
  209. # We only try to fetch data from the local database
  210. event_id = await self.store.get_event_id_for_timestamp(
  211. room_id, timestamp, direction
  212. )
  213. if event_id:
  214. event = await self.store.get_event(
  215. event_id, allow_none=False, allow_rejected=False
  216. )
  217. return 200, {
  218. "event_id": event_id,
  219. "origin_server_ts": event.origin_server_ts,
  220. }
  221. raise SynapseError(
  222. 404,
  223. "Unable to find event from %s in direction %s" % (timestamp, direction),
  224. errcode=Codes.NOT_FOUND,
  225. )
  226. async def on_incoming_transaction(
  227. self,
  228. origin: str,
  229. transaction_id: str,
  230. destination: str,
  231. transaction_data: JsonDict,
  232. ) -> Tuple[int, JsonDict]:
  233. # If we receive a transaction we should make sure that kick off handling
  234. # any old events in the staging area.
  235. if not self._started_handling_of_staged_events:
  236. self._started_handling_of_staged_events = True
  237. self._handle_old_staged_events()
  238. # Start a periodic check for old staged events. This is to handle
  239. # the case where locks time out, e.g. if another process gets killed
  240. # without dropping its locks.
  241. self._clock.looping_call(self._handle_old_staged_events, 60 * 1000)
  242. # keep this as early as possible to make the calculated origin ts as
  243. # accurate as possible.
  244. request_time = self._clock.time_msec()
  245. transaction = Transaction(
  246. transaction_id=transaction_id,
  247. destination=destination,
  248. origin=origin,
  249. origin_server_ts=transaction_data.get("origin_server_ts"), # type: ignore[arg-type]
  250. pdus=transaction_data.get("pdus"),
  251. edus=transaction_data.get("edus"),
  252. )
  253. if not transaction_id:
  254. raise Exception("Transaction missing transaction_id")
  255. logger.debug("[%s] Got transaction", transaction_id)
  256. # Reject malformed transactions early: reject if too many PDUs/EDUs
  257. if len(transaction.pdus) > 50 or len(transaction.edus) > 100:
  258. logger.info("Transaction PDU or EDU count too large. Returning 400")
  259. return 400, {}
  260. # we only process one transaction from each origin at a time. We need to do
  261. # this check here, rather than in _on_incoming_transaction_inner so that we
  262. # don't cache the rejection in _transaction_resp_cache (so that if the txn
  263. # arrives again later, we can process it).
  264. current_transaction = self._active_transactions.get(origin)
  265. if current_transaction and current_transaction != transaction_id:
  266. logger.warning(
  267. "Received another txn %s from %s while still processing %s",
  268. transaction_id,
  269. origin,
  270. current_transaction,
  271. )
  272. return 429, {
  273. "errcode": Codes.UNKNOWN,
  274. "error": "Too many concurrent transactions",
  275. }
  276. # CRITICAL SECTION: we must now not await until we populate _active_transactions
  277. # in _on_incoming_transaction_inner.
  278. # We wrap in a ResponseCache so that we de-duplicate retried
  279. # transactions.
  280. return await self._transaction_resp_cache.wrap(
  281. (origin, transaction_id),
  282. self._on_incoming_transaction_inner,
  283. origin,
  284. transaction,
  285. request_time,
  286. )
  287. async def _on_incoming_transaction_inner(
  288. self, origin: str, transaction: Transaction, request_time: int
  289. ) -> Tuple[int, Dict[str, Any]]:
  290. # CRITICAL SECTION: the first thing we must do (before awaiting) is
  291. # add an entry to _active_transactions.
  292. assert origin not in self._active_transactions
  293. self._active_transactions[origin] = transaction.transaction_id
  294. try:
  295. result = await self._handle_incoming_transaction(
  296. origin, transaction, request_time
  297. )
  298. return result
  299. finally:
  300. del self._active_transactions[origin]
  301. async def _handle_incoming_transaction(
  302. self, origin: str, transaction: Transaction, request_time: int
  303. ) -> Tuple[int, Dict[str, Any]]:
  304. """Process an incoming transaction and return the HTTP response
  305. Args:
  306. origin: the server making the request
  307. transaction: incoming transaction
  308. request_time: timestamp that the HTTP request arrived at
  309. Returns:
  310. HTTP response code and body
  311. """
  312. existing_response = await self.transaction_actions.have_responded(
  313. origin, transaction
  314. )
  315. if existing_response:
  316. logger.debug(
  317. "[%s] We've already responded to this request",
  318. transaction.transaction_id,
  319. )
  320. return existing_response
  321. logger.debug("[%s] Transaction is new", transaction.transaction_id)
  322. # We process PDUs and EDUs in parallel. This is important as we don't
  323. # want to block things like to device messages from reaching clients
  324. # behind the potentially expensive handling of PDUs.
  325. pdu_results, _ = await make_deferred_yieldable(
  326. gather_results(
  327. (
  328. run_in_background(
  329. self._handle_pdus_in_txn, origin, transaction, request_time
  330. ),
  331. run_in_background(self._handle_edus_in_txn, origin, transaction),
  332. ),
  333. consumeErrors=True,
  334. ).addErrback(unwrapFirstError)
  335. )
  336. response = {"pdus": pdu_results}
  337. logger.debug("Returning: %s", str(response))
  338. await self.transaction_actions.set_response(origin, transaction, 200, response)
  339. return 200, response
  340. async def _handle_pdus_in_txn(
  341. self, origin: str, transaction: Transaction, request_time: int
  342. ) -> Dict[str, dict]:
  343. """Process the PDUs in a received transaction.
  344. Args:
  345. origin: the server making the request
  346. transaction: incoming transaction
  347. request_time: timestamp that the HTTP request arrived at
  348. Returns:
  349. A map from event ID of a processed PDU to any errors we should
  350. report back to the sending server.
  351. """
  352. received_pdus_counter.inc(len(transaction.pdus))
  353. origin_host, _ = parse_server_name(origin)
  354. pdus_by_room: Dict[str, List[EventBase]] = {}
  355. newest_pdu_ts = 0
  356. for p in transaction.pdus:
  357. # FIXME (richardv): I don't think this works:
  358. # https://github.com/matrix-org/synapse/issues/8429
  359. if "unsigned" in p:
  360. unsigned = p["unsigned"]
  361. if "age" in unsigned:
  362. p["age"] = unsigned["age"]
  363. if "age" in p:
  364. p["age_ts"] = request_time - int(p["age"])
  365. del p["age"]
  366. # We try and pull out an event ID so that if later checks fail we
  367. # can log something sensible. We don't mandate an event ID here in
  368. # case future event formats get rid of the key.
  369. possible_event_id = p.get("event_id", "<Unknown>")
  370. # Now we get the room ID so that we can check that we know the
  371. # version of the room.
  372. room_id = p.get("room_id")
  373. if not room_id:
  374. logger.info(
  375. "Ignoring PDU as does not have a room_id. Event ID: %s",
  376. possible_event_id,
  377. )
  378. continue
  379. try:
  380. room_version = await self.store.get_room_version(room_id)
  381. except NotFoundError:
  382. logger.info("Ignoring PDU for unknown room_id: %s", room_id)
  383. continue
  384. except UnsupportedRoomVersionError as e:
  385. # this can happen if support for a given room version is withdrawn,
  386. # so that we still get events for said room.
  387. logger.info("Ignoring PDU: %s", e)
  388. continue
  389. event = event_from_pdu_json(p, room_version)
  390. pdus_by_room.setdefault(room_id, []).append(event)
  391. if event.origin_server_ts > newest_pdu_ts:
  392. newest_pdu_ts = event.origin_server_ts
  393. pdu_results = {}
  394. # we can process different rooms in parallel (which is useful if they
  395. # require callouts to other servers to fetch missing events), but
  396. # impose a limit to avoid going too crazy with ram/cpu.
  397. async def process_pdus_for_room(room_id: str) -> None:
  398. with nested_logging_context(room_id):
  399. logger.debug("Processing PDUs for %s", room_id)
  400. try:
  401. await self.check_server_matches_acl(origin_host, room_id)
  402. except AuthError as e:
  403. logger.warning(
  404. "Ignoring PDUs for room %s from banned server", room_id
  405. )
  406. for pdu in pdus_by_room[room_id]:
  407. event_id = pdu.event_id
  408. pdu_results[event_id] = e.error_dict(self.hs.config)
  409. return
  410. for pdu in pdus_by_room[room_id]:
  411. pdu_results[pdu.event_id] = await process_pdu(pdu)
  412. async def process_pdu(pdu: EventBase) -> JsonDict:
  413. """
  414. Processes a pushed PDU sent to us via a `/send` transaction
  415. Returns:
  416. JsonDict representing a "PDU Processing Result" that will be bundled up
  417. with the other processed PDU's in the `/send` transaction and sent back
  418. to remote homeserver.
  419. """
  420. event_id = pdu.event_id
  421. with nested_logging_context(event_id):
  422. try:
  423. await self._handle_received_pdu(origin, pdu)
  424. return {}
  425. except FederationError as e:
  426. logger.warning("Error handling PDU %s: %s", event_id, e)
  427. return {"error": str(e)}
  428. except Exception as e:
  429. f = failure.Failure()
  430. logger.error(
  431. "Failed to handle PDU %s",
  432. event_id,
  433. exc_info=(f.type, f.value, f.getTracebackObject()),
  434. )
  435. return {"error": str(e)}
  436. await concurrently_execute(
  437. process_pdus_for_room, pdus_by_room.keys(), TRANSACTION_CONCURRENCY_LIMIT
  438. )
  439. if newest_pdu_ts and origin in self._federation_metrics_domains:
  440. last_pdu_ts_metric.labels(server_name=origin).set(newest_pdu_ts / 1000)
  441. return pdu_results
  442. async def _handle_edus_in_txn(self, origin: str, transaction: Transaction) -> None:
  443. """Process the EDUs in a received transaction."""
  444. async def _process_edu(edu_dict: JsonDict) -> None:
  445. received_edus_counter.inc()
  446. edu = Edu(
  447. origin=origin,
  448. destination=self.server_name,
  449. edu_type=edu_dict["edu_type"],
  450. content=edu_dict["content"],
  451. )
  452. await self.registry.on_edu(edu.edu_type, origin, edu.content)
  453. await concurrently_execute(
  454. _process_edu,
  455. transaction.edus,
  456. TRANSACTION_CONCURRENCY_LIMIT,
  457. )
  458. async def on_room_state_request(
  459. self, origin: str, room_id: str, event_id: str
  460. ) -> Tuple[int, JsonDict]:
  461. await self._event_auth_handler.assert_host_in_room(room_id, origin)
  462. origin_host, _ = parse_server_name(origin)
  463. await self.check_server_matches_acl(origin_host, room_id)
  464. # we grab the linearizer to protect ourselves from servers which hammer
  465. # us. In theory we might already have the response to this query
  466. # in the cache so we could return it without waiting for the linearizer
  467. # - but that's non-trivial to get right, and anyway somewhat defeats
  468. # the point of the linearizer.
  469. async with self._server_linearizer.queue((origin, room_id)):
  470. resp = await self._state_resp_cache.wrap(
  471. (room_id, event_id),
  472. self._on_context_state_request_compute,
  473. room_id,
  474. event_id,
  475. )
  476. return 200, resp
  477. @trace
  478. @tag_args
  479. async def on_state_ids_request(
  480. self, origin: str, room_id: str, event_id: str
  481. ) -> Tuple[int, JsonDict]:
  482. if not event_id:
  483. raise NotImplementedError("Specify an event")
  484. await self._event_auth_handler.assert_host_in_room(room_id, origin)
  485. origin_host, _ = parse_server_name(origin)
  486. await self.check_server_matches_acl(origin_host, room_id)
  487. resp = await self._state_ids_resp_cache.wrap(
  488. (room_id, event_id),
  489. self._on_state_ids_request_compute,
  490. room_id,
  491. event_id,
  492. )
  493. return 200, resp
  494. @trace
  495. @tag_args
  496. async def _on_state_ids_request_compute(
  497. self, room_id: str, event_id: str
  498. ) -> JsonDict:
  499. state_ids = await self.handler.get_state_ids_for_pdu(room_id, event_id)
  500. auth_chain_ids = await self.store.get_auth_chain_ids(room_id, state_ids)
  501. return {"pdu_ids": state_ids, "auth_chain_ids": list(auth_chain_ids)}
  502. async def _on_context_state_request_compute(
  503. self, room_id: str, event_id: str
  504. ) -> Dict[str, list]:
  505. pdus: Collection[EventBase]
  506. event_ids = await self.handler.get_state_ids_for_pdu(room_id, event_id)
  507. pdus = await self.store.get_events_as_list(event_ids)
  508. auth_chain = await self.store.get_auth_chain(
  509. room_id, [pdu.event_id for pdu in pdus]
  510. )
  511. return {
  512. "pdus": [pdu.get_pdu_json() for pdu in pdus],
  513. "auth_chain": [pdu.get_pdu_json() for pdu in auth_chain],
  514. }
  515. async def on_pdu_request(
  516. self, origin: str, event_id: str
  517. ) -> Tuple[int, Union[JsonDict, str]]:
  518. pdu = await self.handler.get_persisted_pdu(origin, event_id)
  519. if pdu:
  520. return 200, self._transaction_dict_from_pdus([pdu])
  521. else:
  522. return 404, ""
  523. async def on_query_request(
  524. self, query_type: str, args: Dict[str, str]
  525. ) -> Tuple[int, Dict[str, Any]]:
  526. received_queries_counter.labels(query_type).inc()
  527. resp = await self.registry.on_query(query_type, args)
  528. return 200, resp
  529. async def on_make_join_request(
  530. self, origin: str, room_id: str, user_id: str, supported_versions: List[str]
  531. ) -> Dict[str, Any]:
  532. origin_host, _ = parse_server_name(origin)
  533. await self.check_server_matches_acl(origin_host, room_id)
  534. room_version = await self.store.get_room_version_id(room_id)
  535. if room_version not in supported_versions:
  536. logger.warning(
  537. "Room version %s not in %s", room_version, supported_versions
  538. )
  539. raise IncompatibleRoomVersionError(room_version=room_version)
  540. # Refuse the request if that room has seen too many joins recently.
  541. # This is in addition to the HS-level rate limiting applied by
  542. # BaseFederationServlet.
  543. # type-ignore: mypy doesn't seem able to deduce the type of the limiter(!?)
  544. await self._room_member_handler._join_rate_per_room_limiter.ratelimit( # type: ignore[has-type]
  545. requester=None,
  546. key=room_id,
  547. update=False,
  548. )
  549. pdu = await self.handler.on_make_join_request(origin, room_id, user_id)
  550. return {"event": pdu.get_templated_pdu_json(), "room_version": room_version}
  551. async def on_invite_request(
  552. self, origin: str, content: JsonDict, room_version_id: str
  553. ) -> Dict[str, Any]:
  554. room_version = KNOWN_ROOM_VERSIONS.get(room_version_id)
  555. if not room_version:
  556. raise SynapseError(
  557. 400,
  558. "Homeserver does not support this room version",
  559. Codes.UNSUPPORTED_ROOM_VERSION,
  560. )
  561. pdu = event_from_pdu_json(content, room_version)
  562. origin_host, _ = parse_server_name(origin)
  563. await self.check_server_matches_acl(origin_host, pdu.room_id)
  564. try:
  565. pdu = await self._check_sigs_and_hash(room_version, pdu)
  566. except InvalidEventSignatureError as e:
  567. errmsg = f"event id {pdu.event_id}: {e}"
  568. logger.warning("%s", errmsg)
  569. raise SynapseError(403, errmsg, Codes.FORBIDDEN)
  570. ret_pdu = await self.handler.on_invite_request(origin, pdu, room_version)
  571. time_now = self._clock.time_msec()
  572. return {"event": ret_pdu.get_pdu_json(time_now)}
  573. async def on_send_join_request(
  574. self,
  575. origin: str,
  576. content: JsonDict,
  577. room_id: str,
  578. caller_supports_partial_state: bool = False,
  579. ) -> Dict[str, Any]:
  580. set_tag(
  581. SynapseTags.SEND_JOIN_RESPONSE_IS_PARTIAL_STATE,
  582. caller_supports_partial_state,
  583. )
  584. await self._room_member_handler._join_rate_per_room_limiter.ratelimit( # type: ignore[has-type]
  585. requester=None,
  586. key=room_id,
  587. update=False,
  588. )
  589. event, context = await self._on_send_membership_event(
  590. origin, content, Membership.JOIN, room_id
  591. )
  592. prev_state_ids = await context.get_prev_state_ids()
  593. state_event_ids: Collection[str]
  594. servers_in_room: Optional[Collection[str]]
  595. if caller_supports_partial_state:
  596. summary = await self.store.get_room_summary(room_id)
  597. state_event_ids = _get_event_ids_for_partial_state_join(
  598. event, prev_state_ids, summary
  599. )
  600. servers_in_room = await self.state.get_hosts_in_room_at_events(
  601. room_id, event_ids=event.prev_event_ids()
  602. )
  603. else:
  604. state_event_ids = prev_state_ids.values()
  605. servers_in_room = None
  606. auth_chain_event_ids = await self.store.get_auth_chain_ids(
  607. room_id, state_event_ids
  608. )
  609. # if the caller has opted in, we can omit any auth_chain events which are
  610. # already in state_event_ids
  611. if caller_supports_partial_state:
  612. auth_chain_event_ids.difference_update(state_event_ids)
  613. auth_chain_events = await self.store.get_events_as_list(auth_chain_event_ids)
  614. state_events = await self.store.get_events_as_list(state_event_ids)
  615. # we try to do all the async stuff before this point, so that time_now is as
  616. # accurate as possible.
  617. time_now = self._clock.time_msec()
  618. event_json = event.get_pdu_json(time_now)
  619. resp = {
  620. "event": event_json,
  621. "state": [p.get_pdu_json(time_now) for p in state_events],
  622. "auth_chain": [p.get_pdu_json(time_now) for p in auth_chain_events],
  623. "members_omitted": caller_supports_partial_state,
  624. }
  625. if servers_in_room is not None:
  626. resp["servers_in_room"] = list(servers_in_room)
  627. return resp
  628. async def on_make_leave_request(
  629. self, origin: str, room_id: str, user_id: str
  630. ) -> Dict[str, Any]:
  631. origin_host, _ = parse_server_name(origin)
  632. await self.check_server_matches_acl(origin_host, room_id)
  633. pdu = await self.handler.on_make_leave_request(origin, room_id, user_id)
  634. room_version = await self.store.get_room_version_id(room_id)
  635. return {"event": pdu.get_templated_pdu_json(), "room_version": room_version}
  636. async def on_send_leave_request(
  637. self, origin: str, content: JsonDict, room_id: str
  638. ) -> dict:
  639. logger.debug("on_send_leave_request: content: %s", content)
  640. await self._on_send_membership_event(origin, content, Membership.LEAVE, room_id)
  641. return {}
  642. async def on_make_knock_request(
  643. self, origin: str, room_id: str, user_id: str, supported_versions: List[str]
  644. ) -> JsonDict:
  645. """We've received a /make_knock/ request, so we create a partial knock
  646. event for the room and hand that back, along with the room version, to the knocking
  647. homeserver. We do *not* persist or process this event until the other server has
  648. signed it and sent it back.
  649. Args:
  650. origin: The (verified) server name of the requesting server.
  651. room_id: The room to create the knock event in.
  652. user_id: The user to create the knock for.
  653. supported_versions: The room versions supported by the requesting server.
  654. Returns:
  655. The partial knock event.
  656. """
  657. origin_host, _ = parse_server_name(origin)
  658. if await self.store.is_partial_state_room(room_id):
  659. # Before we do anything: check if the room is partial-stated.
  660. # Note that at the time this check was added, `on_make_knock_request` would
  661. # block due to https://github.com/matrix-org/synapse/issues/12997.
  662. raise SynapseError(
  663. 404,
  664. "Unable to handle /make_knock right now; this server is not fully joined.",
  665. errcode=Codes.NOT_FOUND,
  666. )
  667. await self.check_server_matches_acl(origin_host, room_id)
  668. room_version = await self.store.get_room_version(room_id)
  669. # Check that this room version is supported by the remote homeserver
  670. if room_version.identifier not in supported_versions:
  671. logger.warning(
  672. "Room version %s not in %s", room_version.identifier, supported_versions
  673. )
  674. raise IncompatibleRoomVersionError(room_version=room_version.identifier)
  675. # Check that this room supports knocking as defined by its room version
  676. if not room_version.knock_join_rule:
  677. raise SynapseError(
  678. 403,
  679. "This room version does not support knocking",
  680. errcode=Codes.FORBIDDEN,
  681. )
  682. pdu = await self.handler.on_make_knock_request(origin, room_id, user_id)
  683. return {
  684. "event": pdu.get_templated_pdu_json(),
  685. "room_version": room_version.identifier,
  686. }
  687. async def on_send_knock_request(
  688. self,
  689. origin: str,
  690. content: JsonDict,
  691. room_id: str,
  692. ) -> Dict[str, List[JsonDict]]:
  693. """
  694. We have received a knock event for a room. Verify and send the event into the room
  695. on the knocking homeserver's behalf. Then reply with some stripped state from the
  696. room for the knockee.
  697. Args:
  698. origin: The remote homeserver of the knocking user.
  699. content: The content of the request.
  700. room_id: The ID of the room to knock on.
  701. Returns:
  702. The stripped room state.
  703. """
  704. _, context = await self._on_send_membership_event(
  705. origin, content, Membership.KNOCK, room_id
  706. )
  707. # Retrieve stripped state events from the room and send them back to the remote
  708. # server. This will allow the remote server's clients to display information
  709. # related to the room while the knock request is pending.
  710. stripped_room_state = (
  711. await self.store.get_stripped_room_state_from_event_context(
  712. context, self._room_prejoin_state_types
  713. )
  714. )
  715. return {"knock_room_state": stripped_room_state}
  716. async def _on_send_membership_event(
  717. self, origin: str, content: JsonDict, membership_type: str, room_id: str
  718. ) -> Tuple[EventBase, EventContext]:
  719. """Handle an on_send_{join,leave,knock} request
  720. Does some preliminary validation before passing the request on to the
  721. federation handler.
  722. Args:
  723. origin: The (authenticated) requesting server
  724. content: The body of the send_* request - a complete membership event
  725. membership_type: The expected membership type (join or leave, depending
  726. on the endpoint)
  727. room_id: The room_id from the request, to be validated against the room_id
  728. in the event
  729. Returns:
  730. The event and context of the event after inserting it into the room graph.
  731. Raises:
  732. SynapseError if there is a problem with the request, including things like
  733. the room_id not matching or the event not being authorized.
  734. """
  735. assert_params_in_dict(content, ["room_id"])
  736. if content["room_id"] != room_id:
  737. raise SynapseError(
  738. 400,
  739. "Room ID in body does not match that in request path",
  740. Codes.BAD_JSON,
  741. )
  742. # Note that get_room_version throws if the room does not exist here.
  743. room_version = await self.store.get_room_version(room_id)
  744. if await self.store.is_partial_state_room(room_id):
  745. # If our server is still only partially joined, we can't give a complete
  746. # response to /send_join, /send_knock or /send_leave.
  747. # This is because we will not be able to provide the server list (for partial
  748. # joins) or the full state (for full joins).
  749. # Return a 404 as we would if we weren't in the room at all.
  750. logger.info(
  751. f"Rejecting /send_{membership_type} to %s because it's a partial state room",
  752. room_id,
  753. )
  754. raise SynapseError(
  755. 404,
  756. f"Unable to handle /send_{membership_type} right now; this server is not fully joined.",
  757. errcode=Codes.NOT_FOUND,
  758. )
  759. if membership_type == Membership.KNOCK and not room_version.knock_join_rule:
  760. raise SynapseError(
  761. 403,
  762. "This room version does not support knocking",
  763. errcode=Codes.FORBIDDEN,
  764. )
  765. event = event_from_pdu_json(content, room_version)
  766. if event.type != EventTypes.Member or not event.is_state():
  767. raise SynapseError(400, "Not an m.room.member event", Codes.BAD_JSON)
  768. if event.content.get("membership") != membership_type:
  769. raise SynapseError(400, "Not a %s event" % membership_type, Codes.BAD_JSON)
  770. origin_host, _ = parse_server_name(origin)
  771. await self.check_server_matches_acl(origin_host, event.room_id)
  772. logger.debug("_on_send_membership_event: pdu sigs: %s", event.signatures)
  773. # Sign the event since we're vouching on behalf of the remote server that
  774. # the event is valid to be sent into the room. Currently this is only done
  775. # if the user is being joined via restricted join rules.
  776. if (
  777. room_version.restricted_join_rule
  778. and event.membership == Membership.JOIN
  779. and EventContentFields.AUTHORISING_USER in event.content
  780. ):
  781. # We can only authorise our own users.
  782. authorising_server = get_domain_from_id(
  783. event.content[EventContentFields.AUTHORISING_USER]
  784. )
  785. if not self._is_mine_server_name(authorising_server):
  786. raise SynapseError(
  787. 400,
  788. f"Cannot authorise membership event for {authorising_server}. We can only authorise requests from our own homeserver",
  789. )
  790. event.signatures.update(
  791. compute_event_signature(
  792. room_version,
  793. event.get_pdu_json(),
  794. self.hs.hostname,
  795. self.hs.signing_key,
  796. )
  797. )
  798. try:
  799. event = await self._check_sigs_and_hash(room_version, event)
  800. except InvalidEventSignatureError as e:
  801. errmsg = f"event id {event.event_id}: {e}"
  802. logger.warning("%s", errmsg)
  803. raise SynapseError(403, errmsg, Codes.FORBIDDEN)
  804. try:
  805. return await self._federation_event_handler.on_send_membership_event(
  806. origin, event
  807. )
  808. except PartialStateConflictError:
  809. # The room was un-partial stated while we were persisting the event.
  810. # Try once more, with full state this time.
  811. logger.info(
  812. "Room %s was un-partial stated during `on_send_membership_event`, trying again.",
  813. room_id,
  814. )
  815. return await self._federation_event_handler.on_send_membership_event(
  816. origin, event
  817. )
  818. async def on_event_auth(
  819. self, origin: str, room_id: str, event_id: str
  820. ) -> Tuple[int, Dict[str, Any]]:
  821. async with self._server_linearizer.queue((origin, room_id)):
  822. await self._event_auth_handler.assert_host_in_room(room_id, origin)
  823. origin_host, _ = parse_server_name(origin)
  824. await self.check_server_matches_acl(origin_host, room_id)
  825. time_now = self._clock.time_msec()
  826. auth_pdus = await self.handler.on_event_auth(event_id)
  827. res = {"auth_chain": [a.get_pdu_json(time_now) for a in auth_pdus]}
  828. return 200, res
  829. async def on_query_client_keys(
  830. self, origin: str, content: Dict[str, str]
  831. ) -> Tuple[int, Dict[str, Any]]:
  832. return await self.on_query_request("client_keys", content)
  833. async def on_query_user_devices(
  834. self, origin: str, user_id: str
  835. ) -> Tuple[int, Dict[str, Any]]:
  836. keys = await self.device_handler.on_federation_query_user_devices(user_id)
  837. return 200, keys
  838. @trace
  839. async def on_claim_client_keys(
  840. self, query: List[Tuple[str, str, str, int]], always_include_fallback_keys: bool
  841. ) -> Dict[str, Any]:
  842. if any(
  843. not self.hs.is_mine(UserID.from_string(user_id))
  844. for user_id, _, _, _ in query
  845. ):
  846. raise SynapseError(400, "User is not hosted on this homeserver")
  847. log_kv({"message": "Claiming one time keys.", "user, device pairs": query})
  848. results = await self._e2e_keys_handler.claim_local_one_time_keys(
  849. query, always_include_fallback_keys=always_include_fallback_keys
  850. )
  851. json_result: Dict[str, Dict[str, Dict[str, JsonDict]]] = {}
  852. for result in results:
  853. for user_id, device_keys in result.items():
  854. for device_id, keys in device_keys.items():
  855. for key_id, key in keys.items():
  856. json_result.setdefault(user_id, {}).setdefault(device_id, {})[
  857. key_id
  858. ] = key
  859. logger.info(
  860. "Claimed one-time-keys: %s",
  861. ",".join(
  862. (
  863. "%s for %s:%s" % (key_id, user_id, device_id)
  864. for user_id, user_keys in json_result.items()
  865. for device_id, device_keys in user_keys.items()
  866. for key_id, _ in device_keys.items()
  867. )
  868. ),
  869. )
  870. return {"one_time_keys": json_result}
  871. async def on_get_missing_events(
  872. self,
  873. origin: str,
  874. room_id: str,
  875. earliest_events: List[str],
  876. latest_events: List[str],
  877. limit: int,
  878. ) -> Dict[str, list]:
  879. async with self._server_linearizer.queue((origin, room_id)):
  880. origin_host, _ = parse_server_name(origin)
  881. await self.check_server_matches_acl(origin_host, room_id)
  882. logger.debug(
  883. "on_get_missing_events: earliest_events: %r, latest_events: %r,"
  884. " limit: %d",
  885. earliest_events,
  886. latest_events,
  887. limit,
  888. )
  889. missing_events = await self.handler.on_get_missing_events(
  890. origin, room_id, earliest_events, latest_events, limit
  891. )
  892. if len(missing_events) < 5:
  893. logger.debug(
  894. "Returning %d events: %r", len(missing_events), missing_events
  895. )
  896. else:
  897. logger.debug("Returning %d events", len(missing_events))
  898. time_now = self._clock.time_msec()
  899. return {"events": [ev.get_pdu_json(time_now) for ev in missing_events]}
  900. async def on_openid_userinfo(self, token: str) -> Optional[str]:
  901. ts_now_ms = self._clock.time_msec()
  902. return await self.store.get_user_id_for_open_id_token(token, ts_now_ms)
  903. def _transaction_dict_from_pdus(self, pdu_list: List[EventBase]) -> JsonDict:
  904. """Returns a new Transaction containing the given PDUs suitable for
  905. transmission.
  906. """
  907. time_now = self._clock.time_msec()
  908. pdus = [p.get_pdu_json(time_now) for p in pdu_list]
  909. return Transaction(
  910. # Just need a dummy transaction ID and destination since it won't be used.
  911. transaction_id="",
  912. origin=self.server_name,
  913. pdus=pdus,
  914. origin_server_ts=int(time_now),
  915. destination="",
  916. ).get_dict()
  917. async def _handle_received_pdu(self, origin: str, pdu: EventBase) -> None:
  918. """Process a PDU received in a federation /send/ transaction.
  919. If the event is invalid, then this method throws a FederationError.
  920. (The error will then be logged and sent back to the sender (which
  921. probably won't do anything with it), and other events in the
  922. transaction will be processed as normal).
  923. It is likely that we'll then receive other events which refer to
  924. this rejected_event in their prev_events, etc. When that happens,
  925. we'll attempt to fetch the rejected event again, which will presumably
  926. fail, so those second-generation events will also get rejected.
  927. Eventually, we get to the point where there are more than 10 events
  928. between any new events and the original rejected event. Since we
  929. only try to backfill 10 events deep on received pdu, we then accept the
  930. new event, possibly introducing a discontinuity in the DAG, with new
  931. forward extremities, so normal service is approximately returned,
  932. until we try to backfill across the discontinuity.
  933. Args:
  934. origin: server which sent the pdu
  935. pdu: received pdu
  936. Raises: FederationError if the signatures / hash do not match, or
  937. if the event was unacceptable for any other reason (eg, too large,
  938. too many prev_events, couldn't find the prev_events)
  939. """
  940. # We've already checked that we know the room version by this point
  941. room_version = await self.store.get_room_version(pdu.room_id)
  942. # Check signature.
  943. try:
  944. pdu = await self._check_sigs_and_hash(room_version, pdu)
  945. except InvalidEventSignatureError as e:
  946. logger.warning("event id %s: %s", pdu.event_id, e)
  947. raise FederationError("ERROR", 403, str(e), affected=pdu.event_id)
  948. if await self._spam_checker_module_callbacks.should_drop_federated_event(pdu):
  949. logger.warning(
  950. "Unstaged federated event contains spam, dropping %s", pdu.event_id
  951. )
  952. return
  953. # Add the event to our staging area
  954. await self.store.insert_received_event_to_staging(origin, pdu)
  955. # Try and acquire the processing lock for the room, if we get it start a
  956. # background process for handling the events in the room.
  957. lock = await self.store.try_acquire_lock(
  958. _INBOUND_EVENT_HANDLING_LOCK_NAME, pdu.room_id
  959. )
  960. if lock:
  961. self._process_incoming_pdus_in_room_inner(
  962. pdu.room_id, room_version, lock, origin, pdu
  963. )
  964. async def _get_next_nonspam_staged_event_for_room(
  965. self, room_id: str, room_version: RoomVersion
  966. ) -> Optional[Tuple[str, EventBase]]:
  967. """Fetch the first non-spam event from staging queue.
  968. Args:
  969. room_id: the room to fetch the first non-spam event in.
  970. room_version: the version of the room.
  971. Returns:
  972. The first non-spam event in that room.
  973. """
  974. while True:
  975. # We need to do this check outside the lock to avoid a race between
  976. # a new event being inserted by another instance and it attempting
  977. # to acquire the lock.
  978. next = await self.store.get_next_staged_event_for_room(
  979. room_id, room_version
  980. )
  981. if next is None:
  982. return None
  983. origin, event = next
  984. if await self._spam_checker_module_callbacks.should_drop_federated_event(
  985. event
  986. ):
  987. logger.warning(
  988. "Staged federated event contains spam, dropping %s",
  989. event.event_id,
  990. )
  991. continue
  992. return next
  993. @wrap_as_background_process("_process_incoming_pdus_in_room_inner")
  994. async def _process_incoming_pdus_in_room_inner(
  995. self,
  996. room_id: str,
  997. room_version: RoomVersion,
  998. lock: Lock,
  999. latest_origin: Optional[str] = None,
  1000. latest_event: Optional[EventBase] = None,
  1001. ) -> None:
  1002. """Process events in the staging area for the given room.
  1003. The latest_origin and latest_event args are the latest origin and event
  1004. received (or None to simply pull the next event from the database).
  1005. """
  1006. # The common path is for the event we just received be the only event in
  1007. # the room, so instead of pulling the event out of the DB and parsing
  1008. # the event we just pull out the next event ID and check if that matches.
  1009. if latest_event is not None and latest_origin is not None:
  1010. result = await self.store.get_next_staged_event_id_for_room(room_id)
  1011. if result is None:
  1012. latest_origin = None
  1013. latest_event = None
  1014. else:
  1015. next_origin, next_event_id = result
  1016. if (
  1017. next_origin != latest_origin
  1018. or next_event_id != latest_event.event_id
  1019. ):
  1020. latest_origin = None
  1021. latest_event = None
  1022. if latest_origin is None or latest_event is None:
  1023. next = await self.store.get_next_staged_event_for_room(
  1024. room_id, room_version
  1025. )
  1026. if not next:
  1027. await lock.release()
  1028. return
  1029. origin, event = next
  1030. else:
  1031. origin = latest_origin
  1032. event = latest_event
  1033. # We loop round until there are no more events in the room in the
  1034. # staging area, or we fail to get the lock (which means another process
  1035. # has started processing).
  1036. while True:
  1037. async with lock:
  1038. logger.info("handling received PDU in room %s: %s", room_id, event)
  1039. try:
  1040. with nested_logging_context(event.event_id):
  1041. # We're taking out a lock within a lock, which could
  1042. # lead to deadlocks if we're not careful. However, it is
  1043. # safe on this occasion as we only ever take a write
  1044. # lock when deleting a room, which we would never do
  1045. # while holding the `_INBOUND_EVENT_HANDLING_LOCK_NAME`
  1046. # lock.
  1047. async with self._worker_lock_handler.acquire_read_write_lock(
  1048. NEW_EVENT_DURING_PURGE_LOCK_NAME, room_id, write=False
  1049. ):
  1050. await self._federation_event_handler.on_receive_pdu(
  1051. origin, event
  1052. )
  1053. except FederationError as e:
  1054. # XXX: Ideally we'd inform the remote we failed to process
  1055. # the event, but we can't return an error in the transaction
  1056. # response (as we've already responded).
  1057. logger.warning("Error handling PDU %s: %s", event.event_id, e)
  1058. except Exception:
  1059. f = failure.Failure()
  1060. logger.error(
  1061. "Failed to handle PDU %s",
  1062. event.event_id,
  1063. exc_info=(f.type, f.value, f.getTracebackObject()),
  1064. )
  1065. received_ts = await self.store.remove_received_event_from_staging(
  1066. origin, event.event_id
  1067. )
  1068. if received_ts is not None:
  1069. pdu_process_time.observe(
  1070. (self._clock.time_msec() - received_ts) / 1000
  1071. )
  1072. next = await self._get_next_nonspam_staged_event_for_room(
  1073. room_id, room_version
  1074. )
  1075. if not next:
  1076. break
  1077. origin, event = next
  1078. # Prune the event queue if it's getting large.
  1079. #
  1080. # We do this *after* handling the first event as the common case is
  1081. # that the queue is empty (/has the single event in), and so there's
  1082. # no need to do this check.
  1083. pruned = await self.store.prune_staged_events_in_room(room_id, room_version)
  1084. if pruned:
  1085. # If we have pruned the queue check we need to refetch the next
  1086. # event to handle.
  1087. next = await self.store.get_next_staged_event_for_room(
  1088. room_id, room_version
  1089. )
  1090. if not next:
  1091. break
  1092. origin, event = next
  1093. new_lock = await self.store.try_acquire_lock(
  1094. _INBOUND_EVENT_HANDLING_LOCK_NAME, room_id
  1095. )
  1096. if not new_lock:
  1097. return
  1098. lock = new_lock
  1099. async def exchange_third_party_invite(
  1100. self, sender_user_id: str, target_user_id: str, room_id: str, signed: Dict
  1101. ) -> None:
  1102. await self.handler.exchange_third_party_invite(
  1103. sender_user_id, target_user_id, room_id, signed
  1104. )
  1105. async def on_exchange_third_party_invite_request(self, event_dict: Dict) -> None:
  1106. await self.handler.on_exchange_third_party_invite_request(event_dict)
  1107. async def check_server_matches_acl(self, server_name: str, room_id: str) -> None:
  1108. """Check if the given server is allowed by the server ACLs in the room
  1109. Args:
  1110. server_name: name of server, *without any port part*
  1111. room_id: ID of the room to check
  1112. Raises:
  1113. AuthError if the server does not match the ACL
  1114. """
  1115. server_acl_evaluator = (
  1116. await self._storage_controllers.state.get_server_acl_for_room(room_id)
  1117. )
  1118. if server_acl_evaluator and not server_acl_evaluator.server_matches_acl_event(
  1119. server_name
  1120. ):
  1121. raise AuthError(code=403, msg="Server is banned from room")
  1122. class FederationHandlerRegistry:
  1123. """Allows classes to register themselves as handlers for a given EDU or
  1124. query type for incoming federation traffic.
  1125. """
  1126. def __init__(self, hs: "HomeServer"):
  1127. self.config = hs.config
  1128. self.clock = hs.get_clock()
  1129. self._instance_name = hs.get_instance_name()
  1130. # These are safe to load in monolith mode, but will explode if we try
  1131. # and use them. However we have guards before we use them to ensure that
  1132. # we don't route to ourselves, and in monolith mode that will always be
  1133. # the case.
  1134. self._get_query_client = ReplicationGetQueryRestServlet.make_client(hs)
  1135. self._send_edu = ReplicationFederationSendEduRestServlet.make_client(hs)
  1136. self.edu_handlers: Dict[str, Callable[[str, dict], Awaitable[None]]] = {}
  1137. self.query_handlers: Dict[str, Callable[[dict], Awaitable[JsonDict]]] = {}
  1138. # Map from type to instance names that we should route EDU handling to.
  1139. # We randomly choose one instance from the list to route to for each new
  1140. # EDU received.
  1141. self._edu_type_to_instance: Dict[str, List[str]] = {}
  1142. def register_edu_handler(
  1143. self, edu_type: str, handler: Callable[[str, JsonDict], Awaitable[None]]
  1144. ) -> None:
  1145. """Sets the handler callable that will be used to handle an incoming
  1146. federation EDU of the given type.
  1147. Args:
  1148. edu_type: The type of the incoming EDU to register handler for
  1149. handler: A callable invoked on incoming EDU
  1150. of the given type. The arguments are the origin server name and
  1151. the EDU contents.
  1152. """
  1153. if edu_type in self.edu_handlers:
  1154. raise KeyError("Already have an EDU handler for %s" % (edu_type,))
  1155. logger.info("Registering federation EDU handler for %r", edu_type)
  1156. self.edu_handlers[edu_type] = handler
  1157. def register_query_handler(
  1158. self, query_type: str, handler: Callable[[dict], Awaitable[JsonDict]]
  1159. ) -> None:
  1160. """Sets the handler callable that will be used to handle an incoming
  1161. federation query of the given type.
  1162. Args:
  1163. query_type: Category name of the query, which should match
  1164. the string used by make_query.
  1165. handler: Invoked to handle
  1166. incoming queries of this type. The return will be yielded
  1167. on and the result used as the response to the query request.
  1168. """
  1169. if query_type in self.query_handlers:
  1170. raise KeyError("Already have a Query handler for %s" % (query_type,))
  1171. logger.info("Registering federation query handler for %r", query_type)
  1172. self.query_handlers[query_type] = handler
  1173. def register_instances_for_edu(
  1174. self, edu_type: str, instance_names: List[str]
  1175. ) -> None:
  1176. """Register that the EDU handler is on multiple instances."""
  1177. self._edu_type_to_instance[edu_type] = instance_names
  1178. async def on_edu(self, edu_type: str, origin: str, content: dict) -> None:
  1179. if not self.config.server.track_presence and edu_type == EduTypes.PRESENCE:
  1180. return
  1181. # Check if we have a handler on this instance
  1182. handler = self.edu_handlers.get(edu_type)
  1183. if handler:
  1184. with start_active_span_from_edu(content, "handle_edu"):
  1185. try:
  1186. await handler(origin, content)
  1187. except SynapseError as e:
  1188. logger.info("Failed to handle edu %r: %r", edu_type, e)
  1189. except Exception:
  1190. logger.exception("Failed to handle edu %r", edu_type)
  1191. return
  1192. # Check if we can route it somewhere else that isn't us
  1193. instances = self._edu_type_to_instance.get(edu_type, ["master"])
  1194. if self._instance_name not in instances:
  1195. # Pick an instance randomly so that we don't overload one.
  1196. route_to = random.choice(instances)
  1197. try:
  1198. await self._send_edu(
  1199. instance_name=route_to,
  1200. edu_type=edu_type,
  1201. origin=origin,
  1202. content=content,
  1203. )
  1204. except SynapseError as e:
  1205. logger.info("Failed to handle edu %r: %r", edu_type, e)
  1206. except Exception:
  1207. logger.exception("Failed to handle edu %r", edu_type)
  1208. return
  1209. # Oh well, let's just log and move on.
  1210. logger.warning("No handler registered for EDU type %s", edu_type)
  1211. async def on_query(self, query_type: str, args: dict) -> JsonDict:
  1212. handler = self.query_handlers.get(query_type)
  1213. if handler:
  1214. return await handler(args)
  1215. # Check if we can route it somewhere else that isn't us
  1216. if self._instance_name == "master":
  1217. return await self._get_query_client(query_type=query_type, args=args)
  1218. # Uh oh, no handler! Let's raise an exception so the request returns an
  1219. # error.
  1220. logger.warning("No handler registered for query type %s", query_type)
  1221. raise NotFoundError("No handler for Query type '%s'" % (query_type,))
  1222. def _get_event_ids_for_partial_state_join(
  1223. join_event: EventBase,
  1224. prev_state_ids: StateMap[str],
  1225. summary: Mapping[str, MemberSummary],
  1226. ) -> Collection[str]:
  1227. """Calculate state to be returned in a partial_state send_join
  1228. Args:
  1229. join_event: the join event being send_joined
  1230. prev_state_ids: the event ids of the state before the join
  1231. Returns:
  1232. the event ids to be returned
  1233. """
  1234. # return all non-member events
  1235. state_event_ids = {
  1236. event_id
  1237. for (event_type, state_key), event_id in prev_state_ids.items()
  1238. if event_type != EventTypes.Member
  1239. }
  1240. # we also need the current state of the current user (it's going to
  1241. # be an auth event for the new join, so we may as well return it)
  1242. current_membership_event_id = prev_state_ids.get(
  1243. (EventTypes.Member, join_event.state_key)
  1244. )
  1245. if current_membership_event_id is not None:
  1246. state_event_ids.add(current_membership_event_id)
  1247. name_id = prev_state_ids.get((EventTypes.Name, ""))
  1248. canonical_alias_id = prev_state_ids.get((EventTypes.CanonicalAlias, ""))
  1249. if not name_id and not canonical_alias_id:
  1250. # Also include the hero members of the room (for DM rooms without a title).
  1251. # To do this properly, we should select the correct subset of membership events
  1252. # from `prev_state_ids`. Instead, we are lazier and use the (cached)
  1253. # `get_room_summary` function, which is based on the current state of the room.
  1254. # This introduces races; we choose to ignore them because a) they should be rare
  1255. # and b) even if it's wrong, joining servers will get the full state eventually.
  1256. heroes = extract_heroes_from_room_summary(summary, join_event.state_key)
  1257. for hero in heroes:
  1258. membership_event_id = prev_state_ids.get((EventTypes.Member, hero))
  1259. if membership_event_id:
  1260. state_event_ids.add(membership_event_id)
  1261. return state_event_ids