You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 
 
 
 

958 regels
36 KiB

  1. # Copyright 2016 OpenMarket Ltd
  2. # Copyright 2021 The Matrix.org Foundation C.I.C.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. import logging
  16. from typing import (
  17. TYPE_CHECKING,
  18. Collection,
  19. Dict,
  20. Iterable,
  21. List,
  22. Optional,
  23. Set,
  24. Tuple,
  25. cast,
  26. )
  27. from synapse.logging import issue9533_logger
  28. from synapse.logging.opentracing import log_kv, set_tag, trace
  29. from synapse.replication.tcp.streams import ToDeviceStream
  30. from synapse.storage._base import SQLBaseStore, db_to_json
  31. from synapse.storage.database import (
  32. DatabasePool,
  33. LoggingDatabaseConnection,
  34. LoggingTransaction,
  35. make_in_list_sql_clause,
  36. )
  37. from synapse.storage.engines import PostgresEngine
  38. from synapse.storage.util.id_generators import (
  39. AbstractStreamIdGenerator,
  40. MultiWriterIdGenerator,
  41. StreamIdGenerator,
  42. )
  43. from synapse.types import JsonDict
  44. from synapse.util import json_encoder
  45. from synapse.util.caches.expiringcache import ExpiringCache
  46. from synapse.util.caches.stream_change_cache import StreamChangeCache
  47. if TYPE_CHECKING:
  48. from synapse.server import HomeServer
  49. logger = logging.getLogger(__name__)
  50. class DeviceInboxWorkerStore(SQLBaseStore):
  51. def __init__(
  52. self,
  53. database: DatabasePool,
  54. db_conn: LoggingDatabaseConnection,
  55. hs: "HomeServer",
  56. ):
  57. super().__init__(database, db_conn, hs)
  58. self._instance_name = hs.get_instance_name()
  59. # Map of (user_id, device_id) to the last stream_id that has been
  60. # deleted up to. This is so that we can no op deletions.
  61. self._last_device_delete_cache: ExpiringCache[
  62. Tuple[str, Optional[str]], int
  63. ] = ExpiringCache(
  64. cache_name="last_device_delete_cache",
  65. clock=self._clock,
  66. max_len=10000,
  67. expiry_ms=30 * 60 * 1000,
  68. )
  69. if isinstance(database.engine, PostgresEngine):
  70. self._can_write_to_device = (
  71. self._instance_name in hs.config.worker.writers.to_device
  72. )
  73. self._device_inbox_id_gen: AbstractStreamIdGenerator = (
  74. MultiWriterIdGenerator(
  75. db_conn=db_conn,
  76. db=database,
  77. stream_name="to_device",
  78. instance_name=self._instance_name,
  79. tables=[("device_inbox", "instance_name", "stream_id")],
  80. sequence_name="device_inbox_sequence",
  81. writers=hs.config.worker.writers.to_device,
  82. )
  83. )
  84. else:
  85. self._can_write_to_device = True
  86. self._device_inbox_id_gen = StreamIdGenerator(
  87. db_conn, "device_inbox", "stream_id"
  88. )
  89. max_device_inbox_id = self._device_inbox_id_gen.get_current_token()
  90. device_inbox_prefill, min_device_inbox_id = self.db_pool.get_cache_dict(
  91. db_conn,
  92. "device_inbox",
  93. entity_column="user_id",
  94. stream_column="stream_id",
  95. max_value=max_device_inbox_id,
  96. limit=1000,
  97. )
  98. self._device_inbox_stream_cache = StreamChangeCache(
  99. "DeviceInboxStreamChangeCache",
  100. min_device_inbox_id,
  101. prefilled_cache=device_inbox_prefill,
  102. )
  103. # The federation outbox and the local device inbox uses the same
  104. # stream_id generator.
  105. device_outbox_prefill, min_device_outbox_id = self.db_pool.get_cache_dict(
  106. db_conn,
  107. "device_federation_outbox",
  108. entity_column="destination",
  109. stream_column="stream_id",
  110. max_value=max_device_inbox_id,
  111. limit=1000,
  112. )
  113. self._device_federation_outbox_stream_cache = StreamChangeCache(
  114. "DeviceFederationOutboxStreamChangeCache",
  115. min_device_outbox_id,
  116. prefilled_cache=device_outbox_prefill,
  117. )
  118. def process_replication_rows(
  119. self,
  120. stream_name: str,
  121. instance_name: str,
  122. token: int,
  123. rows: Iterable[ToDeviceStream.ToDeviceStreamRow],
  124. ) -> None:
  125. if stream_name == ToDeviceStream.NAME:
  126. # If replication is happening than postgres must be being used.
  127. assert isinstance(self._device_inbox_id_gen, MultiWriterIdGenerator)
  128. self._device_inbox_id_gen.advance(instance_name, token)
  129. for row in rows:
  130. if row.entity.startswith("@"):
  131. self._device_inbox_stream_cache.entity_has_changed(
  132. row.entity, token
  133. )
  134. else:
  135. self._device_federation_outbox_stream_cache.entity_has_changed(
  136. row.entity, token
  137. )
  138. return super().process_replication_rows(stream_name, instance_name, token, rows)
  139. def get_to_device_stream_token(self) -> int:
  140. return self._device_inbox_id_gen.get_current_token()
  141. async def get_messages_for_user_devices(
  142. self,
  143. user_ids: Collection[str],
  144. from_stream_id: int,
  145. to_stream_id: int,
  146. ) -> Dict[Tuple[str, str], List[JsonDict]]:
  147. """
  148. Retrieve to-device messages for a given set of users.
  149. Only to-device messages with stream ids between the given boundaries
  150. (from < X <= to) are returned.
  151. Args:
  152. user_ids: The users to retrieve to-device messages for.
  153. from_stream_id: The lower boundary of stream id to filter with (exclusive).
  154. to_stream_id: The upper boundary of stream id to filter with (inclusive).
  155. Returns:
  156. A dictionary of (user id, device id) -> list of to-device messages.
  157. """
  158. # We expect the stream ID returned by _get_device_messages to always
  159. # be to_stream_id. So, no need to return it from this function.
  160. (
  161. user_id_device_id_to_messages,
  162. last_processed_stream_id,
  163. ) = await self._get_device_messages(
  164. user_ids=user_ids,
  165. from_stream_id=from_stream_id,
  166. to_stream_id=to_stream_id,
  167. )
  168. assert (
  169. last_processed_stream_id == to_stream_id
  170. ), "Expected _get_device_messages to process all to-device messages up to `to_stream_id`"
  171. return user_id_device_id_to_messages
  172. async def get_messages_for_device(
  173. self,
  174. user_id: str,
  175. device_id: str,
  176. from_stream_id: int,
  177. to_stream_id: int,
  178. limit: int = 100,
  179. ) -> Tuple[List[JsonDict], int]:
  180. """
  181. Retrieve to-device messages for a single user device.
  182. Only to-device messages with stream ids between the given boundaries
  183. (from < X <= to) are returned.
  184. Args:
  185. user_id: The ID of the user to retrieve messages for.
  186. device_id: The ID of the device to retrieve to-device messages for.
  187. from_stream_id: The lower boundary of stream id to filter with (exclusive).
  188. to_stream_id: The upper boundary of stream id to filter with (inclusive).
  189. limit: A limit on the number of to-device messages returned.
  190. Returns:
  191. A tuple containing:
  192. * A list of to-device messages within the given stream id range intended for
  193. the given user / device combo.
  194. * The last-processed stream ID. Subsequent calls of this function with the
  195. same device should pass this value as 'from_stream_id'.
  196. """
  197. (
  198. user_id_device_id_to_messages,
  199. last_processed_stream_id,
  200. ) = await self._get_device_messages(
  201. user_ids=[user_id],
  202. device_id=device_id,
  203. from_stream_id=from_stream_id,
  204. to_stream_id=to_stream_id,
  205. limit=limit,
  206. )
  207. if not user_id_device_id_to_messages:
  208. # There were no messages!
  209. return [], to_stream_id
  210. # Extract the messages, no need to return the user and device ID again
  211. to_device_messages = user_id_device_id_to_messages.get((user_id, device_id), [])
  212. return to_device_messages, last_processed_stream_id
  213. async def _get_device_messages(
  214. self,
  215. user_ids: Collection[str],
  216. from_stream_id: int,
  217. to_stream_id: int,
  218. device_id: Optional[str] = None,
  219. limit: Optional[int] = None,
  220. ) -> Tuple[Dict[Tuple[str, str], List[JsonDict]], int]:
  221. """
  222. Retrieve pending to-device messages for a collection of user devices.
  223. Only to-device messages with stream ids between the given boundaries
  224. (from < X <= to) are returned.
  225. Note that a stream ID can be shared by multiple copies of the same message with
  226. different recipient devices. Stream IDs are only unique in the context of a single
  227. user ID / device ID pair. Thus, applying a limit (of messages to return) when working
  228. with a sliding window of stream IDs is only possible when querying messages of a
  229. single user device.
  230. Finally, note that device IDs are not unique across users.
  231. Args:
  232. user_ids: The user IDs to filter device messages by.
  233. from_stream_id: The lower boundary of stream id to filter with (exclusive).
  234. to_stream_id: The upper boundary of stream id to filter with (inclusive).
  235. device_id: A device ID to query to-device messages for. If not provided, to-device
  236. messages from all device IDs for the given user IDs will be queried. May not be
  237. provided if `user_ids` contains more than one entry.
  238. limit: The maximum number of to-device messages to return. Can only be used when
  239. passing a single user ID / device ID tuple.
  240. Returns:
  241. A tuple containing:
  242. * A dict of (user_id, device_id) -> list of to-device messages
  243. * The last-processed stream ID. If this is less than `to_stream_id`, then
  244. there may be more messages to retrieve. If `limit` is not set, then this
  245. is always equal to 'to_stream_id'.
  246. """
  247. if not user_ids:
  248. logger.warning("No users provided upon querying for device IDs")
  249. return {}, to_stream_id
  250. # Prevent a query for one user's device also retrieving another user's device with
  251. # the same device ID (device IDs are not unique across users).
  252. if len(user_ids) > 1 and device_id is not None:
  253. raise AssertionError(
  254. "Programming error: 'device_id' cannot be supplied to "
  255. "_get_device_messages when >1 user_id has been provided"
  256. )
  257. # A limit can only be applied when querying for a single user ID / device ID tuple.
  258. # See the docstring of this function for more details.
  259. if limit is not None and device_id is None:
  260. raise AssertionError(
  261. "Programming error: _get_device_messages was passed 'limit' "
  262. "without a specific user_id/device_id"
  263. )
  264. user_ids_to_query: Set[str] = set()
  265. device_ids_to_query: Set[str] = set()
  266. # Note that a device ID could be an empty str
  267. if device_id is not None:
  268. # If a device ID was passed, use it to filter results.
  269. # Otherwise, device IDs will be derived from the given collection of user IDs.
  270. device_ids_to_query.add(device_id)
  271. # Determine which users have devices with pending messages
  272. for user_id in user_ids:
  273. if self._device_inbox_stream_cache.has_entity_changed(
  274. user_id, from_stream_id
  275. ):
  276. # This user has new messages sent to them. Query messages for them
  277. user_ids_to_query.add(user_id)
  278. if not user_ids_to_query:
  279. return {}, to_stream_id
  280. def get_device_messages_txn(
  281. txn: LoggingTransaction,
  282. ) -> Tuple[Dict[Tuple[str, str], List[JsonDict]], int]:
  283. # Build a query to select messages from any of the given devices that
  284. # are between the given stream id bounds.
  285. # If a list of device IDs was not provided, retrieve all devices IDs
  286. # for the given users. We explicitly do not query hidden devices, as
  287. # hidden devices should not receive to-device messages.
  288. # Note that this is more efficient than just dropping `device_id` from the query,
  289. # since device_inbox has an index on `(user_id, device_id, stream_id)`
  290. if not device_ids_to_query:
  291. user_device_dicts = self.db_pool.simple_select_many_txn(
  292. txn,
  293. table="devices",
  294. column="user_id",
  295. iterable=user_ids_to_query,
  296. keyvalues={"user_id": user_id, "hidden": False},
  297. retcols=("device_id",),
  298. )
  299. device_ids_to_query.update(
  300. {row["device_id"] for row in user_device_dicts}
  301. )
  302. if not device_ids_to_query:
  303. # We've ended up with no devices to query.
  304. return {}, to_stream_id
  305. # We include both user IDs and device IDs in this query, as we have an index
  306. # (device_inbox_user_stream_id) for them.
  307. user_id_many_clause_sql, user_id_many_clause_args = make_in_list_sql_clause(
  308. self.database_engine, "user_id", user_ids_to_query
  309. )
  310. (
  311. device_id_many_clause_sql,
  312. device_id_many_clause_args,
  313. ) = make_in_list_sql_clause(
  314. self.database_engine, "device_id", device_ids_to_query
  315. )
  316. sql = f"""
  317. SELECT stream_id, user_id, device_id, message_json FROM device_inbox
  318. WHERE {user_id_many_clause_sql}
  319. AND {device_id_many_clause_sql}
  320. AND ? < stream_id AND stream_id <= ?
  321. ORDER BY stream_id ASC
  322. """
  323. sql_args = (
  324. *user_id_many_clause_args,
  325. *device_id_many_clause_args,
  326. from_stream_id,
  327. to_stream_id,
  328. )
  329. # If a limit was provided, limit the data retrieved from the database
  330. if limit is not None:
  331. sql += "LIMIT ?"
  332. sql_args += (limit,)
  333. txn.execute(sql, sql_args)
  334. # Create and fill a dictionary of (user ID, device ID) -> list of messages
  335. # intended for each device.
  336. last_processed_stream_pos = to_stream_id
  337. recipient_device_to_messages: Dict[Tuple[str, str], List[JsonDict]] = {}
  338. rowcount = 0
  339. for row in txn:
  340. rowcount += 1
  341. last_processed_stream_pos = row[0]
  342. recipient_user_id = row[1]
  343. recipient_device_id = row[2]
  344. message_dict = db_to_json(row[3])
  345. # Store the device details
  346. recipient_device_to_messages.setdefault(
  347. (recipient_user_id, recipient_device_id), []
  348. ).append(message_dict)
  349. if limit is not None and rowcount == limit:
  350. # We ended up bumping up against the message limit. There may be more messages
  351. # to retrieve. Return what we have, as well as the last stream position that
  352. # was processed.
  353. #
  354. # The caller is expected to set this as the lower (exclusive) bound
  355. # for the next query of this device.
  356. return recipient_device_to_messages, last_processed_stream_pos
  357. # The limit was not reached, thus we know that recipient_device_to_messages
  358. # contains all to-device messages for the given device and stream id range.
  359. #
  360. # We return to_stream_id, which the caller should then provide as the lower
  361. # (exclusive) bound on the next query of this device.
  362. return recipient_device_to_messages, to_stream_id
  363. return await self.db_pool.runInteraction(
  364. "get_device_messages", get_device_messages_txn
  365. )
  366. @trace
  367. async def delete_messages_for_device(
  368. self, user_id: str, device_id: Optional[str], up_to_stream_id: int
  369. ) -> int:
  370. """
  371. Args:
  372. user_id: The recipient user_id.
  373. device_id: The recipient device_id.
  374. up_to_stream_id: Where to delete messages up to.
  375. Returns:
  376. The number of messages deleted.
  377. """
  378. # If we have cached the last stream id we've deleted up to, we can
  379. # check if there is likely to be anything that needs deleting
  380. last_deleted_stream_id = self._last_device_delete_cache.get(
  381. (user_id, device_id), None
  382. )
  383. set_tag("last_deleted_stream_id", last_deleted_stream_id)
  384. if last_deleted_stream_id:
  385. has_changed = self._device_inbox_stream_cache.has_entity_changed(
  386. user_id, last_deleted_stream_id
  387. )
  388. if not has_changed:
  389. log_kv({"message": "No changes in cache since last check"})
  390. return 0
  391. def delete_messages_for_device_txn(txn: LoggingTransaction) -> int:
  392. sql = (
  393. "DELETE FROM device_inbox"
  394. " WHERE user_id = ? AND device_id = ?"
  395. " AND stream_id <= ?"
  396. )
  397. txn.execute(sql, (user_id, device_id, up_to_stream_id))
  398. return txn.rowcount
  399. count = await self.db_pool.runInteraction(
  400. "delete_messages_for_device", delete_messages_for_device_txn
  401. )
  402. log_kv({"message": f"deleted {count} messages for device", "count": count})
  403. # Update the cache, ensuring that we only ever increase the value
  404. updated_last_deleted_stream_id = self._last_device_delete_cache.get(
  405. (user_id, device_id), 0
  406. )
  407. self._last_device_delete_cache[(user_id, device_id)] = max(
  408. updated_last_deleted_stream_id, up_to_stream_id
  409. )
  410. return count
  411. @trace
  412. async def get_new_device_msgs_for_remote(
  413. self, destination: str, last_stream_id: int, current_stream_id: int, limit: int
  414. ) -> Tuple[List[JsonDict], int]:
  415. """
  416. Args:
  417. destination: The name of the remote server.
  418. last_stream_id: The last position of the device message stream
  419. that the server sent up to.
  420. current_stream_id: The current position of the device message stream.
  421. Returns:
  422. A list of messages for the device and where in the stream the messages got to.
  423. """
  424. set_tag("destination", destination)
  425. set_tag("last_stream_id", last_stream_id)
  426. set_tag("current_stream_id", current_stream_id)
  427. set_tag("limit", limit)
  428. has_changed = self._device_federation_outbox_stream_cache.has_entity_changed(
  429. destination, last_stream_id
  430. )
  431. if not has_changed or last_stream_id == current_stream_id:
  432. log_kv({"message": "No new messages in stream"})
  433. return [], current_stream_id
  434. if limit <= 0:
  435. # This can happen if we run out of room for EDUs in the transaction.
  436. return [], last_stream_id
  437. @trace
  438. def get_new_messages_for_remote_destination_txn(
  439. txn: LoggingTransaction,
  440. ) -> Tuple[List[JsonDict], int]:
  441. sql = (
  442. "SELECT stream_id, messages_json FROM device_federation_outbox"
  443. " WHERE destination = ?"
  444. " AND ? < stream_id AND stream_id <= ?"
  445. " ORDER BY stream_id ASC"
  446. " LIMIT ?"
  447. )
  448. txn.execute(sql, (destination, last_stream_id, current_stream_id, limit))
  449. messages = []
  450. stream_pos = current_stream_id
  451. for row in txn:
  452. stream_pos = row[0]
  453. messages.append(db_to_json(row[1]))
  454. # If the limit was not reached we know that there's no more data for this
  455. # user/device pair up to current_stream_id.
  456. if len(messages) < limit:
  457. log_kv({"message": "Set stream position to current position"})
  458. stream_pos = current_stream_id
  459. return messages, stream_pos
  460. return await self.db_pool.runInteraction(
  461. "get_new_device_msgs_for_remote",
  462. get_new_messages_for_remote_destination_txn,
  463. )
  464. @trace
  465. async def delete_device_msgs_for_remote(
  466. self, destination: str, up_to_stream_id: int
  467. ) -> None:
  468. """Used to delete messages when the remote destination acknowledges
  469. their receipt.
  470. Args:
  471. destination: The destination server_name
  472. up_to_stream_id: Where to delete messages up to.
  473. """
  474. def delete_messages_for_remote_destination_txn(txn: LoggingTransaction) -> None:
  475. sql = (
  476. "DELETE FROM device_federation_outbox"
  477. " WHERE destination = ?"
  478. " AND stream_id <= ?"
  479. )
  480. txn.execute(sql, (destination, up_to_stream_id))
  481. await self.db_pool.runInteraction(
  482. "delete_device_msgs_for_remote", delete_messages_for_remote_destination_txn
  483. )
  484. async def get_all_new_device_messages(
  485. self, instance_name: str, last_id: int, current_id: int, limit: int
  486. ) -> Tuple[List[Tuple[int, tuple]], int, bool]:
  487. """Get updates for to device replication stream.
  488. Args:
  489. instance_name: The writer we want to fetch updates from. Unused
  490. here since there is only ever one writer.
  491. last_id: The token to fetch updates from. Exclusive.
  492. current_id: The token to fetch updates up to. Inclusive.
  493. limit: The requested limit for the number of rows to return. The
  494. function may return more or fewer rows.
  495. Returns:
  496. A tuple consisting of: the updates, a token to use to fetch
  497. subsequent updates, and whether we returned fewer rows than exists
  498. between the requested tokens due to the limit.
  499. The token returned can be used in a subsequent call to this
  500. function to get further updatees.
  501. The updates are a list of 2-tuples of stream ID and the row data
  502. """
  503. if last_id == current_id:
  504. return [], current_id, False
  505. def get_all_new_device_messages_txn(
  506. txn: LoggingTransaction,
  507. ) -> Tuple[List[Tuple[int, tuple]], int, bool]:
  508. # We limit like this as we might have multiple rows per stream_id, and
  509. # we want to make sure we always get all entries for any stream_id
  510. # we return.
  511. upper_pos = min(current_id, last_id + limit)
  512. sql = (
  513. "SELECT max(stream_id), user_id"
  514. " FROM device_inbox"
  515. " WHERE ? < stream_id AND stream_id <= ?"
  516. " GROUP BY user_id"
  517. )
  518. txn.execute(sql, (last_id, upper_pos))
  519. updates = [(row[0], row[1:]) for row in txn]
  520. sql = (
  521. "SELECT max(stream_id), destination"
  522. " FROM device_federation_outbox"
  523. " WHERE ? < stream_id AND stream_id <= ?"
  524. " GROUP BY destination"
  525. )
  526. txn.execute(sql, (last_id, upper_pos))
  527. updates.extend((row[0], row[1:]) for row in txn)
  528. # Order by ascending stream ordering
  529. updates.sort()
  530. limited = False
  531. upto_token = current_id
  532. if len(updates) >= limit:
  533. upto_token = updates[-1][0]
  534. limited = True
  535. return updates, upto_token, limited
  536. return await self.db_pool.runInteraction(
  537. "get_all_new_device_messages", get_all_new_device_messages_txn
  538. )
  539. @trace
  540. async def add_messages_to_device_inbox(
  541. self,
  542. local_messages_by_user_then_device: Dict[str, Dict[str, JsonDict]],
  543. remote_messages_by_destination: Dict[str, JsonDict],
  544. ) -> int:
  545. """Used to send messages from this server.
  546. Args:
  547. local_messages_by_user_then_device:
  548. Dictionary of recipient user_id to recipient device_id to message.
  549. remote_messages_by_destination:
  550. Dictionary of destination server_name to the EDU JSON to send.
  551. Returns:
  552. The new stream_id.
  553. """
  554. assert self._can_write_to_device
  555. def add_messages_txn(
  556. txn: LoggingTransaction, now_ms: int, stream_id: int
  557. ) -> None:
  558. # Add the local messages directly to the local inbox.
  559. self._add_messages_to_local_device_inbox_txn(
  560. txn, stream_id, local_messages_by_user_then_device
  561. )
  562. # Add the remote messages to the federation outbox.
  563. # We'll send them to a remote server when we next send a
  564. # federation transaction to that destination.
  565. self.db_pool.simple_insert_many_txn(
  566. txn,
  567. table="device_federation_outbox",
  568. keys=(
  569. "destination",
  570. "stream_id",
  571. "queued_ts",
  572. "messages_json",
  573. "instance_name",
  574. ),
  575. values=[
  576. (
  577. destination,
  578. stream_id,
  579. now_ms,
  580. json_encoder.encode(edu),
  581. self._instance_name,
  582. )
  583. for destination, edu in remote_messages_by_destination.items()
  584. ],
  585. )
  586. if remote_messages_by_destination:
  587. issue9533_logger.debug(
  588. "Queued outgoing to-device messages with stream_id %i for %s",
  589. stream_id,
  590. list(remote_messages_by_destination.keys()),
  591. )
  592. async with self._device_inbox_id_gen.get_next() as stream_id:
  593. now_ms = self._clock.time_msec()
  594. await self.db_pool.runInteraction(
  595. "add_messages_to_device_inbox", add_messages_txn, now_ms, stream_id
  596. )
  597. for user_id in local_messages_by_user_then_device.keys():
  598. self._device_inbox_stream_cache.entity_has_changed(user_id, stream_id)
  599. for destination in remote_messages_by_destination.keys():
  600. self._device_federation_outbox_stream_cache.entity_has_changed(
  601. destination, stream_id
  602. )
  603. return self._device_inbox_id_gen.get_current_token()
  604. async def add_messages_from_remote_to_device_inbox(
  605. self,
  606. origin: str,
  607. message_id: str,
  608. local_messages_by_user_then_device: Dict[str, Dict[str, JsonDict]],
  609. ) -> int:
  610. assert self._can_write_to_device
  611. def add_messages_txn(
  612. txn: LoggingTransaction, now_ms: int, stream_id: int
  613. ) -> None:
  614. # Check if we've already inserted a matching message_id for that
  615. # origin. This can happen if the origin doesn't receive our
  616. # acknowledgement from the first time we received the message.
  617. already_inserted = self.db_pool.simple_select_one_txn(
  618. txn,
  619. table="device_federation_inbox",
  620. keyvalues={"origin": origin, "message_id": message_id},
  621. retcols=("message_id",),
  622. allow_none=True,
  623. )
  624. if already_inserted is not None:
  625. return
  626. # Add an entry for this message_id so that we know we've processed
  627. # it.
  628. self.db_pool.simple_insert_txn(
  629. txn,
  630. table="device_federation_inbox",
  631. values={
  632. "origin": origin,
  633. "message_id": message_id,
  634. "received_ts": now_ms,
  635. },
  636. )
  637. # Add the messages to the appropriate local device inboxes so that
  638. # they'll be sent to the devices when they next sync.
  639. self._add_messages_to_local_device_inbox_txn(
  640. txn, stream_id, local_messages_by_user_then_device
  641. )
  642. async with self._device_inbox_id_gen.get_next() as stream_id:
  643. now_ms = self._clock.time_msec()
  644. await self.db_pool.runInteraction(
  645. "add_messages_from_remote_to_device_inbox",
  646. add_messages_txn,
  647. now_ms,
  648. stream_id,
  649. )
  650. for user_id in local_messages_by_user_then_device.keys():
  651. self._device_inbox_stream_cache.entity_has_changed(user_id, stream_id)
  652. return stream_id
  653. def _add_messages_to_local_device_inbox_txn(
  654. self,
  655. txn: LoggingTransaction,
  656. stream_id: int,
  657. messages_by_user_then_device: Dict[str, Dict[str, JsonDict]],
  658. ) -> None:
  659. assert self._can_write_to_device
  660. local_by_user_then_device = {}
  661. for user_id, messages_by_device in messages_by_user_then_device.items():
  662. messages_json_for_user = {}
  663. devices = list(messages_by_device.keys())
  664. if len(devices) == 1 and devices[0] == "*":
  665. # Handle wildcard device_ids.
  666. # We exclude hidden devices (such as cross-signing keys) here as they are
  667. # not expected to receive to-device messages.
  668. devices = self.db_pool.simple_select_onecol_txn(
  669. txn,
  670. table="devices",
  671. keyvalues={"user_id": user_id, "hidden": False},
  672. retcol="device_id",
  673. )
  674. message_json = json_encoder.encode(messages_by_device["*"])
  675. for device_id in devices:
  676. # Add the message for all devices for this user on this
  677. # server.
  678. messages_json_for_user[device_id] = message_json
  679. else:
  680. if not devices:
  681. continue
  682. # We exclude hidden devices (such as cross-signing keys) here as they are
  683. # not expected to receive to-device messages.
  684. rows = self.db_pool.simple_select_many_txn(
  685. txn,
  686. table="devices",
  687. keyvalues={"user_id": user_id, "hidden": False},
  688. column="device_id",
  689. iterable=devices,
  690. retcols=("device_id",),
  691. )
  692. for row in rows:
  693. # Only insert into the local inbox if the device exists on
  694. # this server
  695. device_id = row["device_id"]
  696. message_json = json_encoder.encode(messages_by_device[device_id])
  697. messages_json_for_user[device_id] = message_json
  698. if messages_json_for_user:
  699. local_by_user_then_device[user_id] = messages_json_for_user
  700. if not local_by_user_then_device:
  701. return
  702. self.db_pool.simple_insert_many_txn(
  703. txn,
  704. table="device_inbox",
  705. keys=("user_id", "device_id", "stream_id", "message_json", "instance_name"),
  706. values=[
  707. (user_id, device_id, stream_id, message_json, self._instance_name)
  708. for user_id, messages_by_device in local_by_user_then_device.items()
  709. for device_id, message_json in messages_by_device.items()
  710. ],
  711. )
  712. issue9533_logger.debug(
  713. "Stored to-device messages with stream_id %i for %s",
  714. stream_id,
  715. [
  716. (user_id, device_id)
  717. for (user_id, messages_by_device) in local_by_user_then_device.items()
  718. for device_id in messages_by_device.keys()
  719. ],
  720. )
  721. class DeviceInboxBackgroundUpdateStore(SQLBaseStore):
  722. DEVICE_INBOX_STREAM_ID = "device_inbox_stream_drop"
  723. REMOVE_DEAD_DEVICES_FROM_INBOX = "remove_dead_devices_from_device_inbox"
  724. def __init__(
  725. self,
  726. database: DatabasePool,
  727. db_conn: LoggingDatabaseConnection,
  728. hs: "HomeServer",
  729. ):
  730. super().__init__(database, db_conn, hs)
  731. self.db_pool.updates.register_background_index_update(
  732. "device_inbox_stream_index",
  733. index_name="device_inbox_stream_id_user_id",
  734. table="device_inbox",
  735. columns=["stream_id", "user_id"],
  736. )
  737. self.db_pool.updates.register_background_update_handler(
  738. self.DEVICE_INBOX_STREAM_ID, self._background_drop_index_device_inbox
  739. )
  740. self.db_pool.updates.register_background_update_handler(
  741. self.REMOVE_DEAD_DEVICES_FROM_INBOX,
  742. self._remove_dead_devices_from_device_inbox,
  743. )
  744. async def _background_drop_index_device_inbox(
  745. self, progress: JsonDict, batch_size: int
  746. ) -> int:
  747. def reindex_txn(conn: LoggingDatabaseConnection) -> None:
  748. txn = conn.cursor()
  749. txn.execute("DROP INDEX IF EXISTS device_inbox_stream_id")
  750. txn.close()
  751. await self.db_pool.runWithConnection(reindex_txn)
  752. await self.db_pool.updates._end_background_update(self.DEVICE_INBOX_STREAM_ID)
  753. return 1
  754. async def _remove_dead_devices_from_device_inbox(
  755. self,
  756. progress: JsonDict,
  757. batch_size: int,
  758. ) -> int:
  759. """A background update to remove devices that were either deleted or hidden from
  760. the device_inbox table.
  761. Args:
  762. progress: The update's progress dict.
  763. batch_size: The batch size for this update.
  764. Returns:
  765. The number of rows deleted.
  766. """
  767. def _remove_dead_devices_from_device_inbox_txn(
  768. txn: LoggingTransaction,
  769. ) -> Tuple[int, bool]:
  770. if "max_stream_id" in progress:
  771. max_stream_id = progress["max_stream_id"]
  772. else:
  773. txn.execute("SELECT max(stream_id) FROM device_inbox")
  774. # There's a type mismatch here between how we want to type the row and
  775. # what fetchone says it returns, but we silence it because we know that
  776. # res can't be None.
  777. res = cast(Tuple[Optional[int]], txn.fetchone())
  778. if res[0] is None:
  779. # this can only happen if the `device_inbox` table is empty, in which
  780. # case we have no work to do.
  781. return 0, True
  782. else:
  783. max_stream_id = res[0]
  784. start = progress.get("stream_id", 0)
  785. stop = start + batch_size
  786. # delete rows in `device_inbox` which do *not* correspond to a known,
  787. # unhidden device.
  788. sql = """
  789. DELETE FROM device_inbox
  790. WHERE
  791. stream_id >= ? AND stream_id < ?
  792. AND NOT EXISTS (
  793. SELECT * FROM devices d
  794. WHERE
  795. d.device_id=device_inbox.device_id
  796. AND d.user_id=device_inbox.user_id
  797. AND NOT hidden
  798. )
  799. """
  800. txn.execute(sql, (start, stop))
  801. self.db_pool.updates._background_update_progress_txn(
  802. txn,
  803. self.REMOVE_DEAD_DEVICES_FROM_INBOX,
  804. {
  805. "stream_id": stop,
  806. "max_stream_id": max_stream_id,
  807. },
  808. )
  809. return stop > max_stream_id
  810. finished = await self.db_pool.runInteraction(
  811. "_remove_devices_from_device_inbox_txn",
  812. _remove_dead_devices_from_device_inbox_txn,
  813. )
  814. if finished:
  815. await self.db_pool.updates._end_background_update(
  816. self.REMOVE_DEAD_DEVICES_FROM_INBOX,
  817. )
  818. return batch_size
  819. class DeviceInboxStore(DeviceInboxWorkerStore, DeviceInboxBackgroundUpdateStore):
  820. pass