You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 
 
 
 

839 lines
28 KiB

  1. # Copyright 2014-2016 OpenMarket Ltd
  2. # Copyright 2018 New Vector Ltd
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. import logging
  16. from typing import (
  17. TYPE_CHECKING,
  18. Any,
  19. Collection,
  20. Dict,
  21. Iterable,
  22. List,
  23. Optional,
  24. Tuple,
  25. cast,
  26. )
  27. from synapse.api.constants import EduTypes
  28. from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker
  29. from synapse.replication.tcp.streams import ReceiptsStream
  30. from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause
  31. from synapse.storage.database import (
  32. DatabasePool,
  33. LoggingDatabaseConnection,
  34. LoggingTransaction,
  35. )
  36. from synapse.storage.engines import PostgresEngine
  37. from synapse.storage.engines._base import IsolationLevel
  38. from synapse.storage.util.id_generators import (
  39. AbstractStreamIdTracker,
  40. MultiWriterIdGenerator,
  41. StreamIdGenerator,
  42. )
  43. from synapse.types import JsonDict
  44. from synapse.util import json_encoder
  45. from synapse.util.caches.descriptors import cached, cachedList
  46. from synapse.util.caches.stream_change_cache import StreamChangeCache
  47. if TYPE_CHECKING:
  48. from synapse.server import HomeServer
  49. logger = logging.getLogger(__name__)
  50. class ReceiptsWorkerStore(SQLBaseStore):
  51. def __init__(
  52. self,
  53. database: DatabasePool,
  54. db_conn: LoggingDatabaseConnection,
  55. hs: "HomeServer",
  56. ):
  57. self._instance_name = hs.get_instance_name()
  58. self._receipts_id_gen: AbstractStreamIdTracker
  59. if isinstance(database.engine, PostgresEngine):
  60. self._can_write_to_receipts = (
  61. self._instance_name in hs.config.worker.writers.receipts
  62. )
  63. self._receipts_id_gen = MultiWriterIdGenerator(
  64. db_conn=db_conn,
  65. db=database,
  66. stream_name="receipts",
  67. instance_name=self._instance_name,
  68. tables=[("receipts_linearized", "instance_name", "stream_id")],
  69. sequence_name="receipts_sequence",
  70. writers=hs.config.worker.writers.receipts,
  71. )
  72. else:
  73. self._can_write_to_receipts = True
  74. # We shouldn't be running in worker mode with SQLite, but its useful
  75. # to support it for unit tests.
  76. #
  77. # If this process is the writer than we need to use
  78. # `StreamIdGenerator`, otherwise we use `SlavedIdTracker` which gets
  79. # updated over replication. (Multiple writers are not supported for
  80. # SQLite).
  81. if hs.get_instance_name() in hs.config.worker.writers.receipts:
  82. self._receipts_id_gen = StreamIdGenerator(
  83. db_conn, "receipts_linearized", "stream_id"
  84. )
  85. else:
  86. self._receipts_id_gen = SlavedIdTracker(
  87. db_conn, "receipts_linearized", "stream_id"
  88. )
  89. super().__init__(database, db_conn, hs)
  90. max_receipts_stream_id = self.get_max_receipt_stream_id()
  91. receipts_stream_prefill, min_receipts_stream_id = self.db_pool.get_cache_dict(
  92. db_conn,
  93. "receipts_linearized",
  94. entity_column="room_id",
  95. stream_column="stream_id",
  96. max_value=max_receipts_stream_id,
  97. limit=10000,
  98. )
  99. self._receipts_stream_cache = StreamChangeCache(
  100. "ReceiptsRoomChangeCache",
  101. min_receipts_stream_id,
  102. prefilled_cache=receipts_stream_prefill,
  103. )
  104. def get_max_receipt_stream_id(self) -> int:
  105. """Get the current max stream ID for receipts stream"""
  106. return self._receipts_id_gen.get_current_token()
  107. async def get_last_receipt_event_id_for_user(
  108. self, user_id: str, room_id: str, receipt_types: Collection[str]
  109. ) -> Optional[str]:
  110. """
  111. Fetch the event ID for the latest receipt in a room with one of the given receipt types.
  112. Args:
  113. user_id: The user to fetch receipts for.
  114. room_id: The room ID to fetch the receipt for.
  115. receipt_type: The receipt types to fetch.
  116. Returns:
  117. The latest receipt, if one exists.
  118. """
  119. result = await self.db_pool.runInteraction(
  120. "get_last_receipt_event_id_for_user",
  121. self.get_last_receipt_for_user_txn,
  122. user_id,
  123. room_id,
  124. receipt_types,
  125. )
  126. if not result:
  127. return None
  128. event_id, _ = result
  129. return event_id
  130. def get_last_receipt_for_user_txn(
  131. self,
  132. txn: LoggingTransaction,
  133. user_id: str,
  134. room_id: str,
  135. receipt_types: Collection[str],
  136. ) -> Optional[Tuple[str, int]]:
  137. """
  138. Fetch the event ID and stream_ordering for the latest receipt in a room
  139. with one of the given receipt types.
  140. Args:
  141. user_id: The user to fetch receipts for.
  142. room_id: The room ID to fetch the receipt for.
  143. receipt_type: The receipt types to fetch.
  144. Returns:
  145. The latest receipt, if one exists.
  146. """
  147. clause, args = make_in_list_sql_clause(
  148. self.database_engine, "receipt_type", receipt_types
  149. )
  150. sql = f"""
  151. SELECT event_id, stream_ordering
  152. FROM receipts_linearized
  153. INNER JOIN events USING (room_id, event_id)
  154. WHERE {clause}
  155. AND user_id = ?
  156. AND room_id = ?
  157. ORDER BY stream_ordering DESC
  158. LIMIT 1
  159. """
  160. args.extend((user_id, room_id))
  161. txn.execute(sql, args)
  162. return cast(Optional[Tuple[str, int]], txn.fetchone())
  163. async def get_receipts_for_user(
  164. self, user_id: str, receipt_types: Iterable[str]
  165. ) -> Dict[str, str]:
  166. """
  167. Fetch the event IDs for the latest receipts sent by the given user.
  168. Args:
  169. user_id: The user to fetch receipts for.
  170. receipt_types: The receipt types to check.
  171. Returns:
  172. A map of room ID to the event ID of the latest receipt for that room.
  173. If the user has not sent a receipt to a room then it will not appear
  174. in the returned dictionary.
  175. """
  176. results = await self.get_receipts_for_user_with_orderings(
  177. user_id, receipt_types
  178. )
  179. # Reduce the result to room ID -> event ID.
  180. return {
  181. room_id: room_result["event_id"] for room_id, room_result in results.items()
  182. }
  183. async def get_receipts_for_user_with_orderings(
  184. self, user_id: str, receipt_types: Iterable[str]
  185. ) -> JsonDict:
  186. """
  187. Fetch receipts for all rooms that the given user is joined to.
  188. Args:
  189. user_id: The user to fetch receipts for.
  190. receipt_types: The receipt types to fetch. Earlier receipt types
  191. are given priority if multiple receipts point to the same event.
  192. Returns:
  193. A map of room ID to the latest receipt (for the given types).
  194. """
  195. results: JsonDict = {}
  196. for receipt_type in receipt_types:
  197. partial_result = await self._get_receipts_for_user_with_orderings(
  198. user_id, receipt_type
  199. )
  200. for room_id, room_result in partial_result.items():
  201. # If the room has not yet been seen, or the receipt is newer,
  202. # use it.
  203. if (
  204. room_id not in results
  205. or results[room_id]["stream_ordering"]
  206. < room_result["stream_ordering"]
  207. ):
  208. results[room_id] = room_result
  209. return results
  210. @cached()
  211. async def _get_receipts_for_user_with_orderings(
  212. self, user_id: str, receipt_type: str
  213. ) -> JsonDict:
  214. """
  215. Fetch receipts for all rooms that the given user is joined to.
  216. Args:
  217. user_id: The user to fetch receipts for.
  218. receipt_type: The receipt type to fetch.
  219. Returns:
  220. A map of room ID to the latest receipt information.
  221. """
  222. def f(txn: LoggingTransaction) -> List[Tuple[str, str, int, int]]:
  223. sql = (
  224. "SELECT rl.room_id, rl.event_id,"
  225. " e.topological_ordering, e.stream_ordering"
  226. " FROM receipts_linearized AS rl"
  227. " INNER JOIN events AS e USING (room_id, event_id)"
  228. " WHERE rl.room_id = e.room_id"
  229. " AND rl.event_id = e.event_id"
  230. " AND user_id = ?"
  231. " AND receipt_type = ?"
  232. )
  233. txn.execute(sql, (user_id, receipt_type))
  234. return cast(List[Tuple[str, str, int, int]], txn.fetchall())
  235. rows = await self.db_pool.runInteraction(
  236. "get_receipts_for_user_with_orderings", f
  237. )
  238. return {
  239. row[0]: {
  240. "event_id": row[1],
  241. "topological_ordering": row[2],
  242. "stream_ordering": row[3],
  243. }
  244. for row in rows
  245. }
  246. async def get_linearized_receipts_for_rooms(
  247. self, room_ids: Iterable[str], to_key: int, from_key: Optional[int] = None
  248. ) -> List[dict]:
  249. """Get receipts for multiple rooms for sending to clients.
  250. Args:
  251. room_id: The room IDs to fetch receipts of.
  252. to_key: Max stream id to fetch receipts up to.
  253. from_key: Min stream id to fetch receipts from. None fetches
  254. from the start.
  255. Returns:
  256. A list of receipts.
  257. """
  258. room_ids = set(room_ids)
  259. if from_key is not None:
  260. # Only ask the database about rooms where there have been new
  261. # receipts added since `from_key`
  262. room_ids = self._receipts_stream_cache.get_entities_changed(
  263. room_ids, from_key
  264. )
  265. results = await self._get_linearized_receipts_for_rooms(
  266. room_ids, to_key, from_key=from_key
  267. )
  268. return [ev for res in results.values() for ev in res]
  269. async def get_linearized_receipts_for_room(
  270. self, room_id: str, to_key: int, from_key: Optional[int] = None
  271. ) -> List[dict]:
  272. """Get receipts for a single room for sending to clients.
  273. Args:
  274. room_ids: The room id.
  275. to_key: Max stream id to fetch receipts up to.
  276. from_key: Min stream id to fetch receipts from. None fetches
  277. from the start.
  278. Returns:
  279. A list of receipts.
  280. """
  281. if from_key is not None:
  282. # Check the cache first to see if any new receipts have been added
  283. # since`from_key`. If not we can no-op.
  284. if not self._receipts_stream_cache.has_entity_changed(room_id, from_key):
  285. return []
  286. return await self._get_linearized_receipts_for_room(room_id, to_key, from_key)
  287. @cached(tree=True)
  288. async def _get_linearized_receipts_for_room(
  289. self, room_id: str, to_key: int, from_key: Optional[int] = None
  290. ) -> List[JsonDict]:
  291. """See get_linearized_receipts_for_room"""
  292. def f(txn: LoggingTransaction) -> List[Dict[str, Any]]:
  293. if from_key:
  294. sql = (
  295. "SELECT * FROM receipts_linearized WHERE"
  296. " room_id = ? AND stream_id > ? AND stream_id <= ?"
  297. )
  298. txn.execute(sql, (room_id, from_key, to_key))
  299. else:
  300. sql = (
  301. "SELECT * FROM receipts_linearized WHERE"
  302. " room_id = ? AND stream_id <= ?"
  303. )
  304. txn.execute(sql, (room_id, to_key))
  305. rows = self.db_pool.cursor_to_dict(txn)
  306. return rows
  307. rows = await self.db_pool.runInteraction("get_linearized_receipts_for_room", f)
  308. if not rows:
  309. return []
  310. content: JsonDict = {}
  311. for row in rows:
  312. content.setdefault(row["event_id"], {}).setdefault(row["receipt_type"], {})[
  313. row["user_id"]
  314. ] = db_to_json(row["data"])
  315. return [{"type": EduTypes.RECEIPT, "room_id": room_id, "content": content}]
  316. @cachedList(
  317. cached_method_name="_get_linearized_receipts_for_room",
  318. list_name="room_ids",
  319. num_args=3,
  320. )
  321. async def _get_linearized_receipts_for_rooms(
  322. self, room_ids: Collection[str], to_key: int, from_key: Optional[int] = None
  323. ) -> Dict[str, List[JsonDict]]:
  324. if not room_ids:
  325. return {}
  326. def f(txn: LoggingTransaction) -> List[Dict[str, Any]]:
  327. if from_key:
  328. sql = """
  329. SELECT * FROM receipts_linearized WHERE
  330. stream_id > ? AND stream_id <= ? AND
  331. """
  332. clause, args = make_in_list_sql_clause(
  333. self.database_engine, "room_id", room_ids
  334. )
  335. txn.execute(sql + clause, [from_key, to_key] + list(args))
  336. else:
  337. sql = """
  338. SELECT * FROM receipts_linearized WHERE
  339. stream_id <= ? AND
  340. """
  341. clause, args = make_in_list_sql_clause(
  342. self.database_engine, "room_id", room_ids
  343. )
  344. txn.execute(sql + clause, [to_key] + list(args))
  345. return self.db_pool.cursor_to_dict(txn)
  346. txn_results = await self.db_pool.runInteraction(
  347. "_get_linearized_receipts_for_rooms", f
  348. )
  349. results: JsonDict = {}
  350. for row in txn_results:
  351. # We want a single event per room, since we want to batch the
  352. # receipts by room, event and type.
  353. room_event = results.setdefault(
  354. row["room_id"],
  355. {"type": EduTypes.RECEIPT, "room_id": row["room_id"], "content": {}},
  356. )
  357. # The content is of the form:
  358. # {"$foo:bar": { "read": { "@user:host": <receipt> }, .. }, .. }
  359. event_entry = room_event["content"].setdefault(row["event_id"], {})
  360. receipt_type = event_entry.setdefault(row["receipt_type"], {})
  361. receipt_type[row["user_id"]] = db_to_json(row["data"])
  362. results = {
  363. room_id: [results[room_id]] if room_id in results else []
  364. for room_id in room_ids
  365. }
  366. return results
  367. @cached(
  368. num_args=2,
  369. )
  370. async def get_linearized_receipts_for_all_rooms(
  371. self, to_key: int, from_key: Optional[int] = None
  372. ) -> Dict[str, JsonDict]:
  373. """Get receipts for all rooms between two stream_ids, up
  374. to a limit of the latest 100 read receipts.
  375. Args:
  376. to_key: Max stream id to fetch receipts up to.
  377. from_key: Min stream id to fetch receipts from. None fetches
  378. from the start.
  379. Returns:
  380. A dictionary of roomids to a list of receipts.
  381. """
  382. def f(txn: LoggingTransaction) -> List[Dict[str, Any]]:
  383. if from_key:
  384. sql = """
  385. SELECT * FROM receipts_linearized WHERE
  386. stream_id > ? AND stream_id <= ?
  387. ORDER BY stream_id DESC
  388. LIMIT 100
  389. """
  390. txn.execute(sql, [from_key, to_key])
  391. else:
  392. sql = """
  393. SELECT * FROM receipts_linearized WHERE
  394. stream_id <= ?
  395. ORDER BY stream_id DESC
  396. LIMIT 100
  397. """
  398. txn.execute(sql, [to_key])
  399. return self.db_pool.cursor_to_dict(txn)
  400. txn_results = await self.db_pool.runInteraction(
  401. "get_linearized_receipts_for_all_rooms", f
  402. )
  403. results: JsonDict = {}
  404. for row in txn_results:
  405. # We want a single event per room, since we want to batch the
  406. # receipts by room, event and type.
  407. room_event = results.setdefault(
  408. row["room_id"],
  409. {"type": EduTypes.RECEIPT, "room_id": row["room_id"], "content": {}},
  410. )
  411. # The content is of the form:
  412. # {"$foo:bar": { "read": { "@user:host": <receipt> }, .. }, .. }
  413. event_entry = room_event["content"].setdefault(row["event_id"], {})
  414. receipt_type = event_entry.setdefault(row["receipt_type"], {})
  415. receipt_type[row["user_id"]] = db_to_json(row["data"])
  416. return results
  417. async def get_users_sent_receipts_between(
  418. self, last_id: int, current_id: int
  419. ) -> List[str]:
  420. """Get all users who sent receipts between `last_id` exclusive and
  421. `current_id` inclusive.
  422. Returns:
  423. The list of users.
  424. """
  425. if last_id == current_id:
  426. return []
  427. def _get_users_sent_receipts_between_txn(txn: LoggingTransaction) -> List[str]:
  428. sql = """
  429. SELECT DISTINCT user_id FROM receipts_linearized
  430. WHERE ? < stream_id AND stream_id <= ?
  431. """
  432. txn.execute(sql, (last_id, current_id))
  433. return [r[0] for r in txn]
  434. return await self.db_pool.runInteraction(
  435. "get_users_sent_receipts_between", _get_users_sent_receipts_between_txn
  436. )
  437. async def get_all_updated_receipts(
  438. self, instance_name: str, last_id: int, current_id: int, limit: int
  439. ) -> Tuple[List[Tuple[int, list]], int, bool]:
  440. """Get updates for receipts replication stream.
  441. Args:
  442. instance_name: The writer we want to fetch updates from. Unused
  443. here since there is only ever one writer.
  444. last_id: The token to fetch updates from. Exclusive.
  445. current_id: The token to fetch updates up to. Inclusive.
  446. limit: The requested limit for the number of rows to return. The
  447. function may return more or fewer rows.
  448. Returns:
  449. A tuple consisting of: the updates, a token to use to fetch
  450. subsequent updates, and whether we returned fewer rows than exists
  451. between the requested tokens due to the limit.
  452. The token returned can be used in a subsequent call to this
  453. function to get further updatees.
  454. The updates are a list of 2-tuples of stream ID and the row data
  455. """
  456. if last_id == current_id:
  457. return [], current_id, False
  458. def get_all_updated_receipts_txn(
  459. txn: LoggingTransaction,
  460. ) -> Tuple[List[Tuple[int, list]], int, bool]:
  461. sql = """
  462. SELECT stream_id, room_id, receipt_type, user_id, event_id, data
  463. FROM receipts_linearized
  464. WHERE ? < stream_id AND stream_id <= ?
  465. ORDER BY stream_id ASC
  466. LIMIT ?
  467. """
  468. txn.execute(sql, (last_id, current_id, limit))
  469. updates = cast(
  470. List[Tuple[int, list]],
  471. [(r[0], r[1:5] + (db_to_json(r[5]),)) for r in txn],
  472. )
  473. limited = False
  474. upper_bound = current_id
  475. if len(updates) == limit:
  476. limited = True
  477. upper_bound = updates[-1][0]
  478. return updates, upper_bound, limited
  479. return await self.db_pool.runInteraction(
  480. "get_all_updated_receipts", get_all_updated_receipts_txn
  481. )
  482. def invalidate_caches_for_receipt(
  483. self, room_id: str, receipt_type: str, user_id: str
  484. ) -> None:
  485. self._get_receipts_for_user_with_orderings.invalidate((user_id, receipt_type))
  486. self._get_linearized_receipts_for_room.invalidate((room_id,))
  487. # We use this method to invalidate so that we don't end up with circular
  488. # dependencies between the receipts and push action stores.
  489. self._attempt_to_invalidate_cache(
  490. "get_unread_event_push_actions_by_room_for_user", (room_id,)
  491. )
  492. def process_replication_rows(
  493. self,
  494. stream_name: str,
  495. instance_name: str,
  496. token: int,
  497. rows: Iterable[Any],
  498. ) -> None:
  499. if stream_name == ReceiptsStream.NAME:
  500. self._receipts_id_gen.advance(instance_name, token)
  501. for row in rows:
  502. self.invalidate_caches_for_receipt(
  503. row.room_id, row.receipt_type, row.user_id
  504. )
  505. self._receipts_stream_cache.entity_has_changed(row.room_id, token)
  506. return super().process_replication_rows(stream_name, instance_name, token, rows)
  507. def _insert_linearized_receipt_txn(
  508. self,
  509. txn: LoggingTransaction,
  510. room_id: str,
  511. receipt_type: str,
  512. user_id: str,
  513. event_id: str,
  514. data: JsonDict,
  515. stream_id: int,
  516. ) -> Optional[int]:
  517. """Inserts a receipt into the database if it's newer than the current one.
  518. Returns:
  519. None if the receipt is older than the current receipt
  520. otherwise, the rx timestamp of the event that the receipt corresponds to
  521. (or 0 if the event is unknown)
  522. """
  523. assert self._can_write_to_receipts
  524. res = self.db_pool.simple_select_one_txn(
  525. txn,
  526. table="events",
  527. retcols=["stream_ordering", "received_ts"],
  528. keyvalues={"event_id": event_id},
  529. allow_none=True,
  530. )
  531. stream_ordering = int(res["stream_ordering"]) if res else None
  532. rx_ts = res["received_ts"] if res else 0
  533. # We don't want to clobber receipts for more recent events, so we
  534. # have to compare orderings of existing receipts
  535. if stream_ordering is not None:
  536. sql = (
  537. "SELECT stream_ordering, event_id FROM events"
  538. " INNER JOIN receipts_linearized AS r USING (event_id, room_id)"
  539. " WHERE r.room_id = ? AND r.receipt_type = ? AND r.user_id = ?"
  540. )
  541. txn.execute(sql, (room_id, receipt_type, user_id))
  542. for so, eid in txn:
  543. if int(so) >= stream_ordering:
  544. logger.debug(
  545. "Ignoring new receipt for %s in favour of existing "
  546. "one for later event %s",
  547. event_id,
  548. eid,
  549. )
  550. return None
  551. txn.call_after(
  552. self.invalidate_caches_for_receipt, room_id, receipt_type, user_id
  553. )
  554. txn.call_after(
  555. self._receipts_stream_cache.entity_has_changed, room_id, stream_id
  556. )
  557. self.db_pool.simple_upsert_txn(
  558. txn,
  559. table="receipts_linearized",
  560. keyvalues={
  561. "room_id": room_id,
  562. "receipt_type": receipt_type,
  563. "user_id": user_id,
  564. },
  565. values={
  566. "stream_id": stream_id,
  567. "event_id": event_id,
  568. "data": json_encoder.encode(data),
  569. },
  570. # receipts_linearized has a unique constraint on
  571. # (user_id, room_id, receipt_type), so no need to lock
  572. lock=False,
  573. )
  574. return rx_ts
  575. def _graph_to_linear(
  576. self, txn: LoggingTransaction, room_id: str, event_ids: List[str]
  577. ) -> str:
  578. """
  579. Generate a linearized event from a list of events (i.e. a list of forward
  580. extremities in the room).
  581. This should allow for calculation of the correct read receipt even if
  582. servers have different event ordering.
  583. Args:
  584. txn: The transaction
  585. room_id: The room ID the events are in.
  586. event_ids: The list of event IDs to linearize.
  587. Returns:
  588. The linearized event ID.
  589. """
  590. # TODO: Make this better.
  591. clause, args = make_in_list_sql_clause(
  592. self.database_engine, "event_id", event_ids
  593. )
  594. sql = """
  595. SELECT event_id WHERE room_id = ? AND stream_ordering IN (
  596. SELECT max(stream_ordering) WHERE %s
  597. )
  598. """ % (
  599. clause,
  600. )
  601. txn.execute(sql, [room_id] + list(args))
  602. rows = txn.fetchall()
  603. if rows:
  604. return rows[0][0]
  605. else:
  606. raise RuntimeError("Unrecognized event_ids: %r" % (event_ids,))
  607. async def insert_receipt(
  608. self,
  609. room_id: str,
  610. receipt_type: str,
  611. user_id: str,
  612. event_ids: List[str],
  613. data: dict,
  614. ) -> Optional[Tuple[int, int]]:
  615. """Insert a receipt, either from local client or remote server.
  616. Automatically does conversion between linearized and graph
  617. representations.
  618. Returns:
  619. The new receipts stream ID and token, if the receipt is newer than
  620. what was previously persisted. None, otherwise.
  621. """
  622. assert self._can_write_to_receipts
  623. if not event_ids:
  624. return None
  625. if len(event_ids) == 1:
  626. linearized_event_id = event_ids[0]
  627. else:
  628. # we need to points in graph -> linearized form.
  629. linearized_event_id = await self.db_pool.runInteraction(
  630. "insert_receipt_conv", self._graph_to_linear, room_id, event_ids
  631. )
  632. async with self._receipts_id_gen.get_next() as stream_id: # type: ignore[attr-defined]
  633. event_ts = await self.db_pool.runInteraction(
  634. "insert_linearized_receipt",
  635. self._insert_linearized_receipt_txn,
  636. room_id,
  637. receipt_type,
  638. user_id,
  639. linearized_event_id,
  640. data,
  641. stream_id=stream_id,
  642. # Read committed is actually beneficial here because we check for a receipt with
  643. # greater stream order, and checking the very latest data at select time is better
  644. # than the data at transaction start time.
  645. isolation_level=IsolationLevel.READ_COMMITTED,
  646. )
  647. # If the receipt was older than the currently persisted one, nothing to do.
  648. if event_ts is None:
  649. return None
  650. now = self._clock.time_msec()
  651. logger.debug(
  652. "RR for event %s in %s (%i ms old)",
  653. linearized_event_id,
  654. room_id,
  655. now - event_ts,
  656. )
  657. await self.db_pool.runInteraction(
  658. "insert_graph_receipt",
  659. self._insert_graph_receipt_txn,
  660. room_id,
  661. receipt_type,
  662. user_id,
  663. event_ids,
  664. data,
  665. )
  666. max_persisted_id = self._receipts_id_gen.get_current_token()
  667. return stream_id, max_persisted_id
  668. def _insert_graph_receipt_txn(
  669. self,
  670. txn: LoggingTransaction,
  671. room_id: str,
  672. receipt_type: str,
  673. user_id: str,
  674. event_ids: List[str],
  675. data: JsonDict,
  676. ) -> None:
  677. assert self._can_write_to_receipts
  678. txn.call_after(
  679. self._get_receipts_for_user_with_orderings.invalidate,
  680. (user_id, receipt_type),
  681. )
  682. # FIXME: This shouldn't invalidate the whole cache
  683. txn.call_after(self._get_linearized_receipts_for_room.invalidate, (room_id,))
  684. self.db_pool.simple_delete_txn(
  685. txn,
  686. table="receipts_graph",
  687. keyvalues={
  688. "room_id": room_id,
  689. "receipt_type": receipt_type,
  690. "user_id": user_id,
  691. },
  692. )
  693. self.db_pool.simple_insert_txn(
  694. txn,
  695. table="receipts_graph",
  696. values={
  697. "room_id": room_id,
  698. "receipt_type": receipt_type,
  699. "user_id": user_id,
  700. "event_ids": json_encoder.encode(event_ids),
  701. "data": json_encoder.encode(data),
  702. },
  703. )
  704. class ReceiptsStore(ReceiptsWorkerStore):
  705. pass