Du kan inte välja fler än 25 ämnen Ämnen måste starta med en bokstav eller siffra, kan innehålla bindestreck ('-') och vara max 35 tecken långa.
 
 
 
 
 
 

2027 rader
77 KiB

  1. # -*- coding: utf-8 -*-
  2. # Copyright 2014-2016 OpenMarket Ltd
  3. # Copyright 2018-2019 New Vector Ltd
  4. # Copyright 2019 The Matrix.org Foundation C.I.C.
  5. #
  6. # Licensed under the Apache License, Version 2.0 (the "License");
  7. # you may not use this file except in compliance with the License.
  8. # You may obtain a copy of the License at
  9. #
  10. # http://www.apache.org/licenses/LICENSE-2.0
  11. #
  12. # Unless required by applicable law or agreed to in writing, software
  13. # distributed under the License is distributed on an "AS IS" BASIS,
  14. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  15. # See the License for the specific language governing permissions and
  16. # limitations under the License.
  17. import itertools
  18. import logging
  19. from collections import OrderedDict, namedtuple
  20. from typing import (
  21. TYPE_CHECKING,
  22. Any,
  23. Dict,
  24. Generator,
  25. Iterable,
  26. List,
  27. Optional,
  28. Set,
  29. Tuple,
  30. )
  31. import attr
  32. from prometheus_client import Counter
  33. import synapse.metrics
  34. from synapse.api.constants import EventContentFields, EventTypes, RelationTypes
  35. from synapse.api.room_versions import RoomVersions
  36. from synapse.crypto.event_signing import compute_event_reference_hash
  37. from synapse.events import EventBase # noqa: F401
  38. from synapse.events.snapshot import EventContext # noqa: F401
  39. from synapse.logging.utils import log_function
  40. from synapse.storage._base import db_to_json, make_in_list_sql_clause
  41. from synapse.storage.database import DatabasePool, LoggingTransaction
  42. from synapse.storage.databases.main.search import SearchEntry
  43. from synapse.storage.util.id_generators import MultiWriterIdGenerator
  44. from synapse.storage.util.sequence import build_sequence_generator
  45. from synapse.types import StateMap, get_domain_from_id
  46. from synapse.util import json_encoder
  47. from synapse.util.iterutils import batch_iter, sorted_topologically
  48. if TYPE_CHECKING:
  49. from synapse.server import HomeServer
  50. from synapse.storage.databases.main import DataStore
  51. logger = logging.getLogger(__name__)
  52. persist_event_counter = Counter("synapse_storage_events_persisted_events", "")
  53. event_counter = Counter(
  54. "synapse_storage_events_persisted_events_sep",
  55. "",
  56. ["type", "origin_type", "origin_entity"],
  57. )
  58. _EventCacheEntry = namedtuple("_EventCacheEntry", ("event", "redacted_event"))
  59. @attr.s(slots=True)
  60. class DeltaState:
  61. """Deltas to use to update the `current_state_events` table.
  62. Attributes:
  63. to_delete: List of type/state_keys to delete from current state
  64. to_insert: Map of state to upsert into current state
  65. no_longer_in_room: The server is not longer in the room, so the room
  66. should e.g. be removed from `current_state_events` table.
  67. """
  68. to_delete = attr.ib(type=List[Tuple[str, str]])
  69. to_insert = attr.ib(type=StateMap[str])
  70. no_longer_in_room = attr.ib(type=bool, default=False)
  71. class PersistEventsStore:
  72. """Contains all the functions for writing events to the database.
  73. Should only be instantiated on one process (when using a worker mode setup).
  74. Note: This is not part of the `DataStore` mixin.
  75. """
  76. def __init__(
  77. self, hs: "HomeServer", db: DatabasePool, main_data_store: "DataStore"
  78. ):
  79. self.hs = hs
  80. self.db_pool = db
  81. self.store = main_data_store
  82. self.database_engine = db.engine
  83. self._clock = hs.get_clock()
  84. self._instance_name = hs.get_instance_name()
  85. def get_chain_id_txn(txn):
  86. txn.execute("SELECT COALESCE(max(chain_id), 0) FROM event_auth_chains")
  87. return txn.fetchone()[0]
  88. self._event_chain_id_gen = build_sequence_generator(
  89. db.engine, get_chain_id_txn, "event_auth_chain_id"
  90. )
  91. self._ephemeral_messages_enabled = hs.config.enable_ephemeral_messages
  92. self.is_mine_id = hs.is_mine_id
  93. # Ideally we'd move these ID gens here, unfortunately some other ID
  94. # generators are chained off them so doing so is a bit of a PITA.
  95. self._backfill_id_gen = (
  96. self.store._backfill_id_gen
  97. ) # type: MultiWriterIdGenerator
  98. self._stream_id_gen = self.store._stream_id_gen # type: MultiWriterIdGenerator
  99. # This should only exist on instances that are configured to write
  100. assert (
  101. hs.get_instance_name() in hs.config.worker.writers.events
  102. ), "Can only instantiate EventsStore on master"
  103. async def _persist_events_and_state_updates(
  104. self,
  105. events_and_contexts: List[Tuple[EventBase, EventContext]],
  106. current_state_for_room: Dict[str, StateMap[str]],
  107. state_delta_for_room: Dict[str, DeltaState],
  108. new_forward_extremeties: Dict[str, List[str]],
  109. backfilled: bool = False,
  110. ) -> None:
  111. """Persist a set of events alongside updates to the current state and
  112. forward extremities tables.
  113. Args:
  114. events_and_contexts:
  115. current_state_for_room: Map from room_id to the current state of
  116. the room based on forward extremities
  117. state_delta_for_room: Map from room_id to the delta to apply to
  118. room state
  119. new_forward_extremities: Map from room_id to list of event IDs
  120. that are the new forward extremities of the room.
  121. backfilled
  122. Returns:
  123. Resolves when the events have been persisted
  124. """
  125. # We want to calculate the stream orderings as late as possible, as
  126. # we only notify after all events with a lesser stream ordering have
  127. # been persisted. I.e. if we spend 10s inside the with block then
  128. # that will delay all subsequent events from being notified about.
  129. # Hence why we do it down here rather than wrapping the entire
  130. # function.
  131. #
  132. # Its safe to do this after calculating the state deltas etc as we
  133. # only need to protect the *persistence* of the events. This is to
  134. # ensure that queries of the form "fetch events since X" don't
  135. # return events and stream positions after events that are still in
  136. # flight, as otherwise subsequent requests "fetch event since Y"
  137. # will not return those events.
  138. #
  139. # Note: Multiple instances of this function cannot be in flight at
  140. # the same time for the same room.
  141. if backfilled:
  142. stream_ordering_manager = self._backfill_id_gen.get_next_mult(
  143. len(events_and_contexts)
  144. )
  145. else:
  146. stream_ordering_manager = self._stream_id_gen.get_next_mult(
  147. len(events_and_contexts)
  148. )
  149. async with stream_ordering_manager as stream_orderings:
  150. for (event, context), stream in zip(events_and_contexts, stream_orderings):
  151. event.internal_metadata.stream_ordering = stream
  152. await self.db_pool.runInteraction(
  153. "persist_events",
  154. self._persist_events_txn,
  155. events_and_contexts=events_and_contexts,
  156. backfilled=backfilled,
  157. state_delta_for_room=state_delta_for_room,
  158. new_forward_extremeties=new_forward_extremeties,
  159. )
  160. persist_event_counter.inc(len(events_and_contexts))
  161. if not backfilled:
  162. # backfilled events have negative stream orderings, so we don't
  163. # want to set the event_persisted_position to that.
  164. synapse.metrics.event_persisted_position.set(
  165. events_and_contexts[-1][0].internal_metadata.stream_ordering
  166. )
  167. for event, context in events_and_contexts:
  168. if context.app_service:
  169. origin_type = "local"
  170. origin_entity = context.app_service.id
  171. elif self.hs.is_mine_id(event.sender):
  172. origin_type = "local"
  173. origin_entity = "*client*"
  174. else:
  175. origin_type = "remote"
  176. origin_entity = get_domain_from_id(event.sender)
  177. event_counter.labels(event.type, origin_type, origin_entity).inc()
  178. for room_id, new_state in current_state_for_room.items():
  179. self.store.get_current_state_ids.prefill((room_id,), new_state)
  180. for room_id, latest_event_ids in new_forward_extremeties.items():
  181. self.store.get_latest_event_ids_in_room.prefill(
  182. (room_id,), list(latest_event_ids)
  183. )
  184. async def _get_events_which_are_prevs(self, event_ids: Iterable[str]) -> List[str]:
  185. """Filter the supplied list of event_ids to get those which are prev_events of
  186. existing (non-outlier/rejected) events.
  187. Args:
  188. event_ids: event ids to filter
  189. Returns:
  190. Filtered event ids
  191. """
  192. results = [] # type: List[str]
  193. def _get_events_which_are_prevs_txn(txn, batch):
  194. sql = """
  195. SELECT prev_event_id, internal_metadata
  196. FROM event_edges
  197. INNER JOIN events USING (event_id)
  198. LEFT JOIN rejections USING (event_id)
  199. LEFT JOIN event_json USING (event_id)
  200. WHERE
  201. NOT events.outlier
  202. AND rejections.event_id IS NULL
  203. AND
  204. """
  205. clause, args = make_in_list_sql_clause(
  206. self.database_engine, "prev_event_id", batch
  207. )
  208. txn.execute(sql + clause, args)
  209. results.extend(r[0] for r in txn if not db_to_json(r[1]).get("soft_failed"))
  210. for chunk in batch_iter(event_ids, 100):
  211. await self.db_pool.runInteraction(
  212. "_get_events_which_are_prevs", _get_events_which_are_prevs_txn, chunk
  213. )
  214. return results
  215. async def _get_prevs_before_rejected(self, event_ids: Iterable[str]) -> Set[str]:
  216. """Get soft-failed ancestors to remove from the extremities.
  217. Given a set of events, find all those that have been soft-failed or
  218. rejected. Returns those soft failed/rejected events and their prev
  219. events (whether soft-failed/rejected or not), and recurses up the
  220. prev-event graph until it finds no more soft-failed/rejected events.
  221. This is used to find extremities that are ancestors of new events, but
  222. are separated by soft failed events.
  223. Args:
  224. event_ids: Events to find prev events for. Note that these must have
  225. already been persisted.
  226. Returns:
  227. The previous events.
  228. """
  229. # The set of event_ids to return. This includes all soft-failed events
  230. # and their prev events.
  231. existing_prevs = set()
  232. def _get_prevs_before_rejected_txn(txn, batch):
  233. to_recursively_check = batch
  234. while to_recursively_check:
  235. sql = """
  236. SELECT
  237. event_id, prev_event_id, internal_metadata,
  238. rejections.event_id IS NOT NULL
  239. FROM event_edges
  240. INNER JOIN events USING (event_id)
  241. LEFT JOIN rejections USING (event_id)
  242. LEFT JOIN event_json USING (event_id)
  243. WHERE
  244. NOT events.outlier
  245. AND
  246. """
  247. clause, args = make_in_list_sql_clause(
  248. self.database_engine, "event_id", to_recursively_check
  249. )
  250. txn.execute(sql + clause, args)
  251. to_recursively_check = []
  252. for event_id, prev_event_id, metadata, rejected in txn:
  253. if prev_event_id in existing_prevs:
  254. continue
  255. soft_failed = db_to_json(metadata).get("soft_failed")
  256. if soft_failed or rejected:
  257. to_recursively_check.append(prev_event_id)
  258. existing_prevs.add(prev_event_id)
  259. for chunk in batch_iter(event_ids, 100):
  260. await self.db_pool.runInteraction(
  261. "_get_prevs_before_rejected", _get_prevs_before_rejected_txn, chunk
  262. )
  263. return existing_prevs
  264. @log_function
  265. def _persist_events_txn(
  266. self,
  267. txn: LoggingTransaction,
  268. events_and_contexts: List[Tuple[EventBase, EventContext]],
  269. backfilled: bool,
  270. state_delta_for_room: Dict[str, DeltaState] = {},
  271. new_forward_extremeties: Dict[str, List[str]] = {},
  272. ):
  273. """Insert some number of room events into the necessary database tables.
  274. Rejected events are only inserted into the events table, the events_json table,
  275. and the rejections table. Things reading from those table will need to check
  276. whether the event was rejected.
  277. Args:
  278. txn
  279. events_and_contexts: events to persist
  280. backfilled: True if the events were backfilled
  281. delete_existing True to purge existing table rows for the events
  282. from the database. This is useful when retrying due to
  283. IntegrityError.
  284. state_delta_for_room: The current-state delta for each room.
  285. new_forward_extremetie: The new forward extremities for each room.
  286. For each room, a list of the event ids which are the forward
  287. extremities.
  288. """
  289. all_events_and_contexts = events_and_contexts
  290. min_stream_order = events_and_contexts[0][0].internal_metadata.stream_ordering
  291. max_stream_order = events_and_contexts[-1][0].internal_metadata.stream_ordering
  292. # stream orderings should have been assigned by now
  293. assert min_stream_order
  294. assert max_stream_order
  295. self._update_forward_extremities_txn(
  296. txn,
  297. new_forward_extremities=new_forward_extremeties,
  298. max_stream_order=max_stream_order,
  299. )
  300. # Ensure that we don't have the same event twice.
  301. events_and_contexts = self._filter_events_and_contexts_for_duplicates(
  302. events_and_contexts
  303. )
  304. self._update_room_depths_txn(
  305. txn, events_and_contexts=events_and_contexts, backfilled=backfilled
  306. )
  307. # _update_outliers_txn filters out any events which have already been
  308. # persisted, and returns the filtered list.
  309. events_and_contexts = self._update_outliers_txn(
  310. txn, events_and_contexts=events_and_contexts
  311. )
  312. # From this point onwards the events are only events that we haven't
  313. # seen before.
  314. self._store_event_txn(txn, events_and_contexts=events_and_contexts)
  315. self._persist_transaction_ids_txn(txn, events_and_contexts)
  316. # Insert into event_to_state_groups.
  317. self._store_event_state_mappings_txn(txn, events_and_contexts)
  318. self._persist_event_auth_chain_txn(txn, [e for e, _ in events_and_contexts])
  319. # _store_rejected_events_txn filters out any events which were
  320. # rejected, and returns the filtered list.
  321. events_and_contexts = self._store_rejected_events_txn(
  322. txn, events_and_contexts=events_and_contexts
  323. )
  324. # From this point onwards the events are only ones that weren't
  325. # rejected.
  326. self._update_metadata_tables_txn(
  327. txn,
  328. events_and_contexts=events_and_contexts,
  329. all_events_and_contexts=all_events_and_contexts,
  330. backfilled=backfilled,
  331. )
  332. # We call this last as it assumes we've inserted the events into
  333. # room_memberships, where applicable.
  334. self._update_current_state_txn(txn, state_delta_for_room, min_stream_order)
  335. def _persist_event_auth_chain_txn(
  336. self, txn: LoggingTransaction, events: List[EventBase],
  337. ) -> None:
  338. # We only care about state events, so this if there are no state events.
  339. if not any(e.is_state() for e in events):
  340. return
  341. # We want to store event_auth mappings for rejected events, as they're
  342. # used in state res v2.
  343. # This is only necessary if the rejected event appears in an accepted
  344. # event's auth chain, but its easier for now just to store them (and
  345. # it doesn't take much storage compared to storing the entire event
  346. # anyway).
  347. self.db_pool.simple_insert_many_txn(
  348. txn,
  349. table="event_auth",
  350. values=[
  351. {
  352. "event_id": event.event_id,
  353. "room_id": event.room_id,
  354. "auth_id": auth_id,
  355. }
  356. for event in events
  357. for auth_id in event.auth_event_ids()
  358. if event.is_state()
  359. ],
  360. )
  361. # We now calculate chain ID/sequence numbers for any state events we're
  362. # persisting. We ignore out of band memberships as we're not in the room
  363. # and won't have their auth chain (we'll fix it up later if we join the
  364. # room).
  365. #
  366. # See: docs/auth_chain_difference_algorithm.md
  367. # We ignore legacy rooms that we aren't filling the chain cover index
  368. # for.
  369. rows = self.db_pool.simple_select_many_txn(
  370. txn,
  371. table="rooms",
  372. column="room_id",
  373. iterable={event.room_id for event in events if event.is_state()},
  374. keyvalues={},
  375. retcols=("room_id", "has_auth_chain_index"),
  376. )
  377. rooms_using_chain_index = {
  378. row["room_id"] for row in rows if row["has_auth_chain_index"]
  379. }
  380. state_events = {
  381. event.event_id: event
  382. for event in events
  383. if event.is_state() and event.room_id in rooms_using_chain_index
  384. }
  385. if not state_events:
  386. return
  387. # Map from event ID to chain ID/sequence number.
  388. chain_map = {} # type: Dict[str, Tuple[int, int]]
  389. # We need to know the type/state_key and auth events of the events we're
  390. # calculating chain IDs for. We don't rely on having the full Event
  391. # instances as we'll potentially be pulling more events from the DB and
  392. # we don't need the overhead of fetching/parsing the full event JSON.
  393. event_to_types = {
  394. e.event_id: (e.type, e.state_key) for e in state_events.values()
  395. }
  396. event_to_auth_chain = {
  397. e.event_id: e.auth_event_ids() for e in state_events.values()
  398. }
  399. # Set of event IDs to calculate chain ID/seq numbers for.
  400. events_to_calc_chain_id_for = set(state_events)
  401. # We check if there are any events that need to be handled in the rooms
  402. # we're looking at. These should just be out of band memberships, where
  403. # we didn't have the auth chain when we first persisted.
  404. rows = self.db_pool.simple_select_many_txn(
  405. txn,
  406. table="event_auth_chain_to_calculate",
  407. keyvalues={},
  408. column="room_id",
  409. iterable={e.room_id for e in state_events.values()},
  410. retcols=("event_id", "type", "state_key"),
  411. )
  412. for row in rows:
  413. event_id = row["event_id"]
  414. event_type = row["type"]
  415. state_key = row["state_key"]
  416. # (We could pull out the auth events for all rows at once using
  417. # simple_select_many, but this case happens rarely and almost always
  418. # with a single row.)
  419. auth_events = self.db_pool.simple_select_onecol_txn(
  420. txn, "event_auth", keyvalues={"event_id": event_id}, retcol="auth_id",
  421. )
  422. events_to_calc_chain_id_for.add(event_id)
  423. event_to_types[event_id] = (event_type, state_key)
  424. event_to_auth_chain[event_id] = auth_events
  425. # First we get the chain ID and sequence numbers for the events'
  426. # auth events (that aren't also currently being persisted).
  427. #
  428. # Note that there there is an edge case here where we might not have
  429. # calculated chains and sequence numbers for events that were "out
  430. # of band". We handle this case by fetching the necessary info and
  431. # adding it to the set of events to calculate chain IDs for.
  432. missing_auth_chains = {
  433. a_id
  434. for auth_events in event_to_auth_chain.values()
  435. for a_id in auth_events
  436. if a_id not in events_to_calc_chain_id_for
  437. }
  438. # We loop here in case we find an out of band membership and need to
  439. # fetch their auth event info.
  440. while missing_auth_chains:
  441. sql = """
  442. SELECT event_id, events.type, state_key, chain_id, sequence_number
  443. FROM events
  444. INNER JOIN state_events USING (event_id)
  445. LEFT JOIN event_auth_chains USING (event_id)
  446. WHERE
  447. """
  448. clause, args = make_in_list_sql_clause(
  449. txn.database_engine, "event_id", missing_auth_chains,
  450. )
  451. txn.execute(sql + clause, args)
  452. missing_auth_chains.clear()
  453. for auth_id, event_type, state_key, chain_id, sequence_number in txn:
  454. event_to_types[auth_id] = (event_type, state_key)
  455. if chain_id is None:
  456. # No chain ID, so the event was persisted out of band.
  457. # We add to list of events to calculate auth chains for.
  458. events_to_calc_chain_id_for.add(auth_id)
  459. event_to_auth_chain[
  460. auth_id
  461. ] = self.db_pool.simple_select_onecol_txn(
  462. txn,
  463. "event_auth",
  464. keyvalues={"event_id": auth_id},
  465. retcol="auth_id",
  466. )
  467. missing_auth_chains.update(
  468. e
  469. for e in event_to_auth_chain[auth_id]
  470. if e not in event_to_types
  471. )
  472. else:
  473. chain_map[auth_id] = (chain_id, sequence_number)
  474. # Now we check if we have any events where we don't have auth chain,
  475. # this should only be out of band memberships.
  476. for event_id in sorted_topologically(event_to_auth_chain, event_to_auth_chain):
  477. for auth_id in event_to_auth_chain[event_id]:
  478. if (
  479. auth_id not in chain_map
  480. and auth_id not in events_to_calc_chain_id_for
  481. ):
  482. events_to_calc_chain_id_for.discard(event_id)
  483. # If this is an event we're trying to persist we add it to
  484. # the list of events to calculate chain IDs for next time
  485. # around. (Otherwise we will have already added it to the
  486. # table).
  487. event = state_events.get(event_id)
  488. if event:
  489. self.db_pool.simple_insert_txn(
  490. txn,
  491. table="event_auth_chain_to_calculate",
  492. values={
  493. "event_id": event.event_id,
  494. "room_id": event.room_id,
  495. "type": event.type,
  496. "state_key": event.state_key,
  497. },
  498. )
  499. # We stop checking the event's auth events since we've
  500. # discarded it.
  501. break
  502. if not events_to_calc_chain_id_for:
  503. return
  504. # We now calculate the chain IDs/sequence numbers for the events. We
  505. # do this by looking at the chain ID and sequence number of any auth
  506. # event with the same type/state_key and incrementing the sequence
  507. # number by one. If there was no match or the chain ID/sequence
  508. # number is already taken we generate a new chain.
  509. #
  510. # We need to do this in a topologically sorted order as we want to
  511. # generate chain IDs/sequence numbers of an event's auth events
  512. # before the event itself.
  513. chains_tuples_allocated = set() # type: Set[Tuple[int, int]]
  514. new_chain_tuples = {} # type: Dict[str, Tuple[int, int]]
  515. for event_id in sorted_topologically(
  516. events_to_calc_chain_id_for, event_to_auth_chain
  517. ):
  518. existing_chain_id = None
  519. for auth_id in event_to_auth_chain[event_id]:
  520. if event_to_types.get(event_id) == event_to_types.get(auth_id):
  521. existing_chain_id = chain_map[auth_id]
  522. break
  523. new_chain_tuple = None
  524. if existing_chain_id:
  525. # We found a chain ID/sequence number candidate, check its
  526. # not already taken.
  527. proposed_new_id = existing_chain_id[0]
  528. proposed_new_seq = existing_chain_id[1] + 1
  529. if (proposed_new_id, proposed_new_seq) not in chains_tuples_allocated:
  530. already_allocated = self.db_pool.simple_select_one_onecol_txn(
  531. txn,
  532. table="event_auth_chains",
  533. keyvalues={
  534. "chain_id": proposed_new_id,
  535. "sequence_number": proposed_new_seq,
  536. },
  537. retcol="event_id",
  538. allow_none=True,
  539. )
  540. if already_allocated:
  541. # Mark it as already allocated so we don't need to hit
  542. # the DB again.
  543. chains_tuples_allocated.add((proposed_new_id, proposed_new_seq))
  544. else:
  545. new_chain_tuple = (
  546. proposed_new_id,
  547. proposed_new_seq,
  548. )
  549. if not new_chain_tuple:
  550. new_chain_tuple = (self._event_chain_id_gen.get_next_id_txn(txn), 1)
  551. chains_tuples_allocated.add(new_chain_tuple)
  552. chain_map[event_id] = new_chain_tuple
  553. new_chain_tuples[event_id] = new_chain_tuple
  554. self.db_pool.simple_insert_many_txn(
  555. txn,
  556. table="event_auth_chains",
  557. values=[
  558. {"event_id": event_id, "chain_id": c_id, "sequence_number": seq}
  559. for event_id, (c_id, seq) in new_chain_tuples.items()
  560. ],
  561. )
  562. self.db_pool.simple_delete_many_txn(
  563. txn,
  564. table="event_auth_chain_to_calculate",
  565. keyvalues={},
  566. column="event_id",
  567. iterable=new_chain_tuples,
  568. )
  569. # Now we need to calculate any new links between chains caused by
  570. # the new events.
  571. #
  572. # Links are pairs of chain ID/sequence numbers such that for any
  573. # event A (CA, SA) and any event B (CB, SB), B is in A's auth chain
  574. # if and only if there is at least one link (CA, S1) -> (CB, S2)
  575. # where SA >= S1 and S2 >= SB.
  576. #
  577. # We try and avoid adding redundant links to the table, e.g. if we
  578. # have two links between two chains which both start/end at the
  579. # sequence number event (or cross) then one can be safely dropped.
  580. #
  581. # To calculate new links we look at every new event and:
  582. # 1. Fetch the chain ID/sequence numbers of its auth events,
  583. # discarding any that are reachable by other auth events, or
  584. # that have the same chain ID as the event.
  585. # 2. For each retained auth event we:
  586. # a. Add a link from the event's to the auth event's chain
  587. # ID/sequence number; and
  588. # b. Add a link from the event to every chain reachable by the
  589. # auth event.
  590. # Step 1, fetch all existing links from all the chains we've seen
  591. # referenced.
  592. chain_links = _LinkMap()
  593. rows = self.db_pool.simple_select_many_txn(
  594. txn,
  595. table="event_auth_chain_links",
  596. column="origin_chain_id",
  597. iterable={chain_id for chain_id, _ in chain_map.values()},
  598. keyvalues={},
  599. retcols=(
  600. "origin_chain_id",
  601. "origin_sequence_number",
  602. "target_chain_id",
  603. "target_sequence_number",
  604. ),
  605. )
  606. for row in rows:
  607. chain_links.add_link(
  608. (row["origin_chain_id"], row["origin_sequence_number"]),
  609. (row["target_chain_id"], row["target_sequence_number"]),
  610. new=False,
  611. )
  612. # We do this in toplogical order to avoid adding redundant links.
  613. for event_id in sorted_topologically(
  614. events_to_calc_chain_id_for, event_to_auth_chain
  615. ):
  616. chain_id, sequence_number = chain_map[event_id]
  617. # Filter out auth events that are reachable by other auth
  618. # events. We do this by looking at every permutation of pairs of
  619. # auth events (A, B) to check if B is reachable from A.
  620. reduction = {
  621. a_id
  622. for a_id in event_to_auth_chain[event_id]
  623. if chain_map[a_id][0] != chain_id
  624. }
  625. for start_auth_id, end_auth_id in itertools.permutations(
  626. event_to_auth_chain[event_id], r=2,
  627. ):
  628. if chain_links.exists_path_from(
  629. chain_map[start_auth_id], chain_map[end_auth_id]
  630. ):
  631. reduction.discard(end_auth_id)
  632. # Step 2, figure out what the new links are from the reduced
  633. # list of auth events.
  634. for auth_id in reduction:
  635. auth_chain_id, auth_sequence_number = chain_map[auth_id]
  636. # Step 2a, add link between the event and auth event
  637. chain_links.add_link(
  638. (chain_id, sequence_number), (auth_chain_id, auth_sequence_number)
  639. )
  640. # Step 2b, add a link to chains reachable from the auth
  641. # event.
  642. for target_id, target_seq in chain_links.get_links_from(
  643. (auth_chain_id, auth_sequence_number)
  644. ):
  645. if target_id == chain_id:
  646. continue
  647. chain_links.add_link(
  648. (chain_id, sequence_number), (target_id, target_seq)
  649. )
  650. self.db_pool.simple_insert_many_txn(
  651. txn,
  652. table="event_auth_chain_links",
  653. values=[
  654. {
  655. "origin_chain_id": source_id,
  656. "origin_sequence_number": source_seq,
  657. "target_chain_id": target_id,
  658. "target_sequence_number": target_seq,
  659. }
  660. for (
  661. source_id,
  662. source_seq,
  663. target_id,
  664. target_seq,
  665. ) in chain_links.get_additions()
  666. ],
  667. )
  668. def _persist_transaction_ids_txn(
  669. self,
  670. txn: LoggingTransaction,
  671. events_and_contexts: List[Tuple[EventBase, EventContext]],
  672. ):
  673. """Persist the mapping from transaction IDs to event IDs (if defined).
  674. """
  675. to_insert = []
  676. for event, _ in events_and_contexts:
  677. token_id = getattr(event.internal_metadata, "token_id", None)
  678. txn_id = getattr(event.internal_metadata, "txn_id", None)
  679. if token_id and txn_id:
  680. to_insert.append(
  681. {
  682. "event_id": event.event_id,
  683. "room_id": event.room_id,
  684. "user_id": event.sender,
  685. "token_id": token_id,
  686. "txn_id": txn_id,
  687. "inserted_ts": self._clock.time_msec(),
  688. }
  689. )
  690. if to_insert:
  691. self.db_pool.simple_insert_many_txn(
  692. txn, table="event_txn_id", values=to_insert,
  693. )
  694. def _update_current_state_txn(
  695. self,
  696. txn: LoggingTransaction,
  697. state_delta_by_room: Dict[str, DeltaState],
  698. stream_id: int,
  699. ):
  700. for room_id, delta_state in state_delta_by_room.items():
  701. to_delete = delta_state.to_delete
  702. to_insert = delta_state.to_insert
  703. if delta_state.no_longer_in_room:
  704. # Server is no longer in the room so we delete the room from
  705. # current_state_events, being careful we've already updated the
  706. # rooms.room_version column (which gets populated in a
  707. # background task).
  708. self._upsert_room_version_txn(txn, room_id)
  709. # Before deleting we populate the current_state_delta_stream
  710. # so that async background tasks get told what happened.
  711. sql = """
  712. INSERT INTO current_state_delta_stream
  713. (stream_id, instance_name, room_id, type, state_key, event_id, prev_event_id)
  714. SELECT ?, ?, room_id, type, state_key, null, event_id
  715. FROM current_state_events
  716. WHERE room_id = ?
  717. """
  718. txn.execute(sql, (stream_id, self._instance_name, room_id))
  719. self.db_pool.simple_delete_txn(
  720. txn, table="current_state_events", keyvalues={"room_id": room_id},
  721. )
  722. else:
  723. # We're still in the room, so we update the current state as normal.
  724. # First we add entries to the current_state_delta_stream. We
  725. # do this before updating the current_state_events table so
  726. # that we can use it to calculate the `prev_event_id`. (This
  727. # allows us to not have to pull out the existing state
  728. # unnecessarily).
  729. #
  730. # The stream_id for the update is chosen to be the minimum of the stream_ids
  731. # for the batch of the events that we are persisting; that means we do not
  732. # end up in a situation where workers see events before the
  733. # current_state_delta updates.
  734. #
  735. sql = """
  736. INSERT INTO current_state_delta_stream
  737. (stream_id, instance_name, room_id, type, state_key, event_id, prev_event_id)
  738. SELECT ?, ?, ?, ?, ?, ?, (
  739. SELECT event_id FROM current_state_events
  740. WHERE room_id = ? AND type = ? AND state_key = ?
  741. )
  742. """
  743. txn.executemany(
  744. sql,
  745. (
  746. (
  747. stream_id,
  748. self._instance_name,
  749. room_id,
  750. etype,
  751. state_key,
  752. to_insert.get((etype, state_key)),
  753. room_id,
  754. etype,
  755. state_key,
  756. )
  757. for etype, state_key in itertools.chain(to_delete, to_insert)
  758. ),
  759. )
  760. # Now we actually update the current_state_events table
  761. txn.executemany(
  762. "DELETE FROM current_state_events"
  763. " WHERE room_id = ? AND type = ? AND state_key = ?",
  764. (
  765. (room_id, etype, state_key)
  766. for etype, state_key in itertools.chain(to_delete, to_insert)
  767. ),
  768. )
  769. # We include the membership in the current state table, hence we do
  770. # a lookup when we insert. This assumes that all events have already
  771. # been inserted into room_memberships.
  772. txn.executemany(
  773. """INSERT INTO current_state_events
  774. (room_id, type, state_key, event_id, membership)
  775. VALUES (?, ?, ?, ?, (SELECT membership FROM room_memberships WHERE event_id = ?))
  776. """,
  777. [
  778. (room_id, key[0], key[1], ev_id, ev_id)
  779. for key, ev_id in to_insert.items()
  780. ],
  781. )
  782. # We now update `local_current_membership`. We do this regardless
  783. # of whether we're still in the room or not to handle the case where
  784. # e.g. we just got banned (where we need to record that fact here).
  785. # Note: Do we really want to delete rows here (that we do not
  786. # subsequently reinsert below)? While technically correct it means
  787. # we have no record of the fact the user *was* a member of the
  788. # room but got, say, state reset out of it.
  789. if to_delete or to_insert:
  790. txn.executemany(
  791. "DELETE FROM local_current_membership"
  792. " WHERE room_id = ? AND user_id = ?",
  793. (
  794. (room_id, state_key)
  795. for etype, state_key in itertools.chain(to_delete, to_insert)
  796. if etype == EventTypes.Member and self.is_mine_id(state_key)
  797. ),
  798. )
  799. if to_insert:
  800. txn.executemany(
  801. """INSERT INTO local_current_membership
  802. (room_id, user_id, event_id, membership)
  803. VALUES (?, ?, ?, (SELECT membership FROM room_memberships WHERE event_id = ?))
  804. """,
  805. [
  806. (room_id, key[1], ev_id, ev_id)
  807. for key, ev_id in to_insert.items()
  808. if key[0] == EventTypes.Member and self.is_mine_id(key[1])
  809. ],
  810. )
  811. txn.call_after(
  812. self.store._curr_state_delta_stream_cache.entity_has_changed,
  813. room_id,
  814. stream_id,
  815. )
  816. # Invalidate the various caches
  817. # Figure out the changes of membership to invalidate the
  818. # `get_rooms_for_user` cache.
  819. # We find out which membership events we may have deleted
  820. # and which we have added, then we invlidate the caches for all
  821. # those users.
  822. members_changed = {
  823. state_key
  824. for ev_type, state_key in itertools.chain(to_delete, to_insert)
  825. if ev_type == EventTypes.Member
  826. }
  827. for member in members_changed:
  828. txn.call_after(
  829. self.store.get_rooms_for_user_with_stream_ordering.invalidate,
  830. (member,),
  831. )
  832. self.store._invalidate_state_caches_and_stream(
  833. txn, room_id, members_changed
  834. )
  835. def _upsert_room_version_txn(self, txn: LoggingTransaction, room_id: str):
  836. """Update the room version in the database based off current state
  837. events.
  838. This is used when we're about to delete current state and we want to
  839. ensure that the `rooms.room_version` column is up to date.
  840. """
  841. sql = """
  842. SELECT json FROM event_json
  843. INNER JOIN current_state_events USING (room_id, event_id)
  844. WHERE room_id = ? AND type = ? AND state_key = ?
  845. """
  846. txn.execute(sql, (room_id, EventTypes.Create, ""))
  847. row = txn.fetchone()
  848. if row:
  849. event_json = db_to_json(row[0])
  850. content = event_json.get("content", {})
  851. creator = content.get("creator")
  852. room_version_id = content.get("room_version", RoomVersions.V1.identifier)
  853. self.db_pool.simple_upsert_txn(
  854. txn,
  855. table="rooms",
  856. keyvalues={"room_id": room_id},
  857. values={"room_version": room_version_id},
  858. insertion_values={"is_public": False, "creator": creator},
  859. )
  860. def _update_forward_extremities_txn(
  861. self, txn, new_forward_extremities, max_stream_order
  862. ):
  863. for room_id, new_extrem in new_forward_extremities.items():
  864. self.db_pool.simple_delete_txn(
  865. txn, table="event_forward_extremities", keyvalues={"room_id": room_id}
  866. )
  867. txn.call_after(
  868. self.store.get_latest_event_ids_in_room.invalidate, (room_id,)
  869. )
  870. self.db_pool.simple_insert_many_txn(
  871. txn,
  872. table="event_forward_extremities",
  873. values=[
  874. {"event_id": ev_id, "room_id": room_id}
  875. for room_id, new_extrem in new_forward_extremities.items()
  876. for ev_id in new_extrem
  877. ],
  878. )
  879. # We now insert into stream_ordering_to_exterm a mapping from room_id,
  880. # new stream_ordering to new forward extremeties in the room.
  881. # This allows us to later efficiently look up the forward extremeties
  882. # for a room before a given stream_ordering
  883. self.db_pool.simple_insert_many_txn(
  884. txn,
  885. table="stream_ordering_to_exterm",
  886. values=[
  887. {
  888. "room_id": room_id,
  889. "event_id": event_id,
  890. "stream_ordering": max_stream_order,
  891. }
  892. for room_id, new_extrem in new_forward_extremities.items()
  893. for event_id in new_extrem
  894. ],
  895. )
  896. @classmethod
  897. def _filter_events_and_contexts_for_duplicates(
  898. cls, events_and_contexts: List[Tuple[EventBase, EventContext]]
  899. ) -> List[Tuple[EventBase, EventContext]]:
  900. """Ensure that we don't have the same event twice.
  901. Pick the earliest non-outlier if there is one, else the earliest one.
  902. Args:
  903. events_and_contexts (list[(EventBase, EventContext)]):
  904. Returns:
  905. list[(EventBase, EventContext)]: filtered list
  906. """
  907. new_events_and_contexts = (
  908. OrderedDict()
  909. ) # type: OrderedDict[str, Tuple[EventBase, EventContext]]
  910. for event, context in events_and_contexts:
  911. prev_event_context = new_events_and_contexts.get(event.event_id)
  912. if prev_event_context:
  913. if not event.internal_metadata.is_outlier():
  914. if prev_event_context[0].internal_metadata.is_outlier():
  915. # To ensure correct ordering we pop, as OrderedDict is
  916. # ordered by first insertion.
  917. new_events_and_contexts.pop(event.event_id, None)
  918. new_events_and_contexts[event.event_id] = (event, context)
  919. else:
  920. new_events_and_contexts[event.event_id] = (event, context)
  921. return list(new_events_and_contexts.values())
  922. def _update_room_depths_txn(
  923. self,
  924. txn,
  925. events_and_contexts: List[Tuple[EventBase, EventContext]],
  926. backfilled: bool,
  927. ):
  928. """Update min_depth for each room
  929. Args:
  930. txn (twisted.enterprise.adbapi.Connection): db connection
  931. events_and_contexts (list[(EventBase, EventContext)]): events
  932. we are persisting
  933. backfilled (bool): True if the events were backfilled
  934. """
  935. depth_updates = {} # type: Dict[str, int]
  936. for event, context in events_and_contexts:
  937. # Remove the any existing cache entries for the event_ids
  938. txn.call_after(self.store._invalidate_get_event_cache, event.event_id)
  939. if not backfilled:
  940. txn.call_after(
  941. self.store._events_stream_cache.entity_has_changed,
  942. event.room_id,
  943. event.internal_metadata.stream_ordering,
  944. )
  945. if not event.internal_metadata.is_outlier() and not context.rejected:
  946. depth_updates[event.room_id] = max(
  947. event.depth, depth_updates.get(event.room_id, event.depth)
  948. )
  949. for room_id, depth in depth_updates.items():
  950. self._update_min_depth_for_room_txn(txn, room_id, depth)
  951. def _update_outliers_txn(self, txn, events_and_contexts):
  952. """Update any outliers with new event info.
  953. This turns outliers into ex-outliers (unless the new event was
  954. rejected).
  955. Args:
  956. txn (twisted.enterprise.adbapi.Connection): db connection
  957. events_and_contexts (list[(EventBase, EventContext)]): events
  958. we are persisting
  959. Returns:
  960. list[(EventBase, EventContext)] new list, without events which
  961. are already in the events table.
  962. """
  963. txn.execute(
  964. "SELECT event_id, outlier FROM events WHERE event_id in (%s)"
  965. % (",".join(["?"] * len(events_and_contexts)),),
  966. [event.event_id for event, _ in events_and_contexts],
  967. )
  968. have_persisted = {event_id: outlier for event_id, outlier in txn}
  969. to_remove = set()
  970. for event, context in events_and_contexts:
  971. if event.event_id not in have_persisted:
  972. continue
  973. to_remove.add(event)
  974. if context.rejected:
  975. # If the event is rejected then we don't care if the event
  976. # was an outlier or not.
  977. continue
  978. outlier_persisted = have_persisted[event.event_id]
  979. if not event.internal_metadata.is_outlier() and outlier_persisted:
  980. # We received a copy of an event that we had already stored as
  981. # an outlier in the database. We now have some state at that
  982. # so we need to update the state_groups table with that state.
  983. # insert into event_to_state_groups.
  984. try:
  985. self._store_event_state_mappings_txn(txn, ((event, context),))
  986. except Exception:
  987. logger.exception("")
  988. raise
  989. metadata_json = json_encoder.encode(event.internal_metadata.get_dict())
  990. sql = "UPDATE event_json SET internal_metadata = ? WHERE event_id = ?"
  991. txn.execute(sql, (metadata_json, event.event_id))
  992. # Add an entry to the ex_outlier_stream table to replicate the
  993. # change in outlier status to our workers.
  994. stream_order = event.internal_metadata.stream_ordering
  995. state_group_id = context.state_group
  996. self.db_pool.simple_insert_txn(
  997. txn,
  998. table="ex_outlier_stream",
  999. values={
  1000. "event_stream_ordering": stream_order,
  1001. "event_id": event.event_id,
  1002. "state_group": state_group_id,
  1003. "instance_name": self._instance_name,
  1004. },
  1005. )
  1006. sql = "UPDATE events SET outlier = ? WHERE event_id = ?"
  1007. txn.execute(sql, (False, event.event_id))
  1008. # Update the event_backward_extremities table now that this
  1009. # event isn't an outlier any more.
  1010. self._update_backward_extremeties(txn, [event])
  1011. return [ec for ec in events_and_contexts if ec[0] not in to_remove]
  1012. def _store_event_txn(self, txn, events_and_contexts):
  1013. """Insert new events into the event, event_json, redaction and
  1014. state_events tables.
  1015. Args:
  1016. txn (twisted.enterprise.adbapi.Connection): db connection
  1017. events_and_contexts (list[(EventBase, EventContext)]): events
  1018. we are persisting
  1019. """
  1020. if not events_and_contexts:
  1021. # nothing to do here
  1022. return
  1023. def event_dict(event):
  1024. d = event.get_dict()
  1025. d.pop("redacted", None)
  1026. d.pop("redacted_because", None)
  1027. return d
  1028. self.db_pool.simple_insert_many_txn(
  1029. txn,
  1030. table="event_json",
  1031. values=[
  1032. {
  1033. "event_id": event.event_id,
  1034. "room_id": event.room_id,
  1035. "internal_metadata": json_encoder.encode(
  1036. event.internal_metadata.get_dict()
  1037. ),
  1038. "json": json_encoder.encode(event_dict(event)),
  1039. "format_version": event.format_version,
  1040. }
  1041. for event, _ in events_and_contexts
  1042. ],
  1043. )
  1044. self.db_pool.simple_insert_many_txn(
  1045. txn,
  1046. table="events",
  1047. values=[
  1048. {
  1049. "instance_name": self._instance_name,
  1050. "stream_ordering": event.internal_metadata.stream_ordering,
  1051. "topological_ordering": event.depth,
  1052. "depth": event.depth,
  1053. "event_id": event.event_id,
  1054. "room_id": event.room_id,
  1055. "type": event.type,
  1056. "processed": True,
  1057. "outlier": event.internal_metadata.is_outlier(),
  1058. "origin_server_ts": int(event.origin_server_ts),
  1059. "received_ts": self._clock.time_msec(),
  1060. "sender": event.sender,
  1061. "contains_url": (
  1062. "url" in event.content and isinstance(event.content["url"], str)
  1063. ),
  1064. }
  1065. for event, _ in events_and_contexts
  1066. ],
  1067. )
  1068. for event, _ in events_and_contexts:
  1069. if not event.internal_metadata.is_redacted():
  1070. # If we're persisting an unredacted event we go and ensure
  1071. # that we mark any redactions that reference this event as
  1072. # requiring censoring.
  1073. self.db_pool.simple_update_txn(
  1074. txn,
  1075. table="redactions",
  1076. keyvalues={"redacts": event.event_id},
  1077. updatevalues={"have_censored": False},
  1078. )
  1079. state_events_and_contexts = [
  1080. ec for ec in events_and_contexts if ec[0].is_state()
  1081. ]
  1082. state_values = []
  1083. for event, context in state_events_and_contexts:
  1084. vals = {
  1085. "event_id": event.event_id,
  1086. "room_id": event.room_id,
  1087. "type": event.type,
  1088. "state_key": event.state_key,
  1089. }
  1090. # TODO: How does this work with backfilling?
  1091. if hasattr(event, "replaces_state"):
  1092. vals["prev_state"] = event.replaces_state
  1093. state_values.append(vals)
  1094. self.db_pool.simple_insert_many_txn(
  1095. txn, table="state_events", values=state_values
  1096. )
  1097. def _store_rejected_events_txn(self, txn, events_and_contexts):
  1098. """Add rows to the 'rejections' table for received events which were
  1099. rejected
  1100. Args:
  1101. txn (twisted.enterprise.adbapi.Connection): db connection
  1102. events_and_contexts (list[(EventBase, EventContext)]): events
  1103. we are persisting
  1104. Returns:
  1105. list[(EventBase, EventContext)] new list, without the rejected
  1106. events.
  1107. """
  1108. # Remove the rejected events from the list now that we've added them
  1109. # to the events table and the events_json table.
  1110. to_remove = set()
  1111. for event, context in events_and_contexts:
  1112. if context.rejected:
  1113. # Insert the event_id into the rejections table
  1114. self._store_rejections_txn(txn, event.event_id, context.rejected)
  1115. to_remove.add(event)
  1116. return [ec for ec in events_and_contexts if ec[0] not in to_remove]
  1117. def _update_metadata_tables_txn(
  1118. self, txn, events_and_contexts, all_events_and_contexts, backfilled
  1119. ):
  1120. """Update all the miscellaneous tables for new events
  1121. Args:
  1122. txn (twisted.enterprise.adbapi.Connection): db connection
  1123. events_and_contexts (list[(EventBase, EventContext)]): events
  1124. we are persisting
  1125. all_events_and_contexts (list[(EventBase, EventContext)]): all
  1126. events that we were going to persist. This includes events
  1127. we've already persisted, etc, that wouldn't appear in
  1128. events_and_context.
  1129. backfilled (bool): True if the events were backfilled
  1130. """
  1131. # Insert all the push actions into the event_push_actions table.
  1132. self._set_push_actions_for_event_and_users_txn(
  1133. txn,
  1134. events_and_contexts=events_and_contexts,
  1135. all_events_and_contexts=all_events_and_contexts,
  1136. )
  1137. if not events_and_contexts:
  1138. # nothing to do here
  1139. return
  1140. for event, context in events_and_contexts:
  1141. if event.type == EventTypes.Redaction and event.redacts is not None:
  1142. # Remove the entries in the event_push_actions table for the
  1143. # redacted event.
  1144. self._remove_push_actions_for_event_id_txn(
  1145. txn, event.room_id, event.redacts
  1146. )
  1147. # Remove from relations table.
  1148. self._handle_redaction(txn, event.redacts)
  1149. # Update the event_forward_extremities, event_backward_extremities and
  1150. # event_edges tables.
  1151. self._handle_mult_prev_events(
  1152. txn, events=[event for event, _ in events_and_contexts]
  1153. )
  1154. for event, _ in events_and_contexts:
  1155. if event.type == EventTypes.Name:
  1156. # Insert into the event_search table.
  1157. self._store_room_name_txn(txn, event)
  1158. elif event.type == EventTypes.Topic:
  1159. # Insert into the event_search table.
  1160. self._store_room_topic_txn(txn, event)
  1161. elif event.type == EventTypes.Message:
  1162. # Insert into the event_search table.
  1163. self._store_room_message_txn(txn, event)
  1164. elif event.type == EventTypes.Redaction and event.redacts is not None:
  1165. # Insert into the redactions table.
  1166. self._store_redaction(txn, event)
  1167. elif event.type == EventTypes.Retention:
  1168. # Update the room_retention table.
  1169. self._store_retention_policy_for_room_txn(txn, event)
  1170. self._handle_event_relations(txn, event)
  1171. # Store the labels for this event.
  1172. labels = event.content.get(EventContentFields.LABELS)
  1173. if labels:
  1174. self.insert_labels_for_event_txn(
  1175. txn, event.event_id, labels, event.room_id, event.depth
  1176. )
  1177. if self._ephemeral_messages_enabled:
  1178. # If there's an expiry timestamp on the event, store it.
  1179. expiry_ts = event.content.get(EventContentFields.SELF_DESTRUCT_AFTER)
  1180. if isinstance(expiry_ts, int) and not event.is_state():
  1181. self._insert_event_expiry_txn(txn, event.event_id, expiry_ts)
  1182. # Insert into the room_memberships table.
  1183. self._store_room_members_txn(
  1184. txn,
  1185. [
  1186. event
  1187. for event, _ in events_and_contexts
  1188. if event.type == EventTypes.Member
  1189. ],
  1190. backfilled=backfilled,
  1191. )
  1192. # Insert event_reference_hashes table.
  1193. self._store_event_reference_hashes_txn(
  1194. txn, [event for event, _ in events_and_contexts]
  1195. )
  1196. # Prefill the event cache
  1197. self._add_to_cache(txn, events_and_contexts)
  1198. def _add_to_cache(self, txn, events_and_contexts):
  1199. to_prefill = []
  1200. rows = []
  1201. N = 200
  1202. for i in range(0, len(events_and_contexts), N):
  1203. ev_map = {e[0].event_id: e[0] for e in events_and_contexts[i : i + N]}
  1204. if not ev_map:
  1205. break
  1206. sql = (
  1207. "SELECT "
  1208. " e.event_id as event_id, "
  1209. " r.redacts as redacts,"
  1210. " rej.event_id as rejects "
  1211. " FROM events as e"
  1212. " LEFT JOIN rejections as rej USING (event_id)"
  1213. " LEFT JOIN redactions as r ON e.event_id = r.redacts"
  1214. " WHERE "
  1215. )
  1216. clause, args = make_in_list_sql_clause(
  1217. self.database_engine, "e.event_id", list(ev_map)
  1218. )
  1219. txn.execute(sql + clause, args)
  1220. rows = self.db_pool.cursor_to_dict(txn)
  1221. for row in rows:
  1222. event = ev_map[row["event_id"]]
  1223. if not row["rejects"] and not row["redacts"]:
  1224. to_prefill.append(
  1225. _EventCacheEntry(event=event, redacted_event=None)
  1226. )
  1227. def prefill():
  1228. for cache_entry in to_prefill:
  1229. self.store._get_event_cache.set((cache_entry[0].event_id,), cache_entry)
  1230. txn.call_after(prefill)
  1231. def _store_redaction(self, txn, event):
  1232. # invalidate the cache for the redacted event
  1233. txn.call_after(self.store._invalidate_get_event_cache, event.redacts)
  1234. self.db_pool.simple_insert_txn(
  1235. txn,
  1236. table="redactions",
  1237. values={
  1238. "event_id": event.event_id,
  1239. "redacts": event.redacts,
  1240. "received_ts": self._clock.time_msec(),
  1241. },
  1242. )
  1243. def insert_labels_for_event_txn(
  1244. self, txn, event_id, labels, room_id, topological_ordering
  1245. ):
  1246. """Store the mapping between an event's ID and its labels, with one row per
  1247. (event_id, label) tuple.
  1248. Args:
  1249. txn (LoggingTransaction): The transaction to execute.
  1250. event_id (str): The event's ID.
  1251. labels (list[str]): A list of text labels.
  1252. room_id (str): The ID of the room the event was sent to.
  1253. topological_ordering (int): The position of the event in the room's topology.
  1254. """
  1255. return self.db_pool.simple_insert_many_txn(
  1256. txn=txn,
  1257. table="event_labels",
  1258. values=[
  1259. {
  1260. "event_id": event_id,
  1261. "label": label,
  1262. "room_id": room_id,
  1263. "topological_ordering": topological_ordering,
  1264. }
  1265. for label in labels
  1266. ],
  1267. )
  1268. def _insert_event_expiry_txn(self, txn, event_id, expiry_ts):
  1269. """Save the expiry timestamp associated with a given event ID.
  1270. Args:
  1271. txn (LoggingTransaction): The database transaction to use.
  1272. event_id (str): The event ID the expiry timestamp is associated with.
  1273. expiry_ts (int): The timestamp at which to expire (delete) the event.
  1274. """
  1275. return self.db_pool.simple_insert_txn(
  1276. txn=txn,
  1277. table="event_expiry",
  1278. values={"event_id": event_id, "expiry_ts": expiry_ts},
  1279. )
  1280. def _store_event_reference_hashes_txn(self, txn, events):
  1281. """Store a hash for a PDU
  1282. Args:
  1283. txn (cursor):
  1284. events (list): list of Events.
  1285. """
  1286. vals = []
  1287. for event in events:
  1288. ref_alg, ref_hash_bytes = compute_event_reference_hash(event)
  1289. vals.append(
  1290. {
  1291. "event_id": event.event_id,
  1292. "algorithm": ref_alg,
  1293. "hash": memoryview(ref_hash_bytes),
  1294. }
  1295. )
  1296. self.db_pool.simple_insert_many_txn(
  1297. txn, table="event_reference_hashes", values=vals
  1298. )
  1299. def _store_room_members_txn(self, txn, events, backfilled):
  1300. """Store a room member in the database.
  1301. """
  1302. def str_or_none(val: Any) -> Optional[str]:
  1303. return val if isinstance(val, str) else None
  1304. self.db_pool.simple_insert_many_txn(
  1305. txn,
  1306. table="room_memberships",
  1307. values=[
  1308. {
  1309. "event_id": event.event_id,
  1310. "user_id": event.state_key,
  1311. "sender": event.user_id,
  1312. "room_id": event.room_id,
  1313. "membership": event.membership,
  1314. "display_name": str_or_none(event.content.get("displayname")),
  1315. "avatar_url": str_or_none(event.content.get("avatar_url")),
  1316. }
  1317. for event in events
  1318. ],
  1319. )
  1320. for event in events:
  1321. txn.call_after(
  1322. self.store._membership_stream_cache.entity_has_changed,
  1323. event.state_key,
  1324. event.internal_metadata.stream_ordering,
  1325. )
  1326. txn.call_after(
  1327. self.store.get_invited_rooms_for_local_user.invalidate,
  1328. (event.state_key,),
  1329. )
  1330. # We update the local_current_membership table only if the event is
  1331. # "current", i.e., its something that has just happened.
  1332. #
  1333. # This will usually get updated by the `current_state_events` handling,
  1334. # unless its an outlier, and an outlier is only "current" if it's an "out of
  1335. # band membership", like a remote invite or a rejection of a remote invite.
  1336. if (
  1337. self.is_mine_id(event.state_key)
  1338. and not backfilled
  1339. and event.internal_metadata.is_outlier()
  1340. and event.internal_metadata.is_out_of_band_membership()
  1341. ):
  1342. self.db_pool.simple_upsert_txn(
  1343. txn,
  1344. table="local_current_membership",
  1345. keyvalues={"room_id": event.room_id, "user_id": event.state_key},
  1346. values={
  1347. "event_id": event.event_id,
  1348. "membership": event.membership,
  1349. },
  1350. )
  1351. def _handle_event_relations(self, txn, event):
  1352. """Handles inserting relation data during peristence of events
  1353. Args:
  1354. txn
  1355. event (EventBase)
  1356. """
  1357. relation = event.content.get("m.relates_to")
  1358. if not relation:
  1359. # No relations
  1360. return
  1361. rel_type = relation.get("rel_type")
  1362. if rel_type not in (
  1363. RelationTypes.ANNOTATION,
  1364. RelationTypes.REFERENCE,
  1365. RelationTypes.REPLACE,
  1366. ):
  1367. # Unknown relation type
  1368. return
  1369. parent_id = relation.get("event_id")
  1370. if not parent_id:
  1371. # Invalid relation
  1372. return
  1373. aggregation_key = relation.get("key")
  1374. self.db_pool.simple_insert_txn(
  1375. txn,
  1376. table="event_relations",
  1377. values={
  1378. "event_id": event.event_id,
  1379. "relates_to_id": parent_id,
  1380. "relation_type": rel_type,
  1381. "aggregation_key": aggregation_key,
  1382. },
  1383. )
  1384. txn.call_after(self.store.get_relations_for_event.invalidate_many, (parent_id,))
  1385. txn.call_after(
  1386. self.store.get_aggregation_groups_for_event.invalidate_many, (parent_id,)
  1387. )
  1388. if rel_type == RelationTypes.REPLACE:
  1389. txn.call_after(self.store.get_applicable_edit.invalidate, (parent_id,))
  1390. def _handle_redaction(self, txn, redacted_event_id):
  1391. """Handles receiving a redaction and checking whether we need to remove
  1392. any redacted relations from the database.
  1393. Args:
  1394. txn
  1395. redacted_event_id (str): The event that was redacted.
  1396. """
  1397. self.db_pool.simple_delete_txn(
  1398. txn, table="event_relations", keyvalues={"event_id": redacted_event_id}
  1399. )
  1400. def _store_room_topic_txn(self, txn, event):
  1401. if hasattr(event, "content") and "topic" in event.content:
  1402. self.store_event_search_txn(
  1403. txn, event, "content.topic", event.content["topic"]
  1404. )
  1405. def _store_room_name_txn(self, txn, event):
  1406. if hasattr(event, "content") and "name" in event.content:
  1407. self.store_event_search_txn(
  1408. txn, event, "content.name", event.content["name"]
  1409. )
  1410. def _store_room_message_txn(self, txn, event):
  1411. if hasattr(event, "content") and "body" in event.content:
  1412. self.store_event_search_txn(
  1413. txn, event, "content.body", event.content["body"]
  1414. )
  1415. def _store_retention_policy_for_room_txn(self, txn, event):
  1416. if not event.is_state():
  1417. logger.debug("Ignoring non-state m.room.retention event")
  1418. return
  1419. if hasattr(event, "content") and (
  1420. "min_lifetime" in event.content or "max_lifetime" in event.content
  1421. ):
  1422. if (
  1423. "min_lifetime" in event.content
  1424. and not isinstance(event.content.get("min_lifetime"), int)
  1425. ) or (
  1426. "max_lifetime" in event.content
  1427. and not isinstance(event.content.get("max_lifetime"), int)
  1428. ):
  1429. # Ignore the event if one of the value isn't an integer.
  1430. return
  1431. self.db_pool.simple_insert_txn(
  1432. txn=txn,
  1433. table="room_retention",
  1434. values={
  1435. "room_id": event.room_id,
  1436. "event_id": event.event_id,
  1437. "min_lifetime": event.content.get("min_lifetime"),
  1438. "max_lifetime": event.content.get("max_lifetime"),
  1439. },
  1440. )
  1441. self.store._invalidate_cache_and_stream(
  1442. txn, self.store.get_retention_policy_for_room, (event.room_id,)
  1443. )
  1444. def store_event_search_txn(self, txn, event, key, value):
  1445. """Add event to the search table
  1446. Args:
  1447. txn (cursor):
  1448. event (EventBase):
  1449. key (str):
  1450. value (str):
  1451. """
  1452. self.store.store_search_entries_txn(
  1453. txn,
  1454. (
  1455. SearchEntry(
  1456. key=key,
  1457. value=value,
  1458. event_id=event.event_id,
  1459. room_id=event.room_id,
  1460. stream_ordering=event.internal_metadata.stream_ordering,
  1461. origin_server_ts=event.origin_server_ts,
  1462. ),
  1463. ),
  1464. )
  1465. def _set_push_actions_for_event_and_users_txn(
  1466. self, txn, events_and_contexts, all_events_and_contexts
  1467. ):
  1468. """Handles moving push actions from staging table to main
  1469. event_push_actions table for all events in `events_and_contexts`.
  1470. Also ensures that all events in `all_events_and_contexts` are removed
  1471. from the push action staging area.
  1472. Args:
  1473. events_and_contexts (list[(EventBase, EventContext)]): events
  1474. we are persisting
  1475. all_events_and_contexts (list[(EventBase, EventContext)]): all
  1476. events that we were going to persist. This includes events
  1477. we've already persisted, etc, that wouldn't appear in
  1478. events_and_context.
  1479. """
  1480. sql = """
  1481. INSERT INTO event_push_actions (
  1482. room_id, event_id, user_id, actions, stream_ordering,
  1483. topological_ordering, notif, highlight, unread
  1484. )
  1485. SELECT ?, event_id, user_id, actions, ?, ?, notif, highlight, unread
  1486. FROM event_push_actions_staging
  1487. WHERE event_id = ?
  1488. """
  1489. if events_and_contexts:
  1490. txn.executemany(
  1491. sql,
  1492. (
  1493. (
  1494. event.room_id,
  1495. event.internal_metadata.stream_ordering,
  1496. event.depth,
  1497. event.event_id,
  1498. )
  1499. for event, _ in events_and_contexts
  1500. ),
  1501. )
  1502. for event, _ in events_and_contexts:
  1503. user_ids = self.db_pool.simple_select_onecol_txn(
  1504. txn,
  1505. table="event_push_actions_staging",
  1506. keyvalues={"event_id": event.event_id},
  1507. retcol="user_id",
  1508. )
  1509. for uid in user_ids:
  1510. txn.call_after(
  1511. self.store.get_unread_event_push_actions_by_room_for_user.invalidate_many,
  1512. (event.room_id, uid),
  1513. )
  1514. # Now we delete the staging area for *all* events that were being
  1515. # persisted.
  1516. txn.executemany(
  1517. "DELETE FROM event_push_actions_staging WHERE event_id = ?",
  1518. ((event.event_id,) for event, _ in all_events_and_contexts),
  1519. )
  1520. def _remove_push_actions_for_event_id_txn(self, txn, room_id, event_id):
  1521. # Sad that we have to blow away the cache for the whole room here
  1522. txn.call_after(
  1523. self.store.get_unread_event_push_actions_by_room_for_user.invalidate_many,
  1524. (room_id,),
  1525. )
  1526. txn.execute(
  1527. "DELETE FROM event_push_actions WHERE room_id = ? AND event_id = ?",
  1528. (room_id, event_id),
  1529. )
  1530. def _store_rejections_txn(self, txn, event_id, reason):
  1531. self.db_pool.simple_insert_txn(
  1532. txn,
  1533. table="rejections",
  1534. values={
  1535. "event_id": event_id,
  1536. "reason": reason,
  1537. "last_check": self._clock.time_msec(),
  1538. },
  1539. )
  1540. def _store_event_state_mappings_txn(
  1541. self, txn, events_and_contexts: Iterable[Tuple[EventBase, EventContext]]
  1542. ):
  1543. state_groups = {}
  1544. for event, context in events_and_contexts:
  1545. if event.internal_metadata.is_outlier():
  1546. continue
  1547. # if the event was rejected, just give it the same state as its
  1548. # predecessor.
  1549. if context.rejected:
  1550. state_groups[event.event_id] = context.state_group_before_event
  1551. continue
  1552. state_groups[event.event_id] = context.state_group
  1553. self.db_pool.simple_insert_many_txn(
  1554. txn,
  1555. table="event_to_state_groups",
  1556. values=[
  1557. {"state_group": state_group_id, "event_id": event_id}
  1558. for event_id, state_group_id in state_groups.items()
  1559. ],
  1560. )
  1561. for event_id, state_group_id in state_groups.items():
  1562. txn.call_after(
  1563. self.store._get_state_group_for_event.prefill,
  1564. (event_id,),
  1565. state_group_id,
  1566. )
  1567. def _update_min_depth_for_room_txn(self, txn, room_id, depth):
  1568. min_depth = self.store._get_min_depth_interaction(txn, room_id)
  1569. if min_depth is not None and depth >= min_depth:
  1570. return
  1571. self.db_pool.simple_upsert_txn(
  1572. txn,
  1573. table="room_depth",
  1574. keyvalues={"room_id": room_id},
  1575. values={"min_depth": depth},
  1576. )
  1577. def _handle_mult_prev_events(self, txn, events):
  1578. """
  1579. For the given event, update the event edges table and forward and
  1580. backward extremities tables.
  1581. """
  1582. self.db_pool.simple_insert_many_txn(
  1583. txn,
  1584. table="event_edges",
  1585. values=[
  1586. {
  1587. "event_id": ev.event_id,
  1588. "prev_event_id": e_id,
  1589. "room_id": ev.room_id,
  1590. "is_state": False,
  1591. }
  1592. for ev in events
  1593. for e_id in ev.prev_event_ids()
  1594. ],
  1595. )
  1596. self._update_backward_extremeties(txn, events)
  1597. def _update_backward_extremeties(self, txn, events):
  1598. """Updates the event_backward_extremities tables based on the new/updated
  1599. events being persisted.
  1600. This is called for new events *and* for events that were outliers, but
  1601. are now being persisted as non-outliers.
  1602. Forward extremities are handled when we first start persisting the events.
  1603. """
  1604. events_by_room = {} # type: Dict[str, List[EventBase]]
  1605. for ev in events:
  1606. events_by_room.setdefault(ev.room_id, []).append(ev)
  1607. query = (
  1608. "INSERT INTO event_backward_extremities (event_id, room_id)"
  1609. " SELECT ?, ? WHERE NOT EXISTS ("
  1610. " SELECT 1 FROM event_backward_extremities"
  1611. " WHERE event_id = ? AND room_id = ?"
  1612. " )"
  1613. " AND NOT EXISTS ("
  1614. " SELECT 1 FROM events WHERE event_id = ? AND room_id = ? "
  1615. " AND outlier = ?"
  1616. " )"
  1617. )
  1618. txn.executemany(
  1619. query,
  1620. [
  1621. (e_id, ev.room_id, e_id, ev.room_id, e_id, ev.room_id, False)
  1622. for ev in events
  1623. for e_id in ev.prev_event_ids()
  1624. if not ev.internal_metadata.is_outlier()
  1625. ],
  1626. )
  1627. query = (
  1628. "DELETE FROM event_backward_extremities"
  1629. " WHERE event_id = ? AND room_id = ?"
  1630. )
  1631. txn.executemany(
  1632. query,
  1633. [
  1634. (ev.event_id, ev.room_id)
  1635. for ev in events
  1636. if not ev.internal_metadata.is_outlier()
  1637. ],
  1638. )
  1639. @attr.s(slots=True)
  1640. class _LinkMap:
  1641. """A helper type for tracking links between chains.
  1642. """
  1643. # Stores the set of links as nested maps: source chain ID -> target chain ID
  1644. # -> source sequence number -> target sequence number.
  1645. maps = attr.ib(type=Dict[int, Dict[int, Dict[int, int]]], factory=dict)
  1646. # Stores the links that have been added (with new set to true), as tuples of
  1647. # `(source chain ID, source sequence no, target chain ID, target sequence no.)`
  1648. additions = attr.ib(type=Set[Tuple[int, int, int, int]], factory=set)
  1649. def add_link(
  1650. self,
  1651. src_tuple: Tuple[int, int],
  1652. target_tuple: Tuple[int, int],
  1653. new: bool = True,
  1654. ) -> bool:
  1655. """Add a new link between two chains, ensuring no redundant links are added.
  1656. New links should be added in topological order.
  1657. Args:
  1658. src_tuple: The chain ID/sequence number of the source of the link.
  1659. target_tuple: The chain ID/sequence number of the target of the link.
  1660. new: Whether this is a "new" link, i.e. should it be returned
  1661. by `get_additions`.
  1662. Returns:
  1663. True if a link was added, false if the given link was dropped as redundant
  1664. """
  1665. src_chain, src_seq = src_tuple
  1666. target_chain, target_seq = target_tuple
  1667. current_links = self.maps.setdefault(src_chain, {}).setdefault(target_chain, {})
  1668. assert src_chain != target_chain
  1669. if new:
  1670. # Check if the new link is redundant
  1671. for current_seq_src, current_seq_target in current_links.items():
  1672. # If a link "crosses" another link then its redundant. For example
  1673. # in the following link 1 (L1) is redundant, as any event reachable
  1674. # via L1 is *also* reachable via L2.
  1675. #
  1676. # Chain A Chain B
  1677. # | |
  1678. # L1 |------ |
  1679. # | | |
  1680. # L2 |---- | -->|
  1681. # | | |
  1682. # | |--->|
  1683. # | |
  1684. # | |
  1685. #
  1686. # So we only need to keep links which *do not* cross, i.e. links
  1687. # that both start and end above or below an existing link.
  1688. #
  1689. # Note, since we add links in topological ordering we should never
  1690. # see `src_seq` less than `current_seq_src`.
  1691. if current_seq_src <= src_seq and target_seq <= current_seq_target:
  1692. # This new link is redundant, nothing to do.
  1693. return False
  1694. self.additions.add((src_chain, src_seq, target_chain, target_seq))
  1695. current_links[src_seq] = target_seq
  1696. return True
  1697. def get_links_from(
  1698. self, src_tuple: Tuple[int, int]
  1699. ) -> Generator[Tuple[int, int], None, None]:
  1700. """Gets the chains reachable from the given chain/sequence number.
  1701. Yields:
  1702. The chain ID and sequence number the link points to.
  1703. """
  1704. src_chain, src_seq = src_tuple
  1705. for target_id, sequence_numbers in self.maps.get(src_chain, {}).items():
  1706. for link_src_seq, target_seq in sequence_numbers.items():
  1707. if link_src_seq <= src_seq:
  1708. yield target_id, target_seq
  1709. def get_links_between(
  1710. self, source_chain: int, target_chain: int
  1711. ) -> Generator[Tuple[int, int], None, None]:
  1712. """Gets the links between two chains.
  1713. Yields:
  1714. The source and target sequence numbers.
  1715. """
  1716. yield from self.maps.get(source_chain, {}).get(target_chain, {}).items()
  1717. def get_additions(self) -> Generator[Tuple[int, int, int, int], None, None]:
  1718. """Gets any newly added links.
  1719. Yields:
  1720. The source chain ID/sequence number and target chain ID/sequence number
  1721. """
  1722. for src_chain, src_seq, target_chain, _ in self.additions:
  1723. target_seq = self.maps.get(src_chain, {}).get(target_chain, {}).get(src_seq)
  1724. if target_seq is not None:
  1725. yield (src_chain, src_seq, target_chain, target_seq)
  1726. def exists_path_from(
  1727. self, src_tuple: Tuple[int, int], target_tuple: Tuple[int, int],
  1728. ) -> bool:
  1729. """Checks if there is a path between the source chain ID/sequence and
  1730. target chain ID/sequence.
  1731. """
  1732. src_chain, src_seq = src_tuple
  1733. target_chain, target_seq = target_tuple
  1734. if src_chain == target_chain:
  1735. return target_seq <= src_seq
  1736. links = self.get_links_between(src_chain, target_chain)
  1737. for link_start_seq, link_end_seq in links:
  1738. if link_start_seq <= src_seq and target_seq <= link_end_seq:
  1739. return True
  1740. return False