You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 
 
 
 

1132 lines
39 KiB

  1. # Copyright 2014-2016 OpenMarket Ltd
  2. # Copyright 2018 New Vector Ltd
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. import logging
  16. from typing import (
  17. TYPE_CHECKING,
  18. Any,
  19. Collection,
  20. Dict,
  21. Iterable,
  22. List,
  23. Mapping,
  24. Optional,
  25. Sequence,
  26. Tuple,
  27. cast,
  28. )
  29. from immutabledict import immutabledict
  30. from synapse.api.constants import EduTypes
  31. from synapse.replication.tcp.streams import ReceiptsStream
  32. from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause
  33. from synapse.storage.database import (
  34. DatabasePool,
  35. LoggingDatabaseConnection,
  36. LoggingTransaction,
  37. )
  38. from synapse.storage.engines import PostgresEngine
  39. from synapse.storage.engines._base import IsolationLevel
  40. from synapse.storage.util.id_generators import (
  41. AbstractStreamIdGenerator,
  42. MultiWriterIdGenerator,
  43. StreamIdGenerator,
  44. )
  45. from synapse.types import (
  46. JsonDict,
  47. JsonMapping,
  48. MultiWriterStreamToken,
  49. PersistedPosition,
  50. )
  51. from synapse.util import json_encoder
  52. from synapse.util.caches.descriptors import cached, cachedList
  53. from synapse.util.caches.stream_change_cache import StreamChangeCache
  54. if TYPE_CHECKING:
  55. from synapse.server import HomeServer
  56. logger = logging.getLogger(__name__)
  57. class ReceiptsWorkerStore(SQLBaseStore):
  58. def __init__(
  59. self,
  60. database: DatabasePool,
  61. db_conn: LoggingDatabaseConnection,
  62. hs: "HomeServer",
  63. ):
  64. self._instance_name = hs.get_instance_name()
  65. # In the worker store this is an ID tracker which we overwrite in the non-worker
  66. # class below that is used on the main process.
  67. self._receipts_id_gen: AbstractStreamIdGenerator
  68. if isinstance(database.engine, PostgresEngine):
  69. self._can_write_to_receipts = (
  70. self._instance_name in hs.config.worker.writers.receipts
  71. )
  72. self._receipts_id_gen = MultiWriterIdGenerator(
  73. db_conn=db_conn,
  74. db=database,
  75. notifier=hs.get_replication_notifier(),
  76. stream_name="receipts",
  77. instance_name=self._instance_name,
  78. tables=[("receipts_linearized", "instance_name", "stream_id")],
  79. sequence_name="receipts_sequence",
  80. writers=hs.config.worker.writers.receipts,
  81. )
  82. else:
  83. self._can_write_to_receipts = True
  84. # Multiple writers are not supported for SQLite.
  85. #
  86. # We shouldn't be running in worker mode with SQLite, but its useful
  87. # to support it for unit tests.
  88. self._receipts_id_gen = StreamIdGenerator(
  89. db_conn,
  90. hs.get_replication_notifier(),
  91. "receipts_linearized",
  92. "stream_id",
  93. is_writer=hs.get_instance_name() in hs.config.worker.writers.receipts,
  94. )
  95. super().__init__(database, db_conn, hs)
  96. max_receipts_stream_id = self.get_max_receipt_stream_id()
  97. receipts_stream_prefill, min_receipts_stream_id = self.db_pool.get_cache_dict(
  98. db_conn,
  99. "receipts_linearized",
  100. entity_column="room_id",
  101. stream_column="stream_id",
  102. max_value=max_receipts_stream_id.stream,
  103. limit=10000,
  104. )
  105. self._receipts_stream_cache = StreamChangeCache(
  106. "ReceiptsRoomChangeCache",
  107. min_receipts_stream_id,
  108. prefilled_cache=receipts_stream_prefill,
  109. )
  110. def get_max_receipt_stream_id(self) -> MultiWriterStreamToken:
  111. """Get the current max stream ID for receipts stream"""
  112. min_pos = self._receipts_id_gen.get_current_token()
  113. positions = {}
  114. if isinstance(self._receipts_id_gen, MultiWriterIdGenerator):
  115. # The `min_pos` is the minimum position that we know all instances
  116. # have finished persisting to, so we only care about instances whose
  117. # positions are ahead of that. (Instance positions can be behind the
  118. # min position as there are times we can work out that the minimum
  119. # position is ahead of the naive minimum across all current
  120. # positions. See MultiWriterIdGenerator for details)
  121. positions = {
  122. i: p
  123. for i, p in self._receipts_id_gen.get_positions().items()
  124. if p > min_pos
  125. }
  126. return MultiWriterStreamToken(
  127. stream=min_pos, instance_map=immutabledict(positions)
  128. )
  129. def get_receipt_stream_id_for_instance(self, instance_name: str) -> int:
  130. return self._receipts_id_gen.get_current_token_for_writer(instance_name)
  131. def get_last_unthreaded_receipt_for_user_txn(
  132. self,
  133. txn: LoggingTransaction,
  134. user_id: str,
  135. room_id: str,
  136. receipt_types: Collection[str],
  137. ) -> Optional[Tuple[str, int]]:
  138. """
  139. Fetch the event ID and stream_ordering for the latest unthreaded receipt
  140. in a room with one of the given receipt types.
  141. Args:
  142. user_id: The user to fetch receipts for.
  143. room_id: The room ID to fetch the receipt for.
  144. receipt_types: The receipt types to fetch.
  145. Returns:
  146. The event ID and stream ordering of the latest receipt, if one exists.
  147. """
  148. clause, args = make_in_list_sql_clause(
  149. self.database_engine, "receipt_type", receipt_types
  150. )
  151. sql = f"""
  152. SELECT event_id, stream_ordering
  153. FROM receipts_linearized
  154. INNER JOIN events USING (room_id, event_id)
  155. WHERE {clause}
  156. AND user_id = ?
  157. AND room_id = ?
  158. AND thread_id IS NULL
  159. ORDER BY stream_ordering DESC
  160. LIMIT 1
  161. """
  162. args.extend((user_id, room_id))
  163. txn.execute(sql, args)
  164. return cast(Optional[Tuple[str, int]], txn.fetchone())
  165. async def get_receipts_for_user(
  166. self, user_id: str, receipt_types: Iterable[str]
  167. ) -> Dict[str, str]:
  168. """
  169. Fetch the event IDs for the latest receipts sent by the given user.
  170. Args:
  171. user_id: The user to fetch receipts for.
  172. receipt_types: The receipt types to check.
  173. Returns:
  174. A map of room ID to the event ID of the latest receipt for that room.
  175. If the user has not sent a receipt to a room then it will not appear
  176. in the returned dictionary.
  177. """
  178. results = await self.get_receipts_for_user_with_orderings(
  179. user_id, receipt_types
  180. )
  181. # Reduce the result to room ID -> event ID.
  182. return {
  183. room_id: room_result["event_id"] for room_id, room_result in results.items()
  184. }
  185. async def get_receipts_for_user_with_orderings(
  186. self, user_id: str, receipt_types: Iterable[str]
  187. ) -> JsonDict:
  188. """
  189. Fetch receipts for all rooms that the given user is joined to.
  190. Args:
  191. user_id: The user to fetch receipts for.
  192. receipt_types: The receipt types to fetch. Earlier receipt types
  193. are given priority if multiple receipts point to the same event.
  194. Returns:
  195. A map of room ID to the latest receipt (for the given types).
  196. """
  197. results: JsonDict = {}
  198. for receipt_type in receipt_types:
  199. partial_result = await self._get_receipts_for_user_with_orderings(
  200. user_id, receipt_type
  201. )
  202. for room_id, room_result in partial_result.items():
  203. # If the room has not yet been seen, or the receipt is newer,
  204. # use it.
  205. if (
  206. room_id not in results
  207. or results[room_id]["stream_ordering"]
  208. < room_result["stream_ordering"]
  209. ):
  210. results[room_id] = room_result
  211. return results
  212. @cached()
  213. async def _get_receipts_for_user_with_orderings(
  214. self, user_id: str, receipt_type: str
  215. ) -> JsonMapping:
  216. """
  217. Fetch receipts for all rooms that the given user is joined to.
  218. Args:
  219. user_id: The user to fetch receipts for.
  220. receipt_type: The receipt type to fetch.
  221. Returns:
  222. A map of room ID to the latest receipt information.
  223. """
  224. def f(txn: LoggingTransaction) -> List[Tuple[str, str, int, int]]:
  225. sql = (
  226. "SELECT rl.room_id, rl.event_id,"
  227. " e.topological_ordering, e.stream_ordering"
  228. " FROM receipts_linearized AS rl"
  229. " INNER JOIN events AS e USING (room_id, event_id)"
  230. " WHERE rl.room_id = e.room_id"
  231. " AND rl.event_id = e.event_id"
  232. " AND user_id = ?"
  233. " AND receipt_type = ?"
  234. )
  235. txn.execute(sql, (user_id, receipt_type))
  236. return cast(List[Tuple[str, str, int, int]], txn.fetchall())
  237. rows = await self.db_pool.runInteraction(
  238. "get_receipts_for_user_with_orderings", f
  239. )
  240. return {
  241. row[0]: {
  242. "event_id": row[1],
  243. "topological_ordering": row[2],
  244. "stream_ordering": row[3],
  245. }
  246. for row in rows
  247. }
  248. async def get_linearized_receipts_for_rooms(
  249. self,
  250. room_ids: Iterable[str],
  251. to_key: MultiWriterStreamToken,
  252. from_key: Optional[MultiWriterStreamToken] = None,
  253. ) -> List[JsonMapping]:
  254. """Get receipts for multiple rooms for sending to clients.
  255. Args:
  256. room_id: The room IDs to fetch receipts of.
  257. to_key: Max stream id to fetch receipts up to.
  258. from_key: Min stream id to fetch receipts from. None fetches
  259. from the start.
  260. Returns:
  261. A list of receipts.
  262. """
  263. room_ids = set(room_ids)
  264. if from_key is not None:
  265. # Only ask the database about rooms where there have been new
  266. # receipts added since `from_key`
  267. room_ids = self._receipts_stream_cache.get_entities_changed(
  268. room_ids, from_key.stream
  269. )
  270. results = await self._get_linearized_receipts_for_rooms(
  271. room_ids, to_key, from_key=from_key
  272. )
  273. return [ev for res in results.values() for ev in res]
  274. async def get_linearized_receipts_for_room(
  275. self,
  276. room_id: str,
  277. to_key: MultiWriterStreamToken,
  278. from_key: Optional[MultiWriterStreamToken] = None,
  279. ) -> Sequence[JsonMapping]:
  280. """Get receipts for a single room for sending to clients.
  281. Args:
  282. room_ids: The room id.
  283. to_key: Max stream id to fetch receipts up to.
  284. from_key: Min stream id to fetch receipts from. None fetches
  285. from the start.
  286. Returns:
  287. A list of receipts.
  288. """
  289. if from_key is not None:
  290. # Check the cache first to see if any new receipts have been added
  291. # since`from_key`. If not we can no-op.
  292. if not self._receipts_stream_cache.has_entity_changed(
  293. room_id, from_key.stream
  294. ):
  295. return []
  296. return await self._get_linearized_receipts_for_room(room_id, to_key, from_key)
  297. @cached(tree=True)
  298. async def _get_linearized_receipts_for_room(
  299. self,
  300. room_id: str,
  301. to_key: MultiWriterStreamToken,
  302. from_key: Optional[MultiWriterStreamToken] = None,
  303. ) -> Sequence[JsonMapping]:
  304. """See get_linearized_receipts_for_room"""
  305. def f(txn: LoggingTransaction) -> List[Tuple[str, str, str, str]]:
  306. if from_key:
  307. sql = """
  308. SELECT stream_id, instance_name, receipt_type, user_id, event_id, data
  309. FROM receipts_linearized
  310. WHERE room_id = ? AND stream_id > ? AND stream_id <= ?
  311. """
  312. txn.execute(
  313. sql, (room_id, from_key.stream, to_key.get_max_stream_pos())
  314. )
  315. else:
  316. sql = """
  317. SELECT stream_id, instance_name, receipt_type, user_id, event_id, data
  318. FROM receipts_linearized WHERE
  319. room_id = ? AND stream_id <= ?
  320. """
  321. txn.execute(sql, (room_id, to_key.get_max_stream_pos()))
  322. return [
  323. (receipt_type, user_id, event_id, data)
  324. for stream_id, instance_name, receipt_type, user_id, event_id, data in txn
  325. if MultiWriterStreamToken.is_stream_position_in_range(
  326. from_key, to_key, instance_name, stream_id
  327. )
  328. ]
  329. rows = await self.db_pool.runInteraction("get_linearized_receipts_for_room", f)
  330. if not rows:
  331. return []
  332. content: JsonDict = {}
  333. for receipt_type, user_id, event_id, data in rows:
  334. content.setdefault(event_id, {}).setdefault(receipt_type, {})[
  335. user_id
  336. ] = db_to_json(data)
  337. return [{"type": EduTypes.RECEIPT, "room_id": room_id, "content": content}]
  338. @cachedList(
  339. cached_method_name="_get_linearized_receipts_for_room",
  340. list_name="room_ids",
  341. num_args=3,
  342. )
  343. async def _get_linearized_receipts_for_rooms(
  344. self,
  345. room_ids: Collection[str],
  346. to_key: MultiWriterStreamToken,
  347. from_key: Optional[MultiWriterStreamToken] = None,
  348. ) -> Mapping[str, Sequence[JsonMapping]]:
  349. if not room_ids:
  350. return {}
  351. def f(
  352. txn: LoggingTransaction,
  353. ) -> List[Tuple[str, str, str, str, Optional[str], str]]:
  354. if from_key:
  355. sql = """
  356. SELECT stream_id, instance_name, room_id, receipt_type,
  357. user_id, event_id, thread_id, data
  358. FROM receipts_linearized WHERE
  359. stream_id > ? AND stream_id <= ? AND
  360. """
  361. clause, args = make_in_list_sql_clause(
  362. self.database_engine, "room_id", room_ids
  363. )
  364. txn.execute(
  365. sql + clause,
  366. [from_key.stream, to_key.get_max_stream_pos()] + list(args),
  367. )
  368. else:
  369. sql = """
  370. SELECT stream_id, instance_name, room_id, receipt_type,
  371. user_id, event_id, thread_id, data
  372. FROM receipts_linearized WHERE
  373. stream_id <= ? AND
  374. """
  375. clause, args = make_in_list_sql_clause(
  376. self.database_engine, "room_id", room_ids
  377. )
  378. txn.execute(sql + clause, [to_key.get_max_stream_pos()] + list(args))
  379. return [
  380. (room_id, receipt_type, user_id, event_id, thread_id, data)
  381. for stream_id, instance_name, room_id, receipt_type, user_id, event_id, thread_id, data in txn
  382. if MultiWriterStreamToken.is_stream_position_in_range(
  383. from_key, to_key, instance_name, stream_id
  384. )
  385. ]
  386. txn_results = await self.db_pool.runInteraction(
  387. "_get_linearized_receipts_for_rooms", f
  388. )
  389. results: JsonDict = {}
  390. for room_id, receipt_type, user_id, event_id, thread_id, data in txn_results:
  391. # We want a single event per room, since we want to batch the
  392. # receipts by room, event and type.
  393. room_event = results.setdefault(
  394. room_id,
  395. {"type": EduTypes.RECEIPT, "room_id": room_id, "content": {}},
  396. )
  397. # The content is of the form:
  398. # {"$foo:bar": { "read": { "@user:host": <receipt> }, .. }, .. }
  399. event_entry = room_event["content"].setdefault(event_id, {})
  400. receipt_type_dict = event_entry.setdefault(receipt_type, {})
  401. receipt_type_dict[user_id] = db_to_json(data)
  402. if thread_id:
  403. receipt_type_dict[user_id]["thread_id"] = thread_id
  404. results = {
  405. room_id: [results[room_id]] if room_id in results else []
  406. for room_id in room_ids
  407. }
  408. return results
  409. @cached(
  410. num_args=2,
  411. )
  412. async def get_linearized_receipts_for_all_rooms(
  413. self,
  414. to_key: MultiWriterStreamToken,
  415. from_key: Optional[MultiWriterStreamToken] = None,
  416. ) -> Mapping[str, JsonMapping]:
  417. """Get receipts for all rooms between two stream_ids, up
  418. to a limit of the latest 100 read receipts.
  419. Args:
  420. to_key: Max stream id to fetch receipts up to.
  421. from_key: Min stream id to fetch receipts from. None fetches
  422. from the start.
  423. Returns:
  424. A dictionary of roomids to a list of receipts.
  425. """
  426. def f(txn: LoggingTransaction) -> List[Tuple[str, str, str, str, str]]:
  427. if from_key:
  428. sql = """
  429. SELECT stream_id, instance_name, room_id, receipt_type, user_id, event_id, data
  430. FROM receipts_linearized WHERE
  431. stream_id > ? AND stream_id <= ?
  432. ORDER BY stream_id DESC
  433. LIMIT 100
  434. """
  435. txn.execute(sql, [from_key.stream, to_key.get_max_stream_pos()])
  436. else:
  437. sql = """
  438. SELECT stream_id, instance_name, room_id, receipt_type, user_id, event_id, data
  439. FROM receipts_linearized WHERE
  440. stream_id <= ?
  441. ORDER BY stream_id DESC
  442. LIMIT 100
  443. """
  444. txn.execute(sql, [to_key.get_max_stream_pos()])
  445. return [
  446. (room_id, receipt_type, user_id, event_id, data)
  447. for stream_id, instance_name, room_id, receipt_type, user_id, event_id, data in txn
  448. if MultiWriterStreamToken.is_stream_position_in_range(
  449. from_key, to_key, instance_name, stream_id
  450. )
  451. ]
  452. txn_results = await self.db_pool.runInteraction(
  453. "get_linearized_receipts_for_all_rooms", f
  454. )
  455. results: JsonDict = {}
  456. for room_id, receipt_type, user_id, event_id, data in txn_results:
  457. # We want a single event per room, since we want to batch the
  458. # receipts by room, event and type.
  459. room_event = results.setdefault(
  460. room_id,
  461. {"type": EduTypes.RECEIPT, "room_id": room_id, "content": {}},
  462. )
  463. # The content is of the form:
  464. # {"$foo:bar": { "read": { "@user:host": <receipt> }, .. }, .. }
  465. event_entry = room_event["content"].setdefault(event_id, {})
  466. receipt_type_dict = event_entry.setdefault(receipt_type, {})
  467. receipt_type_dict[user_id] = db_to_json(data)
  468. return results
  469. async def get_users_sent_receipts_between(
  470. self, last_id: int, current_id: int
  471. ) -> List[str]:
  472. """Get all users who sent receipts between `last_id` exclusive and
  473. `current_id` inclusive.
  474. Returns:
  475. The list of users.
  476. """
  477. if last_id == current_id:
  478. return []
  479. def _get_users_sent_receipts_between_txn(txn: LoggingTransaction) -> List[str]:
  480. sql = """
  481. SELECT DISTINCT user_id FROM receipts_linearized
  482. WHERE ? < stream_id AND stream_id <= ?
  483. """
  484. txn.execute(sql, (last_id, current_id))
  485. return [r[0] for r in txn]
  486. return await self.db_pool.runInteraction(
  487. "get_users_sent_receipts_between", _get_users_sent_receipts_between_txn
  488. )
  489. async def get_all_updated_receipts(
  490. self, instance_name: str, last_id: int, current_id: int, limit: int
  491. ) -> Tuple[
  492. List[Tuple[int, Tuple[str, str, str, str, Optional[str], JsonDict]]], int, bool
  493. ]:
  494. """Get updates for receipts replication stream.
  495. Args:
  496. instance_name: The writer we want to fetch updates from. Unused
  497. here since there is only ever one writer.
  498. last_id: The token to fetch updates from. Exclusive.
  499. current_id: The token to fetch updates up to. Inclusive.
  500. limit: The requested limit for the number of rows to return. The
  501. function may return more or fewer rows.
  502. Returns:
  503. A tuple consisting of: the updates, a token to use to fetch
  504. subsequent updates, and whether we returned fewer rows than exists
  505. between the requested tokens due to the limit.
  506. The token returned can be used in a subsequent call to this
  507. function to get further updatees.
  508. The updates are a list of 2-tuples of stream ID and the row data
  509. """
  510. if last_id == current_id:
  511. return [], current_id, False
  512. def get_all_updated_receipts_txn(
  513. txn: LoggingTransaction,
  514. ) -> Tuple[
  515. List[Tuple[int, Tuple[str, str, str, str, Optional[str], JsonDict]]],
  516. int,
  517. bool,
  518. ]:
  519. sql = """
  520. SELECT stream_id, room_id, receipt_type, user_id, event_id, thread_id, data
  521. FROM receipts_linearized
  522. WHERE ? < stream_id AND stream_id <= ?
  523. AND instance_name = ?
  524. ORDER BY stream_id ASC
  525. LIMIT ?
  526. """
  527. txn.execute(sql, (last_id, current_id, instance_name, limit))
  528. updates = cast(
  529. List[Tuple[int, Tuple[str, str, str, str, Optional[str], JsonDict]]],
  530. [(r[0], r[1:6] + (db_to_json(r[6]),)) for r in txn],
  531. )
  532. limited = False
  533. upper_bound = current_id
  534. if len(updates) == limit:
  535. limited = True
  536. upper_bound = updates[-1][0]
  537. return updates, upper_bound, limited
  538. return await self.db_pool.runInteraction(
  539. "get_all_updated_receipts", get_all_updated_receipts_txn
  540. )
  541. def invalidate_caches_for_receipt(
  542. self, room_id: str, receipt_type: str, user_id: str
  543. ) -> None:
  544. self._get_receipts_for_user_with_orderings.invalidate((user_id, receipt_type))
  545. self._get_linearized_receipts_for_room.invalidate((room_id,))
  546. # We use this method to invalidate so that we don't end up with circular
  547. # dependencies between the receipts and push action stores.
  548. self._attempt_to_invalidate_cache(
  549. "get_unread_event_push_actions_by_room_for_user", (room_id,)
  550. )
  551. def process_replication_rows(
  552. self,
  553. stream_name: str,
  554. instance_name: str,
  555. token: int,
  556. rows: Iterable[Any],
  557. ) -> None:
  558. if stream_name == ReceiptsStream.NAME:
  559. self._receipts_id_gen.advance(instance_name, token)
  560. for row in rows:
  561. self.invalidate_caches_for_receipt(
  562. row.room_id, row.receipt_type, row.user_id
  563. )
  564. self._receipts_stream_cache.entity_has_changed(row.room_id, token)
  565. return super().process_replication_rows(stream_name, instance_name, token, rows)
  566. def process_replication_position(
  567. self, stream_name: str, instance_name: str, token: int
  568. ) -> None:
  569. if stream_name == ReceiptsStream.NAME:
  570. self._receipts_id_gen.advance(instance_name, token)
  571. super().process_replication_position(stream_name, instance_name, token)
  572. def _insert_linearized_receipt_txn(
  573. self,
  574. txn: LoggingTransaction,
  575. room_id: str,
  576. receipt_type: str,
  577. user_id: str,
  578. event_id: str,
  579. thread_id: Optional[str],
  580. data: JsonDict,
  581. stream_id: int,
  582. ) -> Optional[int]:
  583. """Inserts a receipt into the database if it's newer than the current one.
  584. Returns:
  585. None if the receipt is older than the current receipt
  586. otherwise, the rx timestamp of the event that the receipt corresponds to
  587. (or 0 if the event is unknown)
  588. """
  589. assert self._can_write_to_receipts
  590. res = self.db_pool.simple_select_one_txn(
  591. txn,
  592. table="events",
  593. retcols=["stream_ordering", "received_ts"],
  594. keyvalues={"event_id": event_id},
  595. allow_none=True,
  596. )
  597. stream_ordering = int(res["stream_ordering"]) if res else None
  598. rx_ts = res["received_ts"] if res else 0
  599. # We don't want to clobber receipts for more recent events, so we
  600. # have to compare orderings of existing receipts
  601. if stream_ordering is not None:
  602. if thread_id is None:
  603. thread_clause = "r.thread_id IS NULL"
  604. thread_args: Tuple[str, ...] = ()
  605. else:
  606. thread_clause = "r.thread_id = ?"
  607. thread_args = (thread_id,)
  608. sql = f"""
  609. SELECT stream_ordering, event_id FROM events
  610. INNER JOIN receipts_linearized AS r USING (event_id, room_id)
  611. WHERE r.room_id = ? AND r.receipt_type = ? AND r.user_id = ? AND {thread_clause}
  612. """
  613. txn.execute(
  614. sql,
  615. (
  616. room_id,
  617. receipt_type,
  618. user_id,
  619. )
  620. + thread_args,
  621. )
  622. for so, eid in txn:
  623. if int(so) >= stream_ordering:
  624. logger.debug(
  625. "Ignoring new receipt for %s in favour of existing "
  626. "one for later event %s",
  627. event_id,
  628. eid,
  629. )
  630. return None
  631. txn.call_after(
  632. self.invalidate_caches_for_receipt, room_id, receipt_type, user_id
  633. )
  634. txn.call_after(
  635. self._receipts_stream_cache.entity_has_changed, room_id, stream_id
  636. )
  637. keyvalues = {
  638. "room_id": room_id,
  639. "receipt_type": receipt_type,
  640. "user_id": user_id,
  641. }
  642. where_clause = ""
  643. if thread_id is None:
  644. where_clause = "thread_id IS NULL"
  645. else:
  646. keyvalues["thread_id"] = thread_id
  647. self.db_pool.simple_upsert_txn(
  648. txn,
  649. table="receipts_linearized",
  650. keyvalues=keyvalues,
  651. values={
  652. "stream_id": stream_id,
  653. "instance_name": self._instance_name,
  654. "event_id": event_id,
  655. "event_stream_ordering": stream_ordering,
  656. "data": json_encoder.encode(data),
  657. },
  658. where_clause=where_clause,
  659. )
  660. return rx_ts
  661. def _graph_to_linear(
  662. self, txn: LoggingTransaction, room_id: str, event_ids: List[str]
  663. ) -> str:
  664. """
  665. Generate a linearized event from a list of events (i.e. a list of forward
  666. extremities in the room).
  667. This should allow for calculation of the correct read receipt even if
  668. servers have different event ordering.
  669. Args:
  670. txn: The transaction
  671. room_id: The room ID the events are in.
  672. event_ids: The list of event IDs to linearize.
  673. Returns:
  674. The linearized event ID.
  675. """
  676. # TODO: Make this better.
  677. clause, args = make_in_list_sql_clause(
  678. self.database_engine, "event_id", event_ids
  679. )
  680. sql = """
  681. SELECT event_id WHERE room_id = ? AND stream_ordering IN (
  682. SELECT max(stream_ordering) WHERE %s
  683. )
  684. """ % (
  685. clause,
  686. )
  687. txn.execute(sql, [room_id] + list(args))
  688. rows = txn.fetchall()
  689. if rows:
  690. return rows[0][0]
  691. else:
  692. raise RuntimeError("Unrecognized event_ids: %r" % (event_ids,))
  693. async def insert_receipt(
  694. self,
  695. room_id: str,
  696. receipt_type: str,
  697. user_id: str,
  698. event_ids: List[str],
  699. thread_id: Optional[str],
  700. data: dict,
  701. ) -> Optional[PersistedPosition]:
  702. """Insert a receipt, either from local client or remote server.
  703. Automatically does conversion between linearized and graph
  704. representations.
  705. Returns:
  706. The new receipts stream ID and token, if the receipt is newer than
  707. what was previously persisted. None, otherwise.
  708. """
  709. assert self._can_write_to_receipts
  710. if not event_ids:
  711. return None
  712. if len(event_ids) == 1:
  713. linearized_event_id = event_ids[0]
  714. else:
  715. # we need to points in graph -> linearized form.
  716. linearized_event_id = await self.db_pool.runInteraction(
  717. "insert_receipt_conv", self._graph_to_linear, room_id, event_ids
  718. )
  719. async with self._receipts_id_gen.get_next() as stream_id:
  720. event_ts = await self.db_pool.runInteraction(
  721. "insert_linearized_receipt",
  722. self._insert_linearized_receipt_txn,
  723. room_id,
  724. receipt_type,
  725. user_id,
  726. linearized_event_id,
  727. thread_id,
  728. data,
  729. stream_id=stream_id,
  730. # Read committed is actually beneficial here because we check for a receipt with
  731. # greater stream order, and checking the very latest data at select time is better
  732. # than the data at transaction start time.
  733. isolation_level=IsolationLevel.READ_COMMITTED,
  734. )
  735. # If the receipt was older than the currently persisted one, nothing to do.
  736. if event_ts is None:
  737. return None
  738. now = self._clock.time_msec()
  739. logger.debug(
  740. "Receipt %s for event %s in %s (%i ms old)",
  741. receipt_type,
  742. linearized_event_id,
  743. room_id,
  744. now - event_ts,
  745. )
  746. await self._insert_graph_receipt(
  747. room_id,
  748. receipt_type,
  749. user_id,
  750. event_ids,
  751. thread_id,
  752. data,
  753. )
  754. return PersistedPosition(self._instance_name, stream_id)
  755. async def _insert_graph_receipt(
  756. self,
  757. room_id: str,
  758. receipt_type: str,
  759. user_id: str,
  760. event_ids: List[str],
  761. thread_id: Optional[str],
  762. data: JsonDict,
  763. ) -> None:
  764. assert self._can_write_to_receipts
  765. keyvalues = {
  766. "room_id": room_id,
  767. "receipt_type": receipt_type,
  768. "user_id": user_id,
  769. }
  770. where_clause = ""
  771. if thread_id is None:
  772. where_clause = "thread_id IS NULL"
  773. else:
  774. keyvalues["thread_id"] = thread_id
  775. await self.db_pool.simple_upsert(
  776. desc="insert_graph_receipt",
  777. table="receipts_graph",
  778. keyvalues=keyvalues,
  779. values={
  780. "event_ids": json_encoder.encode(event_ids),
  781. "data": json_encoder.encode(data),
  782. },
  783. where_clause=where_clause,
  784. )
  785. self._get_receipts_for_user_with_orderings.invalidate((user_id, receipt_type))
  786. # FIXME: This shouldn't invalidate the whole cache
  787. self._get_linearized_receipts_for_room.invalidate((room_id,))
  788. class ReceiptsBackgroundUpdateStore(SQLBaseStore):
  789. POPULATE_RECEIPT_EVENT_STREAM_ORDERING = "populate_event_stream_ordering"
  790. RECEIPTS_LINEARIZED_UNIQUE_INDEX_UPDATE_NAME = "receipts_linearized_unique_index"
  791. RECEIPTS_GRAPH_UNIQUE_INDEX_UPDATE_NAME = "receipts_graph_unique_index"
  792. def __init__(
  793. self,
  794. database: DatabasePool,
  795. db_conn: LoggingDatabaseConnection,
  796. hs: "HomeServer",
  797. ):
  798. super().__init__(database, db_conn, hs)
  799. self.db_pool.updates.register_background_update_handler(
  800. self.POPULATE_RECEIPT_EVENT_STREAM_ORDERING,
  801. self._populate_receipt_event_stream_ordering,
  802. )
  803. self.db_pool.updates.register_background_update_handler(
  804. self.RECEIPTS_LINEARIZED_UNIQUE_INDEX_UPDATE_NAME,
  805. self._background_receipts_linearized_unique_index,
  806. )
  807. self.db_pool.updates.register_background_update_handler(
  808. self.RECEIPTS_GRAPH_UNIQUE_INDEX_UPDATE_NAME,
  809. self._background_receipts_graph_unique_index,
  810. )
  811. async def _populate_receipt_event_stream_ordering(
  812. self, progress: JsonDict, batch_size: int
  813. ) -> int:
  814. def _populate_receipt_event_stream_ordering_txn(
  815. txn: LoggingTransaction,
  816. ) -> bool:
  817. if "max_stream_id" in progress:
  818. max_stream_id = progress["max_stream_id"]
  819. else:
  820. txn.execute("SELECT max(stream_id) FROM receipts_linearized")
  821. res = txn.fetchone()
  822. if res is None or res[0] is None:
  823. return True
  824. else:
  825. max_stream_id = res[0]
  826. start = progress.get("stream_id", 0)
  827. stop = start + batch_size
  828. sql = """
  829. UPDATE receipts_linearized
  830. SET event_stream_ordering = (
  831. SELECT stream_ordering
  832. FROM events
  833. WHERE event_id = receipts_linearized.event_id
  834. )
  835. WHERE stream_id >= ? AND stream_id < ?
  836. """
  837. txn.execute(sql, (start, stop))
  838. self.db_pool.updates._background_update_progress_txn(
  839. txn,
  840. self.POPULATE_RECEIPT_EVENT_STREAM_ORDERING,
  841. {
  842. "stream_id": stop,
  843. "max_stream_id": max_stream_id,
  844. },
  845. )
  846. return stop > max_stream_id
  847. finished = await self.db_pool.runInteraction(
  848. "_remove_devices_from_device_inbox_txn",
  849. _populate_receipt_event_stream_ordering_txn,
  850. )
  851. if finished:
  852. await self.db_pool.updates._end_background_update(
  853. self.POPULATE_RECEIPT_EVENT_STREAM_ORDERING
  854. )
  855. return batch_size
  856. async def _background_receipts_linearized_unique_index(
  857. self, progress: dict, batch_size: int
  858. ) -> int:
  859. """Removes duplicate receipts and adds a unique index on
  860. `(room_id, receipt_type, user_id)` to `receipts_linearized`, for non-thread
  861. receipts."""
  862. def _remote_duplicate_receipts_txn(txn: LoggingTransaction) -> None:
  863. ROW_ID_NAME = self.database_engine.row_id_name
  864. # Identify any duplicate receipts arising from
  865. # https://github.com/matrix-org/synapse/issues/14406.
  866. # The following query takes less than a minute on matrix.org.
  867. sql = """
  868. SELECT MAX(stream_id), room_id, receipt_type, user_id
  869. FROM receipts_linearized
  870. WHERE thread_id IS NULL
  871. GROUP BY room_id, receipt_type, user_id
  872. HAVING COUNT(*) > 1
  873. """
  874. txn.execute(sql)
  875. duplicate_keys = cast(List[Tuple[int, str, str, str]], list(txn))
  876. # Then remove duplicate receipts, keeping the one with the highest
  877. # `stream_id`. Since there might be duplicate rows with the same
  878. # `stream_id`, we delete by the ctid instead.
  879. for stream_id, room_id, receipt_type, user_id in duplicate_keys:
  880. sql = f"""
  881. SELECT {ROW_ID_NAME}
  882. FROM receipts_linearized
  883. WHERE
  884. room_id = ? AND
  885. receipt_type = ? AND
  886. user_id = ? AND
  887. thread_id IS NULL AND
  888. stream_id = ?
  889. LIMIT 1
  890. """
  891. txn.execute(sql, (room_id, receipt_type, user_id, stream_id))
  892. row_id = cast(Tuple[str], txn.fetchone())[0]
  893. sql = f"""
  894. DELETE FROM receipts_linearized
  895. WHERE
  896. room_id = ? AND
  897. receipt_type = ? AND
  898. user_id = ? AND
  899. thread_id IS NULL AND
  900. {ROW_ID_NAME} != ?
  901. """
  902. txn.execute(sql, (room_id, receipt_type, user_id, row_id))
  903. await self.db_pool.runInteraction(
  904. self.RECEIPTS_LINEARIZED_UNIQUE_INDEX_UPDATE_NAME,
  905. _remote_duplicate_receipts_txn,
  906. )
  907. await self.db_pool.updates.create_index_in_background(
  908. index_name="receipts_linearized_unique_index",
  909. table="receipts_linearized",
  910. columns=["room_id", "receipt_type", "user_id"],
  911. where_clause="thread_id IS NULL",
  912. unique=True,
  913. )
  914. await self.db_pool.updates._end_background_update(
  915. self.RECEIPTS_LINEARIZED_UNIQUE_INDEX_UPDATE_NAME
  916. )
  917. return 1
  918. async def _background_receipts_graph_unique_index(
  919. self, progress: dict, batch_size: int
  920. ) -> int:
  921. """Removes duplicate receipts and adds a unique index on
  922. `(room_id, receipt_type, user_id)` to `receipts_graph`, for non-thread
  923. receipts."""
  924. def _remote_duplicate_receipts_txn(txn: LoggingTransaction) -> None:
  925. # Identify any duplicate receipts arising from
  926. # https://github.com/matrix-org/synapse/issues/14406.
  927. # We expect the following query to use the per-thread receipt index and take
  928. # less than a minute.
  929. sql = """
  930. SELECT room_id, receipt_type, user_id FROM receipts_graph
  931. WHERE thread_id IS NULL
  932. GROUP BY room_id, receipt_type, user_id
  933. HAVING COUNT(*) > 1
  934. """
  935. txn.execute(sql)
  936. duplicate_keys = cast(List[Tuple[str, str, str]], list(txn))
  937. # Then remove all duplicate receipts.
  938. # We could be clever and try to keep the latest receipt out of every set of
  939. # duplicates, but it's far simpler to remove them all.
  940. for room_id, receipt_type, user_id in duplicate_keys:
  941. sql = """
  942. DELETE FROM receipts_graph
  943. WHERE
  944. room_id = ? AND
  945. receipt_type = ? AND
  946. user_id = ? AND
  947. thread_id IS NULL
  948. """
  949. txn.execute(sql, (room_id, receipt_type, user_id))
  950. await self.db_pool.runInteraction(
  951. self.RECEIPTS_GRAPH_UNIQUE_INDEX_UPDATE_NAME,
  952. _remote_duplicate_receipts_txn,
  953. )
  954. await self.db_pool.updates.create_index_in_background(
  955. index_name="receipts_graph_unique_index",
  956. table="receipts_graph",
  957. columns=["room_id", "receipt_type", "user_id"],
  958. where_clause="thread_id IS NULL",
  959. unique=True,
  960. )
  961. await self.db_pool.updates._end_background_update(
  962. self.RECEIPTS_GRAPH_UNIQUE_INDEX_UPDATE_NAME
  963. )
  964. return 1
  965. class ReceiptsStore(ReceiptsWorkerStore, ReceiptsBackgroundUpdateStore):
  966. pass