You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 
 
 
 

555 lines
20 KiB

  1. # Copyright 2015, 2016 OpenMarket Ltd
  2. # Copyright 2022 The Matrix.org Foundation C.I.C.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. import logging
  16. import urllib.parse
  17. from typing import (
  18. TYPE_CHECKING,
  19. Dict,
  20. Iterable,
  21. List,
  22. Mapping,
  23. Optional,
  24. Sequence,
  25. Tuple,
  26. TypeVar,
  27. Union,
  28. )
  29. from prometheus_client import Counter
  30. from typing_extensions import ParamSpec, TypeGuard
  31. from synapse.api.constants import EventTypes, Membership, ThirdPartyEntityKind
  32. from synapse.api.errors import CodeMessageException, HttpResponseException
  33. from synapse.appservice import (
  34. ApplicationService,
  35. TransactionOneTimeKeysCount,
  36. TransactionUnusedFallbackKeys,
  37. )
  38. from synapse.events import EventBase
  39. from synapse.events.utils import SerializeEventConfig, serialize_event
  40. from synapse.http.client import SimpleHttpClient, is_unknown_endpoint
  41. from synapse.logging import opentracing
  42. from synapse.types import DeviceListUpdates, JsonDict, ThirdPartyInstanceID
  43. from synapse.util.caches.response_cache import ResponseCache
  44. if TYPE_CHECKING:
  45. from synapse.server import HomeServer
  46. logger = logging.getLogger(__name__)
  47. sent_transactions_counter = Counter(
  48. "synapse_appservice_api_sent_transactions",
  49. "Number of /transactions/ requests sent",
  50. ["service"],
  51. )
  52. failed_transactions_counter = Counter(
  53. "synapse_appservice_api_failed_transactions",
  54. "Number of /transactions/ requests that failed to send",
  55. ["service"],
  56. )
  57. sent_events_counter = Counter(
  58. "synapse_appservice_api_sent_events", "Number of events sent to the AS", ["service"]
  59. )
  60. sent_ephemeral_counter = Counter(
  61. "synapse_appservice_api_sent_ephemeral",
  62. "Number of ephemeral events sent to the AS",
  63. ["service"],
  64. )
  65. sent_todevice_counter = Counter(
  66. "synapse_appservice_api_sent_todevice",
  67. "Number of todevice messages sent to the AS",
  68. ["service"],
  69. )
  70. HOUR_IN_MS = 60 * 60 * 1000
  71. APP_SERVICE_PREFIX = "/_matrix/app/v1"
  72. P = ParamSpec("P")
  73. R = TypeVar("R")
  74. def _is_valid_3pe_metadata(info: JsonDict) -> bool:
  75. if "instances" not in info:
  76. return False
  77. if not isinstance(info["instances"], list):
  78. return False
  79. return True
  80. def _is_valid_3pe_result(r: object, field: str) -> TypeGuard[JsonDict]:
  81. if not isinstance(r, dict):
  82. return False
  83. for k in (field, "protocol"):
  84. if k not in r:
  85. return False
  86. if not isinstance(r[k], str):
  87. return False
  88. if "fields" not in r:
  89. return False
  90. fields = r["fields"]
  91. if not isinstance(fields, dict):
  92. return False
  93. return True
  94. class ApplicationServiceApi(SimpleHttpClient):
  95. """This class manages HS -> AS communications, including querying and
  96. pushing.
  97. """
  98. def __init__(self, hs: "HomeServer"):
  99. super().__init__(hs)
  100. self.clock = hs.get_clock()
  101. self.config = hs.config.appservice
  102. self.protocol_meta_cache: ResponseCache[Tuple[str, str]] = ResponseCache(
  103. hs.get_clock(), "as_protocol_meta", timeout_ms=HOUR_IN_MS
  104. )
  105. def _get_headers(self, service: "ApplicationService") -> Dict[bytes, List[bytes]]:
  106. """This makes sure we have always the auth header and opentracing headers set."""
  107. # This is also ensured before in the functions. However this is needed to please
  108. # the typechecks.
  109. assert service.hs_token is not None
  110. headers = {b"Authorization": [b"Bearer " + service.hs_token.encode("ascii")]}
  111. opentracing.inject_header_dict(headers, check_destination=False)
  112. return headers
  113. async def query_user(self, service: "ApplicationService", user_id: str) -> bool:
  114. if service.url is None:
  115. return False
  116. # This is required by the configuration.
  117. assert service.hs_token is not None
  118. try:
  119. args = None
  120. if self.config.use_appservice_legacy_authorization:
  121. args = {"access_token": service.hs_token}
  122. response = await self.get_json(
  123. f"{service.url}{APP_SERVICE_PREFIX}/users/{urllib.parse.quote(user_id)}",
  124. args,
  125. headers=self._get_headers(service),
  126. )
  127. if response is not None: # just an empty json object
  128. return True
  129. except CodeMessageException as e:
  130. if e.code == 404:
  131. return False
  132. logger.warning("query_user to %s received %s", service.url, e.code)
  133. except Exception as ex:
  134. logger.warning("query_user to %s threw exception %s", service.url, ex)
  135. return False
  136. async def query_alias(self, service: "ApplicationService", alias: str) -> bool:
  137. if service.url is None:
  138. return False
  139. # This is required by the configuration.
  140. assert service.hs_token is not None
  141. try:
  142. args = None
  143. if self.config.use_appservice_legacy_authorization:
  144. args = {"access_token": service.hs_token}
  145. response = await self.get_json(
  146. f"{service.url}{APP_SERVICE_PREFIX}/rooms/{urllib.parse.quote(alias)}",
  147. args,
  148. headers=self._get_headers(service),
  149. )
  150. if response is not None: # just an empty json object
  151. return True
  152. except CodeMessageException as e:
  153. logger.warning("query_alias to %s received %s", service.url, e.code)
  154. if e.code == 404:
  155. return False
  156. except Exception as ex:
  157. logger.warning("query_alias to %s threw exception %s", service.url, ex)
  158. return False
  159. async def query_3pe(
  160. self,
  161. service: "ApplicationService",
  162. kind: str,
  163. protocol: str,
  164. fields: Dict[bytes, List[bytes]],
  165. ) -> List[JsonDict]:
  166. if kind == ThirdPartyEntityKind.USER:
  167. required_field = "userid"
  168. elif kind == ThirdPartyEntityKind.LOCATION:
  169. required_field = "alias"
  170. else:
  171. raise ValueError("Unrecognised 'kind' argument %r to query_3pe()", kind)
  172. if service.url is None:
  173. return []
  174. # This is required by the configuration.
  175. assert service.hs_token is not None
  176. try:
  177. args: Mapping[bytes, Union[List[bytes], str]] = fields
  178. if self.config.use_appservice_legacy_authorization:
  179. args = {
  180. **fields,
  181. b"access_token": service.hs_token,
  182. }
  183. response = await self.get_json(
  184. f"{service.url}{APP_SERVICE_PREFIX}/thirdparty/{kind}/{urllib.parse.quote(protocol)}",
  185. args=args,
  186. headers=self._get_headers(service),
  187. )
  188. if not isinstance(response, list):
  189. logger.warning(
  190. "query_3pe to %s returned an invalid response %r",
  191. service.url,
  192. response,
  193. )
  194. return []
  195. ret = []
  196. for r in response:
  197. if _is_valid_3pe_result(r, field=required_field):
  198. ret.append(r)
  199. else:
  200. logger.warning(
  201. "query_3pe to %s returned an invalid result %r", service.url, r
  202. )
  203. return ret
  204. except Exception as ex:
  205. logger.warning("query_3pe to %s threw exception %s", service.url, ex)
  206. return []
  207. async def get_3pe_protocol(
  208. self, service: "ApplicationService", protocol: str
  209. ) -> Optional[JsonDict]:
  210. if service.url is None:
  211. return {}
  212. async def _get() -> Optional[JsonDict]:
  213. # This is required by the configuration.
  214. assert service.hs_token is not None
  215. try:
  216. args = None
  217. if self.config.use_appservice_legacy_authorization:
  218. args = {"access_token": service.hs_token}
  219. info = await self.get_json(
  220. f"{service.url}{APP_SERVICE_PREFIX}/thirdparty/protocol/{urllib.parse.quote(protocol)}",
  221. args,
  222. headers=self._get_headers(service),
  223. )
  224. if not _is_valid_3pe_metadata(info):
  225. logger.warning(
  226. "query_3pe_protocol to %s did not return a valid result",
  227. service.url,
  228. )
  229. return None
  230. for instance in info.get("instances", []):
  231. network_id = instance.get("network_id", None)
  232. if network_id is not None:
  233. instance["instance_id"] = ThirdPartyInstanceID(
  234. service.id, network_id
  235. ).to_string()
  236. return info
  237. except Exception as ex:
  238. logger.warning(
  239. "query_3pe_protocol to %s threw exception %s", service.url, ex
  240. )
  241. return None
  242. key = (service.id, protocol)
  243. return await self.protocol_meta_cache.wrap(key, _get)
  244. async def ping(self, service: "ApplicationService", txn_id: Optional[str]) -> None:
  245. # The caller should check that url is set
  246. assert service.url is not None, "ping called without URL being set"
  247. # This is required by the configuration.
  248. assert service.hs_token is not None
  249. await self.post_json_get_json(
  250. uri=f"{service.url}{APP_SERVICE_PREFIX}/ping",
  251. post_json={"transaction_id": txn_id},
  252. headers=self._get_headers(service),
  253. )
  254. async def push_bulk(
  255. self,
  256. service: "ApplicationService",
  257. events: Sequence[EventBase],
  258. ephemeral: List[JsonDict],
  259. to_device_messages: List[JsonDict],
  260. one_time_keys_count: TransactionOneTimeKeysCount,
  261. unused_fallback_keys: TransactionUnusedFallbackKeys,
  262. device_list_summary: DeviceListUpdates,
  263. txn_id: Optional[int] = None,
  264. ) -> bool:
  265. """
  266. Push data to an application service.
  267. Args:
  268. service: The application service to send to.
  269. events: The persistent events to send.
  270. ephemeral: The ephemeral events to send.
  271. to_device_messages: The to-device messages to send.
  272. txn_id: An unique ID to assign to this transaction. Application services should
  273. deduplicate transactions received with identitical IDs.
  274. Returns:
  275. True if the task succeeded, False if it failed.
  276. """
  277. if service.url is None:
  278. return True
  279. # This is required by the configuration.
  280. assert service.hs_token is not None
  281. serialized_events = self._serialize(service, events)
  282. if txn_id is None:
  283. logger.warning(
  284. "push_bulk: Missing txn ID sending events to %s", service.url
  285. )
  286. txn_id = 0
  287. # Never send ephemeral events to appservices that do not support it
  288. body: JsonDict = {"events": serialized_events}
  289. if service.supports_ephemeral:
  290. body.update(
  291. {
  292. # TODO: Update to stable prefixes once MSC2409 completes FCP merge.
  293. "de.sorunome.msc2409.ephemeral": ephemeral,
  294. "de.sorunome.msc2409.to_device": to_device_messages,
  295. }
  296. )
  297. # TODO: Update to stable prefixes once MSC3202 completes FCP merge
  298. if service.msc3202_transaction_extensions:
  299. if one_time_keys_count:
  300. body[
  301. "org.matrix.msc3202.device_one_time_key_counts"
  302. ] = one_time_keys_count
  303. body[
  304. "org.matrix.msc3202.device_one_time_keys_count"
  305. ] = one_time_keys_count
  306. if unused_fallback_keys:
  307. body[
  308. "org.matrix.msc3202.device_unused_fallback_key_types"
  309. ] = unused_fallback_keys
  310. if device_list_summary:
  311. body["org.matrix.msc3202.device_lists"] = {
  312. "changed": list(device_list_summary.changed),
  313. "left": list(device_list_summary.left),
  314. }
  315. try:
  316. args = None
  317. if self.config.use_appservice_legacy_authorization:
  318. args = {"access_token": service.hs_token}
  319. await self.put_json(
  320. f"{service.url}{APP_SERVICE_PREFIX}/transactions/{urllib.parse.quote(str(txn_id))}",
  321. json_body=body,
  322. args=args,
  323. headers=self._get_headers(service),
  324. )
  325. if logger.isEnabledFor(logging.DEBUG):
  326. logger.debug(
  327. "push_bulk to %s succeeded! events=%s",
  328. service.url,
  329. [event.get("event_id") for event in events],
  330. )
  331. sent_transactions_counter.labels(service.id).inc()
  332. sent_events_counter.labels(service.id).inc(len(serialized_events))
  333. sent_ephemeral_counter.labels(service.id).inc(len(ephemeral))
  334. sent_todevice_counter.labels(service.id).inc(len(to_device_messages))
  335. return True
  336. except CodeMessageException as e:
  337. logger.warning(
  338. "push_bulk to %s received code=%s msg=%s",
  339. service.url,
  340. e.code,
  341. e.msg,
  342. exc_info=logger.isEnabledFor(logging.DEBUG),
  343. )
  344. except Exception as ex:
  345. logger.warning(
  346. "push_bulk to %s threw exception(%s) %s args=%s",
  347. service.url,
  348. type(ex).__name__,
  349. ex,
  350. ex.args,
  351. exc_info=logger.isEnabledFor(logging.DEBUG),
  352. )
  353. failed_transactions_counter.labels(service.id).inc()
  354. return False
  355. async def claim_client_keys(
  356. self, service: "ApplicationService", query: List[Tuple[str, str, str, int]]
  357. ) -> Tuple[
  358. Dict[str, Dict[str, Dict[str, JsonDict]]], List[Tuple[str, str, str, int]]
  359. ]:
  360. """Claim one time keys from an application service.
  361. Note that any error (including a timeout) is treated as the application
  362. service having no information.
  363. Args:
  364. service: The application service to query.
  365. query: An iterable of tuples of (user ID, device ID, algorithm).
  366. Returns:
  367. A tuple of:
  368. A map of user ID -> a map device ID -> a map of key ID -> JSON dict.
  369. A copy of the input which has not been fulfilled because the
  370. appservice doesn't support this endpoint or has not returned
  371. data for that tuple.
  372. """
  373. if service.url is None:
  374. return {}, query
  375. # This is required by the configuration.
  376. assert service.hs_token is not None
  377. # Create the expected payload shape.
  378. body: Dict[str, Dict[str, List[str]]] = {}
  379. for user_id, device, algorithm, count in query:
  380. body.setdefault(user_id, {}).setdefault(device, []).extend(
  381. [algorithm] * count
  382. )
  383. uri = f"{service.url}/_matrix/app/unstable/org.matrix.msc3983/keys/claim"
  384. try:
  385. response = await self.post_json_get_json(
  386. uri,
  387. body,
  388. headers=self._get_headers(service),
  389. )
  390. except HttpResponseException as e:
  391. # The appservice doesn't support this endpoint.
  392. if is_unknown_endpoint(e):
  393. return {}, query
  394. logger.warning("claim_keys to %s received %s", uri, e.code)
  395. return {}, query
  396. except Exception as ex:
  397. logger.warning("claim_keys to %s threw exception %s", uri, ex)
  398. return {}, query
  399. # Check if the appservice fulfilled all of the queried user/device/algorithms
  400. # or if some are still missing.
  401. #
  402. # TODO This places a lot of faith in the response shape being correct.
  403. missing = []
  404. for user_id, device, algorithm, count in query:
  405. # Count the number of keys in the response for this algorithm by
  406. # checking which key IDs start with the algorithm. This uses that
  407. # True == 1 in Python to generate a count.
  408. response_count = sum(
  409. key_id.startswith(f"{algorithm}:")
  410. for key_id in response.get(user_id, {}).get(device, {})
  411. )
  412. count -= response_count
  413. # If the appservice responds with fewer keys than requested, then
  414. # consider the request unfulfilled.
  415. if count > 0:
  416. missing.append((user_id, device, algorithm, count))
  417. return response, missing
  418. async def query_keys(
  419. self, service: "ApplicationService", query: Dict[str, List[str]]
  420. ) -> Dict[str, Dict[str, Dict[str, JsonDict]]]:
  421. """Query the application service for keys.
  422. Note that any error (including a timeout) is treated as the application
  423. service having no information.
  424. Args:
  425. service: The application service to query.
  426. query: An iterable of tuples of (user ID, device ID, algorithm).
  427. Returns:
  428. A map of device_keys/master_keys/self_signing_keys/user_signing_keys:
  429. device_keys is a map of user ID -> a map device ID -> device info.
  430. """
  431. if service.url is None:
  432. return {}
  433. # This is required by the configuration.
  434. assert service.hs_token is not None
  435. uri = f"{service.url}/_matrix/app/unstable/org.matrix.msc3984/keys/query"
  436. try:
  437. response = await self.post_json_get_json(
  438. uri,
  439. query,
  440. headers=self._get_headers(service),
  441. )
  442. except HttpResponseException as e:
  443. # The appservice doesn't support this endpoint.
  444. if is_unknown_endpoint(e):
  445. return {}
  446. logger.warning("query_keys to %s received %s", uri, e.code)
  447. return {}
  448. except Exception as ex:
  449. logger.warning("query_keys to %s threw exception %s", uri, ex)
  450. return {}
  451. return response
  452. def _serialize(
  453. self, service: "ApplicationService", events: Iterable[EventBase]
  454. ) -> List[JsonDict]:
  455. time_now = self.clock.time_msec()
  456. return [
  457. serialize_event(
  458. e,
  459. time_now,
  460. config=SerializeEventConfig(
  461. as_client_event=True,
  462. # If this is an invite or a knock membership event, and we're interested
  463. # in this user, then include any stripped state alongside the event.
  464. include_stripped_room_state=(
  465. e.type == EventTypes.Member
  466. and (
  467. e.membership == Membership.INVITE
  468. or e.membership == Membership.KNOCK
  469. )
  470. and service.is_interested_in_user(e.state_key)
  471. ),
  472. ),
  473. )
  474. for e in events
  475. ]