You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 
 
 
 

395 regels
15 KiB

  1. # Copyright 2016 OpenMarket Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import logging
  15. from http import HTTPStatus
  16. from typing import TYPE_CHECKING, Any, Dict, Optional
  17. from synapse.api.constants import EduTypes, EventContentFields, ToDeviceEventTypes
  18. from synapse.api.errors import Codes, SynapseError
  19. from synapse.api.ratelimiting import Ratelimiter
  20. from synapse.logging.context import run_in_background
  21. from synapse.logging.opentracing import (
  22. SynapseTags,
  23. get_active_span_text_map,
  24. log_kv,
  25. set_tag,
  26. )
  27. from synapse.replication.http.devices import (
  28. ReplicationMultiUserDevicesResyncRestServlet,
  29. )
  30. from synapse.types import JsonDict, Requester, StreamKeyType, UserID, get_domain_from_id
  31. from synapse.util import json_encoder
  32. from synapse.util.stringutils import random_string
  33. if TYPE_CHECKING:
  34. from synapse.server import HomeServer
  35. logger = logging.getLogger(__name__)
  36. class DeviceMessageHandler:
  37. def __init__(self, hs: "HomeServer"):
  38. """
  39. Args:
  40. hs: server
  41. """
  42. self.store = hs.get_datastores().main
  43. self.notifier = hs.get_notifier()
  44. self.is_mine = hs.is_mine
  45. if hs.config.experimental.msc3814_enabled:
  46. self.event_sources = hs.get_event_sources()
  47. self.device_handler = hs.get_device_handler()
  48. # We only need to poke the federation sender explicitly if its on the
  49. # same instance. Other federation sender instances will get notified by
  50. # `synapse.app.generic_worker.FederationSenderHandler` when it sees it
  51. # in the to-device replication stream.
  52. self.federation_sender = None
  53. if hs.should_send_federation():
  54. self.federation_sender = hs.get_federation_sender()
  55. # If we can handle the to device EDUs we do so, otherwise we route them
  56. # to the appropriate worker.
  57. if hs.get_instance_name() in hs.config.worker.writers.to_device:
  58. hs.get_federation_registry().register_edu_handler(
  59. EduTypes.DIRECT_TO_DEVICE, self.on_direct_to_device_edu
  60. )
  61. else:
  62. hs.get_federation_registry().register_instances_for_edu(
  63. EduTypes.DIRECT_TO_DEVICE,
  64. hs.config.worker.writers.to_device,
  65. )
  66. # The handler to call when we think a user's device list might be out of
  67. # sync. We do all device list resyncing on the master instance, so if
  68. # we're on a worker we hit the device resync replication API.
  69. if hs.config.worker.worker_app is None:
  70. self._multi_user_device_resync = (
  71. hs.get_device_handler().device_list_updater.multi_user_device_resync
  72. )
  73. else:
  74. self._multi_user_device_resync = (
  75. ReplicationMultiUserDevicesResyncRestServlet.make_client(hs)
  76. )
  77. # a rate limiter for room key requests. The keys are
  78. # (sending_user_id, sending_device_id).
  79. self._ratelimiter = Ratelimiter(
  80. store=self.store,
  81. clock=hs.get_clock(),
  82. cfg=hs.config.ratelimiting.rc_key_requests,
  83. )
  84. async def on_direct_to_device_edu(self, origin: str, content: JsonDict) -> None:
  85. """
  86. Handle receiving to-device messages from remote homeservers.
  87. Args:
  88. origin: The remote homeserver.
  89. content: The JSON dictionary containing the to-device messages.
  90. """
  91. local_messages = {}
  92. sender_user_id = content["sender"]
  93. if origin != get_domain_from_id(sender_user_id):
  94. logger.warning(
  95. "Dropping device message from %r with spoofed sender %r",
  96. origin,
  97. sender_user_id,
  98. )
  99. message_type = content["type"]
  100. message_id = content["message_id"]
  101. for user_id, by_device in content["messages"].items():
  102. # we use UserID.from_string to catch invalid user ids
  103. if not self.is_mine(UserID.from_string(user_id)):
  104. logger.warning("To-device message to non-local user %s", user_id)
  105. raise SynapseError(400, "Not a user here")
  106. if not by_device:
  107. continue
  108. # Ratelimit key requests by the sending user.
  109. if message_type == ToDeviceEventTypes.RoomKeyRequest:
  110. allowed, _ = await self._ratelimiter.can_do_action(
  111. None, (sender_user_id, None)
  112. )
  113. if not allowed:
  114. logger.info(
  115. "Dropping room_key_request from %s to %s due to rate limit",
  116. sender_user_id,
  117. user_id,
  118. )
  119. continue
  120. messages_by_device = {
  121. device_id: {
  122. "content": message_content,
  123. "type": message_type,
  124. "sender": sender_user_id,
  125. }
  126. for device_id, message_content in by_device.items()
  127. }
  128. local_messages[user_id] = messages_by_device
  129. await self._check_for_unknown_devices(
  130. message_type, sender_user_id, by_device
  131. )
  132. # Add messages to the database.
  133. # Retrieve the stream id of the last-processed to-device message.
  134. last_stream_id = await self.store.add_messages_from_remote_to_device_inbox(
  135. origin, message_id, local_messages
  136. )
  137. # Notify listeners that there are new to-device messages to process,
  138. # handing them the latest stream id.
  139. self.notifier.on_new_event(
  140. StreamKeyType.TO_DEVICE, last_stream_id, users=local_messages.keys()
  141. )
  142. async def _check_for_unknown_devices(
  143. self,
  144. message_type: str,
  145. sender_user_id: str,
  146. by_device: Dict[str, Dict[str, Any]],
  147. ) -> None:
  148. """Checks inbound device messages for unknown remote devices, and if
  149. found marks the remote cache for the user as stale.
  150. """
  151. if message_type != "m.room_key_request":
  152. return
  153. # Get the sending device IDs
  154. requesting_device_ids = set()
  155. for message_content in by_device.values():
  156. device_id = message_content.get("requesting_device_id")
  157. requesting_device_ids.add(device_id)
  158. # Check if we are tracking the devices of the remote user.
  159. room_ids = await self.store.get_rooms_for_user(sender_user_id)
  160. if not room_ids:
  161. logger.info(
  162. "Received device message from remote device we don't"
  163. " share a room with: %s %s",
  164. sender_user_id,
  165. requesting_device_ids,
  166. )
  167. return
  168. # If we are tracking check that we know about the sending
  169. # devices.
  170. cached_devices = await self.store.get_cached_devices_for_user(sender_user_id)
  171. unknown_devices = requesting_device_ids - set(cached_devices)
  172. if unknown_devices:
  173. logger.info(
  174. "Received device message from remote device not in our cache: %s %s",
  175. sender_user_id,
  176. unknown_devices,
  177. )
  178. await self.store.mark_remote_users_device_caches_as_stale((sender_user_id,))
  179. # Immediately attempt a resync in the background
  180. run_in_background(self._multi_user_device_resync, user_ids=[sender_user_id])
  181. async def send_device_message(
  182. self,
  183. requester: Requester,
  184. message_type: str,
  185. messages: Dict[str, Dict[str, JsonDict]],
  186. ) -> None:
  187. """
  188. Handle a request from a user to send to-device message(s).
  189. Args:
  190. requester: The user that is sending the to-device messages.
  191. message_type: The type of to-device messages that are being sent.
  192. messages: A dictionary containing recipients mapped to messages intended for them.
  193. """
  194. sender_user_id = requester.user.to_string()
  195. set_tag(SynapseTags.TO_DEVICE_TYPE, message_type)
  196. set_tag(SynapseTags.TO_DEVICE_SENDER, sender_user_id)
  197. local_messages = {}
  198. remote_messages: Dict[str, Dict[str, Dict[str, JsonDict]]] = {}
  199. for user_id, by_device in messages.items():
  200. # add an opentracing log entry for each message
  201. for device_id, message_content in by_device.items():
  202. log_kv(
  203. {
  204. "event": "send_to_device_message",
  205. "user_id": user_id,
  206. "device_id": device_id,
  207. EventContentFields.TO_DEVICE_MSGID: message_content.get(
  208. EventContentFields.TO_DEVICE_MSGID
  209. ),
  210. }
  211. )
  212. # Ratelimit local cross-user key requests by the sending device.
  213. if (
  214. message_type == ToDeviceEventTypes.RoomKeyRequest
  215. and user_id != sender_user_id
  216. ):
  217. allowed, _ = await self._ratelimiter.can_do_action(
  218. requester, (sender_user_id, requester.device_id)
  219. )
  220. if not allowed:
  221. log_kv({"message": f"dropping key requests to {user_id}"})
  222. logger.info(
  223. "Dropping room_key_request from %s to %s due to rate limit",
  224. sender_user_id,
  225. user_id,
  226. )
  227. continue
  228. # we use UserID.from_string to catch invalid user ids
  229. if self.is_mine(UserID.from_string(user_id)):
  230. messages_by_device = {
  231. device_id: {
  232. "content": message_content,
  233. "type": message_type,
  234. "sender": sender_user_id,
  235. }
  236. for device_id, message_content in by_device.items()
  237. }
  238. if messages_by_device:
  239. local_messages[user_id] = messages_by_device
  240. else:
  241. destination = get_domain_from_id(user_id)
  242. remote_messages.setdefault(destination, {})[user_id] = by_device
  243. context = get_active_span_text_map()
  244. remote_edu_contents = {}
  245. for destination, messages in remote_messages.items():
  246. # The EDU contains a "message_id" property which is used for
  247. # idempotence. Make up a random one.
  248. message_id = random_string(16)
  249. log_kv({"destination": destination, "message_id": message_id})
  250. remote_edu_contents[destination] = {
  251. "messages": messages,
  252. "sender": sender_user_id,
  253. "type": message_type,
  254. "message_id": message_id,
  255. "org.matrix.opentracing_context": json_encoder.encode(context),
  256. }
  257. # Add messages to the database.
  258. # Retrieve the stream id of the last-processed to-device message.
  259. last_stream_id = await self.store.add_messages_to_device_inbox(
  260. local_messages, remote_edu_contents
  261. )
  262. # Notify listeners that there are new to-device messages to process,
  263. # handing them the latest stream id.
  264. self.notifier.on_new_event(
  265. StreamKeyType.TO_DEVICE, last_stream_id, users=local_messages.keys()
  266. )
  267. if self.federation_sender:
  268. # Enqueue a new federation transaction to send the new
  269. # device messages to each remote destination.
  270. await self.federation_sender.send_device_messages(remote_messages.keys())
  271. async def get_events_for_dehydrated_device(
  272. self,
  273. requester: Requester,
  274. device_id: str,
  275. since_token: Optional[str],
  276. limit: int,
  277. ) -> JsonDict:
  278. """Fetches up to `limit` events sent to `device_id` starting from `since_token`
  279. and returns the new since token. If there are no more messages, returns an empty
  280. array.
  281. Args:
  282. requester: the user requesting the messages
  283. device_id: ID of the dehydrated device
  284. since_token: stream id to start from when fetching messages
  285. limit: the number of messages to fetch
  286. Returns:
  287. A dict containing the to-device messages, as well as a token that the client
  288. can provide in the next call to fetch the next batch of messages
  289. """
  290. user_id = requester.user.to_string()
  291. # only allow fetching messages for the dehydrated device id currently associated
  292. # with the user
  293. dehydrated_device = await self.device_handler.get_dehydrated_device(user_id)
  294. if dehydrated_device is None:
  295. raise SynapseError(
  296. HTTPStatus.FORBIDDEN,
  297. "No dehydrated device exists",
  298. Codes.FORBIDDEN,
  299. )
  300. dehydrated_device_id, _ = dehydrated_device
  301. if device_id != dehydrated_device_id:
  302. raise SynapseError(
  303. HTTPStatus.FORBIDDEN,
  304. "You may only fetch messages for your dehydrated device",
  305. Codes.FORBIDDEN,
  306. )
  307. since_stream_id = 0
  308. if since_token:
  309. if not since_token.startswith("d"):
  310. raise SynapseError(
  311. HTTPStatus.BAD_REQUEST,
  312. "from parameter %r has an invalid format" % (since_token,),
  313. errcode=Codes.INVALID_PARAM,
  314. )
  315. try:
  316. since_stream_id = int(since_token[1:])
  317. except Exception:
  318. raise SynapseError(
  319. HTTPStatus.BAD_REQUEST,
  320. "from parameter %r has an invalid format" % (since_token,),
  321. errcode=Codes.INVALID_PARAM,
  322. )
  323. to_token = self.event_sources.get_current_token().to_device_key
  324. messages, stream_id = await self.store.get_messages_for_device(
  325. user_id, device_id, since_stream_id, to_token, limit
  326. )
  327. for message in messages:
  328. # Remove the message id before sending to client
  329. message_id = message.pop("message_id", None)
  330. if message_id:
  331. set_tag(SynapseTags.TO_DEVICE_EDU_ID, message_id)
  332. logger.debug(
  333. "Returning %d to-device messages between %d and %d (current token: %d) for "
  334. "dehydrated device %s, user_id %s",
  335. len(messages),
  336. since_stream_id,
  337. stream_id,
  338. to_token,
  339. device_id,
  340. user_id,
  341. )
  342. return {
  343. "events": messages,
  344. "next_batch": f"d{stream_id}",
  345. }