You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 
 
 
 

1059 line
34 KiB

  1. # Copyright 2014-2022 The Matrix.org Foundation C.I.C.
  2. # Copyright 2020 Sorunome
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. import logging
  16. import urllib
  17. from typing import (
  18. TYPE_CHECKING,
  19. Any,
  20. Callable,
  21. Collection,
  22. Dict,
  23. Generator,
  24. Iterable,
  25. List,
  26. Mapping,
  27. Optional,
  28. Tuple,
  29. Union,
  30. )
  31. import attr
  32. import ijson
  33. from synapse.api.constants import Direction, Membership
  34. from synapse.api.errors import Codes, HttpResponseException, SynapseError
  35. from synapse.api.room_versions import RoomVersion
  36. from synapse.api.urls import (
  37. FEDERATION_UNSTABLE_PREFIX,
  38. FEDERATION_V1_PREFIX,
  39. FEDERATION_V2_PREFIX,
  40. )
  41. from synapse.events import EventBase, make_event_from_dict
  42. from synapse.federation.units import Transaction
  43. from synapse.http.matrixfederationclient import ByteParser, LegacyJsonSendParser
  44. from synapse.http.types import QueryParams
  45. from synapse.types import JsonDict, UserID
  46. from synapse.util import ExceptionBundle
  47. if TYPE_CHECKING:
  48. from synapse.app.homeserver import HomeServer
  49. logger = logging.getLogger(__name__)
  50. class TransportLayerClient:
  51. """Sends federation HTTP requests to other servers"""
  52. def __init__(self, hs: "HomeServer"):
  53. self.client = hs.get_federation_http_client()
  54. self._is_mine_server_name = hs.is_mine_server_name
  55. async def get_room_state_ids(
  56. self, destination: str, room_id: str, event_id: str
  57. ) -> JsonDict:
  58. """Requests the IDs of all state for a given room at the given event.
  59. Args:
  60. destination: The host name of the remote homeserver we want
  61. to get the state from.
  62. room_id: the room we want the state of
  63. event_id: The event we want the context at.
  64. Returns:
  65. Results in a dict received from the remote homeserver.
  66. """
  67. logger.debug("get_room_state_ids dest=%s, room=%s", destination, room_id)
  68. path = _create_v1_path("/state_ids/%s", room_id)
  69. return await self.client.get_json(
  70. destination,
  71. path=path,
  72. args={"event_id": event_id},
  73. try_trailing_slash_on_400=True,
  74. )
  75. async def get_room_state(
  76. self, room_version: RoomVersion, destination: str, room_id: str, event_id: str
  77. ) -> "StateRequestResponse":
  78. """Requests the full state for a given room at the given event.
  79. Args:
  80. room_version: the version of the room (required to build the event objects)
  81. destination: The host name of the remote homeserver we want
  82. to get the state from.
  83. room_id: the room we want the state of
  84. event_id: The event we want the context at.
  85. Returns:
  86. Results in a dict received from the remote homeserver.
  87. """
  88. path = _create_v1_path("/state/%s", room_id)
  89. return await self.client.get_json(
  90. destination,
  91. path=path,
  92. args={"event_id": event_id},
  93. # This can take a looooooong time for large rooms. Give this a generous
  94. # timeout of 10 minutes to avoid the partial state resync timing out early
  95. # and trying a bunch of servers who haven't seen our join yet.
  96. timeout=600_000,
  97. parser=_StateParser(room_version),
  98. )
  99. async def get_event(
  100. self, destination: str, event_id: str, timeout: Optional[int] = None
  101. ) -> JsonDict:
  102. """Requests the pdu with give id and origin from the given server.
  103. Args:
  104. destination: The host name of the remote homeserver we want
  105. to get the state from.
  106. event_id: The id of the event being requested.
  107. timeout: How long to try (in ms) the destination for before
  108. giving up. None indicates no timeout.
  109. Returns:
  110. Results in a dict received from the remote homeserver.
  111. """
  112. logger.debug("get_pdu dest=%s, event_id=%s", destination, event_id)
  113. path = _create_v1_path("/event/%s", event_id)
  114. return await self.client.get_json(
  115. destination, path=path, timeout=timeout, try_trailing_slash_on_400=True
  116. )
  117. async def backfill(
  118. self, destination: str, room_id: str, event_tuples: Collection[str], limit: int
  119. ) -> Optional[Union[JsonDict, list]]:
  120. """Requests `limit` previous PDUs in a given context before list of
  121. PDUs.
  122. Args:
  123. destination
  124. room_id
  125. event_tuples:
  126. Must be a Collection that is falsy when empty.
  127. (Iterable is not enough here!)
  128. limit
  129. Returns:
  130. Results in a dict received from the remote homeserver.
  131. """
  132. logger.debug(
  133. "backfill dest=%s, room_id=%s, event_tuples=%r, limit=%s",
  134. destination,
  135. room_id,
  136. event_tuples,
  137. str(limit),
  138. )
  139. if not event_tuples:
  140. # TODO: raise?
  141. return None
  142. path = _create_v1_path("/backfill/%s", room_id)
  143. args = {"v": event_tuples, "limit": [str(limit)]}
  144. return await self.client.get_json(
  145. destination, path=path, args=args, try_trailing_slash_on_400=True
  146. )
  147. async def timestamp_to_event(
  148. self, destination: str, room_id: str, timestamp: int, direction: Direction
  149. ) -> Union[JsonDict, List]:
  150. """
  151. Calls a remote federating server at `destination` asking for their
  152. closest event to the given timestamp in the given direction.
  153. Args:
  154. destination: Domain name of the remote homeserver
  155. room_id: Room to fetch the event from
  156. timestamp: The point in time (inclusive) we should navigate from in
  157. the given direction to find the closest event.
  158. direction: indicates whether we should navigate forward
  159. or backward from the given timestamp to find the closest event.
  160. Returns:
  161. Response dict received from the remote homeserver.
  162. Raises:
  163. Various exceptions when the request fails
  164. """
  165. path = _create_v1_path(
  166. "/timestamp_to_event/%s",
  167. room_id,
  168. )
  169. args = {"ts": [str(timestamp)], "dir": [direction.value]}
  170. remote_response = await self.client.get_json(
  171. destination, path=path, args=args, try_trailing_slash_on_400=True
  172. )
  173. return remote_response
  174. async def send_transaction(
  175. self,
  176. transaction: Transaction,
  177. json_data_callback: Optional[Callable[[], JsonDict]] = None,
  178. ) -> JsonDict:
  179. """Sends the given Transaction to its destination
  180. Args:
  181. transaction
  182. Returns:
  183. Succeeds when we get a 2xx HTTP response. The result
  184. will be the decoded JSON body.
  185. Fails with ``HTTPRequestException`` if we get an HTTP response
  186. code >= 300.
  187. Fails with ``NotRetryingDestination`` if we are not yet ready
  188. to retry this server.
  189. Fails with ``FederationDeniedError`` if this destination
  190. is not on our federation whitelist
  191. """
  192. logger.debug(
  193. "send_data dest=%s, txid=%s",
  194. transaction.destination,
  195. transaction.transaction_id,
  196. )
  197. if self._is_mine_server_name(transaction.destination):
  198. raise RuntimeError("Transport layer cannot send to itself!")
  199. # FIXME: This is only used by the tests. The actual json sent is
  200. # generated by the json_data_callback.
  201. json_data = transaction.get_dict()
  202. path = _create_v1_path("/send/%s", transaction.transaction_id)
  203. return await self.client.put_json(
  204. transaction.destination,
  205. path=path,
  206. data=json_data,
  207. json_data_callback=json_data_callback,
  208. long_retries=True,
  209. try_trailing_slash_on_400=True,
  210. # Sending a transaction should always succeed, if it doesn't
  211. # then something is wrong and we should backoff.
  212. backoff_on_all_error_codes=True,
  213. )
  214. async def make_query(
  215. self,
  216. destination: str,
  217. query_type: str,
  218. args: QueryParams,
  219. retry_on_dns_fail: bool,
  220. ignore_backoff: bool = False,
  221. prefix: str = FEDERATION_V1_PREFIX,
  222. ) -> JsonDict:
  223. path = _create_path(prefix, "/query/%s", query_type)
  224. return await self.client.get_json(
  225. destination=destination,
  226. path=path,
  227. args=args,
  228. retry_on_dns_fail=retry_on_dns_fail,
  229. timeout=10000,
  230. ignore_backoff=ignore_backoff,
  231. )
  232. async def make_membership_event(
  233. self,
  234. destination: str,
  235. room_id: str,
  236. user_id: str,
  237. membership: str,
  238. params: Optional[Mapping[str, Union[str, Iterable[str]]]],
  239. ) -> JsonDict:
  240. """Asks a remote server to build and sign us a membership event
  241. Note that this does not append any events to any graphs.
  242. Args:
  243. destination: address of remote homeserver
  244. room_id: room to join/leave
  245. user_id: user to be joined/left
  246. membership: one of join/leave
  247. params: Query parameters to include in the request.
  248. Returns:
  249. Succeeds when we get a 2xx HTTP response. The result
  250. will be the decoded JSON body (ie, the new event).
  251. Fails with ``HTTPRequestException`` if we get an HTTP response
  252. code >= 300.
  253. Fails with ``NotRetryingDestination`` if we are not yet ready
  254. to retry this server.
  255. Fails with ``FederationDeniedError`` if the remote destination
  256. is not in our federation whitelist
  257. """
  258. valid_memberships = {Membership.JOIN, Membership.LEAVE, Membership.KNOCK}
  259. if membership not in valid_memberships:
  260. raise RuntimeError(
  261. "make_membership_event called with membership='%s', must be one of %s"
  262. % (membership, ",".join(valid_memberships))
  263. )
  264. path = _create_v1_path("/make_%s/%s/%s", membership, room_id, user_id)
  265. ignore_backoff = False
  266. retry_on_dns_fail = False
  267. if membership == Membership.LEAVE:
  268. # we particularly want to do our best to send leave events. The
  269. # problem is that if it fails, we won't retry it later, so if the
  270. # remote server was just having a momentary blip, the room will be
  271. # out of sync.
  272. ignore_backoff = True
  273. retry_on_dns_fail = True
  274. return await self.client.get_json(
  275. destination=destination,
  276. path=path,
  277. args=params,
  278. retry_on_dns_fail=retry_on_dns_fail,
  279. timeout=20000,
  280. ignore_backoff=ignore_backoff,
  281. )
  282. async def send_join_v1(
  283. self,
  284. room_version: RoomVersion,
  285. destination: str,
  286. room_id: str,
  287. event_id: str,
  288. content: JsonDict,
  289. ) -> "SendJoinResponse":
  290. path = _create_v1_path("/send_join/%s/%s", room_id, event_id)
  291. return await self.client.put_json(
  292. destination=destination,
  293. path=path,
  294. data=content,
  295. parser=SendJoinParser(room_version, v1_api=True),
  296. )
  297. async def send_join_v2(
  298. self,
  299. room_version: RoomVersion,
  300. destination: str,
  301. room_id: str,
  302. event_id: str,
  303. content: JsonDict,
  304. omit_members: bool,
  305. ) -> "SendJoinResponse":
  306. path = _create_v2_path("/send_join/%s/%s", room_id, event_id)
  307. query_params: Dict[str, str] = {}
  308. # lazy-load state on join
  309. query_params["omit_members"] = "true" if omit_members else "false"
  310. return await self.client.put_json(
  311. destination=destination,
  312. path=path,
  313. args=query_params,
  314. data=content,
  315. parser=SendJoinParser(room_version, v1_api=False),
  316. )
  317. async def send_leave_v1(
  318. self, destination: str, room_id: str, event_id: str, content: JsonDict
  319. ) -> Tuple[int, JsonDict]:
  320. path = _create_v1_path("/send_leave/%s/%s", room_id, event_id)
  321. return await self.client.put_json(
  322. destination=destination,
  323. path=path,
  324. data=content,
  325. # we want to do our best to send this through. The problem is
  326. # that if it fails, we won't retry it later, so if the remote
  327. # server was just having a momentary blip, the room will be out of
  328. # sync.
  329. ignore_backoff=True,
  330. parser=LegacyJsonSendParser(),
  331. )
  332. async def send_leave_v2(
  333. self, destination: str, room_id: str, event_id: str, content: JsonDict
  334. ) -> JsonDict:
  335. path = _create_v2_path("/send_leave/%s/%s", room_id, event_id)
  336. return await self.client.put_json(
  337. destination=destination,
  338. path=path,
  339. data=content,
  340. # we want to do our best to send this through. The problem is
  341. # that if it fails, we won't retry it later, so if the remote
  342. # server was just having a momentary blip, the room will be out of
  343. # sync.
  344. ignore_backoff=True,
  345. )
  346. async def send_knock_v1(
  347. self,
  348. destination: str,
  349. room_id: str,
  350. event_id: str,
  351. content: JsonDict,
  352. ) -> JsonDict:
  353. """
  354. Sends a signed knock membership event to a remote server. This is the second
  355. step for knocking after make_knock.
  356. Args:
  357. destination: The remote homeserver.
  358. room_id: The ID of the room to knock on.
  359. event_id: The ID of the knock membership event that we're sending.
  360. content: The knock membership event that we're sending. Note that this is not the
  361. `content` field of the membership event, but the entire signed membership event
  362. itself represented as a JSON dict.
  363. Returns:
  364. The remote homeserver can optionally return some state from the room. The response
  365. dictionary is in the form:
  366. {"knock_room_state": [<state event dict>, ...]}
  367. The list of state events may be empty.
  368. """
  369. path = _create_v1_path("/send_knock/%s/%s", room_id, event_id)
  370. return await self.client.put_json(
  371. destination=destination, path=path, data=content
  372. )
  373. async def send_invite_v1(
  374. self, destination: str, room_id: str, event_id: str, content: JsonDict
  375. ) -> Tuple[int, JsonDict]:
  376. path = _create_v1_path("/invite/%s/%s", room_id, event_id)
  377. return await self.client.put_json(
  378. destination=destination,
  379. path=path,
  380. data=content,
  381. ignore_backoff=True,
  382. parser=LegacyJsonSendParser(),
  383. )
  384. async def send_invite_v2(
  385. self, destination: str, room_id: str, event_id: str, content: JsonDict
  386. ) -> JsonDict:
  387. path = _create_v2_path("/invite/%s/%s", room_id, event_id)
  388. return await self.client.put_json(
  389. destination=destination, path=path, data=content, ignore_backoff=True
  390. )
  391. async def get_public_rooms(
  392. self,
  393. remote_server: str,
  394. limit: Optional[int] = None,
  395. since_token: Optional[str] = None,
  396. search_filter: Optional[Dict] = None,
  397. include_all_networks: bool = False,
  398. third_party_instance_id: Optional[str] = None,
  399. ) -> JsonDict:
  400. """Get the list of public rooms from a remote homeserver
  401. See synapse.federation.federation_client.FederationClient.get_public_rooms for
  402. more information.
  403. """
  404. path = _create_v1_path("/publicRooms")
  405. if search_filter:
  406. # this uses MSC2197 (Search Filtering over Federation)
  407. data: Dict[str, Any] = {"include_all_networks": include_all_networks}
  408. if third_party_instance_id:
  409. data["third_party_instance_id"] = third_party_instance_id
  410. if limit:
  411. data["limit"] = limit
  412. if since_token:
  413. data["since"] = since_token
  414. data["filter"] = search_filter
  415. try:
  416. response = await self.client.post_json(
  417. destination=remote_server, path=path, data=data, ignore_backoff=True
  418. )
  419. except HttpResponseException as e:
  420. if e.code == 403:
  421. raise SynapseError(
  422. 403,
  423. "You are not allowed to view the public rooms list of %s"
  424. % (remote_server,),
  425. errcode=Codes.FORBIDDEN,
  426. )
  427. raise
  428. else:
  429. args: Dict[str, Union[str, Iterable[str]]] = {
  430. "include_all_networks": "true" if include_all_networks else "false"
  431. }
  432. if third_party_instance_id:
  433. args["third_party_instance_id"] = third_party_instance_id
  434. if limit:
  435. args["limit"] = str(limit)
  436. if since_token:
  437. args["since"] = since_token
  438. try:
  439. response = await self.client.get_json(
  440. destination=remote_server, path=path, args=args, ignore_backoff=True
  441. )
  442. except HttpResponseException as e:
  443. if e.code == 403:
  444. raise SynapseError(
  445. 403,
  446. "You are not allowed to view the public rooms list of %s"
  447. % (remote_server,),
  448. errcode=Codes.FORBIDDEN,
  449. )
  450. raise
  451. return response
  452. async def exchange_third_party_invite(
  453. self, destination: str, room_id: str, event_dict: JsonDict
  454. ) -> JsonDict:
  455. path = _create_v1_path("/exchange_third_party_invite/%s", room_id)
  456. return await self.client.put_json(
  457. destination=destination, path=path, data=event_dict
  458. )
  459. async def get_event_auth(
  460. self, destination: str, room_id: str, event_id: str
  461. ) -> JsonDict:
  462. path = _create_v1_path("/event_auth/%s/%s", room_id, event_id)
  463. return await self.client.get_json(destination=destination, path=path)
  464. async def query_client_keys(
  465. self, destination: str, query_content: JsonDict, timeout: int
  466. ) -> JsonDict:
  467. """Query the device keys for a list of user ids hosted on a remote
  468. server.
  469. Request:
  470. {
  471. "device_keys": {
  472. "<user_id>": ["<device_id>"]
  473. }
  474. }
  475. Response:
  476. {
  477. "device_keys": {
  478. "<user_id>": {
  479. "<device_id>": {...}
  480. }
  481. },
  482. "master_key": {
  483. "<user_id>": {...}
  484. }
  485. },
  486. "self_signing_key": {
  487. "<user_id>": {...}
  488. }
  489. }
  490. Args:
  491. destination: The server to query.
  492. query_content: The user ids to query.
  493. Returns:
  494. A dict containing device and cross-signing keys.
  495. """
  496. path = _create_v1_path("/user/keys/query")
  497. return await self.client.post_json(
  498. destination=destination, path=path, data=query_content, timeout=timeout
  499. )
  500. async def query_user_devices(
  501. self, destination: str, user_id: str, timeout: int
  502. ) -> JsonDict:
  503. """Query the devices for a user id hosted on a remote server.
  504. Response:
  505. {
  506. "stream_id": "...",
  507. "devices": [ { ... } ],
  508. "master_key": {
  509. "user_id": "<user_id>",
  510. "usage": [...],
  511. "keys": {...},
  512. "signatures": {
  513. "<user_id>": {...}
  514. }
  515. },
  516. "self_signing_key": {
  517. "user_id": "<user_id>",
  518. "usage": [...],
  519. "keys": {...},
  520. "signatures": {
  521. "<user_id>": {...}
  522. }
  523. }
  524. }
  525. Args:
  526. destination: The server to query.
  527. query_content: The user ids to query.
  528. Returns:
  529. A dict containing device and cross-signing keys.
  530. """
  531. path = _create_v1_path("/user/devices/%s", user_id)
  532. return await self.client.get_json(
  533. destination=destination, path=path, timeout=timeout
  534. )
  535. async def claim_client_keys(
  536. self,
  537. user: UserID,
  538. destination: str,
  539. query_content: JsonDict,
  540. timeout: Optional[int],
  541. ) -> JsonDict:
  542. """Claim one-time keys for a list of devices hosted on a remote server.
  543. Request:
  544. {
  545. "one_time_keys": {
  546. "<user_id>": {
  547. "<device_id>": "<algorithm>"
  548. }
  549. }
  550. }
  551. Response:
  552. {
  553. "one_time_keys": {
  554. "<user_id>": {
  555. "<device_id>": {
  556. "<algorithm>:<key_id>": <OTK JSON>
  557. }
  558. }
  559. }
  560. }
  561. Args:
  562. user: the user_id of the requesting user
  563. destination: The server to query.
  564. query_content: The user ids to query.
  565. Returns:
  566. A dict containing the one-time keys.
  567. """
  568. path = _create_v1_path("/user/keys/claim")
  569. return await self.client.post_json(
  570. destination=destination,
  571. path=path,
  572. data={"one_time_keys": query_content},
  573. timeout=timeout,
  574. )
  575. async def claim_client_keys_unstable(
  576. self,
  577. user: UserID,
  578. destination: str,
  579. query_content: JsonDict,
  580. timeout: Optional[int],
  581. ) -> JsonDict:
  582. """Claim one-time keys for a list of devices hosted on a remote server.
  583. Request:
  584. {
  585. "one_time_keys": {
  586. "<user_id>": {
  587. "<device_id>": {"<algorithm>": <count>}
  588. }
  589. }
  590. }
  591. Response:
  592. {
  593. "one_time_keys": {
  594. "<user_id>": {
  595. "<device_id>": {
  596. "<algorithm>:<key_id>": <OTK JSON>
  597. }
  598. }
  599. }
  600. }
  601. Args:
  602. user: the user_id of the requesting user
  603. destination: The server to query.
  604. query_content: The user ids to query.
  605. Returns:
  606. A dict containing the one-time keys.
  607. """
  608. path = _create_path(FEDERATION_UNSTABLE_PREFIX, "/user/keys/claim")
  609. return await self.client.post_json(
  610. destination=destination,
  611. path=path,
  612. data={"one_time_keys": query_content},
  613. timeout=timeout,
  614. )
  615. async def get_missing_events(
  616. self,
  617. destination: str,
  618. room_id: str,
  619. earliest_events: Iterable[str],
  620. latest_events: Iterable[str],
  621. limit: int,
  622. min_depth: int,
  623. timeout: int,
  624. ) -> JsonDict:
  625. path = _create_v1_path("/get_missing_events/%s", room_id)
  626. return await self.client.post_json(
  627. destination=destination,
  628. path=path,
  629. data={
  630. "limit": int(limit),
  631. "min_depth": int(min_depth),
  632. "earliest_events": earliest_events,
  633. "latest_events": latest_events,
  634. },
  635. timeout=timeout,
  636. )
  637. async def get_room_complexity(self, destination: str, room_id: str) -> JsonDict:
  638. """
  639. Args:
  640. destination: The remote server
  641. room_id: The room ID to ask about.
  642. """
  643. path = _create_path(FEDERATION_UNSTABLE_PREFIX, "/rooms/%s/complexity", room_id)
  644. return await self.client.get_json(destination=destination, path=path)
  645. async def get_room_hierarchy(
  646. self, destination: str, room_id: str, suggested_only: bool
  647. ) -> JsonDict:
  648. """
  649. Args:
  650. destination: The remote server
  651. room_id: The room ID to ask about.
  652. suggested_only: if True, only suggested rooms will be returned
  653. """
  654. path = _create_v1_path("/hierarchy/%s", room_id)
  655. return await self.client.get_json(
  656. destination=destination,
  657. path=path,
  658. args={"suggested_only": "true" if suggested_only else "false"},
  659. )
  660. async def get_room_hierarchy_unstable(
  661. self, destination: str, room_id: str, suggested_only: bool
  662. ) -> JsonDict:
  663. """
  664. Args:
  665. destination: The remote server
  666. room_id: The room ID to ask about.
  667. suggested_only: if True, only suggested rooms will be returned
  668. """
  669. path = _create_path(
  670. FEDERATION_UNSTABLE_PREFIX, "/org.matrix.msc2946/hierarchy/%s", room_id
  671. )
  672. return await self.client.get_json(
  673. destination=destination,
  674. path=path,
  675. args={"suggested_only": "true" if suggested_only else "false"},
  676. )
  677. async def get_account_status(
  678. self, destination: str, user_ids: List[str]
  679. ) -> JsonDict:
  680. """
  681. Args:
  682. destination: The remote server.
  683. user_ids: The user ID(s) for which to request account status(es).
  684. """
  685. path = _create_path(
  686. FEDERATION_UNSTABLE_PREFIX, "/org.matrix.msc3720/account_status"
  687. )
  688. return await self.client.post_json(
  689. destination=destination, path=path, data={"user_ids": user_ids}
  690. )
  691. def _create_path(federation_prefix: str, path: str, *args: str) -> str:
  692. """
  693. Ensures that all args are url encoded.
  694. """
  695. return federation_prefix + path % tuple(urllib.parse.quote(arg, "") for arg in args)
  696. def _create_v1_path(path: str, *args: str) -> str:
  697. """Creates a path against V1 federation API from the path template and
  698. args. Ensures that all args are url encoded.
  699. Example:
  700. _create_v1_path("/event/%s", event_id)
  701. Args:
  702. path: String template for the path
  703. args: Args to insert into path. Each arg will be url encoded
  704. """
  705. return _create_path(FEDERATION_V1_PREFIX, path, *args)
  706. def _create_v2_path(path: str, *args: str) -> str:
  707. """Creates a path against V2 federation API from the path template and
  708. args. Ensures that all args are url encoded.
  709. Example:
  710. _create_v2_path("/event/%s", event_id)
  711. Args:
  712. path: String template for the path
  713. args: Args to insert into path. Each arg will be url encoded
  714. """
  715. return _create_path(FEDERATION_V2_PREFIX, path, *args)
  716. @attr.s(slots=True, auto_attribs=True)
  717. class SendJoinResponse:
  718. """The parsed response of a `/send_join` request."""
  719. # The list of auth events from the /send_join response.
  720. auth_events: List[EventBase]
  721. # The list of state from the /send_join response.
  722. state: List[EventBase]
  723. # The raw join event from the /send_join response.
  724. event_dict: JsonDict
  725. # The parsed join event from the /send_join response. This will be None if
  726. # "event" is not included in the response.
  727. event: Optional[EventBase] = None
  728. # The room state is incomplete
  729. members_omitted: bool = False
  730. # List of servers in the room
  731. servers_in_room: Optional[List[str]] = None
  732. @attr.s(slots=True, auto_attribs=True)
  733. class StateRequestResponse:
  734. """The parsed response of a `/state` request."""
  735. auth_events: List[EventBase]
  736. state: List[EventBase]
  737. @ijson.coroutine
  738. def _event_parser(event_dict: JsonDict) -> Generator[None, Tuple[str, Any], None]:
  739. """Helper function for use with `ijson.kvitems_coro` to parse key-value pairs
  740. to add them to a given dictionary.
  741. """
  742. while True:
  743. key, value = yield
  744. event_dict[key] = value
  745. @ijson.coroutine
  746. def _event_list_parser(
  747. room_version: RoomVersion, events: List[EventBase]
  748. ) -> Generator[None, JsonDict, None]:
  749. """Helper function for use with `ijson.items_coro` to parse an array of
  750. events and add them to the given list.
  751. """
  752. while True:
  753. obj = yield
  754. event = make_event_from_dict(obj, room_version)
  755. events.append(event)
  756. @ijson.coroutine
  757. def _members_omitted_parser(response: SendJoinResponse) -> Generator[None, Any, None]:
  758. """Helper function for use with `ijson.items_coro`
  759. Parses the members_omitted field in send_join responses
  760. """
  761. while True:
  762. val = yield
  763. if not isinstance(val, bool):
  764. raise TypeError("members_omitted must be a boolean")
  765. response.members_omitted = val
  766. @ijson.coroutine
  767. def _servers_in_room_parser(response: SendJoinResponse) -> Generator[None, Any, None]:
  768. """Helper function for use with `ijson.items_coro`
  769. Parses the servers_in_room field in send_join responses
  770. """
  771. while True:
  772. val = yield
  773. if not isinstance(val, list) or any(not isinstance(x, str) for x in val):
  774. raise TypeError("servers_in_room must be a list of strings")
  775. response.servers_in_room = val
  776. class SendJoinParser(ByteParser[SendJoinResponse]):
  777. """A parser for the response to `/send_join` requests.
  778. Args:
  779. room_version: The version of the room.
  780. v1_api: Whether the response is in the v1 format.
  781. """
  782. CONTENT_TYPE = "application/json"
  783. # /send_join responses can be huge, so we override the size limit here. The response
  784. # is parsed in a streaming manner, which helps alleviate the issue of memory
  785. # usage a bit.
  786. MAX_RESPONSE_SIZE = 500 * 1024 * 1024
  787. def __init__(self, room_version: RoomVersion, v1_api: bool):
  788. self._response = SendJoinResponse([], [], event_dict={})
  789. self._room_version = room_version
  790. self._coros: List[Generator[None, bytes, None]] = []
  791. # The V1 API has the shape of `[200, {...}]`, which we handle by
  792. # prefixing with `item.*`.
  793. prefix = "item." if v1_api else ""
  794. self._coros = [
  795. ijson.items_coro(
  796. _event_list_parser(room_version, self._response.state),
  797. prefix + "state.item",
  798. use_float=True,
  799. ),
  800. ijson.items_coro(
  801. _event_list_parser(room_version, self._response.auth_events),
  802. prefix + "auth_chain.item",
  803. use_float=True,
  804. ),
  805. ijson.kvitems_coro(
  806. _event_parser(self._response.event_dict),
  807. prefix + "event",
  808. use_float=True,
  809. ),
  810. ]
  811. if not v1_api:
  812. self._coros.append(
  813. ijson.items_coro(
  814. _members_omitted_parser(self._response),
  815. "members_omitted",
  816. use_float="True",
  817. )
  818. )
  819. # Again, stable field name comes last
  820. self._coros.append(
  821. ijson.items_coro(
  822. _servers_in_room_parser(self._response),
  823. "servers_in_room",
  824. use_float="True",
  825. )
  826. )
  827. def write(self, data: bytes) -> int:
  828. for c in self._coros:
  829. c.send(data)
  830. return len(data)
  831. def finish(self) -> SendJoinResponse:
  832. _close_coros(self._coros)
  833. if self._response.event_dict:
  834. self._response.event = make_event_from_dict(
  835. self._response.event_dict, self._room_version
  836. )
  837. return self._response
  838. class _StateParser(ByteParser[StateRequestResponse]):
  839. """A parser for the response to `/state` requests.
  840. Args:
  841. room_version: The version of the room.
  842. """
  843. CONTENT_TYPE = "application/json"
  844. # As with /send_join, /state responses can be huge.
  845. MAX_RESPONSE_SIZE = 500 * 1024 * 1024
  846. def __init__(self, room_version: RoomVersion):
  847. self._response = StateRequestResponse([], [])
  848. self._room_version = room_version
  849. self._coros: List[Generator[None, bytes, None]] = [
  850. ijson.items_coro(
  851. _event_list_parser(room_version, self._response.state),
  852. "pdus.item",
  853. use_float=True,
  854. ),
  855. ijson.items_coro(
  856. _event_list_parser(room_version, self._response.auth_events),
  857. "auth_chain.item",
  858. use_float=True,
  859. ),
  860. ]
  861. def write(self, data: bytes) -> int:
  862. for c in self._coros:
  863. c.send(data)
  864. return len(data)
  865. def finish(self) -> StateRequestResponse:
  866. _close_coros(self._coros)
  867. return self._response
  868. def _close_coros(coros: Iterable[Generator[None, bytes, None]]) -> None:
  869. """Close each of the given coroutines.
  870. Always calls .close() on each coroutine, even if doing so raises an exception.
  871. Any exceptions raised are aggregated into an ExceptionBundle.
  872. :raises ExceptionBundle: if at least one coroutine fails to close.
  873. """
  874. exceptions = []
  875. for c in coros:
  876. try:
  877. c.close()
  878. except Exception as e:
  879. exceptions.append(e)
  880. if exceptions:
  881. # raise from the first exception so that the traceback has slightly more context
  882. raise ExceptionBundle(
  883. f"There were {len(exceptions)} errors closing coroutines", exceptions
  884. ) from exceptions[0]